Skip to content

Commit c306687

Browse files
competitive physics informed solver
1 parent afc5614 commit c306687

6 files changed

Lines changed: 511 additions & 2 deletions

File tree

docs/source/_rst/_code.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ Solvers and Mixins
9292
Autoregressive Single-Model Solver <solver/autoregressive_single_model_solver.rst>
9393
Autoregressive Ensemble Solver <solver/autoregressive_ensemble_solver.rst>
9494
Self-Adaptive Physics-Informed Solver <solver/self_adaptive_physics_informed_solver.rst>
95+
Competitive Physics-Informed Solver <solver/competitive_physics_informed_solver.rst>
9596
Single-Model Mixin <solver/mixin/single_model_mixin.rst>
9697
Multi-Model Mixin <solver/mixin/multi_model_mixin.rst>
9798
Ensemble Mixin <solver/mixin/ensemble_mixin.rst>
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
Competitive Physics-Informed Solver
2+
=======================================
3+
4+
.. currentmodule:: pina.solver.competitive_physics_informed_solver
5+
6+
.. automodule:: pina._src.solver.competitive_physics_informed_solver
7+
8+
.. autoclass:: pina._src.solver.competitive_physics_informed_solver.CompetitivePhysicsInformedSolver
9+
:members:
10+
:show-inheritance:
11+
:noindex:
Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
"""Module for the competitive physics-informed multi-model solver."""
2+
3+
import copy
4+
from pina._src.solver.mixin.physics_informed_mixin import _PhysicsInformedMixin
5+
from pina._src.condition.input_equation_condition import InputEquationCondition
6+
from pina._src.condition.input_target_condition import InputTargetCondition
7+
from pina._src.solver.multi_model_solver import MultiModelSolver
8+
from pina._src.condition.domain_equation_condition import (
9+
DomainEquationCondition,
10+
)
11+
12+
13+
class CompetitivePhysicsInformedSolver(_PhysicsInformedMixin, MultiModelSolver):
14+
r"""
15+
Multi-model solver for competitive physics-informed learning problems.
16+
17+
This solver approximates the solution of a differential problem using a
18+
trainable model together with a discriminator network. It is intended for
19+
problems whose conditions may include supervised data, equation residuals
20+
evaluated on input points, and equation residuals sampled from domains.
21+
22+
Given a model :math:`\mathcal{M}`, the predicted solution is
23+
24+
.. math::
25+
26+
\hat{\mathbf{u}}(\mathbf{x}) = \mathcal{M}(\mathbf{x}).
27+
28+
The discriminator :math:`D` assigns pointwise weights to the residuals,
29+
encouraging the model to focus on regions where the approximation performs
30+
poorly. The model parameters are optimized by minimizing the loss, while the
31+
discriminator parameters are optimized by maximizing it.
32+
33+
For a problem with governing equation operator :math:`\mathcal{A}` in the
34+
domain :math:`\Omega` and boundary operator :math:`\mathcal{B}` on the
35+
boundary :math:`\partial\Omega`, the competitive objective can be written as
36+
37+
.. math::
38+
39+
\mathcal{L}_{\mathrm{problem}} = \frac{1}{N_{\Omega}}
40+
\sum_{i=1}^{N_{\Omega}} \mathcal{L}
41+
\left(D(\mathbf{x}_i)\mathcal{A}[\hat{\mathbf{u}}](\mathbf{x}_i)\right)
42+
+\frac{1}{N_{\partial\Omega}} \sum_{i=1}^{N_{\partial\Omega}}
43+
\mathcal{L}
44+
\left(D(\mathbf{x}_i)\mathcal{B}[\hat{\mathbf{u}}](\mathbf{x}_i)\right),
45+
46+
where :math:`D` is the discriminator network and :math:`\mathcal{L}` is the
47+
selected loss function, typically the mean squared error.
48+
49+
The model and discriminator are trained through a min-max problem:
50+
51+
.. math::
52+
53+
\min_{\theta} \max_{\phi} \mathcal{L}_{\mathrm{problem}},
54+
55+
where :math:`\theta` denotes the model parameters and :math:`\phi` denotes
56+
the discriminator parameters.
57+
58+
.. seealso::
59+
60+
**Original reference**: Zeng, Q., Kothari, P., Chou, E., & Masi, G.
61+
(2022).
62+
*Competitive physics informed networks.*
63+
International Conference on Learning Representations, ICLR 2022.
64+
`OpenReview Preprint <https://openreview.net/forum?id=z9SIj-IM7tn>`_.
65+
"""
66+
67+
# Accepted conditions types for this solver
68+
accepted_conditions_types = (
69+
InputTargetCondition,
70+
InputEquationCondition,
71+
DomainEquationCondition,
72+
)
73+
74+
def __init__(
75+
self,
76+
problem,
77+
model,
78+
discriminator=None,
79+
optimizer_model=None,
80+
optimizer_discriminator=None,
81+
scheduler_model=None,
82+
scheduler_discriminator=None,
83+
weighting=None,
84+
loss=None,
85+
):
86+
"""
87+
Initialization of the :class:`CompetitivePhysicsInformedSolver` class.
88+
89+
:param BaseProblem problem: The problem to be solved.
90+
:param torch.nn.Module model: The model used by the solver.
91+
:param torch.nn.Module discriminator: The discriminator used by the
92+
solver. If ``None``, a deep copy of the model is used as
93+
discriminator. Default is ``None``.
94+
:param TorchOptimizer optimizer_model: The optimizer of the main model.
95+
If ``None``, the ``torch.optim.Adam`` optimizer with a learning rate
96+
of ``0.001`` is used. Default is ``None``.
97+
:param TorchOptimizer optimizer_discriminator: The optimizer of the
98+
discriminator. If ``None``, the ``torch.optim.Adam`` optimizer with
99+
a learning rate of ``0.001`` is used. Default is ``None``.
100+
:param TorchScheduler scheduler_model: The scheduler of the main model.
101+
If ``None``, the ``torch.optim.lr_scheduler.ConstantLR`` scheduler
102+
with a factor of ``1.0`` is used. Default is ``None``.
103+
:param TorchScheduler scheduler_discriminator: The scheduler of the
104+
discriminator.
105+
If ``None``, the ``torch.optim.lr_scheduler.ConstantLR`` scheduler
106+
with a factor of ``1.0`` is used. Default is ``None``.
107+
:param BaseWeighting weighting: The weighting strategy used to combine
108+
condition losses. If ``None``, no weighting is applied. Default is
109+
``None``.
110+
:param loss: The loss function used to compute residual losses.
111+
If ``None``, :class:`torch.nn.MSELoss` is used. Default is ``None``.
112+
:raises ValueError: If ``weight_function`` is not a ``torch.nn.Module``.
113+
:raises ValueError: If not all domains have been discretised.
114+
"""
115+
# Initialize the discriminator if not provided
116+
if discriminator is None:
117+
discriminator = copy.deepcopy(model)
118+
119+
# Prepare optimizers
120+
optimizers = (
121+
[optimizer_model, optimizer_discriminator]
122+
if any(
123+
o is not None
124+
for o in (optimizer_model, optimizer_discriminator)
125+
)
126+
else None
127+
)
128+
129+
# Prepare schedulers
130+
schedulers = (
131+
[scheduler_model, scheduler_discriminator]
132+
if any(
133+
s is not None
134+
for s in (scheduler_model, scheduler_discriminator)
135+
)
136+
else None
137+
)
138+
139+
# Initialize the base solver
140+
MultiModelSolver.__init__(
141+
self,
142+
problem=problem,
143+
models=[model, discriminator],
144+
optimizers=optimizers,
145+
schedulers=schedulers,
146+
weighting=weighting,
147+
loss=loss,
148+
use_lt=True,
149+
)
150+
151+
def training_step(self, batch, batch_idx):
152+
"""
153+
Solver training step.
154+
155+
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
156+
tuple containing a condition name and a dictionary of points.
157+
:param int batch_idx: The index of the current batch.
158+
:return: The loss of the training step.
159+
:rtype: torch.Tensor
160+
"""
161+
# Zero the gradients of the model optimizer and compute the loss
162+
self.optimizer_model.instance.zero_grad()
163+
loss = self.batch_evaluation_step(batch, batch_idx)
164+
165+
# Perform the backward pass and complete a step for the model
166+
self.manual_backward(loss)
167+
self.optimizer_model.instance.step()
168+
self.scheduler_model.instance.step()
169+
170+
# Zero the gradients of the discriminator optimizer and compute the loss
171+
self.optimizer_discriminator.instance.zero_grad()
172+
loss = self.batch_evaluation_step(batch, batch_idx)
173+
174+
# Perform the backward pass and complete a step for the discriminator
175+
self.manual_backward(-loss)
176+
self.optimizer_discriminator.instance.step()
177+
self.scheduler_discriminator.instance.step()
178+
179+
# Log the training loss
180+
self.log(
181+
name="train_loss",
182+
value=loss.item(),
183+
batch_size=self.get_batch_size(batch),
184+
**self.trainer.logging_kwargs,
185+
)
186+
187+
return loss
188+
189+
def forward(self, x):
190+
"""
191+
Forward pass through the model.
192+
193+
:param x: The input data.
194+
:type x: torch.Tensor | LabelTensor | Data | Graph
195+
:return: The output of the model.
196+
:rtype: torch.Tensor | LabelTensor | Data | Graph
197+
"""
198+
return self.model(x)
199+
200+
def _compute_condition_loss(self, condition, data, batch_idx):
201+
"""
202+
Compute the scalar loss for a given condition and its data.
203+
204+
:param BaseCondition condition: The condition for which to compute the
205+
loss.
206+
:param dict data: The data corresponding to the condition.
207+
:param int batch_idx: The index of the current batch.
208+
:return: The scalar loss for the condition.
209+
:rtype: torch.Tensor
210+
"""
211+
# Clone the input tensor if it exists to avoid in-place modifications
212+
if "input" in data and hasattr(data["input"], "clone"):
213+
data = dict(data)
214+
data["input"] = data["input"].clone()
215+
216+
# Compute and store the residual tensor for the condition
217+
self.residual_tensor = condition.evaluate(data, self)
218+
219+
# Compute the discriminator bets for the current condition
220+
discriminator_input = data["input"][self.problem.input_variables]
221+
discriminator_bets = self.discriminator(discriminator_input)
222+
223+
# Weight the residual tensor using the discriminator bets
224+
self.residual_tensor = self.residual_tensor * discriminator_bets
225+
226+
# Retrieve condition name for more complex weighting schemes
227+
condition_name = condition.name if hasattr(condition, "name") else None
228+
229+
# Compute the tensor loss from the residual tensor
230+
condition_tensor_loss = self._loss_from_residual(condition_name)
231+
232+
# Compute the scalar loss from the tensor loss and return it
233+
condition_scalar_loss = self._apply_reduction(condition_tensor_loss)
234+
235+
return condition_scalar_loss
236+
237+
@property
238+
def model(self):
239+
"""
240+
The single model used by the solver.
241+
242+
:return: The single model used by the solver.
243+
:rtype: torch.nn.Module
244+
"""
245+
return self._pina_models[0]
246+
247+
@property
248+
def discriminator(self):
249+
"""
250+
The discriminator used by the solver.
251+
252+
:return: The discriminator used by the solver.
253+
:rtype: torch.nn.Module
254+
"""
255+
return self._pina_models[1]
256+
257+
@property
258+
def optimizer_model(self):
259+
"""
260+
The optimizer for the model used by the solver.
261+
262+
:return: The optimizer for the model used by the solver.
263+
:rtype: TorchOptimizer
264+
"""
265+
return self.optimizers[0]
266+
267+
@property
268+
def optimizer_discriminator(self):
269+
"""
270+
The optimizer for the discriminator used by the solver.
271+
272+
:return: The optimizer for the discriminator used by the solver.
273+
:rtype: TorchOptimizer
274+
"""
275+
return self.optimizers[1]
276+
277+
@property
278+
def scheduler_model(self):
279+
"""
280+
The scheduler for the model used by the solver.
281+
282+
:return: The scheduler for the model used by the solver.
283+
:rtype: TorchScheduler
284+
"""
285+
return self.schedulers[0]
286+
287+
@property
288+
def scheduler_discriminator(self):
289+
"""
290+
The scheduler for the discriminator used by the solver.
291+
292+
:return: The scheduler for the discriminator used by the solver.
293+
:rtype: TorchScheduler
294+
"""
295+
return self.schedulers[1]

pina/solver/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"AutoregressiveSingleModelSolver",
1414
"AutoregressiveEnsembleSolver",
1515
"SelfAdaptivePhysicsInformedSolver",
16+
"CompetitivePhysicsInformedSolver",
1617
]
1718

1819

@@ -40,3 +41,6 @@
4041
from pina._src.solver.self_adaptive_physics_informed_solver import (
4142
SelfAdaptivePhysicsInformedSolver,
4243
)
44+
from pina._src.solver.competitive_physics_informed_solver import (
45+
CompetitivePhysicsInformedSolver,
46+
)

tests/test_callback/test_metric_tracker.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,5 @@ def test_routine(metrics_to_track, batch_size):
6767
for metric in expected_metrics
6868
for suffix in ("step", "epoch")
6969
]
70-
print(f"Logged metrics: {logged_metrics}")
71-
print(f"Expected metrics: {expected_metrics}")
70+
7271
assert sorted(logged_metrics) == sorted(expected_metrics)

0 commit comments

Comments
 (0)