@@ -221,9 +221,12 @@ def loss_data(
221221 state_shape = unroll .shape [3 :]
222222
223223 losses = []
224+ current_state = unroll [:, :, 0 , ...] # [B1, B2, *state_shape]
224225 for step in range (1 , Twin ):
225226
226- model_input = self .preprocess_step (unroll , step = step , ** kwargs )
227+ model_input = self .preprocess_step (
228+ current_state , unroll , step = step , ** kwargs
229+ )
227230 model_output = self .model_forward (model_input )
228231 predicted_state = self .postprocess_step (
229232 model_output , unroll = unroll , step = step , ** kwargs
@@ -233,6 +236,8 @@ def loss_data(
233236 step_loss = self ._loss_fn (predicted_state , target_state , ** kwargs )
234237 losses .append (step_loss )
235238
239+ current_state = predicted_state
240+
236241 step_losses = torch .stack (losses ) # [unroll_length]
237242
238243 with torch .no_grad ():
@@ -252,21 +257,22 @@ def loss_data(
252257
253258 return aggregation_strategy (step_losses * weights )
254259
255- def preprocess_step (self , unroll , step = None , ** kwargs ):
260+ def preprocess_step (self , current_state , unroll , step = None , ** kwargs ):
256261 """
257262 Pre-process the input unroll for the current step before feeding it to the model.
258263 This method can be overridden by subclasses to implement specific preprocessing logic.
259264
260- :param torch.Tensor unroll: The unroll.
265+ :param torch.Tensor current_state: The current state.
266+ :param torch.Tensor unroll: The unroll, which can be used for context.
261267 :param int step: The current step index within the unroll.
262268 :kwargs: Additional keyword arguments for preprocessing.
263269 :return: The preprocessed unroll for the given step.
264270 :rtype: torch.Tensor
265271 """
266272 if step is None :
267- return unroll
273+ return current_state
268274 else :
269- return unroll [:, :, step - 1 , ...]
275+ return current_state
270276
271277 def model_forward (self , model_input , ** kwargs ):
272278 """
@@ -396,7 +402,7 @@ def predict(self, initial_input, num_steps, **kwargs):
396402 with torch .no_grad ():
397403 for step in range (1 , num_steps + 1 ):
398404 model_input = self .preprocess_step (
399- predictions [- 1 ], step = None , ** kwargs
405+ predictions [- 1 ], unroll = predictions , step = None , ** kwargs
400406 )
401407 next_state = self .model_forward (model_input )
402408 next_state = self .postprocess_step (
0 commit comments