Skip to content

Commit c9cfedb

Browse files
authored
Merge branch 'main' into feat/aggregators-skill
2 parents 2916067 + 1d854be commit c9cfedb

2 files changed

Lines changed: 237 additions & 0 deletions

File tree

skills/implement-method/SKILL.md

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
---
2+
name: implement-method
3+
description: Implements a new method (scalarizer or aggregator) in TorchJD, starting from the research produced by the research-method skill and following the established file-by-file conventions. Use when a contributor wants to add the actual implementation of a scalarizer or aggregator that has already been investigated and listed in the tracking issues.
4+
---
5+
6+
# Implement new method
7+
8+
This skill implements a new method by recovering its research, comparing the paper against the
9+
existing implementations, settling the non-standard parts of its interface, and producing the full
10+
set of TorchJD files (class, docs, tests, changelog) that match the established conventions.
11+
12+
It is the companion of the `research-method` skill: that one investigates a method and records a row
13+
in a tracking issue; this one turns that row into a merged implementation.
14+
15+
**For agents:** invoke as `/implement-method method-name (paper-name)` (e.g.
16+
`/implement-method stch (Smooth Tchebycheff Scalarization for Multi-objective Optimization)`).
17+
If no method name is provided, ask the user for the name of the method and the title of the paper.
18+
19+
**For humans:** follow the numbered steps below to guide your development.
20+
21+
---
22+
23+
## Instructions
24+
25+
### Step 1: Recover the research context
26+
27+
Determine whether the method should be a **scalarizer** or an **aggregator**, then read everything
28+
the `research-method` skill already found about it:
29+
30+
- Scalarizers are tracked in https://github.com/SimplexLab/TorchJD/issues/667, aggregators in
31+
https://github.com/SimplexLab/TorchJD/issues/665. Fetch the relevant issue and find the row for
32+
this method. Read every column:
33+
- **Ref** — the paper (open it; you will need the exact equations / algorithm).
34+
- **Stateful** — whether and how the method holds state.
35+
- **Existing implementations** — links to the official repo (if any) and the best-known
36+
third-party ones (LibMTL, libmoon, pymoo, ...), ideally with the exact file(s) and line(s).
37+
- **Special Remarks** — may link to a full research write-up (e.g. a `claude.ai` share produced by
38+
`research-method`). Read it if present.
39+
- The most valuable inputs are the **non-standard interface aspects** uncovered during research:
40+
statefulness, trainable parameters, randomness, warm-up / history buffers, statistics beyond the
41+
`forward` values (e.g. per-task losses for an aggregator), and preconditions. If these are not
42+
fully captured in the issue, **ask the user to share the `research-method` findings** before
43+
continuing. Do not guess them.
44+
45+
If the method is not in the tracking issue yet, run `research-method` first.
46+
47+
### Step 2: Load the implementation reference for this method type
48+
49+
Read only the reference matching the method type, to keep context focused:
50+
51+
- **Scalarizer** → read `references/scalarizers.md`.
52+
- **Aggregator** → read `references/aggregators.md`.
53+
54+
Each reference lists the exact files to create/edit and the TorchJD-specific conventions, with the
55+
closest existing methods to mirror.
56+
57+
### Step 3: Compare the paper with the existing implementations
58+
59+
Always do this — it is the step we invariably end up needing. Read the relevant equations / the
60+
algorithm box in the paper, then read the official and best-known third-party implementations at the
61+
exact files/lines from the tracking row.
62+
63+
Reconcile any discrepancies between them. The ones that most often bite:
64+
65+
- **Minimization vs maximization.** TorchJD minimizes losses; much MOO/evolutionary work is written
66+
for maximization, with the minimization form buried in a footnote. Find it, and check the sign of
67+
every reference / ideal-point subtraction.
68+
- **Normalization.** A direction or weight vector may be normalized (`w / ‖w‖`) in the code but not
69+
the paper, or vice versa.
70+
- **Dead arguments.** An impl may accept a parameter (e.g. a reference point) yet silently ignore it.
71+
- **Droppable terms.** An `abs` / `clamp` / `max(0, ·)` in the paper may be unnecessary under the
72+
method's preconditions (e.g. non-negative weights); drop it only with a justification.
73+
- **Other:** an extra factor, an init value, a stabilization / epsilon trick.
74+
75+
Decide which to follow, note **why**, and surface the disagreement to the user — the implementation
76+
should be faithful to a clearly-stated source, not an unexplained blend.
77+
78+
### Step 4: Settle the interface and design decisions
79+
80+
Using the research findings (Step 1) and the comparison (Step 3), map each non-standard aspect onto
81+
the closest existing pattern from the reference loaded in Step 2 (statefulness, trainable parameters,
82+
an internal optimizer, a preference/reference vector, ...). Then settle, for any method type:
83+
84+
- **Preconditions** (e.g. positivity): enforce them (raise `ValueError`) or only document them, and
85+
how `nan`/`inf` should propagate.
86+
- Which constructor arguments are **required vs optional**, and their **defaults**.
87+
88+
List the non-standard parts and your proposed handling, and **confirm the design with the user
89+
before writing code.** This is where most of the maintainer review happens, so settle it up front.
90+
91+
### Step 5: Implement the method
92+
93+
Follow the file-by-file checklist in the reference loaded at Step 2. Match the style, naming, and
94+
conventions of the closest existing method. If you adapt code from a third-party repository, add the
95+
license header to the source file and an entry to `NOTICES` (see the reference).
96+
97+
### Step 6: Verify
98+
99+
Run the checks listed in the reference (unit tests with `-W error`, lint, and the docs
100+
build/doctest). GPU tests require a CUDA device; if you cannot run them, provide the exact commands
101+
for the user to run on their GPU and report back the results.
102+
103+
### Step 7: Self-review the code you produced
104+
105+
Before opening anything, re-read your own diff against the requirements and improve what can be
106+
improved. Check that:
107+
108+
- The class follows the closest existing method's conventions (the reference's checklist): correct
109+
base class(es), `forward(self, values, /)` returning a 0-dim scalar, shape validation, `reset()`
110+
for stateful methods, a correct `__repr__`, and the docstring conventions (`r"""` only with LaTeX,
111+
`:class:` cross-ref, `.. math::` + bullet list, a usage doctest, `:param:` for each argument).
112+
- The design decisions settled in Step 4 are actually reflected in the code, and any discrepancy
113+
between the paper and the existing implementations (Step 3) is resolved deliberately, with a
114+
comment or docstring note where it is non-obvious.
115+
- The tests cover the documented edge cases and contracts, not just the happy path.
116+
- All six files are present and consistent (class, `__init__.py`, `.rst`, toctree, test,
117+
`CHANGELOG.md`), plus `NOTICES` + a license header if you adapted code.
118+
119+
Apply the fixes you find, then re-run the relevant checks from Step 6.
120+
121+
### Step 8: Open a draft PR
122+
123+
Create a new branch, commit, and open a **draft** pull request targeting `main`, following the
124+
repository's PR conventions (a `CHANGELOG.md` entry under `[Unreleased] > ### Added`; when asked for
125+
a PR description, output raw GitHub-flavored markdown in a fenced code block, with GitHub math syntax
126+
`$...$` / `$$...$$` and no em dashes). Keep it a draft so the contributor can read the code
127+
themselves before requesting maintainer review. Return the PR URL when done.
128+
129+
---
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Reference: implementing a `Scalarizer`
2+
3+
A `Scalarizer` reduces a tensor of values of any shape into a single scalar — the baseline that
4+
combines *losses* directly (a plain `loss.backward()` then gives the gradient), as opposed to an
5+
`Aggregator` which combines per-loss *gradients*. Base class: `Scalarizer` in
6+
`src/torchjd/scalarization/_scalarizer_base.py`.
7+
8+
**Don't work from this file alone — read the closest existing class end-to-end (its `_*.py` + `.rst`
9+
+ `test_*.py`) and mirror it.** This reference is the map and the non-obvious rules, not a template.
10+
11+
## Contract for the subclass
12+
13+
- Subclasses `Scalarizer` (an `nn.Module`); `forward(self, values: Tensor, /) -> Tensor` returns a
14+
**0-dim** scalar.
15+
- The parameter is named **`values`** (positional-only), not `losses``Scalarizer` is generic
16+
(maintainer decision). Accepts **any shape** and reduces over all elements (flatten if needed).
17+
18+
## Files to create / edit (new scalarizer `Foo`)
19+
20+
1. `src/torchjd/scalarization/_foo.py` — the class.
21+
2. `src/torchjd/scalarization/__init__.py` — add the import + the `__all__` entry.
22+
3. `docs/source/docs/scalarization/foo.rst` — doc page (mirror `geometric_mean.rst`).
23+
4. `docs/source/docs/scalarization/index.rst` — add `foo.rst` to the `.. toctree::`.
24+
5. `tests/unit/scalarization/test_foo.py` — tests.
25+
6. `CHANGELOG.md` — entry under `[Unreleased] > ### Added`.
26+
7. *(Only if you adapt third-party code)* license header in `_foo.py` + an entry in `NOTICES`.
27+
28+
## Pick the pattern and mirror it
29+
30+
| Pattern | Mirror | File |
31+
|---|---|---|
32+
| Stateless one-liner | `GeometricMean`, `Mean`, `Sum` | `_geometric_mean.py`, `_mean.py`, `_sum.py` |
33+
| Stateless + preference/reference vector | `STCH`, `COSMOS`, `PBI` | `_stch.py`, `_cosmos.py`, `_pbi.py` |
34+
| Stateful, trainable parameter | `UW`, `IMTL-L` | `_uw.py`, `_imtl_l.py` |
35+
| Stateful, non-trainable history buffer | `DWA` | `_dwa.py` |
36+
| Internal optimizer + multi-call protocol | `FAMO` | `_famo.py` |
37+
38+
### Pattern-specific rules (the things not obvious from one file)
39+
40+
- **Trainable** (`UW`/`IMTL-L`): also subclass `Stateful` (`from torchjd._mixins import Stateful`)
41+
and implement `reset()`. State is an `nn.Parameter`, init to a neutral default (usually `0`), with
42+
a `shape: int | Sequence[int]` arg (`Foo(3)``(3,)`). Validate `values.shape` at call time
43+
(`ValueError`). The params are in `.parameters()`, so the user passes them to the optimizer — show
44+
this in a doctest. A trained per-position param makes it **not** permutation-invariant; don't
45+
assert it. Add a `shape`-aware `__repr__`.
46+
- **History buffer** (`DWA`): **no** `nn.Parameter` (`list(Foo().parameters())` must be empty); hold
47+
state in a `register_buffer` (moves with `.to()`, can be created lazily from the first input
48+
shape). Provide an explicit update method (e.g. scheduler-like `step()`); `forward` **detaches**
49+
weights derived from the state; `reset()` clears the buffer.
50+
- **Internal optimizer / multi-call** (`FAMO`): private `nn.Parameter` (`_w`) with `.grad` cleared
51+
after each step; a lazily-created internal `torch.optim.Adam`; an `update(new_losses)` method;
52+
`forward` detaches the weights. Read `_famo.py` before copying.
53+
- **Preference / reference vector** (`STCH`/`COSMOS`/`PBI`): validate shapes at call time
54+
(`ValueError`, like `Constant`); flatten `weights`/`values`/`reference` in `forward`. `reference`
55+
(z*) usually defaults to `0`; `weights` is required or uniform per the paper. Watch `nan`-gradient
56+
footguns — `‖x‖` has a `0/0` grad at `0` (use `sqrt(‖x‖² + eps)`, see `PBI`); cosine needs an
57+
eps-clamped denominator (use `torch.nn.functional.cosine_similarity`, see `COSMOS`). Lock with a
58+
test.
59+
60+
## Docstring conventions
61+
62+
- Use a **raw** `r"""` docstring **only** if it contains LaTeX (`:math:` / `.. math::`) so
63+
backslashes stay single; plain `"""` otherwise.
64+
- Start with the `:class:` cross-ref(s) (`:class:`~torchjd.scalarization.Scalarizer``, plus
65+
`:class:`~torchjd.Stateful`` if stateful); link the paper by full title + URL.
66+
- Multi-symbol math → a `.. math::` block + a bullet list defining each symbol (not one dense inline
67+
paragraph; see `STCH`). Document every `:param:`. Add a usage doctest (for stateful methods show
68+
the optimizer / `step()` / `update()` cadence). Note preconditions in `.. note::` and decide
69+
whether to enforce (`ValueError`) or let `nan`/`inf` propagate.
70+
71+
## Tests
72+
73+
Mirror `test_geometric_mean.py` (stateless) or `test_uw.py` (stateful). Shared infra in
74+
`tests/unit/scalarization/`: `_inputs.py` (`shapes = [[], [5], [3, 4], [2, 3, 4]]`, `all_inputs`);
75+
`_asserts.py` (`assert_returns_scalar`, `assert_grad_flow`, `assert_permutation_invariant`);
76+
`utils.tensors` helpers (`tensor_`, `rand_`, `randn_`, `ones_`, `zeros_`, `randperm_` — they respect
77+
`PYTEST_TORCH_DEVICE`/`PYTEST_TORCH_DTYPE`; for stateful instances make a `_foo(shape)` helper that
78+
`.to(device=DEVICE, dtype=DTYPE)`, see `test_uw.py`). Cover: `test_value` (hand-checked),
79+
`test_expected_structure` + `test_grad_flow` (parametrized over shapes), `test_permutation_invariant`
80+
**only if** invariant, the documented edge cases/contracts (e.g. assert `nan` propagates on a bad
81+
input so a future clamp can't slip in; a `does_not_raise()`/`raises(ValueError)` shape table; `reset`
82+
clears state; params train / buffer rolls), and `test_representations`.
83+
84+
## CHANGELOG
85+
86+
`- Added `Foo` from [Paper Title](url) (Venue Year), a `Scalarizer` that <one-line description>.`
87+
88+
## Third-party attribution (only if adapting code, e.g. `FAMO`)
89+
90+
Header comment in `_foo.py`: `# Partly adapted from <url> — <License>, Copyright (c) <year>
91+
<author>. # See NOTICES for the full license text.` plus the full license text in `NOTICES`.
92+
93+
## Verify (from repo root)
94+
95+
```bash
96+
uv run pytest tests/unit/scalarization -W error -v # new tests
97+
uv run pytest tests/unit -W error # full unit regression
98+
uv run ruff check && uv run ruff format --check # lint + format
99+
uv run make doctest -C docs && uv run make clean -C docs && uv run make html -C docs
100+
uv run pre-commit run --all-files
101+
PYTEST_TORCH_DEVICE=cuda:0 uv run pytest tests/unit -W error # GPU (needs CUDA)
102+
```
103+
104+
- If `uv run` re-syncs unexpectedly, prefix with `UV_NO_SYNC=1`. Docs build is strict (`-W -n`), so
105+
an `.rst` title underline must match its title length.
106+
- Treat CI as the source of truth. A pre-existing test unrelated to your change can fail by
107+
a tiny float tolerance on other platforms; confirm your new tests pass and that nothing you
108+
touched regressed, rather than chasing it.

0 commit comments

Comments
 (0)