Skip to content

Commit 746c2f3

Browse files
committed
formatter
1 parent ee11345 commit 746c2f3

9 files changed

Lines changed: 35 additions & 25 deletions

File tree

pina/solver/ensemble_solver/ensemble_pinn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def _residual_loss(self, samples, equation):
150150
Computes the physics loss for the physics-informed solver based on the
151151
provided samples and equation. This method should never be overridden
152152
by the user, if not intentionally,
153-
since it is used internally to compute validation loss. It overrides the
153+
since it is used internally to compute validation loss. It overrides the
154154
:obj:`~pina.solver.physics_informed_solver.PINNInterface._residual_loss`
155155
method.
156156

pina/solver/ensemble_solver/ensemble_solver_interface.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class DeepEnsembleSolverInterface(MultiSolverInterface):
4545
processing systems, 30.
4646
DOI: `arXiv:1612.01474 <https://arxiv.org/abs/1612.01474>`_.
4747
"""
48+
4849
def __init__(
4950
self,
5051
problem,
@@ -99,19 +100,19 @@ def forward(self, x, ensemble_idx=None):
99100
return self.models[ensemble_idx].forward(x)
100101
# otherwise return the stacked output
101102
return torch.stack(
102-
[self.forward(x, idx) for idx in range(self.num_ensembles)],
103-
dim=self.ensemble_dim,
104-
)
103+
[self.forward(x, idx) for idx in range(self.num_ensembles)],
104+
dim=self.ensemble_dim,
105+
)
105106

106107
def training_step(self, batch):
107108
"""
108109
Training step for the solver, overridden for manual optimization.
109110
This method performs a forward pass, calculates the loss, and applies
110-
manual backward propagation and optimization steps for each model in
111+
manual backward propagation and optimization steps for each model in
111112
the ensemble.
112113
113-
:param list[tuple[str, dict]] batch: A batch of training data.
114-
Each element is a tuple containing a condition name and a
114+
:param list[tuple[str, dict]] batch: A batch of training data.
115+
Each element is a tuple containing a condition name and a
115116
dictionary of points.
116117
:return: The aggregated loss after the training step.
117118
:rtype: torch.Tensor

pina/solver/ensemble_solver/ensemble_supervised.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class DeepEnsembleSupervisedSolver(
5757
processing systems, 30.
5858
DOI: `arXiv:1612.01474 <https://arxiv.org/abs/1612.01474>`_.
5959
"""
60+
6061
def __init__(
6162
self,
6263
problem,
@@ -102,8 +103,8 @@ def __init__(
102103

103104
def loss_data(self, input, target):
104105
"""
105-
Compute the data loss for the EnsembleSupervisedSolver by evaluating
106-
the loss between the network's output and the true solution for each
106+
Compute the data loss for the EnsembleSupervisedSolver by evaluating
107+
the loss between the network's output and the true solution for each
107108
model. This method should not be overridden, if not intentionally.
108109
109110
:param input: The input to the neural network.

pina/solver/garom.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ def __init__(
5151
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
5252
Default is ``None``.
5353
:param Optimizer optimizer_discriminator: The optimizer for the
54-
discriminator. If ``None``, the :class:`torch.optim.Adam` optimizer is
55-
used. Default is ``None``.
54+
discriminator. If ``None``, the :class:`torch.optim.Adam`
55+
optimizer is used. Default is ``None``.
5656
:param Scheduler scheduler_generator: The learning rate scheduler for
5757
the generator.
5858
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`

pina/solver/physics_informed_solver/pinn_interface.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(self, problem, loss=None, **kwargs):
5353
check_consistency(loss, (LossInterface, _Loss), subclass=False)
5454

5555
# assign variables
56-
self._loss = loss
56+
self._loss_fn = loss
5757

5858
# inverse problem handling
5959
if isinstance(self.problem, InverseProblem):
@@ -185,7 +185,7 @@ def _residual_loss(self, samples, equation):
185185
"""
186186
Computes the physics loss for the physics-informed solver based on the
187187
provided samples and equation. This method should never be overridden
188-
by the user, if not intentionally,
188+
by the user, if not intentionally,
189189
since it is used internally to compute validation loss.
190190
191191
@@ -215,7 +215,7 @@ def loss(self):
215215
:return: The loss function used for training.
216216
:rtype: torch.nn.Module
217217
"""
218-
return self._loss
218+
return self._loss_fn
219219

220220
@property
221221
def current_condition_name(self):

pina/solver/supervised_solver/supervised.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(
6868
weighting=weighting,
6969
use_lt=use_lt,
7070
)
71-
71+
7272
def loss_data(self, input, target):
7373
"""
7474
Compute the data loss for the Supervised solver by evaluating the loss
@@ -82,4 +82,4 @@ def loss_data(self, input, target):
8282
:return: The supervised loss, averaged over the number of observations.
8383
:rtype: LabelTensor | torch.Tensor | Graph | Data
8484
"""
85-
return self._loss(self.forward(input), target)
85+
return self.loss(self.forward(input), target)

pina/solver/supervised_solver/supervised_solver_interface.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""Module for the Supervised solver interface."""
22

3-
import torch
4-
53
from abc import abstractmethod
64

5+
import torch
6+
77
from torch.nn.modules.loss import _Loss
88
from ..solver import SolverInterface
99
from ...utils import check_consistency
@@ -16,8 +16,8 @@ class SupervisedSolverInterface(SolverInterface):
1616
Base class for Supervised solvers. This class implements a Supervised Solver
1717
, using a user specified ``model`` to solve a specific ``problem``.
1818
19-
The ``SupervisedSolverInterface`` class can be used to define
20-
Supervised solvers that work with one or multiple optimizers and/or models.
19+
The ``SupervisedSolverInterface`` class can be used to define
20+
Supervised solvers that work with one or multiple optimizers and/or models.
2121
By default, it is compatible with problems defined by
2222
:class:`~pina.problem.abstract_problem.AbstractProblem`,
2323
and users can choose the problem type the solver is meant to address.
@@ -45,7 +45,7 @@ def __init__(self, loss=None, **kwargs):
4545
check_consistency(loss, (LossInterface, _Loss), subclass=False)
4646

4747
# assign variables
48-
self._loss = loss
48+
self._loss_fn = loss
4949

5050
def optimization_cycle(self, batch):
5151
"""
@@ -87,4 +87,4 @@ def loss(self):
8787
:return: The loss function to be minimized.
8888
:rtype: torch.nn.Module
8989
"""
90-
return self._loss
90+
return self._loss_fn

tests/test_solver/test_ensemble_pinn.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727

2828
# define models
2929
models = [
30-
FeedForward(len(problem.input_variables), len(problem.output_variables), n_layers=1)
30+
FeedForward(
31+
len(problem.input_variables), len(problem.output_variables), n_layers=1
32+
)
3133
for _ in range(5)
3234
]
3335

@@ -84,6 +86,7 @@ def test_solver_validation(batch_size, compile):
8486
[isinstance(model, OptimizedModule) for model in solver.models]
8587
)
8688

89+
8790
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
8891
@pytest.mark.parametrize("compile", [True, False])
8992
def test_solver_test(batch_size, compile):
@@ -104,6 +107,7 @@ def test_solver_test(batch_size, compile):
104107
[isinstance(model, OptimizedModule) for model in solver.models]
105108
)
106109

110+
107111
def test_train_load_restore():
108112
dir = "tests/test_solver/tmp"
109113
solver = DeepEnsemblePINN(models=models, problem=problem)

tests/test_solver/test_ensemble_supervised_solver.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,9 @@ def forward(self, batch):
8989

9090

9191
def test_constructor():
92-
solver=DeepEnsembleSupervisedSolver(problem=TensorProblem(), models=models)
92+
solver = DeepEnsembleSupervisedSolver(
93+
problem=TensorProblem(), models=models
94+
)
9395
DeepEnsembleSupervisedSolver(problem=LabelTensorProblem(), models=models)
9496
assert DeepEnsembleSupervisedSolver.accepted_conditions_types == (
9597
InputTargetCondition
@@ -118,7 +120,9 @@ def test_solver_train(use_lt, batch_size, compile):
118120

119121
trainer.train()
120122
if trainer.compile:
121-
assert all([isinstance(model, OptimizedModule) for model in solver.models])
123+
assert all(
124+
[isinstance(model, OptimizedModule) for model in solver.models]
125+
)
122126

123127

124128
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])

0 commit comments

Comments
 (0)