Skip to content

Commit fed952f

Browse files
committed
ensemble solver
1 parent e56ebdf commit fed952f

25 files changed

Lines changed: 832 additions & 1866 deletions

pina/_src/callback/refinement/r3_refinement.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
)
77
from pina._src.core.label_tensor import LabelTensor
88
from pina._src.core.utils import check_consistency
9-
from pina._src.loss.loss_interface import LossInterface
9+
from pina._src.loss.loss_interface import DualLossInterface as LossInterface
1010

1111

1212
class R3Refinement(RefinementInterface):

pina/_src/callback/refinement/refinement_interface.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66
from abc import ABCMeta, abstractmethod
77
from lightning.pytorch import Callback
88
from pina._src.core.utils import check_consistency
9-
from pina._src.solver.physics_informed_solver.pinn_interface import (
10-
PINNInterface,
11-
)
9+
from pina._src.solver.pinn import PINN as PINNInterface
1210

1311

1412
class RefinementInterface(Callback, metaclass=ABCMeta):

pina/_src/core/trainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
import lightning
77
from pina._src.core.utils import check_consistency, custom_warning_format
88
from pina._src.data.data_module import PinaDataModule
9-
from pina._src.solver.supervised_solver.supervised_solver_interface import (
9+
from pina._src.solver.solver_interface import (
1010
SolverInterface,
1111
)
12-
from pina._src.solver.physics_informed_solver.pinn_interface import (
13-
PINNInterface,
14-
)
12+
# from pina._src.solver.physics_informed_solver.pinn_interface import (
13+
# PINNInterface,
14+
# )
1515

1616
# set the warning for compile options
1717
warnings.formatwarning = custom_warning_format

pina/_src/loss/loss_interface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
import torch
66

77

8-
class LossInterface(_Loss, metaclass=ABCMeta):
8+
class DualLossInterface(_Loss, metaclass=ABCMeta):
99
"""
1010
Abstract base class for all losses. All classes defining a loss function
1111
should inherit from this interface.
1212
"""
1313

1414
def __init__(self, reduction="mean"):
1515
"""
16-
Initialization of the :class:`LossInterface` class.
16+
Initialization of the :class:`DualLossInterface` class.
1717
1818
:param str reduction: The reduction method for the loss.
1919
Available options: ``none``, ``mean``, ``sum``.
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
"""Module for the DeepEnsemble simple solver."""
2+
3+
from pina._src.solver.multi_model_simple_solver import MultiModelSimpleSolver
4+
5+
6+
class DeepEnsembleSimpleSolver(MultiModelSimpleSolver):
7+
r"""
8+
Deep Ensemble Simple Solver class. This class implements a Deep Ensemble
9+
solver for generic conditions (data, equations, or domain residuals) using
10+
user-specified ``models`` to solve a specific ``problem``.
11+
12+
It is the ensemble counterpart of
13+
:class:`~pina.solver.SingleModelSimpleSolver`: each model in the ensemble
14+
evaluates every condition independently, and the per-model scalar losses
15+
are averaged to produce the final condition loss.
16+
17+
An ensemble model is constructed by combining multiple models that solve
18+
the same type of problem. Mathematically, this creates an implicit
19+
distribution :math:`p(\mathbf{u} \mid \mathbf{s})` over the possible
20+
outputs :math:`\mathbf{u}`, given the original input :math:`\mathbf{s}`.
21+
The models :math:`\mathcal{M}_{i\in (1,\dots,r)}` in
22+
the ensemble work collaboratively to capture different
23+
aspects of the data or task, with each model contributing a distinct
24+
prediction
25+
:math:`\mathbf{y}_{i}=\mathcal{M}_i(\mathbf{u} \mid \mathbf{s})`.
26+
By aggregating these predictions, the ensemble
27+
model can achieve greater robustness and accuracy compared to individual
28+
models, leveraging the diversity of the models to reduce overfitting and
29+
improve generalization. Furthemore, statistical metrics can
30+
be computed, e.g. the ensemble mean and variance:
31+
32+
.. math::
33+
\mathbf{\mu} = \frac{1}{N}\sum_{i=1}^r \mathbf{y}_{i}
34+
35+
.. math::
36+
\mathbf{\sigma^2} = \frac{1}{N}\sum_{i=1}^r
37+
(\mathbf{y}_{i} - \mathbf{\mu})^2
38+
39+
During training the condition loss is minimised by each ensemble model
40+
independently and then averaged:
41+
42+
.. math::
43+
\mathcal{L}_{\rm{condition}} = \frac{1}{N_{\rm{ensemble}}}
44+
\sum_{i=1}^{N_{\rm{ensemble}}}
45+
\mathcal{L}_i(\mathcal{M}_i, \mathbf{s})
46+
47+
where :math:`\mathcal{L}` is a specific loss function, typically the MSE:
48+
49+
.. math::
50+
\mathcal{L}(v) = \| v \|^2_2.
51+
52+
.. seealso::
53+
54+
**Original reference**: Lakshminarayanan, B., Pritzel, A., & Blundell,
55+
C. (2017). *Simple and scalable predictive uncertainty estimation
56+
using deep ensembles*. Advances in neural information
57+
processing systems, 30.
58+
DOI: `arXiv:1612.01474 <https://arxiv.org/abs/1612.01474>`_.
59+
"""
60+
61+
def __init__(
62+
self,
63+
problem,
64+
models,
65+
optimizers=None,
66+
schedulers=None,
67+
weighting=None,
68+
loss=None,
69+
use_lt=True,
70+
ensemble_dim=0,
71+
):
72+
"""
73+
Initialization of the :class:`DeepEnsembleSimpleSolver` class.
74+
75+
:param AbstractProblem problem: The problem to be solved.
76+
:param list[torch.nn.Module] models: The neural network models to be
77+
used. Must be a list or tuple with at least two models.
78+
:param list[Optimizer] optimizers: The optimizers to be used.
79+
If ``None``, the :class:`torch.optim.Adam` optimizer is used for
80+
each model. Default is ``None``.
81+
:param list[Scheduler] schedulers: The learning rate schedulers.
82+
If ``None``, :class:`torch.optim.lr_scheduler.ConstantLR` is used
83+
for each model. Default is ``None``.
84+
:param WeightingInterface weighting: The weighting schema to be used.
85+
If ``None``, no weighting schema is used. Default is ``None``.
86+
:param torch.nn.Module loss: The element-wise loss module.
87+
If ``None``, :class:`torch.nn.MSELoss` is used. Default is
88+
``None``.
89+
:param bool use_lt: If ``True``, the solver uses LabelTensors as
90+
input. Default is ``True``.
91+
:param int ensemble_dim: The dimension along which the per-model
92+
outputs are stacked in :meth:`forward`. Default is ``0``.
93+
"""
94+
super().__init__(
95+
problem=problem,
96+
models=models,
97+
optimizers=optimizers,
98+
schedulers=schedulers,
99+
weighting=weighting,
100+
loss=loss,
101+
use_lt=use_lt,
102+
ensemble_dim=ensemble_dim,
103+
)

pina/_src/solver/ensemble_solver/__init__.py

Whitespace-only changes.

pina/_src/solver/ensemble_solver/ensemble_pinn.py

Lines changed: 0 additions & 174 deletions
This file was deleted.

0 commit comments

Comments
 (0)