|
1 | | -"""Module for the LpLoss class.""" |
| 1 | +"""Module for the Lp Loss class.""" |
2 | 2 |
|
3 | 3 | import torch |
4 | | -from pina._src.loss.loss_interface import LossInterface |
| 4 | +from pina._src.loss.base_loss import BaseLoss |
5 | 5 | from pina._src.core.utils import check_consistency |
6 | 6 |
|
7 | 7 |
|
8 | | -class LpLoss(LossInterface): |
| 8 | +class LpLoss(BaseLoss): |
9 | 9 | r""" |
10 | | - Implementation of the Lp Loss. It defines a criterion to measures the |
11 | | - pointwise Lp error between values in the input :math:`x` and values in the |
12 | | - target :math:`y`. |
| 10 | + Implementation of the :math:`L^p` loss measuring the pointwise :math:`L^p` |
| 11 | + distance between an input tensor :math:`x` and a target tensor :math:`y`. |
13 | 12 |
|
14 | | - If ``reduction`` is set to ``none``, the loss can be written as: |
| 13 | + Given a batch of size :math:`N` and feature dimension :math:`D`, the |
| 14 | + unreduced loss (``reduction="none"``) is defined as: |
15 | 15 |
|
16 | 16 | .. math:: |
17 | | - \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad |
18 | | - l_n = \left[\sum_{i=1}^{D} \left| x_n^i - y_n^i \right|^p \right], |
19 | | - |
20 | | - If ``relative`` is set to ``True``, the relative Lp error is computed: |
| 17 | + L = \{l_1, \dots, l_N\}^\top, \quad |
| 18 | + l_n = \left( \sum_{i=1}^{D} \left| x_n^i - y_n^i \right|^p \right)^{1/p} |
| 19 | +
|
| 20 | + If ``relative=True``, each term is normalized by the :math:`L^p` norm of the |
| 21 | + input tensor :math:`x`: |
21 | 22 |
|
22 | 23 | .. math:: |
23 | | - \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad |
24 | | - l_n = \frac{ [\sum_{i=1}^{D} | x_n^i - y_n^i|^p] } |
25 | | - {[\sum_{i=1}^{D}|y_n^i|^p]}, |
| 24 | + l_n = \frac{\left( \sum_{i=1}^{D} |x_n^i - y_n^i|^p \right)^{1/p}} |
| 25 | + {\left( \sum_{i=1}^{D} |x_n^i|^p \right)^{1/p}} |
26 | 26 |
|
27 | | - where :math:`N` is the batch size. |
28 | | - |
29 | | - If ``reduction`` is not ``none``, then: |
| 27 | + If ``reduction`` is set to ``"mean"`` or ``"sum"``, the vector :math:`L` |
| 28 | + is aggregated accordingly: |
30 | 29 |
|
31 | 30 | .. math:: |
32 | 31 | \ell(x, y) = |
33 | 32 | \begin{cases} |
34 | | - \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ |
35 | | - \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} |
| 33 | + \operatorname{mean}(L), & \text{if reduction} = \text{``mean''} \\ |
| 34 | + \operatorname{sum}(L), & \text{if reduction} = \text{``sum''} |
36 | 35 | \end{cases} |
| 36 | +
|
| 37 | + where :math:`N` is the batch size. |
37 | 38 | """ |
38 | 39 |
|
39 | 40 | def __init__(self, p=2, reduction="mean", relative=False): |
40 | 41 | """ |
41 | 42 | Initialization of the :class:`LpLoss` class. |
42 | 43 |
|
43 | | - :param int p: Degree of the Lp norm. It specifies the norm to be |
44 | | - computed. Default is ``2`` (euclidean norm). |
45 | | - :param str reduction: The reduction method for the loss. |
46 | | - Available options: ``none``, ``mean``, ``sum``. |
47 | | - If ``none``, no reduction is applied. If ``mean``, the sum of the |
48 | | - loss values is divided by the number of values. If ``sum``, the loss |
49 | | - values are summed. Default is ``mean``. |
50 | | - :param bool relative: If ``True``, the relative error is computed. |
| 44 | + :param p: The order of the norm. It can be a numeric value for standard |
| 45 | + p-norms or one of the following strings: ``"inf"`` for maximum |
| 46 | + absolute value, ``"-inf"`` for minimum absolute value. The values |
| 47 | + ``"inf"`` and ``"-inf"`` are internally converted to their floating |
| 48 | + counterparts. Default is ``2``. |
| 49 | + :type p: int | float | str |
| 50 | + :param str reduction: The reduction method to aggregate pointwise loss |
| 51 | + values. Available options include: ``"none"`` for unreduced loss, |
| 52 | + ``"mean"`` for the average of the loss values, and ``"sum"`` for |
| 53 | + their total sum. Default is ``"mean"``. |
| 54 | + :param bool relative: If ``True``, computes the relative error. |
51 | 55 | Default is ``False``. |
| 56 | + :raises ValueError: If ``relative`` is not a boolean. |
| 57 | + :raises ValueError: If ``p`` is not a valid norm order. |
52 | 58 | """ |
53 | 59 | super().__init__(reduction=reduction) |
54 | 60 |
|
55 | | - # check consistency |
56 | | - check_consistency(p, (str, int, float)) |
| 61 | + # Convert to float if inf or -inf |
| 62 | + if p == "inf": |
| 63 | + p = float("inf") |
| 64 | + elif p == "-inf": |
| 65 | + p = float("-inf") |
| 66 | + |
| 67 | + # Check consistency |
57 | 68 | check_consistency(relative, bool) |
| 69 | + check_consistency(p, (int, float)) |
58 | 70 |
|
| 71 | + # Initialize attributes |
59 | 72 | self.p = p |
60 | 73 | self.relative = relative |
61 | 74 |
|
62 | 75 | def forward(self, input, target): |
63 | 76 | """ |
64 | 77 | Forward method of the loss function. |
65 | 78 |
|
66 | | - :param torch.Tensor input: Input tensor from real data. |
67 | | - :param torch.Tensor target: Model tensor output. |
68 | | - :return: Loss evaluation. |
| 79 | + :param torch.Tensor input: The input tensor. |
| 80 | + :param torch.Tensor target: The target tensor. |
| 81 | + :return: The computed loss. |
69 | 82 | :rtype: torch.Tensor |
70 | 83 | """ |
| 84 | + # Compute the standard loss |
71 | 85 | loss = torch.linalg.norm((input - target), ord=self.p, dim=-1) |
| 86 | + |
| 87 | + # Compute the input norm for relative error |
72 | 88 | if self.relative: |
73 | 89 | loss = loss / torch.linalg.norm(input, ord=self.p, dim=-1) |
| 90 | + |
74 | 91 | return self._reduction(loss) |
0 commit comments