From b561ec830b09eed6cf419a17d4de1f8967c70b75 Mon Sep 17 00:00:00 2001 From: hallerite Date: Fri, 5 Jun 2026 13:42:00 +0000 Subject: [PATCH] feat(orchestrator): per-env advantage strategy Advantage was a single global config applied to every training env. Make it configurable per env: each `TrainEnvConfig` can set its own `advantage`, inheriting the top-level `orchestrator.advantage` when unset (same pattern as `group_size`). `TrainSink` resolves one advantage fn per env and applies it in `process_group` via `self.advantage_fns[env_name]`. The advantage/length-penalty config classes move above `EnvConfig` so `TrainEnvConfig` can reference `AdvantageConfig` (the module has no `from __future__ import annotations`, so annotations evaluate eagerly). Behavior-preserving: with the default config every env inherits the global default, producing identical advantages. Co-Authored-By: Claude Opus 4.8 (1M context) --- docs/algorithms.md | 16 ++++ .../src/prime_rl/configs/orchestrator.py | 96 ++++++++++--------- src/prime_rl/orchestrator/orchestrator.py | 1 - src/prime_rl/orchestrator/train_sink.py | 17 ++-- 4 files changed, 79 insertions(+), 51 deletions(-) diff --git a/docs/algorithms.md b/docs/algorithms.md index fdd5b6e2da..0ffe69edba 100644 --- a/docs/algorithms.md +++ b/docs/algorithms.md @@ -167,6 +167,22 @@ kwargs = { eps = 1e-8 } `AdvantageInputs.rollouts` is a list of `verifiers.RolloutOutput`, so you have access to the full rollout (turns, tool calls, custom metadata) — not just the reward. Use this for anything reward-shaping-like that needs trajectory context. +### Per-Env Advantage + +`advantage` can be set per training environment. Each env inherits the top-level `[orchestrator.advantage]` when it doesn't set its own, so mixed-env runs can give each env its own advantage computation: + +```toml +[orchestrator.advantage] +type = "default" # the default every env inherits unless it overrides + +[[orchestrator.train.env]] +id = "math-env" # inherits the default above + +[[orchestrator.train.env]] +id = "agent-env" +advantage = { type = "custom", import_path = "my_module.normalized_advantage" } +``` + ## Filters Filters drop rollouts between scoring and training. Built-ins (composable): diff --git a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py index be5fe249f3..83f3fea7a6 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py @@ -143,6 +143,49 @@ def _deprecate_max_tokens(cls, data: Any) -> Any: return data +class TokensLengthPenaltyConfig(BaseConfig): + type: Literal["tokens"] = "tokens" + + completion_weight: float = Field(1.0, ge=0, allow_inf_nan=False) + """Weight on model completion tokens. Finite and non-negative.""" + + tool_response_weight: float = Field(1.0, ge=0, allow_inf_nan=False) + """Weight on tool-response tokens (read from the rollout's ``*_total_tool_response_tokens`` harness metric; 0 if absent). Finite and non-negative.""" + + +class TurnsLengthPenaltyConfig(BaseConfig): + type: Literal["turns"] = "turns" + + +LengthPenaltyConfig: TypeAlias = Annotated[ + TokensLengthPenaltyConfig | TurnsLengthPenaltyConfig, + Field(discriminator="type"), +] + + +class DefaultAdvantageConfig(BaseConfig): + type: Literal["default"] = "default" + + length_penalty: LengthPenaltyConfig | None = None + """Correctness-gated length penalty. ``tokens`` shapes by weighted token cost; ``turns`` shapes by trajectory turn count; None disables shaping. In mixed groups, lower-cost correct rollouts get amplified advantage (up to 2x), higher-cost correct rollouts are unchanged, incorrect untouched. In all-correct groups, below-average-cost rollouts get advantage in [0, 1], others get 0.""" + + +class CustomAdvantageConfig(BaseConfig): + type: Literal["custom"] = "custom" + + import_path: str + """Import path to the advantage function (e.g. ``my_module.my_advantage``).""" + + kwargs: dict[str, Any] = Field(default_factory=dict) + """Kwargs forwarded to the advantage function.""" + + +AdvantageConfig: TypeAlias = Annotated[ + DefaultAdvantageConfig | CustomAdvantageConfig, + Field(discriminator="type"), +] + + class EnvConfig(BaseConfig): id: str = "reverse-text" """Registered verifiers environment ID (e.g. ``math-env``, ``primeintellect/math-env``). May include an ``@version`` suffix for installation.""" @@ -214,6 +257,11 @@ class TrainEnvConfig(EnvConfig): """Rollouts generated per example for GRPO group-relative advantages. Inherits from ``orchestrator.group_size`` when unset.""" + advantage: AdvantageConfig | None = None + """Advantage strategy for this env's GRPO groups. Inherits from the top-level + ``orchestrator.advantage`` when unset; set a different ``default``/``custom`` + config to give this env its own advantage computation.""" + class EvalEnvConfig(EnvConfig): sampling: EvalSamplingConfig = EvalSamplingConfig() @@ -374,49 +422,6 @@ class CheckpointConfig(BaseConfig): """Skip loading the progress from checkpoint.""" -class TokensLengthPenaltyConfig(BaseConfig): - type: Literal["tokens"] = "tokens" - - completion_weight: float = Field(1.0, ge=0, allow_inf_nan=False) - """Weight on model completion tokens. Finite and non-negative.""" - - tool_response_weight: float = Field(1.0, ge=0, allow_inf_nan=False) - """Weight on tool-response tokens (read from the rollout's ``*_total_tool_response_tokens`` harness metric; 0 if absent). Finite and non-negative.""" - - -class TurnsLengthPenaltyConfig(BaseConfig): - type: Literal["turns"] = "turns" - - -LengthPenaltyConfig: TypeAlias = Annotated[ - TokensLengthPenaltyConfig | TurnsLengthPenaltyConfig, - Field(discriminator="type"), -] - - -class DefaultAdvantageConfig(BaseConfig): - type: Literal["default"] = "default" - - length_penalty: LengthPenaltyConfig | None = None - """Correctness-gated length penalty. ``tokens`` shapes by weighted token cost; ``turns`` shapes by trajectory turn count; None disables shaping. In mixed groups, lower-cost correct rollouts get amplified advantage (up to 2x), higher-cost correct rollouts are unchanged, incorrect untouched. In all-correct groups, below-average-cost rollouts get advantage in [0, 1], others get 0.""" - - -class CustomAdvantageConfig(BaseConfig): - type: Literal["custom"] = "custom" - - import_path: str - """Import path to the advantage function (e.g. ``my_module.my_advantage``).""" - - kwargs: dict[str, Any] = Field(default_factory=dict) - """Kwargs forwarded to the advantage function.""" - - -AdvantageConfig: TypeAlias = Annotated[ - DefaultAdvantageConfig | CustomAdvantageConfig, - Field(discriminator="type"), -] - - # Flags rare tokens generated at high entropy (Section 5.2, https://arxiv.org/abs/2510.02387). class GibberishFilterConfig(BaseConfig): type: Literal["gibberish"] = "gibberish" @@ -876,6 +881,11 @@ def resolve_batching(self): if "group_size" not in env_cfg.model_fields_set: env_cfg.group_size = self.group_size + # Propagate the top-level ``advantage`` into each train env that didn't set its own. + for env_cfg in self.train.env: + if "advantage" not in env_cfg.model_fields_set: + env_cfg.advantage = self.advantage + # Resolve train env num_workers from max_inflight_rollouts for env_cfg in self.train.env: if env_cfg.num_workers == "auto": diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index 902c8b963b..81bfb0ed5b 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -389,7 +389,6 @@ async def setup(self) -> None: mm_token_type_ids_mapping=self.mm_token_type_ids_mapping, batch_size=config.batch_size, token_batch_size=config.token_batch_size, - advantage_config=config.advantage, pre_filters=pre_filters, post_filters=post_filters, ) diff --git a/src/prime_rl/orchestrator/train_sink.py b/src/prime_rl/orchestrator/train_sink.py index 26e7b915b0..1c735a6a05 100644 --- a/src/prime_rl/orchestrator/train_sink.py +++ b/src/prime_rl/orchestrator/train_sink.py @@ -17,8 +17,8 @@ import uuid from collections import defaultdict -from prime_rl.configs.orchestrator import AdvantageConfig, OrchestratorConfig -from prime_rl.orchestrator.advantage import assign_advantages, setup_advantage_fn +from prime_rl.configs.orchestrator import OrchestratorConfig +from prime_rl.orchestrator.advantage import AdvantageFn, assign_advantages, setup_advantage_fn from prime_rl.orchestrator.envs import TrainEnvs from prime_rl.orchestrator.filters import RolloutFilter, apply_filters from prime_rl.orchestrator.trajectories import ( @@ -44,7 +44,6 @@ def __init__( mm_token_type_ids_mapping: dict[int, int] | None, batch_size: int | None, token_batch_size: int | None, - advantage_config: AdvantageConfig | None, pre_filters: list[RolloutFilter], post_filters: list[RolloutFilter], ) -> None: @@ -58,9 +57,13 @@ def __init__( self.mm_token_type_ids_mapping = mm_token_type_ids_mapping self.batch_size = batch_size self.token_batch_size = token_batch_size - # Built once — custom advantage funcs do an ``import_object`` and - # we don't want to pay that per group. ``None`` = reward-only path - self.advantage_fn = setup_advantage_fn(advantage_config) if advantage_config is not None else None + # Built once per env — custom advantage funcs do an ``import_object`` and + # we don't want to pay that per group. Each env carries its own advantage + # config (inheriting the top-level default when unset). ``None`` = reward-only path. + self.advantage_fns: dict[str, AdvantageFn | None] = { + env.name: setup_advantage_fn(env.config.advantage) if env.config.advantage is not None else None + for env in train_envs + } self.pre_filters = pre_filters self.post_filters = post_filters @@ -200,7 +203,7 @@ def process_group(self, group_id: uuid.UUID) -> None: ) return - assign_advantages(survivors, self.advantage_fn) + assign_advantages(survivors, self.advantage_fns[env_name]) # Propagate to the pre-tokenized samples so the orchestrator can # collect samples at ship time without re-walking rollouts. The env