Skip to content

Commit 502973f

Browse files
authored
refactor(autojac): Stop cloning in Accumulate (#511)
1 parent 8de14e0 commit 502973f

File tree

3 files changed

+17
-6
lines changed

3 files changed

+17
-6
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ changelog does not include internal changes that do not affect the user.
88

99
## [Unreleased]
1010

11+
### Changed
12+
13+
- Removed an unnecessary internal cloning of gradient. This should slightly improve the memory
14+
efficiency of `autojac`.
15+
1116
## [0.8.1] - 2026-01-07
1217

1318
### Added

src/torchjd/autojac/_transform/_accumulate.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,16 @@ def __call__(self, gradients: TensorDict) -> TensorDict:
1515
if hasattr(key, "grad") and key.grad is not None:
1616
key.grad += gradients[key]
1717
else:
18-
# We clone the value because we do not want subsequent accumulations to also affect
19-
# this value (in case it is still used outside). We do not detach from the
20-
# computation graph because the value can have grad_fn that we want to keep track of
21-
# (in case it was obtained via create_graph=True and a differentiable aggregator).
22-
key.grad = gradients[key].clone()
18+
# We do not clone the value to save memory and time, so subsequent modifications of
19+
# the value of key.grad (subsequent accumulations) will also affect the value of
20+
# gradients[key] and outside changes to the value of gradients[key] will also affect
21+
# the value of key.grad. So to be safe, the values of gradients should not be used
22+
# anymore after being passed to this function.
23+
#
24+
# We do not detach from the computation graph because the value can have grad_fn
25+
# that we want to keep track of (in case it was obtained via create_graph=True and a
26+
# differentiable aggregator).
27+
key.grad = gradients[key]
2328

2429
return {}
2530

tests/unit/autojac/_transform/test_accumulate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,12 @@ def test_multiple_accumulation(iterations: int):
4545
value1 = ones_([])
4646
value2 = ones_([1])
4747
value3 = ones_([2, 3])
48-
input = {key1: value1, key2: value2, key3: value3}
4948

5049
accumulate = Accumulate()
5150

5251
for i in range(iterations):
52+
# Clone values to ensure that we accumulate values that are not ever used afterwards
53+
input = {key1: value1.clone(), key2: value2.clone(), key3: value3.clone()}
5354
accumulate(input)
5455

5556
grads = {key1: key1.grad, key2: key2.grad, key3: key3.grad}

0 commit comments

Comments
 (0)