@@ -21,17 +21,23 @@ class InputTargetCondition(ConditionInterface):
2121
2222 def __new__ (cls , input , target ):
2323 """
24- Instanciate the correct subclass of InputTargetCondition by checking the
25- type of the input and target data.
26-
27- :param input: torch.Tensor or Graph/Data object containing the input
28- :type input: torch.Tensor or Graph or Data
29- :param target: torch.Tensor or Graph/Data object containing the target
30- :type target: torch.Tensor or Graph or Data
31- :return: InputTargetCondition subclass
32- :rtype: TensorInputTensorTargetCondition or
33- TensorInputGraphTargetCondition or GraphInputTensorTargetCondition
34- or GraphInputGraphTargetCondition
24+ Instantiate the appropriate subclass of InputTargetCondition based on
25+ the types of input and target data.
26+
27+ :param input: Input data for the condition.
28+ :type input: torch.Tensor | Graph | Data | list | tuple
29+ :param target: Target data for the condition.
30+ Graph, Data, or list/tuple.
31+ :type target: torch.Tensor | Graph | Data | list | tuple
32+ :return: Subclass of InputTargetCondition
33+ :rtype: TensorInputTensorTargetCondition |
34+ TensorInputGraphTargetCondition |
35+ GraphInputTensorTargetCondition |
36+ GraphInputGraphTargetCondition
37+
38+ :raises ValueError: If input and or target are not of type
39+ :class:`torch.Tensor`, :class:`LabelTensor`, :class:`Graph`, or
40+ :class:`Data`.
3541 """
3642 if cls != InputTargetCondition :
3743 return super ().__new__ (cls )
@@ -74,10 +80,16 @@ def __init__(self, input, target):
7480 Initialize the InputTargetCondition, storing the input and target data.
7581
7682 :param input: torch.Tensor or Graph/Data object containing the input
77- :type input: torch.Tensor or Graph or Data
83+ :type input: torch.Tensor | Graph or Data
7884 :param target: torch.Tensor or Graph/Data object containing the target
7985 :type target: torch.Tensor or Graph or Data
86+
87+ .. note::
88+ If either ``input`` or ``target`` are composed by a list of
89+ :class:`Graph`/:class:`Data` objects, all elements must have the
90+ same structure (keys and data types)
8091 """
92+
8193 super ().__init__ ()
8294 self ._check_input_target_len (input , target )
8395 self .input = input
@@ -97,25 +109,27 @@ def _check_input_target_len(input, target):
97109
98110class TensorInputTensorTargetCondition (InputTargetCondition ):
99111 """
100- InputTargetCondition subclass for torch.Tensor input and target data.
112+ InputTargetCondition subclass for :class:`torch.Tensor`/:class:`LabelTensor`
113+ input and target data.
101114 """
102115
103116
104117class TensorInputGraphTargetCondition (InputTargetCondition ):
105118 """
106- InputTargetCondition subclass for torch.Tensor input and Graph/Data target
107- data.
119+ InputTargetCondition subclass for :class:` torch.Tensor`/:class:`LabelTensor`
120+ input and :class:`Graph`/:class:`Data` target data.
108121 """
109122
110123
111124class GraphInputTensorTargetCondition (InputTargetCondition ):
112125 """
113- InputTargetCondition subclass for Graph/ Data input and torch.Tensor target
114- data.
126+ InputTargetCondition subclass for :class:` Graph`/:class:` Data` input and
127+ :class:`torch.Tensor`/:class:`LabelTensor` target data.
115128 """
116129
117130
118131class GraphInputGraphTargetCondition (InputTargetCondition ):
119132 """
120- InputTargetCondition subclass for Graph/Data input and target data.
133+ InputTargetCondition subclass for :class:`Graph`/:class:`Data` input and
134+ target data.
121135 """
0 commit comments