1+ from collections import deque
12from collections .abc import Iterable
2- from typing import overload
3+ from typing import TypeGuard , cast , overload
34
45import torch
5- from torch import Tensor
6+ from torch import Tensor , nn
67
7- from torchjd ._linalg import Matrix
8+ from torchjd ._linalg import Matrix , PSDMatrix , compute_gramian
89from torchjd .aggregation import Aggregator , Weighting
9- from torchjd .aggregation ._aggregator_bases import WeightedAggregator
10+ from torchjd .aggregation ._aggregator_bases import GramianWeightedAggregator , WeightedAggregator
1011
1112from ._accumulation import TensorWithJac , accumulate_grads , is_tensor_with_jac
1213from ._utils import check_consistent_first_dimension
1617def jac_to_grad (
1718 tensors : Iterable [Tensor ],
1819 / ,
19- aggregator : WeightedAggregator ,
20+ aggregator : GramianWeightedAggregator ,
21+ * ,
22+ retain_jac : bool = False ,
23+ optimize_gramian_computation : bool = False ,
24+ ) -> Tensor : ...
25+
26+
27+ @overload
28+ def jac_to_grad (
29+ tensors : Iterable [Tensor ],
30+ / ,
31+ aggregator : WeightedAggregator , # Not a GramianWA, because overloads are checked in order
2032 * ,
2133 retain_jac : bool = False ,
2234) -> Tensor : ...
@@ -38,6 +50,7 @@ def jac_to_grad(
3850 aggregator : Aggregator ,
3951 * ,
4052 retain_jac : bool = False ,
53+ optimize_gramian_computation : bool = False ,
4154) -> Tensor | None :
4255 r"""
4356 Aggregates the Jacobians stored in the ``.jac`` fields of ``tensors`` and accumulates the result
@@ -50,12 +63,27 @@ def jac_to_grad(
5063 the Jacobians, ``jac_to_grad`` will also return the computed weights.
5164 :param retain_jac: Whether to preserve the ``.jac`` fields of the tensors after they have been
5265 used. Defaults to ``False``.
66+ :param optimize_gramian_computation: When the ``aggregator`` computes weights based on the
67+ Gramian of the Jacobian, it's possible to skip the concatenation of the Jacobians and to
68+ instead compute the Gramian as the sum of the Gramians of the individual Jacobians. This
69+ saves memory (up to 50% memory saving) but can be slightly slower (up to 15%) on CUDA. We
70+ advise to try this optimization if memory is an issue for you. Defaults to ``False``.
5371
5472 .. note::
55- This function starts by "flattening" the ``.jac`` fields into matrices (i.e. flattening all
56- of their dimensions except the first one), then concatenates those matrices into a combined
57- Jacobian matrix. The aggregator is then used on this matrix, which returns a combined
58- gradient vector, that is split and reshaped to fit into the ``.grad`` fields of the tensors.
73+ When ``optimize_gramian_computation=False``, this function starts by "flattening" the
74+ ``.jac`` fields into matrices (i.e. flattening all of their dimensions except the first
75+ one), then concatenates those matrices into a combined Jacobian matrix. The ``aggregator``
76+ is then used on this matrix, which returns a combined gradient vector, that is split and
77+ reshaped to fit into the ``.grad`` fields of the tensors.
78+
79+ .. note::
80+ When ``optimize_gramian_computation=True``, this function computes and sums the Gramian
81+ of each individual ``.jac`` field, iteratively. The inner weighting of the ``aggregator`` is
82+ then used to extract some weights from the obtained Gramian, used to compute a linear
83+ combination of the rows of each ``.jac`` field, to be stored into the corresponding
84+ ``.grad`` field. This is mathematically equivalent to the approach with
85+ ``optimize_gramian_computation=False``, but saves memory by not having to hold the
86+ concatenated Jacobian matrix in memory at any time.
5987
6088 .. admonition::
6189 Example
@@ -96,13 +124,46 @@ def jac_to_grad(
96124 if len (tensors_ ) == 0 :
97125 raise ValueError ("The `tensors` parameter cannot be empty." )
98126
99- jacobians = [t .jac for t in tensors_ ]
100-
127+ jacobians = deque (t .jac for t in tensors_ )
101128 check_consistent_first_dimension (jacobians , "tensors.jac" )
102129
103130 if not retain_jac :
104131 _free_jacs (tensors_ )
105132
133+ if optimize_gramian_computation :
134+ if not _can_skip_jacobian_combination (aggregator ):
135+ raise ValueError (
136+ "In order to use `jac_to_grad` with `optimize_gramian_computation=True`, you must "
137+ "provide an `Aggregator` that computes weights based on the Gramian of the Jacobian"
138+ " (e.g. `UPGrad`) and that doesn't have any forward hooks attached to it."
139+ )
140+
141+ gradients , weights = _gramian_based (aggregator , jacobians )
142+ else :
143+ gradients , weights = _jacobian_based (aggregator , jacobians , tensors_ )
144+ accumulate_grads (tensors_ , gradients )
145+
146+ return weights
147+
148+
149+ def _can_skip_jacobian_combination (aggregator : Aggregator ) -> TypeGuard [GramianWeightedAggregator ]:
150+ return (
151+ isinstance (aggregator , GramianWeightedAggregator )
152+ and not _has_forward_hook (aggregator )
153+ and not _has_forward_hook (aggregator .weighting )
154+ )
155+
156+
157+ def _has_forward_hook (module : nn .Module ) -> bool :
158+ """Return whether the module has any forward hook registered."""
159+ return len (module ._forward_hooks ) > 0 or len (module ._forward_pre_hooks ) > 0
160+
161+
162+ def _jacobian_based (
163+ aggregator : Aggregator ,
164+ jacobians : deque [Tensor ],
165+ tensors : list [TensorWithJac ],
166+ ) -> tuple [list [Tensor ], Tensor | None ]:
106167 jacobian_matrix = _unite_jacobians (jacobians )
107168 weights : Tensor | None = None
108169
@@ -124,13 +185,36 @@ def capture_hook(_m: Weighting[Matrix], _i: tuple[Tensor], output: Tensor) -> No
124185 handle .remove ()
125186 else :
126187 gradient_vector = aggregator (jacobian_matrix )
127- gradients = _disunite_gradient (gradient_vector , tensors_ )
128- accumulate_grads (tensors_ , gradients )
129- return weights
188+ gradients = _disunite_gradient (gradient_vector , tensors )
189+ return gradients , weights
190+
191+
192+ def _gramian_based (
193+ aggregator : GramianWeightedAggregator ,
194+ jacobians : deque [Tensor ],
195+ ) -> tuple [list [Tensor ], Tensor ]:
196+ weighting = aggregator .gramian_weighting
197+ gramian = _compute_gramian_sum (jacobians )
198+ weights = weighting (gramian )
199+
200+ gradients = list [Tensor ]()
201+ while jacobians :
202+ jacobian = jacobians .popleft () # get jacobian + dereference it to free memory asap
203+ gradients .append (torch .tensordot (weights , jacobian , dims = 1 ))
204+
205+ return gradients , weights
206+
207+
208+ def _compute_gramian_sum (jacobians : deque [Tensor ]) -> PSDMatrix :
209+ gramian = sum ([compute_gramian (matrix ) for matrix in jacobians ])
210+ return cast (PSDMatrix , gramian )
130211
131212
132- def _unite_jacobians (jacobians : list [Tensor ]) -> Tensor :
133- jacobian_matrices = [jacobian .reshape (jacobian .shape [0 ], - 1 ) for jacobian in jacobians ]
213+ def _unite_jacobians (jacobians : deque [Tensor ]) -> Tensor :
214+ jacobian_matrices = list [Tensor ]()
215+ while jacobians :
216+ jacobian = jacobians .popleft () # get jacobian + dereference it to free memory asap
217+ jacobian_matrices .append (jacobian .reshape (jacobian .shape [0 ], - 1 ))
134218 jacobian_matrix = torch .concat (jacobian_matrices , dim = 1 )
135219 return jacobian_matrix
136220
0 commit comments