Skip to content

Commit 87b66f8

Browse files
committed
Fix comment in accumulate_jacs that applied to accumulate_grads
1 parent 8b3d447 commit 87b66f8

1 file changed

Lines changed: 4 additions & 5 deletions

File tree

src/torchjd/autojac/_accumulation.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,13 @@ def accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> No
3131
param_.jac += jac
3232
else:
3333
# We do not clone the value to save memory and time, so subsequent modifications of
34-
# the value of key.grad (subsequent accumulations) will also affect the value of
35-
# gradients[key] and outside changes to the value of gradients[key] will also affect
36-
# the value of key.grad. So to be safe, the values of gradients should not be used
34+
# the value of key.jac (subsequent accumulations) will also affect the value of
35+
# jacobians[key] and outside changes to the value of jacobians[key] will also affect
36+
# the value of key.jac. So to be safe, the values of jacobians should not be used
3737
# anymore after being passed to this function.
3838
#
3939
# We do not detach from the computation graph because the value can have grad_fn
40-
# that we want to keep track of (in case it was obtained via create_graph=True and a
41-
# differentiable aggregator).
40+
# that we want to keep track of (in case it was obtained via create_graph=True).
4241
param.__setattr__("jac", jac)
4342

4443

0 commit comments

Comments
 (0)