Composable per-term losses#2715
Draft
snimu wants to merge 20 commits into
Draft
Conversation
Design doc for generalizing the RL loss into a list of composable terms (loss core + filters + weight), with echo as a preset. Default stays byte-for-byte main. See plans/losses.md. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
- Drop all references to other branches / code that only exists off-main; the doc now reads for someone who knows verifiers main + the echo concept. - Record decisions: single shared losses list at RLConfig level; ship per-sample resolved core kwargs for trainer-side knobs; bake adv_tau orchestrator-side. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
opd registers as a built-in core with an `opd` preset, reusing opd_loss_fn; its teacher-KL signal is derived inside the core (the weight slot stays external/orchestrator-only). The opd path stays separate for now; §16 records how the loss framework + future env sampler collapse the three paths into one. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
compute_loss now selects loss terms via build_loss_terms() from the core registry and sums them before the single backward, instead of dispatching a single fn. Today there is exactly one term per training_mode, so behavior is unchanged — this is the seam where echo/custom terms attach in phase 2. Adds CPU golden tests asserting compute_loss equals the per-sample core loss summed and scaled (rl/sft/opd), plus build_loss_terms unit tests. No behavior change; compute_loss/setup_loss_fns signatures are untouched so train.py and existing tests are unaffected. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Add echo_loss_fn (weighted masked NLL: advantages carry the per-token weight, loss_mask is the term's selection mask) and an ExtraTerm carrier. compute_loss now sums extra terms — each with its own per-sample mask/weight and scale — alongside the primary training_mode term. The rl-only path is bit-identical (the primary term keeps the single global divide); extra_terms defaults to None so train.py and existing call sites are unaffected. Orchestrator wiring and the echo preset land in phase 2b. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Wire echo through the stack as an additive loss term. Per the design, there is no rl/echo exclusion: echo_mask/echo_weight stay separate from loss_mask/ advantages, so terms may overlap and their gradients sum. - orchestrator: EchoConfig (per-role alphas + tool_names + optional filter) on TrainEnvConfig; new echo.py builds per-token echo_alpha from prompt_attribution (+ optional filter); envs binds the filter fn; train_sink/trajectories stamp echo_alpha onto samples. - wire: TrainingSample.echo_alpha; MicroBatch.echo_mask + echo_weight (omit_defaults). - trainer: batch.py builds echo_mask/echo_weight (token 0 excluded — no shifted logprob) and carries them through packing/padding/dummy; data.py collates them; train.py computes a separate global echo denominator (folded into one all-reduce) and builds an echo ExtraTerm applied via the phase-2a multi-term compute_loss. The rl-only path is unchanged: no echo config -> no echo term, same loss_scale. Adds CPU unit tests for the echo role/alpha/filter logic and prepare_sample's echo_mask/echo_weight building. Not run locally (linux-only lockfile); ruff clean. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Introduce the canonical `losses` surface: a discriminated list of named term presets (rl/sft/opd/echo/custom) that will replace trainer.loss + per-env echo. Self-contained config module, not yet wired — the branch stays green so the surface can be reviewed before the breaking wire-up + config migration. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…tep 2) Replace trainer.loss + per-env echo with the shared `losses` term list: - RLConfig.losses propagated to trainer.losses + orchestrator.losses via propagate_shared_fields; per-env enabled_losses + loss_overrides on TrainEnvConfig. Remove DefaultLossConfig/CustomLossConfig and the per-env EchoConfig/role classes (now in configs.losses). - train_sink resolves each env's enabled echo term (<=1, per-env overrides deep-merged) and binds its filter; builds the echo overlay from it. The single echo weight stream + phase-2 wire/batch/data are reused unchanged. - trainer reads the rl/custom core from the losses list (setup_loss_fns(list)); token_export reads the rl term for its DPPO-threshold annotations. - migrate tests + docs/algorithms.md + the configs skill to the new surface. Default losses=[rl] reproduces prior behavior. Deferred to phase 4: per-env trainer-side core kwargs (per-env kl_tau) and the term_weights dict generalization (only needed for multiple independent-core / custom terms). Not run locally (no 3.12+pydantic on this box); ruff clean, py_compile + an isolated DSL static check pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
- LossList type rejects duplicate term names wherever `losses` is set. - OrchestratorConfig validator: per-env enabled_losses must reference defined term names, and at most one echo term may be enabled per env (config-time, mirroring the train_sink runtime guard). - validate_shared_losses: trainer.losses must equal orchestrator.losses, catching `losses` set under [trainer]/[orchestrator] separately instead of top-level. - tests for the validators + losses propagation in test_configs.py. Custom pointers: the custom core (CustomLossTermConfig) and custom echo filter (EchoFilterConfig.import_path) already work end-to-end. A custom weight fn and the term_weights dict generalization (multiple independent-core terms) are deferred as a follow-up — they belong together. ruff clean, py_compile OK; not run locally (no 3.12+pydantic on this box). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
- enabled_losses=None now validated as the full term list, so >1 echo term per env is caught at config time instead of at rollout time. [review #6] - loss_overrides keys validated against `losses`; non-echo overrides rejected. [#7] - warn (don't fail) when prompt-role echo is configured with renderer=None (MITO), where prompt_attribution is unavailable so it would silently no-op. [#8] - token_export: add echo_mask/echo_weight columns + export sequences trained only via echo (gate on loss_mask OR echo_mask). [#9] - doc notes: echo CE uses the rollout temperature (scale alpha to compensate, kept as-is); negative alpha is intentional (suppresses tokens). [#1, #10] - tests for the new config validators. Deferred to a follow-up pass (per the review): full per-sample primary routing / rl-disable [#2b] + the <=1-primary validation it enables [#5], and the multi-run losses fingerprint [#3]. Not run locally; ruff + py_compile clean. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…#2/#5) - enabled_losses can now disable the primary loss for an env: when the term matching training_mode isn't enabled, train_sink zeroes that env's samples' completion_mask (after the decode/prefill metric), so the primary core trains nothing while echo (its own mask) still applies. [#2b] - setup_loss_fns no longer fabricates a default rl core when no rl/custom term exists; the rl core raises if an rl-mode batch is applied. [#2a, trainer-side] - OrchestratorConfig.validate_primary_loss: training_mode must have a matching term (rl/custom | sft | opd) in `losses`. [#2a, config-time] - LossList rejects more than one primary (rl/custom) term. [#5] - tests for the validators + the no-rl-term rl-core error. Not run locally; ruff + py_compile clean. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…review #3) The trainer builds its loss cores once from its startup `losses`; each run's orchestrator carries its own. Register a MultiRunManager config-validation hook (alongside the LoRA-rank hook, reusing the per-run config-discovery path) that rejects a run whose `losses` differ from the trainer's. Single-run is already covered by validate_shared_losses. Not run locally; ruff + py_compile clean. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
sft/opd dispatch to fixed cores (sft_loss_fn/opd_loss_fn) selected by
training_mode, independent of the `losses` list — only the rl-mode primary comes
from an rl/custom term. The previous validate_primary_loss + _primary_enabled
wrongly required a matching sft/opd term, which (a) rejected every sft/opd config
using the default losses=[rl] (test_load_configs) and (b) would have zeroed
completion_mask for all sft/opd samples. Now both only apply to rl-mode.
Fixes the 7 test_load_configs failures (configs/debug/training_modes/{sft,opd}*,
configs/ci/integration/reverse_text_rl_{opd,sft}). ruff + py_compile clean.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The sft/opd fix reworded validate_primary_loss; update the test regex to match
("requires an rl or custom loss term"). The validation behavior is unchanged.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…sft warning - Drop sft/opd presets from the losses DSL: they dispatch to fixed cores by training_mode and are not loss-list terms, so listing them (or naming them in enabled_losses) was a silent no-op footgun. `losses` now holds rl/custom + echo only. [review #2] - Gate the zero-advantage filter on the env's rl primary being active (ZeroAdvantageFilter.primary_active, wired to train_sink._primary_enabled), so echo-only / rl-disabled envs (where advantage is irrelevant — e.g. group_size=1) aren't dropped on zero advantage. Default always-active → unchanged without echo. [review #1] - The MITO prompt-role-echo warning also fires for training_mode='sft' (which forces renderer=None after the warning's original check). [review #7] - tests for sft-type rejection + the filter gate. Left as-is: trainer.loss migration (intentional break), backfill prompt-role echo (best-effort, known), docs DPPO clipping (pre-existing main doc bug). Not run locally; ruff + py_compile clean. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
test_duplicate_loss_names_rejected still fed {"type": "sft"}, which now
fails the discriminated-union tag check before validate_loss_list can
raise "Duplicate loss term names". Swap the second term to a valid echo
term sharing the same name so the duplicate-name path is exercised.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ation - orchestrator: gate the zero_advantage filter on primary-active in BOTH pre- and post-batch filter lists (echo-only envs were dropped pre-batch). - train_sink: count n_trainable from real loss-bearing tokens (primary or echo), not just unfiltered rollouts. - configs: reject an explicitly-empty enabled_losses; validate per-env loss_overrides at config time by constructing the merged EchoLossConfig (shared deep_merge / apply_echo_override helper, reused by the orchestrator). - loss: echo_loss_fn skips its metrics for empty-mask packed splits (removes the packing-composition bias in echo_nll). - train: entropy/env metrics span loss_mask | echo_mask so echo-only envs aren't logged empty/nan; mismatch_kl stays completion-only. - token_export: bump SCHEMA_VERSION to 2 for the echo_mask/echo_weight columns. - docs/skills: fix the DPPO loss math (probs-diff masking, not min-clip), the length_penalty / filter config paths, and the trainer.optim.lr / orchestrator.train.env paths. - tests: empty enabled_losses + malformed loss_override validators; pure-function coverage for the n_trainable trainability check. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
review-main-diff-2026-06-04.md was committed by an over-broad `git add -A`; it's a local review note, not part of the PR. Untrack it (kept on disk). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
Introduces a composable per-term loss framework. A run defines a list of named loss terms under a single shared
losseskey; each env selects which terms apply viaenabled_losses; the trainer sums the enabled terms over one shared forward/backward. Echo ships as a preset term —{core: SFT cross-entropy, filters: [role], weight: per-token alpha}— rather than bespoke code.The default
losses = [rl]reproduces today's DPPO+KL training bit-for-bit. It too is just a preset that expands to the full loss config.Changes vs
main[trainer.loss](DefaultLossConfig/CustomLossConfig). Its DPPO+KL knobs (kl_tau,dppo_mask_low/high,adv_tau) now live on a[[losses]]type = "rl"term; a custom loss is atype = "custom"term.losseslist (term typesrl/echo/custom), the echo overlay capability (new — not present onmain), and per-envenabled_losses+loss_overrides.rlterm), theopd/sftcores (still dispatched bytraining_mode), and — with defaultlosses— every training run's numerics.How it works
A loss term =
{ core, filters, weight }:rl(DPPO+KL),echo(weighted masked NLL),custom(your import path). Theopd(teacher-KL) andsft(masked NLL) cores are dispatched bytraining_modeand are not loss-list terms.alphaare resolved here and shipped to the trainer.Key behaviors:
lossesis set once at the top level and propagated to bothtrainer.lossesandorchestrator.losses(propagate_shared_fields). Don't set it under[trainer]/[orchestrator]separately — a validator rejects a mismatch.enabled_losses(defaultNone= all terms) picks which terms apply to an env;loss_overridesdeep-merges per-env params into a named term (currently echo-only, e.g. a differentalpha).enabled_lossesomits the rl/custom primary, that env's completion mask is zeroed (no RL gradient) while echo still applies — i.e. an echo-only env. Thezero_advantagefilter is gated so echo-only envs aren't dropped when their advantage is 0.T=1NLL) — scalealphato compensate.alpha = 0keeps a token supervised with zero gradient; negativealphasuppresses tokens (anti-echo). Echo masks/weights are kept separate fromloss_mask/advantages, so RL and echo may cover overlapping tokens and their gradients simply sum.prompt_attribution. Supervising system/user/tool tokens requires a renderer that emits prompt attribution; under MITO (renderer = None, whichsftmode forces) those tokens silently no-op, so config validation warns. Assistant-role echo has no such requirement.orchestrator.lossesis validated to equal the trainer's via aMultiRunManagerconfig hook, so all runs served by one trainer agree on the term list.Config surface
Term fields —
rl:kl_tau,dppo_mask_low,dppo_mask_high,adv_tau.echo:name,system/user/assistant/tool(each{ alpha, [tool_names] }, ≥1 required), optionalfilter = { import_path, kwargs }.custom:name,import_path,kwargs.Validation
Unique term names · ≤1 primary (rl/custom) term ·
enabled_losses ⊆defined terms · ≤1 echo term per env ·loss_overrideskeys must be defined echo terms ·training_mode = "rl"requires an rl/custom term ·trainer.losses == orchestrator.losses· MITO + prompt-role echo warning.Tests
tests/unit/train/rl/test_loss_terms.py(term registry, multi-termcompute_loss, echo core, default==rl-only) ·tests/unit/orchestrator/test_echo.py(role attribution, per-token alpha, filter) ·tests/unit/orchestrator/test_batch.py(echo plumbing) ·tests/unit/orchestrator/test_filters.py(echo-only zero-advantage gating) ·lossespropagation + validators intests/unit/test_configs.py.Deferred (follow-ups)
term_weightsdict (multiple independent-core terms; would lift the ≤1-echo-per-env limit) · a custom weight fn · per-env trainer-side core kwargs (per-envkl_tau). Custom cores (type = "custom") and custom echo filters (filter.import_path) already ship. The IPO ablation TOMLs under theconfigs/privatesubmodule are migrated separately (owner-managed).Example configs
1 — Default RL (no change from
main). With nolosseskey, the default is[rl]→ DPPO+KL exactly as before. To tune the knobs (the former[trainer.loss]):2 — RL + assistant echo. Overlay CE on the model's own assistant tokens (no renderer requirement). Works with any renderer:
3 — RL + prompt-role echo. Supervise system/user/tool context tokens. Requires a renderer that emits
prompt_attribution(i.e. not MITO):4 — Multi-env: echo on some envs, not others. A shared term list; each env opts in. The math env gets RL-only; the chat env adds echo with a per-env
alphaoverride:5 — Echo-only env (primary disabled). Omit the primary from
enabled_losses: the env's completion mask is zeroed (no RL gradient) and only echo trains. Useful for a distillation/imitation env mixed alongside RL envs:6 — Anti-echo (negative alpha). Suppress a role's tokens instead of reinforcing them:
7 — Echo with a custom filter. Narrow role-selected tokens further with an imported callable:
8 — Custom loss core (replaces the RL primary). Point at your own core (
def core(inputs, **kwargs) -> LossOutputs). At most one primary (rl/custom) term per run: