Skip to content

Commit 0bcf1c0

Browse files
committed
Merge branch 'main' into revamp-interface
2 parents e08e45f + 502973f commit 0bcf1c0

3 files changed

Lines changed: 18 additions & 8 deletions

File tree

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ changelog does not include internal changes that do not affect the user.
88

99
## [Unreleased]
1010

11+
### Changed
12+
13+
- Removed an unnecessary internal cloning of gradient. This should slightly improve the memory
14+
efficiency of `autojac`.
15+
1116
## [0.8.1] - 2026-01-07
1217

1318
### Added

src/torchjd/utils/_accumulation.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,16 @@ def _accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> N
1313
param_ = cast(TensorWithJac, param)
1414
param_.jac += jac
1515
else:
16-
# TODO: this could be a serious memory issue
17-
# We clone the value because we do not want subsequent accumulations to also affect
18-
# this value (in case it is still used outside). We do not detach from the
19-
# computation graph because the value can have grad_fn that we want to keep track of
20-
# (in case it was obtained via create_graph=True and a differentiable aggregator).
21-
param.__setattr__("jac", jac.clone())
16+
# We do not clone the value to save memory and time, so subsequent modifications of
17+
# the value of key.grad (subsequent accumulations) will also affect the value of
18+
# gradients[key] and outside changes to the value of gradients[key] will also affect
19+
# the value of key.grad. So to be safe, the values of gradients should not be used
20+
# anymore after being passed to this function.
21+
#
22+
# We do not detach from the computation graph because the value can have grad_fn
23+
# that we want to keep track of (in case it was obtained via create_graph=True and a
24+
# differentiable aggregator).
25+
param.__setattr__("jac", jac)
2226

2327

2428
def _accumulate_grads(params: Iterable[Tensor], gradients: Iterable[Tensor]) -> None:
@@ -27,7 +31,7 @@ def _accumulate_grads(params: Iterable[Tensor], gradients: Iterable[Tensor]) ->
2731
if hasattr(param, "grad") and param.grad is not None:
2832
param.grad += grad
2933
else:
30-
param.grad = grad.clone()
34+
param.grad = grad
3135

3236

3337
def _check_expects_grad(tensor: Tensor) -> None:

tests/unit/autojac/_transform/test_accumulate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,12 @@ def test_multiple_accumulation(iterations: int):
4545
value1 = ones_([])
4646
value2 = ones_([1])
4747
value3 = ones_([2, 3])
48-
input = {key1: value1, key2: value2, key3: value3}
4948

5049
accumulate = Accumulate()
5150

5251
for i in range(iterations):
52+
# Clone values to ensure that we accumulate values that are not ever used afterwards
53+
input = {key1: value1.clone(), key2: value2.clone(), key3: value3.clone()}
5354
accumulate(input)
5455

5556
grads = {key1: key1.grad, key2: key2.grad, key3: key3.grad}

0 commit comments

Comments
 (0)