11"""Module for the Supervised solver interface."""
22
3- import torch
4-
53from abc import abstractmethod
64
5+ import torch
6+
77from torch .nn .modules .loss import _Loss
88from ..solver import SolverInterface
99from ...utils import check_consistency
@@ -16,8 +16,8 @@ class SupervisedSolverInterface(SolverInterface):
1616 Base class for Supervised solvers. This class implements a Supervised Solver
1717 , using a user specified ``model`` to solve a specific ``problem``.
1818
19- The ``SupervisedSolverInterface`` class can be used to define
20- Supervised solvers that work with one or multiple optimizers and/or models.
19+ The ``SupervisedSolverInterface`` class can be used to define
20+ Supervised solvers that work with one or multiple optimizers and/or models.
2121 By default, it is compatible with problems defined by
2222 :class:`~pina.problem.abstract_problem.AbstractProblem`,
2323 and users can choose the problem type the solver is meant to address.
@@ -45,7 +45,7 @@ def __init__(self, loss=None, **kwargs):
4545 check_consistency (loss , (LossInterface , _Loss ), subclass = False )
4646
4747 # assign variables
48- self ._loss = loss
48+ self ._loss_fn = loss
4949
5050 def optimization_cycle (self , batch ):
5151 """
@@ -87,4 +87,4 @@ def loss(self):
8787 :return: The loss function to be minimized.
8888 :rtype: torch.nn.Module
8989 """
90- return self ._loss
90+ return self ._loss_fn
0 commit comments