Skip to content

Commit 5a97c18

Browse files
committed
add check_keys in test_stack.FakeGradientsTransform
1 parent 3f11448 commit 5a97c18

1 file changed

Lines changed: 2 additions & 7 deletions

File tree

tests/unit/autojac/_transform/test_stack.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,8 @@ def __init__(self, keys: Iterable[Tensor]):
2020
def _compute(self, input: EmptyTensorDict) -> Gradients:
2121
return Gradients({key: torch.ones_like(key) for key in self.keys})
2222

23-
@property
24-
def required_keys(self) -> set[Tensor]:
25-
return set()
26-
27-
@property
28-
def output_keys(self) -> set[Tensor]:
29-
return self.keys
23+
def check_keys(self) -> tuple[set[Tensor], set[Tensor]]:
24+
return set(), self.keys
3025

3126

3227
def test_single_key():

0 commit comments

Comments
 (0)