Skip to content

Commit f58d00d

Browse files
committed
implement skill
1 parent 89b1782 commit f58d00d

2 files changed

Lines changed: 213 additions & 0 deletions

File tree

skills/implement-method/SKILL.md

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
---
2+
name: implement-method
3+
description: Implements a new method (scalarizer or aggregator) in TorchJD, starting from the research produced by the research-method skill and following the established file-by-file conventions. Use when a contributor wants to add the actual implementation of a scalarizer or aggregator that has already been investigated and listed in the tracking issues.
4+
---
5+
6+
# Implement new method
7+
8+
This skill implements a new method by recovering its research, comparing the paper against the
9+
existing implementations, settling the non-standard parts of its interface, and producing the full
10+
set of TorchJD files (class, docs, tests, changelog) that match the established conventions.
11+
12+
It is the companion of the `research-method` skill: that one investigates a method and records a row
13+
in a tracking issue; this one turns that row into a merged implementation.
14+
15+
**For agents:** invoke as `/implement-method method-name (paper-name)` (e.g.
16+
`/implement-method stch (Smooth Tchebycheff Scalarization for Multi-objective Optimization)`).
17+
If no method name is provided, ask the user for the name of the method and the title of the paper.
18+
19+
**For humans:** follow the numbered steps below to guide your development.
20+
21+
---
22+
23+
## Instructions
24+
25+
### Step 1: Recover the research context
26+
27+
Determine whether the method should be a **scalarizer** or an **aggregator**, then read everything
28+
the `research-method` skill already found about it:
29+
30+
- Scalarizers are tracked in https://github.com/SimplexLab/TorchJD/issues/667, aggregators in
31+
https://github.com/SimplexLab/TorchJD/issues/665. Fetch the relevant issue and find the row for
32+
this method. Read every column:
33+
- **Ref** — the paper (open it; you will need the exact equations / algorithm).
34+
- **Stateful** — whether and how the method holds state.
35+
- **Existing implementations** — links to the official repo (if any) and the best-known
36+
third-party ones (LibMTL, libmoon, pymoo, ...), ideally with the exact file(s) and line(s).
37+
- **Special Remarks** — may link to a full research write-up (e.g. a `claude.ai` share produced by
38+
`research-method`). Read it if present.
39+
- The most valuable inputs are the **non-standard interface aspects** uncovered during research:
40+
statefulness, trainable parameters, randomness, warm-up / history buffers, statistics beyond the
41+
`forward` values (e.g. per-task losses for an aggregator), and preconditions. If these are not
42+
fully captured in the issue, **ask the user to share the `research-method` findings** before
43+
continuing. Do not guess them.
44+
45+
If the method is not in the tracking issue yet, run `research-method` first.
46+
47+
### Step 2: Load the implementation reference for this method type
48+
49+
Read only the reference matching the method type, to keep context focused:
50+
51+
- **Scalarizer** → read `references/scalarizers.md`.
52+
- **Aggregator** → read `references/aggregators.md`.
53+
54+
Each reference lists the exact files to create/edit and the TorchJD-specific conventions, with the
55+
closest existing methods to mirror.
56+
57+
### Step 3: Compare the paper with the existing implementations
58+
59+
Always do this — it is the step we invariably end up needing. Read the relevant equations / the
60+
algorithm box in the paper, then read the official and best-known third-party implementations at the
61+
exact files/lines from the tracking row.
62+
63+
Reconcile any discrepancies between them (a different normalization, an extra factor, an
64+
initialization value, a stabilization trick, a sign convention). Decide which version to follow,
65+
note **why**, and surface the disagreement to the user. The implementation should be faithful to a
66+
clearly-stated source, not an unexplained blend.
67+
68+
### Step 4: Settle the interface and design decisions
69+
70+
Using the research findings (Step 1) and the comparison (Step 3), decide how each non-standard
71+
aspect maps onto TorchJD, reusing the closest existing pattern:
72+
73+
- **Stateful, trainable** parameter(s) → `nn.Parameter` + the `Stateful` mixin's `reset()` (mirror
74+
`UW` / `IMTL-L`).
75+
- **Stateful, non-trainable** history/buffer → a registered buffer + an explicit update method (e.g.
76+
`step()`), with no `nn.Parameter` (mirror `DWA`).
77+
- **Internal optimizer / multi-call update protocol** (mirror `FAMO`).
78+
- **Preference / reference vector** argument (mirror `STCH` / `COSMOS` / `PBI`).
79+
- **Preconditions** (e.g. positivity): decide whether to enforce them (raise `ValueError`) or only
80+
document them, and how `nan`/`inf` should propagate.
81+
- Which constructor arguments are **required vs optional**, and their **defaults**.
82+
83+
List the non-standard parts and your proposed handling, and **confirm the design with the user
84+
before writing code.** This is where most of the maintainer review happens, so settle it up front.
85+
86+
### Step 5: Implement the method
87+
88+
Follow the file-by-file checklist in the reference loaded at Step 2. Match the style, naming, and
89+
conventions of the closest existing method. If you adapt code from a third-party repository, add the
90+
license header to the source file and an entry to `NOTICES` (see the reference).
91+
92+
### Step 6: Verify
93+
94+
Run the checks listed in the reference (unit tests with `-W error`, lint, type-check, and the docs
95+
build/doctest). GPU tests require a CUDA device; if you cannot run them, provide the exact commands
96+
for the user to run on their GPU and report back the results.
97+
98+
### Step 7: Open the PR
99+
100+
Create a new branch, commit, and open a pull request targeting `main`, following the repository's
101+
PR conventions (a `CHANGELOG.md` entry under `[Unreleased] > ### Added`; when asked for a PR
102+
description, output raw GitHub-flavored markdown in a fenced code block, with GitHub math syntax
103+
`$...$` / `$$...$$` and no em dashes). Return the PR URL when done.
104+
105+
---
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Reference: implementing a `Scalarizer`
2+
3+
A `Scalarizer` reduces a tensor of values of any shape into a single scalar — the baseline that
4+
combines *losses* directly (a plain `loss.backward()` then gives the gradient), as opposed to an
5+
`Aggregator` which combines per-loss *gradients*. Base class: `Scalarizer` in
6+
`src/torchjd/scalarization/_scalarizer_base.py`.
7+
8+
**Don't work from this file alone — read the closest existing class end-to-end (its `_*.py` + `.rst`
9+
+ `test_*.py`) and mirror it.** This reference is the map and the non-obvious rules, not a template.
10+
11+
## Contract for the subclass
12+
13+
- Subclasses `Scalarizer` (an `nn.Module`); `forward(self, values: Tensor, /) -> Tensor` returns a
14+
**0-dim** scalar.
15+
- The parameter is named **`values`** (positional-only), not `losses``Scalarizer` is generic
16+
(maintainer decision). Accepts **any shape** and reduces over all elements (flatten if needed).
17+
18+
## Files to create / edit (new scalarizer `Foo`)
19+
20+
1. `src/torchjd/scalarization/_foo.py` — the class.
21+
2. `src/torchjd/scalarization/__init__.py` — add the import + the `__all__` entry.
22+
3. `docs/source/docs/scalarization/foo.rst` — doc page (mirror `geometric_mean.rst`).
23+
4. `docs/source/docs/scalarization/index.rst` — add `foo.rst` to the `.. toctree::`.
24+
5. `tests/unit/scalarization/test_foo.py` — tests.
25+
6. `CHANGELOG.md` — entry under `[Unreleased] > ### Added`.
26+
7. *(Only if you adapt third-party code)* license header in `_foo.py` + an entry in `NOTICES`.
27+
28+
## Pick the pattern and mirror it
29+
30+
| Pattern | Mirror | File |
31+
|---|---|---|
32+
| Stateless one-liner | `GeometricMean`, `Mean`, `Sum` | `_geometric_mean.py`, `_mean.py`, `_sum.py` |
33+
| Stateless + preference/reference vector | `STCH`, `COSMOS`, `PBI` | `_stch.py`, `_cosmos.py`, `_pbi.py` |
34+
| Stateful, trainable parameter | `UW`, `IMTL-L` | `_uw.py`, `_imtl_l.py` |
35+
| Stateful, non-trainable history buffer | `DWA` | `_dwa.py` |
36+
| Internal optimizer + multi-call protocol | `FAMO` | `_famo.py` |
37+
38+
### Pattern-specific rules (the things not obvious from one file)
39+
40+
- **Trainable** (`UW`/`IMTL-L`): also subclass `Stateful` (`from torchjd._mixins import Stateful`)
41+
and implement `reset()`. State is an `nn.Parameter`, init to a neutral default (usually `0`), with
42+
a `shape: int | Sequence[int]` arg (`Foo(3)``(3,)`). Validate `values.shape` at call time
43+
(`ValueError`). The params are in `.parameters()`, so the user passes them to the optimizer — show
44+
this in a doctest. A trained per-position param makes it **not** permutation-invariant; don't
45+
assert it. Add a `shape`-aware `__repr__`.
46+
- **History buffer** (`DWA`): **no** `nn.Parameter` (`list(Foo().parameters())` must be empty); hold
47+
state in a `register_buffer` (moves with `.to()`, can be created lazily from the first input
48+
shape). Provide an explicit update method (e.g. scheduler-like `step()`); `forward` **detaches**
49+
weights derived from the state; `reset()` clears the buffer.
50+
- **Internal optimizer / multi-call** (`FAMO`): private `nn.Parameter` (`_w`) with `.grad` cleared
51+
after each step; a lazily-created internal `torch.optim.Adam`; an `update(new_losses)` method;
52+
`forward` detaches the weights. Read `_famo.py` before copying.
53+
- **Preference / reference vector** (`STCH`/`COSMOS`/`PBI`): validate shapes at call time
54+
(`ValueError`, like `Constant`); flatten `weights`/`values`/`reference` in `forward`. `reference`
55+
(z*) usually defaults to `0`; `weights` is required or uniform per the paper. Watch `nan`-gradient
56+
footguns — `‖x‖` has a `0/0` grad at `0` (use `sqrt(‖x‖² + eps)`, see `PBI`); cosine needs an
57+
eps-clamped denominator (use `torch.nn.functional.cosine_similarity`, see `COSMOS`). Lock with a
58+
test.
59+
60+
## Docstring conventions
61+
62+
- Use a **raw** `r"""` docstring **only** if it contains LaTeX (`:math:` / `.. math::`) so
63+
backslashes stay single; plain `"""` otherwise.
64+
- Start with the `:class:` cross-ref(s) (`:class:`~torchjd.scalarization.Scalarizer``, plus
65+
`:class:`~torchjd.Stateful`` if stateful); link the paper by full title + URL.
66+
- Multi-symbol math → a `.. math::` block + a bullet list defining each symbol (not one dense inline
67+
paragraph; see `STCH`). Document every `:param:`. Add a usage doctest (for stateful methods show
68+
the optimizer / `step()` / `update()` cadence). Note preconditions in `.. note::` and decide
69+
whether to enforce (`ValueError`) or let `nan`/`inf` propagate.
70+
71+
## Tests
72+
73+
Mirror `test_geometric_mean.py` (stateless) or `test_uw.py` (stateful). Shared infra in
74+
`tests/unit/scalarization/`: `_inputs.py` (`shapes = [[], [5], [3, 4], [2, 3, 4]]`, `all_inputs`);
75+
`_asserts.py` (`assert_returns_scalar`, `assert_grad_flow`, `assert_permutation_invariant`);
76+
`utils.tensors` helpers (`tensor_`, `rand_`, `randn_`, `ones_`, `zeros_`, `randperm_` — they respect
77+
`PYTEST_TORCH_DEVICE`/`PYTEST_TORCH_DTYPE`; for stateful instances make a `_foo(shape)` helper that
78+
`.to(device=DEVICE, dtype=DTYPE)`, see `test_uw.py`). Cover: `test_value` (hand-checked),
79+
`test_expected_structure` + `test_grad_flow` (parametrized over shapes), `test_permutation_invariant`
80+
**only if** invariant, the documented edge cases/contracts (e.g. assert `nan` propagates on a bad
81+
input so a future clamp can't slip in; a `does_not_raise()`/`raises(ValueError)` shape table; `reset`
82+
clears state; params train / buffer rolls), and `test_representations`.
83+
84+
## CHANGELOG
85+
86+
`- Added `Foo` from [Paper Title](url) (Venue Year), a `Scalarizer` that <one-line description>.`
87+
88+
## Third-party attribution (only if adapting code, e.g. `FAMO`)
89+
90+
Header comment in `_foo.py`: `# Partly adapted from <url> — <License>, Copyright (c) <year>
91+
<author>. # See NOTICES for the full license text.` plus the full license text in `NOTICES`.
92+
93+
## Verify (from repo root)
94+
95+
```bash
96+
uv run pytest tests/unit/scalarization -W error -v # new tests
97+
uv run pytest tests/unit -W error # full unit regression
98+
uv run ruff check && uv run ruff format --check # lint + format
99+
uv run make doctest -C docs && uv run make clean -C docs && uv run make html -C docs
100+
uv run pre-commit run --all-files
101+
PYTEST_TORCH_DEVICE=cuda:0 uv run pytest tests/unit -W error # GPU (needs CUDA)
102+
```
103+
104+
- If `uv run` re-syncs unexpectedly, prefix with `UV_NO_SYNC=1`. Docs build is strict (`-W -n`), so
105+
an `.rst` title underline must match its title length.
106+
- `test_dualproj.py::test_permutation_invariant` and `test_upgrad.py::test_permutation_invariant`
107+
are known flaky off-Linux (~1 float32 ULP, quadprog), pre-existing and unrelated. CI (Linux) is
108+
the source of truth.

0 commit comments

Comments
 (0)