Skip to content

Commit afc5614

Browse files
self-adaptive physics informed solver
1 parent df2b7fc commit afc5614

8 files changed

Lines changed: 580 additions & 2 deletions

docs/source/_rst/_code.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ Solvers and Mixins
9191
Physics-Informed Ensemble Solver <solver/physics_informed_ensemble_solver.rst>
9292
Autoregressive Single-Model Solver <solver/autoregressive_single_model_solver.rst>
9393
Autoregressive Ensemble Solver <solver/autoregressive_ensemble_solver.rst>
94+
Self-Adaptive Physics-Informed Solver <solver/self_adaptive_physics_informed_solver.rst>
9495
Single-Model Mixin <solver/mixin/single_model_mixin.rst>
9596
Multi-Model Mixin <solver/mixin/multi_model_mixin.rst>
9697
Ensemble Mixin <solver/mixin/ensemble_mixin.rst>
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
Self-Adaptive Physics-Informed Solver
2+
=======================================
3+
4+
.. currentmodule:: pina.solver.self_adaptive_physics_informed_solver
5+
6+
.. automodule:: pina._src.solver.self_adaptive_physics_informed_solver
7+
8+
.. autoclass:: pina._src.solver.self_adaptive_physics_informed_solver.SelfAdaptivePhysicsInformedSolver
9+
:members:
10+
:show-inheritance:
11+
:noindex:

pina/_src/problem/zoo/inverse_poisson_problem.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,5 @@ def __init__(self, load=True, data_size=1.0):
174174
self.conditions["data"] = Condition(
175175
input=input_data[:n_data], target=output_data[:n_data]
176176
)
177+
self.conditions["data"].problem = self
178+
self.conditions["data"].name = "data"
Lines changed: 335 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
"""Module for the self-adaptive physics-informed multi-model solver."""
2+
3+
import torch
4+
from pina._src.solver.mixin.physics_informed_mixin import _PhysicsInformedMixin
5+
from pina._src.condition.input_equation_condition import InputEquationCondition
6+
from pina._src.condition.input_target_condition import InputTargetCondition
7+
from pina._src.solver.multi_model_solver import MultiModelSolver
8+
from pina._src.core.utils import check_consistency
9+
from pina._src.condition.domain_equation_condition import (
10+
DomainEquationCondition,
11+
)
12+
13+
14+
class SelfAdaptivePhysicsInformedSolver(
15+
_PhysicsInformedMixin, MultiModelSolver
16+
):
17+
r"""
18+
Multi-model solver for self-adaptive physics-informed learning problems.
19+
20+
This solver approximates the solution of a differential problem using a
21+
trainable model together with condition-wise self-adaptive weights. It is
22+
intended for problems whose conditions may include supervised data, equation
23+
residuals evaluated on input points, and equation residuals sampled from
24+
domains.
25+
26+
Given a model :math:`\mathcal{M}`, the predicted solution is
27+
28+
.. math::
29+
30+
\hat{\mathbf{u}}(\mathbf{x}) = \mathcal{M}(\mathbf{x}).
31+
32+
For each condition, the solver introduces trainable pointwise weights. These
33+
weights are passed through a user-defined weight function :math:`m`,
34+
typically chosen to keep the effective weights bounded or positive. The
35+
resulting weighted objective encourages the model to focus on regions where
36+
the residual is larger.
37+
38+
For a problem with governing equation operator :math:`\mathcal{A}` in the
39+
domain :math:`\Omega` and boundary operator :math:`\mathcal{B}` on the
40+
boundary :math:`\partial\Omega`, the objective can be written as
41+
42+
.. math::
43+
44+
\mathcal{L}_{\mathrm{problem}} = \frac{1}{N_{\Omega}}
45+
\sum_{i=1}^{N_{\Omega}} m(\lambda_{\Omega}^{i}) \mathcal{L}
46+
\left( \mathcal{A}[\hat{\mathbf{u}}](\mathbf{x}_i) \right)
47+
+ \frac{1}{N_{\partial\Omega}} \sum_{i=1}^{N_{\partial\Omega}}
48+
m(\lambda_{\partial\Omega}^{i})
49+
\mathcal{L} \left( \mathcal{B}[\hat{\mathbf{u}}](\mathbf{x}_i) \right),
50+
51+
where :math:`\lambda_{\Omega}^{i}` and :math:`\lambda_{\partial\Omega}^{i}`
52+
are the self-adaptive weights associated with points in :math:`\Omega` and
53+
:math:`\partial\Omega`, respectively, and :math:`\mathcal{L}` is the
54+
selected loss function, typically the mean squared error.
55+
56+
The model parameters and the self-adaptive weights are optimized through a
57+
min-max problem:
58+
59+
.. math::
60+
61+
\min_{\theta} \max_{\lambda} \mathcal{L}_{\mathrm{problem}},
62+
63+
where :math:`\theta` denotes the model parameters and :math:`\lambda`
64+
denotes the collection of self-adaptive weights.
65+
66+
.. seealso::
67+
68+
**Original reference**: McClenny, L. D., & Braga-Neto, U. M. (2023).
69+
*Self-adaptive physics-informed neural networks.*
70+
Journal of Computational Physics, 474, 111722.
71+
DOI: `10.1016/j.jcp.2022.111722
72+
<https://doi.org/10.1016/j.jcp.2022.111722>`_.
73+
"""
74+
75+
# Accepted conditions types for this solver
76+
accepted_conditions_types = (
77+
InputTargetCondition,
78+
InputEquationCondition,
79+
DomainEquationCondition,
80+
)
81+
82+
def __init__(
83+
self,
84+
problem,
85+
model,
86+
weight_function=torch.nn.Sigmoid(),
87+
optimizer_model=None,
88+
optimizer_weights=None,
89+
scheduler_model=None,
90+
scheduler_weights=None,
91+
weighting=None,
92+
loss=None,
93+
):
94+
"""
95+
Initialization of the :class:`SelfAdaptivePhysicsInformedSolver` class.
96+
97+
:param BaseProblem problem: The problem to be solved.
98+
:param torch.nn.Module model: The model used by the solver.
99+
:param torch.nn.Module weight_function: The weight function used to
100+
compute self-adaptive weights. Default is ``torch.nn.Sigmoid()``.
101+
:param TorchOptimizer optimizer_model: The optimizer of the main model.
102+
If ``None``, the ``torch.optim.Adam`` optimizer with a learning rate
103+
of ``0.001`` is used. Default is ``None``.
104+
:param TorchOptimizer optimizer_weights: The optimizer of the
105+
self-adaptive weights. If ``None``, the ``torch.optim.Adam``
106+
optimizer with a learning rate of ``0.001`` is used.
107+
Default is ``None``.
108+
:param TorchScheduler scheduler_model: The scheduler of the main model.
109+
If ``None``, the ``torch.optim.lr_scheduler.ConstantLR`` scheduler
110+
with a factor of ``1.0`` is used. Default is ``None``.
111+
:param TorchScheduler scheduler_weights: The scheduler of the
112+
self-adaptive weights. If ``None``, the
113+
``torch.optim.lr_scheduler.ConstantLR`` scheduler with a factor of
114+
``1.0`` is used. Default is ``None``.
115+
:param BaseWeighting weighting: The weighting strategy used to combine
116+
condition losses. If ``None``, no weighting is applied. Default is
117+
``None``.
118+
:param loss: The loss function used to compute residual losses.
119+
If ``None``, :class:`torch.nn.MSELoss` is used. Default is ``None``.
120+
:raises ValueError: If ``weight_function`` is not a ``torch.nn.Module``.
121+
:raises ValueError: If not all domains have been discretised.
122+
"""
123+
# Check consistency
124+
check_consistency(weight_function, torch.nn.Module)
125+
126+
# Check that all domains have been discretised
127+
if not problem.are_all_domains_discretised:
128+
raise ValueError(
129+
"All domains must be discretised before initializing the "
130+
"solver."
131+
)
132+
133+
# Compute the number of points for each condition
134+
num_points = {
135+
cond: (
136+
problem._discretised_domains[cond].shape[0]
137+
if isinstance(problem.conditions[cond], DomainEquationCondition)
138+
else problem.conditions[cond].data.input.shape[0]
139+
)
140+
for cond in problem.conditions
141+
}
142+
143+
# Initialize weights container and per-condition parameters
144+
weights = torch.nn.Module()
145+
146+
# Attach the weight function as a submodule
147+
weights.func = weight_function
148+
149+
# Register a torch.nn.Parameter for each condition to store the weights
150+
for cond in problem.conditions:
151+
p = torch.nn.Parameter(torch.zeros(num_points[cond], 1))
152+
setattr(weights, cond, p)
153+
154+
# Prepare optimizers
155+
optimizers = (
156+
[optimizer_model, optimizer_weights]
157+
if any(o is not None for o in (optimizer_model, optimizer_weights))
158+
else None
159+
)
160+
161+
# Prepare schedulers
162+
schedulers = (
163+
[scheduler_model, scheduler_weights]
164+
if any(s is not None for s in (scheduler_model, scheduler_weights))
165+
else None
166+
)
167+
168+
# Initialize the base solver
169+
MultiModelSolver.__init__(
170+
self,
171+
problem=problem,
172+
models=[model, weights],
173+
optimizers=optimizers,
174+
schedulers=schedulers,
175+
weighting=weighting,
176+
loss=loss,
177+
use_lt=True,
178+
)
179+
180+
def training_step(self, batch, batch_idx):
181+
"""
182+
Solver training step.
183+
184+
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
185+
tuple containing a condition name and a dictionary of points.
186+
:param int batch_idx: The index of the current batch.
187+
:return: The loss of the training step.
188+
:rtype: torch.Tensor
189+
"""
190+
# Zero the gradients of weights optimizer and compute the loss
191+
self.optimizer_weights.instance.zero_grad()
192+
loss = self.batch_evaluation_step(batch, batch_idx)
193+
194+
# Perform the backward pass and complete a step for the weights
195+
self.manual_backward(-loss)
196+
self.optimizer_weights.instance.step()
197+
self.scheduler_weights.instance.step()
198+
199+
# Zero the gradients of model optimizer and compute the loss again
200+
self.optimizer_model.instance.zero_grad()
201+
loss = self.batch_evaluation_step(batch, batch_idx)
202+
203+
# Perform the backward pass and complete a step for the model
204+
self.manual_backward(loss)
205+
self.optimizer_model.instance.step()
206+
self.scheduler_model.instance.step()
207+
208+
# Log the training loss
209+
self.log(
210+
name="train_loss",
211+
value=loss.item(),
212+
batch_size=self.get_batch_size(batch),
213+
**self.trainer.logging_kwargs,
214+
)
215+
216+
return loss
217+
218+
def forward(self, x):
219+
"""
220+
Forward pass through the model.
221+
222+
:param x: The input data.
223+
:type x: torch.Tensor | LabelTensor | Data | Graph
224+
:return: The output of the model.
225+
:rtype: torch.Tensor | LabelTensor | Data | Graph
226+
"""
227+
return self.model(x)
228+
229+
def _compute_condition_loss(self, condition, data, batch_idx):
230+
"""
231+
Compute the scalar loss for a given condition and its data.
232+
233+
:param BaseCondition condition: The condition for which to compute the
234+
loss.
235+
:param dict data: The data corresponding to the condition.
236+
:param int batch_idx: The index of the current batch.
237+
:return: The scalar loss for the condition.
238+
:rtype: torch.Tensor
239+
"""
240+
# Clone the input tensor if it exists to avoid in-place modifications
241+
if "input" in data and hasattr(data["input"], "clone"):
242+
data = dict(data)
243+
data["input"] = data["input"].clone()
244+
245+
# Compute and store the residual tensor for the condition
246+
self.residual_tensor = condition.evaluate(data, self)
247+
248+
# Retrieve condition name for more complex weighting schemes
249+
condition_name = condition.name
250+
251+
# Apply the activation function to the condition-specific weights
252+
weight_param = getattr(self.weights, condition_name)
253+
weight_tensor = self.weights.func(weight_param)
254+
255+
# Compute the tensor loss from the residual tensor
256+
condition_tensor_loss = self._loss_from_residual(condition_name)
257+
258+
# Get the correct indices to retrieve the weights for the current batch
259+
len_residuals = self.residual_tensor.shape[0]
260+
261+
# Get the total number of points, together with the start / end indices
262+
total_points = weight_param.shape[0]
263+
start = (batch_idx * len_residuals) % total_points
264+
end = start + len_residuals
265+
266+
# Retrieve the weights for the current batch using modular indexing
267+
idx = torch.arange(start, end, device=self.residual_tensor.device)
268+
idx = idx % total_points
269+
270+
# Compute the scalar loss from the tensor loss and return it
271+
condition_scalar_loss = self._apply_reduction(
272+
condition_tensor_loss * weight_tensor[idx]
273+
)
274+
275+
return condition_scalar_loss
276+
277+
@property
278+
def model(self):
279+
"""
280+
The single model used by the solver.
281+
282+
:return: The single model used by the solver.
283+
:rtype: torch.nn.Module
284+
"""
285+
return self._pina_models[0]
286+
287+
@property
288+
def weights(self):
289+
"""
290+
The self-adaptive weights used by the solver.
291+
292+
:return: The self-adaptive weights used by the solver.
293+
:rtype: torch.nn.Module
294+
"""
295+
return self._pina_models[1]
296+
297+
@property
298+
def optimizer_model(self):
299+
"""
300+
The optimizer for the model used by the solver.
301+
302+
:return: The optimizer for the model used by the solver.
303+
:rtype: TorchOptimizer
304+
"""
305+
return self.optimizers[0]
306+
307+
@property
308+
def optimizer_weights(self):
309+
"""
310+
The optimizer for the weights used by the solver.
311+
312+
:return: The optimizer for the weights used by the solver.
313+
:rtype: TorchOptimizer
314+
"""
315+
return self.optimizers[1]
316+
317+
@property
318+
def scheduler_model(self):
319+
"""
320+
The scheduler for the model used by the solver.
321+
322+
:return: The scheduler for the model used by the solver.
323+
:rtype: TorchScheduler
324+
"""
325+
return self.schedulers[0]
326+
327+
@property
328+
def scheduler_weights(self):
329+
"""
330+
The scheduler for the weights used by the solver.
331+
332+
:return: The scheduler for the weights used by the solver.
333+
:rtype: TorchScheduler
334+
"""
335+
return self.schedulers[1]

pina/solver/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"PhysicsInformedEnsembleSolver",
1313
"AutoregressiveSingleModelSolver",
1414
"AutoregressiveEnsembleSolver",
15+
"SelfAdaptivePhysicsInformedSolver",
1516
]
1617

1718

@@ -36,3 +37,6 @@
3637
from pina._src.solver.autoregressive_ensemble_solver import (
3738
AutoregressiveEnsembleSolver,
3839
)
40+
from pina._src.solver.self_adaptive_physics_informed_solver import (
41+
SelfAdaptivePhysicsInformedSolver,
42+
)

tests/test_solver/test_physics_informed_ensemble_solver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def define_direct_problem_model(n_pts=10, n_models=3):
3939
def define_inverse_problem_model(n_pts=10, n_models=5):
4040

4141
# Initialize inverse problem
42-
problem = InversePoisson2DSquareProblem()
42+
problem = InversePoisson2DSquareProblem(load=True, data_size=0.01)
4343
problem.discretise_domain(n_pts)
4444

4545
# Initialize the models

tests/test_solver/test_physics_informed_single_model_solver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def define_direct_problem_model(n_pts=10):
3838
def define_inverse_problem_model(n_pts=10):
3939

4040
# Initialize inverse problem
41-
problem = InversePoisson2DSquareProblem()
41+
problem = InversePoisson2DSquareProblem(load=True, data_size=0.01)
4242
problem.discretise_domain(n_pts)
4343

4444
# Initialize the model

0 commit comments

Comments
 (0)