We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 3f11448 commit 5a97c18Copy full SHA for 5a97c18
1 file changed
tests/unit/autojac/_transform/test_stack.py
@@ -20,13 +20,8 @@ def __init__(self, keys: Iterable[Tensor]):
20
def _compute(self, input: EmptyTensorDict) -> Gradients:
21
return Gradients({key: torch.ones_like(key) for key in self.keys})
22
23
- @property
24
- def required_keys(self) -> set[Tensor]:
25
- return set()
26
-
27
28
- def output_keys(self) -> set[Tensor]:
29
- return self.keys
+ def check_keys(self) -> tuple[set[Tensor], set[Tensor]]:
+ return set(), self.keys
30
31
32
def test_single_key():
0 commit comments