@@ -40,72 +40,54 @@ def __str__(self) -> str:
4040 return type (self ).__name__
4141
4242 @abstractmethod
43- def _compute (self , input : _B ) -> _C :
44- """Applies the transform to the input."""
45-
4643 def __call__ (self , input : _B ) -> _C :
47- input .check_keys_are (self .required_keys )
48- return self ._compute (input )
44+ """Applies the transform to the input."""
4945
50- @property
5146 @abstractmethod
52- def required_keys (self ) -> set [Tensor ]:
53- """
54- Returns the set of keys that the transform requires to be present in its input TensorDicts.
47+ def check_and_get_keys (self ) -> tuple [set [Tensor ], set [Tensor ]]:
5548 """
49+ Returns a pair containing (in order) the required keys and the output keys of the Transform
50+ and recursively checks that the transform is valid.
5651
57- @property
58- @abstractmethod
59- def output_keys (self ) -> set [Tensor ]:
60- """Returns the set of keys that will be present in the output of the transform."""
52+ The required keys are the set of keys that the transform requires to be present in its input
53+ TensorDicts. The output keys are the set of keys that will be present in the output
54+ TensorDicts of the transform.
55+
56+ Since the computation of the required and output keys and the verification that the
57+ transform is valid are sometimes intertwined operations, we do them in a single method.
58+ """
6159
6260 __lshift__ = compose
6361 __or__ = conjunct
6462
6563
6664class Composition (Transform [_A , _C ]):
6765 def __init__ (self , outer : Transform [_B , _C ], inner : Transform [_A , _B ]):
68- if outer .required_keys != inner .output_keys :
69- raise ValueError (
70- "The `output_keys` of `inner` must match with the `required_keys` of "
71- f"outer. Found { outer .required_keys } and { inner .output_keys } "
72- )
7366 self .outer = outer
7467 self .inner = inner
7568
7669 def __str__ (self ) -> str :
7770 return str (self .outer ) + " ∘ " + str (self .inner )
7871
79- def _compute (self , input : _A ) -> _C :
72+ def __call__ (self , input : _A ) -> _C :
8073 intermediate = self .inner (input )
8174 return self .outer (intermediate )
8275
83- @property
84- def required_keys (self ) -> set [Tensor ]:
85- return self .inner .required_keys
86-
87- @property
88- def output_keys (self ) -> set [Tensor ]:
89- return self .outer .output_keys
76+ def check_and_get_keys (self ) -> tuple [set [Tensor ], set [Tensor ]]:
77+ outer_required_keys , outer_output_keys = self .outer .check_and_get_keys ()
78+ inner_required_keys , inner_output_keys = self .inner .check_and_get_keys ()
79+ if outer_required_keys != inner_output_keys :
80+ raise ValueError (
81+ "The `output_keys` of `inner` must match with the `required_keys` of "
82+ f"outer. Found { outer_required_keys } and { inner_output_keys } "
83+ )
84+ return inner_required_keys , outer_output_keys
9085
9186
9287class Conjunction (Transform [_A , _B ]):
9388 def __init__ (self , transforms : Sequence [Transform [_A , _B ]]):
9489 self .transforms = transforms
9590
96- self ._required_keys = set (
97- key for transform in transforms for key in transform .required_keys
98- )
99- for transform in transforms :
100- if transform .required_keys != self .required_keys :
101- raise ValueError ("All transforms should require the same set of keys." )
102-
103- output_keys_with_duplicates = [key for t in transforms for key in t .output_keys ]
104- self ._output_keys = set (output_keys_with_duplicates )
105-
106- if len (self ._output_keys ) != len (output_keys_with_duplicates ):
107- raise ValueError ("The sets of output keys of transforms should be disjoint." )
108-
10991 def __str__ (self ) -> str :
11092 strings = []
11193 for t in self .transforms :
@@ -116,14 +98,22 @@ def __str__(self) -> str:
11698 strings .append (s )
11799 return "(" + " | " .join (strings ) + ")"
118100
119- def _compute (self , tensor_dict : _A ) -> _B :
101+ def __call__ (self , tensor_dict : _A ) -> _B :
120102 output = _union ([transform (tensor_dict ) for transform in self .transforms ])
121103 return output
122104
123- @property
124- def required_keys (self ) -> set [Tensor ]:
125- return self ._required_keys
105+ def check_and_get_keys (self ) -> tuple [set [Tensor ], set [Tensor ]]:
106+ keys_pairs = [transform .check_and_get_keys () for transform in self .transforms ]
107+
108+ required_keys = set (key for required_keys , _ in keys_pairs for key in required_keys )
109+ for transform_required_keys , _ in keys_pairs :
110+ if transform_required_keys != required_keys :
111+ raise ValueError ("All transforms should require the same set of keys." )
112+
113+ output_keys_with_duplicates = [key for _ , output_keys in keys_pairs for key in output_keys ]
114+ output_keys = set (output_keys_with_duplicates )
115+
116+ if len (output_keys ) != len (output_keys_with_duplicates ):
117+ raise ValueError ("The sets of output keys of transforms should be disjoint." )
126118
127- @property
128- def output_keys (self ) -> set [Tensor ]:
129- return self ._output_keys
119+ return required_keys , output_keys
0 commit comments