Skip to content

Commit c91cdcc

Browse files
FilippoOlivodario-coscia
authored andcommitted
Update doc condition
1 parent f723937 commit c91cdcc

File tree

4 files changed

+28
-12
lines changed

4 files changed

+28
-12
lines changed

pina/condition/condition_interface.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ def _check_graph_list_consistency(data_list):
4545
objects is consistent.
4646
4747
:param data_list: List of graph type objects.
48-
:type data_list: list(torch_geometric.data.Data) | list(Graph)
48+
:type data_list: torch_geometric.data.Data | Graph|
49+
list[torch_geometric.data.Data] | list[Graph]
4950
5051
:raises ValueError: Input data must be either torch_geometric.data.Data
5152
or Graph objects.

pina/condition/data_condition.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@ def __new__(cls, input, conditional_variables=None):
2626
types of input data.
2727
2828
:param input: Input data for the condition.
29-
:type input: torch.Tensor | LabelTensor | Graph |
30-
torch_geometric.data.Data
29+
:type input: torch.Tensor | LabelTensor | Graph |
30+
torch_geometric.data.Data | list[Graph] |
31+
list[torch_geometric.data.Data] | tuple[Graph] |
32+
tuple[torch_geometric.data.Data]
3133
:param conditional_variables: Conditional variables for the condition.
3234
:type conditional_variables: torch.Tensor | LabelTensor
3335
:return: Subclass of DataCondition.
@@ -61,7 +63,10 @@ def __init__(self, input, conditional_variables=None):
6163
variables (if any).
6264
6365
:param input: Input data for the condition.
64-
:type input: torch.Tensor or Graph or torch_geometric.data.Data
66+
:type input: torch.Tensor | LabelTensor | Graph |
67+
torch_geometric.data.Data | list[Graph] |
68+
list[torch_geometric.data.Data] | tuple[Graph] |
69+
tuple[torch_geometric.data.Data]
6570
:param conditional_variables: Conditional variables for the condition.
6671
:type conditional_variables: torch.Tensor or LabelTensor
6772

pina/condition/input_equation_condition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __new__(cls, input, equation):
2626
the type of input data.
2727
2828
:param input: Input data. It can be a LabelTensor or a Graph object.
29-
:type input: LabelTensor | Graph
29+
:type input: LabelTensor | Graph | list[Graph] | tuple[Graph]
3030
:param EquationInterface equation: Equation object containing the
3131
equation function.
3232
:return: Subclass of InputEquationCondition, based on the input type.
@@ -61,7 +61,7 @@ def __init__(self, input, equation):
6161
Initialize the InputEquationCondition by storing the input and equation.
6262
6363
:param input: Input data for the condition.
64-
:type input: torch.Tensor | Graph
64+
:type input: LabelTensor | Graph | list[Graph] | tuple[Graph]
6565
:param EquationInterface equation: Equation object containing the
6666
equation function.
6767

pina/condition/input_target_condition.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,15 @@ def __new__(cls, input, target):
2525
the types of input and target data.
2626
2727
:param input: Input data for the condition.
28-
:type input: torch.Tensor | Graph | torch_geometric.data.Data | list | \
29-
tuple
28+
:type input: torch.Tensor | LabelTensor | Graph |
29+
torch_geometric.data.Data | list[Graph] |
30+
list[torch_geometric.data.Data] | tuple[Graph] |
31+
tuple[torch_geometric.data.Data]
3032
:param target: Target data for the condition.
31-
:type target: torch.Tensor | Graph | torch_geometric.data.Data | list \
32-
| tuple
33+
:type target: torch.Tensor | LabelTensor | Graph |
34+
torch_geometric.data.Data | list[Graph] |
35+
list[torch_geometric.data.Data] | tuple[Graph] |
36+
tuple[torch_geometric.data.Data]
3337
:return: Subclass of InputTargetCondition
3438
:rtype: TensorInputTensorTargetCondition | \
3539
TensorInputGraphTargetCondition | \
@@ -81,9 +85,15 @@ def __init__(self, input, target):
8185
Initialize the InputTargetCondition, storing the input and target data.
8286
8387
:param input: Input data for the condition.
84-
:type input: torch.Tensor | Graph | torch_geometric.data.Data
88+
:type input: torch.Tensor | LabelTensor | Graph |
89+
torch_geometric.data.Data | list[Graph] |
90+
list[torch_geometric.data.Data] | tuple[Graph] |
91+
tuple[torch_geometric.data.Data]
8592
:param target: Target data for the condition.
86-
:type target: torch.Tensor | Graph | torch_geometric.data.Data
93+
:type target: torch.Tensor | LabelTensor | Graph |
94+
torch_geometric.data.Data | list[Graph] |
95+
list[torch_geometric.data.Data] | tuple[Graph] |
96+
tuple[torch_geometric.data.Data]
8797
8898
.. note::
8999
If either ``input`` or ``target`` are composed by a list of

0 commit comments

Comments
 (0)