Skip to content

Commit 7a29ee7

Browse files
committed
Add test_empty_outputs for Grad
1 parent e951575 commit 7a29ee7

1 file changed

Lines changed: 17 additions & 0 deletions

File tree

tests/unit/autojac/_transform/test_grad.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,23 @@ def test_empty_inputs_2():
6262
assert_tensor_dicts_are_close(gradients, expected_gradients)
6363

6464

65+
def test_empty_outputs():
66+
"""
67+
Tests that the Grad transform works correctly when the `outputs` parameter is an empty
68+
`Iterable`.
69+
"""
70+
71+
a = torch.tensor(1.0, requires_grad=True)
72+
input = Gradients({})
73+
74+
grad = Grad(outputs=[], inputs=[a])
75+
76+
gradients = grad(input)
77+
expected_gradients = {a: torch.zeros_like(a)}
78+
79+
assert_tensor_dicts_are_close(gradients, expected_gradients)
80+
81+
6582
def test_retain_graph():
6683
"""Tests that the `Grad` transform behaves as expected with the `retain_graph` flag."""
6784

0 commit comments

Comments
 (0)