Skip to content

Commit 23a7de0

Browse files
Merge branch 'main' into stationarity_property
2 parents fc83099 + b514092 commit 23a7de0

3 files changed

Lines changed: 41 additions & 6 deletions

File tree

src/torchjd/autojac/_transform/grad.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,7 @@ def _differentiate(self, grad_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
3737
return tuple()
3838

3939
if len(outputs) == 0:
40-
return tuple(
41-
[
42-
torch.empty(input.shape, device=input.device, dtype=input.dtype)
43-
for input in inputs
44-
]
45-
)
40+
return tuple([torch.zeros_like(input) for input in inputs])
4641

4742
optional_grads = torch.autograd.grad(
4843
outputs,

tests/unit/autojac/_transform/test_grad.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,24 @@ def test_empty_inputs_2():
6262
assert_tensor_dicts_are_close(gradients, expected_gradients)
6363

6464

65+
def test_empty_outputs():
66+
"""
67+
Tests that the Grad transform works correctly when the `outputs` parameter is an empty
68+
`Iterable`.
69+
"""
70+
71+
a1 = torch.tensor(1.0, requires_grad=True)
72+
a2 = torch.tensor([1.0, 2.0], requires_grad=True)
73+
input = Gradients({})
74+
75+
grad = Grad(outputs=[], inputs=[a1, a2])
76+
77+
gradients = grad(input)
78+
expected_gradients = {a1: torch.zeros_like(a1), a2: torch.zeros_like(a2)}
79+
80+
assert_tensor_dicts_are_close(gradients, expected_gradients)
81+
82+
6583
def test_retain_graph():
6684
"""Tests that the `Grad` transform behaves as expected with the `retain_graph` flag."""
6785

tests/unit/autojac/_transform/test_jac.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,28 @@ def test_empty_inputs_2(chunk_size: int | None):
7272
assert_tensor_dicts_are_close(jacobians, expected_jacobians)
7373

7474

75+
@mark.parametrize("chunk_size", [1, 3, None])
76+
def test_empty_outputs(chunk_size: int | None):
77+
"""
78+
Tests that the Jac transform works correctly when the `outputs` parameter is an empty
79+
`Iterable`.
80+
"""
81+
82+
a1 = torch.tensor(1.0, requires_grad=True)
83+
a2 = torch.tensor([1.0, 2.0], requires_grad=True)
84+
input = Jacobians({})
85+
86+
jac = Jac(outputs=[], inputs=[a1, a2], chunk_size=chunk_size)
87+
88+
jacobians = jac(input)
89+
expected_jacobians = {
90+
a1: torch.empty_like(a1).unsqueeze(0)[:0], # Jacobian with no row
91+
a2: torch.empty_like(a2).unsqueeze(0)[:0], # Jacobian with no row
92+
}
93+
94+
assert_tensor_dicts_are_close(jacobians, expected_jacobians)
95+
96+
7597
def test_retain_graph():
7698
"""Tests that the `Jac` transform behaves as expected with the `retain_graph` flag."""
7799

0 commit comments

Comments
 (0)