Skip to content

Commit 0698f40

Browse files
refactor(aggregation)!: Remove generalized gramians (#692)
* `Engine.compute_gramian` now always returns a flat `[m, m]` gramian, regardless of the output shape * Remove `GeneralizedWeighting`, and `Flattening` — they are no longer needed * Update the IWMTL example to use `UPGradWeighting` directly and reshape the weights before calling `backward` * Add a migration guide in `CHANGELOG.md` --------- Co-authored-by: Pierre Quinton <pierre.quinton@epfl.ch>
1 parent f93ee7e commit 0698f40

10 files changed

Lines changed: 72 additions & 204 deletions

File tree

CHANGELOG.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,29 @@ changelog does not include internal changes that do not affect the user.
2020
(installed with `pip install torchjd`) much lighter, but it means that users of `UPGrad` and
2121
`DualProj` now have to install the new optional dependency group `quadprog_projector` explicitly
2222
(with e.g. `pip install "torchjd[quadprog_projector]"`).
23+
- **BREAKING**: Removed entirely the concept of generalized Gramians. The `Engine.compute_gramian`
24+
method now always returns a square matrix of shape `[m, m]`, where `m` is the total number of
25+
elements of the ``output`` tensor (treating all dimensions uniformly). Previously, an output of
26+
shape `[m1, m2]` would return a 4D generalized Gramian of shape `[m1, m2, m2, m1]`; it now
27+
returns a `[m1 * m2, m1 * m2]` matrix.
28+
This also removes `GeneralizedWeighting` and `Flattening`.
29+
To update, replace `Flattening(weighting)` with a standard `Weighting` and reshape the resulting
30+
weight vector yourself:
31+
```python
32+
# Before
33+
from torchjd.aggregation import Flattening, UPGradWeighting
34+
weighting = Flattening(UPGradWeighting())
35+
gramian = engine.compute_gramian(losses) # shape: [m1, m2, m2, m1]
36+
weights = weighting(gramian) # shape: [m1, m2]
37+
losses.backward(weights)
38+
39+
# After
40+
from torchjd.aggregation import UPGradWeighting
41+
weighting = UPGradWeighting()
42+
gramian = engine.compute_gramian(losses) # shape: [m1 * m2, m1 * m2]
43+
weights = weighting(gramian).reshape(losses.shape) # shape: [m1, m2]
44+
losses.backward(weights)
45+
```
2346

2447
## [0.11.0] - 2026-05-18
2548

docs/source/docs/aggregation/flattening.rst

Lines changed: 0 additions & 7 deletions
This file was deleted.

docs/source/docs/aggregation/index.rst

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@ Abstract base classes
1919
.. autoclass:: torchjd.aggregation.Weighting
2020
:members: __call__
2121

22-
.. autoclass:: torchjd.aggregation.GeneralizedWeighting
23-
:members: __call__
24-
2522
.. autoclass:: torchjd.aggregation.Stateful
2623
:members: reset
2724

@@ -38,7 +35,6 @@ Abstract base classes
3835
cr_mogm.rst
3936
dualproj.rst
4037
fairgrad.rst
41-
flattening.rst
4238
graddrop.rst
4339
gradvac.rst
4440
imtl_g.rst

docs/source/examples/iwmtl.rst

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ The following example shows how to do that.
1616
from torch.nn import Linear, MSELoss, ReLU, Sequential
1717
from torch.optim import SGD
1818

19-
from torchjd.aggregation import Flattening, UPGradWeighting
19+
from torchjd.aggregation import UPGradWeighting
2020
from torchjd.autogram import Engine
2121

2222
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
@@ -30,7 +30,7 @@ The following example shows how to do that.
3030
3131
optimizer = SGD(params, lr=0.1)
3232
mse = MSELoss(reduction="none")
33-
weighting = Flattening(UPGradWeighting())
33+
weighting = UPGradWeighting()
3434
engine = Engine(shared_module, batch_dim=0)
3535

3636
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
@@ -46,20 +46,19 @@ The following example shows how to do that.
4646
losses = torch.stack([mse(out1, target1), mse(out2, target2)], dim=1) # shape: [16, 2]
4747

4848
# Compute the gramian (inner products between pairs of gradients of the losses)
49-
gramian = engine.compute_gramian(losses) # shape: [16, 2, 2, 16]
49+
gramian = engine.compute_gramian(losses) # shape: [32, 32]
5050

5151
# Obtain the weights that lead to no conflict between reweighted gradients
52-
weights = weighting(gramian) # shape: [16, 2]
52+
weights = weighting(gramian) # shape: [32]
5353

5454
# Do the standard backward pass, but weighted using the obtained weights
55-
losses.backward(weights)
55+
losses.backward(weights.reshape(losses.shape))
5656
optimizer.step()
5757
optimizer.zero_grad()
5858

5959
.. note::
60-
In this example, the tensor of losses is a matrix rather than a vector. The gramian is thus a
61-
4D tensor rather than a matrix, and a
62-
:class:`~torchjd.aggregation._weighting_bases.GeneralizedWeighting`, such as
63-
:class:`~torchjd.aggregation._flattening.Flattening`, has to be used to extract a matrix of
64-
weights from it. More information about ``GeneralizedWeighting`` can be found in the
65-
:doc:`../../docs/aggregation/index` page.
60+
In this example, the tensor of losses is a matrix of shape ``[16, 2]`` (16 samples, 2 tasks).
61+
The autogram engine flattens this into a vector of ``m = 16 × 2 = 32`` objectives, so the
62+
Gramian has shape ``[32, 32]``. A standard :class:`~torchjd.aggregation.Weighting` is then used
63+
to extract a vector of 32 weights, which is reshaped back to ``[16, 2]`` before being passed to
64+
:meth:`~torch.Tensor.backward`.

src/torchjd/aggregation/__init__.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -36,28 +36,6 @@
3636
>>> weights = weighting(gramian)
3737
>>> weights
3838
tensor([1.1109, 0.7894])
39-
40-
When dealing with a more general tensor of objectives, of shape ``[m_1, ..., m_k]`` (i.e. not
41-
necessarily a simple vector), the Jacobian will be of shape ``[m_1, ..., m_k, n]``, and its Gramian
42-
will be called a `generalized Gramian`, of shape ``[m_1, ..., m_k, m_k, ..., m_1]``. One can use a
43-
:class:`GeneralizedWeighting<torchjd.aggregation.GeneralizedWeighting>` to extract
44-
a tensor of weights (of shape ``[m_1, ..., m_k]``) from such a generalized Gramian. The simplest
45-
:class:`GeneralizedWeighting<torchjd.aggregation.GeneralizedWeighting>` is
46-
:class:`Flattening<torchjd.aggregation.Flattening>`: it simply "flattens" the
47-
generalized Gramian into a square Gramian matrix (of shape ``[m_1 * ... * m_k, m_1 * ... * m_k]``),
48-
applies a normal weighting to it to obtain a vector of weights, and returns the reshaped tensor of
49-
weights.
50-
51-
>>> from torch import ones
52-
>>> from torchjd.aggregation import Flattening, UPGradWeighting
53-
>>>
54-
>>> weighting = Flattening(UPGradWeighting())
55-
>>> # Generate a generalized Gramian filled with ones, for the sake of the example
56-
>>> generalized_gramian = ones((2, 3, 3, 2))
57-
>>> weights = weighting(generalized_gramian)
58-
>>> weights
59-
tensor([[0.1667, 0.1667, 0.1667],
60-
[0.1667, 0.1667, 0.1667]])
6139
"""
6240

6341
from ._aggregator_bases import Aggregator, GramianWeightedAggregator, WeightedAggregator
@@ -68,7 +46,6 @@
6846
from ._cr_mogm import CRMOGMWeighting
6947
from ._dualproj import DualProj, DualProjWeighting
7048
from ._fairgrad import FairGrad, FairGradWeighting
71-
from ._flattening import Flattening
7249
from ._graddrop import GradDrop
7350
from ._gradvac import GradVac, GradVacWeighting
7451
from ._imtl_g import IMTLG, IMTLGWeighting
@@ -82,7 +59,7 @@
8259
from ._sum import Sum, SumWeighting
8360
from ._trimmed_mean import TrimmedMean
8461
from ._upgrad import UPGrad, UPGradWeighting
85-
from ._weighting_bases import GeneralizedWeighting, Weighting
62+
from ._weighting_bases import Weighting
8663

8764
__all__ = [
8865
"Aggregator",
@@ -98,8 +75,6 @@
9875
"DualProjWeighting",
9976
"FairGrad",
10077
"FairGradWeighting",
101-
"Flattening",
102-
"GeneralizedWeighting",
10378
"GradDrop",
10479
"GradVac",
10580
"GradVacWeighting",

src/torchjd/aggregation/_flattening.py

Lines changed: 0 additions & 32 deletions
This file was deleted.

src/torchjd/aggregation/_weighting_bases.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from torch import Tensor, nn
88

9-
from torchjd._linalg import PSDTensor, is_psd_tensor
109
from torchjd.linalg import Matrix, PSDMatrix
1110

1211
_T = TypeVar("_T", contravariant=True, bound=Tensor)
@@ -76,30 +75,3 @@ def __call__(self, gramian: Tensor, /) -> Tensor:
7675
:param gramian: The Gramian from which the weights must be extracted.
7776
"""
7877
return super().__call__(gramian)
79-
80-
81-
class GeneralizedWeighting(nn.Module, ABC):
82-
r"""
83-
Abstract base class for all weightings that operate on generalized Gramians. It has the role of
84-
extracting a tensor of weights of dimension :math:`m_1 \times \dots \times m_k` from a
85-
generalized Gramian of dimension
86-
:math:`m_1 \times \dots \times m_k \times m_k \times \dots \times m_1`.
87-
"""
88-
89-
def __init__(self) -> None:
90-
super().__init__()
91-
92-
@abstractmethod
93-
def forward(self, generalized_gramian: PSDTensor, /) -> Tensor:
94-
"""Computes the vector of weights from the input generalized Gramian."""
95-
96-
def __call__(self, generalized_gramian: Tensor, /) -> Tensor:
97-
"""
98-
Computes the tensor of weights from the input generalized Gramian and applies all registered
99-
hooks.
100-
101-
:param generalized_gramian: The tensor from which the weights must be extracted.
102-
"""
103-
104-
assert is_psd_tensor(generalized_gramian)
105-
return super().__call__(generalized_gramian)

src/torchjd/autogram/_engine.py

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch import Tensor, nn, vmap
55
from torch.autograd.graph import get_gradient_edge
66

7-
from torchjd._linalg import movedim, reshape
7+
from torchjd._linalg import flatten, movedim, reshape
88
from torchjd.linalg import PSDMatrix
99

1010
from ._edge_registry import EdgeRegistry
@@ -246,28 +246,18 @@ def compute_gramian(self, output: Tensor, /) -> Tensor:
246246
Computes the Gramian of the Jacobian of ``output`` with respect to the direct parameters of
247247
all ``modules``.
248248
249-
:param output: The tensor of arbitrary shape to differentiate. The shape of the returned
250-
Gramian depends on the shape of this output.
251-
252-
.. note::
253-
This function doesn't require ``output`` to be a vector. For example, if ``output`` is
254-
a matrix of shape :math:`[m_1, m_2]`, its Jacobian :math:`J` with respect to the
255-
parameters will be of shape :math:`[m_1, m_2, n]`, where :math:`n` is the number of
256-
parameters in the model. This is what we call a `generalized Jacobian`. The
257-
corresponding Gramian :math:`G = J J^\top` will be of shape
258-
:math:`[m_1, m_2, m_2, m_1]`. This is what we call a `generalized Gramian`. The number
259-
of dimensions of the returned generalized Gramian will always be twice that of the
260-
``output``.
249+
:param output: The tensor to differentiate. Its elements are treated as a flat vector of
250+
:math:`m` objectives (where :math:`m` is the total number of elements of ``output``),
251+
so the returned Gramian always has shape :math:`[m, m]`.
261252
262253
A few examples:
263-
- 0D (scalar) ``output``: 0D Gramian (this can be used to efficiently compute the
264-
squared norm of the gradient of ``output``).
265-
- 1D (vector) ``output``: 2D Gramian (this is the standard setting of Jacobian
266-
descent).
267-
- 2D (matrix) ``output``: 4D Gramian (this can be used for :doc:`Instance-Wise
268-
Multi-Task Learning (IWMTL) <../../examples/iwmtl>`, as each sample in the batch
269-
has one loss per task).
270-
- etc.
254+
- Scalar ``output``: :math:`1\times 1` Gramian (this can be used to efficiently
255+
compute the squared norm of the gradient of ``output``).
256+
- Vector ``output`` of dimension :math:`m`: :math:`m \times m` Gramian (this is the
257+
standard setting of Jacobian descent).
258+
- Matrix ``output`` of dimension :math:`m_1\times m_2`: :math:`m_1 m_2 \times m_1 m_2`
259+
Gramian (this can be used for :doc:`Instance-Wise Multi-Task Learning (IWMTL)
260+
<../../examples/iwmtl>`, as each sample in the batch has one loss per task).
271261
"""
272262

273263
if self._batch_dim is not None:
@@ -305,12 +295,11 @@ def compute_gramian(self, output: Tensor, /) -> Tensor:
305295
for gramian_computer in self._gramian_computers.values():
306296
gramian_computer.reset()
307297

308-
unordered_gramian = reshape(square_gramian, ordered_shape)
309-
310298
if self._batch_dim is not None:
311-
gramian = movedim(unordered_gramian, [-1], [self._batch_dim])
299+
unordered_gramian = reshape(square_gramian, ordered_shape)
300+
gramian = flatten(movedim(unordered_gramian, [-1], [self._batch_dim]))
312301
else:
313-
gramian = unordered_gramian
302+
gramian = square_gramian
314303

315304
return gramian
316305

tests/unit/aggregation/test_flattening.py

Lines changed: 0 additions & 36 deletions
This file was deleted.

0 commit comments

Comments
 (0)