@@ -214,16 +214,17 @@ def _init_solver_components(
214214 self ._pina_optimizers = optimizers
215215 self ._pina_schedulers = schedulers
216216
217- def training_step (self , batch ):
217+ def training_step (self , batch , batch_idx ):
218218 """
219219 Solver training step.
220220
221221 :param list[tuple[str, dict]] batch: A batch of data. Each element is a
222222 tuple containing a condition name and a dictionary of points.
223+ :param int batch_idx: The index of the current batch.
223224 :return: The loss of the training step.
224225 :rtype: torch.Tensor
225226 """
226- loss = self .batch_evaluation_step (batch = batch )
227+ loss = self .batch_evaluation_step (batch = batch , batch_idx = batch_idx )
227228 self .log (
228229 name = "train_loss" ,
229230 value = loss .item (),
@@ -232,16 +233,17 @@ def training_step(self, batch):
232233 )
233234 return loss
234235
235- def validation_step (self , batch ):
236+ def validation_step (self , batch , batch_idx ):
236237 """
237238 Solver validation step.
238239
239240 :param list[tuple[str, dict]] batch: A batch of data. Each element is a
240241 tuple containing a condition name and a dictionary of points.
242+ :param int batch_idx: The index of the current batch.
241243 :return: The loss of the training step.
242244 :rtype: torch.Tensor
243245 """
244- loss = self .batch_evaluation_step (batch = batch )
246+ loss = self .batch_evaluation_step (batch = batch , batch_idx = batch_idx )
245247 self .log (
246248 name = "val_loss" ,
247249 value = loss .item (),
@@ -250,16 +252,17 @@ def validation_step(self, batch):
250252 )
251253 return loss
252254
253- def test_step (self , batch ):
255+ def test_step (self , batch , batch_idx ):
254256 """
255257 Solver test step.
256258
257259 :param list[tuple[str, dict]] batch: A batch of data. Each element is a
258260 tuple containing a condition name and a dictionary of points.
261+ :param int batch_idx: The index of the current batch.
259262 :return: The loss of the training step.
260263 :rtype: torch.Tensor
261264 """
262- loss = self .batch_evaluation_step (batch = batch )
265+ loss = self .batch_evaluation_step (batch = batch , batch_idx = batch_idx )
263266 self .log (
264267 name = "test_loss" ,
265268 value = loss .item (),
@@ -268,13 +271,14 @@ def test_step(self, batch):
268271 )
269272 return loss
270273
271- def _compute_condition_loss (self , condition , data ):
274+ def _compute_condition_loss (self , condition , data , batch_idx ):
272275 """
273276 Compute the scalar loss for a given condition and its data.
274277
275278 :param BaseCondition condition: The condition for which to compute the
276279 loss.
277280 :param dict data: The data corresponding to the condition.
281+ :param int batch_idx: The index of the current batch.
278282 :return: The scalar loss for the condition.
279283 :rtype: torch.Tensor
280284 """
@@ -289,26 +293,27 @@ def _compute_condition_loss(self, condition, data):
289293 # Retrieve condition name for more complex weighting schemes
290294 condition_name = condition .name if hasattr (condition , "name" ) else None
291295
292- # Compute the scalar loss from the residual tensor and return it
293- condition_loss = self ._loss_from_residual (condition_name )
296+ # Compute the tensor loss from the residual tensor
297+ condition_tensor_loss = self ._loss_from_residual (condition_name )
294298
295- return condition_loss
299+ # Compute the scalar loss from the tensor loss and return it
300+ condition_scalar_loss = self ._apply_reduction (condition_tensor_loss )
301+
302+ return condition_scalar_loss
296303
297304 def _loss_from_residual (self , condition_name = None ):
298305 """
299- Compute the scalar loss from the residual tensor.
306+ Compute the tensor loss from the residual tensor.
300307
301308 :param str condition_name: The name of the condition.
302- :return: The scalar loss computed from the residual tensor.
309+ :return: The tensor loss computed from the residual tensor.
303310 :rtype: torch.Tensor | LabelTensor
304311 """
305312 # Compute the loss tensor and appply reduction
306- loss_tensor = self ._loss_fn (
313+ return self ._loss_fn (
307314 self .residual_tensor , torch .zeros_like (self .residual_tensor )
308315 )
309316
310- return self ._apply_reduction (loss_tensor )
311-
312317 def _apply_reduction (self , value ):
313318 """
314319 Apply the specified reduction to the loss tensor.
0 commit comments