Skip to content

Commit f375911

Browse files
implement mixin logic for solvers
1 parent 990030d commit f375911

57 files changed

Lines changed: 2374 additions & 2529 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

pina/_src/callback/refinement/base_refinement.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Module for the Base Refinement class."""
22

3-
from pina._src.solver.pinn import PINN
3+
from pina._src.solver.physics_informed_single_model_solver import (
4+
PhysicsInformedSingleModelSolver,
5+
)
46
from lightning.pytorch import Callback
57
from pina._src.core.utils import check_consistency, check_positive_integer
68
from pina._src.callback.refinement.refinement_interface import (
@@ -65,7 +67,7 @@ def on_train_start(self, trainer, solver):
6567
'domain' attribute for sampling.
6668
"""
6769
# Check solver consistency
68-
if not isinstance(solver, PINN):
70+
if not isinstance(solver, PhysicsInformedSingleModelSolver):
6971
raise RuntimeError(
7072
"Refinement strategies require a physics-informed solver. "
7173
f"Got '{type(solver).__name__}'."

pina/_src/core/trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import warnings
55
import torch
66
import lightning
7+
from pina._src.solver.mixin.physics_informed_mixin import _PhysicsInformedMixin
78
from pina._src.solver.base_solver import BaseSolver
89
from pina._src.data.data_module import DataModule
9-
from pina._src.solver.pinn import PINN
1010
from pina._src.core.utils import (
1111
check_consistency,
1212
custom_warning_format,
@@ -132,8 +132,8 @@ def __init__(
132132
f"Expected one of: {sorted(self._AVAIL_BATCHING_MODES)}."
133133
)
134134

135-
# Set inference mode to false for PINN solvers to track gradients
136-
if isinstance(solver, PINN):
135+
# Set inference mode to false when usiing physics-informed mixin
136+
if isinstance(solver, _PhysicsInformedMixin):
137137
kwargs["inference_mode"] = False
138138

139139
# Set log_every_n_steps to 0 if batch_size is None, otherwise default

pina/_src/solver/__init__.py

Whitespace-only changes.
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from pina._src.solver.mixin.autoregressive_mixin import _AutoregressiveMixin
2+
from pina._src.condition.time_series_condition import TimeSeriesCondition
3+
from pina._src.solver.ensemble_solver import EnsembleSolver
4+
5+
6+
class AutoregressiveEnsembleSolver(_AutoregressiveMixin, EnsembleSolver):
7+
"""
8+
Ensemble solver specialized for autoregressive conditions.
9+
"""
10+
11+
# Accepted conditions types for this solver
12+
accepted_conditions_types = (TimeSeriesCondition,)
13+
14+
def __init__(
15+
self,
16+
problem,
17+
models,
18+
optimizers=None,
19+
schedulers=None,
20+
weighting=None,
21+
loss=None,
22+
use_lt=True,
23+
eps=0.0,
24+
reset_weights_at_epoch_start=True,
25+
kwargs=None,
26+
):
27+
"""
28+
Initialization of the :class:`AutoregressiveEnsembleSolver` class.
29+
"""
30+
# Initialize the parent class
31+
EnsembleSolver.__init__(
32+
self,
33+
problem=problem,
34+
models=models,
35+
optimizers=optimizers,
36+
schedulers=schedulers,
37+
weighting=weighting,
38+
loss=loss,
39+
use_lt=use_lt,
40+
)
41+
42+
# Initialize the autoregressive components
43+
self._init_autoregressive_components(
44+
eps=eps,
45+
reset_weights_at_epoch_start=reset_weights_at_epoch_start,
46+
kwargs=kwargs,
47+
)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from pina._src.solver.mixin.autoregressive_mixin import _AutoregressiveMixin
2+
from pina._src.condition.time_series_condition import TimeSeriesCondition
3+
from pina._src.solver.single_model_solver import SingleModelSolver
4+
5+
6+
class AutoregressiveSingleModelSolver(_AutoregressiveMixin, SingleModelSolver):
7+
r"""
8+
The autoregressive solver for learning dynamical systems.
9+
"""
10+
11+
# Accepted conditions types for this solver
12+
accepted_conditions_types = (TimeSeriesCondition,)
13+
14+
def __init__(
15+
self,
16+
problem,
17+
model,
18+
loss=None,
19+
optimizer=None,
20+
scheduler=None,
21+
weighting=None,
22+
use_lt=False,
23+
eps=0.0,
24+
reset_weights_at_epoch_start=True,
25+
kwargs=None,
26+
):
27+
"""
28+
Initialization of the :class:`AutoregressiveSingleModelSolver` class.
29+
"""
30+
31+
# Initialize the parent class
32+
SingleModelSolver.__init__(
33+
self,
34+
problem=problem,
35+
model=model,
36+
optimizer=optimizer,
37+
scheduler=scheduler,
38+
weighting=weighting,
39+
loss=loss,
40+
use_lt=use_lt,
41+
)
42+
43+
# Initialize the autoregressive components
44+
self._init_autoregressive_components(
45+
eps=eps,
46+
reset_weights_at_epoch_start=reset_weights_at_epoch_start,
47+
kwargs=kwargs,
48+
)

0 commit comments

Comments
 (0)