|
| 1 | +# Reference: implementing a `Scalarizer` |
| 2 | + |
| 3 | +A `Scalarizer` reduces a tensor of values of any shape into a single scalar — the baseline that |
| 4 | +combines *losses* directly (a plain `loss.backward()` then gives the gradient), as opposed to an |
| 5 | +`Aggregator` which combines per-loss *gradients*. Base class: `Scalarizer` in |
| 6 | +`src/torchjd/scalarization/_scalarizer_base.py`. |
| 7 | + |
| 8 | +**Don't work from this file alone — read the closest existing class end-to-end (its `_*.py` + `.rst` |
| 9 | ++ `test_*.py`) and mirror it.** This reference is the map and the non-obvious rules, not a template. |
| 10 | + |
| 11 | +## Contract for the subclass |
| 12 | + |
| 13 | +- Subclasses `Scalarizer` (an `nn.Module`); `forward(self, values: Tensor, /) -> Tensor` returns a |
| 14 | + **0-dim** scalar. |
| 15 | +- The parameter is named **`values`** (positional-only), not `losses` — `Scalarizer` is generic |
| 16 | + (maintainer decision). Accepts **any shape** and reduces over all elements (flatten if needed). |
| 17 | + |
| 18 | +## Files to create / edit (new scalarizer `Foo`) |
| 19 | + |
| 20 | +1. `src/torchjd/scalarization/_foo.py` — the class. |
| 21 | +2. `src/torchjd/scalarization/__init__.py` — add the import + the `__all__` entry. |
| 22 | +3. `docs/source/docs/scalarization/foo.rst` — doc page (mirror `geometric_mean.rst`). |
| 23 | +4. `docs/source/docs/scalarization/index.rst` — add `foo.rst` to the `.. toctree::`. |
| 24 | +5. `tests/unit/scalarization/test_foo.py` — tests. |
| 25 | +6. `CHANGELOG.md` — entry under `[Unreleased] > ### Added`. |
| 26 | +7. *(Only if you adapt third-party code)* license header in `_foo.py` + an entry in `NOTICES`. |
| 27 | + |
| 28 | +## Pick the pattern and mirror it |
| 29 | + |
| 30 | +| Pattern | Mirror | File | |
| 31 | +|---|---|---| |
| 32 | +| Stateless one-liner | `GeometricMean`, `Mean`, `Sum` | `_geometric_mean.py`, `_mean.py`, `_sum.py` | |
| 33 | +| Stateless + preference/reference vector | `STCH`, `COSMOS`, `PBI` | `_stch.py`, `_cosmos.py`, `_pbi.py` | |
| 34 | +| Stateful, trainable parameter | `UW`, `IMTL-L` | `_uw.py`, `_imtl_l.py` | |
| 35 | +| Stateful, non-trainable history buffer | `DWA` | `_dwa.py` | |
| 36 | +| Internal optimizer + multi-call protocol | `FAMO` | `_famo.py` | |
| 37 | + |
| 38 | +### Pattern-specific rules (the things not obvious from one file) |
| 39 | + |
| 40 | +- **Trainable** (`UW`/`IMTL-L`): also subclass `Stateful` (`from torchjd._mixins import Stateful`) |
| 41 | + and implement `reset()`. State is an `nn.Parameter`, init to a neutral default (usually `0`), with |
| 42 | + a `shape: int | Sequence[int]` arg (`Foo(3)` → `(3,)`). Validate `values.shape` at call time |
| 43 | + (`ValueError`). The params are in `.parameters()`, so the user passes them to the optimizer — show |
| 44 | + this in a doctest. A trained per-position param makes it **not** permutation-invariant; don't |
| 45 | + assert it. Add a `shape`-aware `__repr__`. |
| 46 | +- **History buffer** (`DWA`): **no** `nn.Parameter` (`list(Foo().parameters())` must be empty); hold |
| 47 | + state in a `register_buffer` (moves with `.to()`, can be created lazily from the first input |
| 48 | + shape). Provide an explicit update method (e.g. scheduler-like `step()`); `forward` **detaches** |
| 49 | + weights derived from the state; `reset()` clears the buffer. |
| 50 | +- **Internal optimizer / multi-call** (`FAMO`): private `nn.Parameter` (`_w`) with `.grad` cleared |
| 51 | + after each step; a lazily-created internal `torch.optim.Adam`; an `update(new_losses)` method; |
| 52 | + `forward` detaches the weights. Read `_famo.py` before copying. |
| 53 | +- **Preference / reference vector** (`STCH`/`COSMOS`/`PBI`): validate shapes at call time |
| 54 | + (`ValueError`, like `Constant`); flatten `weights`/`values`/`reference` in `forward`. `reference` |
| 55 | + (z*) usually defaults to `0`; `weights` is required or uniform per the paper. Watch `nan`-gradient |
| 56 | + footguns — `‖x‖` has a `0/0` grad at `0` (use `sqrt(‖x‖² + eps)`, see `PBI`); cosine needs an |
| 57 | + eps-clamped denominator (use `torch.nn.functional.cosine_similarity`, see `COSMOS`). Lock with a |
| 58 | + test. |
| 59 | + |
| 60 | +## Docstring conventions |
| 61 | + |
| 62 | +- Use a **raw** `r"""` docstring **only** if it contains LaTeX (`:math:` / `.. math::`) so |
| 63 | + backslashes stay single; plain `"""` otherwise. |
| 64 | +- Start with the `:class:` cross-ref(s) (`:class:`~torchjd.scalarization.Scalarizer``, plus |
| 65 | + `:class:`~torchjd.Stateful`` if stateful); link the paper by full title + URL. |
| 66 | +- Multi-symbol math → a `.. math::` block + a bullet list defining each symbol (not one dense inline |
| 67 | + paragraph; see `STCH`). Document every `:param:`. Add a usage doctest (for stateful methods show |
| 68 | + the optimizer / `step()` / `update()` cadence). Note preconditions in `.. note::` and decide |
| 69 | + whether to enforce (`ValueError`) or let `nan`/`inf` propagate. |
| 70 | + |
| 71 | +## Tests |
| 72 | + |
| 73 | +Mirror `test_geometric_mean.py` (stateless) or `test_uw.py` (stateful). Shared infra in |
| 74 | +`tests/unit/scalarization/`: `_inputs.py` (`shapes = [[], [5], [3, 4], [2, 3, 4]]`, `all_inputs`); |
| 75 | +`_asserts.py` (`assert_returns_scalar`, `assert_grad_flow`, `assert_permutation_invariant`); |
| 76 | +`utils.tensors` helpers (`tensor_`, `rand_`, `randn_`, `ones_`, `zeros_`, `randperm_` — they respect |
| 77 | +`PYTEST_TORCH_DEVICE`/`PYTEST_TORCH_DTYPE`; for stateful instances make a `_foo(shape)` helper that |
| 78 | +`.to(device=DEVICE, dtype=DTYPE)`, see `test_uw.py`). Cover: `test_value` (hand-checked), |
| 79 | +`test_expected_structure` + `test_grad_flow` (parametrized over shapes), `test_permutation_invariant` |
| 80 | +**only if** invariant, the documented edge cases/contracts (e.g. assert `nan` propagates on a bad |
| 81 | +input so a future clamp can't slip in; a `does_not_raise()`/`raises(ValueError)` shape table; `reset` |
| 82 | +clears state; params train / buffer rolls), and `test_representations`. |
| 83 | + |
| 84 | +## CHANGELOG |
| 85 | + |
| 86 | +`- Added `Foo` from [Paper Title](url) (Venue Year), a `Scalarizer` that <one-line description>.` |
| 87 | + |
| 88 | +## Third-party attribution (only if adapting code, e.g. `FAMO`) |
| 89 | + |
| 90 | +Header comment in `_foo.py`: `# Partly adapted from <url> — <License>, Copyright (c) <year> |
| 91 | +<author>. # See NOTICES for the full license text.` plus the full license text in `NOTICES`. |
| 92 | + |
| 93 | +## Verify (from repo root) |
| 94 | + |
| 95 | +```bash |
| 96 | +uv run pytest tests/unit/scalarization -W error -v # new tests |
| 97 | +uv run pytest tests/unit -W error # full unit regression |
| 98 | +uv run ruff check && uv run ruff format --check # lint + format |
| 99 | +uv run make doctest -C docs && uv run make clean -C docs && uv run make html -C docs |
| 100 | +uv run pre-commit run --all-files |
| 101 | +PYTEST_TORCH_DEVICE=cuda:0 uv run pytest tests/unit -W error # GPU (needs CUDA) |
| 102 | +``` |
| 103 | + |
| 104 | +- If `uv run` re-syncs unexpectedly, prefix with `UV_NO_SYNC=1`. Docs build is strict (`-W -n`), so |
| 105 | + an `.rst` title underline must match its title length. |
| 106 | +- `test_dualproj.py::test_permutation_invariant` and `test_upgrad.py::test_permutation_invariant` |
| 107 | + are known flaky off-Linux (~1 float32 ULP, quadprog), pre-existing and unrelated. CI (Linux) is |
| 108 | + the source of truth. |
0 commit comments