Skip to content

Commit df2b7fc

Browse files
fix batch_idx + clamp_params + tensor_loss
1 parent 1dc2a39 commit df2b7fc

13 files changed

Lines changed: 58 additions & 59 deletions

pina/_src/problem/base_problem.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ def move_discretisation_into_conditions(self):
248248
# Set the domain and problem attributes of the new condition
249249
new_condition.domain = cond.domain
250250
new_condition.problem = self
251+
new_condition.name = name
251252

252253
# Replace the old condition in the conditions dictionary
253254
self.conditions[name] = new_condition

pina/_src/solver/autoregressive_single_model_solver.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,10 @@ def __init__(
6767
:param torch.nn.Module model: The model used by the solver.
6868
:param TorchOptimizer optimizer: The optimizer used by the solver.
6969
If ``None``, the ``torch.optim.Adam`` optimizer with a learning rate
70-
of ``0.001`` is used for each model. Default is ``None``.
70+
of ``0.001`` is used. Default is ``None``.
7171
:param TorchScheduler scheduler: The scheduler used by the solver.
7272
If ``None``, the ``torch.optim.lr_scheduler.ConstantLR`` scheduler
73-
with a factor of ``1.0`` is used for each model.
74-
Default is ``None``.
73+
with a factor of ``1.0`` is used. Default is ``None``.
7574
:param BaseWeighting weighting: The weighting strategy used to combine
7675
condition losses. If ``None``, no weighting is applied. Default is
7776
``None``.

pina/_src/solver/base_solver.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -214,16 +214,17 @@ def _init_solver_components(
214214
self._pina_optimizers = optimizers
215215
self._pina_schedulers = schedulers
216216

217-
def training_step(self, batch):
217+
def training_step(self, batch, batch_idx):
218218
"""
219219
Solver training step.
220220
221221
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
222222
tuple containing a condition name and a dictionary of points.
223+
:param int batch_idx: The index of the current batch.
223224
:return: The loss of the training step.
224225
:rtype: torch.Tensor
225226
"""
226-
loss = self.batch_evaluation_step(batch=batch)
227+
loss = self.batch_evaluation_step(batch=batch, batch_idx=batch_idx)
227228
self.log(
228229
name="train_loss",
229230
value=loss.item(),
@@ -232,16 +233,17 @@ def training_step(self, batch):
232233
)
233234
return loss
234235

235-
def validation_step(self, batch):
236+
def validation_step(self, batch, batch_idx):
236237
"""
237238
Solver validation step.
238239
239240
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
240241
tuple containing a condition name and a dictionary of points.
242+
:param int batch_idx: The index of the current batch.
241243
:return: The loss of the training step.
242244
:rtype: torch.Tensor
243245
"""
244-
loss = self.batch_evaluation_step(batch=batch)
246+
loss = self.batch_evaluation_step(batch=batch, batch_idx=batch_idx)
245247
self.log(
246248
name="val_loss",
247249
value=loss.item(),
@@ -250,16 +252,17 @@ def validation_step(self, batch):
250252
)
251253
return loss
252254

253-
def test_step(self, batch):
255+
def test_step(self, batch, batch_idx):
254256
"""
255257
Solver test step.
256258
257259
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
258260
tuple containing a condition name and a dictionary of points.
261+
:param int batch_idx: The index of the current batch.
259262
:return: The loss of the training step.
260263
:rtype: torch.Tensor
261264
"""
262-
loss = self.batch_evaluation_step(batch=batch)
265+
loss = self.batch_evaluation_step(batch=batch, batch_idx=batch_idx)
263266
self.log(
264267
name="test_loss",
265268
value=loss.item(),
@@ -268,13 +271,14 @@ def test_step(self, batch):
268271
)
269272
return loss
270273

271-
def _compute_condition_loss(self, condition, data):
274+
def _compute_condition_loss(self, condition, data, batch_idx):
272275
"""
273276
Compute the scalar loss for a given condition and its data.
274277
275278
:param BaseCondition condition: The condition for which to compute the
276279
loss.
277280
:param dict data: The data corresponding to the condition.
281+
:param int batch_idx: The index of the current batch.
278282
:return: The scalar loss for the condition.
279283
:rtype: torch.Tensor
280284
"""
@@ -289,26 +293,27 @@ def _compute_condition_loss(self, condition, data):
289293
# Retrieve condition name for more complex weighting schemes
290294
condition_name = condition.name if hasattr(condition, "name") else None
291295

292-
# Compute the scalar loss from the residual tensor and return it
293-
condition_loss = self._loss_from_residual(condition_name)
296+
# Compute the tensor loss from the residual tensor
297+
condition_tensor_loss = self._loss_from_residual(condition_name)
294298

295-
return condition_loss
299+
# Compute the scalar loss from the tensor loss and return it
300+
condition_scalar_loss = self._apply_reduction(condition_tensor_loss)
301+
302+
return condition_scalar_loss
296303

297304
def _loss_from_residual(self, condition_name=None):
298305
"""
299-
Compute the scalar loss from the residual tensor.
306+
Compute the tensor loss from the residual tensor.
300307
301308
:param str condition_name: The name of the condition.
302-
:return: The scalar loss computed from the residual tensor.
309+
:return: The tensor loss computed from the residual tensor.
303310
:rtype: torch.Tensor | LabelTensor
304311
"""
305312
# Compute the loss tensor and appply reduction
306-
loss_tensor = self._loss_fn(
313+
return self._loss_fn(
307314
self.residual_tensor, torch.zeros_like(self.residual_tensor)
308315
)
309316

310-
return self._apply_reduction(loss_tensor)
311-
312317
def _apply_reduction(self, value):
313318
"""
314319
Apply the specified reduction to the loss tensor.

pina/_src/solver/mixin/autoregressive_mixin.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ def _init_autoregressive_components(
4747

4848
def _loss_from_residual(self, condition_name=None):
4949
"""
50-
Compute the scalar loss from the residual tensor.
50+
Compute the tensor loss from the residual tensor.
5151
5252
:param str condition_name: The name of the condition.
53-
:return: The scalar loss computed from the residual tensor.
53+
:return: The tensor loss computed from the residual tensor.
5454
:rtype: torch.Tensor | LabelTensor
5555
"""
5656
# Compute the step losses from the residual tensor
@@ -62,10 +62,7 @@ def _loss_from_residual(self, condition_name=None):
6262
with torch.no_grad():
6363
weights = self._get_weights(condition_name or "default", step_loss)
6464

65-
# Compute the weighted step losses
66-
weighted_step_loss = step_loss * weights
67-
68-
return self._apply_reduction(weighted_step_loss)
65+
return step_loss * weights
6966

7067
def _get_weights(self, condition_name, step_loss):
7168
"""

pina/_src/solver/mixin/condition_aggregator_mixin.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class _ConditionAggregatorMixin:
1010
:class:`~pina._src.solver.base_solver.BaseSolver`.
1111
"""
1212

13-
def batch_evaluation_step(self, batch):
13+
def batch_evaluation_step(self, batch, batch_idx):
1414
"""
1515
Evaluate and aggregate the losses for all conditions in a batch.
1616
@@ -21,6 +21,7 @@ def batch_evaluation_step(self, batch):
2121
2222
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
2323
tuple containing a condition name and a dictionary of points.
24+
:param int batch_idx: The index of the current batch.
2425
:return: The aggregated scalar loss for the batch.
2526
:rtype: torch.Tensor
2627
"""
@@ -34,8 +35,12 @@ def batch_evaluation_step(self, batch):
3435
condition_losses[condition_name] = self._compute_condition_loss(
3536
condition=self.problem.conditions[condition_name],
3637
data=dict(data),
38+
batch_idx=batch_idx,
3739
)
3840

41+
# Clamp parameters - null operation if problem is not InverseProblem
42+
self._clamp_params()
43+
3944
# Log the individual condition losses
4045
for name, value in condition_losses.items():
4146
self.log(

pina/_src/solver/mixin/manual_optimization_mixin.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@ def _init_manual_optimization(self):
1414
"""
1515
self.automatic_optimization = False
1616

17-
def training_step(self, batch):
17+
def training_step(self, batch, batch_idx):
1818
"""
1919
Solver training step.
2020
2121
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
2222
tuple containing a condition name and a dictionary of points.
23+
:param int batch_idx: The index of the current batch.
2324
:return: The loss of the training step.
2425
:rtype: torch.Tensor
2526
"""
@@ -28,7 +29,7 @@ def training_step(self, batch):
2829
opt.instance.zero_grad()
2930

3031
# Perform the forward pass and compute the loss
31-
loss = super().training_step(batch)
32+
loss = super().training_step(batch, batch_idx)
3233

3334
# Perform the backward pass
3435
self.manual_backward(loss)

pina/_src/solver/mixin/multi_model_mixin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def configure_optimizers(self):
3131
Configure the optimizers and schedulers for all models.
3232
3333
:return: The optimizer and the scheduler
34-
:rtype: tuple[list[Optimizer], list[Scheduler]]
34+
:rtype: tuple[list[TorchOptimizer], list[TorchScheduler]]
3535
"""
3636
# Iterate over models, optimizers, and schedulers to hook them together
3737
for optimizer, scheduler, model in zip(

pina/_src/solver/mixin/physics_informed_mixin.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,27 @@ class _PhysicsInformedMixin:
1212
"""
1313

1414
@torch.enable_grad()
15-
def validation_step(self, batch):
15+
def validation_step(self, batch, batch_idx):
1616
"""
1717
Solver validation step.
1818
1919
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
2020
tuple containing a condition name and a dictionary of points.
21+
:param int batch_idx: The index of the current batch.
2122
:return: The loss of the training step.
2223
:rtype: torch.Tensor
2324
"""
24-
return super().validation_step(batch)
25+
return super().validation_step(batch, batch_idx)
2526

2627
@torch.enable_grad()
27-
def test_step(self, batch):
28+
def test_step(self, batch, batch_idx):
2829
"""
2930
Solver test step.
3031
3132
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
3233
tuple containing a condition name and a dictionary of points.
34+
:param int batch_idx: The index of the current batch.
3335
:return: The loss of the training step.
3436
:rtype: torch.Tensor
3537
"""
36-
return super().test_step(batch)
38+
return super().test_step(batch, batch_idx)

pina/_src/solver/mixin/single_model_mixin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def configure_optimizers(self):
2828
Configure the optimizer and scheduler for the single model.
2929
3030
:return: The optimizer and the scheduler
31-
:rtype: tuple[list[Optimizer], list[Scheduler]]
31+
:rtype: tuple[list[TorchOptimizer], list[TorchScheduler]]
3232
"""
3333
# Hook the optimizer to the model parameters
3434
self.optimizer.hook(self.model.parameters())

pina/_src/solver/physics_informed_single_model_solver.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,10 @@ def __init__(
7676
:param torch.nn.Module model: The model used by the solver.
7777
:param TorchOptimizer optimizer: The optimizer used by the solver.
7878
If ``None``, the ``torch.optim.Adam`` optimizer with a learning rate
79-
of ``0.001`` is used for each model. Default is ``None``.
79+
of ``0.001`` is used. Default is ``None``.
8080
:param TorchScheduler scheduler: The scheduler used by the solver.
8181
If ``None``, the ``torch.optim.lr_scheduler.ConstantLR`` scheduler
82-
with a factor of ``1.0`` is used for each model.
83-
Default is ``None``.
82+
with a factor of ``1.0`` is used. Default is ``None``.
8483
:param BaseWeighting weighting: The weighting strategy used to combine
8584
condition losses. If ``None``, no weighting is applied. Default is
8685
``None``.

0 commit comments

Comments
 (0)