|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from collections.abc import Iterable |
| 4 | + |
| 5 | +import torch |
| 6 | +import torch.nn as nn |
| 7 | +from torch import Tensor |
| 8 | + |
| 9 | +from torchjd._linalg import Matrix |
| 10 | + |
| 11 | +from ._aggregator_bases import Aggregator |
| 12 | +from ._utils.non_differentiable import raise_non_differentiable_error |
| 13 | + |
| 14 | +#: Default small constant added to denominators for numerical stability. |
| 15 | +DEFAULT_GRADVAC_EPS = 1e-8 |
| 16 | + |
| 17 | + |
| 18 | +def _gradvac_all_layer_group_sizes(encoder: nn.Module) -> tuple[int, ...]: |
| 19 | + """ |
| 20 | + Block sizes per leaf submodule with parameters, matching the ``all_layer`` grouping: iterate |
| 21 | + ``encoder.modules()`` and append the total number of elements in each module that has no child |
| 22 | + submodules and registers at least one parameter. |
| 23 | + """ |
| 24 | + |
| 25 | + return tuple( |
| 26 | + sum(w.numel() for w in module.parameters()) |
| 27 | + for module in encoder.modules() |
| 28 | + if len(module._modules) == 0 and len(module._parameters) > 0 |
| 29 | + ) |
| 30 | + |
| 31 | + |
| 32 | +def _gradvac_all_matrix_group_sizes(shared_params: Iterable[Tensor]) -> tuple[int, ...]: |
| 33 | + """One block per tensor in ``shared_params`` order (``all_matrix`` / shared-parameter layout).""" |
| 34 | + |
| 35 | + return tuple(p.numel() for p in shared_params) |
| 36 | + |
| 37 | + |
| 38 | +class GradVac(Aggregator): |
| 39 | + r""" |
| 40 | + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing Gradient Vaccine |
| 41 | + (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task Optimization in |
| 42 | + Massively Multilingual Models (ICLR 2021 Spotlight) |
| 43 | + <https://openreview.net/forum?id=F1vEjWK-lH_>`_. |
| 44 | +
|
| 45 | + The input matrix is a Jacobian :math:`G \in \mathbb{R}^{M \times D}` whose rows are per-task |
| 46 | + gradients. For each task :math:`i`, rows are visited in a random order; for each other task |
| 47 | + :math:`j` and each parameter block :math:`k`, the cosine correlation :math:`\rho_{ijk}` between |
| 48 | + the (possibly already modified) gradient of task :math:`i` and the original gradient of task |
| 49 | + :math:`j` on that block is compared to an EMA target :math:`\bar{\rho}_{ijk}`. When |
| 50 | + :math:`\rho_{ijk} < \bar{\rho}_{ijk}`, a closed-form correction adds a scaled copy of |
| 51 | + :math:`g_j` to the block of :math:`g_i^{(\mathrm{PC})}`. The EMA is then updated with |
| 52 | + :math:`\bar{\rho}_{ijk} \leftarrow (1-\beta)\bar{\rho}_{ijk} + \beta \rho_{ijk}`. The aggregated |
| 53 | + vector is the sum of the modified rows. |
| 54 | +
|
| 55 | + This aggregator is stateful: it keeps :math:`\bar{\rho}` across calls. Use :meth:`reset` when |
| 56 | + the number of tasks, parameter dimension, grouping, device, or dtype changes. |
| 57 | +
|
| 58 | + **Parameter granularity** is selected by ``group_type`` (integer, default ``0``). It defines how |
| 59 | + each task gradient row is partitioned into blocks :math:`k` so that cosines and EMA targets |
| 60 | + :math:`\bar{\rho}_{ijk}` are computed **per block** rather than only globally: |
| 61 | +
|
| 62 | + * ``0`` — **whole model** (``whole_model``): the full row of length :math:`D` is a single block. |
| 63 | + Cosine similarity is taken between entire task gradients. Do not pass ``encoder`` or |
| 64 | + ``shared_params``. |
| 65 | + * ``1`` — **all layer** (``all_layer``): one block per leaf ``nn.Module`` under ``encoder`` that |
| 66 | + holds parameters (same rule as iterating ``encoder.modules()`` and selecting leaves with |
| 67 | + parameters). Pass ``encoder``; ``shared_params`` must be omitted. |
| 68 | + * ``2`` — **all matrix** (``all_matrix``): one block per tensor in ``shared_params``, in iteration |
| 69 | + order. That order must match how Jacobian columns are laid out for those shared parameters. |
| 70 | + Pass ``shared_params``; ``encoder`` must be omitted. |
| 71 | +
|
| 72 | + :param beta: EMA decay for :math:`\bar{\rho}` (paper default ``0.5``). |
| 73 | + :param group_type: Granularity of parameter grouping; see **Parameter granularity** above. |
| 74 | + :param encoder: Module whose subtree defines ``all_layer`` blocks when ``group_type == 1``. |
| 75 | + :param shared_params: Iterable of parameter tensors defining ``all_matrix`` block sizes and |
| 76 | + order when ``group_type == 2``. It is materialized once at construction. |
| 77 | + :param eps: Small positive constant added to denominators when computing cosines and the |
| 78 | + vaccine weight (default :data:`~torchjd.aggregation.DEFAULT_GRADVAC_EPS`). You may read or |
| 79 | + assign the :attr:`eps` attribute between steps to tune numerical behavior. |
| 80 | +
|
| 81 | + .. note:: |
| 82 | + GradVac is not compatible with autogram: it needs full Jacobian rows and per-block inner |
| 83 | + products, not only a Gram matrix. Only the autojac path is supported. |
| 84 | +
|
| 85 | + .. note:: |
| 86 | + Task-order shuffling uses the global PyTorch RNG (``torch.randperm``). Seed it with |
| 87 | + ``torch.manual_seed`` if you need reproducibility. |
| 88 | + """ |
| 89 | + |
| 90 | + def __init__( |
| 91 | + self, |
| 92 | + beta: float = 0.5, |
| 93 | + group_type: int = 0, |
| 94 | + encoder: nn.Module | None = None, |
| 95 | + shared_params: Iterable[Tensor] | None = None, |
| 96 | + eps: float = DEFAULT_GRADVAC_EPS, |
| 97 | + ) -> None: |
| 98 | + super().__init__() |
| 99 | + if not (0.0 <= beta <= 1.0): |
| 100 | + raise ValueError(f"Parameter `beta` must be in [0, 1]. Found beta={beta!r}.") |
| 101 | + if group_type not in (0, 1, 2): |
| 102 | + raise ValueError( |
| 103 | + "Parameter `group_type` must be 0 (whole_model), 1 (all_layer), or 2 (all_matrix). " |
| 104 | + f"Found group_type={group_type!r}.", |
| 105 | + ) |
| 106 | + params_tuple: tuple[Tensor, ...] = () |
| 107 | + fixed_block_sizes: tuple[int, ...] | None |
| 108 | + if group_type == 0: |
| 109 | + if encoder is not None: |
| 110 | + raise ValueError("Parameter `encoder` must be None when `group_type == 0`.") |
| 111 | + if shared_params is not None: |
| 112 | + raise ValueError("Parameter `shared_params` must be None when `group_type == 0`.") |
| 113 | + fixed_block_sizes = None |
| 114 | + elif group_type == 1: |
| 115 | + if encoder is None: |
| 116 | + raise ValueError("Parameter `encoder` is required when `group_type == 1`.") |
| 117 | + if shared_params is not None: |
| 118 | + raise ValueError("Parameter `shared_params` must be None when `group_type == 1`.") |
| 119 | + fixed_block_sizes = _gradvac_all_layer_group_sizes(encoder) |
| 120 | + if sum(fixed_block_sizes) == 0: |
| 121 | + raise ValueError("Parameter `encoder` has no parameters in any leaf module.") |
| 122 | + else: |
| 123 | + if shared_params is None: |
| 124 | + raise ValueError("Parameter `shared_params` is required when `group_type == 2`.") |
| 125 | + if encoder is not None: |
| 126 | + raise ValueError("Parameter `encoder` must be None when `group_type == 2`.") |
| 127 | + params_tuple = tuple(shared_params) |
| 128 | + if len(params_tuple) == 0: |
| 129 | + raise ValueError( |
| 130 | + "Parameter `shared_params` must be non-empty when `group_type == 2`." |
| 131 | + ) |
| 132 | + fixed_block_sizes = _gradvac_all_matrix_group_sizes(params_tuple) |
| 133 | + |
| 134 | + if eps <= 0.0: |
| 135 | + raise ValueError(f"Parameter `eps` must be positive. Found eps={eps!r}.") |
| 136 | + |
| 137 | + self._beta = beta |
| 138 | + self._group_type = group_type |
| 139 | + self._encoder = encoder |
| 140 | + self._shared_params_len = len(params_tuple) |
| 141 | + self._fixed_block_sizes = fixed_block_sizes |
| 142 | + self._eps = float(eps) |
| 143 | + |
| 144 | + self._rho_t: Tensor | None = None |
| 145 | + self._state_key: tuple[int, int, tuple[int, ...], torch.device, torch.dtype] | None = None |
| 146 | + |
| 147 | + self.register_full_backward_pre_hook(raise_non_differentiable_error) |
| 148 | + |
| 149 | + @property |
| 150 | + def eps(self) -> float: |
| 151 | + """Small positive constant added to denominators for numerical stability.""" |
| 152 | + |
| 153 | + return self._eps |
| 154 | + |
| 155 | + @eps.setter |
| 156 | + def eps(self, value: float) -> None: |
| 157 | + v = float(value) |
| 158 | + if v <= 0.0: |
| 159 | + raise ValueError(f"Attribute `eps` must be positive. Found eps={value!r}.") |
| 160 | + self._eps = v |
| 161 | + |
| 162 | + def reset(self) -> None: |
| 163 | + """Clears EMA state so the next forward starts from zero targets.""" |
| 164 | + |
| 165 | + self._rho_t = None |
| 166 | + self._state_key = None |
| 167 | + |
| 168 | + def __repr__(self) -> str: |
| 169 | + enc = "None" if self._encoder is None else f"{self._encoder.__class__.__name__}(...)" |
| 170 | + sp = "None" if self._group_type != 2 else f"n_params={self._shared_params_len}" |
| 171 | + return ( |
| 172 | + f"{self.__class__.__name__}(beta={self._beta!r}, group_type={self._group_type!r}, " |
| 173 | + f"encoder={enc}, shared_params={sp}, eps={self._eps!r})" |
| 174 | + ) |
| 175 | + |
| 176 | + def _resolve_segment_sizes(self, n: int) -> tuple[int, ...]: |
| 177 | + if self._group_type == 0: |
| 178 | + return (n,) |
| 179 | + assert self._fixed_block_sizes is not None |
| 180 | + sizes = self._fixed_block_sizes |
| 181 | + if sum(sizes) != n: |
| 182 | + raise ValueError( |
| 183 | + "The Jacobian width `D` must equal the sum of block sizes implied by " |
| 184 | + f"`encoder` or `shared_params` for this `group_type`. Found D={n}, " |
| 185 | + f"sum(block_sizes)={sum(sizes)}.", |
| 186 | + ) |
| 187 | + return sizes |
| 188 | + |
| 189 | + def _ensure_state( |
| 190 | + self, |
| 191 | + m: int, |
| 192 | + n: int, |
| 193 | + sizes: tuple[int, ...], |
| 194 | + device: torch.device, |
| 195 | + dtype: torch.dtype, |
| 196 | + ) -> None: |
| 197 | + key = (m, n, sizes, device, dtype) |
| 198 | + num_groups = len(sizes) |
| 199 | + if self._state_key != key or self._rho_t is None: |
| 200 | + self._rho_t = torch.zeros(m, m, num_groups, device=device, dtype=dtype) |
| 201 | + self._state_key = key |
| 202 | + |
| 203 | + def forward(self, matrix: Matrix, /) -> Tensor: |
| 204 | + grads = matrix |
| 205 | + m, n = grads.shape |
| 206 | + if m == 0 or n == 0: |
| 207 | + return torch.zeros(n, dtype=grads.dtype, device=grads.device) |
| 208 | + |
| 209 | + sizes = self._resolve_segment_sizes(n) |
| 210 | + device = grads.device |
| 211 | + dtype = grads.dtype |
| 212 | + self._ensure_state(m, n, sizes, device, dtype) |
| 213 | + assert self._rho_t is not None |
| 214 | + |
| 215 | + rho_t = self._rho_t |
| 216 | + beta = self._beta |
| 217 | + eps = self.eps |
| 218 | + |
| 219 | + pc_grads = grads.clone() |
| 220 | + offsets = [0] |
| 221 | + for s in sizes: |
| 222 | + offsets.append(offsets[-1] + s) |
| 223 | + |
| 224 | + for i in range(m): |
| 225 | + others = [j for j in range(m) if j != i] |
| 226 | + perm = torch.randperm(len(others)) |
| 227 | + order = perm.tolist() |
| 228 | + shuffled_js = [others[idx] for idx in order] |
| 229 | + |
| 230 | + for j in shuffled_js: |
| 231 | + for k in range(len(sizes)): |
| 232 | + beg, end = offsets[k], offsets[k + 1] |
| 233 | + slice_i = pc_grads[i, beg:end] |
| 234 | + slice_j = grads[j, beg:end] |
| 235 | + |
| 236 | + norm_i = slice_i.norm() |
| 237 | + norm_j = slice_j.norm() |
| 238 | + denom = norm_i * norm_j + eps |
| 239 | + rho_ijk = slice_i.dot(slice_j) / denom |
| 240 | + |
| 241 | + bar = rho_t[i, j, k] |
| 242 | + if rho_ijk < bar: |
| 243 | + sqrt_1_rho2 = (1.0 - rho_ijk * rho_ijk).clamp(min=0.0).sqrt() |
| 244 | + sqrt_1_bar2 = (1.0 - bar * bar).clamp(min=0.0).sqrt() |
| 245 | + denom_w = norm_j * sqrt_1_bar2 + eps |
| 246 | + w = norm_i * (bar * sqrt_1_rho2 - rho_ijk * sqrt_1_bar2) / denom_w |
| 247 | + pc_grads[i, beg:end] = slice_i + slice_j * w |
| 248 | + |
| 249 | + rho_t[i, j, k] = (1.0 - beta) * bar + beta * rho_ijk |
| 250 | + |
| 251 | + return pc_grads.sum(dim=0) |
0 commit comments