Skip to content

Commit 012b1ba

Browse files
feat(aggregation): Add GradVac aggregator (#638)
1 parent 8974877 commit 012b1ba

File tree

9 files changed

+366
-0
lines changed

9 files changed

+366
-0
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ changelog does not include internal changes that do not affect the user.
88

99
## [Unreleased]
1010

11+
### Added
12+
13+
- Added `GradVac` and `GradVacWeighting` from
14+
[Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models](https://arxiv.org/pdf/2010.05874).
15+
1116
### Fixed
1217

1318
- Added a fallback for when the inner optimization of `NashMTL` fails (which can happen for example

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ TorchJD provides many existing aggregators from the literature, listed in the fo
281281
| [Constant](https://torchjd.org/stable/docs/aggregation/constant#torchjd.aggregation.Constant) | [ConstantWeighting](https://torchjd.org/stable/docs/aggregation/constant#torchjd.aggregation.ConstantWeighting) | - |
282282
| [DualProj](https://torchjd.org/stable/docs/aggregation/dualproj#torchjd.aggregation.DualProj) | [DualProjWeighting](https://torchjd.org/stable/docs/aggregation/dualproj#torchjd.aggregation.DualProjWeighting) | [Gradient Episodic Memory for Continual Learning](https://arxiv.org/pdf/1706.08840) |
283283
| [GradDrop](https://torchjd.org/stable/docs/aggregation/graddrop#torchjd.aggregation.GradDrop) | - | [Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout](https://arxiv.org/pdf/2010.06808) |
284+
| [GradVac](https://torchjd.org/stable/docs/aggregation/gradvac#torchjd.aggregation.GradVac) | [GradVacWeighting](https://torchjd.org/stable/docs/aggregation/gradvac#torchjd.aggregation.GradVacWeighting) | [Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models](https://arxiv.org/pdf/2010.05874) |
284285
| [IMTLG](https://torchjd.org/stable/docs/aggregation/imtl_g#torchjd.aggregation.IMTLG) | [IMTLGWeighting](https://torchjd.org/stable/docs/aggregation/imtl_g#torchjd.aggregation.IMTLGWeighting) | [Towards Impartial Multi-task Learning](https://discovery.ucl.ac.uk/id/eprint/10120667/) |
285286
| [Krum](https://torchjd.org/stable/docs/aggregation/krum#torchjd.aggregation.Krum) | [KrumWeighting](https://torchjd.org/stable/docs/aggregation/krum#torchjd.aggregation.KrumWeighting) | [Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent](https://proceedings.neurips.cc/paper/2017/file/f4b9ec30ad9f68f89b29639786cb62ef-Paper.pdf) |
286287
| [Mean](https://torchjd.org/stable/docs/aggregation/mean#torchjd.aggregation.Mean) | [MeanWeighting](https://torchjd.org/stable/docs/aggregation/mean#torchjd.aggregation.MeanWeighting) | - |
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
:hide-toc:
2+
3+
GradVac
4+
=======
5+
6+
.. autoclass:: torchjd.aggregation.GradVac
7+
:members:
8+
:undoc-members:
9+
:exclude-members: forward, eps, beta
10+
11+
.. autoclass:: torchjd.aggregation.GradVacWeighting
12+
:members:
13+
:undoc-members:
14+
:exclude-members: forward, eps, beta

docs/source/docs/aggregation/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ Abstract base classes
3535
dualproj.rst
3636
flattening.rst
3737
graddrop.rst
38+
gradvac.rst
3839
imtl_g.rst
3940
krum.rst
4041
mean.rst

src/torchjd/aggregation/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
from ._dualproj import DualProj, DualProjWeighting
6767
from ._flattening import Flattening
6868
from ._graddrop import GradDrop
69+
from ._gradvac import GradVac, GradVacWeighting
6970
from ._imtl_g import IMTLG, IMTLGWeighting
7071
from ._krum import Krum, KrumWeighting
7172
from ._mean import Mean, MeanWeighting
@@ -92,6 +93,8 @@
9293
"Flattening",
9394
"GeneralizedWeighting",
9495
"GradDrop",
96+
"GradVac",
97+
"GradVacWeighting",
9598
"IMTLG",
9699
"IMTLGWeighting",
97100
"Krum",
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
from __future__ import annotations
2+
3+
from typing import cast
4+
5+
import torch
6+
from torch import Tensor
7+
8+
from torchjd._linalg import PSDMatrix
9+
10+
from ._aggregator_bases import GramianWeightedAggregator
11+
from ._utils.non_differentiable import raise_non_differentiable_error
12+
from ._weighting_bases import Weighting
13+
14+
15+
class GradVac(GramianWeightedAggregator):
16+
r"""
17+
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing the aggregation step of
18+
Gradient Vaccine (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task
19+
Optimization in Massively Multilingual Models (ICLR 2021 Spotlight)
20+
<https://openreview.net/forum?id=F1vEjWK-lH_>`_.
21+
22+
For each task :math:`i`, the order in which other tasks :math:`j` are visited is drawn at
23+
random. For each pair :math:`(i, j)`, the cosine similarity :math:`\phi_{ij}` between the
24+
(possibly already modified) gradient of task :math:`i` and the original gradient of task
25+
:math:`j` is compared to an EMA target :math:`\hat{\phi}_{ij}`. When
26+
:math:`\phi_{ij} < \hat{\phi}_{ij}`, a closed-form correction adds a scaled copy of
27+
:math:`g_j` to :math:`g_i^{(\mathrm{PC})}`. The EMA is then updated with
28+
:math:`\hat{\phi}_{ij} \leftarrow (1-\beta)\hat{\phi}_{ij} + \beta \phi_{ij}`. The aggregated
29+
vector is the sum of the modified rows.
30+
31+
This aggregator is stateful: it keeps :math:`\hat{\phi}` across calls. Use :meth:`reset` when
32+
the number of tasks or dtype changes.
33+
34+
:param beta: EMA decay for :math:`\hat{\phi}`.
35+
:param eps: Small non-negative constant added to denominators.
36+
37+
.. note::
38+
For each task :math:`i`, the order of other tasks :math:`j` is shuffled independently
39+
using the global PyTorch RNG (``torch.randperm``). Seed it with ``torch.manual_seed`` if
40+
you need reproducibility.
41+
"""
42+
43+
def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None:
44+
weighting = GradVacWeighting(beta=beta, eps=eps)
45+
super().__init__(weighting)
46+
self._gradvac_weighting = weighting
47+
self.register_full_backward_pre_hook(raise_non_differentiable_error)
48+
49+
@property
50+
def beta(self) -> float:
51+
return self._gradvac_weighting.beta
52+
53+
@beta.setter
54+
def beta(self, value: float) -> None:
55+
self._gradvac_weighting.beta = value
56+
57+
@property
58+
def eps(self) -> float:
59+
return self._gradvac_weighting.eps
60+
61+
@eps.setter
62+
def eps(self, value: float) -> None:
63+
self._gradvac_weighting.eps = value
64+
65+
def reset(self) -> None:
66+
"""Clears EMA state so the next forward starts from zero targets."""
67+
68+
self._gradvac_weighting.reset()
69+
70+
def __repr__(self) -> str:
71+
return f"GradVac(beta={self.beta!r}, eps={self.eps!r})"
72+
73+
74+
class GradVacWeighting(Weighting[PSDMatrix]):
75+
r"""
76+
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
77+
:class:`~torchjd.aggregation.GradVac`.
78+
79+
All required quantities (gradient norms, cosine similarities, and their updates after the
80+
vaccine correction) are derived purely from the Gramian, without needing the full Jacobian.
81+
If :math:`g_i^{(\mathrm{PC})} = \sum_k c_{ik} g_k`, then:
82+
83+
.. math::
84+
85+
\|g_i^{(\mathrm{PC})}\|^2 = \mathbf{c}_i G \mathbf{c}_i^\top,\qquad
86+
g_i^{(\mathrm{PC})} \cdot g_j = \mathbf{c}_i G_{:,j}
87+
88+
where :math:`G` is the Gramian. The correction :math:`g_i^{(\mathrm{PC})} \mathrel{+}= w
89+
g_j` then becomes :math:`c_{ij} \mathrel{+}= w`, and the updated dot products follow
90+
immediately.
91+
92+
This weighting is stateful: it keeps :math:`\hat{\phi}` across calls. Use :meth:`reset` when
93+
the number of tasks or dtype changes.
94+
95+
:param beta: EMA decay for :math:`\hat{\phi}`.
96+
:param eps: Small non-negative constant added to denominators.
97+
"""
98+
99+
def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None:
100+
super().__init__()
101+
if not (0.0 <= beta <= 1.0):
102+
raise ValueError(f"Parameter `beta` must be in [0, 1]. Found beta={beta!r}.")
103+
if eps < 0.0:
104+
raise ValueError(f"Parameter `eps` must be non-negative. Found eps={eps!r}.")
105+
106+
self._beta = beta
107+
self._eps = eps
108+
self._phi_t: Tensor | None = None
109+
self._state_key: tuple[int, torch.dtype] | None = None
110+
111+
@property
112+
def beta(self) -> float:
113+
return self._beta
114+
115+
@beta.setter
116+
def beta(self, value: float) -> None:
117+
if not (0.0 <= value <= 1.0):
118+
raise ValueError(f"Attribute `beta` must be in [0, 1]. Found beta={value!r}.")
119+
self._beta = value
120+
121+
@property
122+
def eps(self) -> float:
123+
return self._eps
124+
125+
@eps.setter
126+
def eps(self, value: float) -> None:
127+
if value < 0.0:
128+
raise ValueError(f"Attribute `eps` must be non-negative. Found eps={value!r}.")
129+
self._eps = value
130+
131+
def reset(self) -> None:
132+
"""Clears EMA state so the next forward starts from zero targets."""
133+
134+
self._phi_t = None
135+
self._state_key = None
136+
137+
def forward(self, gramian: PSDMatrix, /) -> Tensor:
138+
# Move all computations on cpu to avoid moving memory between cpu and gpu at each iteration
139+
device = gramian.device
140+
dtype = gramian.dtype
141+
cpu = torch.device("cpu")
142+
143+
G = cast(PSDMatrix, gramian.to(device=cpu))
144+
m = G.shape[0]
145+
146+
self._ensure_state(m, dtype)
147+
phi_t = cast(Tensor, self._phi_t)
148+
149+
beta = self._beta
150+
eps = self._eps
151+
152+
# C[i, :] holds coefficients such that g_i^PC = sum_k C[i,k] * g_k (original gradients).
153+
# Initially each modified gradient equals the original, so C = I.
154+
C = torch.eye(m, device=cpu, dtype=dtype)
155+
156+
for i in range(m):
157+
# Dot products of g_i^PC with every original g_j, shape (m,).
158+
cG = C[i] @ G
159+
160+
others = [j for j in range(m) if j != i]
161+
perm = torch.randperm(len(others))
162+
shuffled_js = [others[idx] for idx in perm.tolist()]
163+
164+
for j in shuffled_js:
165+
dot_ij = cG[j]
166+
norm_i_sq = (cG * C[i]).sum()
167+
norm_i = norm_i_sq.clamp(min=0.0).sqrt()
168+
norm_j = G[j, j].clamp(min=0.0).sqrt()
169+
denom = norm_i * norm_j + eps
170+
phi_ijk = dot_ij / denom
171+
172+
phi_hat = phi_t[i, j]
173+
if phi_ijk < phi_hat:
174+
sqrt_1_phi2 = (1.0 - phi_ijk * phi_ijk).clamp(min=0.0).sqrt()
175+
sqrt_1_hat2 = (1.0 - phi_hat * phi_hat).clamp(min=0.0).sqrt()
176+
denom_w = norm_j * sqrt_1_hat2 + eps
177+
w = norm_i * (phi_hat * sqrt_1_phi2 - phi_ijk * sqrt_1_hat2) / denom_w
178+
C[i, j] = C[i, j] + w
179+
cG = cG + w * G[j]
180+
181+
phi_t[i, j] = (1.0 - beta) * phi_hat + beta * phi_ijk
182+
183+
weights = C.sum(dim=0)
184+
return weights.to(device)
185+
186+
def _ensure_state(self, m: int, dtype: torch.dtype) -> None:
187+
key = (m, dtype)
188+
if self._state_key != key or self._phi_t is None:
189+
self._phi_t = torch.zeros(m, m, dtype=dtype)
190+
self._state_key = key

tests/plots/interactive_plotter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
ConFIG,
1818
DualProj,
1919
GradDrop,
20+
GradVac,
2021
Mean,
2122
NashMTL,
2223
PCGrad,
@@ -48,6 +49,7 @@ def main() -> None:
4849
ConFIG(),
4950
DualProj(),
5051
GradDrop(),
52+
GradVac(),
5153
IMTLG(),
5254
Mean(),
5355
MGDA(),

0 commit comments

Comments
 (0)