Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/prime_rl/orchestrator/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from prime_rl.configs.orchestrator import EnvConfig, EvalEnvConfig, TrainEnvConfig
from prime_rl.orchestrator.eval_utils import compute_pass_at_k
from prime_rl.orchestrator.sampling import SampleStrategy, ShuffledCursorSampler
from prime_rl.utils.logger import ProgressTracker, get_logger
from prime_rl.utils.monitor import get_monitor
from prime_rl.utils.utils import capitalize
Expand Down Expand Up @@ -170,10 +171,24 @@ class TrainEnv(Env):
def __init__(self, config: TrainEnvConfig):
super().__init__(config)
self.sampling_args = config.sampling.to_sampling_args()
# Set by ``build_sampler`` (called by TrainSource at setup). Owns this
# env's dataset + selection state; reached by the sink for ``observe``.
self.sampler: SampleStrategy | None = None

def get_dataset(self, seed: int | None = None):
return self.env.get_dataset(seed=seed)

def build_sampler(self, *, seed: int | None) -> None:
"""Load this env's dataset and build its default ``SampleStrategy``.
Each row is stamped with ``env_name`` (``example_id`` comes from the
dataset)."""
rows: list[dict] = []
for row in self.get_dataset(seed=seed):
ex = dict(row)
ex["env_name"] = self.name
rows.append(ex)
self.sampler = ShuffledCursorSampler(rows, seed=seed)


class EvalEnv(Env):
config: EvalEnvConfig
Expand Down
89 changes: 89 additions & 0 deletions src/prime_rl/orchestrator/sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Sampling strategies for training rollouts.

Two seams sit between the train envs and the dispatcher:

- ``EnvMixStrategy`` (global) — decides *which* env to draw from next.
- ``SampleStrategy`` (per-env) — decides *what* example that env serves next,
and (via ``observe``) can learn from finished, scored groups. Each env owns
its own ``SampleStrategy`` instance, so it can hold dataset + per-env state
(cursor today; curriculum / replay buffers later).

The defaults (``WeightedRoundRobin`` + ``ShuffledCursorSampler``) reproduce the
previous ``TrainSource`` behavior: a weighted round-robin over per-env datasets
that are each shuffled once and walked with a reshuffling cursor.
"""

from __future__ import annotations

import random
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from prime_rl.orchestrator.types import TrainRollout


class SampleStrategy(ABC):
"""Per-env example selection. One stateful instance per env, alive for the
whole run. ``next`` returns the next example dict (carrying ``env_name`` +
``example_id``); ``observe`` is the feedback hook for stateful strategies."""

@abstractmethod
def next(self) -> dict:
"""Return the next example for this env."""
...

def observe(self, group: list[TrainRollout]) -> None:
"""Called with one finished, scored group of this env's rollouts (after
advantages are assigned). Default is a no-op; stateful strategies
(curriculum, replay) override this to learn from outcomes."""
return


class ShuffledCursorSampler(SampleStrategy):
"""Default sampler: shuffle the env's rows once, walk a cursor, reshuffle on
exhaustion (infinite pull)."""

def __init__(self, rows: list[dict], *, seed: int | None) -> None:
if not rows:
raise ValueError("ShuffledCursorSampler needs at least one example")
self._rng = random.Random(seed)
self._rows = list(rows)
self._rng.shuffle(self._rows)
self._cursor = 0

@property
def dataset_size(self) -> int:
return len(self._rows)

def next(self) -> dict:
if self._cursor >= len(self._rows):
self._rng.shuffle(self._rows)
self._cursor = 0
row = self._rows[self._cursor]
self._cursor += 1
return row


class EnvMixStrategy(ABC):
"""Global: which env to draw from next. ``pick`` returns an env name."""

@abstractmethod
def pick(self) -> str:
"""Return the env name to sample from next."""
...


class WeightedRoundRobin(EnvMixStrategy):
"""Default env mix: weighted random choice over env names. Weights are the
configured per-env ratios (when all set) or per-env dataset sizes."""

def __init__(self, env_names: list[str], weights: list[float], *, seed: int | None) -> None:
if not env_names:
raise ValueError("WeightedRoundRobin needs at least one env")
self._rng = random.Random(seed)
self._env_names = list(env_names)
self._weights = list(weights)

def pick(self) -> str:
return self._rng.choices(self._env_names, weights=self._weights, k=1)[0]
7 changes: 7 additions & 0 deletions src/prime_rl/orchestrator/train_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,13 @@ def process_group(self, group_id: uuid.UUID) -> None:

assign_advantages(survivors, self.advantage_fns[env_name])

# Feedback hook: let this env's sampler learn from the finished, scored
# group (advantages now assigned). No-op for the default cursor sampler;
# curriculum / replay samplers override ``observe``.
sampler = self.train_envs.get(env_name).sampler
if sampler is not None:
sampler.observe(survivors)

# Propagate to the pre-tokenized samples so the orchestrator can
# collect samples at ship time without re-walking rollouts. The env
# has a single sampling temperature; fan it out across each sample's
Expand Down
75 changes: 38 additions & 37 deletions src/prime_rl/orchestrator/train_source.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,60 @@
"""TrainSource: weighted round-robin across train envs, infinite pull.
"""TrainSource: weighted env mix over per-env samplers, infinite pull.

Weights default to configured ``ratio`` (when every env sets one) or to
per-env dataset size. ``next_example`` reshuffles on cursor exhaustion."""
Env selection is delegated to an ``EnvMixStrategy`` (default: weighted
round-robin by configured ``ratio`` when all envs set one, else by per-env
dataset size); example selection within an env is delegated to that env's
``SampleStrategy`` (default: a reshuffling cursor). Both are swappable seams.
Returned dicts carry ``env_name`` + ``example_id``.
"""

from __future__ import annotations

import random

from prime_rl.orchestrator.envs import TrainEnvs
from prime_rl.orchestrator.envs import TrainEnv, TrainEnvs
from prime_rl.orchestrator.sampling import WeightedRoundRobin


class TrainSource:
"""``next_example(available_permits)`` picks a weighted-RR env and
returns its next example (or ``None`` when the env's per-call permit
cost doesn't fit — the dispatch loop retries when permits free up).
Returned dicts carry ``env_name`` + ``example_id``."""
"""``next_example(available_permits)`` picks an env via the mix strategy and
pulls that env's next example from its sampler (or ``None`` when the env's
per-call permit cost doesn't fit — the dispatch loop retries when permits
free up)."""

def __init__(self, train_envs: TrainEnvs, *, seed: int | None) -> None:
self.rng = random.Random(seed)
self.envs = list(train_envs)
if not self.envs:
raise ValueError("TrainSource needs at least one train env")
self._envs_by_name = {env.name: env for env in self.envs}

self.examples: dict[str, list[dict]] = {}
self.cursors: dict[str, int] = {}
# Group-scoring envs reserve ``group_size`` permits up front;
# per-rollout envs need 1
# Build each env's sampler (which owns its dataset) and per-env permit
# cost. Group-scoring envs reserve ``group_size`` permits up front;
# per-rollout envs need 1. Per-env seeds keep distinct envs from
# shuffling in lockstep.
self.env_costs: dict[str, int] = {}
for env in self.envs:
rows: list[dict] = []
for row in env.get_dataset(seed=seed):
ex = dict(row)
ex["env_name"] = env.name
rows.append(ex)
self.rng.shuffle(rows)
self.examples[env.name] = rows
self.cursors[env.name] = 0
for i, env in enumerate(self.envs):
env.build_sampler(seed=(seed + i) if seed is not None else None)
self.env_costs[env.name] = env.config.group_size if env.requires_group_scoring else 1

self.env_names = [e.name for e in self.envs]
configured_ratios = [e.config.ratio for e in self.envs]
env_names = [env.name for env in self.envs]
configured_ratios = [env.config.ratio for env in self.envs]
if all(r is not None for r in configured_ratios):
self.weights: list[float] = [float(r) for r in configured_ratios] # type: ignore[arg-type]
weights = [float(r) for r in configured_ratios] # type: ignore[arg-type]
else:
self.weights = [float(len(self.examples[name])) for name in self.env_names]
weights = [float(self._dataset_size(env)) for env in self.envs]
self.env_mix = WeightedRoundRobin(env_names, weights, seed=seed)

@staticmethod
def _dataset_size(env: TrainEnv) -> int:
size = getattr(env.sampler, "dataset_size", None)
if size is None:
raise ValueError(
f"Env {env.name!r} sampler exposes no dataset_size; set explicit per-env ratios to weight the env mix."
)
return size

def next_example(self, available_permits: int) -> dict | None:
env_name = self.rng.choices(self.env_names, weights=self.weights, k=1)[0]
env_name = self.env_mix.pick()
if self.env_costs[env_name] > available_permits:
return None
rows = self.examples[env_name]
cursor = self.cursors[env_name]
if cursor >= len(rows):
self.rng.shuffle(rows)
cursor = 0
example = rows[cursor]
self.cursors[env_name] = cursor + 1
return example
sampler = self._envs_by_name[env_name].sampler
assert sampler is not None # built in __init__
return sampler.next()
60 changes: 60 additions & 0 deletions tests/unit/orchestrator/test_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from collections import Counter

import pytest

from prime_rl.orchestrator.sampling import (
ShuffledCursorSampler,
WeightedRoundRobin,
)


def _rows(n: int) -> list[dict]:
return [{"example_id": i, "env_name": "e"} for i in range(n)]


def test_shuffled_cursor_cycles_without_repeats_then_reshuffles():
sampler = ShuffledCursorSampler(_rows(5), seed=42)
cycle1 = [sampler.next()["example_id"] for _ in range(5)]
cycle2 = [sampler.next()["example_id"] for _ in range(5)]
# Each cycle visits every example exactly once (cursor), then reshuffles.
assert sorted(cycle1) == list(range(5))
assert sorted(cycle2) == list(range(5))


def test_shuffled_cursor_dataset_size():
assert ShuffledCursorSampler(_rows(7), seed=0).dataset_size == 7


def test_shuffled_cursor_is_deterministic_per_seed():
a = ShuffledCursorSampler(_rows(8), seed=123)
b = ShuffledCursorSampler(_rows(8), seed=123)
assert [a.next()["example_id"] for _ in range(8)] == [b.next()["example_id"] for _ in range(8)]


def test_shuffled_cursor_empty_raises():
with pytest.raises(ValueError):
ShuffledCursorSampler([], seed=0)


def test_observe_default_is_noop():
sampler = ShuffledCursorSampler(_rows(3), seed=0)
# Default observe accepts a (possibly empty) group and does nothing observable.
assert sampler.observe([]) is None


def test_weighted_round_robin_honors_weights():
mix = WeightedRoundRobin(["A", "B"], [1.0, 3.0], seed=0)
counts = Counter(mix.pick() for _ in range(4000))
# B weighted 3x A — expect roughly a 3:1 split.
assert 2.5 < counts["B"] / counts["A"] < 3.5


def test_weighted_round_robin_is_deterministic_per_seed():
a = WeightedRoundRobin(["A", "B", "C"], [1.0, 1.0, 1.0], seed=7)
b = WeightedRoundRobin(["A", "B", "C"], [1.0, 1.0, 1.0], seed=7)
assert [a.pick() for _ in range(20)] == [b.pick() for _ in range(20)]


def test_weighted_round_robin_empty_raises():
with pytest.raises(ValueError):
WeightedRoundRobin([], [], seed=0)
Loading