Skip to content

Commit 5901ecc

Browse files
Davide-MiottiGiovanniCanali
authored andcommitted
update tests for cleaner version
1 parent 6200fe8 commit 5901ecc

File tree

3 files changed

+39
-40
lines changed

3 files changed

+39
-40
lines changed

pina/_src/solver/autoregressive_solver/autoregressive_solver.py

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ def __init__(
4242
self,
4343
problem,
4444
model,
45-
eps=None,
4645
loss=None,
4746
optimizer=None,
4847
scheduler=None,
@@ -57,8 +56,6 @@ def __init__(
5756
the time series data conditions.
5857
:param torch.nn.Module model: Neural network that predicts the
5958
next state given the current state.
60-
:param float eps: If provided, applies exponential weighting to the per-step losses.
61-
If ``None``, uniform weights are used. Default is ``None``.
6259
:param torch.nn.Module loss: Loss function to minimize.
6360
If ``None``, :class:`torch.nn.MSELoss` is used.
6461
Default is ``None``.
@@ -207,15 +204,13 @@ def loss_data(
207204
208205
:param torch.Tensor unroll: Batch of unroll windows with shape
209206
``[B1, B2, Twin, *state_shape]`` where ``Twin = unroll_length + 1``.
207+
:param str condition_name: Name of the condition associated with this data.
210208
:param float eps: If provided, applies step weighting through
211209
:meth:`weighting_strategy`. If ``None``, uniform normalized weights are used.
212210
:param callable aggregation_strategy: Reduction applied to the weighted per-step
213211
losses. If ``None``, :func:`torch.sum` is used.
214-
:param str condition_name: Name of the condition associated with this data.
215-
:kwargs: Additional keyword arguments forwarded to
216-
:meth:`_step_kwargs` and subsequently to :meth:`preprocess_step`.
212+
:kwargs: Additional keyword arguments.
217213
:return: Scalar loss value for the given batch.
218-
:rtype: torch.Tensor
219214
"""
220215
if unroll.dim() < 4:
221216
raise ValueError(
@@ -225,23 +220,19 @@ def loss_data(
225220
B1, B2, Twin = unroll.shape[0], unroll.shape[1], unroll.shape[2]
226221
state_shape = unroll.shape[3:]
227222

228-
# current_value = unroll[:, :, 0, ...] # first time step of each batch
229223
losses = []
230-
231224
for step in range(1, Twin):
232225

233226
model_input = self.preprocess_step(unroll, step=step, **kwargs)
234227
model_output = self.model_forward(model_input)
235-
predicted_state = self.postprocess_step(model_output, unroll=unroll, step=step, **kwargs)
228+
predicted_state = self.postprocess_step(
229+
model_output, unroll=unroll, step=step, **kwargs
230+
)
236231

237232
target_state = unroll[:, :, step, ...]
238-
step_loss = self._loss_fn(
239-
predicted_state, target_state, **kwargs
240-
)
233+
step_loss = self._loss_fn(predicted_state, target_state, **kwargs)
241234
losses.append(step_loss)
242235

243-
# current_value = predicted_state
244-
245236
step_losses = torch.stack(losses) # [unroll_length]
246237

247238
with torch.no_grad():
@@ -260,7 +251,7 @@ def loss_data(
260251
aggregation_strategy = torch.sum
261252

262253
return aggregation_strategy(step_losses * weights)
263-
254+
264255
def preprocess_step(self, unroll, step=None, **kwargs):
265256
"""
266257
Pre-process the input unroll for the current step before feeding it to the model.
@@ -276,7 +267,6 @@ def preprocess_step(self, unroll, step=None, **kwargs):
276267
return unroll
277268
else:
278269
return unroll[:, :, step - 1, ...]
279-
280270

281271
def model_forward(self, model_input, **kwargs):
282272
"""
@@ -292,8 +282,8 @@ def model_forward(self, model_input, **kwargs):
292282
"""
293283

294284
return self.model(model_input)
295-
296-
def postprocess_step(self, model_output, unroll=None, step=None, **kwargs):
285+
286+
def postprocess_step(self, model_output, unroll, step=None, **kwargs):
297287
"""
298288
Post-process the predicted state after obtaining it from the model.
299289
This method can be overridden by subclasses to implement specific post-processing logic.
@@ -303,13 +293,17 @@ def postprocess_step(self, model_output, unroll=None, step=None, **kwargs):
303293
or incorporating additional context from the unroll.
304294
305295
:param torch.Tensor model_output: The output of the model.
296+
:param torch.Tensor unroll: The original unroll tensor, which can be used for context.
297+
:param int step: The current step index within the unroll.
298+
By default is ``None``, which is meant to be used in inference.
299+
:kwargs: Additional keyword arguments for post-processing.
306300
:return: The post-processed model output.
307301
:rtype: torch.Tensor
308302
"""
309-
if unroll is not None and step is not None:
310-
#reshape model output to match target shape if needed
311-
if model_output.shape != unroll[:, :, step, ...].shape:
312-
model_output = model_output.view_as(unroll[:, :, step, ...])
303+
if step is not None:
304+
# do the logic for the training phase, also involving unroll if needed
305+
return model_output
306+
313307
return model_output
314308

315309
def get_weights(self, condition_name, step_losses, eps):
@@ -401,12 +395,12 @@ def predict(self, initial_input, num_steps, **kwargs):
401395

402396
with torch.no_grad():
403397
for step in range(1, num_steps + 1):
404-
model_input = self.preprocess_state(
405-
predictions[-1], **kwargs
398+
model_input = self.preprocess_step(
399+
predictions[-1], step=None, **kwargs
406400
)
407401
next_state = self.model_forward(model_input)
408-
next_state = self.post_process_state(
409-
next_state, **kwargs
402+
next_state = self.postprocess_step(
403+
next_state, unroll=predictions[-1], step=None, **kwargs
410404
)
411405
predictions.append(next_state)
412406

pina/_src/solver/autoregressive_solver/autoregressive_solver_interface.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,19 @@ def optimization_cycle(self, batch):
7979

8080
condition_loss = {}
8181
for condition_name, points in batch:
82+
if hasattr(self.problem.conditions[condition_name], "settings"):
83+
settings = self.problem.conditions[condition_name].settings
84+
eps = settings.get("eps", None)
85+
kwargs = settings.get("kwargs", {})
86+
else:
87+
eps = None
88+
kwargs = {}
89+
8290
loss = self.loss_data(
83-
points["input"]["unroll"],
91+
points["input"],
8492
condition_name=condition_name,
85-
eps=points["input"].get("eps", None),
86-
**points["input"].get("kwargs", {})
93+
eps=eps,
94+
**kwargs,
8795
)
8896
condition_loss[condition_name] = loss
8997
return condition_loss

tests/test_solver/test_autoregressive_solver.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,26 +53,25 @@ def test_end_to_end(y_data_large):
5353
the AutoregressiveSolver with curriculum learning
5454
"""
5555

56-
# AbstratProblem with empty conditions and conditions_settings to be filled later
56+
# AbstratProblem with empty conditions to be filled later
5757
class Problem(AbstractProblem):
5858
output_variables = None
5959
input_variables = None
6060
conditions = {}
61-
conditions_settings = {}
6261

6362
problem = Problem()
6463

6564
solver = AutoregressiveSolver(
6665
problem=problem,
6766
model=MinimalModel(),
68-
optimizer=TorchOptimizer(torch.optim.AdamW, lr=0.008),
67+
optimizer=TorchOptimizer(torch.optim.AdamW, lr=0.015),
6968
)
7069
# PHASE1: train with 'short' condition only
7170
y_short = AutoregressiveSolver.unroll(
7271
y_data_large, unroll_length=4, num_unrolls=20, randomize=False
7372
)
7473
problem.conditions["short"] = DataCondition(input=y_short)
75-
problem.conditions_settings["short"] = {"eps": 0.1}
74+
problem.conditions["short"].settings = {"eps": 0.1}
7675
trainer1 = Trainer(
7776
solver, max_epochs=300, accelerator="cpu", enable_model_summary=False
7877
)
@@ -84,8 +83,7 @@ class Problem(AbstractProblem):
8483
)
8584
problem.conditions.clear()
8685
problem.conditions["medium"] = DataCondition(input=y_medium)
87-
problem.conditions_settings.clear()
88-
problem.conditions_settings["medium"] = {"eps": 0.2}
86+
problem.conditions["medium"].settings = {"eps": 0.2}
8987
trainer2 = Trainer(
9088
solver, max_epochs=1500, accelerator="cpu", enable_model_summary=False
9189
)
@@ -97,8 +95,7 @@ class Problem(AbstractProblem):
9795
)
9896
problem.conditions.clear()
9997
problem.conditions["long"] = DataCondition(input=y_long)
100-
problem.conditions_settings.clear()
101-
problem.conditions_settings["long"] = {"eps": 0.2}
98+
problem.conditions["long"].settings = {"eps": 0.25}
10299
trainer3 = Trainer(
103100
solver, max_epochs=4000, accelerator="cpu", enable_model_summary=False
104101
)
@@ -117,10 +114,10 @@ class Problem(AbstractProblem):
117114
total_mse = torch.nn.functional.mse_loss(
118115
prediction.squeeze(1)[:, 1:, :], ground_truth[:, 1:, :]
119116
)
120-
assert total_mse < 1e-6
117+
assert total_mse < 1e-5
121118

122119

123-
### UNIT TESTS #############################################################################
120+
# ### UNIT TESTS #############################################################################
124121

125122
NUM_TIMESTEPS = 10
126123
NUM_FEATURES = 3

0 commit comments

Comments
 (0)