Skip to content

Commit ee11345

Browse files
committed
Refactoring solvers
* Simplify logic compile * Improve and update doc * Create SupervisedSolverInterface * Specialize SupervisedSolver and ReducedOrderModelSolver * Create EnsembleSolverInterface + EnsembleSupervisedSolver * Create tests ensemble solvers
1 parent 4e16d0a commit ee11345

30 files changed

Lines changed: 1433 additions & 463 deletions

docs/source/_rst/_code.rst

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,19 @@ Solvers
6868
SolverInterface <solver/solver_interface.rst>
6969
SingleSolverInterface <solver/single_solver_interface.rst>
7070
MultiSolverInterface <solver/multi_solver_interface.rst>
71+
SupervisedSolverInterface <solver/supervised_solver/supervised_solver_interface>
72+
DeepEnsembleSolverInterface <solver/ensemble_solver/ensemble_solver_interface>
7173
PINNInterface <solver/physics_informed_solver/pinn_interface.rst>
7274
PINN <solver/physics_informed_solver/pinn.rst>
7375
GradientPINN <solver/physics_informed_solver/gradient_pinn.rst>
7476
CausalPINN <solver/physics_informed_solver/causal_pinn.rst>
7577
CompetitivePINN <solver/physics_informed_solver/competitive_pinn.rst>
7678
SelfAdaptivePINN <solver/physics_informed_solver/self_adaptive_pinn.rst>
7779
RBAPINN <solver/physics_informed_solver/rba_pinn.rst>
78-
SupervisedSolver <solver/supervised.rst>
79-
ReducedOrderModelSolver <solver/reduced_order_model.rst>
80+
DeepEnsemblePINN <solver/ensemble_solver/ensemble_pinn>
81+
SupervisedSolver <solver/supervised_solver/supervised.rst>
82+
DeepEnsembleSupervisedSolver <solver/ensemble_solver/ensemble_supervised>
83+
ReducedOrderModelSolver <solver/supervised_solver/reduced_order_model.rst>
8084
GAROM <solver/garom.rst>
8185

8286

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
DeepEnsemblePINN
2+
==================
3+
.. currentmodule:: pina.solver.ensemble_solver.ensemble_pinn
4+
5+
.. autoclass:: DeepEnsemblePINN
6+
:show-inheritance:
7+
:members:
8+
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
DeepEnsembleSolverInterface
2+
=============================
3+
.. currentmodule:: pina.solver.ensemble_solver.ensemble_solver_interface
4+
5+
.. autoclass:: DeepEnsembleSolverInterface
6+
:show-inheritance:
7+
:members:
8+
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
DeepEnsembleSupervisedSolver
2+
=============================
3+
.. currentmodule:: pina.solver.ensemble_solver.ensemble_supervised
4+
5+
.. autoclass:: DeepEnsembleSupervisedSolver
6+
:show-inheritance:
7+
:members:
8+

docs/source/_rst/solver/reduced_order_model.rst renamed to docs/source/_rst/solver/supervised_solver/reduced_order_model.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
ReducedOrderModelSolver
22
==========================
3-
.. currentmodule:: pina.solver.reduced_order_model
3+
.. currentmodule:: pina.solver.supervised_solver.reduced_order_model
44

55
.. autoclass:: ReducedOrderModelSolver
66
:members:

docs/source/_rst/solver/supervised.rst renamed to docs/source/_rst/solver/supervised_solver/supervised.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
SupervisedSolver
22
===================
3-
.. currentmodule:: pina.solver.supervised
3+
.. currentmodule:: pina.solver.supervised_solver.supervised
44

55
.. autoclass:: SupervisedSolver
66
:members:
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
SupervisedSolverInterface
2+
==========================
3+
.. currentmodule:: pina.solver.supervised_solver.supervised_solver_interface
4+
5+
.. autoclass:: SupervisedSolverInterface
6+
:show-inheritance:
7+
:members:
8+

pina/solver/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@
1313
"RBAPINN",
1414
"SupervisedSolver",
1515
"ReducedOrderModelSolver",
16+
"DeepEnsembleSolverInterface",
17+
"DeepEnsembleSupervisedSolver",
18+
"DeepEnsemblePINN",
1619
"GAROM",
1720
]
1821

1922
from .solver import SolverInterface, SingleSolverInterface, MultiSolverInterface
2023
from .physics_informed_solver import *
21-
from .supervised import SupervisedSolver
22-
from .reduced_order_model import ReducedOrderModelSolver
24+
from .supervised_solver import *
25+
from .ensemble_solver import *
2326
from .garom import GAROM
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""Module for the Ensemble solver classes."""
2+
3+
__all__ = [
4+
"DeepEnsembleSolverInterface",
5+
"DeepEnsembleSupervisedSolver",
6+
"DeepEnsemblePINN",
7+
]
8+
9+
from .ensemble_solver_interface import DeepEnsembleSolverInterface
10+
from .ensemble_supervised import DeepEnsembleSupervisedSolver
11+
from .ensemble_pinn import DeepEnsemblePINN
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
"""Module for the DeepEnsemble physics solver."""
2+
3+
import torch
4+
5+
from .ensemble_solver_interface import DeepEnsembleSolverInterface
6+
from ..physics_informed_solver import PINNInterface
7+
from ...problem import InverseProblem
8+
9+
10+
class DeepEnsemblePINN(PINNInterface, DeepEnsembleSolverInterface):
11+
r"""
12+
Deep Ensemble Physics Informed Solver class. This class implements a
13+
Deep Ensemble for Physics Informed Neural Networks using user
14+
specified ``model``s to solve a specific ``problem``.
15+
16+
An ensemble model is constructed by combining multiple models that solve
17+
the same type of problem. Mathematically, this creates an implicit
18+
distribution :math:`p(\mathbf{u} \mid \mathbf{s})` over the possible
19+
outputs :math:`\mathbf{u}`, given the original input :math:`\mathbf{s}`.
20+
The models :math:`\mathcal{M}_{i\in (1,\dots,r)}` in
21+
the ensemble work collaboratively to capture different
22+
aspects of the data or task, with each model contributing a distinct
23+
prediction :math:`\mathbf{y}_{i}=\mathcal{M}_i(\mathbf{u} \mid \mathbf{s})`.
24+
By aggregating these predictions, the ensemble
25+
model can achieve greater robustness and accuracy compared to individual
26+
models, leveraging the diversity of the models to reduce overfitting and
27+
improve generalization. Furthemore, statistical metrics can
28+
be computed, e.g. the ensemble mean and variance:
29+
30+
.. math::
31+
\mathbf{\mu} = \frac{1}{N}\sum_{i=1}^r \mathbf{y}_{i}
32+
33+
.. math::
34+
\mathbf{\sigma^2} = \frac{1}{N}\sum_{i=1}^r
35+
(\mathbf{y}_{i} - \mathbf{\mu})^2
36+
37+
During training the PINN loss is minimized by each ensemble model:
38+
39+
.. math::
40+
\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^4
41+
\mathcal{L}(\mathcal{A}[\mathbf{u}](\mathbf{x}_i)) +
42+
\frac{1}{N}\sum_{i=1}^N
43+
\mathcal{L}(\mathcal{B}[\mathbf{u}](\mathbf{x}_i)),
44+
45+
for the differential system:
46+
47+
.. math::
48+
49+
\begin{cases}
50+
\mathcal{A}[\mathbf{u}](\mathbf{x})=0\quad,\mathbf{x}\in\Omega\\
51+
\mathcal{B}[\mathbf{u}](\mathbf{x})=0\quad,
52+
\mathbf{x}\in\partial\Omega
53+
\end{cases}
54+
55+
:math:`\mathcal{L}` indicates a specific loss function, typically the MSE:
56+
57+
.. math::
58+
\mathcal{L}(v) = \| v \|^2_2.
59+
60+
.. seealso::
61+
62+
**Original reference**: Zou, Z., Wang, Z., & Karniadakis, G. E. (2025).
63+
*Learning and discovering multiple solutions using physics-informed
64+
neural networks with random initialization and deep ensemble*.
65+
DOI: `arXiv:2503.06320 <https://arxiv.org/abs/2503.06320>`_.
66+
67+
.. warning::
68+
This solver does not work with inverse problem. Hence in the ``problem``
69+
definition must not inherit from
70+
:class:`~pina.problem.inverse_problem.InverseProblem`.
71+
"""
72+
73+
def __init__(
74+
self,
75+
problem,
76+
models,
77+
loss=None,
78+
optimizers=None,
79+
schedulers=None,
80+
weighting=None,
81+
ensemble_dim=0,
82+
):
83+
"""
84+
Initialization of the :class:`DeepEnsemblePINN` class.
85+
86+
:param AbstractProblem problem: The problem to be solved.
87+
:param torch.nn.Module models: The neural network models to be used.
88+
:param torch.nn.Module loss: The loss function to be minimized.
89+
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
90+
Default is ``None``.
91+
:param Optimizer optimizer: The optimizer to be used.
92+
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
93+
Default is ``None``.
94+
:param Scheduler scheduler: Learning rate scheduler.
95+
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
96+
scheduler is used. Default is ``None``.
97+
:param WeightingInterface weighting: The weighting schema to be used.
98+
If ``None``, no weighting schema is used. Default is ``None``.
99+
:param int ensemble_dim: The dimension along which the ensemble
100+
outputs are stacked. Default is 0.
101+
"""
102+
if isinstance(problem, InverseProblem):
103+
raise NotImplementedError(
104+
"DeepEnsemblePINN does not work on inverse problems."
105+
)
106+
super().__init__(
107+
problem=problem,
108+
models=models,
109+
loss=loss,
110+
optimizers=optimizers,
111+
schedulers=schedulers,
112+
weighting=weighting,
113+
ensemble_dim=ensemble_dim,
114+
)
115+
116+
def loss_data(self, input, target):
117+
"""
118+
Compute the data loss for the ensemble PINN solver by evaluating
119+
the loss between the network's output and the true solution for each
120+
model. This method should not be overridden, if not intentionally.
121+
122+
:param input: The input to the neural network.
123+
:type input: LabelTensor | torch.Tensor | Graph | Data
124+
:param target: The target to compare with the network's output.
125+
:type target: LabelTensor | torch.Tensor | Graph | Data
126+
:return: The supervised loss, averaged over the number of observations.
127+
:rtype: torch.Tensor
128+
"""
129+
loss = sum(
130+
self.loss(self.forward(input, idx), target)
131+
for idx in range(self.num_ensembles)
132+
)
133+
return loss / self.num_ensembles
134+
135+
def loss_phys(self, samples, equation):
136+
"""
137+
Computes the physics loss for the ensemble PINN solver by evaluating
138+
the loss between the network's output and the true solution for each
139+
model. This method should not be overridden, if not intentionally.
140+
141+
:param LabelTensor samples: The samples to evaluate the physics loss.
142+
:param EquationInterface equation: The governing equation.
143+
:return: The computed physics loss.
144+
:rtype: LabelTensor
145+
"""
146+
return self._residual_loss(samples, equation)
147+
148+
def _residual_loss(self, samples, equation):
149+
"""
150+
Computes the physics loss for the physics-informed solver based on the
151+
provided samples and equation. This method should never be overridden
152+
by the user, if not intentionally,
153+
since it is used internally to compute validation loss. It overrides the
154+
:obj:`~pina.solver.physics_informed_solver.PINNInterface._residual_loss`
155+
method.
156+
157+
:param LabelTensor samples: The samples to evaluate the loss.
158+
:param EquationInterface equation: The governing equation.
159+
:return: The residual loss.
160+
:rtype: torch.Tensor
161+
"""
162+
loss = 0
163+
for idx in range(self.num_ensembles):
164+
residuals = equation.residual(samples, self.forward(samples, idx))
165+
loss = loss + self.loss(residuals, torch.zeros_like(residuals))
166+
return loss / self.num_ensembles

0 commit comments

Comments
 (0)