Skip to content

Commit 2f6f9f0

Browse files
add interface + base class structure for losses
1 parent 1cf29f8 commit 2f6f9f0

11 files changed

Lines changed: 252 additions & 190 deletions

File tree

docs/source/_rst/_code.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,9 @@ Losses and Weightings
310310
:titlesonly:
311311

312312
LossInterface <loss/loss_interface.rst>
313-
LpLoss <loss/lploss.rst>
314-
PowerLoss <loss/powerloss.rst>
313+
BaseLoss <loss/base_loss.rst>
314+
LpLoss <loss/lp_loss.rst>
315+
PowerLoss <loss/power_loss.rst>
315316
WeightingInterface <loss/weighting_interface.rst>
316317
ScalarWeighting <loss/scalar_weighting.rst>
317318
NeuralTangentKernelWeighting <loss/ntk_weighting.rst>
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
Base Loss
2+
===============
3+
.. currentmodule:: pina.loss.base_loss
4+
5+
.. automodule:: pina._src.loss.base_loss
6+
7+
.. autoclass:: pina._src.loss.base_loss.BaseLoss
8+
:members:
9+
:show-inheritance:

docs/source/_rst/loss/loss_interface.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
LossInterface
1+
Loss Interface
22
===============
33
.. currentmodule:: pina.loss.loss_interface
44

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
LpLoss
1+
Lp Loss
22
===============
33
.. currentmodule:: pina.loss.lp_loss
44

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
PowerLoss
1+
Power Loss
22
====================
33
.. currentmodule:: pina.loss.power_loss
44

pina/_src/loss/base_loss.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""Module for the BaseLoss class."""
2+
3+
import torch
4+
from pina._src.loss.loss_interface import LossInterface
5+
6+
7+
class BaseLoss(LossInterface):
8+
"""
9+
Base class for all losses, implementing common functionality.
10+
11+
All specific loss types should inherit from this class and implement its
12+
abstract methods.
13+
14+
This class is not meant to be instantiated directly.
15+
"""
16+
17+
# Define available reduction methods
18+
_REDUCTION_METHOD = {
19+
"sum": lambda x: torch.sum(x, keepdim=True, dim=-1),
20+
"mean": lambda x: torch.mean(x, keepdim=True, dim=-1),
21+
"none": lambda x: x,
22+
}
23+
24+
def __init__(self, reduction="mean"):
25+
"""
26+
Initialization of the :class:`BaseLoss` class.
27+
28+
:param str reduction: The reduction method to aggregate pointwise loss
29+
values. Available options include: ``"none"`` for unreduced loss,
30+
``"mean"`` for the average of the loss values, and ``"sum"`` for
31+
their total sum. Default is ``"mean"``.
32+
:raises ValueError: If the specified reduction method is not among the
33+
available options.
34+
"""
35+
# Check that the reduction method is available
36+
if reduction not in self._REDUCTION_METHOD:
37+
raise ValueError(
38+
f"Invalid reduction method. Available options: "
39+
f"{list(self._REDUCTION_METHOD.keys())}. Got {reduction}."
40+
)
41+
42+
# Initialization
43+
super().__init__(reduction=reduction, size_average=None, reduce=None)
44+
45+
def _reduction(self, loss):
46+
"""
47+
Apply the configured reduction operation to pointwise loss values.
48+
49+
:param torch.Tensor loss: The tensor of pointwise losses.
50+
:return: The reduced loss tensor.
51+
:rtype: torch.Tensor
52+
"""
53+
return self._REDUCTION_METHOD[self.reduction](loss)

pina/_src/loss/loss_interface.py

Lines changed: 9 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,51 +2,30 @@
22

33
from abc import ABCMeta, abstractmethod
44
from torch.nn.modules.loss import _Loss
5-
import torch
65

76

87
class LossInterface(_Loss, metaclass=ABCMeta):
98
"""
10-
Abstract base class for all losses. All classes defining a loss function
11-
should inherit from this interface.
9+
Abstract interface for all losses.
1210
"""
1311

14-
def __init__(self, reduction="mean"):
15-
"""
16-
Initialization of the :class:`LossInterface` class.
17-
18-
:param str reduction: The reduction method for the loss.
19-
Available options: ``none``, ``mean``, ``sum``.
20-
If ``none``, no reduction is applied. If ``mean``, the sum of the
21-
loss values is divided by the number of values. If ``sum``, the loss
22-
values are summed. Default is ``mean``.
23-
"""
24-
super().__init__(reduction=reduction, size_average=None, reduce=None)
25-
2612
@abstractmethod
2713
def forward(self, input, target):
2814
"""
2915
Forward method of the loss function.
3016
31-
:param torch.Tensor input: Input tensor from real data.
32-
:param torch.Tensor target: Model tensor output.
17+
:param torch.Tensor input: The input tensor.
18+
:param torch.Tensor target: The target tensor.
19+
:return: The computed loss.
20+
:rtype: torch.Tensor
3321
"""
3422

23+
@abstractmethod
3524
def _reduction(self, loss):
3625
"""
37-
Apply the reduction to the loss.
26+
Apply the configured reduction operation to pointwise loss values.
3827
39-
:param torch.Tensor loss: The tensor containing the pointwise losses.
40-
:raises ValueError: If the reduction method is not valid.
41-
:return: Reduced loss.
28+
:param torch.Tensor loss: The tensor of pointwise losses.
29+
:return: The reduced loss tensor.
4230
:rtype: torch.Tensor
4331
"""
44-
if self.reduction == "none":
45-
ret = loss
46-
elif self.reduction == "mean":
47-
ret = torch.mean(loss, keepdim=True, dim=-1)
48-
elif self.reduction == "sum":
49-
ret = torch.sum(loss, keepdim=True, dim=-1)
50-
else:
51-
raise ValueError(self.reduction + " is not valid")
52-
return ret

pina/_src/loss/lp_loss.py

Lines changed: 49 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,91 @@
1-
"""Module for the LpLoss class."""
1+
"""Module for the Lp Loss class."""
22

33
import torch
4-
from pina._src.loss.loss_interface import LossInterface
4+
from pina._src.loss.base_loss import BaseLoss
55
from pina._src.core.utils import check_consistency
66

77

8-
class LpLoss(LossInterface):
8+
class LpLoss(BaseLoss):
99
r"""
10-
Implementation of the Lp Loss. It defines a criterion to measures the
11-
pointwise Lp error between values in the input :math:`x` and values in the
12-
target :math:`y`.
10+
Implementation of the :math:`L^p` loss measuring the pointwise :math:`L^p`
11+
distance between an input tensor :math:`x` and a target tensor :math:`y`.
1312
14-
If ``reduction`` is set to ``none``, the loss can be written as:
13+
Given a batch of size :math:`N` and feature dimension :math:`D`, the
14+
unreduced loss (``reduction="none"``) is defined as:
1515
1616
.. math::
17-
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
18-
l_n = \left[\sum_{i=1}^{D} \left| x_n^i - y_n^i \right|^p \right],
19-
20-
If ``relative`` is set to ``True``, the relative Lp error is computed:
17+
L = \{l_1, \dots, l_N\}^\top, \quad
18+
l_n = \left( \sum_{i=1}^{D} \left| x_n^i - y_n^i \right|^p \right)^{1/p}
19+
20+
If ``relative=True``, each term is normalized by the :math:`L^p` norm of the
21+
input tensor :math:`x`:
2122
2223
.. math::
23-
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
24-
l_n = \frac{ [\sum_{i=1}^{D} | x_n^i - y_n^i|^p] }
25-
{[\sum_{i=1}^{D}|y_n^i|^p]},
24+
l_n = \frac{\left( \sum_{i=1}^{D} |x_n^i - y_n^i|^p \right)^{1/p}}
25+
{\left( \sum_{i=1}^{D} |x_n^i|^p \right)^{1/p}}
2626
27-
where :math:`N` is the batch size.
28-
29-
If ``reduction`` is not ``none``, then:
27+
If ``reduction`` is set to ``"mean"`` or ``"sum"``, the vector :math:`L`
28+
is aggregated accordingly:
3029
3130
.. math::
3231
\ell(x, y) =
3332
\begin{cases}
34-
\operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
35-
\operatorname{sum}(L), & \text{if reduction} = \text{`sum'.}
33+
\operatorname{mean}(L), & \text{if reduction} = \text{``mean''} \\
34+
\operatorname{sum}(L), & \text{if reduction} = \text{``sum''}
3635
\end{cases}
36+
37+
where :math:`N` is the batch size.
3738
"""
3839

3940
def __init__(self, p=2, reduction="mean", relative=False):
4041
"""
4142
Initialization of the :class:`LpLoss` class.
4243
43-
:param int p: Degree of the Lp norm. It specifies the norm to be
44-
computed. Default is ``2`` (euclidean norm).
45-
:param str reduction: The reduction method for the loss.
46-
Available options: ``none``, ``mean``, ``sum``.
47-
If ``none``, no reduction is applied. If ``mean``, the sum of the
48-
loss values is divided by the number of values. If ``sum``, the loss
49-
values are summed. Default is ``mean``.
50-
:param bool relative: If ``True``, the relative error is computed.
44+
:param p: The order of the norm. It can be a numeric value for standard
45+
p-norms or one of the following strings: ``"inf"`` for maximum
46+
absolute value, ``"-inf"`` for minimum absolute value. The values
47+
``"inf"`` and ``"-inf"`` are internally converted to their floating
48+
counterparts. Default is ``2``.
49+
:type p: int | float | str
50+
:param str reduction: The reduction method to aggregate pointwise loss
51+
values. Available options include: ``"none"`` for unreduced loss,
52+
``"mean"`` for the average of the loss values, and ``"sum"`` for
53+
their total sum. Default is ``"mean"``.
54+
:param bool relative: If ``True``, computes the relative error.
5155
Default is ``False``.
56+
:raises ValueError: If ``relative`` is not a boolean.
57+
:raises ValueError: If ``p`` is not a valid norm order.
5258
"""
5359
super().__init__(reduction=reduction)
5460

55-
# check consistency
56-
check_consistency(p, (str, int, float))
61+
# Convert to float if inf or -inf
62+
if p == "inf":
63+
p = float("inf")
64+
elif p == "-inf":
65+
p = float("-inf")
66+
67+
# Check consistency
5768
check_consistency(relative, bool)
69+
check_consistency(p, (int, float))
5870

71+
# Initialize attributes
5972
self.p = p
6073
self.relative = relative
6174

6275
def forward(self, input, target):
6376
"""
6477
Forward method of the loss function.
6578
66-
:param torch.Tensor input: Input tensor from real data.
67-
:param torch.Tensor target: Model tensor output.
68-
:return: Loss evaluation.
79+
:param torch.Tensor input: The input tensor.
80+
:param torch.Tensor target: The target tensor.
81+
:return: The computed loss.
6982
:rtype: torch.Tensor
7083
"""
84+
# Compute the standard loss
7185
loss = torch.linalg.norm((input - target), ord=self.p, dim=-1)
86+
87+
# Compute the input norm for relative error
7288
if self.relative:
7389
loss = loss / torch.linalg.norm(input, ord=self.p, dim=-1)
90+
7491
return self._reduction(loss)

pina/_src/loss/power_loss.py

Lines changed: 40 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,76 +1,81 @@
1-
"""Module for the PowerLoss class."""
1+
"""Module for the Power Loss class."""
22

33
import torch
4+
from pina._src.loss.base_loss import BaseLoss
5+
from pina._src.core.utils import check_consistency, check_positive_integer
46

5-
from pina._src.loss.loss_interface import LossInterface
6-
from pina._src.core.utils import check_consistency
77

8-
9-
class PowerLoss(LossInterface):
8+
class PowerLoss(BaseLoss):
109
r"""
11-
Implementation of the Power Loss. It defines a criterion to measures the
12-
pointwise error between values in the input :math:`x` and values in the
13-
target :math:`y`.
10+
Implementation of the Power loss, measuring the pointwise averaged
11+
:math:`p`-power error between an input tensor :math:`x` and a target tensor
12+
:math:`y`.
1413
15-
If ``reduction`` is set to ``none``, the loss can be written as:
14+
Given a batch of size :math:`N` and feature dimension :math:`D`, the
15+
unreduced loss (``reduction="none"``) is defined as:
1616
1717
.. math::
18-
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
19-
l_n = \frac{1}{D}\left[\sum_{i=1}^{D}
20-
\left| x_n^i - y_n^i \right|^p\right],
21-
22-
If ``relative`` is set to ``True``, the relative error is computed:
18+
L = \{l_1, \dots, l_N\}^\top, \quad
19+
l_n = \frac{1}{D} \sum_{i=1}^{D} \left| x_n^i - y_n^i \right|^p
20+
21+
If ``relative=True``, each term is normalized by the averaged
22+
:math:`p`-power magnitude of the input tensor :math:`x`:
2323
2424
.. math::
25-
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
26-
l_n = \frac{ \sum_{i=1}^{D} | x_n^i - y_n^i|^p }
27-
{\sum_{i=1}^{D}|y_n^i|^p},
25+
l_n = \frac{\frac{1}{D} \sum_{i=1}^{D} |x_n^i - y_n^i|^p}
26+
{\frac{1}{D} \sum_{i=1}^{D} |x_n^i|^p}
2827
29-
where :math:`N` is the batch size.
30-
31-
If ``reduction`` is not ``none``, then:
28+
If ``reduction`` is set to ``"mean"`` or ``"sum"``, the vector :math:`L`
29+
is aggregated accordingly:
3230
3331
.. math::
3432
\ell(x, y) =
3533
\begin{cases}
36-
\operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
37-
\operatorname{sum}(L), & \text{if reduction} = \text{`sum'.}
34+
\operatorname{mean}(L), & \text{if reduction} = \text{``mean''} \\
35+
\operatorname{sum}(L), & \text{if reduction} = \text{``sum''}
3836
\end{cases}
37+
38+
where :math:`N` is the batch size.
3939
"""
4040

4141
def __init__(self, p=2, reduction="mean", relative=False):
4242
"""
4343
Initialization of the :class:`PowerLoss` class.
4444
45-
:param int p: Degree of the Lp norm. It specifies the norm to be
46-
computed. Default is ``2`` (euclidean norm).
47-
:param str reduction: The reduction method for the loss.
48-
Available options: ``none``, ``mean``, ``sum``.
49-
If ``none``, no reduction is applied. If ``mean``, the sum of the
50-
loss values is divided by the number of values. If ``sum``, the loss
51-
values are summed. Default is ``mean``.
52-
:param bool relative: If ``True``, the relative error is computed.
45+
:param int p: The order of the p-norm. Default is ``2``.
46+
:param str reduction: The reduction method to aggregate pointwise loss
47+
values. Available options include: ``"none"`` for unreduced loss,
48+
``"mean"`` for the average of the loss values, and ``"sum"`` for
49+
their total sum. Default is ``"mean"``.
50+
:param bool relative: If ``True``, computes the relative error.
5351
Default is ``False``.
52+
:raises ValueError: If ``relative`` is not a boolean.
53+
:raises ValueError: If ``p`` is not a positive integer.
5454
"""
5555
super().__init__(reduction=reduction)
5656

57-
# check consistency
58-
check_consistency(p, (str, int, float))
57+
# Check consistency
5958
check_consistency(relative, bool)
59+
check_positive_integer(p, strict=True)
6060

61+
# Initialize attributes
6162
self.p = p
6263
self.relative = relative
6364

6465
def forward(self, input, target):
6566
"""
6667
Forward method of the loss function.
6768
68-
:param torch.Tensor input: Input tensor from real data.
69-
:param torch.Tensor target: Model tensor output.
70-
:return: Loss evaluation.
69+
:param torch.Tensor input: The input tensor.
70+
:param torch.Tensor target: The target tensor.
71+
:return: The computed loss.
7172
:rtype: torch.Tensor
7273
"""
74+
# Compute the standard loss
7375
loss = torch.abs((input - target)).pow(self.p).mean(-1)
76+
77+
# Compute the input norm for relative error
7478
if self.relative:
7579
loss = loss / torch.abs(input).pow(self.p).mean(-1)
80+
7681
return self._reduction(loss)

0 commit comments

Comments
 (0)