Skip to content

Commit 31b5fd1

Browse files
committed
fix: keep OPD teacher scoring on native vLLM route
1 parent 8d100cf commit 31b5fd1

3 files changed

Lines changed: 11 additions & 64 deletions

File tree

skills/configs/SKILL.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ In `opd`, rollouts are generated by the student. The orchestrator scores the stu
7070

7171
`[inference]` is required for the usual online path because it starts the student inference server and auto-configures `orchestrator.student.client.base_url`. The student pool is used for online evals and policy weight sync. For externally started student inference, set `orchestrator.student.client.base_url` explicitly instead.
7272

73-
Teacher logprob scoring supports both self-hosted vLLM and Prime API teacher clients: `/inference/v1/generate` for vLLM server roots, `/api/v1/generate` when the teacher client base URL ends in `/api/v1`.
73+
Teacher logprob scoring uses PrimeRL's vLLM-native `/inference/v1/generate` route. The request field is `token_ids`, meaning the prompt plus completion tokens to score; response `choices[].token_ids` remains generated completion tokens and is not used for OPD scoring.
7474

7575
## RL trainer token exports
7676

src/prime_rl/orchestrator/utils.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -107,20 +107,15 @@ async def compute_teacher_logprobs(
107107
"""Compute teacher model logprobs for a batch of training samples via prefill."""
108108
import httpx
109109

110-
def _teacher_generate_request(base_url: str, model_name: str, token_ids: list[int]) -> tuple[str, dict[str, Any]]:
110+
def _teacher_generate_request(
111+
base_url: str,
112+
model_name: str,
113+
scored_token_ids: list[int],
114+
) -> tuple[str, dict[str, Any]]:
111115
base = base_url.rstrip("/")
112-
if base.endswith("/api/v1"):
113-
return f"{base}/generate", {
114-
"model": model_name,
115-
"prompt_token_ids": token_ids,
116-
"max_tokens": 1,
117-
"temperature": 1.0,
118-
"top_p": 1.0,
119-
"prompt_logprobs": 1,
120-
}
121116
return f"{base.removesuffix('/v1')}/inference/v1/generate", {
122117
"model": model_name,
123-
"token_ids": token_ids,
118+
"token_ids": scored_token_ids,
124119
"sampling_params": {
125120
"max_tokens": 1,
126121
"temperature": 1.0,
@@ -157,27 +152,26 @@ def _flatten_prompt_logprobs(response: dict[str, Any], token_ids: list[int]) ->
157152

158153
async def _compute_single(client_config: vf.ClientConfig, sample: TrainingSample) -> list[float]:
159154
client = setup_openai_client(client_config)
160-
token_ids = list(sample.prompt_ids) + list(sample.completion_ids)
155+
scored_token_ids = list(sample.prompt_ids) + list(sample.completion_ids)
161156

162157
# Two escape hatches from ``AsyncOpenAI.post``:
163158
# 1. URL — vLLM mounts ``/inference/v1/generate`` at server root,
164-
# while Prime Inference exposes ``/api/v1/generate``. Pass an
165-
# absolute URL so the SDK's ``_prepare_url`` skips base-url merge.
159+
# so pass an absolute URL and skip the SDK's base-url merge.
166160
# 2. Parse — vLLM's ``GenerateResponse`` is a plain
167161
# ``pydantic.BaseModel`` and the SDK's parse layer rejects any
168162
# ``cast_to`` that doesn't subclass ``openai.BaseModel``. Use
169163
# ``cast_to=httpx.Response`` so the SDK still builds the request
170164
# (preserving ``auth_headers``, retries, timeouts, idempotency
171165
# keys) and just hands us the raw response to validate ourselves.
172-
url, body = _teacher_generate_request(str(client.base_url), model_name, token_ids)
166+
url, body = _teacher_generate_request(str(client.base_url), model_name, scored_token_ids)
173167
http_response = await client.post(
174168
url,
175169
cast_to=httpx.Response,
176170
body=body,
177171
)
178172
http_response.raise_for_status()
179173
response = http_response.json()
180-
return _flatten_prompt_logprobs(response, token_ids)
174+
return _flatten_prompt_logprobs(response, scored_token_ids)
181175

182176
return await asyncio.gather(*[_compute_single(client, sample) for client, sample in zip(cycle(clients), samples)])
183177

tests/unit/orchestrator/test_teacher_logprobs.py

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -85,53 +85,6 @@ async def _run():
8585
asyncio.run(_run())
8686

8787

88-
def test_compute_teacher_logprobs_uses_prime_generate_for_api_base_url(monkeypatch):
89-
async def _run():
90-
fake_client = _FakeOpenAIClient(
91-
{
92-
"request_id": "gen-test",
93-
"choices": [],
94-
"prompt_logprobs": [
95-
None,
96-
{"13": {"logprob": -0.1}, "2": {"logprob": -0.7}},
97-
{"198": {"logprob": -0.2}, "3": {"logprob": -0.3}},
98-
],
99-
"kv_transfer_params": None,
100-
},
101-
base_url="https://api.primeintellect.ai/api/v1",
102-
)
103-
monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client)
104-
105-
sample = SimpleNamespace(
106-
prompt_ids=[1],
107-
prompt_mask=[True],
108-
completion_ids=[2, 3],
109-
completion_mask=[True, True],
110-
completion_logprobs=[-0.1, -0.2],
111-
completion_temperatures=[1.0, 1.0],
112-
env_name="test-env",
113-
)
114-
115-
result = await orchestrator_utils.compute_teacher_logprobs(
116-
clients=[vf.ClientConfig()],
117-
model_name="teacher-model",
118-
samples=[sample],
119-
)
120-
121-
assert result == [[0.0, -0.7, -0.3]]
122-
assert fake_client.calls[0]["url"] == "https://api.primeintellect.ai/api/v1/generate"
123-
assert fake_client.calls[0]["body"] == {
124-
"model": "teacher-model",
125-
"prompt_token_ids": [1, 2, 3],
126-
"max_tokens": 1,
127-
"temperature": 1.0,
128-
"top_p": 1.0,
129-
"prompt_logprobs": 1,
130-
}
131-
132-
asyncio.run(_run())
133-
134-
13588
def test_compute_teacher_logprobs_rejects_wrong_length(monkeypatch):
13689
async def _run():
13790
fake_client = _FakeOpenAIClient(

0 commit comments

Comments
 (0)