|
1 | 1 | """Module for the Equation.""" |
2 | 2 |
|
3 | 3 | import inspect |
4 | | -from pina._src.equation.equation_interface import EquationInterface |
| 4 | +from pina._src.equation.base_equation import BaseEquation |
5 | 5 |
|
6 | 6 |
|
7 | | -class Equation(EquationInterface): |
| 7 | +class Equation(BaseEquation): |
8 | 8 | """ |
9 | | - Implementation of the Equation class. Every ``equation`` passed to a |
10 | | - :class:`~pina.condition.condition.Condition` object must be either an |
11 | | - instance of :class:`Equation` or |
12 | | - :class:`~pina.equation.system_equation.SystemEquation`. |
| 9 | + Implementation of the Equation class, representing a single mathematical |
| 10 | + equation to be satisfied by the model outputs. |
| 11 | +
|
| 12 | + It can be passed to a :class:`~pina.condition.condition.Condition` object to |
| 13 | + define the conditions under which the model is trained. |
13 | 14 | """ |
14 | 15 |
|
15 | 16 | def __init__(self, equation): |
16 | 17 | """ |
17 | 18 | Initialization of the :class:`Equation` class. |
18 | 19 |
|
19 | | - :param Callable equation: A ``torch`` callable function used to compute |
20 | | - the residual of a mathematical equation. |
| 20 | + :param Callable equation: A callable function used to compute the |
| 21 | + residual of a mathematical equation. |
21 | 22 | :raises ValueError: If the equation is not a callable function. |
22 | 23 | """ |
| 24 | + # Check consistency |
23 | 25 | if not callable(equation): |
24 | | - raise ValueError( |
25 | | - "equation must be a callable function." |
26 | | - "Expected a callable function, got " |
27 | | - f"{equation}" |
28 | | - ) |
29 | | - # compute the signature |
| 26 | + raise ValueError(f"Expected a callable function, got {equation}") |
| 27 | + |
| 28 | + # Compute the signature length |
30 | 29 | sig = inspect.signature(equation) |
31 | 30 | self.__len_sig = len(sig.parameters) |
32 | 31 | self.__equation = equation |
33 | 32 |
|
34 | 33 | def residual(self, input_, output_, params_=None): |
35 | 34 | """ |
36 | | - Compute the residual of the equation. |
| 35 | + Evaluate the equation residual at the given inputs. |
37 | 36 |
|
38 | | - :param LabelTensor input_: Input points where the equation is evaluated. |
39 | | - :param LabelTensor output_: Output tensor, eventually produced by a |
| 37 | + :param LabelTensor input_: The input points where the residual is |
| 38 | + computed. |
| 39 | + :param LabelTensor output_: The output tensor, potentially produced by a |
40 | 40 | :class:`torch.nn.Module` instance. |
41 | | - :param dict params_: Dictionary of unknown parameters, associated with a |
42 | | - :class:`~pina.problem.inverse_problem.InverseProblem` instance. |
43 | | - If the equation is not related to a |
44 | | - :class:`~pina.problem.inverse_problem.InverseProblem` instance, the |
45 | | - parameters must be initialized to ``None``. Default is ``None``. |
46 | | - :return: The computed residual of the equation. |
| 41 | + :param dict params_: An optional dictionary of unknown parameters, used |
| 42 | + in :class:`~pina.problem.inverse_problem.InverseProblem` settings. |
| 43 | + If the equation is not related to an inverse problem, this should be |
| 44 | + set to ``None``. Default is ``None``. |
| 45 | + :raises RuntimeError: If the underlying equation signature is neither of |
| 46 | + length 2 for direct problems nor of length 3 for inverse problems. |
| 47 | + :return: The residual values of the equation. |
47 | 48 | :rtype: LabelTensor |
48 | | - :raises RuntimeError: If the underlying equation signature length is not |
49 | | - 2 (direct problem) or 3 (inverse problem). |
50 | 49 | """ |
51 | 50 | # Move the equation to the input_ device |
52 | 51 | self.to(input_.device) |
53 | 52 |
|
54 | | - # Call the underlying equation based on its signature length |
| 53 | + # Evaluate the equation for direct problems |
55 | 54 | if self.__len_sig == 2: |
56 | 55 | return self.__equation(input_, output_) |
| 56 | + |
| 57 | + # Evaluate the equation for inverse problems |
57 | 58 | if self.__len_sig == 3: |
58 | 59 | return self.__equation(input_, output_, params_) |
| 60 | + |
| 61 | + # Raise an error if the signature length is unexpected |
59 | 62 | raise RuntimeError( |
60 | 63 | f"Unexpected number of arguments in equation: {self.__len_sig}. " |
61 | | - "Expected either 2 (direct problem) or 3 (inverse problem)." |
| 64 | + "Expected either 2 for direct problems, or 3 for inverse problems." |
62 | 65 | ) |
0 commit comments