Skip to content

Commit 2916067

Browse files
committed
chore: Add aggregators reference for implement-method skill
1 parent f5e67dc commit 2916067

1 file changed

Lines changed: 239 additions & 0 deletions

File tree

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
# Aggregators Reference
2+
3+
This document describes the architecture, conventions, and patterns for implementing a new
4+
aggregator in `torchjd.aggregation`. Read this alongside the existing implementations
5+
(`_cr_mogm.py`, `_modo.py`, `_sdmgrad.py`, `_excess_mtl.py`, `_gradvac.py`, `_nash_mtl.py`)
6+
before writing any code.
7+
8+
---
9+
10+
## What is an Aggregator?
11+
12+
An `Aggregator` maps a Jacobian matrix `J ∈ ℝ^{m×n}` (m tasks × n parameters) to a single
13+
gradient vector `g ∈ ℝ^n`. It answers: *given the per-task gradients, which single direction
14+
should we update the model parameters in?*
15+
16+
Most aggregators work by computing a weight vector `λ ∈ ℝ^m` and returning `λ @ J`. This is
17+
the `WeightedAggregator` pattern.
18+
19+
---
20+
21+
## Class Hierarchy
22+
23+
```
24+
nn.Module
25+
└── Aggregator # ABC: validates input, calls forward()
26+
├── WeightedAggregator # forward = weighting(J) @ J
27+
└── GramianWeightedAggregator # forward = gramian_weighting(J @ J^T) @ J
28+
# (pre-computes the Gramian before calling weighting)
29+
30+
Weighting[_T] # ABC: takes a statistic _T, returns weights [m]
31+
├── _MatrixWeighting # _T = Matrix → takes raw Jacobian J
32+
└── _GramianWeighting # _T = PSDMatrix → takes Gramian J @ J^T
33+
```
34+
35+
**Key rule:** if your method needs coordinate-wise information (e.g. per-element gradient
36+
history like ExcessMTL) or a cross-Jacobian product `J_1 @ J_2^T` (like MoDo, SDMGrad),
37+
use `_MatrixWeighting`. If your method only needs task-task inner products (the Gramian),
38+
use `_GramianWeighting`.
39+
40+
---
41+
42+
## The Paired Class Convention
43+
44+
Almost every algorithm ships as **two classes**:
45+
46+
1. `_FooWeighting(_MatrixWeighting or _GramianWeighting, ...)` — the core computation.
47+
Prefixed with `_` to signal it is private. Takes either `Matrix` or `PSDMatrix` as input.
48+
2. `Foo(WeightedAggregator or GramianWeightedAggregator, ...)` — the public-facing aggregator
49+
wrapping the weighting.
50+
51+
**Exception:** if the method is a *modifier* that wraps any existing weighting (like CR-MOGM),
52+
ship only the `*Weighting` class. Do not create a convenience aggregator — the user composes
53+
it themselves with `WeightedAggregator` or `GramianWeightedAggregator`.
54+
55+
---
56+
57+
## Stateful Aggregators
58+
59+
If the weighting maintains state across calls (e.g. EMA of weights, accumulated gradient
60+
history, warm-started weights), inherit from `Stateful` and implement `reset()`.
61+
62+
```python
63+
from torchjd._mixins import Stateful
64+
65+
class _FooWeighting(_MatrixWeighting, Stateful, _NonDifferentiable):
66+
def __init__(self, ...):
67+
super().__init__()
68+
# Register ALL state tensors as buffers — never as plain Python attributes.
69+
# register_buffer ensures .to(device) moves them correctly.
70+
self.register_buffer("_my_state", None) # None = lazily initialized
71+
self._state_key: tuple[...] | None = None # plain attribute, not a tensor
72+
73+
def reset(self) -> None:
74+
"""Clears all state so the next forward starts fresh."""
75+
self._my_state = None
76+
self._state_key = None
77+
```
78+
79+
**State key:** use `(m, dtype, device)` when state shape depends only on `m` (number of
80+
tasks). Use `(m, n, dtype, device)` when state shape is `[m, n]` (e.g. ExcessMTL's
81+
`_grad_sum`). Auto-reset state when the key changes — never raise an error.
82+
83+
**Lazy initialisation:** since `m` (and sometimes `n`) is only known at `forward` time,
84+
initialize state tensors in `_ensure_state`, not in `__init__`.
85+
86+
```python
87+
def _ensure_state(self, matrix: Matrix) -> None:
88+
key = (matrix.shape[0], matrix.dtype, matrix.device)
89+
if self._state_key == key and self._my_state is not None:
90+
return
91+
m = matrix.shape[0]
92+
self._my_state = matrix.new_zeros(m) # or new_full, etc.
93+
self._state_key = key
94+
```
95+
96+
---
97+
98+
## Non-Differentiable Weightings
99+
100+
If the weights are computed from detached statistics (no gradient should flow through the
101+
weighting), inherit from `_NonDifferentiable`:
102+
103+
```python
104+
from torchjd.aggregation._mixins import _NonDifferentiable
105+
```
106+
107+
This is the case for most stateful aggregators (GradVac, NashMTL, CR-MOGM, MoDo, SDMGrad,
108+
ExcessMTL). The `_NonDifferentiable` mixin registers a backward hook on the Aggregator that
109+
raises a clear error if a user accidentally tries to backprop through the weighting.
110+
111+
---
112+
113+
## Property-Based Validation
114+
115+
All constructor parameters must be exposed as properties with validating setters. This
116+
allows safe mutation after construction and gives immediate, clear error messages.
117+
118+
```python
119+
@property
120+
def alpha(self) -> float:
121+
return self._alpha
122+
123+
@alpha.setter
124+
def alpha(self, value: float) -> None:
125+
if not (0.0 <= value <= 1.0):
126+
raise ValueError(
127+
f"Attribute `alpha` must be in [0, 1]. Found alpha={value!r}."
128+
)
129+
self._alpha = value
130+
```
131+
132+
Common constraints:
133+
- Step sizes / learning rates: `> 0`
134+
- Momentum: `in [0, 1)`
135+
- Regularisation coefficients: `>= 0`
136+
- Iteration counts: `>= 1` (int)
137+
- Preference vectors: `ndim == 1`, non-negative, sums to 1
138+
139+
---
140+
141+
## `set_losses` Setter
142+
143+
If the method needs raw loss values (not just the Jacobian), expose them via a setter
144+
rather than passing them to `forward`. CONTRIBUTING.md explicitly anticipates this pattern.
145+
146+
```python
147+
def set_losses(self, losses: Tensor) -> None:
148+
"""Must be called before each forward with the current per-task losses."""
149+
self._losses = losses.detach()
150+
```
151+
152+
Document clearly in the docstring that users must call `set_losses` before each training
153+
step.
154+
155+
---
156+
157+
## `__repr__`
158+
159+
Override `__repr__` to show all hyperparameters. Do not override `__str__``Aggregator`
160+
already defines `__str__` to return just the class name, and `Weighting` inherits
161+
`nn.Module.__repr__` as its `__str__`. Concrete example:
162+
163+
```python
164+
def __repr__(self) -> str:
165+
return (
166+
f"{self.__class__.__name__}("
167+
f"alpha={self.alpha!r}, "
168+
f"rho={self.rho!r})"
169+
)
170+
```
171+
172+
---
173+
174+
## Attribution
175+
176+
If any code is adapted from an external implementation, add:
177+
1. A comment at the top of `_foo.py`:
178+
```python
179+
# Partly adapted from https://github.com/... — MIT License, Copyright (c) ...
180+
# See NOTICES for the full license text.
181+
```
182+
2. An entry in the `NOTICES` file following the existing template.
183+
184+
---
185+
186+
## Files to Create or Modify
187+
188+
| File | Action |
189+
|---|---|
190+
| `src/torchjd/aggregation/_foo.py` | Create — the implementation |
191+
| `src/torchjd/aggregation/__init__.py` | Add import + `__all__` entry (alphabetical) |
192+
| `tests/unit/aggregation/test_foo.py` | Create — mirror `test_modo.py` or `test_excess_mtl.py` |
193+
| `docs/source/docs/aggregation/foo.rst` | Create — mirror `modo.rst` |
194+
| `docs/source/docs/aggregation/index.rst` | Add `foo.rst` to toctree (alphabetical) |
195+
| `CHANGELOG.md` | Add entry under `[Unreleased] → Added` |
196+
| `NOTICES` | Add entry if code is adapted from an external source |
197+
198+
---
199+
200+
## Tests
201+
202+
Mirror `tests/unit/aggregation/test_modo.py` (for matrix-weighting) or
203+
`tests/unit/aggregation/test_excess_mtl.py` (for stateful + set_losses). At minimum cover:
204+
205+
- `test_representations` — verify `repr(...)` string exactly
206+
- `test_expected_structure_*` — use `assert_expected_structure` from `_asserts.py`,
207+
parametrize over `typical_matrices + scaled_matrices` from `_inputs.py`
208+
- `test_reset_restores_first_step_behavior` — call → call → `reset()` → call; third == first
209+
- setter tests — `*_accepts_valid` and `*_rejects_*` for every hyperparameter
210+
- `test_output_lies_on_simplex` — returned weights sum to 1 and are ≥ 0
211+
- `test_update_recurrence` — manually verify the formula for one step
212+
- `test_two_consecutive_steps` — verify warm-start carry-over if stateful
213+
- `test_changing_m_auto_resets` — state resets when number of tasks changes
214+
- `test_non_differentiable` — weights have no grad if `_NonDifferentiable`
215+
- `test_zero_columns``(m, 0)` input → output shape `(0,)`
216+
217+
**Always use `utils.tensors` partials** (`randn_`, `tensor_`, `ones_`, etc.) — never raw
218+
`torch.*`. This ensures tests run on CUDA/float64 via environment variables.
219+
220+
---
221+
222+
## Examples by Pattern
223+
224+
### Simple stateful wrapper (CR-MOGM)
225+
Wraps any `Weighting[_T]` generically, applies EMA to the output weights.
226+
`_CRMOGMWeighting(Weighting[_T], Stateful)` — generic over `_T`, no convenience aggregator.
227+
State: `_lambda: Tensor | None`, `_initial_weights: Tensor | None`.
228+
229+
### Cross-Gramian / double-sampling (MoDo, SDMGrad)
230+
Receives `A = J_1 @ J_2^T` (not PSD → use `_MatrixWeighting`, NOT `_GramianWeighting`).
231+
Users compute `A` via `autojac.jac` on two independent mini-batches.
232+
State: `_w: Tensor | None` (warm-started weights).
233+
Returns `w + λ·w̃` (SDMGrad) or `w` (MoDo) normalized to sum = 1.
234+
235+
### Stateful with Jacobian-sized state (ExcessMTL)
236+
`_ExcessMTLWeighting(_MatrixWeighting, Stateful, _NonDifferentiable)`.
237+
State key: `(m, n, dtype, device)` — because `_grad_sum` has shape `[m, n]`.
238+
Memory warning: `_grad_sum` is Jacobian-sized, held persistently. Document this.
239+
Uses `set_losses` if loss values are needed (ExcessMTL does not, but GradNorm would).

0 commit comments

Comments
 (0)