Skip to content

Commit 8208ee0

Browse files
committed
Add replay-backed SFT path
1 parent 65914f9 commit 8208ee0

11 files changed

Lines changed: 350 additions & 16 deletions

File tree

configs/debug/training_modes/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Minimal end-to-end configs for the three training modes (`rl` / `opd` / `sft`) a
1010
| `sft.toml` | `sft` | local vLLM (`Qwen3-0.6B-Reverse-Text-RL`) | |
1111
| `sft_lora.toml` | `sft` | local vLLM (`Qwen3-0.6B-Reverse-Text-RL`) | trains a LoRA adapter (rank 8) |
1212
| `sft_external.toml` | `sft` | PI inference (`openai/gpt-5-mini`) | external OAI endpoint; no local teacher |
13+
| `sft_replay.toml` | `sft` | none | replays saved message traces through `sft-replay` |
1314

1415
The student inference server is auto-launched on GPU 0 at `http://localhost:8000/v1` with `gpu_memory_utilization=0.5`. The local teacher (used by everything except `rl.toml` and `sft_external.toml`) is **not** auto-launched — start it manually on GPU 1.
1516

@@ -42,6 +43,9 @@ uv run rl @ configs/debug/training_modes/sft_lora.toml
4243
# SFT hard distill from openai/gpt-5-mini via PI inference
4344
# (requires PRIME_API_KEY + PRIME_TEAM_ID in env; no local teacher needed)
4445
uv run rl @ configs/debug/training_modes/sft_external.toml
46+
47+
# SFT from replayed dataset traces (no teacher)
48+
uv run rl @ configs/debug/training_modes/sft_replay.toml
4549
```
4650

4751
See [docs/training.md](../../docs/training.md#training-modes-rl--opd--sft-via-orchestrator) for what each mode does.
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Static trace SFT through the RL orchestrator. No teacher server is needed:
2+
# sft-replay turns dataset message rows into replayed rollout trajectories.
3+
4+
max_steps = 20
5+
seq_len = 2048
6+
7+
[model]
8+
name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT"
9+
10+
[wandb]
11+
project = "reverse-text-debug"
12+
name = "debug-sft-replay"
13+
14+
[orchestrator]
15+
training_mode = "sft"
16+
batch_size = 128
17+
group_size = 1
18+
19+
[[orchestrator.train.env]]
20+
id = "sft-replay"
21+
22+
[orchestrator.train.env.args.taskset]
23+
dataset = "PrimeIntellect/Reverse-Text-SFT"
24+
25+
[orchestrator.eval]
26+
interval = 1
27+
num_examples = 128
28+
29+
[orchestrator.eval.sampling]
30+
max_completion_tokens = 128
31+
32+
[[orchestrator.eval.env]]
33+
id = "reverse-text"
34+
35+
[trainer.optim]
36+
lr = 3e-6
37+
38+
[ckpt]
39+
40+
[inference]
41+
gpu_memory_utilization = 0.5

deps/verifiers

Submodule verifiers updated 124 files

packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,31 @@
2020
from prime_rl.configs.trainer import TokenizerConfig
2121
from prime_rl.utils.config import BaseConfig
2222

23+
SFT_REPLAY_ENV_ID = "sft-replay"
24+
25+
26+
def _is_sft_replay_env_id(env_id: str) -> bool:
27+
stripped = env_id.split("@")[0]
28+
return stripped == SFT_REPLAY_ENV_ID or stripped.endswith(f"/{SFT_REPLAY_ENV_ID}")
29+
30+
31+
def _sft_replay_dataset_arg(env_args: dict) -> object | None:
32+
taskset = env_args.get("taskset")
33+
if isinstance(taskset, dict):
34+
dataset = taskset.get("dataset")
35+
if dataset:
36+
return dataset
37+
38+
config = env_args.get("config")
39+
if isinstance(config, dict):
40+
taskset = config.get("taskset")
41+
if isinstance(taskset, dict):
42+
dataset = taskset.get("dataset")
43+
if dataset:
44+
return dataset
45+
46+
return None
47+
2348

2449
class OptimizerConfig(BaseConfig):
2550
lr: float = Field(1e-4, ge=0)
@@ -501,13 +526,13 @@ class RolloutModelConfig(BaseConfig):
501526

502527
class OrchestratorConfig(BaseConfig):
503528
training_mode: Literal["rl", "opd", "sft"] = "rl"
504-
"""Training mode. ``rl``: student generates rollouts, no teacher. ``opd``: student generates rollouts, teacher computes logprobs (teacher_tau > 0). ``sft``: teacher generates rollouts, student inference pool used for evals and weight sync."""
529+
"""Training mode. ``rl``: student generates rollouts, no teacher. ``opd``: student generates rollouts, teacher computes logprobs. ``sft``: teacher generates rollouts when configured, otherwise train envs must provide replayed traces."""
505530

506531
student: RolloutModelConfig = Field(RolloutModelConfig(), validation_alias=AliasChoices("student", "model"))
507532
"""Student rollout participant (model + client) — the model being trained."""
508533

509534
teacher: RolloutModelConfig | None = Field(None, validation_alias=AliasChoices("teacher", "teacher_model"))
510-
"""Teacher rollout participant (model + client). Role depends on ``training_mode``: ``opd`` — teacher computes logprobs; ``sft`` — teacher generates rollouts."""
535+
"""Teacher rollout participant (model + client). Required for ``opd``. Optional for ``sft`` when train envs provide replayed traces."""
511536

512537
train: TrainConfig = TrainConfig()
513538

@@ -752,10 +777,16 @@ def validate_unique_filter_types(self):
752777
)
753778
return self
754779

780+
@model_validator(mode="after")
781+
def _drop_default_sft_zero_advantage_filter(self):
782+
if self.training_mode == "sft" and "post_batch_filters" not in self.model_fields_set:
783+
self.post_batch_filters = [f for f in self.post_batch_filters if f.type != "zero_advantage"]
784+
return self
785+
755786
@model_validator(mode="after")
756787
def _force_no_renderer_for_sft(self):
757-
"""SFT rolls out via the teacher's plain chat-completions endpoint; the
758-
renderer client doesn't apply. Force ``renderer=None`` so the user
788+
"""SFT train rollouts use teacher chat completions or replayed traces;
789+
the renderer client doesn't apply. Force ``renderer=None`` so the user
759790
doesn't have to remember to set it. Declared before the renderer
760791
validators below so they see the corrected value."""
761792
if self.training_mode == "sft":
@@ -768,8 +799,30 @@ def validate_training_mode(self):
768799
has_teacher = self.teacher is not None
769800
if self.training_mode == "rl" and has_teacher:
770801
raise ValueError("orchestrator.teacher must not be set when training_mode = 'rl'.")
771-
if self.training_mode in ("opd", "sft") and not has_teacher:
772-
raise ValueError(f"orchestrator.teacher must be configured when training_mode = '{self.training_mode}'.")
802+
if self.training_mode == "opd" and not has_teacher:
803+
raise ValueError("orchestrator.teacher must be configured when training_mode = 'opd'.")
804+
return self
805+
806+
@model_validator(mode="after")
807+
def validate_teacherless_sft_uses_sft_replay(self):
808+
"""Teacherless SFT is only valid when train envs replay existing data."""
809+
if self.training_mode != "sft" or self.teacher is not None:
810+
return self
811+
812+
non_replay_envs = [env.id for env in self.train.env if not _is_sft_replay_env_id(env.id)]
813+
if non_replay_envs:
814+
raise ValueError(
815+
"orchestrator.teacher must be configured for SFT unless every train env uses "
816+
f"{SFT_REPLAY_ENV_ID!r}; got non-replay train env(s): {non_replay_envs}."
817+
)
818+
819+
missing_dataset = [env.resolved_name for env in self.train.env if _sft_replay_dataset_arg(env.args) is None]
820+
if missing_dataset:
821+
raise ValueError(
822+
f"teacherless SFT with {SFT_REPLAY_ENV_ID!r} requires an explicit "
823+
"env.args.taskset.dataset or env.args.config.taskset.dataset for "
824+
f"each train env; missing for: {missing_dataset}."
825+
)
773826
return self
774827

775828
@model_validator(mode="after")

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ envs = [
8888
"rlm-swe",
8989
"science-env",
9090
"simpleqa-verified",
91+
"sft-replay",
9192
"tau2-bench",
9293
"wiki-search",
9394
"wordle",
@@ -150,6 +151,8 @@ override-dependencies = [
150151
"transformers==5.6.2",
151152
"torch>=2.9.0",
152153
"openenv-core",
154+
"harnesses>=0.1.0",
155+
"tasksets>=0.1.0",
153156
]
154157

155158
# ModelExpress 0.3.0 publishes protobuf<6 metadata, but its generated proto is
@@ -224,6 +227,9 @@ reverse-text = { path = "deps/verifiers/environments/reverse_text", editable = t
224227
rlm-swe = { path = "deps/research-environments/environments/rlm_swe", editable = true }
225228
science-env = { path = "deps/research-environments/environments/science_env", editable = true }
226229
simpleqa-verified = { path = "deps/research-environments/environments/simpleqa_verified", editable = true }
230+
sft-replay = { path = "deps/verifiers/environments/sft_replay", editable = true }
231+
harnesses = { path = "deps/verifiers/packages/harnesses", editable = true }
232+
tasksets = { path = "deps/verifiers/packages/tasksets", editable = true }
227233
tau2-bench = { path = "deps/research-environments/environments/tau2_bench", editable = true }
228234
wiki-search = { path = "deps/verifiers/environments/wiki_search", editable = true }
229235
wordle = { path = "deps/verifiers/environments/wordle", editable = true }

src/prime_rl/orchestrator/dispatcher.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,9 @@ def __init__(
135135
self.policy = policy
136136
self.train_envs = train_envs
137137
self.eval_envs = eval_envs
138-
# Train rollouts go to ``inference`` (the teacher in SFT mode);
139-
# eval always evaluates the student, so it uses ``eval_inference``.
138+
# Train rollouts go to ``inference`` (teacher in teacher-SFT, student
139+
# otherwise); eval always evaluates the student, so it uses
140+
# ``eval_inference``.
140141
self.inference = inference
141142
self.eval_inference = eval_inference
142143
self.train_source = train_source
@@ -173,9 +174,9 @@ def __init__(
173174

174175
@property
175176
def train_model_name(self) -> str:
176-
"""Model name for *train* rollouts. In SFT mode train data comes from
177-
the teacher pool, so use its model name; otherwise the live student
178-
policy. (Eval always uses ``policy.model_name`` — the student.)"""
177+
"""Model name for *train* rollouts. Teacher-SFT uses the teacher pool
178+
name; replay SFT receives the student name but ignores it. Eval always
179+
uses ``policy.model_name`` — the student."""
179180
if self.training_mode == "sft":
180181
return self.inference.model_name
181182
return self.policy.model_name

src/prime_rl/orchestrator/orchestrator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -345,10 +345,10 @@ async def setup(self) -> None:
345345
else:
346346
get_logger().info("Training from scratch")
347347

348-
# SFT generates rollouts via the teacher (the student is trained on
349-
# the teacher's outputs); RL / OPD generate via the student
350-
if config.training_mode == "sft":
351-
assert self.teacher_inference is not None, "sft mode requires teacher inference"
348+
# SFT train rollouts come from the teacher when configured. Teacherless
349+
# SFT is validated at config parse time to use replay envs, which ignore
350+
# the client/model passed by the dispatcher.
351+
if config.training_mode == "sft" and self.teacher_inference is not None:
352352
rollout_inference = self.teacher_inference
353353
else:
354354
rollout_inference = self.student_inference

src/prime_rl/orchestrator/trajectories.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,24 @@ def _tokenize_step_with_renderer(
164164
return build_trajectory_step(renderer, prompt, completion, tools=tools)
165165

166166

167+
def _set_token_usage_from_trajectory(output: vf.RolloutOutput) -> None:
168+
trajectory = output.get("trajectory") or []
169+
tokenized_steps = [step for step in trajectory if step.get("tokens") is not None]
170+
if not tokenized_steps:
171+
return
172+
173+
prompt_tokens = [len(step["tokens"]["prompt_ids"]) for step in tokenized_steps]
174+
completion_tokens = [len(step["tokens"]["completion_ids"]) for step in tokenized_steps]
175+
total_completion = sum(completion_tokens)
176+
last_total = prompt_tokens[-1] + completion_tokens[-1]
177+
output["token_usage"] = {
178+
"input_tokens": float(sum(prompt_tokens)),
179+
"output_tokens": float(total_completion),
180+
"final_input_tokens": float(max(0, last_total - total_completion)),
181+
"final_output_tokens": float(total_completion),
182+
}
183+
184+
167185
def backfill_rollout_tokens(
168186
output: vf.RolloutOutput,
169187
tokenizer: PreTrainedTokenizer,
@@ -175,6 +193,9 @@ def backfill_rollout_tokens(
175193
Otherwise falls back to the tokenizer + apply_chat_template path.
176194
"""
177195
if all(step["tokens"] is not None for step in output["trajectory"]):
196+
token_usage = output.get("token_usage") or {}
197+
if "final_input_tokens" not in token_usage or "final_output_tokens" not in token_usage:
198+
_set_token_usage_from_trajectory(output)
178199
return True
179200

180201
logger = get_logger()
@@ -198,6 +219,10 @@ def backfill_rollout_tokens(
198219
reconstructed.pop("original_prompt_len")
199220
step["tokens"] = reconstructed
200221

222+
token_usage = output.get("token_usage") or {}
223+
if "final_input_tokens" not in token_usage or "final_output_tokens" not in token_usage:
224+
_set_token_usage_from_trajectory(output)
225+
201226
return True
202227

203228

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import pytest
2+
import verifiers as vf
3+
from verifiers.clients import Client
4+
5+
from prime_rl.orchestrator.trajectories import backfill_rollout_tokens, interleave_rollout
6+
7+
8+
class NoopClient(Client):
9+
def setup_client(self, config):
10+
return object()
11+
12+
async def to_native_tool(self, tool):
13+
raise AssertionError("sft-replay must not convert tools")
14+
15+
async def to_native_prompt(self, messages):
16+
raise AssertionError("sft-replay must not render prompts through a client")
17+
18+
async def get_native_response(self, prompt, model, sampling_args, tools=None, **kwargs):
19+
raise AssertionError("sft-replay must not request model responses")
20+
21+
async def raise_from_native_response(self, response) -> None:
22+
raise AssertionError("sft-replay must not handle native responses")
23+
24+
async def from_native_response(self, response):
25+
raise AssertionError("sft-replay must not parse native responses")
26+
27+
async def close(self) -> None:
28+
return None
29+
30+
31+
class SimpleChatTokenizer:
32+
def __init__(self):
33+
self._tok2id: dict[str, int] = {}
34+
self._next_id = 1
35+
36+
def _id(self, token: str) -> int:
37+
if token not in self._tok2id:
38+
self._tok2id[token] = self._next_id
39+
self._next_id += 1
40+
return self._tok2id[token]
41+
42+
def apply_chat_template(self, messages, add_generation_prompt=False, return_dict=False, tools=None):
43+
del return_dict, tools
44+
ids = []
45+
for message in messages:
46+
role = message.get("role", "unknown")
47+
ids.append(self._id(f"<|{role}|>"))
48+
content = message.get("content", "")
49+
if isinstance(content, str):
50+
if content:
51+
ids.append(self._id(content))
52+
else:
53+
ids.append(self._id(str(content)))
54+
if add_generation_prompt:
55+
ids.append(self._id("<|assistant|>"))
56+
return ids
57+
58+
59+
def role_content(messages) -> list[tuple[str, object]]:
60+
return [(message["role"], message["content"]) for message in messages]
61+
62+
63+
@pytest.mark.asyncio
64+
async def test_sft_replay_env_replays_messages_for_prime_rl_training_path():
65+
env = vf.load_environment("sft-replay", taskset={})
66+
row = dict(env.get_dataset()[0])
67+
68+
output = await env.run_rollout(
69+
row,
70+
client=NoopClient(vf.ClientConfig()),
71+
model="unused-student",
72+
sampling_args={},
73+
state_columns=["trajectory", "sampling_args"],
74+
)
75+
76+
assert output["error"] is None
77+
assert output["stop_condition"] == "replayed_messages"
78+
assert len(output["trajectory"]) == 1
79+
assert output["trajectory"][0]["tokens"] is None
80+
assert role_content(output["trajectory"][0]["prompt"]) == [("user", "Reverse abc.")]
81+
assert role_content(output["trajectory"][0]["completion"]) == [("assistant", "cba")]
82+
83+
backfill_rollout_tokens(output, SimpleChatTokenizer())
84+
samples = interleave_rollout(output, env_name="sft-replay")
85+
86+
assert samples is not None
87+
assert len(samples) == 1
88+
assert any(samples[0].completion_mask)
89+
assert output["token_usage"]["final_input_tokens"] > 0
90+
assert output["token_usage"]["final_output_tokens"] > 0

0 commit comments

Comments
 (0)