Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 8 additions & 63 deletions pina/solver/physics_informed_solver/pinn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,16 @@

from abc import ABCMeta, abstractmethod
import torch
from torch.nn.modules.loss import _Loss

from ..solver import SolverInterface
from ...utils import check_consistency
from ...loss.loss_interface import LossInterface
from ...problem import InverseProblem
from ..supervised_solver import SupervisedSolverInterface
from ...condition import (
InputTargetCondition,
InputEquationCondition,
DomainEquationCondition,
)


class PINNInterface(SolverInterface, metaclass=ABCMeta):
class PINNInterface(SupervisedSolverInterface, metaclass=ABCMeta):
"""
Base class for Physics-Informed Neural Network (PINN) solvers, implementing
the :class:`~pina.solver.solver.SolverInterface` class.
Expand All @@ -32,7 +28,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
DomainEquationCondition,
)

def __init__(self, problem, loss=None, **kwargs):
def __init__(self, **kwargs):
"""
Initialization of the :class:`PINNInterface` class.

Expand All @@ -41,28 +37,13 @@ def __init__(self, problem, loss=None, **kwargs):
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
Default is `None`.
:param kwargs: Additional keyword arguments to be passed to the
:class:`~pina.solver.solver.SolverInterface` class.
:class:`~pina.solver.supervised_solver.SupervisedSolverInterface`
class.
"""
kwargs["use_lt"] = True
super().__init__(**kwargs)

if loss is None:
loss = torch.nn.MSELoss()

super().__init__(problem=problem, use_lt=True, **kwargs)

# check consistency
check_consistency(loss, (LossInterface, _Loss), subclass=False)

# assign variables
self._loss_fn = loss

# inverse problem handling
if isinstance(self.problem, InverseProblem):
self._params = self.problem.unknown_parameters
self._clamp_params = self._clamp_inverse_problem_params
else:
self._params = None
self._clamp_params = lambda: None

# current condition name
self.__metric = None

def optimization_cycle(self, batch, loss_residuals=None):
Expand Down Expand Up @@ -103,8 +84,6 @@ def optimization_cycle(self, batch, loss_residuals=None):
)
# append loss
condition_loss[condition_name] = loss
# clamp unknown parameters in InverseProblem (if needed)
self._clamp_params()
return condition_loss

@torch.set_grad_enabled(True)
Expand Down Expand Up @@ -135,20 +114,6 @@ def test_step(self, batch):
"""
return super().test_step(batch, loss_residuals=self._residual_loss)

@abstractmethod
def loss_data(self, input, target):
"""
Compute the data loss for the PINN solver by evaluating the loss
between the network's output and the true solution. This method should
be overridden by the derived class.

:param LabelTensor input: The input to the neural network.
:param LabelTensor target: The target to compare with the
network's output.
:return: The supervised loss, averaged over the number of observations.
:rtype: LabelTensor
"""

@abstractmethod
def loss_phys(self, samples, equation):
"""
Expand Down Expand Up @@ -196,26 +161,6 @@ def _residual_loss(self, samples, equation):
residuals = self.compute_residual(samples, equation)
return self._loss_fn(residuals, torch.zeros_like(residuals))

def _clamp_inverse_problem_params(self):
"""
Clamps the parameters of the inverse problem solver to specified ranges.
"""
for v in self._params:
self._params[v].data.clamp_(
self.problem.unknown_parameter_domain.range_[v][0],
self.problem.unknown_parameter_domain.range_[v][1],
)

@property
def loss(self):
"""
The loss used for training.

:return: The loss function used for training.
:rtype: torch.nn.Module
"""
return self._loss_fn

@property
def current_condition_name(self):
"""
Expand Down
34 changes: 33 additions & 1 deletion pina/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

from torch._dynamo import OptimizedModule
from ..problem import AbstractProblem
from ..problem import AbstractProblem, InverseProblem
from ..optim import Optimizer, Scheduler, TorchOptimizer, TorchScheduler
from ..loss import WeightingInterface
from ..loss.scalar_weighting import _NoWeighting
Expand Down Expand Up @@ -64,6 +64,14 @@ def __init__(self, problem, weighting, use_lt):
self._pina_optimizers = None
self._pina_schedulers = None

# inverse problem handling
if isinstance(self.problem, InverseProblem):
self._params = self.problem.unknown_parameters
self._clamp_params = self._clamp_inverse_problem_params
else:
self._params = None
self._clamp_params = lambda: None

@abstractmethod
def forward(self, *args, **kwargs):
"""
Expand Down Expand Up @@ -231,14 +239,29 @@ def _optimization_cycle(self, batch, **kwargs):
containing the condition name and the associated scalar loss.
:rtype: dict
"""
# compute losses
losses = self.optimization_cycle(batch)
# clamp unknown parameters in InverseProblem (if needed)
self._clamp_params()
# store log
for name, value in losses.items():
self.store_log(
f"{name}_loss", value.item(), self.get_batch_size(batch)
)
# aggregate
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
return loss

def _clamp_inverse_problem_params(self):
"""
Clamps the parameters of the inverse problem solver to specified ranges.
"""
for v in self._params:
self._params[v].data.clamp_(
self.problem.unknown_parameter_domain.range_[v][0],
self.problem.unknown_parameter_domain.range_[v][1],
)

@staticmethod
def _compile_modules(model):
"""
Expand Down Expand Up @@ -405,6 +428,15 @@ def configure_optimizers(self):
:rtype: tuple[list[Optimizer], list[Scheduler]]
"""
self.optimizer.hook(self.model.parameters())
if isinstance(self.problem, InverseProblem):
self.optimizer.instance.add_param_group(
{
"params": [
self._params[var]
for var in self.problem.unknown_variables
]
}
)
self.scheduler.hook(self.optimizer)
return ([self.optimizer.instance], [self.scheduler.instance])

Expand Down
Loading