Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
ae6be7d
WIP: add gramian-based jac_to_grad
ValerianRey Jan 21, 2026
8bdf512
Update changelog
ValerianRey Jan 21, 2026
aaf2544
Use deque to free memory asap
ValerianRey Jan 23, 2026
64b06ad
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Jan 23, 2026
745f707
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Jan 28, 2026
5eb77f9
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Jan 28, 2026
8f65caa
Use gramian_weighting in jac_to_grad
ValerianRey Jan 28, 2026
6fe15a4
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Jan 28, 2026
d5cb5c2
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Jan 29, 2026
f986950
Only optimize when no forward hooks
ValerianRey Jan 29, 2026
4cf5cbb
Make _gramian_based take aggregator instead of weighting
ValerianRey Jan 29, 2026
add549c
Add _can_skip_jacobian_combination helper function
ValerianRey Jan 29, 2026
453971a
Add test_can_skip_jacobian_combination
ValerianRey Jan 29, 2026
9d4c41c
Optimize compute_gramian for when contracted_dims=-1
ValerianRey Jan 29, 2026
48cd70b
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Jan 30, 2026
8f2660d
Use TypeGuard in _can_skip_jacobian_combination
ValerianRey Jan 30, 2026
fc9bbcf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 30, 2026
3f9a6d1
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Feb 1, 2026
9d9cbf0
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Feb 4, 2026
b5ca226
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 4, 2026
0baa914
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Feb 5, 2026
2ed1d7c
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Feb 13, 2026
86be778
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 13, 2026
4ace19e
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Feb 13, 2026
2a84bef
Add ruff if-else squeezing
ValerianRey Feb 13, 2026
4b6209c
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Feb 23, 2026
1b1c660
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 23, 2026
b714253
Many fixes of problems coming from the merge
ValerianRey Feb 23, 2026
2bb8ab1
Fix _can_skip_jacobian_combination
ValerianRey Feb 23, 2026
63c9dde
Make check_consistent_first_dimension work with Deque
ValerianRey Feb 23, 2026
0f85811
Improve test_can_skip_jacobian_combination
ValerianRey Feb 23, 2026
9d55215
Add optimize_gramian_computation param and add error when not compatible
ValerianRey Feb 23, 2026
456510b
Fix overloads (partly) and add missing code coverage
ValerianRey Feb 23, 2026
55c69d1
Fix overloads
ValerianRey Feb 23, 2026
8a401a3
Fix docstring
ValerianRey Feb 23, 2026
b4bf7c4
fixup what @ValerianRey did wrong
PierreQuinton Feb 23, 2026
24a991a
Improve error message
ValerianRey Feb 23, 2026
2ea44a4
Improve docstring
ValerianRey Feb 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,8 @@ changelog does not include internal changes that do not affect the user.
jac_to_grad(shared_module.parameters(), aggregator)
```

- Removed an unnecessary memory duplication. This should significantly improve the memory efficiency
of `autojac`.
- Removed an unnecessary internal cloning of gradient. This should slightly improve the memory
efficiency of `autojac`.
- Removed several unnecessary memory duplications. This should significantly improve the memory
efficiency and speed of `autojac`.

## [0.8.1] - 2026-01-07

Expand Down
56 changes: 46 additions & 10 deletions src/torchjd/autojac/_jac_to_grad.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from collections import deque
from collections.abc import Iterable
from typing import cast
Comment thread
ValerianRey marked this conversation as resolved.
Outdated

import torch
from torch import Tensor

from torchjd.aggregation import Aggregator
from torchjd._linalg import PSDMatrix, compute_gramian
from torchjd.aggregation import Aggregator, Weighting
from torchjd.aggregation._aggregator_bases import GramianWeightedAggregator

from ._accumulation import TensorWithJac, accumulate_grads, is_tensor_with_jac

Expand Down Expand Up @@ -63,29 +67,61 @@ def jac_to_grad(
if len(tensors_) == 0:
return

jacobians = [t.jac for t in tensors_]
jacobians = deque(t.jac for t in tensors_)

if not all([jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians[1:]]):
if not all([jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians]):
raise ValueError("All Jacobians should have the same number of rows.")

if not retain_jac:
_free_jacs(tensors_)

if isinstance(aggregator, GramianWeightedAggregator):
# When it's possible, avoid the concatenation of the jacobians that can be very costly in
# memory.
gradients = _gramian_based(aggregator.weighting.weighting, jacobians, tensors_)
Comment thread
ValerianRey marked this conversation as resolved.
Outdated
Comment thread
ValerianRey marked this conversation as resolved.
Outdated
else:
gradients = _jacobian_based(aggregator, jacobians, tensors_)
accumulate_grads(tensors_, gradients)


def _jacobian_based(
aggregator: Aggregator, jacobians: deque[Tensor], tensors: list[TensorWithJac]
) -> list[Tensor]:
jacobian_matrix = _unite_jacobians(jacobians)
gradient_vector = aggregator(jacobian_matrix)
gradients = _disunite_gradient(gradient_vector, jacobians, tensors_)
accumulate_grads(tensors_, gradients)
gradients = _disunite_gradient(gradient_vector, tensors)
return gradients


def _unite_jacobians(jacobians: list[Tensor]) -> Tensor:
jacobian_matrices = [jacobian.reshape(jacobian.shape[0], -1) for jacobian in jacobians]
def _gramian_based(
weighting: Weighting[PSDMatrix], jacobians: deque[Tensor], tensors: list[TensorWithJac]
) -> list[Tensor]:
gramian = _compute_gramian_sum(jacobians)
weights = weighting(gramian)

gradients = list[Tensor]()
while jacobians:
jacobian = jacobians.popleft() # get jacobian + dereference it to free memory asap
gradients.append(torch.tensordot(weights, jacobian, dims=1))

return gradients


def _compute_gramian_sum(jacobians: deque[Tensor]) -> PSDMatrix:
gramian = sum([compute_gramian(matrix) for matrix in jacobians])
return cast(PSDMatrix, gramian)


def _unite_jacobians(jacobians: deque[Tensor]) -> Tensor:
jacobian_matrices = list[Tensor]()
while jacobians:
jacobian = jacobians.popleft() # get jacobian + dereference it to free memory asap
jacobian_matrices.append(jacobian.reshape(jacobian.shape[0], -1))
jacobian_matrix = torch.concat(jacobian_matrices, dim=1)
return jacobian_matrix


def _disunite_gradient(
gradient_vector: Tensor, jacobians: list[Tensor], tensors: list[TensorWithJac]
) -> list[Tensor]:
def _disunite_gradient(gradient_vector: Tensor, tensors: list[TensorWithJac]) -> list[Tensor]:
gradient_vectors = gradient_vector.split([t.numel() for t in tensors])
gradients = [g.view(t.shape) for g, t in zip(gradient_vectors, tensors)]
return gradients
Expand Down
Loading