Skip to content

Commit 1030d57

Browse files
committed
feat(aggregation): add GradVac aggregator
Implement Gradient Vaccine (ICLR 2021) as a stateful Jacobian aggregator. Support group_type 0 (whole model), 1 (all_layer via encoder), and 2 (all_matrix via shared_params), with DEFAULT_GRADVAC_EPS and configurable eps. Add Sphinx page and unit tests. Autogram is not supported; use torch.manual_seed for reproducible task shuffle order. Made-with: Cursor
1 parent 511561f commit 1030d57

File tree

5 files changed

+462
-0
lines changed

5 files changed

+462
-0
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
:hide-toc:
2+
3+
GradVac
4+
=======
5+
6+
.. autodata:: torchjd.aggregation.DEFAULT_GRADVAC_EPS
7+
8+
The constructor argument ``group_type`` (default ``0``) sets **parameter granularity** for the
9+
per-block cosine statistics in GradVac:
10+
11+
* ``0`` — **whole model** (``whole_model``): one block per task gradient row. Omit ``encoder`` and
12+
``shared_params``.
13+
* ``1`` — **all layer** (``all_layer``): one block per leaf submodule with parameters under
14+
``encoder`` (same traversal as ``encoder.modules()`` in the reference formulation).
15+
* ``2`` — **all matrix** (``all_matrix``): one block per tensor in ``shared_params``, in order. Use
16+
the same tensors as for the shared-parameter Jacobian columns (e.g. the parameters you would pass
17+
to a shared-gradient helper).
18+
19+
.. autoclass:: torchjd.aggregation.GradVac
20+
:members:
21+
:undoc-members:
22+
:exclude-members: forward

docs/source/docs/aggregation/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ Abstract base classes
3535
dualproj.rst
3636
flattening.rst
3737
graddrop.rst
38+
gradvac.rst
3839
imtl_g.rst
3940
krum.rst
4041
mean.rst

src/torchjd/aggregation/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
from ._dualproj import DualProj, DualProjWeighting
6767
from ._flattening import Flattening
6868
from ._graddrop import GradDrop
69+
from ._gradvac import DEFAULT_GRADVAC_EPS, GradVac
6970
from ._imtl_g import IMTLG, IMTLGWeighting
7071
from ._krum import Krum, KrumWeighting
7172
from ._mean import Mean, MeanWeighting
@@ -87,11 +88,13 @@
8788
"ConFIG",
8889
"Constant",
8990
"ConstantWeighting",
91+
"DEFAULT_GRADVAC_EPS",
9092
"DualProj",
9193
"DualProjWeighting",
9294
"Flattening",
9395
"GeneralizedWeighting",
9496
"GradDrop",
97+
"GradVac",
9598
"IMTLG",
9699
"IMTLGWeighting",
97100
"Krum",
Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Iterable
4+
5+
import torch
6+
import torch.nn as nn
7+
from torch import Tensor
8+
9+
from torchjd._linalg import Matrix
10+
11+
from ._aggregator_bases import Aggregator
12+
from ._utils.non_differentiable import raise_non_differentiable_error
13+
14+
#: Default small constant added to denominators for numerical stability.
15+
DEFAULT_GRADVAC_EPS = 1e-8
16+
17+
18+
def _gradvac_all_layer_group_sizes(encoder: nn.Module) -> tuple[int, ...]:
19+
"""
20+
Block sizes per leaf submodule with parameters, matching the ``all_layer`` grouping: iterate
21+
``encoder.modules()`` and append the total number of elements in each module that has no child
22+
submodules and registers at least one parameter.
23+
"""
24+
25+
return tuple(
26+
sum(w.numel() for w in module.parameters())
27+
for module in encoder.modules()
28+
if len(module._modules) == 0 and len(module._parameters) > 0
29+
)
30+
31+
32+
def _gradvac_all_matrix_group_sizes(shared_params: Iterable[Tensor]) -> tuple[int, ...]:
33+
"""One block per tensor in ``shared_params`` order (``all_matrix`` / shared-parameter layout)."""
34+
35+
return tuple(p.numel() for p in shared_params)
36+
37+
38+
class GradVac(Aggregator):
39+
r"""
40+
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing Gradient Vaccine
41+
(GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task Optimization in
42+
Massively Multilingual Models (ICLR 2021 Spotlight)
43+
<https://openreview.net/forum?id=F1vEjWK-lH_>`_.
44+
45+
The input matrix is a Jacobian :math:`G \in \mathbb{R}^{M \times D}` whose rows are per-task
46+
gradients. For each task :math:`i`, rows are visited in a random order; for each other task
47+
:math:`j` and each parameter block :math:`k`, the cosine correlation :math:`\rho_{ijk}` between
48+
the (possibly already modified) gradient of task :math:`i` and the original gradient of task
49+
:math:`j` on that block is compared to an EMA target :math:`\bar{\rho}_{ijk}`. When
50+
:math:`\rho_{ijk} < \bar{\rho}_{ijk}`, a closed-form correction adds a scaled copy of
51+
:math:`g_j` to the block of :math:`g_i^{(\mathrm{PC})}`. The EMA is then updated with
52+
:math:`\bar{\rho}_{ijk} \leftarrow (1-\beta)\bar{\rho}_{ijk} + \beta \rho_{ijk}`. The aggregated
53+
vector is the sum of the modified rows.
54+
55+
This aggregator is stateful: it keeps :math:`\bar{\rho}` across calls. Use :meth:`reset` when
56+
the number of tasks, parameter dimension, grouping, device, or dtype changes.
57+
58+
**Parameter granularity** is selected by ``group_type`` (integer, default ``0``). It defines how
59+
each task gradient row is partitioned into blocks :math:`k` so that cosines and EMA targets
60+
:math:`\bar{\rho}_{ijk}` are computed **per block** rather than only globally:
61+
62+
* ``0`` — **whole model** (``whole_model``): the full row of length :math:`D` is a single block.
63+
Cosine similarity is taken between entire task gradients. Do not pass ``encoder`` or
64+
``shared_params``.
65+
* ``1`` — **all layer** (``all_layer``): one block per leaf ``nn.Module`` under ``encoder`` that
66+
holds parameters (same rule as iterating ``encoder.modules()`` and selecting leaves with
67+
parameters). Pass ``encoder``; ``shared_params`` must be omitted.
68+
* ``2`` — **all matrix** (``all_matrix``): one block per tensor in ``shared_params``, in iteration
69+
order. That order must match how Jacobian columns are laid out for those shared parameters.
70+
Pass ``shared_params``; ``encoder`` must be omitted.
71+
72+
:param beta: EMA decay for :math:`\bar{\rho}` (paper default ``0.5``).
73+
:param group_type: Granularity of parameter grouping; see **Parameter granularity** above.
74+
:param encoder: Module whose subtree defines ``all_layer`` blocks when ``group_type == 1``.
75+
:param shared_params: Iterable of parameter tensors defining ``all_matrix`` block sizes and
76+
order when ``group_type == 2``. It is materialized once at construction.
77+
:param eps: Small positive constant added to denominators when computing cosines and the
78+
vaccine weight (default :data:`~torchjd.aggregation.DEFAULT_GRADVAC_EPS`). You may read or
79+
assign the :attr:`eps` attribute between steps to tune numerical behavior.
80+
81+
.. note::
82+
GradVac is not compatible with autogram: it needs full Jacobian rows and per-block inner
83+
products, not only a Gram matrix. Only the autojac path is supported.
84+
85+
.. note::
86+
Task-order shuffling uses the global PyTorch RNG (``torch.randperm``). Seed it with
87+
``torch.manual_seed`` if you need reproducibility.
88+
"""
89+
90+
def __init__(
91+
self,
92+
beta: float = 0.5,
93+
group_type: int = 0,
94+
encoder: nn.Module | None = None,
95+
shared_params: Iterable[Tensor] | None = None,
96+
eps: float = DEFAULT_GRADVAC_EPS,
97+
) -> None:
98+
super().__init__()
99+
if not (0.0 <= beta <= 1.0):
100+
raise ValueError(f"Parameter `beta` must be in [0, 1]. Found beta={beta!r}.")
101+
if group_type not in (0, 1, 2):
102+
raise ValueError(
103+
"Parameter `group_type` must be 0 (whole_model), 1 (all_layer), or 2 (all_matrix). "
104+
f"Found group_type={group_type!r}.",
105+
)
106+
params_tuple: tuple[Tensor, ...] = ()
107+
fixed_block_sizes: tuple[int, ...] | None
108+
if group_type == 0:
109+
if encoder is not None:
110+
raise ValueError("Parameter `encoder` must be None when `group_type == 0`.")
111+
if shared_params is not None:
112+
raise ValueError("Parameter `shared_params` must be None when `group_type == 0`.")
113+
fixed_block_sizes = None
114+
elif group_type == 1:
115+
if encoder is None:
116+
raise ValueError("Parameter `encoder` is required when `group_type == 1`.")
117+
if shared_params is not None:
118+
raise ValueError("Parameter `shared_params` must be None when `group_type == 1`.")
119+
fixed_block_sizes = _gradvac_all_layer_group_sizes(encoder)
120+
if sum(fixed_block_sizes) == 0:
121+
raise ValueError("Parameter `encoder` has no parameters in any leaf module.")
122+
else:
123+
if shared_params is None:
124+
raise ValueError("Parameter `shared_params` is required when `group_type == 2`.")
125+
if encoder is not None:
126+
raise ValueError("Parameter `encoder` must be None when `group_type == 2`.")
127+
params_tuple = tuple(shared_params)
128+
if len(params_tuple) == 0:
129+
raise ValueError(
130+
"Parameter `shared_params` must be non-empty when `group_type == 2`."
131+
)
132+
fixed_block_sizes = _gradvac_all_matrix_group_sizes(params_tuple)
133+
134+
if eps <= 0.0:
135+
raise ValueError(f"Parameter `eps` must be positive. Found eps={eps!r}.")
136+
137+
self._beta = beta
138+
self._group_type = group_type
139+
self._encoder = encoder
140+
self._shared_params_len = len(params_tuple)
141+
self._fixed_block_sizes = fixed_block_sizes
142+
self._eps = float(eps)
143+
144+
self._rho_t: Tensor | None = None
145+
self._state_key: tuple[int, int, tuple[int, ...], torch.device, torch.dtype] | None = None
146+
147+
self.register_full_backward_pre_hook(raise_non_differentiable_error)
148+
149+
@property
150+
def eps(self) -> float:
151+
"""Small positive constant added to denominators for numerical stability."""
152+
153+
return self._eps
154+
155+
@eps.setter
156+
def eps(self, value: float) -> None:
157+
v = float(value)
158+
if v <= 0.0:
159+
raise ValueError(f"Attribute `eps` must be positive. Found eps={value!r}.")
160+
self._eps = v
161+
162+
def reset(self) -> None:
163+
"""Clears EMA state so the next forward starts from zero targets."""
164+
165+
self._rho_t = None
166+
self._state_key = None
167+
168+
def __repr__(self) -> str:
169+
enc = "None" if self._encoder is None else f"{self._encoder.__class__.__name__}(...)"
170+
sp = "None" if self._group_type != 2 else f"n_params={self._shared_params_len}"
171+
return (
172+
f"{self.__class__.__name__}(beta={self._beta!r}, group_type={self._group_type!r}, "
173+
f"encoder={enc}, shared_params={sp}, eps={self._eps!r})"
174+
)
175+
176+
def _resolve_segment_sizes(self, n: int) -> tuple[int, ...]:
177+
if self._group_type == 0:
178+
return (n,)
179+
assert self._fixed_block_sizes is not None
180+
sizes = self._fixed_block_sizes
181+
if sum(sizes) != n:
182+
raise ValueError(
183+
"The Jacobian width `D` must equal the sum of block sizes implied by "
184+
f"`encoder` or `shared_params` for this `group_type`. Found D={n}, "
185+
f"sum(block_sizes)={sum(sizes)}.",
186+
)
187+
return sizes
188+
189+
def _ensure_state(
190+
self,
191+
m: int,
192+
n: int,
193+
sizes: tuple[int, ...],
194+
device: torch.device,
195+
dtype: torch.dtype,
196+
) -> None:
197+
key = (m, n, sizes, device, dtype)
198+
num_groups = len(sizes)
199+
if self._state_key != key or self._rho_t is None:
200+
self._rho_t = torch.zeros(m, m, num_groups, device=device, dtype=dtype)
201+
self._state_key = key
202+
203+
def forward(self, matrix: Matrix, /) -> Tensor:
204+
grads = matrix
205+
m, n = grads.shape
206+
if m == 0 or n == 0:
207+
return torch.zeros(n, dtype=grads.dtype, device=grads.device)
208+
209+
sizes = self._resolve_segment_sizes(n)
210+
device = grads.device
211+
dtype = grads.dtype
212+
self._ensure_state(m, n, sizes, device, dtype)
213+
assert self._rho_t is not None
214+
215+
rho_t = self._rho_t
216+
beta = self._beta
217+
eps = self.eps
218+
219+
pc_grads = grads.clone()
220+
offsets = [0]
221+
for s in sizes:
222+
offsets.append(offsets[-1] + s)
223+
224+
for i in range(m):
225+
others = [j for j in range(m) if j != i]
226+
perm = torch.randperm(len(others))
227+
order = perm.tolist()
228+
shuffled_js = [others[idx] for idx in order]
229+
230+
for j in shuffled_js:
231+
for k in range(len(sizes)):
232+
beg, end = offsets[k], offsets[k + 1]
233+
slice_i = pc_grads[i, beg:end]
234+
slice_j = grads[j, beg:end]
235+
236+
norm_i = slice_i.norm()
237+
norm_j = slice_j.norm()
238+
denom = norm_i * norm_j + eps
239+
rho_ijk = slice_i.dot(slice_j) / denom
240+
241+
bar = rho_t[i, j, k]
242+
if rho_ijk < bar:
243+
sqrt_1_rho2 = (1.0 - rho_ijk * rho_ijk).clamp(min=0.0).sqrt()
244+
sqrt_1_bar2 = (1.0 - bar * bar).clamp(min=0.0).sqrt()
245+
denom_w = norm_j * sqrt_1_bar2 + eps
246+
w = norm_i * (bar * sqrt_1_rho2 - rho_ijk * sqrt_1_bar2) / denom_w
247+
pc_grads[i, beg:end] = slice_i + slice_j * w
248+
249+
rho_t[i, j, k] = (1.0 - beta) * bar + beta * rho_ijk
250+
251+
return pc_grads.sum(dim=0)

0 commit comments

Comments
 (0)