@@ -42,7 +42,6 @@ def __init__(
4242 self ,
4343 problem ,
4444 model ,
45- eps = None ,
4645 loss = None ,
4746 optimizer = None ,
4847 scheduler = None ,
@@ -57,8 +56,6 @@ def __init__(
5756 the time series data conditions.
5857 :param torch.nn.Module model: Neural network that predicts the
5958 next state given the current state.
60- :param float eps: If provided, applies exponential weighting to the per-step losses.
61- If ``None``, uniform weights are used. Default is ``None``.
6259 :param torch.nn.Module loss: Loss function to minimize.
6360 If ``None``, :class:`torch.nn.MSELoss` is used.
6461 Default is ``None``.
@@ -207,15 +204,13 @@ def loss_data(
207204
208205 :param torch.Tensor unroll: Batch of unroll windows with shape
209206 ``[B1, B2, Twin, *state_shape]`` where ``Twin = unroll_length + 1``.
207+ :param str condition_name: Name of the condition associated with this data.
210208 :param float eps: If provided, applies step weighting through
211209 :meth:`weighting_strategy`. If ``None``, uniform normalized weights are used.
212210 :param callable aggregation_strategy: Reduction applied to the weighted per-step
213211 losses. If ``None``, :func:`torch.sum` is used.
214- :param str condition_name: Name of the condition associated with this data.
215- :kwargs: Additional keyword arguments forwarded to
216- :meth:`_step_kwargs` and subsequently to :meth:`preprocess_step`.
212+ :kwargs: Additional keyword arguments.
217213 :return: Scalar loss value for the given batch.
218- :rtype: torch.Tensor
219214 """
220215 if unroll .dim () < 4 :
221216 raise ValueError (
@@ -225,23 +220,19 @@ def loss_data(
225220 B1 , B2 , Twin = unroll .shape [0 ], unroll .shape [1 ], unroll .shape [2 ]
226221 state_shape = unroll .shape [3 :]
227222
228- # current_value = unroll[:, :, 0, ...] # first time step of each batch
229223 losses = []
230-
231224 for step in range (1 , Twin ):
232225
233226 model_input = self .preprocess_step (unroll , step = step , ** kwargs )
234227 model_output = self .model_forward (model_input )
235- predicted_state = self .postprocess_step (model_output , unroll = unroll , step = step , ** kwargs )
228+ predicted_state = self .postprocess_step (
229+ model_output , unroll = unroll , step = step , ** kwargs
230+ )
236231
237232 target_state = unroll [:, :, step , ...]
238- step_loss = self ._loss_fn (
239- predicted_state , target_state , ** kwargs
240- )
233+ step_loss = self ._loss_fn (predicted_state , target_state , ** kwargs )
241234 losses .append (step_loss )
242235
243- # current_value = predicted_state
244-
245236 step_losses = torch .stack (losses ) # [unroll_length]
246237
247238 with torch .no_grad ():
@@ -260,7 +251,7 @@ def loss_data(
260251 aggregation_strategy = torch .sum
261252
262253 return aggregation_strategy (step_losses * weights )
263-
254+
264255 def preprocess_step (self , unroll , step = None , ** kwargs ):
265256 """
266257 Pre-process the input unroll for the current step before feeding it to the model.
@@ -276,7 +267,6 @@ def preprocess_step(self, unroll, step=None, **kwargs):
276267 return unroll
277268 else :
278269 return unroll [:, :, step - 1 , ...]
279-
280270
281271 def model_forward (self , model_input , ** kwargs ):
282272 """
@@ -292,8 +282,8 @@ def model_forward(self, model_input, **kwargs):
292282 """
293283
294284 return self .model (model_input )
295-
296- def postprocess_step (self , model_output , unroll = None , step = None , ** kwargs ):
285+
286+ def postprocess_step (self , model_output , unroll , step = None , ** kwargs ):
297287 """
298288 Post-process the predicted state after obtaining it from the model.
299289 This method can be overridden by subclasses to implement specific post-processing logic.
@@ -303,13 +293,17 @@ def postprocess_step(self, model_output, unroll=None, step=None, **kwargs):
303293 or incorporating additional context from the unroll.
304294
305295 :param torch.Tensor model_output: The output of the model.
296+ :param torch.Tensor unroll: The original unroll tensor, which can be used for context.
297+ :param int step: The current step index within the unroll.
298+ By default is ``None``, which is meant to be used in inference.
299+ :kwargs: Additional keyword arguments for post-processing.
306300 :return: The post-processed model output.
307301 :rtype: torch.Tensor
308302 """
309- if unroll is not None and step is not None :
310- #reshape model output to match target shape if needed
311- if model_output . shape != unroll [:, :, step , ...]. shape :
312- model_output = model_output . view_as ( unroll [:, :, step , ...])
303+ if step is not None :
304+ # do the logic for the training phase, also involving unroll if needed
305+ return model_output
306+
313307 return model_output
314308
315309 def get_weights (self , condition_name , step_losses , eps ):
@@ -401,12 +395,12 @@ def predict(self, initial_input, num_steps, **kwargs):
401395
402396 with torch .no_grad ():
403397 for step in range (1 , num_steps + 1 ):
404- model_input = self .preprocess_state (
405- predictions [- 1 ], ** kwargs
398+ model_input = self .preprocess_step (
399+ predictions [- 1 ], step = None , ** kwargs
406400 )
407401 next_state = self .model_forward (model_input )
408- next_state = self .post_process_state (
409- next_state , ** kwargs
402+ next_state = self .postprocess_step (
403+ next_state , unroll = predictions [ - 1 ], step = None , ** kwargs
410404 )
411405 predictions .append (next_state )
412406
0 commit comments