Skip to content

Commit b514092

Browse files
authored
test(autojac): Add tests for differentiation with empty outputs (#273)
* Add test_empty_outputs for Grad * Add test_empty_outputs for Jac
1 parent e951575 commit b514092

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

tests/unit/autojac/_transform/test_grad.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,24 @@ def test_empty_inputs_2():
6262
assert_tensor_dicts_are_close(gradients, expected_gradients)
6363

6464

65+
def test_empty_outputs():
66+
"""
67+
Tests that the Grad transform works correctly when the `outputs` parameter is an empty
68+
`Iterable`.
69+
"""
70+
71+
a1 = torch.tensor(1.0, requires_grad=True)
72+
a2 = torch.tensor([1.0, 2.0], requires_grad=True)
73+
input = Gradients({})
74+
75+
grad = Grad(outputs=[], inputs=[a1, a2])
76+
77+
gradients = grad(input)
78+
expected_gradients = {a1: torch.zeros_like(a1), a2: torch.zeros_like(a2)}
79+
80+
assert_tensor_dicts_are_close(gradients, expected_gradients)
81+
82+
6583
def test_retain_graph():
6684
"""Tests that the `Grad` transform behaves as expected with the `retain_graph` flag."""
6785

tests/unit/autojac/_transform/test_jac.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,28 @@ def test_empty_inputs_2(chunk_size: int | None):
7272
assert_tensor_dicts_are_close(jacobians, expected_jacobians)
7373

7474

75+
@mark.parametrize("chunk_size", [1, 3, None])
76+
def test_empty_outputs(chunk_size: int | None):
77+
"""
78+
Tests that the Jac transform works correctly when the `outputs` parameter is an empty
79+
`Iterable`.
80+
"""
81+
82+
a1 = torch.tensor(1.0, requires_grad=True)
83+
a2 = torch.tensor([1.0, 2.0], requires_grad=True)
84+
input = Jacobians({})
85+
86+
jac = Jac(outputs=[], inputs=[a1, a2], chunk_size=chunk_size)
87+
88+
jacobians = jac(input)
89+
expected_jacobians = {
90+
a1: torch.empty_like(a1).unsqueeze(0)[:0], # Jacobian with no row
91+
a2: torch.empty_like(a2).unsqueeze(0)[:0], # Jacobian with no row
92+
}
93+
94+
assert_tensor_dicts_are_close(jacobians, expected_jacobians)
95+
96+
7597
def test_retain_graph():
7698
"""Tests that the `Jac` transform behaves as expected with the `retain_graph` flag."""
7799

0 commit comments

Comments
 (0)