Skip to content

Commit 2737ad1

Browse files
Davide-MiottiGiovanniCanali
authored andcommitted
bug fix for training
1 parent 5901ecc commit 2737ad1

File tree

2 files changed

+18
-10
lines changed

2 files changed

+18
-10
lines changed

pina/_src/solver/autoregressive_solver/autoregressive_solver.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -221,9 +221,12 @@ def loss_data(
221221
state_shape = unroll.shape[3:]
222222

223223
losses = []
224+
current_state = unroll[:, :, 0, ...] # [B1, B2, *state_shape]
224225
for step in range(1, Twin):
225226

226-
model_input = self.preprocess_step(unroll, step=step, **kwargs)
227+
model_input = self.preprocess_step(
228+
current_state, unroll, step=step, **kwargs
229+
)
227230
model_output = self.model_forward(model_input)
228231
predicted_state = self.postprocess_step(
229232
model_output, unroll=unroll, step=step, **kwargs
@@ -233,6 +236,8 @@ def loss_data(
233236
step_loss = self._loss_fn(predicted_state, target_state, **kwargs)
234237
losses.append(step_loss)
235238

239+
current_state = predicted_state
240+
236241
step_losses = torch.stack(losses) # [unroll_length]
237242

238243
with torch.no_grad():
@@ -252,21 +257,22 @@ def loss_data(
252257

253258
return aggregation_strategy(step_losses * weights)
254259

255-
def preprocess_step(self, unroll, step=None, **kwargs):
260+
def preprocess_step(self, current_state, unroll, step=None, **kwargs):
256261
"""
257262
Pre-process the input unroll for the current step before feeding it to the model.
258263
This method can be overridden by subclasses to implement specific preprocessing logic.
259264
260-
:param torch.Tensor unroll: The unroll.
265+
:param torch.Tensor current_state: The current state.
266+
:param torch.Tensor unroll: The unroll, which can be used for context.
261267
:param int step: The current step index within the unroll.
262268
:kwargs: Additional keyword arguments for preprocessing.
263269
:return: The preprocessed unroll for the given step.
264270
:rtype: torch.Tensor
265271
"""
266272
if step is None:
267-
return unroll
273+
return current_state
268274
else:
269-
return unroll[:, :, step - 1, ...]
275+
return current_state
270276

271277
def model_forward(self, model_input, **kwargs):
272278
"""
@@ -396,7 +402,7 @@ def predict(self, initial_input, num_steps, **kwargs):
396402
with torch.no_grad():
397403
for step in range(1, num_steps + 1):
398404
model_input = self.preprocess_step(
399-
predictions[-1], step=None, **kwargs
405+
predictions[-1], unroll=predictions, step=None, **kwargs
400406
)
401407
next_state = self.model_forward(model_input)
402408
next_state = self.postprocess_step(

pina/_src/solver/autoregressive_solver/autoregressive_solver_interface.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,17 +97,19 @@ def optimization_cycle(self, batch):
9797
return condition_loss
9898

9999
@abstractmethod
100-
def preprocess_step(self, tensor, step=None, **kwargs):
100+
def preprocess_step(self, current_state, unroll, step=None, **kwargs):
101101
"""
102-
Preprocess the input state before passing it to the model.
102+
Preprocess the current state before passing it to the model.
103103
This method can be overridden by subclasses to implement
104104
specific preprocessing steps.
105105
106-
:param torch.Tensor tensor: The tensor to be preprocessed.
106+
:param torch.Tensor current_state: The current state to be preprocessed.
107+
:param torch.Tensor unroll: The unroll tensor, which can be used for context.
108+
During inference (predict), this may be a list of previous predictions.
107109
:param int step: The current step index within the unroll.
108110
By default is ``None``, which is meant to be used in inference.
109111
:kwargs: Additional keyword arguments for preprocessing.
110-
:return: The preprocessed tensor for the given step.
112+
:return: The preprocessed state for the given step.
111113
:rtype: torch.Tensor
112114
"""
113115
pass

0 commit comments

Comments
 (0)