From e7b80f084be1ad0786adeb02304c84e2e3001f7d Mon Sep 17 00:00:00 2001 From: hallerite Date: Fri, 5 Jun 2026 14:23:09 +0000 Subject: [PATCH] feat(orchestrator): per-env sample strategy + env-mix seam MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce the per-env sampling seam. Each train env owns a `SampleStrategy` (what example to serve, plus an `observe()` feedback hook); env selection is delegated to a swappable `EnvMixStrategy`. Defaults reproduce today's behavior (weighted round-robin over per-env reshuffling-cursor datasets). - `orchestrator/sampling.py` (new): SampleStrategy + ShuffledCursorSampler; EnvMixStrategy + WeightedRoundRobin. - TrainEnv owns its dataset via `build_sampler()` and holds `.sampler`. - TrainSource slims to env-mix + per-env samplers. - TrainSink.process_group calls `env.sampler.observe(survivors)` after advantages (no-op default) — the feedback wire for curriculum / replay samplers. Behavior-equivalent; RNG partitioned per-env + mix. Stacked on feat/per-env-advantage. Co-Authored-By: Claude Opus 4.8 (1M context) --- src/prime_rl/orchestrator/envs.py | 15 ++++ src/prime_rl/orchestrator/sampling.py | 89 +++++++++++++++++++++++ src/prime_rl/orchestrator/train_sink.py | 7 ++ src/prime_rl/orchestrator/train_source.py | 75 +++++++++---------- tests/unit/orchestrator/test_sampling.py | 60 +++++++++++++++ 5 files changed, 209 insertions(+), 37 deletions(-) create mode 100644 src/prime_rl/orchestrator/sampling.py create mode 100644 tests/unit/orchestrator/test_sampling.py diff --git a/src/prime_rl/orchestrator/envs.py b/src/prime_rl/orchestrator/envs.py index fe02d2e61a..d857d06a97 100644 --- a/src/prime_rl/orchestrator/envs.py +++ b/src/prime_rl/orchestrator/envs.py @@ -16,6 +16,7 @@ from prime_rl.configs.orchestrator import EnvConfig, EvalEnvConfig, TrainEnvConfig from prime_rl.orchestrator.eval_utils import compute_pass_at_k +from prime_rl.orchestrator.sampling import SampleStrategy, ShuffledCursorSampler from prime_rl.utils.logger import ProgressTracker, get_logger from prime_rl.utils.monitor import get_monitor from prime_rl.utils.utils import capitalize @@ -170,10 +171,24 @@ class TrainEnv(Env): def __init__(self, config: TrainEnvConfig): super().__init__(config) self.sampling_args = config.sampling.to_sampling_args() + # Set by ``build_sampler`` (called by TrainSource at setup). Owns this + # env's dataset + selection state; reached by the sink for ``observe``. + self.sampler: SampleStrategy | None = None def get_dataset(self, seed: int | None = None): return self.env.get_dataset(seed=seed) + def build_sampler(self, *, seed: int | None) -> None: + """Load this env's dataset and build its default ``SampleStrategy``. + Each row is stamped with ``env_name`` (``example_id`` comes from the + dataset).""" + rows: list[dict] = [] + for row in self.get_dataset(seed=seed): + ex = dict(row) + ex["env_name"] = self.name + rows.append(ex) + self.sampler = ShuffledCursorSampler(rows, seed=seed) + class EvalEnv(Env): config: EvalEnvConfig diff --git a/src/prime_rl/orchestrator/sampling.py b/src/prime_rl/orchestrator/sampling.py new file mode 100644 index 0000000000..ad72b0e2e1 --- /dev/null +++ b/src/prime_rl/orchestrator/sampling.py @@ -0,0 +1,89 @@ +"""Sampling strategies for training rollouts. + +Two seams sit between the train envs and the dispatcher: + +- ``EnvMixStrategy`` (global) — decides *which* env to draw from next. +- ``SampleStrategy`` (per-env) — decides *what* example that env serves next, + and (via ``observe``) can learn from finished, scored groups. Each env owns + its own ``SampleStrategy`` instance, so it can hold dataset + per-env state + (cursor today; curriculum / replay buffers later). + +The defaults (``WeightedRoundRobin`` + ``ShuffledCursorSampler``) reproduce the +previous ``TrainSource`` behavior: a weighted round-robin over per-env datasets +that are each shuffled once and walked with a reshuffling cursor. +""" + +from __future__ import annotations + +import random +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from prime_rl.orchestrator.types import TrainRollout + + +class SampleStrategy(ABC): + """Per-env example selection. One stateful instance per env, alive for the + whole run. ``next`` returns the next example dict (carrying ``env_name`` + + ``example_id``); ``observe`` is the feedback hook for stateful strategies.""" + + @abstractmethod + def next(self) -> dict: + """Return the next example for this env.""" + ... + + def observe(self, group: list[TrainRollout]) -> None: + """Called with one finished, scored group of this env's rollouts (after + advantages are assigned). Default is a no-op; stateful strategies + (curriculum, replay) override this to learn from outcomes.""" + return + + +class ShuffledCursorSampler(SampleStrategy): + """Default sampler: shuffle the env's rows once, walk a cursor, reshuffle on + exhaustion (infinite pull).""" + + def __init__(self, rows: list[dict], *, seed: int | None) -> None: + if not rows: + raise ValueError("ShuffledCursorSampler needs at least one example") + self._rng = random.Random(seed) + self._rows = list(rows) + self._rng.shuffle(self._rows) + self._cursor = 0 + + @property + def dataset_size(self) -> int: + return len(self._rows) + + def next(self) -> dict: + if self._cursor >= len(self._rows): + self._rng.shuffle(self._rows) + self._cursor = 0 + row = self._rows[self._cursor] + self._cursor += 1 + return row + + +class EnvMixStrategy(ABC): + """Global: which env to draw from next. ``pick`` returns an env name.""" + + @abstractmethod + def pick(self) -> str: + """Return the env name to sample from next.""" + ... + + +class WeightedRoundRobin(EnvMixStrategy): + """Default env mix: weighted random choice over env names. Weights are the + configured per-env ratios (when all set) or per-env dataset sizes.""" + + def __init__(self, env_names: list[str], weights: list[float], *, seed: int | None) -> None: + if not env_names: + raise ValueError("WeightedRoundRobin needs at least one env") + self._rng = random.Random(seed) + self._env_names = list(env_names) + self._weights = list(weights) + + def pick(self) -> str: + return self._rng.choices(self._env_names, weights=self._weights, k=1)[0] diff --git a/src/prime_rl/orchestrator/train_sink.py b/src/prime_rl/orchestrator/train_sink.py index 1c735a6a05..1a3bac6c24 100644 --- a/src/prime_rl/orchestrator/train_sink.py +++ b/src/prime_rl/orchestrator/train_sink.py @@ -205,6 +205,13 @@ def process_group(self, group_id: uuid.UUID) -> None: assign_advantages(survivors, self.advantage_fns[env_name]) + # Feedback hook: let this env's sampler learn from the finished, scored + # group (advantages now assigned). No-op for the default cursor sampler; + # curriculum / replay samplers override ``observe``. + sampler = self.train_envs.get(env_name).sampler + if sampler is not None: + sampler.observe(survivors) + # Propagate to the pre-tokenized samples so the orchestrator can # collect samples at ship time without re-walking rollouts. The env # has a single sampling temperature; fan it out across each sample's diff --git a/src/prime_rl/orchestrator/train_source.py b/src/prime_rl/orchestrator/train_source.py index db439f7539..67b69dd405 100644 --- a/src/prime_rl/orchestrator/train_source.py +++ b/src/prime_rl/orchestrator/train_source.py @@ -1,59 +1,60 @@ -"""TrainSource: weighted round-robin across train envs, infinite pull. +"""TrainSource: weighted env mix over per-env samplers, infinite pull. -Weights default to configured ``ratio`` (when every env sets one) or to -per-env dataset size. ``next_example`` reshuffles on cursor exhaustion.""" +Env selection is delegated to an ``EnvMixStrategy`` (default: weighted +round-robin by configured ``ratio`` when all envs set one, else by per-env +dataset size); example selection within an env is delegated to that env's +``SampleStrategy`` (default: a reshuffling cursor). Both are swappable seams. +Returned dicts carry ``env_name`` + ``example_id``. +""" from __future__ import annotations -import random - -from prime_rl.orchestrator.envs import TrainEnvs +from prime_rl.orchestrator.envs import TrainEnv, TrainEnvs +from prime_rl.orchestrator.sampling import WeightedRoundRobin class TrainSource: - """``next_example(available_permits)`` picks a weighted-RR env and - returns its next example (or ``None`` when the env's per-call permit - cost doesn't fit — the dispatch loop retries when permits free up). - Returned dicts carry ``env_name`` + ``example_id``.""" + """``next_example(available_permits)`` picks an env via the mix strategy and + pulls that env's next example from its sampler (or ``None`` when the env's + per-call permit cost doesn't fit — the dispatch loop retries when permits + free up).""" def __init__(self, train_envs: TrainEnvs, *, seed: int | None) -> None: - self.rng = random.Random(seed) self.envs = list(train_envs) if not self.envs: raise ValueError("TrainSource needs at least one train env") + self._envs_by_name = {env.name: env for env in self.envs} - self.examples: dict[str, list[dict]] = {} - self.cursors: dict[str, int] = {} - # Group-scoring envs reserve ``group_size`` permits up front; - # per-rollout envs need 1 + # Build each env's sampler (which owns its dataset) and per-env permit + # cost. Group-scoring envs reserve ``group_size`` permits up front; + # per-rollout envs need 1. Per-env seeds keep distinct envs from + # shuffling in lockstep. self.env_costs: dict[str, int] = {} - for env in self.envs: - rows: list[dict] = [] - for row in env.get_dataset(seed=seed): - ex = dict(row) - ex["env_name"] = env.name - rows.append(ex) - self.rng.shuffle(rows) - self.examples[env.name] = rows - self.cursors[env.name] = 0 + for i, env in enumerate(self.envs): + env.build_sampler(seed=(seed + i) if seed is not None else None) self.env_costs[env.name] = env.config.group_size if env.requires_group_scoring else 1 - self.env_names = [e.name for e in self.envs] - configured_ratios = [e.config.ratio for e in self.envs] + env_names = [env.name for env in self.envs] + configured_ratios = [env.config.ratio for env in self.envs] if all(r is not None for r in configured_ratios): - self.weights: list[float] = [float(r) for r in configured_ratios] # type: ignore[arg-type] + weights = [float(r) for r in configured_ratios] # type: ignore[arg-type] else: - self.weights = [float(len(self.examples[name])) for name in self.env_names] + weights = [float(self._dataset_size(env)) for env in self.envs] + self.env_mix = WeightedRoundRobin(env_names, weights, seed=seed) + + @staticmethod + def _dataset_size(env: TrainEnv) -> int: + size = getattr(env.sampler, "dataset_size", None) + if size is None: + raise ValueError( + f"Env {env.name!r} sampler exposes no dataset_size; set explicit per-env ratios to weight the env mix." + ) + return size def next_example(self, available_permits: int) -> dict | None: - env_name = self.rng.choices(self.env_names, weights=self.weights, k=1)[0] + env_name = self.env_mix.pick() if self.env_costs[env_name] > available_permits: return None - rows = self.examples[env_name] - cursor = self.cursors[env_name] - if cursor >= len(rows): - self.rng.shuffle(rows) - cursor = 0 - example = rows[cursor] - self.cursors[env_name] = cursor + 1 - return example + sampler = self._envs_by_name[env_name].sampler + assert sampler is not None # built in __init__ + return sampler.next() diff --git a/tests/unit/orchestrator/test_sampling.py b/tests/unit/orchestrator/test_sampling.py new file mode 100644 index 0000000000..b6b6944907 --- /dev/null +++ b/tests/unit/orchestrator/test_sampling.py @@ -0,0 +1,60 @@ +from collections import Counter + +import pytest + +from prime_rl.orchestrator.sampling import ( + ShuffledCursorSampler, + WeightedRoundRobin, +) + + +def _rows(n: int) -> list[dict]: + return [{"example_id": i, "env_name": "e"} for i in range(n)] + + +def test_shuffled_cursor_cycles_without_repeats_then_reshuffles(): + sampler = ShuffledCursorSampler(_rows(5), seed=42) + cycle1 = [sampler.next()["example_id"] for _ in range(5)] + cycle2 = [sampler.next()["example_id"] for _ in range(5)] + # Each cycle visits every example exactly once (cursor), then reshuffles. + assert sorted(cycle1) == list(range(5)) + assert sorted(cycle2) == list(range(5)) + + +def test_shuffled_cursor_dataset_size(): + assert ShuffledCursorSampler(_rows(7), seed=0).dataset_size == 7 + + +def test_shuffled_cursor_is_deterministic_per_seed(): + a = ShuffledCursorSampler(_rows(8), seed=123) + b = ShuffledCursorSampler(_rows(8), seed=123) + assert [a.next()["example_id"] for _ in range(8)] == [b.next()["example_id"] for _ in range(8)] + + +def test_shuffled_cursor_empty_raises(): + with pytest.raises(ValueError): + ShuffledCursorSampler([], seed=0) + + +def test_observe_default_is_noop(): + sampler = ShuffledCursorSampler(_rows(3), seed=0) + # Default observe accepts a (possibly empty) group and does nothing observable. + assert sampler.observe([]) is None + + +def test_weighted_round_robin_honors_weights(): + mix = WeightedRoundRobin(["A", "B"], [1.0, 3.0], seed=0) + counts = Counter(mix.pick() for _ in range(4000)) + # B weighted 3x A — expect roughly a 3:1 split. + assert 2.5 < counts["B"] / counts["A"] < 3.5 + + +def test_weighted_round_robin_is_deterministic_per_seed(): + a = WeightedRoundRobin(["A", "B", "C"], [1.0, 1.0, 1.0], seed=7) + b = WeightedRoundRobin(["A", "B", "C"], [1.0, 1.0, 1.0], seed=7) + assert [a.pick() for _ in range(20)] == [b.pick() for _ in range(20)] + + +def test_weighted_round_robin_empty_raises(): + with pytest.raises(ValueError): + WeightedRoundRobin([], [], seed=0)