Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 44 additions & 10 deletions pina/equation/system_equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,51 @@

class SystemEquation(EquationInterface):
"""
Implementation of the System of Equations. Every ``equation`` passed to a
:class:`~pina.condition.condition.Condition` object must be either a
:class:`~pina.equation.equation.Equation` or a
:class:`~pina.equation.system_equation.SystemEquation` instance.
Implementation of the System of Equations, to be passed to a
:class:`~pina.condition.condition.Condition` object.

Unlike the :class:`~pina.equation.equation.Equation` class, which represents
a single equation, the :class:`SystemEquation` class allows multiple
equations to be grouped together into a system. This is particularly useful
when dealing with multi-component outputs or coupled physical models, where
the residual must be computed collectively across several constraints.

Each equation in the system must be either:
- An instance of :class:`~pina.equation.equation.Equation`;
- A callable function.

The residuals from each equation are computed independently and then
aggregated using an optional reduction strategy (e.g., ``mean``, ``sum``).
The resulting residual is returned as a single :class:`~pina.LabelTensor`.

:Example:

>>> from pina.equation import SystemEquation, FixedValue, FixedGradient
>>> from pina import LabelTensor
>>> import torch
>>> pts = LabelTensor(torch.rand(10, 2), labels=["x", "y"])
>>> pts.requires_grad = True
>>> output_ = torch.pow(pts, 2)
>>> output_.labels = ["u", "v"]
>>> system_equation = SystemEquation(
... [
... FixedValue(value=1.0, components=["u"]),
... FixedGradient(value=0.0, components=["v"],d=["y"]),
... ],
... reduction="mean",
... )
>>> residual = system_equation.residual(pts, output_)

"""

def __init__(self, list_equation, reduction=None):
"""
Initialization of the :class:`SystemEquation` class.

:param Callable equation: A ``torch`` callable function used to compute
the residual of a mathematical equation.
:param list_equation: A list containing either callable functions or
instances of :class:`~pina.equation.equation.Equation`, used to
compute the residuals of mathematical equations.
:type list_equation: list[Callable] | list[Equation]
:param str reduction: The reduction method to aggregate the residuals of
each equation. Available options are: ``None``, ``mean``, ``sum``,
``callable``.
Expand All @@ -32,9 +65,10 @@ def __init__(self, list_equation, reduction=None):
check_consistency([list_equation], list)

# equations definition
self.equations = []
for _, equation in enumerate(list_equation):
self.equations.append(Equation(equation))
self.equations = [
equation if isinstance(equation, Equation) else Equation(equation)
Comment thread
dario-coscia marked this conversation as resolved.
for equation in list_equation
]

# possible reduction
if reduction == "mean":
Expand All @@ -45,7 +79,7 @@ def __init__(self, list_equation, reduction=None):
self.reduction = reduction
else:
raise NotImplementedError(
"Only mean and sum reductions implemented."
"Only mean and sum reductions are currenly supported."
)

def residual(self, input_, output_, params_=None):
Expand Down
78 changes: 61 additions & 17 deletions tests/test_equations/test_system_equation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pina.equation import SystemEquation
from pina.equation import SystemEquation, FixedValue, FixedGradient
from pina.operator import grad, laplacian
from pina import LabelTensor
import torch
Expand All @@ -24,34 +24,78 @@ def foo():
pass


def test_constructor():
SystemEquation([eq1, eq2])
SystemEquation([eq1, eq2], reduction="sum")
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
def test_constructor(reduction):

# Constructor with callable functions
SystemEquation([eq1, eq2], reduction=reduction)

# Constructor with Equation instances
SystemEquation(
[
FixedValue(value=0.0, components=["u1"]),
FixedGradient(value=0.0, components=["u2"]),
],
reduction=reduction,
)

# Constructor with mixed types
SystemEquation(
[
FixedValue(value=0.0, components=["u1"]),
eq1,
],
reduction=reduction,
)

# Non-standard reduction not implemented
with pytest.raises(NotImplementedError):
SystemEquation([eq1, eq2], reduction="foo")

# Invalid input type
with pytest.raises(ValueError):
SystemEquation(foo)


def test_residual():
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
def test_residual(reduction):

# Generate random points and output
pts = LabelTensor(torch.rand(10, 2), labels=["x", "y"])
pts.requires_grad = True
u = torch.pow(pts, 2)
u.labels = ["u1", "u2"]

eq_1 = SystemEquation([eq1, eq2], reduction="mean")
res = eq_1.residual(pts, u)
assert res.shape == torch.Size([10])
# System with callable functions
system_eq = SystemEquation([eq1, eq2], reduction=reduction)
res = system_eq.residual(pts, u)

# Checks on the shape of the residual
shape = torch.Size([10, 3]) if reduction is None else torch.Size([10])
assert res.shape == shape

eq_1 = SystemEquation([eq1, eq2], reduction="sum")
res = eq_1.residual(pts, u)
assert res.shape == torch.Size([10])
# System with Equation instances
system_eq = SystemEquation(
[
FixedValue(value=0.0, components=["u1"]),
FixedGradient(value=0.0, components=["u2"]),
],
reduction=reduction,
)

eq_1 = SystemEquation([eq1, eq2], reduction=None)
res = eq_1.residual(pts, u)
assert res.shape == torch.Size([10, 3])
# Checks on the shape of the residual
shape = torch.Size([10, 3]) if reduction is None else torch.Size([10])
assert res.shape == shape

# System with mixed types
system_eq = SystemEquation(
[
FixedValue(value=0.0, components=["u1"]),
eq1,
],
reduction=reduction,
)

eq_1 = SystemEquation([eq1, eq2])
res = eq_1.residual(pts, u)
assert res.shape == torch.Size([10, 3])
# Checks on the shape of the residual
shape = torch.Size([10, 3]) if reduction is None else torch.Size([10])
assert res.shape == shape