Skip to content

Commit 8d3d98d

Browse files
committed
Same as before for FakeTransforms
1 parent df6c200 commit 8d3d98d

2 files changed

Lines changed: 2 additions & 2 deletions

File tree

tests/unit/autojac/_transform/test_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __init__(self, required_keys: set[Tensor], output_keys: set[Tensor]):
2121
def __str__(self):
2222
return "T"
2323

24-
def _compute(self, input: _B) -> _C:
24+
def __call__(self, input: _B) -> _C:
2525
# Ignore the input, create a dictionary with the right keys as an output.
2626
# Cast the type for the purpose of type-checking.
2727
output_dict = {key: torch.empty(0) for key in self._output_keys}

tests/unit/autojac/_transform/test_stack.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class FakeGradientsTransform(Transform[EmptyTensorDict, Gradients]):
1717
def __init__(self, keys: Iterable[Tensor]):
1818
self.keys = set(keys)
1919

20-
def _compute(self, input: EmptyTensorDict) -> Gradients:
20+
def __call__(self, input: EmptyTensorDict) -> Gradients:
2121
return Gradients({key: torch.ones_like(key) for key in self.keys})
2222

2323
def check_keys(self) -> tuple[set[Tensor], set[Tensor]]:

0 commit comments

Comments
 (0)