Skip to content

Commit 4f1016e

Browse files
authored
refactor: Avoid cat in Jac when not needed (#518)
1 parent a258254 commit 4f1016e

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ changelog does not include internal changes that do not affect the user.
3535
jac_to_grad(shared_module.parameters(), aggregator)
3636
```
3737

38+
- Removed an unnecessary memory duplication. This should significantly improve the memory efficiency
39+
of `autojac`.
3840
- Removed an unnecessary internal cloning of gradient. This should slightly improve the memory
3941
efficiency of `autojac`.
4042

src/torchjd/autojac/_transform/_jac.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,12 @@ def _differentiate(self, jac_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
9191
jacs_chunks.append(_get_jacs_chunk(jac_outputs_chunk, get_vjp_last))
9292

9393
n_inputs = len(self.inputs)
94-
jacs = tuple(torch.cat([chunks[i] for chunks in jacs_chunks]) for i in range(n_inputs))
94+
if len(jacs_chunks) == 1:
95+
# Avoid using cat to avoid doubling memory usage, if it's not needed
96+
jacs = jacs_chunks[0]
97+
else:
98+
jacs = tuple(torch.cat([chunks[i] for chunks in jacs_chunks]) for i in range(n_inputs))
99+
95100
return jacs
96101

97102

0 commit comments

Comments
 (0)