Skip to content

Commit 429a505

Browse files
add interface + base class structure
1 parent edfd0a0 commit 429a505

11 files changed

Lines changed: 271 additions & 221 deletions

File tree

docs/source/_rst/_code.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,10 @@ Equations and Differential Operators
195195
.. toctree::
196196
:titlesonly:
197197

198-
EquationInterface <equation/equation_interface.rst>
198+
Equation Interface <equation/equation_interface.rst>
199+
Base Equation <equation/base_equation.rst>
199200
Equation <equation/equation.rst>
200-
SystemEquation <equation/system_equation.rst>
201+
System Equation <equation/system_equation.rst>
201202
Equation Factory <equation/equation_factory.rst>
202203
Differential Operators <operator.rst>
203204

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Base Equation
2+
====================
3+
4+
.. currentmodule:: pina.equation.base_equation
5+
.. autoclass:: pina._src.equation.base_equation.BaseEquation
6+
:members:
7+
:show-inheritance:

docs/source/_rst/equation/equation_factory.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,9 @@ Equation Factory
3939
:show-inheritance:
4040

4141
.. autoclass:: pina._src.equation.equation_factory.Poisson
42+
:members:
43+
:show-inheritance:
44+
45+
.. autoclass:: pina._src.equation.equation_factory.AcousticWave
4246
:members:
4347
:show-inheritance:
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""Module for the Base Equation."""
2+
3+
from abc import ABCMeta, abstractmethod
4+
import torch
5+
6+
7+
class BaseEquation(metaclass=ABCMeta):
8+
"""
9+
Base class for all equations, implementing common functionality.
10+
11+
Equations are fundamental components in PINA, representing mathematical
12+
constraints that must be satisfied by the model outputs. They can be passed
13+
to :class:`~pina.condition.condition.Condition` objects to define the
14+
conditions under which the model is trained.
15+
16+
All specific equation types should inherit from this class and implement its
17+
abstract methods.
18+
19+
This class is not meant to be instantiated directly.
20+
"""
21+
22+
@abstractmethod
23+
def residual(self, input_, output_, params_):
24+
"""
25+
Evaluate the equation residual at the given inputs.
26+
27+
:param LabelTensor input_: The input points where the residual is
28+
computed.
29+
:param LabelTensor output_: The output tensor, potentially produced by a
30+
:class:`torch.nn.Module` instance.
31+
:param dict params_: An optional dictionary of unknown parameters, used
32+
in :class:`~pina.problem.inverse_problem.InverseProblem` settings.
33+
If the equation is not related to an inverse problem, this should be
34+
set to ``None``. Default is ``None``.
35+
:return: The residual values of the equation.
36+
:rtype: LabelTensor
37+
"""
38+
39+
def to(self, device):
40+
"""
41+
Move all tensor attributes to the specified device.
42+
43+
:param torch.device device: The target device to move the tensors to.
44+
:return: The instance moved to the specified device.
45+
:rtype: BaseEquation
46+
"""
47+
# Iterate over all attributes of the Equation
48+
for key, val in self.__dict__.items():
49+
50+
# Move tensors in dictionaries to the specified device
51+
if isinstance(val, dict):
52+
self.__dict__[key] = {
53+
k: v.to(device) if torch.is_tensor(v) else v
54+
for k, v in val.items()
55+
}
56+
57+
# Move tensors in lists to the specified device
58+
elif isinstance(val, list):
59+
self.__dict__[key] = [
60+
v.to(device) if torch.is_tensor(v) else v for v in val
61+
]
62+
63+
# Move tensor attributes to the specified device
64+
elif torch.is_tensor(val):
65+
self.__dict__[key] = val.to(device)
66+
67+
return self

pina/_src/equation/equation.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,65 @@
11
"""Module for the Equation."""
22

33
import inspect
4-
from pina._src.equation.equation_interface import EquationInterface
4+
from pina._src.equation.base_equation import BaseEquation
55

66

7-
class Equation(EquationInterface):
7+
class Equation(BaseEquation):
88
"""
9-
Implementation of the Equation class. Every ``equation`` passed to a
10-
:class:`~pina.condition.condition.Condition` object must be either an
11-
instance of :class:`Equation` or
12-
:class:`~pina.equation.system_equation.SystemEquation`.
9+
Implementation of the Equation class, representing a single mathematical
10+
equation to be satisfied by the model outputs.
11+
12+
It can be passed to a :class:`~pina.condition.condition.Condition` object to
13+
define the conditions under which the model is trained.
1314
"""
1415

1516
def __init__(self, equation):
1617
"""
1718
Initialization of the :class:`Equation` class.
1819
19-
:param Callable equation: A ``torch`` callable function used to compute
20-
the residual of a mathematical equation.
20+
:param Callable equation: A callable function used to compute the
21+
residual of a mathematical equation.
2122
:raises ValueError: If the equation is not a callable function.
2223
"""
24+
# Check consistency
2325
if not callable(equation):
24-
raise ValueError(
25-
"equation must be a callable function."
26-
"Expected a callable function, got "
27-
f"{equation}"
28-
)
29-
# compute the signature
26+
raise ValueError(f"Expected a callable function, got {equation}")
27+
28+
# Compute the signature length
3029
sig = inspect.signature(equation)
3130
self.__len_sig = len(sig.parameters)
3231
self.__equation = equation
3332

3433
def residual(self, input_, output_, params_=None):
3534
"""
36-
Compute the residual of the equation.
35+
Evaluate the equation residual at the given inputs.
3736
38-
:param LabelTensor input_: Input points where the equation is evaluated.
39-
:param LabelTensor output_: Output tensor, eventually produced by a
37+
:param LabelTensor input_: The input points where the residual is
38+
computed.
39+
:param LabelTensor output_: The output tensor, potentially produced by a
4040
:class:`torch.nn.Module` instance.
41-
:param dict params_: Dictionary of unknown parameters, associated with a
42-
:class:`~pina.problem.inverse_problem.InverseProblem` instance.
43-
If the equation is not related to a
44-
:class:`~pina.problem.inverse_problem.InverseProblem` instance, the
45-
parameters must be initialized to ``None``. Default is ``None``.
46-
:return: The computed residual of the equation.
41+
:param dict params_: An optional dictionary of unknown parameters, used
42+
in :class:`~pina.problem.inverse_problem.InverseProblem` settings.
43+
If the equation is not related to an inverse problem, this should be
44+
set to ``None``. Default is ``None``.
45+
:raises RuntimeError: If the underlying equation signature is neither of
46+
length 2 for direct problems nor of length 3 for inverse problems.
47+
:return: The residual values of the equation.
4748
:rtype: LabelTensor
48-
:raises RuntimeError: If the underlying equation signature length is not
49-
2 (direct problem) or 3 (inverse problem).
5049
"""
5150
# Move the equation to the input_ device
5251
self.to(input_.device)
5352

54-
# Call the underlying equation based on its signature length
53+
# Evaluate the equation for direct problems
5554
if self.__len_sig == 2:
5655
return self.__equation(input_, output_)
56+
57+
# Evaluate the equation for inverse problems
5758
if self.__len_sig == 3:
5859
return self.__equation(input_, output_, params_)
60+
61+
# Raise an error if the signature length is unexpected
5962
raise RuntimeError(
6063
f"Unexpected number of arguments in equation: {self.__len_sig}. "
61-
"Expected either 2 (direct problem) or 3 (inverse problem)."
64+
"Expected either 2 for direct problems, or 3 for inverse problems."
6265
)

pina/_src/equation/equation_factory.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ def equation(_, output_):
2828
"""
2929
Definition of the equation to enforce a fixed value.
3030
31-
:param LabelTensor input_: Input points where the equation is
32-
evaluated.
33-
:param LabelTensor output_: Output tensor, eventually produced by a
34-
:class:`torch.nn.Module` instance.
35-
:return: The computed residual of the equation.
31+
:param LabelTensor input_: The input points where the residual is
32+
computed.
33+
:param LabelTensor output_: The output tensor, potentially produced
34+
by a :class:`torch.nn.Module` instance.
35+
:return: The residual values of the equation.
3636
:rtype: LabelTensor
3737
"""
3838
if components is None:
@@ -66,11 +66,11 @@ def equation(input_, output_):
6666
"""
6767
Definition of the equation to enforce a fixed gradient.
6868
69-
:param LabelTensor input_: Input points where the equation is
70-
evaluated.
71-
:param LabelTensor output_: Output tensor, eventually produced by a
72-
:class:`torch.nn.Module` instance.
73-
:return: The computed residual of the equation.
69+
:param LabelTensor input_: The input points where the residual is
70+
computed.
71+
:param LabelTensor output_: The output tensor, potentially produced
72+
by a :class:`torch.nn.Module` instance.
73+
:return: The residual values of the equation.
7474
:rtype: LabelTensor
7575
"""
7676
return grad(output_, input_, components=components, d=d) - value
@@ -101,11 +101,11 @@ def equation(input_, output_):
101101
"""
102102
Definition of the equation to enforce a fixed flux.
103103
104-
:param LabelTensor input_: Input points where the equation is
105-
evaluated.
106-
:param LabelTensor output_: Output tensor, eventually produced by a
107-
:class:`torch.nn.Module` instance.
108-
:return: The computed residual of the equation.
104+
:param LabelTensor input_: The input points where the residual is
105+
computed.
106+
:param LabelTensor output_: The output tensor, potentially produced
107+
by a :class:`torch.nn.Module` instance.
108+
:return: The residual values of the equation.
109109
:rtype: LabelTensor
110110
"""
111111
return div(output_, input_, components=components, d=d) - value
@@ -137,11 +137,11 @@ def equation(input_, output_):
137137
"""
138138
Definition of the equation to enforce a fixed laplacian.
139139
140-
:param LabelTensor input_: Input points where the equation is
141-
evaluated.
142-
:param LabelTensor output_: Output tensor, eventually produced by a
143-
:class:`torch.nn.Module` instance.
144-
:return: The computed residual of the equation.
140+
:param LabelTensor input_: The input points where the residual is
141+
computed.
142+
:param LabelTensor output_: The output tensor, potentially produced
143+
by a :class:`torch.nn.Module` instance.
144+
:return: The residual values of the equation.
145145
:rtype: LabelTensor
146146
"""
147147
return (
@@ -158,7 +158,7 @@ class Laplace(FixedLaplacian): # pylint: disable=R0903
158158
159159
.. math::
160160
161-
\delta u = 0
161+
\Delta u = 0
162162
163163
"""
164164

Lines changed: 12 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,31 @@
11
"""Module for the Equation Interface."""
22

33
from abc import ABCMeta, abstractmethod
4-
import torch
54

65

76
class EquationInterface(metaclass=ABCMeta):
87
"""
9-
Abstract base class for equations.
10-
11-
Equations in PINA simplify the training process. When defining a problem,
12-
each equation passed to a :class:`~pina.condition.condition.Condition`
13-
object must be either an :class:`~pina.equation.equation.Equation` or a
14-
:class:`~pina.equation.system_equation.SystemEquation` instance.
15-
16-
An :class:`~pina.equation.equation.Equation` is a wrapper for a callable
17-
function, while :class:`~pina.equation.system_equation.SystemEquation`
18-
wraps a list of callable functions. To streamline code writing, PINA
19-
provides a diverse set of pre-implemented equations, such as
20-
:class:`~pina.equation.equation_factory.FixedValue`,
21-
:class:`~pina.equation.equation_factory.FixedGradient`, and many others.
8+
Abstract interface for all equations.
229
"""
2310

2411
@abstractmethod
25-
def residual(self, input_, output_, params_):
12+
def residual(self, input_, output_, params_=None):
2613
"""
27-
Abstract method to compute the residual of an equation.
14+
Evaluate the equation residual at the given inputs.
2815
29-
:param LabelTensor input_: Input points where the equation is evaluated.
30-
:param LabelTensor output_: Output tensor, eventually produced by a
16+
:param LabelTensor input_: The input points where the residual is
17+
computed.
18+
:param LabelTensor output_: The output tensor, potentially produced by a
3119
:class:`torch.nn.Module` instance.
32-
:param dict params_: Dictionary of unknown parameters, associated with a
33-
:class:`~pina.problem.inverse_problem.InverseProblem` instance.
34-
:return: The computed residual of the equation.
20+
:param dict params_: An optional dictionary of unknown parameters, used
21+
in :class:`~pina.problem.inverse_problem.InverseProblem` settings.
22+
If the equation is not related to an inverse problem, this should be
23+
set to ``None``. Default is ``None``.
24+
:return: The residual values of the equation.
3525
:rtype: LabelTensor
3626
"""
3727

28+
@abstractmethod
3829
def to(self, device):
3930
"""
4031
Move all tensor attributes to the specified device.
@@ -43,24 +34,3 @@ def to(self, device):
4334
:return: The instance moved to the specified device.
4435
:rtype: EquationInterface
4536
"""
46-
# Iterate over all attributes of the Equation
47-
for key, val in self.__dict__.items():
48-
49-
# Move tensors in dictionaries to the specified device
50-
if isinstance(val, dict):
51-
self.__dict__[key] = {
52-
k: v.to(device) if torch.is_tensor(v) else v
53-
for k, v in val.items()
54-
}
55-
56-
# Move tensors in lists to the specified device
57-
elif isinstance(val, list):
58-
self.__dict__[key] = [
59-
v.to(device) if torch.is_tensor(v) else v for v in val
60-
]
61-
62-
# Move tensor attributes to the specified device
63-
elif torch.is_tensor(val):
64-
self.__dict__[key] = val.to(device)
65-
66-
return self

0 commit comments

Comments
 (0)