Skip to content

Composable per-term losses#2715

Draft
snimu wants to merge 20 commits into
mainfrom
sebastian/losses-2026-06-04
Draft

Composable per-term losses#2715
snimu wants to merge 20 commits into
mainfrom
sebastian/losses-2026-06-04

Conversation

@snimu
Copy link
Copy Markdown
Collaborator

@snimu snimu commented Jun 4, 2026

What

Introduces a composable per-term loss framework. A run defines a list of named loss terms under a single shared losses key; each env selects which terms apply via enabled_losses; the trainer sums the enabled terms over one shared forward/backward. Echo ships as a preset term — {core: SFT cross-entropy, filters: [role], weight: per-token alpha} — rather than bespoke code.

The default losses = [rl] reproduces today's DPPO+KL training bit-for-bit. It too is just a preset that expands to the full loss config.

Changes vs main

  • Removed [trainer.loss] (DefaultLossConfig / CustomLossConfig). Its DPPO+KL knobs (kl_tau, dppo_mask_low/high, adv_tau) now live on a [[losses]] type = "rl" term; a custom loss is a type = "custom" term.
  • Added the unified, shared losses list (term types rl / echo / custom), the echo overlay capability (new — not present on main), and per-env enabled_losses + loss_overrides.
  • Unchanged: the DPPO+KL math (relocated verbatim into the rl term), the opd / sft cores (still dispatched by training_mode), and — with default losses — every training run's numerics.

How it works

A loss term = { core, filters, weight }:

  • Cores run trainer-side (they need logprobs): rl (DPPO+KL), echo (weighted masked NLL), custom (your import path). The opd (teacher-KL) and sft (masked NLL) cores are dispatched by training_mode and are not loss-list terms.
  • Filters + weights run orchestrator-side (they need role attribution / GRPO advantage). Echo's per-role token selection and per-token alpha are resolved here and shipped to the trainer.
  • Every enabled term differentiates the same shared forward → the trainer sums them into one backward.

Key behaviors:

  • Shared config. losses is set once at the top level and propagated to both trainer.losses and orchestrator.losses (propagate_shared_fields). Don't set it under [trainer]/[orchestrator] separately — a validator rejects a mismatch.
  • Per-env selection. enabled_losses (default None = all terms) picks which terms apply to an env; loss_overrides deep-merges per-env params into a named term (currently echo-only, e.g. a different alpha).
  • Per-env primary disable. If an env's enabled_losses omits the rl/custom primary, that env's completion mask is zeroed (no RL gradient) while echo still applies — i.e. an echo-only env. The zero_advantage filter is gated so echo-only envs aren't dropped when their advantage is 0.
  • Echo semantics. Echo is a per-role CE overlay on context tokens (system / user / tool / assistant). It is computed on the rollout's temperature-scaled logprobs (not a true T=1 NLL) — scale alpha to compensate. alpha = 0 keeps a token supervised with zero gradient; negative alpha suppresses tokens (anti-echo). Echo masks/weights are kept separate from loss_mask/advantages, so RL and echo may cover overlapping tokens and their gradients simply sum.
  • Prompt-role echo needs prompt_attribution. Supervising system/user/tool tokens requires a renderer that emits prompt attribution; under MITO (renderer = None, which sft mode forces) those tokens silently no-op, so config validation warns. Assistant-role echo has no such requirement.
  • Multi-run. Each run's orchestrator.losses is validated to equal the trainer's via a MultiRunManager config hook, so all runs served by one trainer agree on the term list.

Config surface

# top-level (shared) — propagated to trainer + orchestrator
[[losses]]
type = "rl"                  # the DPPO+KL primary; default term
kl_tau = 1e-3                # dppo_mask_low/high, adv_tau also live here

[[losses]]
type = "echo"
name = "echo"                # unique; referenced by enabled_losses
assistant = { alpha = 0.5 }  # any of system/user/assistant/tool (>=1 required)
tool = { alpha = 0.3, tool_names = ["calc"] }   # restrict to named tools

# per-env selection (orchestrator)
[[orchestrator.train.env]]
id = "math-env"
enabled_losses = ["rl", "echo"]                          # None (default) = all terms
loss_overrides = { echo = { system = { alpha = 0.01 } } } # deep-merged per env

Term fieldsrl: kl_tau, dppo_mask_low, dppo_mask_high, adv_tau. echo: name, system/user/assistant/tool (each { alpha, [tool_names] }, ≥1 required), optional filter = { import_path, kwargs }. custom: name, import_path, kwargs.

Validation

Unique term names · ≤1 primary (rl/custom) term · enabled_losses ⊆ defined terms · ≤1 echo term per env · loss_overrides keys must be defined echo terms · training_mode = "rl" requires an rl/custom term · trainer.losses == orchestrator.losses · MITO + prompt-role echo warning.

Tests

tests/unit/train/rl/test_loss_terms.py (term registry, multi-term compute_loss, echo core, default==rl-only) · tests/unit/orchestrator/test_echo.py (role attribution, per-token alpha, filter) · tests/unit/orchestrator/test_batch.py (echo plumbing) · tests/unit/orchestrator/test_filters.py (echo-only zero-advantage gating) · losses propagation + validators in tests/unit/test_configs.py.

Deferred (follow-ups)

term_weights dict (multiple independent-core terms; would lift the ≤1-echo-per-env limit) · a custom weight fn · per-env trainer-side core kwargs (per-env kl_tau). Custom cores (type = "custom") and custom echo filters (filter.import_path) already ship. The IPO ablation TOMLs under the configs/private submodule are migrated separately (owner-managed).


Example configs

Each block shows only the loss-relevant keys; [model], [orchestrator], [trainer], [inference] are as usual (see configs/gsm8k/rl.toml for a minimal full config).

1 — Default RL (no change from main). With no losses key, the default is [rl] → DPPO+KL exactly as before. To tune the knobs (the former [trainer.loss]):

[[losses]]
type = "rl"
kl_tau = 5e-4
dppo_mask_low = 0.2
dppo_mask_high = 0.2
adv_tau = 1.0

2 — RL + assistant echo. Overlay CE on the model's own assistant tokens (no renderer requirement). Works with any renderer:

[[losses]]
type = "rl"

[[losses]]
type = "echo"
assistant = { alpha = 0.5 }

3 — RL + prompt-role echo. Supervise system/user/tool context tokens. Requires a renderer that emits prompt_attribution (i.e. not MITO):

[[losses]]
type = "rl"

[[losses]]
type = "echo"
name = "echo"
system = { alpha = 1.0 }
user   = { alpha = 0.5 }
tool   = { alpha = 0.3, tool_names = ["python", "search"] }  # only these tools

[orchestrator.renderer]
name = "auto"   # a prompt_attribution-capable renderer (not renderer = None)

4 — Multi-env: echo on some envs, not others. A shared term list; each env opts in. The math env gets RL-only; the chat env adds echo with a per-env alpha override:

[[losses]]
type = "rl"

[[losses]]
type = "echo"
name = "echo"
assistant = { alpha = 0.5 }

[[orchestrator.train.env]]
id = "math-env"
name = "math"
enabled_losses = ["rl"]            # RL only here

[[orchestrator.train.env]]
id = "chat-env"
name = "chat"
enabled_losses = ["rl", "echo"]                                   # RL + echo
loss_overrides = { echo = { assistant = { alpha = 0.1 } } }       # softer echo here

5 — Echo-only env (primary disabled). Omit the primary from enabled_losses: the env's completion mask is zeroed (no RL gradient) and only echo trains. Useful for a distillation/imitation env mixed alongside RL envs:

[[losses]]
type = "rl"

[[losses]]
type = "echo"
name = "echo"
assistant = { alpha = 1.0 }

[[orchestrator.train.env]]
id = "rl-env"
name = "math"
enabled_losses = ["rl"]

[[orchestrator.train.env]]
id = "imitate-env"
name = "demos"
enabled_losses = ["echo"]   # echo-only: no RL signal, not dropped on zero advantage

6 — Anti-echo (negative alpha). Suppress a role's tokens instead of reinforcing them:

[[losses]]
type = "rl"

[[losses]]
type = "echo"
name = "suppress-tools"
tool = { alpha = -0.2 }   # push probability mass away from tool-message tokens

7 — Echo with a custom filter. Narrow role-selected tokens further with an imported callable:

[[losses]]
type = "rl"

[[losses]]
type = "echo"
name = "echo"
user = { alpha = 0.5 }
filter = { import_path = "my_pkg.filters.drop_boilerplate", kwargs = { min_len = 8 } }

8 — Custom loss core (replaces the RL primary). Point at your own core (def core(inputs, **kwargs) -> LossOutputs). At most one primary (rl/custom) term per run:

[[losses]]
type = "custom"
name = "my-loss"
import_path = "my_pkg.losses.my_core"
kwargs = { beta = 0.1 }

snimu and others added 20 commits June 4, 2026 14:32
Design doc for generalizing the RL loss into a list of composable terms
(loss core + filters + weight), with echo as a preset. Default stays
byte-for-byte main. See plans/losses.md.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
- Drop all references to other branches / code that only exists off-main;
  the doc now reads for someone who knows verifiers main + the echo concept.
- Record decisions: single shared losses list at RLConfig level; ship
  per-sample resolved core kwargs for trainer-side knobs; bake adv_tau
  orchestrator-side.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
opd registers as a built-in core with an `opd` preset, reusing opd_loss_fn;
its teacher-KL signal is derived inside the core (the weight slot stays
external/orchestrator-only). The opd path stays separate for now; §16 records
how the loss framework + future env sampler collapse the three paths into one.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
compute_loss now selects loss terms via build_loss_terms() from the core
registry and sums them before the single backward, instead of dispatching a
single fn. Today there is exactly one term per training_mode, so behavior is
unchanged — this is the seam where echo/custom terms attach in phase 2.

Adds CPU golden tests asserting compute_loss equals the per-sample core loss
summed and scaled (rl/sft/opd), plus build_loss_terms unit tests. No behavior
change; compute_loss/setup_loss_fns signatures are untouched so train.py and
existing tests are unaffected.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Add echo_loss_fn (weighted masked NLL: advantages carry the per-token weight,
loss_mask is the term's selection mask) and an ExtraTerm carrier. compute_loss
now sums extra terms — each with its own per-sample mask/weight and scale —
alongside the primary training_mode term. The rl-only path is bit-identical
(the primary term keeps the single global divide); extra_terms defaults to
None so train.py and existing call sites are unaffected. Orchestrator wiring
and the echo preset land in phase 2b.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Wire echo through the stack as an additive loss term. Per the design, there is
no rl/echo exclusion: echo_mask/echo_weight stay separate from loss_mask/
advantages, so terms may overlap and their gradients sum.

- orchestrator: EchoConfig (per-role alphas + tool_names + optional filter) on
  TrainEnvConfig; new echo.py builds per-token echo_alpha from prompt_attribution
  (+ optional filter); envs binds the filter fn; train_sink/trajectories stamp
  echo_alpha onto samples.
- wire: TrainingSample.echo_alpha; MicroBatch.echo_mask + echo_weight (omit_defaults).
- trainer: batch.py builds echo_mask/echo_weight (token 0 excluded — no shifted
  logprob) and carries them through packing/padding/dummy; data.py collates them;
  train.py computes a separate global echo denominator (folded into one all-reduce)
  and builds an echo ExtraTerm applied via the phase-2a multi-term compute_loss.

The rl-only path is unchanged: no echo config -> no echo term, same loss_scale.
Adds CPU unit tests for the echo role/alpha/filter logic and prepare_sample's
echo_mask/echo_weight building. Not run locally (linux-only lockfile); ruff clean.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Introduce the canonical `losses` surface: a discriminated list of named term
presets (rl/sft/opd/echo/custom) that will replace trainer.loss + per-env echo.
Self-contained config module, not yet wired — the branch stays green so the
surface can be reviewed before the breaking wire-up + config migration.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…tep 2)

Replace trainer.loss + per-env echo with the shared `losses` term list:

- RLConfig.losses propagated to trainer.losses + orchestrator.losses via
  propagate_shared_fields; per-env enabled_losses + loss_overrides on
  TrainEnvConfig. Remove DefaultLossConfig/CustomLossConfig and the per-env
  EchoConfig/role classes (now in configs.losses).
- train_sink resolves each env's enabled echo term (<=1, per-env overrides
  deep-merged) and binds its filter; builds the echo overlay from it. The
  single echo weight stream + phase-2 wire/batch/data are reused unchanged.
- trainer reads the rl/custom core from the losses list (setup_loss_fns(list));
  token_export reads the rl term for its DPPO-threshold annotations.
- migrate tests + docs/algorithms.md + the configs skill to the new surface.

Default losses=[rl] reproduces prior behavior. Deferred to phase 4: per-env
trainer-side core kwargs (per-env kl_tau) and the term_weights dict
generalization (only needed for multiple independent-core / custom terms).

Not run locally (no 3.12+pydantic on this box); ruff clean, py_compile + an
isolated DSL static check pass.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
- LossList type rejects duplicate term names wherever `losses` is set.
- OrchestratorConfig validator: per-env enabled_losses must reference defined
  term names, and at most one echo term may be enabled per env (config-time,
  mirroring the train_sink runtime guard).
- validate_shared_losses: trainer.losses must equal orchestrator.losses, catching
  `losses` set under [trainer]/[orchestrator] separately instead of top-level.
- tests for the validators + losses propagation in test_configs.py.

Custom pointers: the custom core (CustomLossTermConfig) and custom echo filter
(EchoFilterConfig.import_path) already work end-to-end. A custom weight fn and the
term_weights dict generalization (multiple independent-core terms) are deferred as
a follow-up — they belong together.

ruff clean, py_compile OK; not run locally (no 3.12+pydantic on this box).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
- enabled_losses=None now validated as the full term list, so >1 echo term per
  env is caught at config time instead of at rollout time. [review #6]
- loss_overrides keys validated against `losses`; non-echo overrides rejected. [#7]
- warn (don't fail) when prompt-role echo is configured with renderer=None
  (MITO), where prompt_attribution is unavailable so it would silently no-op. [#8]
- token_export: add echo_mask/echo_weight columns + export sequences trained
  only via echo (gate on loss_mask OR echo_mask). [#9]
- doc notes: echo CE uses the rollout temperature (scale alpha to compensate,
  kept as-is); negative alpha is intentional (suppresses tokens). [#1, #10]
- tests for the new config validators.

Deferred to a follow-up pass (per the review): full per-sample primary routing /
rl-disable [#2b] + the <=1-primary validation it enables [#5], and the multi-run
losses fingerprint [#3]. Not run locally; ruff + py_compile clean.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…#2/#5)

- enabled_losses can now disable the primary loss for an env: when the term
  matching training_mode isn't enabled, train_sink zeroes that env's samples'
  completion_mask (after the decode/prefill metric), so the primary core trains
  nothing while echo (its own mask) still applies. [#2b]
- setup_loss_fns no longer fabricates a default rl core when no rl/custom term
  exists; the rl core raises if an rl-mode batch is applied. [#2a, trainer-side]
- OrchestratorConfig.validate_primary_loss: training_mode must have a matching
  term (rl/custom | sft | opd) in `losses`. [#2a, config-time]
- LossList rejects more than one primary (rl/custom) term. [#5]
- tests for the validators + the no-rl-term rl-core error.

Not run locally; ruff + py_compile clean.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…review #3)

The trainer builds its loss cores once from its startup `losses`; each run's
orchestrator carries its own. Register a MultiRunManager config-validation hook
(alongside the LoRA-rank hook, reusing the per-run config-discovery path) that
rejects a run whose `losses` differ from the trainer's. Single-run is already
covered by validate_shared_losses.

Not run locally; ruff + py_compile clean.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
sft/opd dispatch to fixed cores (sft_loss_fn/opd_loss_fn) selected by
training_mode, independent of the `losses` list — only the rl-mode primary comes
from an rl/custom term. The previous validate_primary_loss + _primary_enabled
wrongly required a matching sft/opd term, which (a) rejected every sft/opd config
using the default losses=[rl] (test_load_configs) and (b) would have zeroed
completion_mask for all sft/opd samples. Now both only apply to rl-mode.

Fixes the 7 test_load_configs failures (configs/debug/training_modes/{sft,opd}*,
configs/ci/integration/reverse_text_rl_{opd,sft}). ruff + py_compile clean.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The sft/opd fix reworded validate_primary_loss; update the test regex to match
("requires an rl or custom loss term"). The validation behavior is unchanged.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…sft warning

- Drop sft/opd presets from the losses DSL: they dispatch to fixed cores by
  training_mode and are not loss-list terms, so listing them (or naming them in
  enabled_losses) was a silent no-op footgun. `losses` now holds rl/custom +
  echo only. [review #2]
- Gate the zero-advantage filter on the env's rl primary being active
  (ZeroAdvantageFilter.primary_active, wired to train_sink._primary_enabled), so
  echo-only / rl-disabled envs (where advantage is irrelevant — e.g. group_size=1)
  aren't dropped on zero advantage. Default always-active → unchanged without echo. [review #1]
- The MITO prompt-role-echo warning also fires for training_mode='sft' (which
  forces renderer=None after the warning's original check). [review #7]
- tests for sft-type rejection + the filter gate.

Left as-is: trainer.loss migration (intentional break), backfill prompt-role echo
(best-effort, known), docs DPPO clipping (pre-existing main doc bug).

Not run locally; ruff + py_compile clean.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
test_duplicate_loss_names_rejected still fed {"type": "sft"}, which now
fails the discriminated-union tag check before validate_loss_list can
raise "Duplicate loss term names". Swap the second term to a valid echo
term sharing the same name so the duplicate-name path is exercised.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ation

- orchestrator: gate the zero_advantage filter on primary-active in BOTH pre-
  and post-batch filter lists (echo-only envs were dropped pre-batch).
- train_sink: count n_trainable from real loss-bearing tokens (primary or echo),
  not just unfiltered rollouts.
- configs: reject an explicitly-empty enabled_losses; validate per-env
  loss_overrides at config time by constructing the merged EchoLossConfig
  (shared deep_merge / apply_echo_override helper, reused by the orchestrator).
- loss: echo_loss_fn skips its metrics for empty-mask packed splits (removes the
  packing-composition bias in echo_nll).
- train: entropy/env metrics span loss_mask | echo_mask so echo-only envs aren't
  logged empty/nan; mismatch_kl stays completion-only.
- token_export: bump SCHEMA_VERSION to 2 for the echo_mask/echo_weight columns.
- docs/skills: fix the DPPO loss math (probs-diff masking, not min-clip), the
  length_penalty / filter config paths, and the trainer.optim.lr /
  orchestrator.train.env paths.
- tests: empty enabled_losses + malformed loss_override validators; pure-function
  coverage for the n_trainable trainability check.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
review-main-diff-2026-06-04.md was committed by an over-broad `git add -A`;
it's a local review note, not part of the PR. Untrack it (kept on disk).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant