File tree Expand file tree Collapse file tree
tests/unit/autojac/_transform Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1111
1212class FakeTransform (Transform [_B , _C ]):
1313 """
14- Fake ``Transform`` to test `required_keys` and `output_keys ` when composing and conjuncting.
14+ Fake ``Transform`` to test `check_keys ` when composing and conjuncting.
1515 """
1616
1717 def __init__ (self , required_keys : set [Tensor ], output_keys : set [Tensor ]):
@@ -28,6 +28,7 @@ def __call__(self, input: _B) -> _C:
2828 return typing .cast (_C , output_dict )
2929
3030 def check_keys (self , input_keys : set [Tensor ]) -> set [Tensor ]:
31+ # Arbitrary requirement for testing purposes.
3132 if not input_keys == self ._required_keys :
3233 raise RequirementError ()
3334 return self ._output_keys
Original file line number Diff line number Diff line change 99
1010
1111class FakeGradientsTransform (Transform [EmptyTensorDict , Gradients ]):
12- """
13- Transform that produces gradients filled with ones, for testing purposes. Note that it does the
14- same thing as Init, but it does not depend on Init.
15- """
12+ """Transform that produces gradients filled with ones, for testing purposes."""
1613
1714 def __init__ (self , keys : Iterable [Tensor ]):
1815 self .keys = set (keys )
You can’t perform that action at this time.
0 commit comments