|
| 1 | +# Aggregators Reference |
| 2 | + |
| 3 | +This document describes the architecture, conventions, and patterns for implementing a new |
| 4 | +aggregator in `torchjd.aggregation`. Read this alongside the existing implementations |
| 5 | +(`_cr_mogm.py`, `_modo.py`, `_sdmgrad.py`, `_excess_mtl.py`, `_gradvac.py`, `_nash_mtl.py`) |
| 6 | +before writing any code. |
| 7 | + |
| 8 | +--- |
| 9 | + |
| 10 | +## What is an Aggregator? |
| 11 | + |
| 12 | +An `Aggregator` maps a Jacobian matrix `J ∈ ℝ^{m×n}` (m tasks × n parameters) to a single |
| 13 | +gradient vector `g ∈ ℝ^n`. It answers: *given the per-task gradients, which single direction |
| 14 | +should we update the model parameters in?* |
| 15 | + |
| 16 | +Most aggregators work by computing a weight vector `λ ∈ ℝ^m` and returning `λ @ J`. This is |
| 17 | +the `WeightedAggregator` pattern. |
| 18 | + |
| 19 | +--- |
| 20 | + |
| 21 | +## Class Hierarchy |
| 22 | + |
| 23 | +``` |
| 24 | +nn.Module |
| 25 | +└── Aggregator # ABC: validates input, calls forward() |
| 26 | + ├── WeightedAggregator # forward = weighting(J) @ J |
| 27 | + └── GramianWeightedAggregator # forward = gramian_weighting(J @ J^T) @ J |
| 28 | + # (pre-computes the Gramian before calling weighting) |
| 29 | +
|
| 30 | +Weighting[_T] # ABC: takes a statistic _T, returns weights [m] |
| 31 | +├── _MatrixWeighting # _T = Matrix → takes raw Jacobian J |
| 32 | +└── _GramianWeighting # _T = PSDMatrix → takes Gramian J @ J^T |
| 33 | +``` |
| 34 | + |
| 35 | +**Key rule:** if your method needs coordinate-wise information (e.g. per-element gradient |
| 36 | +history like ExcessMTL) or a cross-Jacobian product `J_1 @ J_2^T` (like MoDo, SDMGrad), |
| 37 | +use `_MatrixWeighting`. If your method only needs task-task inner products (the Gramian), |
| 38 | +use `_GramianWeighting`. |
| 39 | + |
| 40 | +--- |
| 41 | + |
| 42 | +## The Paired Class Convention |
| 43 | + |
| 44 | +Almost every algorithm ships as **two classes**: |
| 45 | + |
| 46 | +1. `_FooWeighting(_MatrixWeighting or _GramianWeighting, ...)` — the core computation. |
| 47 | + Prefixed with `_` to signal it is private. Takes either `Matrix` or `PSDMatrix` as input. |
| 48 | +2. `Foo(WeightedAggregator or GramianWeightedAggregator, ...)` — the public-facing aggregator |
| 49 | + wrapping the weighting. |
| 50 | + |
| 51 | +**Exception:** if the method is a *modifier* that wraps any existing weighting (like CR-MOGM), |
| 52 | +ship only the `*Weighting` class. Do not create a convenience aggregator — the user composes |
| 53 | +it themselves with `WeightedAggregator` or `GramianWeightedAggregator`. |
| 54 | + |
| 55 | +--- |
| 56 | + |
| 57 | +## Stateful Aggregators |
| 58 | + |
| 59 | +If the weighting maintains state across calls (e.g. EMA of weights, accumulated gradient |
| 60 | +history, warm-started weights), inherit from `Stateful` and implement `reset()`. |
| 61 | + |
| 62 | +```python |
| 63 | +from torchjd._mixins import Stateful |
| 64 | + |
| 65 | +class _FooWeighting(_MatrixWeighting, Stateful, _NonDifferentiable): |
| 66 | + def __init__(self, ...): |
| 67 | + super().__init__() |
| 68 | + # Register ALL state tensors as buffers — never as plain Python attributes. |
| 69 | + # register_buffer ensures .to(device) moves them correctly. |
| 70 | + self.register_buffer("_my_state", None) # None = lazily initialized |
| 71 | + self._state_key: tuple[...] | None = None # plain attribute, not a tensor |
| 72 | + |
| 73 | + def reset(self) -> None: |
| 74 | + """Clears all state so the next forward starts fresh.""" |
| 75 | + self._my_state = None |
| 76 | + self._state_key = None |
| 77 | +``` |
| 78 | + |
| 79 | +**State key:** use `(m, dtype, device)` when state shape depends only on `m` (number of |
| 80 | +tasks). Use `(m, n, dtype, device)` when state shape is `[m, n]` (e.g. ExcessMTL's |
| 81 | +`_grad_sum`). Auto-reset state when the key changes — never raise an error. |
| 82 | + |
| 83 | +**Lazy initialisation:** since `m` (and sometimes `n`) is only known at `forward` time, |
| 84 | +initialize state tensors in `_ensure_state`, not in `__init__`. |
| 85 | + |
| 86 | +```python |
| 87 | +def _ensure_state(self, matrix: Matrix) -> None: |
| 88 | + key = (matrix.shape[0], matrix.dtype, matrix.device) |
| 89 | + if self._state_key == key and self._my_state is not None: |
| 90 | + return |
| 91 | + m = matrix.shape[0] |
| 92 | + self._my_state = matrix.new_zeros(m) # or new_full, etc. |
| 93 | + self._state_key = key |
| 94 | +``` |
| 95 | + |
| 96 | +--- |
| 97 | + |
| 98 | +## Non-Differentiable Weightings |
| 99 | + |
| 100 | +If the weights are computed from detached statistics (no gradient should flow through the |
| 101 | +weighting), inherit from `_NonDifferentiable`: |
| 102 | + |
| 103 | +```python |
| 104 | +from torchjd.aggregation._mixins import _NonDifferentiable |
| 105 | +``` |
| 106 | + |
| 107 | +This is the case for most stateful aggregators (GradVac, NashMTL, CR-MOGM, MoDo, SDMGrad, |
| 108 | +ExcessMTL). The `_NonDifferentiable` mixin registers a backward hook on the Aggregator that |
| 109 | +raises a clear error if a user accidentally tries to backprop through the weighting. |
| 110 | + |
| 111 | +--- |
| 112 | + |
| 113 | +## Property-Based Validation |
| 114 | + |
| 115 | +All constructor parameters must be exposed as properties with validating setters. This |
| 116 | +allows safe mutation after construction and gives immediate, clear error messages. |
| 117 | + |
| 118 | +```python |
| 119 | +@property |
| 120 | +def alpha(self) -> float: |
| 121 | + return self._alpha |
| 122 | + |
| 123 | +@alpha.setter |
| 124 | +def alpha(self, value: float) -> None: |
| 125 | + if not (0.0 <= value <= 1.0): |
| 126 | + raise ValueError( |
| 127 | + f"Attribute `alpha` must be in [0, 1]. Found alpha={value!r}." |
| 128 | + ) |
| 129 | + self._alpha = value |
| 130 | +``` |
| 131 | + |
| 132 | +Common constraints: |
| 133 | +- Step sizes / learning rates: `> 0` |
| 134 | +- Momentum: `in [0, 1)` |
| 135 | +- Regularisation coefficients: `>= 0` |
| 136 | +- Iteration counts: `>= 1` (int) |
| 137 | +- Preference vectors: `ndim == 1`, non-negative, sums to 1 |
| 138 | + |
| 139 | +--- |
| 140 | + |
| 141 | +## `set_losses` Setter |
| 142 | + |
| 143 | +If the method needs raw loss values (not just the Jacobian), expose them via a setter |
| 144 | +rather than passing them to `forward`. CONTRIBUTING.md explicitly anticipates this pattern. |
| 145 | + |
| 146 | +```python |
| 147 | +def set_losses(self, losses: Tensor) -> None: |
| 148 | + """Must be called before each forward with the current per-task losses.""" |
| 149 | + self._losses = losses.detach() |
| 150 | +``` |
| 151 | + |
| 152 | +Document clearly in the docstring that users must call `set_losses` before each training |
| 153 | +step. |
| 154 | + |
| 155 | +--- |
| 156 | + |
| 157 | +## `__repr__` |
| 158 | + |
| 159 | +Override `__repr__` to show all hyperparameters. Do not override `__str__` — `Aggregator` |
| 160 | +already defines `__str__` to return just the class name, and `Weighting` inherits |
| 161 | +`nn.Module.__repr__` as its `__str__`. Concrete example: |
| 162 | + |
| 163 | +```python |
| 164 | +def __repr__(self) -> str: |
| 165 | + return ( |
| 166 | + f"{self.__class__.__name__}(" |
| 167 | + f"alpha={self.alpha!r}, " |
| 168 | + f"rho={self.rho!r})" |
| 169 | + ) |
| 170 | +``` |
| 171 | + |
| 172 | +--- |
| 173 | + |
| 174 | +## Attribution |
| 175 | + |
| 176 | +If any code is adapted from an external implementation, add: |
| 177 | +1. A comment at the top of `_foo.py`: |
| 178 | + ```python |
| 179 | + # Partly adapted from https://github.com/... — MIT License, Copyright (c) ... |
| 180 | + # See NOTICES for the full license text. |
| 181 | + ``` |
| 182 | +2. An entry in the `NOTICES` file following the existing template. |
| 183 | + |
| 184 | +--- |
| 185 | + |
| 186 | +## Files to Create or Modify |
| 187 | + |
| 188 | +| File | Action | |
| 189 | +|---|---| |
| 190 | +| `src/torchjd/aggregation/_foo.py` | Create — the implementation | |
| 191 | +| `src/torchjd/aggregation/__init__.py` | Add import + `__all__` entry (alphabetical) | |
| 192 | +| `tests/unit/aggregation/test_foo.py` | Create — mirror `test_modo.py` or `test_excess_mtl.py` | |
| 193 | +| `docs/source/docs/aggregation/foo.rst` | Create — mirror `modo.rst` | |
| 194 | +| `docs/source/docs/aggregation/index.rst` | Add `foo.rst` to toctree (alphabetical) | |
| 195 | +| `CHANGELOG.md` | Add entry under `[Unreleased] → Added` | |
| 196 | +| `NOTICES` | Add entry if code is adapted from an external source | |
| 197 | + |
| 198 | +--- |
| 199 | + |
| 200 | +## Tests |
| 201 | + |
| 202 | +Mirror `tests/unit/aggregation/test_modo.py` (for matrix-weighting) or |
| 203 | +`tests/unit/aggregation/test_excess_mtl.py` (for stateful + set_losses). At minimum cover: |
| 204 | + |
| 205 | +- `test_representations` — verify `repr(...)` string exactly |
| 206 | +- `test_expected_structure_*` — use `assert_expected_structure` from `_asserts.py`, |
| 207 | + parametrize over `typical_matrices + scaled_matrices` from `_inputs.py` |
| 208 | +- `test_reset_restores_first_step_behavior` — call → call → `reset()` → call; third == first |
| 209 | +- setter tests — `*_accepts_valid` and `*_rejects_*` for every hyperparameter |
| 210 | +- `test_output_lies_on_simplex` — returned weights sum to 1 and are ≥ 0 |
| 211 | +- `test_update_recurrence` — manually verify the formula for one step |
| 212 | +- `test_two_consecutive_steps` — verify warm-start carry-over if stateful |
| 213 | +- `test_changing_m_auto_resets` — state resets when number of tasks changes |
| 214 | +- `test_non_differentiable` — weights have no grad if `_NonDifferentiable` |
| 215 | +- `test_zero_columns` — `(m, 0)` input → output shape `(0,)` |
| 216 | + |
| 217 | +**Always use `utils.tensors` partials** (`randn_`, `tensor_`, `ones_`, etc.) — never raw |
| 218 | +`torch.*`. This ensures tests run on CUDA/float64 via environment variables. |
| 219 | + |
| 220 | +--- |
| 221 | + |
| 222 | +## Examples by Pattern |
| 223 | + |
| 224 | +### Simple stateful wrapper (CR-MOGM) |
| 225 | +Wraps any `Weighting[_T]` generically, applies EMA to the output weights. |
| 226 | +`_CRMOGMWeighting(Weighting[_T], Stateful)` — generic over `_T`, no convenience aggregator. |
| 227 | +State: `_lambda: Tensor | None`, `_initial_weights: Tensor | None`. |
| 228 | + |
| 229 | +### Cross-Gramian / double-sampling (MoDo, SDMGrad) |
| 230 | +Receives `A = J_1 @ J_2^T` (not PSD → use `_MatrixWeighting`, NOT `_GramianWeighting`). |
| 231 | +Users compute `A` via `autojac.jac` on two independent mini-batches. |
| 232 | +State: `_w: Tensor | None` (warm-started weights). |
| 233 | +Returns `w + λ·w̃` (SDMGrad) or `w` (MoDo) normalized to sum = 1. |
| 234 | + |
| 235 | +### Stateful with Jacobian-sized state (ExcessMTL) |
| 236 | +`_ExcessMTLWeighting(_MatrixWeighting, Stateful, _NonDifferentiable)`. |
| 237 | +State key: `(m, n, dtype, device)` — because `_grad_sum` has shape `[m, n]`. |
| 238 | +Memory warning: `_grad_sum` is Jacobian-sized, held persistently. Document this. |
| 239 | +Uses `set_losses` if loss values are needed (ExcessMTL does not, but GradNorm would). |
0 commit comments