Skip to content

Commit 8d100cf

Browse files
committed
fix: support hosted OPD teacher logprobs
1 parent da9a92a commit 8d100cf

3 files changed

Lines changed: 271 additions & 37 deletions

File tree

skills/configs/SKILL.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,18 @@ CLI: `--env.0.id reverse-text --env.1.id math-env`.
6060

6161
In TOML, an empty section header (`[ckpt]`) does the same.
6262

63+
## Distillation training modes
64+
65+
Set `orchestrator.training_mode = "sft"` or `orchestrator.training_mode = "opd"` (or top-level `training_mode`, which auto-propagates) and configure `orchestrator.teacher` with the teacher endpoint.
66+
67+
In `sft`, rollouts are generated by the teacher. The shared `training_mode` validator sets `trainer.loss.type = "sft"`, and trainer loss dispatch uses `sft_loss_fn`.
68+
69+
In `opd`, rollouts are generated by the student. The orchestrator scores the student token sequence with the teacher, writes those values to `TrainingSample.teacher_logprobs`, and trainer loss dispatch uses `opd_loss_fn`.
70+
71+
`[inference]` is required for the usual online path because it starts the student inference server and auto-configures `orchestrator.student.client.base_url`. The student pool is used for online evals and policy weight sync. For externally started student inference, set `orchestrator.student.client.base_url` explicitly instead.
72+
73+
Teacher logprob scoring supports both self-hosted vLLM and Prime API teacher clients: `/inference/v1/generate` for vLLM server roots, `/api/v1/generate` when the teacher client base URL ends in `/api/v1`.
74+
6375
## RL trainer token exports
6476

6577
For rollout debugging, enable trainer-side token export under `trainer.experimental.token_export` (or `experimental.token_export` when running the trainer entrypoint directly). It writes one JSONL record per exported sequence under `output_dir/token_exports/step_<step>/rank_<rank>.jsonl`. Each record stores aligned per-token arrays for token ids, loss mask, advantage, reward, entropy, mismatch KL, inference/trainer logprobs, importance ratios, probability deltas, and masking diagnostics. It does not decode token text in the trainer.

src/prime_rl/orchestrator/utils.py

Lines changed: 61 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
from __future__ import annotations
2+
13
import asyncio
24
import logging
35
import time
46
from concurrent.futures import ThreadPoolExecutor
57
from itertools import cycle
68
from pathlib import Path
9+
from typing import Any
710

811
import orjson
912
import verifiers as vf
@@ -103,51 +106,78 @@ async def compute_teacher_logprobs(
103106
) -> list[list[float]]:
104107
"""Compute teacher model logprobs for a batch of training samples via prefill."""
105108
import httpx
106-
from vllm.entrypoints.serve.disagg.protocol import GenerateResponse
109+
110+
def _teacher_generate_request(base_url: str, model_name: str, token_ids: list[int]) -> tuple[str, dict[str, Any]]:
111+
base = base_url.rstrip("/")
112+
if base.endswith("/api/v1"):
113+
return f"{base}/generate", {
114+
"model": model_name,
115+
"prompt_token_ids": token_ids,
116+
"max_tokens": 1,
117+
"temperature": 1.0,
118+
"top_p": 1.0,
119+
"prompt_logprobs": 1,
120+
}
121+
return f"{base.removesuffix('/v1')}/inference/v1/generate", {
122+
"model": model_name,
123+
"token_ids": token_ids,
124+
"sampling_params": {
125+
"max_tokens": 1,
126+
"temperature": 1.0,
127+
"top_p": 1.0,
128+
"prompt_logprobs": 1,
129+
},
130+
}
131+
132+
def _flatten_prompt_logprobs(response: dict[str, Any], token_ids: list[int]) -> list[float]:
133+
# ``prompt_logprobs[i]`` is a ``{token_id: Logprob}`` dict for tokens
134+
# the engine could score, or ``None`` for the leading token which has
135+
# no preceding context. vLLM can include both the target token and the
136+
# top-k alternatives; select the exact target token at each position.
137+
prompt_logprobs = response.get("prompt_logprobs") or []
138+
if len(prompt_logprobs) != len(token_ids):
139+
raise ValueError(
140+
f"teacher prompt_logprobs length != sample length ({len(prompt_logprobs)} != {len(token_ids)})"
141+
)
142+
flat: list[float] = []
143+
for i, (token_id, entry) in enumerate(zip(token_ids, prompt_logprobs)):
144+
if not entry:
145+
if i != 0:
146+
raise ValueError(f"teacher prompt_logprobs missing entry at position {i} for token id {token_id}")
147+
flat.append(0.0)
148+
continue
149+
target = entry.get(str(token_id))
150+
if target is None:
151+
target = entry.get(token_id)
152+
if target is None:
153+
raise ValueError(f"teacher prompt_logprobs missing token id {token_id}")
154+
lp = target.get("logprob")
155+
flat.append(float(lp) if lp is not None else 0.0)
156+
return flat
107157

108158
async def _compute_single(client_config: vf.ClientConfig, sample: TrainingSample) -> list[float]:
109159
client = setup_openai_client(client_config)
160+
token_ids = list(sample.prompt_ids) + list(sample.completion_ids)
110161

111162
# Two escape hatches from ``AsyncOpenAI.post``:
112-
# 1. URL — ``/inference/v1/generate`` is mounted at server root, not
113-
# under ``/v1``. Pass an absolute URL so the SDK's
114-
# ``_prepare_url`` skips the base-url merge (it short-circuits
115-
# when the path passes ``httpx.URL.is_relative_url`` as False).
163+
# 1. URL — vLLM mounts ``/inference/v1/generate`` at server root,
164+
# while Prime Inference exposes ``/api/v1/generate``. Pass an
165+
# absolute URL so the SDK's ``_prepare_url`` skips base-url merge.
116166
# 2. Parse — vLLM's ``GenerateResponse`` is a plain
117167
# ``pydantic.BaseModel`` and the SDK's parse layer rejects any
118168
# ``cast_to`` that doesn't subclass ``openai.BaseModel``. Use
119169
# ``cast_to=httpx.Response`` so the SDK still builds the request
120170
# (preserving ``auth_headers``, retries, timeouts, idempotency
121171
# keys) and just hands us the raw response to validate ourselves.
122-
base = str(client.base_url).rstrip("/").removesuffix("/v1")
172+
url, body = _teacher_generate_request(str(client.base_url), model_name, token_ids)
123173
http_response = await client.post(
124-
f"{base}/inference/v1/generate",
174+
url,
125175
cast_to=httpx.Response,
126-
body={
127-
"model": model_name,
128-
"token_ids": list(sample.prompt_ids) + list(sample.completion_ids),
129-
"sampling_params": {
130-
"max_tokens": 1,
131-
"temperature": 1.0,
132-
"top_p": 1.0,
133-
"prompt_logprobs": 1,
134-
},
135-
},
176+
body=body,
136177
)
137-
response = GenerateResponse.model_validate_json(http_response.content)
138-
# ``prompt_logprobs[i]`` is a ``{token_id: Logprob}`` dict for tokens
139-
# the engine could score, or ``None`` for the leading token which has
140-
# no preceding context. Flatten to ``list[float]`` with 0.0 in the
141-
# unscored slot.
142-
flat: list[float] = []
143-
for entry in response.prompt_logprobs or []:
144-
if not entry:
145-
flat.append(0.0)
146-
continue
147-
first = next(iter(entry.values()))
148-
lp = first.logprob if hasattr(first, "logprob") else first.get("logprob")
149-
flat.append(float(lp) if lp is not None else 0.0)
150-
return flat
178+
http_response.raise_for_status()
179+
response = http_response.json()
180+
return _flatten_prompt_logprobs(response, token_ids)
151181

152182
return await asyncio.gather(*[_compute_single(client, sample) for client, sample in zip(cycle(clients), samples)])
153183

tests/unit/orchestrator/test_teacher_logprobs.py

Lines changed: 198 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import asyncio
22
import json
3+
from types import SimpleNamespace
34

45
import httpx
56
import verifiers as vf
67

78
from prime_rl.orchestrator import utils as orchestrator_utils
8-
from prime_rl.transport import TrainingSample
99

1010

1111
class _FakeOpenAIClient:
@@ -14,17 +14,18 @@ class _FakeOpenAIClient:
1414
handed back verbatim, mirroring the real SDK's short-circuit at
1515
``AsyncAPIClient._process_response``."""
1616

17-
def __init__(self, payload: dict):
17+
def __init__(self, payload: dict, base_url: str = "http://fake-host:8000/v1", status_code: int = 200):
1818
# Match what AsyncOpenAI exposes — utils.py reads ``str(client.base_url)``.
19-
self.base_url = "http://fake-host:8000/v1"
19+
self.base_url = base_url
2020
self._payload = payload
21+
self._status_code = status_code
2122
self.calls: list[dict] = []
2223

2324
async def post(self, url, *, cast_to, body):
2425
self.calls.append({"url": url, "cast_to": cast_to, "body": body})
2526
request = httpx.Request("POST", url, json=body)
2627
return httpx.Response(
27-
status_code=200,
28+
status_code=self._status_code,
2829
content=json.dumps(self._payload).encode(),
2930
request=request,
3031
)
@@ -37,13 +38,17 @@ async def _run():
3738
"request_id": "gen-test",
3839
"choices": [],
3940
# Upstream wire shape: list[dict[token_id, Logprob] | None]
40-
"prompt_logprobs": [None, {"11": {"logprob": -0.7}}, {"12": {"logprob": -0.3}}],
41+
"prompt_logprobs": [
42+
None,
43+
{"13": {"logprob": -0.1}, "2": {"logprob": -0.7}},
44+
{"198": {"logprob": -0.2}, "3": {"logprob": -0.3}},
45+
],
4146
"kv_transfer_params": None,
4247
}
4348
)
4449
monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client)
4550

46-
sample = TrainingSample(
51+
sample = SimpleNamespace(
4752
prompt_ids=[1],
4853
prompt_mask=[True],
4954
completion_ids=[2, 3],
@@ -78,3 +83,190 @@ async def _run():
7883
]
7984

8085
asyncio.run(_run())
86+
87+
88+
def test_compute_teacher_logprobs_uses_prime_generate_for_api_base_url(monkeypatch):
89+
async def _run():
90+
fake_client = _FakeOpenAIClient(
91+
{
92+
"request_id": "gen-test",
93+
"choices": [],
94+
"prompt_logprobs": [
95+
None,
96+
{"13": {"logprob": -0.1}, "2": {"logprob": -0.7}},
97+
{"198": {"logprob": -0.2}, "3": {"logprob": -0.3}},
98+
],
99+
"kv_transfer_params": None,
100+
},
101+
base_url="https://api.primeintellect.ai/api/v1",
102+
)
103+
monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client)
104+
105+
sample = SimpleNamespace(
106+
prompt_ids=[1],
107+
prompt_mask=[True],
108+
completion_ids=[2, 3],
109+
completion_mask=[True, True],
110+
completion_logprobs=[-0.1, -0.2],
111+
completion_temperatures=[1.0, 1.0],
112+
env_name="test-env",
113+
)
114+
115+
result = await orchestrator_utils.compute_teacher_logprobs(
116+
clients=[vf.ClientConfig()],
117+
model_name="teacher-model",
118+
samples=[sample],
119+
)
120+
121+
assert result == [[0.0, -0.7, -0.3]]
122+
assert fake_client.calls[0]["url"] == "https://api.primeintellect.ai/api/v1/generate"
123+
assert fake_client.calls[0]["body"] == {
124+
"model": "teacher-model",
125+
"prompt_token_ids": [1, 2, 3],
126+
"max_tokens": 1,
127+
"temperature": 1.0,
128+
"top_p": 1.0,
129+
"prompt_logprobs": 1,
130+
}
131+
132+
asyncio.run(_run())
133+
134+
135+
def test_compute_teacher_logprobs_rejects_wrong_length(monkeypatch):
136+
async def _run():
137+
fake_client = _FakeOpenAIClient(
138+
{
139+
"request_id": "gen-test",
140+
"choices": [],
141+
"prompt_logprobs": [None, {"2": {"logprob": -0.7}}],
142+
"kv_transfer_params": None,
143+
}
144+
)
145+
monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client)
146+
147+
sample = SimpleNamespace(
148+
prompt_ids=[1],
149+
prompt_mask=[True],
150+
completion_ids=[2, 3],
151+
completion_mask=[True, True],
152+
completion_logprobs=[-0.1, -0.2],
153+
completion_temperatures=[1.0, 1.0],
154+
env_name="test-env",
155+
)
156+
157+
try:
158+
await orchestrator_utils.compute_teacher_logprobs(
159+
clients=[vf.ClientConfig()],
160+
model_name="teacher-model",
161+
samples=[sample],
162+
)
163+
except ValueError as exc:
164+
assert "teacher prompt_logprobs length != sample length" in str(exc)
165+
else:
166+
raise AssertionError("Expected ValueError")
167+
168+
asyncio.run(_run())
169+
170+
171+
def test_compute_teacher_logprobs_rejects_missing_token_id(monkeypatch):
172+
async def _run():
173+
fake_client = _FakeOpenAIClient(
174+
{
175+
"request_id": "gen-test",
176+
"choices": [],
177+
"prompt_logprobs": [None, {"13": {"logprob": -0.1}}, {"3": {"logprob": -0.3}}],
178+
"kv_transfer_params": None,
179+
}
180+
)
181+
monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client)
182+
183+
sample = SimpleNamespace(
184+
prompt_ids=[1],
185+
prompt_mask=[True],
186+
completion_ids=[2, 3],
187+
completion_mask=[True, True],
188+
completion_logprobs=[-0.1, -0.2],
189+
completion_temperatures=[1.0, 1.0],
190+
env_name="test-env",
191+
)
192+
193+
try:
194+
await orchestrator_utils.compute_teacher_logprobs(
195+
clients=[vf.ClientConfig()],
196+
model_name="teacher-model",
197+
samples=[sample],
198+
)
199+
except ValueError as exc:
200+
assert "teacher prompt_logprobs missing token id 2" in str(exc)
201+
else:
202+
raise AssertionError("Expected ValueError")
203+
204+
asyncio.run(_run())
205+
206+
207+
def test_compute_teacher_logprobs_rejects_missing_non_leading_entry(monkeypatch):
208+
async def _run():
209+
fake_client = _FakeOpenAIClient(
210+
{
211+
"request_id": "gen-test",
212+
"choices": [],
213+
"prompt_logprobs": [None, None, {"3": {"logprob": -0.3}}],
214+
"kv_transfer_params": None,
215+
}
216+
)
217+
monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client)
218+
219+
sample = SimpleNamespace(
220+
prompt_ids=[1],
221+
prompt_mask=[True],
222+
completion_ids=[2, 3],
223+
completion_mask=[True, True],
224+
completion_logprobs=[-0.1, -0.2],
225+
completion_temperatures=[1.0, 1.0],
226+
env_name="test-env",
227+
)
228+
229+
try:
230+
await orchestrator_utils.compute_teacher_logprobs(
231+
clients=[vf.ClientConfig()],
232+
model_name="teacher-model",
233+
samples=[sample],
234+
)
235+
except ValueError as exc:
236+
assert "teacher prompt_logprobs missing entry at position 1 for token id 2" in str(exc)
237+
else:
238+
raise AssertionError("Expected ValueError")
239+
240+
asyncio.run(_run())
241+
242+
243+
def test_compute_teacher_logprobs_raises_for_teacher_http_error(monkeypatch):
244+
async def _run():
245+
fake_client = _FakeOpenAIClient(
246+
{"error": {"message": "invalid teacher api key"}},
247+
status_code=401,
248+
)
249+
monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client)
250+
251+
sample = SimpleNamespace(
252+
prompt_ids=[1],
253+
prompt_mask=[True],
254+
completion_ids=[2, 3],
255+
completion_mask=[True, True],
256+
completion_logprobs=[-0.1, -0.2],
257+
completion_temperatures=[1.0, 1.0],
258+
env_name="test-env",
259+
)
260+
261+
try:
262+
await orchestrator_utils.compute_teacher_logprobs(
263+
clients=[vf.ClientConfig()],
264+
model_name="teacher-model",
265+
samples=[sample],
266+
)
267+
except httpx.HTTPStatusError as exc:
268+
assert exc.response.status_code == 401
269+
else:
270+
raise AssertionError("Expected HTTPStatusError")
271+
272+
asyncio.run(_run())

0 commit comments

Comments
 (0)