Skip to content

Commit cb458b0

Browse files
committed
Doc conditions
1 parent b7f5ac1 commit cb458b0

File tree

7 files changed

+121
-56
lines changed

7 files changed

+121
-56
lines changed

pina/condition/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Module for conditions.
2+
Module for importing Conditions objects.
33
"""
44

55
__all__ = [

pina/condition/condition.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
"""Condition module."""
1+
"""
2+
Condition module.
3+
"""
24

35
import warnings
46
from .data_condition import DataCondition
@@ -13,12 +15,11 @@
1315

1416

1517
def warning_function(new, old):
16-
"""Handle the deprecation warning.
18+
"""
19+
Handle the deprecation warning.
1720
18-
:param new: Object to use instead of the old one.
19-
:type new: str
20-
:param old: Object to deprecate.
21-
:type old: str
21+
:param str new: Object to use instead of the old one.
22+
:param str old: Object to deprecate.
2223
"""
2324
warnings.warn(
2425
f"'{old}' is deprecated and will be removed "
@@ -72,7 +73,6 @@ class Condition:
7273
... input=data,
7374
... conditional_variables=conditional_variables
7475
... )
75-
7676
"""
7777

7878
__slots__ = list(
@@ -85,7 +85,19 @@ class Condition:
8585
)
8686

8787
def __new__(cls, *args, **kwargs):
88-
88+
"""
89+
Create a new condition object based on the keyword arguments passed.
90+
- ``input`` and ``target``: :class:`InputTargetCondition`
91+
- ``domain`` and ``equation``: :class:`DomainEquationCondition`
92+
- ``input`` and ``equation``: :class:`InputEquationCondition`
93+
- ``input``: :class:`DataCondition`
94+
- ``input`` and ``conditional_variables``: :class:`DataCondition`
95+
96+
:raises ValueError: No valid condition has been found.
97+
:return: A new condition instance belonging to the proper class.
98+
:rtype: ConditionInputTarget | ConditionInputEquation |
99+
ConditionDomainEquation | ConditionData
100+
"""
89101
if len(args) != 0:
90102
raise ValueError(
91103
"Condition takes only the following keyword "

pina/condition/condition_interface.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,36 @@ def problem(self):
2121
"""
2222
Return the problem to which the condition is associated.
2323
24-
:return: Problem to which the condition is associated
24+
:return: Problem to which the condition is associated.
2525
:rtype: pina.problem.AbstractProblem
2626
"""
27+
2728
return self._problem
2829

2930
@problem.setter
3031
def problem(self, value):
32+
"""
33+
Set the problem to which the condition is associated.
34+
35+
:param value: Problem to which the condition is associated.
36+
:type value: pina.problem.AbstractProblem
37+
"""
38+
3139
self._problem = value
3240

3341
@staticmethod
3442
def _check_graph_list_consistency(data_list):
43+
"""
44+
Check if the list of Data/Graph objects is consistent.
45+
46+
:param data_list: list of Data/Graph objects.
47+
:type data_list: list(Data) | list(Graph)
48+
49+
:raises ValueError: Input data must be either Data or Graph objects.
50+
:raises ValueError: All elements in the list must have the same keys.
51+
:raises ValueError: Type mismatch in data tensors.
52+
:raises ValueError: Label mismatch in LabelTensors.
53+
"""
3554

3655
# If the data is a Graph or Data object, return (do not need to check
3756
# anything)

pina/condition/data_condition.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,20 @@ class DataCondition(ConditionInterface):
2323

2424
def __new__(cls, input, conditional_variables=None):
2525
"""
26-
Instanciate the correct subclass of DataCondition by checking the type
27-
of the input data (input and conditional_variables).
26+
Instantiate the appropriate subclass of DataCondition based on the
27+
types of input data.
28+
29+
:param input: Input data for the condition.
30+
:type input: torch.Tensor | LabelTensor | Graph | Data
31+
:param conditional_variables: Conditional variables for the condition.
32+
:type conditional_variables: torch.Tensor | LabelTensor
33+
:return: Subclass of DataCondition.
34+
:rtype: TensorDataCondition | GraphDataCondition
35+
36+
:raises ValueError: If input is not of type :class:`torch.Tensor`,
37+
:class:`LabelTensor`, :class:`Graph`, or :class:`Data`.
38+
2839
29-
:param input: torch.Tensor or Graph/Data object containing the input
30-
data
31-
:type input: torch.Tensor or Graph or Data
32-
:param conditional_variables: torch.Tensor or LabelTensor containing
33-
the conditional variables
34-
:type conditional_variables: torch.Tensor or LabelTensor
35-
:return: DataCondition subclass
36-
:rtype: TensorDataCondition or GraphDataCondition
3740
"""
3841
if cls != DataCondition:
3942
return super().__new__(cls)
@@ -56,12 +59,15 @@ def __init__(self, input, conditional_variables=None):
5659
Initialize the DataCondition, storing the input and conditional
5760
variables (if any).
5861
59-
:param input: torch.Tensor or Graph/Data object containing the input
60-
data
62+
:param input: Input data for the condition.
6163
:type input: torch.Tensor or Graph or Data
62-
:param conditional_variables: torch.Tensor or LabelTensor containing
63-
the conditional variables
64+
:param conditional_variables: Conditional variables for the condition.
6465
:type conditional_variables: torch.Tensor or LabelTensor
66+
67+
.. note::
68+
If either `input` is composed by a list of :class:`Graph`/
69+
:class:`Data` objects, all elements must have the same structure
70+
(keys and data types)
6571
"""
6672
super().__init__()
6773
self.input = input

pina/condition/domain_equation_condition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ def __init__(self, domain, equation):
2020
"""
2121
Initialize the DomainEquationCondition, storing the domain and equation.
2222
23-
:param DomainInterface domain: Domain object containing the domain data
23+
:param DomainInterface domain: Domain object containing the domain data.
2424
:param EquationInterface equation: Equation object containing the
25-
equation data
25+
equation data.
2626
"""
2727
super().__init__()
2828
self.domain = domain

pina/condition/input_equation_condition.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,18 @@ class InputEquationCondition(ConditionInterface):
2222

2323
def __new__(cls, input, equation):
2424
"""
25-
Instanciate the correct subclass of InputEquationCondition by checking
26-
the type of the input data (only `input`).
25+
Instantiate the appropriate subclass of InputEquationCondition based on
26+
the type of input data.
2727
28-
:param input: torch.Tensor or Graph/Data object containing the input
29-
:type input: torch.Tensor or Graph or Data
28+
:param input: Input data. It can be a LabelTensor or a Graph object.
29+
:type input: LabelTensor | Graph
3030
:param EquationInterface equation: Equation object containing the
31-
equation function
32-
:return: InputEquationCondition subclass
33-
:rtype: InputTensorEquationCondition or InputGraphEquationCondition
31+
equation function.
32+
:return: Subclass of InputEquationCondition, based on the input type.
33+
:rtype: InputTensorEquationCondition | InputGraphEquationCondition
34+
35+
:raises ValueError: If input is not of type :class:`torch.Tensor`,
36+
:class:`LabelTensor`, :class:`Graph`, or :class:`Data`.
3437
"""
3538

3639
# If the class is already a subclass, return the instance
@@ -56,11 +59,18 @@ def __init__(self, input, equation):
5659
"""
5760
Initialize the InputEquationCondition by storing the input and equation.
5861
59-
:param input: torch.Tensor or Graph/Data object containing the input
60-
:type input: torch.Tensor or Graph or Data
62+
:param input: torch.Tensor or Graph/Data object containing the input.
63+
:type input: torch.Tensor | Graph
6164
:param EquationInterface equation: Equation object containing the
62-
equation function
65+
equation function.
66+
67+
.. note::
68+
If ``input`` is composed by a list of :class:`Graph`/:class:`Data`
69+
objects, all elements must have the same structure (keys and data
70+
types). Moreover, at least one attribute must be a
71+
:class:`LabelTensor`.
6372
"""
73+
6474
super().__init__()
6575
self.input = input
6676
self.equation = equation
@@ -90,11 +100,15 @@ class InputGraphEquationCondition(InputEquationCondition):
90100
@staticmethod
91101
def _check_label_tensor(input):
92102
"""
93-
Check if the input is a LabelTensor.
103+
Check if at least one LabelTensor is present in the Graph object.
94104
95-
:param input: input data
105+
:param input: Input data.
96106
:type input: torch.Tensor or Graph or Data
107+
108+
:raises ValueError: If the input data object does not contain at least
109+
one LabelTensor.
97110
"""
111+
98112
# Store the fist element of the list/tuple if input is a list/tuple
99113
# it is anougth to check the first element because all elements must
100114
# have the same type and structure (already checked)

pina/condition/input_target_condition.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,23 @@ class InputTargetCondition(ConditionInterface):
2121

2222
def __new__(cls, input, target):
2323
"""
24-
Instanciate the correct subclass of InputTargetCondition by checking the
25-
type of the input and target data.
26-
27-
:param input: torch.Tensor or Graph/Data object containing the input
28-
:type input: torch.Tensor or Graph or Data
29-
:param target: torch.Tensor or Graph/Data object containing the target
30-
:type target: torch.Tensor or Graph or Data
31-
:return: InputTargetCondition subclass
32-
:rtype: TensorInputTensorTargetCondition or
33-
TensorInputGraphTargetCondition or GraphInputTensorTargetCondition
34-
or GraphInputGraphTargetCondition
24+
Instantiate the appropriate subclass of InputTargetCondition based on
25+
the types of input and target data.
26+
27+
:param input: Input data for the condition.
28+
:type input: torch.Tensor | Graph | Data | list | tuple
29+
:param target: Target data for the condition.
30+
Graph, Data, or list/tuple.
31+
:type target: torch.Tensor | Graph | Data | list | tuple
32+
:return: Subclass of InputTargetCondition
33+
:rtype: TensorInputTensorTargetCondition |
34+
TensorInputGraphTargetCondition |
35+
GraphInputTensorTargetCondition |
36+
GraphInputGraphTargetCondition
37+
38+
:raises ValueError: If input and or target are not of type
39+
:class:`torch.Tensor`, :class:`LabelTensor`, :class:`Graph`, or
40+
:class:`Data`.
3541
"""
3642
if cls != InputTargetCondition:
3743
return super().__new__(cls)
@@ -74,10 +80,16 @@ def __init__(self, input, target):
7480
Initialize the InputTargetCondition, storing the input and target data.
7581
7682
:param input: torch.Tensor or Graph/Data object containing the input
77-
:type input: torch.Tensor or Graph or Data
83+
:type input: torch.Tensor | Graph or Data
7884
:param target: torch.Tensor or Graph/Data object containing the target
7985
:type target: torch.Tensor or Graph or Data
86+
87+
.. note::
88+
If either ``input`` or ``target`` are composed by a list of
89+
:class:`Graph`/:class:`Data` objects, all elements must have the
90+
same structure (keys and data types)
8091
"""
92+
8193
super().__init__()
8294
self._check_input_target_len(input, target)
8395
self.input = input
@@ -97,25 +109,27 @@ def _check_input_target_len(input, target):
97109

98110
class TensorInputTensorTargetCondition(InputTargetCondition):
99111
"""
100-
InputTargetCondition subclass for torch.Tensor input and target data.
112+
InputTargetCondition subclass for :class:`torch.Tensor`/:class:`LabelTensor`
113+
input and target data.
101114
"""
102115

103116

104117
class TensorInputGraphTargetCondition(InputTargetCondition):
105118
"""
106-
InputTargetCondition subclass for torch.Tensor input and Graph/Data target
107-
data.
119+
InputTargetCondition subclass for :class:`torch.Tensor`/:class:`LabelTensor`
120+
input and :class:`Graph`/:class:`Data` target data.
108121
"""
109122

110123

111124
class GraphInputTensorTargetCondition(InputTargetCondition):
112125
"""
113-
InputTargetCondition subclass for Graph/Data input and torch.Tensor target
114-
data.
126+
InputTargetCondition subclass for :class:`Graph`/:class:`Data` input and
127+
:class:`torch.Tensor`/:class:`LabelTensor` target data.
115128
"""
116129

117130

118131
class GraphInputGraphTargetCondition(InputTargetCondition):
119132
"""
120-
InputTargetCondition subclass for Graph/Data input and target data.
133+
InputTargetCondition subclass for :class:`Graph`/:class:`Data` input and
134+
target data.
121135
"""

0 commit comments

Comments
 (0)