Skip to content

Commit 882d32e

Browse files
add batching support for self-adaptive pinns
1 parent 6d10989 commit 882d32e

2 files changed

Lines changed: 206 additions & 138 deletions

File tree

pina/solver/physics_informed_solver/self_adaptive_pinn.py

Lines changed: 188 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,20 @@ class Weights(torch.nn.Module):
1515
:class:`SelfAdaptivePINN` solver.
1616
"""
1717

18-
def __init__(self, func):
18+
def __init__(self, func, num_points):
1919
"""
2020
Initialization of the :class:`Weights` class.
2121
2222
:param torch.nn.Module func: the mask model.
23+
:param int num_points: the number of input points.
2324
"""
2425
super().__init__()
26+
27+
# Check consistency
2528
check_consistency(func, torch.nn.Module)
26-
self.sa_weights = torch.nn.Parameter(torch.Tensor())
29+
30+
# Initialize the weights as a learnable parameter
31+
self.sa_weights = torch.nn.Parameter(torch.zeros(num_points, 1))
2732
self.func = func
2833

2934
def forward(self):
@@ -140,134 +145,111 @@ def __init__(
140145
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
141146
Default is `None`.
142147
"""
143-
# check consistency weitghs_function
148+
# Check consistency
144149
check_consistency(weight_function, torch.nn.Module)
145150

146-
# create models for weights
147-
weights_dict = {}
148-
for condition_name in problem.conditions:
149-
weights_dict[condition_name] = Weights(weight_function)
150-
weights_dict = torch.nn.ModuleDict(weights_dict)
151+
# Define a ModuleDict for the weights
152+
weights = {}
153+
for cond, data in problem.input_pts.items():
154+
weights[cond] = Weights(func=weight_function, num_points=len(data))
155+
weights = torch.nn.ModuleDict(weights)
151156

152157
super().__init__(
153-
models=[model, weights_dict],
158+
models=[model, weights],
154159
problem=problem,
155160
optimizers=[optimizer_model, optimizer_weights],
156161
schedulers=[scheduler_model, scheduler_weights],
157162
weighting=weighting,
158163
loss=loss,
159164
)
160165

161-
self._vectorial_loss = deepcopy(self.loss)
162-
self._vectorial_loss.reduction = "none"
166+
# Extract the reduction method from the loss function
167+
self._reduction = self._loss_fn.reduction
163168

164-
def forward(self, x):
165-
"""
166-
Forward pass.
169+
# Set the loss function to return non-aggregated losses
170+
self._loss_fn = type(self._loss_fn)(reduction="none")
167171

168-
:param LabelTensor x: Input tensor.
169-
:return: The output of the neural network.
170-
:rtype: LabelTensor
172+
def training_step(self, batch, batch_idx, **kwargs):
171173
"""
172-
return self.model(x)
173-
174-
def training_step(self, batch):
175-
"""
176-
Solver training step, overridden to perform manual optimization.
174+
Solver training step. It computes the optimization cycle and aggregates
175+
the losses using the ``weighting`` attribute.
177176
178177
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
179178
tuple containing a condition name and a dictionary of points.
180-
:return: The aggregated loss.
181-
:rtype: LabelTensor
179+
:param int batch_idx: The index of the current batch.
180+
:param dict kwargs: Additional keyword arguments passed to
181+
``optimization_cycle``.
182+
:return: The loss of the training step.
183+
:rtype: torch.Tensor
182184
"""
183185
# Weights optimization
184186
self.optimizer_weights.instance.zero_grad()
185-
loss = super().training_step(batch)
187+
loss = self._optimization_cycle(
188+
batch=batch, batch_idx=batch_idx, **kwargs
189+
)
186190
self.manual_backward(-loss)
187191
self.optimizer_weights.instance.step()
188192
self.scheduler_weights.instance.step()
189193

190194
# Model optimization
191195
self.optimizer_model.instance.zero_grad()
192-
loss = super().training_step(batch)
196+
loss = self._optimization_cycle(
197+
batch=batch, batch_idx=batch_idx, **kwargs
198+
)
193199
self.manual_backward(loss)
194200
self.optimizer_model.instance.step()
195201
self.scheduler_model.instance.step()
196202

203+
# Log the loss
204+
self.store_log("train_loss", loss, self.get_batch_size(batch))
205+
197206
return loss
198207

199-
def configure_optimizers(self):
208+
@torch.set_grad_enabled(True)
209+
def validation_step(self, batch, **kwargs):
200210
"""
201-
Optimizer configuration.
211+
The validation step for the Self-Adaptive PINN solver. It returns the
212+
average residual computed with the ``loss`` function not aggregated.
202213
203-
:return: The optimizers and the schedulers
204-
:rtype: tuple[list[Optimizer], list[Scheduler]]
205-
"""
206-
# If the problem is an InverseProblem, add the unknown parameters
207-
# to the parameters to be optimized
208-
self.optimizer_model.hook(self.model.parameters())
209-
self.optimizer_weights.hook(self.weights_dict.parameters())
210-
if isinstance(self.problem, InverseProblem):
211-
self.optimizer_model.instance.add_param_group(
212-
{
213-
"params": [
214-
self._params[var]
215-
for var in self.problem.unknown_variables
216-
]
217-
}
218-
)
219-
self.scheduler_model.hook(self.optimizer_model)
220-
self.scheduler_weights.hook(self.optimizer_weights)
221-
return (
222-
[self.optimizer_model.instance, self.optimizer_weights.instance],
223-
[self.scheduler_model.instance, self.scheduler_weights.instance],
224-
)
225-
226-
def on_train_start(self):
214+
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
215+
tuple containing a condition name and a dictionary of points.
216+
:param dict kwargs: Additional keyword arguments passed to
217+
``optimization_cycle``.
218+
:return: The loss of the validation step.
219+
:rtype: torch.Tensor
227220
"""
228-
This method is called at the start of the training process to set the
229-
self-adaptive weights as parameters of the mask model.
221+
losses = self.optimization_cycle(batch=batch, **kwargs)
230222

231-
:raises NotImplementedError: If the batch size is not ``None``.
232-
"""
233-
if self.trainer.batch_size is not None:
234-
raise NotImplementedError(
235-
"SelfAdaptivePINN only works with full "
236-
"batch size, set batch_size=None inside "
237-
"the Trainer to use the solver."
238-
)
239-
device = torch.device(
240-
self.trainer._accelerator_connector._accelerator_flag
241-
)
223+
# Aggregate losses for each condition
224+
for cond, loss in losses.items():
225+
losses[cond] = self._apply_reduction(loss=losses[cond])
242226

243-
# Initialize the self adaptive weights only for training points
244-
for (
245-
condition_name,
246-
tensor,
247-
) in self.trainer.data_module.train_dataset.input.items():
248-
self.weights_dict[condition_name].sa_weights.data = torch.rand(
249-
(tensor.shape[0], 1), device=device
250-
)
251-
return super().on_train_start()
227+
loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor)
228+
self.store_log("val_loss", loss, self.get_batch_size(batch))
229+
return loss
252230

253-
def on_load_checkpoint(self, checkpoint):
231+
@torch.set_grad_enabled(True)
232+
def test_step(self, batch, **kwargs):
254233
"""
255-
Override of the Pytorch Lightning ``on_load_checkpoint`` method to
256-
handle checkpoints for Self-Adaptive Weights. This method should not be
257-
overridden, if not intentionally.
234+
The test step for the Self-Adaptive PINN solver. It returns the average
235+
residual computed with the ``loss`` function not aggregated.
258236
259-
:param dict checkpoint: Pytorch Lightning checkpoint dict.
237+
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
238+
tuple containing a condition name and a dictionary of points.
239+
:param dict kwargs: Additional keyword arguments passed to
240+
``optimization_cycle``.
241+
:return: The loss of the test step.
242+
:rtype: torch.Tensor
260243
"""
261-
# First initialize self-adaptive weights with correct shape,
262-
# then load the values from the checkpoint.
263-
for condition_name, _ in self.problem.input_pts.items():
264-
shape = checkpoint["state_dict"][
265-
f"_pina_models.1.{condition_name}.sa_weights"
266-
].shape
267-
self.weights_dict[condition_name].sa_weights.data = torch.rand(
268-
shape
269-
)
270-
return super().on_load_checkpoint(checkpoint)
244+
losses = self.optimization_cycle(batch=batch, **kwargs)
245+
246+
# Aggregate losses for each condition
247+
for cond, loss in losses.items():
248+
losses[cond] = self._apply_reduction(loss=losses[cond])
249+
250+
loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor)
251+
self.store_log("test_loss", loss, self.get_batch_size(batch))
252+
return loss
271253

272254
def loss_phys(self, samples, equation):
273255
"""
@@ -279,47 +261,138 @@ def loss_phys(self, samples, equation):
279261
:return: The computed physics loss.
280262
:rtype: LabelTensor
281263
"""
282-
residual = self.compute_residual(samples, equation)
283-
weights = self.weights_dict[self.current_condition_name].forward()
284-
loss_value = self._vectorial_loss(
285-
torch.zeros_like(residual, requires_grad=True), residual
286-
)
287-
return self._vect_to_scalar(weights * loss_value)
264+
residuals = self.compute_residual(samples, equation)
265+
return self._loss_fn(residuals, torch.zeros_like(residuals))
288266

289267
def loss_data(self, input, target):
290268
"""
291-
Compute the data loss for the PINN solver by evaluating the loss
269+
Compute the data loss for the Supervised solver by evaluating the loss
292270
between the network's output and the true solution. This method should
293271
not be overridden, if not intentionally.
294272
295273
:param input: The input to the neural network.
296-
:type input: LabelTensor
274+
:type input: LabelTensor | torch.Tensor | Graph | Data
297275
:param target: The target to compare with the network's output.
298-
:type target: LabelTensor
276+
:type target: LabelTensor | torch.Tensor | Graph | Data
299277
:return: The supervised loss, averaged over the number of observations.
300-
:rtype: LabelTensor
278+
:rtype: LabelTensor | torch.Tensor | Graph | Data
301279
"""
302280
return self._loss_fn(self.forward(input), target)
303281

304-
def _vect_to_scalar(self, loss_value):
282+
def forward(self, x):
305283
"""
306-
Computation of the scalar loss.
284+
Forward pass.
307285
308-
:param LabelTensor loss_value: the tensor of pointwise losses.
309-
:raises RuntimeError: If the loss reduction is not ``mean`` or ``sum``.
310-
:return: The computed scalar loss.
311-
:rtype: LabelTensor
286+
:param x: Input tensor.
287+
:type x: torch.Tensor | LabelTensor
288+
:return: The output of the neural network.
289+
:rtype: torch.Tensor | LabelTensor
290+
"""
291+
return self.model(x)
292+
293+
def configure_optimizers(self):
294+
"""
295+
Optimizer configuration.
296+
297+
:return: The optimizers and the schedulers
298+
:rtype: tuple[list[Optimizer], list[Scheduler]]
312299
"""
313-
if self.loss.reduction == "mean":
314-
ret = torch.mean(loss_value)
315-
elif self.loss.reduction == "sum":
316-
ret = torch.sum(loss_value)
317-
else:
318-
raise RuntimeError(
319-
f"Invalid reduction, got {self.loss.reduction} "
320-
"but expected mean or sum."
300+
# Hook the optimizers to the models
301+
self.optimizer_model.hook(self.model.parameters())
302+
self.optimizer_weights.hook(self.weights.parameters())
303+
304+
# Add unknown parameters to optimization list in case of InverseProblem
305+
if isinstance(self.problem, InverseProblem):
306+
self.optimizer_model.instance.add_param_group(
307+
{
308+
"params": [
309+
self._params[var]
310+
for var in self.problem.unknown_variables
311+
]
312+
}
321313
)
322-
return ret
314+
315+
# Hook the schedulers to the optimizers
316+
self.scheduler_model.hook(self.optimizer_model)
317+
self.scheduler_weights.hook(self.optimizer_weights)
318+
319+
return (
320+
[self.optimizer_model.instance, self.optimizer_weights.instance],
321+
[self.scheduler_model.instance, self.scheduler_weights.instance],
322+
)
323+
324+
def _optimization_cycle(self, batch, batch_idx, **kwargs):
325+
"""
326+
Aggregate the loss for each condition in the batch.
327+
328+
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
329+
tuple containing a condition name and a dictionary of points.
330+
:param int batch_idx: The index of the current batch.
331+
:param dict kwargs: Additional keyword arguments passed to
332+
``optimization_cycle``.
333+
:return: The losses computed for all conditions in the batch, casted
334+
to a subclass of :class:`torch.Tensor`. It should return a dict
335+
containing the condition name and the associated scalar loss.
336+
:rtype: dict
337+
"""
338+
# Compute non-aggregated residuals
339+
residuals = self.optimization_cycle(batch)
340+
341+
# Compute losses
342+
losses = {}
343+
for cond, res in residuals.items():
344+
345+
weight_tensor = self.weights[cond]()
346+
347+
# Get the correct indices for the weights. Modulus is used according
348+
# to the number of points in the condition, as in the PinaDataset.
349+
len_res = len(res)
350+
idx = torch.arange(
351+
batch_idx * len_res,
352+
(batch_idx + 1) * len_res,
353+
device=res.device,
354+
) % len(self.problem.input_pts[cond])
355+
356+
# Apply the weights to the residuals
357+
losses[cond] = self._apply_reduction(
358+
loss=(res * weight_tensor[idx])
359+
)
360+
361+
# Store log
362+
self.store_log(
363+
f"{cond}_loss", losses[cond].item(), self.get_batch_size(batch)
364+
)
365+
366+
# Clamp unknown parameters in InverseProblem (if needed)
367+
self._clamp_params()
368+
369+
# Aggregate
370+
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
371+
372+
return loss
373+
374+
def _apply_reduction(self, loss):
375+
"""
376+
Apply the specified reduction to the loss. The reduction is deferred
377+
until the end of the optimization cycle to allow self-adaptive weights
378+
to be applied to each point beforehand.
379+
380+
:param torch.Tensor loss: The loss tensor to be reduced.
381+
:return: The reduced loss tensor.
382+
:rtype: torch.Tensor
383+
:raises ValueError: If the reduction method is neither "mean" nor "sum".
384+
"""
385+
# Apply the specified reduction method
386+
if self._reduction == "mean":
387+
return loss.mean()
388+
if self._reduction == "sum":
389+
return loss.sum()
390+
391+
# Raise an error if the reduction method is not recognized
392+
raise ValueError(
393+
f"Unknown reduction: {self._reduction}."
394+
" Supported reductions are 'mean' and 'sum'."
395+
)
323396

324397
@property
325398
def model(self):
@@ -332,7 +405,7 @@ def model(self):
332405
return self.models[0]
333406

334407
@property
335-
def weights_dict(self):
408+
def weights(self):
336409
"""
337410
The self-adaptive weights.
338411

0 commit comments

Comments
 (0)