Skip to content

Commit 27a4245

Browse files
committed
Add test_empty_outputs for Jac
1 parent 7a29ee7 commit 27a4245

1 file changed

Lines changed: 22 additions & 0 deletions

File tree

tests/unit/autojac/_transform/test_jac.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,28 @@ def test_empty_inputs_2(chunk_size: int | None):
7272
assert_tensor_dicts_are_close(jacobians, expected_jacobians)
7373

7474

75+
@mark.parametrize("chunk_size", [1, 3, None])
76+
def test_empty_outputs(chunk_size: int | None):
77+
"""
78+
Tests that the Jac transform works correctly when the `outputs` parameter is an empty
79+
`Iterable`.
80+
"""
81+
82+
a1 = torch.tensor(1.0, requires_grad=True)
83+
a2 = torch.tensor(1.0, requires_grad=True)
84+
input = Jacobians({})
85+
86+
jac = Jac(outputs=[], inputs=[a1, a2], chunk_size=chunk_size)
87+
88+
jacobians = jac(input)
89+
expected_jacobians = {
90+
a1: torch.empty_like(a1).unsqueeze(0)[:0], # Jacobian with no row
91+
a2: torch.empty_like(a2).unsqueeze(0)[:0], # Jacobian with no row
92+
}
93+
94+
assert_tensor_dicts_are_close(jacobians, expected_jacobians)
95+
96+
7597
def test_retain_graph():
7698
"""Tests that the `Jac` transform behaves as expected with the `retain_graph` flag."""
7799

0 commit comments

Comments
 (0)