Skip to content

Commit fde6450

Browse files
committed
Small fixes in conditions
1 parent cb458b0 commit fde6450

File tree

5 files changed

+67
-60
lines changed

5 files changed

+67
-60
lines changed

pina/condition/condition.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -56,23 +56,23 @@ class Condition:
5656
5757
Example::
5858
59-
>>> from pina import Condition
60-
>>> condition = Condition(
61-
... input=input,
62-
... target=target
63-
... )
64-
>>> condition = Condition(
65-
... domain=location,
66-
... equation=equation
67-
... )
68-
>>> condition = Condition(
69-
... input=input,
70-
... equation=equation
71-
... )
72-
>>> condition = Condition(
73-
... input=data,
74-
... conditional_variables=conditional_variables
75-
... )
59+
>>> from pina import Condition
60+
>>> condition = Condition(
61+
... input=input,
62+
... target=target
63+
... )
64+
>>> condition = Condition(
65+
... domain=location,
66+
... equation=equation
67+
... )
68+
>>> condition = Condition(
69+
... input=input,
70+
... equation=equation
71+
... )
72+
>>> condition = Condition(
73+
... input=data,
74+
... conditional_variables=conditional_variables
75+
... )
7676
"""
7777

7878
__slots__ = list(
@@ -87,6 +87,7 @@ class Condition:
8787
def __new__(cls, *args, **kwargs):
8888
"""
8989
Create a new condition object based on the keyword arguments passed.
90+
9091
- ``input`` and ``target``: :class:`InputTargetCondition`
9192
- ``domain`` and ``equation``: :class:`DomainEquationCondition`
9293
- ``input`` and ``equation``: :class:`InputEquationCondition`
@@ -95,8 +96,8 @@ def __new__(cls, *args, **kwargs):
9596
9697
:raises ValueError: No valid condition has been found.
9798
:return: A new condition instance belonging to the proper class.
98-
:rtype: ConditionInputTarget | ConditionInputEquation |
99-
ConditionDomainEquation | ConditionData
99+
:rtype: InputTargetCondition | DomainEquationCondition |
100+
InputEquationCondition | DataCondition
100101
"""
101102
if len(args) != 0:
102103
raise ValueError(

pina/condition/condition_interface.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,14 @@ def problem(self, value):
4141
@staticmethod
4242
def _check_graph_list_consistency(data_list):
4343
"""
44-
Check if the list of Data/Graph objects is consistent.
44+
Check if the list of :class:`torch_geometric.data.Data`/:class:`Graph`
45+
objects is consistent.
4546
46-
:param data_list: list of Data/Graph objects.
47-
:type data_list: list(Data) | list(Graph)
47+
:param data_list: List of graph type objects.
48+
:type data_list: list(torch_geometric.data.Data) | list(Graph)
4849
49-
:raises ValueError: Input data must be either Data or Graph objects.
50+
:raises ValueError: Input data must be either torch_geometric.data.Data
51+
or Graph objects.
5052
:raises ValueError: All elements in the list must have the same keys.
5153
:raises ValueError: Type mismatch in data tensors.
5254
:raises ValueError: Label mismatch in LabelTensors.

pina/condition/data_condition.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@
1111

1212
class DataCondition(ConditionInterface):
1313
"""
14-
Condition for data. This condition must be used every
15-
time a Unsupervised Loss is needed in the Solver. The conditionalvariable
16-
can be passed as extra-input when the model learns a conditional
17-
distribution
14+
This condition must be used every time a Unsupervised Loss is needed in
15+
the Solver. The conditionalvariable can be passed as extra-input when
16+
the model learns a conditional distribution.
1817
"""
1918

2019
__slots__ = ["input", "conditional_variables"]
@@ -27,14 +26,16 @@ def __new__(cls, input, conditional_variables=None):
2726
types of input data.
2827
2928
:param input: Input data for the condition.
30-
:type input: torch.Tensor | LabelTensor | Graph | Data
29+
:type input: torch.Tensor | LabelTensor | Graph | \
30+
torch_geometric.data.Data
3131
:param conditional_variables: Conditional variables for the condition.
3232
:type conditional_variables: torch.Tensor | LabelTensor
3333
:return: Subclass of DataCondition.
3434
:rtype: TensorDataCondition | GraphDataCondition
3535
3636
:raises ValueError: If input is not of type :class:`torch.Tensor`,
37-
:class:`LabelTensor`, :class:`Graph`, or :class:`Data`.
37+
:class:`LabelTensor`, :class:`Graph`, or
38+
:class:`torch_geometric.data.Data`.
3839
3940
4041
"""
@@ -51,7 +52,7 @@ def __new__(cls, input, conditional_variables=None):
5152

5253
raise ValueError(
5354
"Invalid input types. "
54-
"Please provide either Data or Graph objects."
55+
"Please provide either torch_geometric.data.Data or Graph objects."
5556
)
5657

5758
def __init__(self, input, conditional_variables=None):
@@ -60,14 +61,14 @@ def __init__(self, input, conditional_variables=None):
6061
variables (if any).
6162
6263
:param input: Input data for the condition.
63-
:type input: torch.Tensor or Graph or Data
64+
:type input: torch.Tensor or Graph or torch_geometric.data.Data
6465
:param conditional_variables: Conditional variables for the condition.
6566
:type conditional_variables: torch.Tensor or LabelTensor
6667
6768
.. note::
6869
If either `input` is composed by a list of :class:`Graph`/
69-
:class:`Data` objects, all elements must have the same structure
70-
(keys and data types)
70+
:class:`torch_geometric.data.Data` objects, all elements must have
71+
the same structure (keys and data types)
7172
"""
7273
super().__init__()
7374
self.input = input

pina/condition/input_equation_condition.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ def __new__(cls, input, equation):
3333
:rtype: InputTensorEquationCondition | InputGraphEquationCondition
3434
3535
:raises ValueError: If input is not of type :class:`torch.Tensor`,
36-
:class:`LabelTensor`, :class:`Graph`, or :class:`Data`.
36+
:class:`LabelTensor`, :class:`Graph`, or
37+
:class:`torch_geometric.data.Data`.
3738
"""
3839

3940
# If the class is already a subclass, return the instance
@@ -59,16 +60,16 @@ def __init__(self, input, equation):
5960
"""
6061
Initialize the InputEquationCondition by storing the input and equation.
6162
62-
:param input: torch.Tensor or Graph/Data object containing the input.
63+
:param input: Input data for the condition.
6364
:type input: torch.Tensor | Graph
6465
:param EquationInterface equation: Equation object containing the
6566
equation function.
6667
6768
.. note::
68-
If ``input`` is composed by a list of :class:`Graph`/:class:`Data`
69-
objects, all elements must have the same structure (keys and data
70-
types). Moreover, at least one attribute must be a
71-
:class:`LabelTensor`.
69+
If ``input`` is composed by a list of :class:`Graph`/
70+
:class:`torch_geometric.data.Data` objects, all elements must have
71+
the same structure (keys and data types). Moreover, at least one
72+
attribute must be a :class:`LabelTensor`.
7273
"""
7374

7475
super().__init__()
@@ -103,7 +104,7 @@ def _check_label_tensor(input):
103104
Check if at least one LabelTensor is present in the Graph object.
104105
105106
:param input: Input data.
106-
:type input: torch.Tensor or Graph or Data
107+
:type input: torch.Tensor | Graph | torch_geometric.data.Data
107108
108109
:raises ValueError: If the input data object does not contain at least
109110
one LabelTensor.

pina/condition/input_target_condition.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,20 @@ def __new__(cls, input, target):
2525
the types of input and target data.
2626
2727
:param input: Input data for the condition.
28-
:type input: torch.Tensor | Graph | Data | list | tuple
28+
:type input: torch.Tensor | Graph | torch_geometric.data.Data | list | \
29+
tuple
2930
:param target: Target data for the condition.
30-
Graph, Data, or list/tuple.
31-
:type target: torch.Tensor | Graph | Data | list | tuple
31+
:type target: torch.Tensor | Graph | torch_geometric.data.Data | list \
32+
| tuple
3233
:return: Subclass of InputTargetCondition
33-
:rtype: TensorInputTensorTargetCondition |
34-
TensorInputGraphTargetCondition |
35-
GraphInputTensorTargetCondition |
34+
:rtype: TensorInputTensorTargetCondition | \
35+
TensorInputGraphTargetCondition | \
36+
GraphInputTensorTargetCondition | \
3637
GraphInputGraphTargetCondition
3738
3839
:raises ValueError: If input and or target are not of type
3940
:class:`torch.Tensor`, :class:`LabelTensor`, :class:`Graph`, or
40-
:class:`Data`.
41+
:class:`torch_geometric.data.Data`.
4142
"""
4243
if cls != InputTargetCondition:
4344
return super().__new__(cls)
@@ -71,23 +72,23 @@ def __new__(cls, input, target):
7172

7273
raise ValueError(
7374
"Invalid input/target types. "
74-
"Please provide either Data, Graph, LabelTensor or torch.Tensor "
75-
"objects."
75+
"Please provide either torch_geometric.data.Data, Graph, LabelTensor "
76+
"or torch.Tensor objects."
7677
)
7778

7879
def __init__(self, input, target):
7980
"""
8081
Initialize the InputTargetCondition, storing the input and target data.
8182
82-
:param input: torch.Tensor or Graph/Data object containing the input
83-
:type input: torch.Tensor | Graph or Data
84-
:param target: torch.Tensor or Graph/Data object containing the target
85-
:type target: torch.Tensor or Graph or Data
83+
:param input: Input data for the condition.
84+
:type input: torch.Tensor | Graph | torch_geometric.data.Data
85+
:param target: Target data for the condition.
86+
:type target: torch.Tensor | Graph | torch_geometric.data.Data
8687
8788
.. note::
8889
If either ``input`` or ``target`` are composed by a list of
89-
:class:`Graph`/:class:`Data` objects, all elements must have the
90-
same structure (keys and data types)
90+
:class:`Graph`/:class:`torch_geometric.data.Data` objects, all
91+
elements must have the same structure (keys and data types)
9192
"""
9293

9394
super().__init__()
@@ -117,19 +118,20 @@ class TensorInputTensorTargetCondition(InputTargetCondition):
117118
class TensorInputGraphTargetCondition(InputTargetCondition):
118119
"""
119120
InputTargetCondition subclass for :class:`torch.Tensor`/:class:`LabelTensor`
120-
input and :class:`Graph`/:class:`Data` target data.
121+
input and :class:`Graph`/:class:`torch_geometric.data.Data` target data.
121122
"""
122123

123124

124125
class GraphInputTensorTargetCondition(InputTargetCondition):
125126
"""
126-
InputTargetCondition subclass for :class:`Graph`/:class:`Data` input and
127-
:class:`torch.Tensor`/:class:`LabelTensor` target data.
127+
InputTargetCondition subclass for :class:`Graph`/
128+
:class:`torch_geometric.data.Data` input and :class:`torch.Tensor`/
129+
:class:`LabelTensor` target data.
128130
"""
129131

130132

131133
class GraphInputGraphTargetCondition(InputTargetCondition):
132134
"""
133-
InputTargetCondition subclass for :class:`Graph`/:class:`Data` input and
134-
target data.
135+
InputTargetCondition subclass for :class:`Graph`/
136+
:class:`torch_geometric.data.Data` input and target data.
135137
"""

0 commit comments

Comments
 (0)