Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions configs/debug/training_modes/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Minimal end-to-end configs for the three training modes (`rl` / `opd` / `sft`) a
| `sft.toml` | `sft` | local vLLM (`Qwen3-0.6B-Reverse-Text-RL`) | |
| `sft_lora.toml` | `sft` | local vLLM (`Qwen3-0.6B-Reverse-Text-RL`) | trains a LoRA adapter (rank 8) |
| `sft_external.toml` | `sft` | PI inference (`openai/gpt-5-mini`) | external OAI endpoint; no local teacher |
| `sft_replay.toml` | `sft` | none | replays saved message traces through `sft-replay` |

The student inference server is auto-launched on GPU 0 at `http://localhost:8000/v1` with `gpu_memory_utilization=0.5`. The local teacher (used by everything except `rl.toml` and `sft_external.toml`) is **not** auto-launched — start it manually on GPU 1.

Expand Down Expand Up @@ -42,6 +43,9 @@ uv run rl @ configs/debug/training_modes/sft_lora.toml
# SFT hard distill from openai/gpt-5-mini via PI inference
# (requires PRIME_API_KEY + PRIME_TEAM_ID in env; no local teacher needed)
uv run rl @ configs/debug/training_modes/sft_external.toml

# SFT from replayed dataset traces (no teacher)
uv run rl @ configs/debug/training_modes/sft_replay.toml
```

See [docs/training.md](../../docs/training.md#training-modes-rl--opd--sft-via-orchestrator) for what each mode does.
41 changes: 41 additions & 0 deletions configs/debug/training_modes/sft_replay.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Static trace SFT through the RL orchestrator. No teacher server is needed:
# sft-replay turns dataset message rows into replayed rollout trajectories.

max_steps = 20
seq_len = 2048

[model]
name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT"

[wandb]
project = "reverse-text-debug"
name = "debug-sft-replay"

[orchestrator]
training_mode = "sft"
batch_size = 128
group_size = 1

[[orchestrator.train.env]]
id = "sft-replay"

[orchestrator.train.env.args.taskset]
dataset = "PrimeIntellect/Reverse-Text-SFT"

[orchestrator.eval]
interval = 1
num_examples = 128

[orchestrator.eval.sampling]
max_completion_tokens = 128

[[orchestrator.eval.env]]
id = "reverse-text"

[trainer.optim]
lr = 3e-6

[ckpt]

[inference]
gpu_memory_utilization = 0.5
2 changes: 1 addition & 1 deletion deps/verifiers
Submodule verifiers updated 137 files
8 changes: 5 additions & 3 deletions docs/training.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,18 @@ The RL entrypoint supports three training modes, switched via `orchestrator.trai
|---|---|---|---|
| `rl` | Required | Forbidden | Standard RL |
| `opd` | Required | Required, must be vLLM (needs `prompt_logprobs`) | [On-policy distillation](https://thinkingmachines.ai/blog/on-policy-distillation/): student generates rollouts, trainer minimizes KL to teacher logprobs |
| `sft` | Required | Required, any OpenAI-compatible endpoint | Hard-distill: teacher generates rollouts, student trains on them |
| `sft` | Required | Optional | Hard-distill from teacher-generated rollouts, or train from replayed message traces via `sft-replay` |

The `rl` entrypoint only manages student-policy inference. For OPD and (local-vLLM) SFT, start the teacher inference server manually and point `[orchestrator.teacher.client]` at it:
The `rl` entrypoint only manages student-policy inference. For OPD and teacher-backed local-vLLM SFT, start the teacher inference server manually and point `[orchestrator.teacher.client]` at it:

```bash
CUDA_VISIBLE_DEVICES=1 uv run inference \
--model.name <teacher> --server.port 8001
```

The standalone `uv run sft` entrypoint is the more traditional SFT path — pure dataset-based, no teacher, no orchestrator. Use `orchestrator.training_mode = "sft"` only when you want a teacher to generate the supervision on the fly.
Teacherless orchestrator SFT is valid only when every train env is `sft-replay` and each env config sets `args.taskset.dataset` (or `args.config.taskset.dataset`). In that path, the env replays stored assistant messages into trajectories without model calls, then prime-rl tokenizes them for the trainer.

The standalone `uv run sft` entrypoint is the more traditional SFT path — pure dataset-based, no teacher, no orchestrator. Use `orchestrator.training_mode = "sft"` when you want teacher-generated supervision or replayed env trajectories inside the RL orchestrator/eval pipeline.

### Important Metrics

Expand Down
65 changes: 59 additions & 6 deletions packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,31 @@
from prime_rl.configs.trainer import TokenizerConfig
from prime_rl.utils.config import BaseConfig

SFT_REPLAY_ENV_ID = "sft-replay"


def _is_sft_replay_env_id(env_id: str) -> bool:
stripped = env_id.split("@")[0]
return stripped == SFT_REPLAY_ENV_ID or stripped.endswith(f"/{SFT_REPLAY_ENV_ID}")


def _sft_replay_dataset_arg(env_args: dict) -> object | None:
taskset = env_args.get("taskset")
if isinstance(taskset, dict):
dataset = taskset.get("dataset")
if dataset:
return dataset

config = env_args.get("config")
if isinstance(config, dict):
taskset = config.get("taskset")
if isinstance(taskset, dict):
dataset = taskset.get("dataset")
if dataset:
return dataset

return None


class OptimizerConfig(BaseConfig):
lr: float = Field(1e-4, ge=0)
Expand Down Expand Up @@ -501,13 +526,13 @@ class RolloutModelConfig(BaseConfig):

class OrchestratorConfig(BaseConfig):
training_mode: Literal["rl", "opd", "sft"] = "rl"
"""Training mode. ``rl``: student generates rollouts, no teacher. ``opd``: student generates rollouts, teacher computes logprobs (teacher_tau > 0). ``sft``: teacher generates rollouts, student inference pool used for evals and weight sync."""
"""Training mode. ``rl``: student generates rollouts, no teacher. ``opd``: student generates rollouts, teacher computes logprobs (teacher_tau > 0). ``sft``: teacher generates rollouts when configured, otherwise train envs must provide replayed traces."""

student: RolloutModelConfig = Field(RolloutModelConfig(), validation_alias=AliasChoices("student", "model"))
"""Student rollout participant (model + client) — the model being trained."""

teacher: RolloutModelConfig | None = Field(None, validation_alias=AliasChoices("teacher", "teacher_model"))
"""Teacher rollout participant (model + client). Role depends on ``training_mode``: ``opd`` — teacher computes logprobs; ``sft`` — teacher generates rollouts."""
"""Teacher rollout participant (model + client). Role depends on ``training_mode``: ``opd`` — teacher computes logprobs; ``sft`` — teacher generates rollouts when configured."""

train: TrainConfig = TrainConfig()

Expand Down Expand Up @@ -752,10 +777,16 @@ def validate_unique_filter_types(self):
)
return self

@model_validator(mode="after")
def _drop_default_sft_zero_advantage_filter(self):
if self.training_mode == "sft" and "post_batch_filters" not in self.model_fields_set:
self.post_batch_filters = [f for f in self.post_batch_filters if f.type != "zero_advantage"]
return self

@model_validator(mode="after")
def _force_no_renderer_for_sft(self):
"""SFT rolls out via the teacher's plain chat-completions endpoint; the
renderer client doesn't apply. Force ``renderer=None`` so the user
"""SFT train rollouts use teacher chat completions or replayed traces;
the renderer client doesn't apply. Force ``renderer=None`` so the user
doesn't have to remember to set it. Declared before the renderer
validators below so they see the corrected value."""
if self.training_mode == "sft":
Expand All @@ -768,8 +799,30 @@ def validate_training_mode(self):
has_teacher = self.teacher is not None
if self.training_mode == "rl" and has_teacher:
raise ValueError("orchestrator.teacher must not be set when training_mode = 'rl'.")
if self.training_mode in ("opd", "sft") and not has_teacher:
raise ValueError(f"orchestrator.teacher must be configured when training_mode = '{self.training_mode}'.")
if self.training_mode == "opd" and not has_teacher:
raise ValueError("orchestrator.teacher must be configured when training_mode = 'opd'.")
return self

@model_validator(mode="after")
def validate_teacherless_sft_uses_sft_replay(self):
"""Teacherless SFT is only valid when train envs replay existing data."""
if self.training_mode != "sft" or self.teacher is not None:
return self

non_replay_envs = [env.id for env in self.train.env if not _is_sft_replay_env_id(env.id)]
if non_replay_envs:
raise ValueError(
"orchestrator.teacher must be configured for SFT unless every train env uses "
f"{SFT_REPLAY_ENV_ID!r}; got non-replay train env(s): {non_replay_envs}."
)

missing_dataset = [env.resolved_name for env in self.train.env if _sft_replay_dataset_arg(env.args) is None]
if missing_dataset:
raise ValueError(
f"teacherless SFT with {SFT_REPLAY_ENV_ID!r} requires an explicit "
"env.args.taskset.dataset or env.args.config.taskset.dataset for "
f"each train env; missing for: {missing_dataset}."
)
return self

@model_validator(mode="after")
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ envs = [
"rlm-swe",
"science-env",
"simpleqa-verified",
"sft-replay",
"tau2-bench",
"wiki-search",
"wordle",
Expand Down Expand Up @@ -150,6 +151,8 @@ override-dependencies = [
"transformers==5.6.2",
"torch>=2.9.0",
"openenv-core",
"harnesses>=0.1.0",
"tasksets>=0.1.0",
]

# ModelExpress 0.3.0 publishes protobuf<6 metadata, but its generated proto is
Expand Down Expand Up @@ -224,6 +227,9 @@ reverse-text = { path = "deps/verifiers/environments/reverse_text", editable = t
rlm-swe = { path = "deps/research-environments/environments/rlm_swe", editable = true }
science-env = { path = "deps/research-environments/environments/science_env", editable = true }
simpleqa-verified = { path = "deps/research-environments/environments/simpleqa_verified", editable = true }
sft-replay = { path = "deps/verifiers/environments/sft_replay", editable = true }
harnesses = { path = "deps/verifiers/packages/harnesses", editable = true }
tasksets = { path = "deps/verifiers/packages/tasksets", editable = true }
tau2-bench = { path = "deps/research-environments/environments/tau2_bench", editable = true }
wiki-search = { path = "deps/verifiers/environments/wiki_search", editable = true }
wordle = { path = "deps/verifiers/environments/wordle", editable = true }
Expand Down
17 changes: 17 additions & 0 deletions skills/configs/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,23 @@ CLI: `--env.0.id reverse-text --env.1.id math-env`.

In TOML, an empty section header (`[ckpt]`) does the same.

## Replay-backed SFT config

For teacherless SFT through the `rl` orchestrator, set `orchestrator.training_mode = "sft"` and use only `sft-replay` train envs. Each train env must provide the dataset under the taskset config:

```toml
[orchestrator]
training_mode = "sft"

[[orchestrator.train.env]]
id = "sft-replay"

[orchestrator.train.env.args.taskset]
dataset = "PrimeIntellect/Reverse-Text-SFT"
```

Do not pass the dataset at `env.args.dataset`; config validation rejects that shape because replay data belongs to the taskset.

## RL trainer token exports

For rollout debugging, enable trainer-side token export under `trainer.experimental.token_export` (or `experimental.token_export` when running the trainer entrypoint directly). It writes one JSONL record per exported sequence under `output_dir/token_exports/step_<step>/rank_<rank>.jsonl`. Each record stores aligned per-token arrays for token ids, loss mask, advantage, reward, entropy, mismatch KL, inference/trainer logprobs, importance ratios, probability deltas, and masking diagnostics. It does not decode token text in the trainer.
Expand Down
3 changes: 2 additions & 1 deletion skills/training/start-run/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ uv run rl @ examples/reverse_text/rl.toml --dry-run
- Config: `RLConfig` (`packages/prime-rl-configs/src/prime_rl/configs/rl.py`)
- Entrypoint: `src/prime_rl/entrypoints/rl.py`
- SLURM: single- and multi-node
- Training modes: `orchestrator.training_mode = "rl"` (default), `"opd"` (requires teacher), or `"sft"`. SFT can use a configured teacher, or teacherless replay when every train env is `sft-replay` with `args.taskset.dataset` set.
- Environment packages: before launching a config with a non-core verifier env id,
verify the package imports under `uv run` (for example
`uv run python -c "import importlib.util; print(importlib.util.find_spec('rlm_swe'))"`).
Expand Down Expand Up @@ -82,7 +83,7 @@ curl http://localhost:8000/v1/chat/completions \

| Command | Purpose | Typical use |
|---------|---------|-------------|
| `rl` | Full RL pipeline | Production RL training |
| `rl` | Full orchestrator pipeline | RL, OPD, and orchestrator-backed SFT |
| `sft` | Supervised fine-tuning | SFT and hard-distill |
| `inference` | vLLM server | Standalone serving / debugging |

Expand Down
11 changes: 6 additions & 5 deletions src/prime_rl/orchestrator/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,9 @@ def __init__(
self.policy = policy
self.train_envs = train_envs
self.eval_envs = eval_envs
# Train rollouts go to ``inference`` (the teacher in SFT mode);
# eval always evaluates the student, so it uses ``eval_inference``.
# Train rollouts go to ``inference`` (teacher in teacher-SFT, student
# otherwise); eval always evaluates the student, so it uses
# ``eval_inference``.
self.inference = inference
self.eval_inference = eval_inference
self.train_source = train_source
Expand Down Expand Up @@ -173,9 +174,9 @@ def __init__(

@property
def train_model_name(self) -> str:
"""Model name for *train* rollouts. In SFT mode train data comes from
the teacher pool, so use its model name; otherwise the live student
policy. (Eval always uses ``policy.model_name`` — the student.)"""
"""Model name for *train* rollouts. Teacher-SFT uses the teacher pool
name; replay SFT receives the student name but ignores it. Eval always
uses ``policy.model_name`` — the student."""
if self.training_mode == "sft":
return self.inference.model_name
return self.policy.model_name
Expand Down
8 changes: 4 additions & 4 deletions src/prime_rl/orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,10 +345,10 @@ async def setup(self) -> None:
else:
get_logger().info("Training from scratch")

# SFT generates rollouts via the teacher (the student is trained on
# the teacher's outputs); RL / OPD generate via the student
if config.training_mode == "sft":
assert self.teacher_inference is not None, "sft mode requires teacher inference"
# SFT train rollouts come from the teacher when configured. Teacherless
# SFT is validated at config parse time to use replay envs, which ignore
# the client/model passed by the dispatcher.
if config.training_mode == "sft" and self.teacher_inference is not None:
rollout_inference = self.teacher_inference
else:
rollout_inference = self.student_inference
Expand Down
25 changes: 25 additions & 0 deletions src/prime_rl/orchestrator/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,24 @@ def _tokenize_step_with_renderer(
return build_trajectory_step(renderer, prompt, completion, tools=tools)


def _set_token_usage_from_trajectory(output: vf.RolloutOutput) -> None:
trajectory = output.get("trajectory") or []
tokenized_steps = [step for step in trajectory if step.get("tokens") is not None]
if not tokenized_steps:
return

prompt_tokens = [len(step["tokens"]["prompt_ids"]) for step in tokenized_steps]
completion_tokens = [len(step["tokens"]["completion_ids"]) for step in tokenized_steps]
total_completion = sum(completion_tokens)
last_total = prompt_tokens[-1] + completion_tokens[-1]
output["token_usage"] = {
"input_tokens": float(sum(prompt_tokens)),
"output_tokens": float(total_completion),
"final_input_tokens": float(max(0, last_total - total_completion)),
"final_output_tokens": float(total_completion),
}


def backfill_rollout_tokens(
output: vf.RolloutOutput,
tokenizer: PreTrainedTokenizer,
Expand All @@ -175,6 +193,9 @@ def backfill_rollout_tokens(
Otherwise falls back to the tokenizer + apply_chat_template path.
"""
if all(step["tokens"] is not None for step in output["trajectory"]):
token_usage = output.get("token_usage") or {}
if "final_input_tokens" not in token_usage or "final_output_tokens" not in token_usage:
_set_token_usage_from_trajectory(output)
return True

logger = get_logger()
Expand All @@ -198,6 +219,10 @@ def backfill_rollout_tokens(
reconstructed.pop("original_prompt_len")
step["tokens"] = reconstructed

token_usage = output.get("token_usage") or {}
if "final_input_tokens" not in token_usage or "final_output_tokens" not in token_usage:
_set_token_usage_from_trajectory(output)

return True


Expand Down
Loading
Loading