Skip to content

Commit e56ebdf

Browse files
committed
multi model
1 parent 31e2c03 commit e56ebdf

2 files changed

Lines changed: 267 additions & 0 deletions

File tree

Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
"""Module for the MultiModelSimpleSolver."""
2+
3+
import torch
4+
from torch.nn.modules.loss import _Loss
5+
6+
from pina._src.condition.domain_equation_condition import (
7+
DomainEquationCondition,
8+
)
9+
from pina._src.condition.input_equation_condition import (
10+
InputEquationCondition,
11+
)
12+
from pina._src.condition.input_target_condition import InputTargetCondition
13+
from pina._src.core.utils import check_consistency
14+
from pina._src.loss.loss_interface import LossInterface
15+
from pina._src.solver.solver import MultiSolverInterface
16+
17+
18+
class MultiModelSimpleSolver(MultiSolverInterface):
19+
"""
20+
Minimal multi-model solver with explicit residual evaluation, reduction,
21+
and loss aggregation across conditions.
22+
23+
The solver orchestrates a uniform workflow for all conditions in the batch.
24+
Each model in the ensemble contributes its own forward pass independently,
25+
and the outputs are stacked along ``ensemble_dim``:
26+
27+
.. math::
28+
\\hat{\\mathbf{u}}_i = \\mathcal{M}_i(\\mathbf{s}),
29+
\\quad i = 1, \\dots, N_{\\rm ensemble}
30+
31+
During the optimization cycle each model's prediction is evaluated against
32+
the condition independently, and the resulting per-model losses are
33+
averaged to form the aggregated condition loss:
34+
35+
.. math::
36+
\\mathcal{L}_{\\rm condition} = \\frac{1}{N_{\\rm ensemble}}
37+
\\sum_{i=1}^{N_{\\rm ensemble}} \\mathcal{L}_i
38+
39+
The per-condition workflow is:
40+
41+
1. evaluate the condition for each model and obtain non-aggregated
42+
loss tensors;
43+
2. apply the configured reduction to each per-model tensor;
44+
3. average the reduced per-model losses into a single scalar for
45+
the condition;
46+
4. return the per-condition losses, which are aggregated by the
47+
inherited solver machinery through the configured weighting.
48+
"""
49+
50+
accepted_conditions_types = (
51+
InputTargetCondition,
52+
InputEquationCondition,
53+
DomainEquationCondition,
54+
)
55+
56+
def __init__(
57+
self,
58+
problem,
59+
models,
60+
optimizers=None,
61+
schedulers=None,
62+
weighting=None,
63+
loss=None,
64+
use_lt=True,
65+
ensemble_dim=0,
66+
):
67+
"""
68+
Initialize the multi-model simple solver.
69+
70+
:param AbstractProblem problem: The problem to be solved.
71+
:param list[torch.nn.Module] models: The neural network models to be
72+
used. Must be a list or tuple with at least two models.
73+
:param list[Optimizer] optimizers: The optimizers to be used.
74+
If ``None``, the :class:`torch.optim.Adam` optimizer is used for
75+
each model. Default is ``None``.
76+
:param list[Scheduler] schedulers: The learning rate schedulers.
77+
If ``None``, :class:`torch.optim.lr_scheduler.ConstantLR` is used
78+
for each model. Default is ``None``.
79+
:param WeightingInterface weighting: The weighting schema to be used.
80+
If ``None``, no weighting schema is used. Default is ``None``.
81+
:param torch.nn.Module loss: The element-wise loss module whose
82+
reduction strategy is reused by the solver. If ``None``,
83+
:class:`torch.nn.MSELoss` is used. Default is ``None``.
84+
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
85+
Default is ``True``.
86+
:param int ensemble_dim: The dimension along which the per-model
87+
outputs are stacked in :meth:`forward`. Default is ``0``.
88+
"""
89+
if loss is None:
90+
loss = torch.nn.MSELoss()
91+
92+
check_consistency(loss, (LossInterface, _Loss), subclass=False)
93+
check_consistency(ensemble_dim, int)
94+
95+
super().__init__(
96+
problem=problem,
97+
models=models,
98+
optimizers=optimizers,
99+
schedulers=schedulers,
100+
weighting=weighting,
101+
use_lt=use_lt,
102+
)
103+
104+
self._loss_fn = loss
105+
self._reduction = getattr(loss, "reduction", "mean")
106+
self._ensemble_dim = ensemble_dim
107+
108+
if hasattr(self._loss_fn, "reduction"):
109+
self._loss_fn.reduction = "none"
110+
111+
# ------------------------------------------------------------------
112+
# Forward
113+
# ------------------------------------------------------------------
114+
115+
def forward(self, x, model_idx=None):
116+
"""
117+
Forward pass through the ensemble models.
118+
119+
If ``model_idx`` is provided, returns the output of the single model
120+
at that index. Otherwise stacks the outputs of all models along
121+
``ensemble_dim``.
122+
123+
:param LabelTensor x: The input tensor to the models.
124+
:param int model_idx: Optional index to select a specific model from
125+
the ensemble. If ``None`` results for all models are stacked in
126+
the ``ensemble_dim`` dimension. Default is ``None``.
127+
:return: The output of the selected model, or the stacked outputs from
128+
all models.
129+
:rtype: LabelTensor | torch.Tensor
130+
"""
131+
if model_idx is not None:
132+
return self.models[model_idx].forward(x)
133+
return torch.stack(
134+
[self.forward(x, idx) for idx in range(self.num_models)],
135+
dim=self._ensemble_dim,
136+
)
137+
138+
# ------------------------------------------------------------------
139+
# Training
140+
# ------------------------------------------------------------------
141+
142+
def training_step(self, batch):
143+
"""
144+
Training step for the solver, overridden for manual optimization.
145+
146+
Performs a forward pass, calculates the loss via
147+
:meth:`optimization_cycle`, applies manual backward propagation and
148+
runs the optimization step for each model in the ensemble.
149+
150+
:param list[tuple[str, dict]] batch: A batch of training data. Each
151+
element is a tuple containing a condition name and a dictionary of
152+
points.
153+
:return: The aggregated loss after the training step.
154+
:rtype: torch.Tensor
155+
"""
156+
# zero grad for all optimizers
157+
for opt in self.optimizers:
158+
opt.instance.zero_grad()
159+
# compute condition losses (calls optimization_cycle internally via
160+
# the parent training_step)
161+
loss = super().training_step(batch)
162+
# backpropagate
163+
self.manual_backward(loss)
164+
# optimizer + scheduler step for each model
165+
for opt, sched in zip(self.optimizers, self.schedulers):
166+
opt.instance.step()
167+
sched.instance.step()
168+
return loss
169+
170+
def optimization_cycle(self, batch):
171+
"""
172+
Compute one reduced, ensemble-averaged loss per condition in the batch.
173+
174+
For each condition the method evaluates every model independently and
175+
averages the resulting scalar losses.
176+
177+
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
178+
tuple containing a condition name and a dictionary of points.
179+
:return: The reduced, ensemble-averaged losses for all conditions.
180+
:rtype: dict[str, torch.Tensor]
181+
"""
182+
condition_losses = {}
183+
184+
for condition_name, data in batch:
185+
condition = self.problem.conditions[condition_name]
186+
condition_data = dict(data)
187+
188+
# Evaluate each model independently and average the losses.
189+
per_model_losses = []
190+
for idx in range(self.num_models):
191+
# Temporarily expose only one model through forward so that
192+
# condition.evaluate uses just that model.
193+
original_forward = self.forward
194+
self.forward = ( # noqa: E731
195+
lambda x, _idx=idx: self.models[_idx].forward(x)
196+
)
197+
loss_tensor = condition.evaluate(
198+
condition_data, self, self._loss_fn
199+
)
200+
self.forward = original_forward
201+
per_model_losses.append(self._apply_reduction(loss_tensor))
202+
203+
condition_losses[condition_name] = torch.stack(
204+
per_model_losses
205+
).mean()
206+
207+
return condition_losses
208+
209+
# ------------------------------------------------------------------
210+
# Helpers
211+
# ------------------------------------------------------------------
212+
213+
def _apply_reduction(self, value):
214+
"""
215+
Apply the configured reduction to a non-aggregated condition tensor.
216+
217+
:param value: The non-aggregated tensor returned by a condition.
218+
:type value: torch.Tensor
219+
:return: The reduced scalar tensor.
220+
:rtype: torch.Tensor
221+
:raises ValueError: If the reduction is not supported.
222+
"""
223+
if self._reduction == "none":
224+
return value
225+
if self._reduction == "mean":
226+
return value.mean()
227+
if self._reduction == "sum":
228+
return value.sum()
229+
raise ValueError(f"Unsupported reduction '{self._reduction}'.")
230+
231+
# ------------------------------------------------------------------
232+
# Properties
233+
# ------------------------------------------------------------------
234+
235+
@property
236+
def loss(self):
237+
"""
238+
The underlying element-wise loss module.
239+
240+
:return: The stored loss module.
241+
:rtype: torch.nn.Module
242+
"""
243+
return self._loss_fn
244+
245+
@property
246+
def ensemble_dim(self):
247+
"""
248+
The dimension along which the per-model outputs are stacked.
249+
250+
:return: The ensemble dimension.
251+
:rtype: int
252+
"""
253+
return self._ensemble_dim
254+
255+
@property
256+
def num_models(self):
257+
"""
258+
The number of models in the ensemble.
259+
260+
:return: The number of models.
261+
:rtype: int
262+
"""
263+
return len(self.models)

pina/solver/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"SingleSolverInterface",
1515
"MultiSolverInterface",
1616
"SingleModelSimpleSolver",
17+
"MultiModelSimpleSolver",
1718
"PINNInterface",
1819
"PINN",
1920
"GradientPINN",
@@ -39,6 +40,9 @@
3940
from pina._src.solver.single_model_simple_solver import (
4041
SingleModelSimpleSolver,
4142
)
43+
from pina._src.solver.multi_model_simple_solver import (
44+
MultiModelSimpleSolver,
45+
)
4246
from pina._src.solver.pinn import PINNInterface, PINN
4347
from pina._src.solver.physics_informed_solver.gradient_pinn import GradientPINN
4448
from pina._src.solver.physics_informed_solver.causal_pinn import CausalPINN

0 commit comments

Comments
 (0)