diff --git a/tests/unit/autojac/_transform/test_grad.py b/tests/unit/autojac/_transform/test_grad.py index d5bdb2e83..cef17d590 100644 --- a/tests/unit/autojac/_transform/test_grad.py +++ b/tests/unit/autojac/_transform/test_grad.py @@ -62,6 +62,24 @@ def test_empty_inputs_2(): assert_tensor_dicts_are_close(gradients, expected_gradients) +def test_empty_outputs(): + """ + Tests that the Grad transform works correctly when the `outputs` parameter is an empty + `Iterable`. + """ + + a1 = torch.tensor(1.0, requires_grad=True) + a2 = torch.tensor([1.0, 2.0], requires_grad=True) + input = Gradients({}) + + grad = Grad(outputs=[], inputs=[a1, a2]) + + gradients = grad(input) + expected_gradients = {a1: torch.zeros_like(a1), a2: torch.zeros_like(a2)} + + assert_tensor_dicts_are_close(gradients, expected_gradients) + + def test_retain_graph(): """Tests that the `Grad` transform behaves as expected with the `retain_graph` flag.""" diff --git a/tests/unit/autojac/_transform/test_jac.py b/tests/unit/autojac/_transform/test_jac.py index 92be71549..54ac47e34 100644 --- a/tests/unit/autojac/_transform/test_jac.py +++ b/tests/unit/autojac/_transform/test_jac.py @@ -72,6 +72,28 @@ def test_empty_inputs_2(chunk_size: int | None): assert_tensor_dicts_are_close(jacobians, expected_jacobians) +@mark.parametrize("chunk_size", [1, 3, None]) +def test_empty_outputs(chunk_size: int | None): + """ + Tests that the Jac transform works correctly when the `outputs` parameter is an empty + `Iterable`. + """ + + a1 = torch.tensor(1.0, requires_grad=True) + a2 = torch.tensor([1.0, 2.0], requires_grad=True) + input = Jacobians({}) + + jac = Jac(outputs=[], inputs=[a1, a2], chunk_size=chunk_size) + + jacobians = jac(input) + expected_jacobians = { + a1: torch.empty_like(a1).unsqueeze(0)[:0], # Jacobian with no row + a2: torch.empty_like(a2).unsqueeze(0)[:0], # Jacobian with no row + } + + assert_tensor_dicts_are_close(jacobians, expected_jacobians) + + def test_retain_graph(): """Tests that the `Jac` transform behaves as expected with the `retain_graph` flag."""