Skip to content

Commit 931fdb0

Browse files
Davide-MiottiFilippoOlivo
authored andcommitted
implement autoregressive solver
Co-authored-by: GiovanniCanali <giovanni.canali98@yahoo.it> fixed solver test Remove old test fix test data manager Add pytest fixture for cleaning work directory remove problem.input_pts fix batch size fix batch_size bug
1 parent ab2c5c3 commit 931fdb0

21 files changed

+792
-123
lines changed

docs/source/_rst/_code.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ Solvers
8282
DeepEnsembleSupervisedSolver <solver/ensemble_solver/ensemble_supervised>
8383
ReducedOrderModelSolver <solver/supervised_solver/reduced_order_model.rst>
8484
GAROM <solver/garom.rst>
85+
AutoregressiveSolverInterface <solver/autoregressive_solver/autoregressive_solver_interface.rst>
86+
AutoregressiveSolver <solver/autoregressive_solver/autoregressive_solver.rst>
8587

8688

8789
Models
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Autoregressive Solver
2+
======================
3+
.. currentmodule:: pina.solver.autoregressive_solver.autoregressive_solver
4+
5+
.. autoclass:: pina._src.solver.autoregressive_solver.autoregressive_solver.AutoregressiveSolver
6+
:members:
7+
:show-inheritance:
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Autoregressive Solver Interface
2+
=================================
3+
.. currentmodule:: pina.solver.autoregressive_solver.autoregressive_solver_interface
4+
5+
.. autoclass:: pina._src.solver.autoregressive_solver.autoregressive_solver_interface.AutoregressiveSolverInterface
6+
:members:
7+
:show-inheritance:

pina/_src/data/creator.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,16 @@ def _compute_batch_sizes(self, datasets):
7979
"""
8080
batch_sizes = {}
8181
if self.batching_mode == "common_batch_size":
82+
83+
if self.batch_size is None:
84+
batch_size = max(
85+
dataset.length for dataset in datasets.values()
86+
)
87+
else:
88+
batch_size = self.batch_size
89+
8290
for name in datasets.keys():
83-
if self.batch_size is None:
84-
batch_sizes[name] = len(datasets[name])
85-
else:
86-
batch_sizes[name] = min(
87-
self.batch_size, len(datasets[name])
88-
)
91+
batch_sizes[name] = min(batch_size, len(datasets[name]))
8992
return batch_sizes
9093
if self.batching_mode == "proportional":
9194
return self._compute_proportional_batch_sizes(datasets)
@@ -168,8 +171,12 @@ def __call__(self, datasets):
168171
dataloaders = {}
169172
if self.batching_mode == "common_batch_size":
170173
max_len = max(len(dataset) for dataset in datasets.values())
174+
print(batch_sizes)
171175
for name, dataset in datasets.items():
172-
if self.batching_mode == "common_batch_size":
176+
if (
177+
self.batching_mode == "common_batch_size"
178+
and dataset.length != batch_sizes[name]
179+
):
173180
dataset.max_len = max_len
174181
dataloaders[name] = self.conditions[name].create_dataloader(
175182
dataset=dataset,

pina/_src/problem/abstract_problem.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,21 +43,21 @@ def __init__(self):
4343
self.domains[cond_name] = cond.domain
4444
cond.domain = cond_name
4545

46-
# back compatibility 0.1
47-
@property
48-
def input_pts(self):
49-
"""
50-
Return a dictionary mapping condition names to their corresponding
51-
input points. If some domains are not sampled, they will not be returned
52-
and the corresponding condition will be empty.
53-
54-
:return: The input points of the problem.
55-
:rtype: dict
56-
"""
57-
to_return = {}
58-
for cond_name, data in self.collected_data.items():
59-
to_return[cond_name] = data["input"]
60-
return to_return
46+
# # back compatibility 0.1
47+
# @property
48+
# def input_pts(self):
49+
# """
50+
# Return a dictionary mapping condition names to their corresponding
51+
# input points. If some domains are not sampled, they will not be returned
52+
# and the corresponding condition will be empty.
53+
54+
# :return: The input points of the problem.
55+
# :rtype: dict
56+
# """
57+
# to_return = {}
58+
# for cond_name, data in self.collected_data.items():
59+
# to_return[cond_name] = data["input"]
60+
# return to_return
6161

6262
@property
6363
def discretised_domains(self):

pina/_src/solver/autoregressive_solver/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)