Skip to content

Commit 25342e2

Browse files
committed
Make jac behave like autograd.grad when inputs are repeated
1 parent eac2016 commit 25342e2

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

src/torchjd/autojac/_jac.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,10 @@ def jac(
118118

119119
if inputs is None:
120120
inputs_ = get_leaf_tensors(tensors=outputs_, excluded=set())
121+
inputs_with_repetition = list(inputs_)
121122
else:
122-
inputs_ = OrderedSet(inputs)
123+
inputs_with_repetition = list(inputs) # Create a list to avoid emptying generator
124+
inputs_ = OrderedSet(inputs_with_repetition)
123125

124126
jac_transform = _create_transform(
125127
outputs=outputs_,
@@ -129,7 +131,7 @@ def jac(
129131
)
130132

131133
result = jac_transform({})
132-
return tuple(val for val in result.values())
134+
return tuple(result[input] for input in inputs_with_repetition)
133135

134136

135137
def _create_transform(

tests/unit/autojac/test_jac.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -261,9 +261,9 @@ def test_repeated_tensors():
261261

262262
def test_repeated_inputs():
263263
"""
264-
Tests that jac correctly works when some inputs are repeated. This behaviour is different than
265-
torch.autograd.grad, which would repeat the output gradients as many times as the inputs are
266-
repeated.
264+
Tests that jac correctly works when some inputs are repeated. In this case, since
265+
torch.autograd.grad repeats the output gradients, it is natural for autojac to also repeat the
266+
output jacobians.
267267
"""
268268

269269
a1 = tensor_([1.0, 2.0], requires_grad=True)
@@ -276,6 +276,7 @@ def test_repeated_inputs():
276276
J2 = tensor_([[1.0, 1.0], [6.0, 8.0]])
277277

278278
jacobians = jac([y1, y2], inputs=[a1, a1, a2])
279-
assert len(jacobians) == 2
279+
assert len(jacobians) == 3
280280
assert_close(jacobians[0], J1)
281-
assert_close(jacobians[1], J2)
281+
assert_close(jacobians[1], J1)
282+
assert_close(jacobians[2], J2)

0 commit comments

Comments
 (0)