Skip to content

Commit 77309f4

Browse files
perf(autojac): Add optimize_gramian_computation to jac_to_grad (#525)
* Add optimization to compute_gramian * Add optimize_gramian_computation parameter to jac_to_grad, with the effect of skipping the Jacobian computation if set to True * Use deque for the Jacobians to free memory asap * Make check_consistent_first_dimension work with deque too * Add some tests * Update changelog --------- Co-authored-by: Pierre Quinton <pierre.quinton@gmail.com>
1 parent e9f3cca commit 77309f4

File tree

5 files changed

+265
-31
lines changed

5 files changed

+265
-31
lines changed

CHANGELOG.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,8 @@ changelog does not include internal changes that do not affect the user.
6868
- `GeneralizedWeighting.__call__`: The `generalized_gramian` parameter is now positional-only.
6969
Suggested change: `generalized_weighting(generalized_gramian=generalized_gramian)` =>
7070
`generalized_weighting(generalized_gramian)`.
71-
- Removed an unnecessary memory duplication. This should significantly improve the memory efficiency
72-
of `autojac`.
73-
- Removed an unnecessary internal cloning of gradient. This should slightly improve the memory
74-
efficiency of `autojac`.
71+
- Removed several unnecessary memory duplications. This should significantly improve the memory
72+
efficiency and speed of `autojac`.
7573
- Increased the lower bounds of the torch (from 2.0.0 to 2.3.0) and numpy (from 1.21.0
7674
to 1.21.2) dependencies to reflect what really works with torchjd. We now also run torchjd's tests
7775
with the dependency lower-bounds specified in `pyproject.toml`, so we should now always accurately

src/torchjd/_linalg/_gramian.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,20 @@ def compute_gramian(t: Tensor, contracted_dims: int = -1) -> PSDTensor:
3535
first dimension).
3636
"""
3737

38-
contracted_dims = contracted_dims if contracted_dims >= 0 else contracted_dims + t.ndim
39-
indices_source = list(range(t.ndim - contracted_dims))
40-
indices_dest = list(range(t.ndim - 1, contracted_dims - 1, -1))
41-
transposed = t.movedim(indices_source, indices_dest)
42-
gramian = torch.tensordot(t, transposed, dims=contracted_dims)
38+
# Optimization: it's faster to do that than moving dims and using tensordot, and this case
39+
# happens very often, sometimes hundreds of times for a single jac_to_grad.
40+
if contracted_dims == -1:
41+
matrix = t.unsqueeze(1) if t.ndim == 1 else t.flatten(start_dim=1)
42+
43+
gramian = matrix @ matrix.T
44+
45+
else:
46+
contracted_dims = contracted_dims if contracted_dims >= 0 else contracted_dims + t.ndim
47+
indices_source = list(range(t.ndim - contracted_dims))
48+
indices_dest = list(range(t.ndim - 1, contracted_dims - 1, -1))
49+
transposed = t.movedim(indices_source, indices_dest)
50+
gramian = torch.tensordot(t, transposed, dims=contracted_dims)
51+
4352
return cast(PSDTensor, gramian)
4453

4554

src/torchjd/autojac/_jac_to_grad.py

Lines changed: 100 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1+
from collections import deque
12
from collections.abc import Iterable
2-
from typing import overload
3+
from typing import TypeGuard, cast, overload
34

45
import torch
5-
from torch import Tensor
6+
from torch import Tensor, nn
67

7-
from torchjd._linalg import Matrix
8+
from torchjd._linalg import Matrix, PSDMatrix, compute_gramian
89
from torchjd.aggregation import Aggregator, Weighting
9-
from torchjd.aggregation._aggregator_bases import WeightedAggregator
10+
from torchjd.aggregation._aggregator_bases import GramianWeightedAggregator, WeightedAggregator
1011

1112
from ._accumulation import TensorWithJac, accumulate_grads, is_tensor_with_jac
1213
from ._utils import check_consistent_first_dimension
@@ -16,7 +17,18 @@
1617
def jac_to_grad(
1718
tensors: Iterable[Tensor],
1819
/,
19-
aggregator: WeightedAggregator,
20+
aggregator: GramianWeightedAggregator,
21+
*,
22+
retain_jac: bool = False,
23+
optimize_gramian_computation: bool = False,
24+
) -> Tensor: ...
25+
26+
27+
@overload
28+
def jac_to_grad(
29+
tensors: Iterable[Tensor],
30+
/,
31+
aggregator: WeightedAggregator, # Not a GramianWA, because overloads are checked in order
2032
*,
2133
retain_jac: bool = False,
2234
) -> Tensor: ...
@@ -38,6 +50,7 @@ def jac_to_grad(
3850
aggregator: Aggregator,
3951
*,
4052
retain_jac: bool = False,
53+
optimize_gramian_computation: bool = False,
4154
) -> Tensor | None:
4255
r"""
4356
Aggregates the Jacobians stored in the ``.jac`` fields of ``tensors`` and accumulates the result
@@ -50,12 +63,27 @@ def jac_to_grad(
5063
the Jacobians, ``jac_to_grad`` will also return the computed weights.
5164
:param retain_jac: Whether to preserve the ``.jac`` fields of the tensors after they have been
5265
used. Defaults to ``False``.
66+
:param optimize_gramian_computation: When the ``aggregator`` computes weights based on the
67+
Gramian of the Jacobian, it's possible to skip the concatenation of the Jacobians and to
68+
instead compute the Gramian as the sum of the Gramians of the individual Jacobians. This
69+
saves memory (up to 50% memory saving) but can be slightly slower (up to 15%) on CUDA. We
70+
advise to try this optimization if memory is an issue for you. Defaults to ``False``.
5371
5472
.. note::
55-
This function starts by "flattening" the ``.jac`` fields into matrices (i.e. flattening all
56-
of their dimensions except the first one), then concatenates those matrices into a combined
57-
Jacobian matrix. The aggregator is then used on this matrix, which returns a combined
58-
gradient vector, that is split and reshaped to fit into the ``.grad`` fields of the tensors.
73+
When ``optimize_gramian_computation=False``, this function starts by "flattening" the
74+
``.jac`` fields into matrices (i.e. flattening all of their dimensions except the first
75+
one), then concatenates those matrices into a combined Jacobian matrix. The ``aggregator``
76+
is then used on this matrix, which returns a combined gradient vector, that is split and
77+
reshaped to fit into the ``.grad`` fields of the tensors.
78+
79+
.. note::
80+
When ``optimize_gramian_computation=True``, this function computes and sums the Gramian
81+
of each individual ``.jac`` field, iteratively. The inner weighting of the ``aggregator`` is
82+
then used to extract some weights from the obtained Gramian, used to compute a linear
83+
combination of the rows of each ``.jac`` field, to be stored into the corresponding
84+
``.grad`` field. This is mathematically equivalent to the approach with
85+
``optimize_gramian_computation=False``, but saves memory by not having to hold the
86+
concatenated Jacobian matrix in memory at any time.
5987
6088
.. admonition::
6189
Example
@@ -96,13 +124,46 @@ def jac_to_grad(
96124
if len(tensors_) == 0:
97125
raise ValueError("The `tensors` parameter cannot be empty.")
98126

99-
jacobians = [t.jac for t in tensors_]
100-
127+
jacobians = deque(t.jac for t in tensors_)
101128
check_consistent_first_dimension(jacobians, "tensors.jac")
102129

103130
if not retain_jac:
104131
_free_jacs(tensors_)
105132

133+
if optimize_gramian_computation:
134+
if not _can_skip_jacobian_combination(aggregator):
135+
raise ValueError(
136+
"In order to use `jac_to_grad` with `optimize_gramian_computation=True`, you must "
137+
"provide an `Aggregator` that computes weights based on the Gramian of the Jacobian"
138+
" (e.g. `UPGrad`) and that doesn't have any forward hooks attached to it."
139+
)
140+
141+
gradients, weights = _gramian_based(aggregator, jacobians)
142+
else:
143+
gradients, weights = _jacobian_based(aggregator, jacobians, tensors_)
144+
accumulate_grads(tensors_, gradients)
145+
146+
return weights
147+
148+
149+
def _can_skip_jacobian_combination(aggregator: Aggregator) -> TypeGuard[GramianWeightedAggregator]:
150+
return (
151+
isinstance(aggregator, GramianWeightedAggregator)
152+
and not _has_forward_hook(aggregator)
153+
and not _has_forward_hook(aggregator.weighting)
154+
)
155+
156+
157+
def _has_forward_hook(module: nn.Module) -> bool:
158+
"""Return whether the module has any forward hook registered."""
159+
return len(module._forward_hooks) > 0 or len(module._forward_pre_hooks) > 0
160+
161+
162+
def _jacobian_based(
163+
aggregator: Aggregator,
164+
jacobians: deque[Tensor],
165+
tensors: list[TensorWithJac],
166+
) -> tuple[list[Tensor], Tensor | None]:
106167
jacobian_matrix = _unite_jacobians(jacobians)
107168
weights: Tensor | None = None
108169

@@ -124,13 +185,36 @@ def capture_hook(_m: Weighting[Matrix], _i: tuple[Tensor], output: Tensor) -> No
124185
handle.remove()
125186
else:
126187
gradient_vector = aggregator(jacobian_matrix)
127-
gradients = _disunite_gradient(gradient_vector, tensors_)
128-
accumulate_grads(tensors_, gradients)
129-
return weights
188+
gradients = _disunite_gradient(gradient_vector, tensors)
189+
return gradients, weights
190+
191+
192+
def _gramian_based(
193+
aggregator: GramianWeightedAggregator,
194+
jacobians: deque[Tensor],
195+
) -> tuple[list[Tensor], Tensor]:
196+
weighting = aggregator.gramian_weighting
197+
gramian = _compute_gramian_sum(jacobians)
198+
weights = weighting(gramian)
199+
200+
gradients = list[Tensor]()
201+
while jacobians:
202+
jacobian = jacobians.popleft() # get jacobian + dereference it to free memory asap
203+
gradients.append(torch.tensordot(weights, jacobian, dims=1))
204+
205+
return gradients, weights
206+
207+
208+
def _compute_gramian_sum(jacobians: deque[Tensor]) -> PSDMatrix:
209+
gramian = sum([compute_gramian(matrix) for matrix in jacobians])
210+
return cast(PSDMatrix, gramian)
130211

131212

132-
def _unite_jacobians(jacobians: list[Tensor]) -> Tensor:
133-
jacobian_matrices = [jacobian.reshape(jacobian.shape[0], -1) for jacobian in jacobians]
213+
def _unite_jacobians(jacobians: deque[Tensor]) -> Tensor:
214+
jacobian_matrices = list[Tensor]()
215+
while jacobians:
216+
jacobian = jacobians.popleft() # get jacobian + dereference it to free memory asap
217+
jacobian_matrices.append(jacobian.reshape(jacobian.shape[0], -1))
134218
jacobian_matrix = torch.concat(jacobian_matrices, dim=1)
135219
return jacobian_matrix
136220

src/torchjd/autojac/_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,9 @@ def check_consistent_first_dimension(
113113
:param jacobians: Sequence of Jacobian tensors to validate.
114114
:param variable_name: Name of the variable to include in the error message.
115115
"""
116+
116117
if len(jacobians) > 0 and not all(
117-
jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians[1:]
118+
jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians
118119
):
119120
raise ValueError(f"All Jacobians in `{variable_name}` should have the same number of rows.")
120121

0 commit comments

Comments
 (0)