Skip to content

Commit 7a55a28

Browse files
committed
fix: support hosted OPD teacher logprobs
1 parent 845eb6e commit 7a55a28

3 files changed

Lines changed: 270 additions & 41 deletions

File tree

skills/config/SKILL.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,12 +161,18 @@ uv run sft --data.type fake --data.batch-size 4
161161

162162
If you wish to configure values of the default variant, you don't need to set the `type` field.
163163

164-
### SFT hard distill override
164+
### Distillation training modes
165165

166-
Set `orchestrator.training_mode = "sft"` (or top-level `training_mode = "sft"`, which auto-propagates) and configure `orchestrator.teacher` with the teacher endpoint. The orchestrator stamps each `TrainingSample.sft_loss = True` and the shared `training_mode` validator sets `trainer.loss.type = "sft"`, which the trainer's `compute_loss` honors by dispatching to `sft_loss_fn` per batch.
166+
Set `orchestrator.training_mode = "sft"` or `orchestrator.training_mode = "opd"` (or top-level `training_mode`, which auto-propagates) and configure `orchestrator.teacher` with the teacher endpoint.
167+
168+
In `sft`, rollouts are generated by the teacher. The shared `training_mode` validator sets `trainer.loss.type = "sft"`, and trainer loss dispatch uses `sft_loss_fn`.
169+
170+
In `opd`, rollouts are generated by the student. The orchestrator scores the student token sequence with the teacher, writes those values to `TrainingSample.teacher_logprobs`, and trainer loss dispatch uses `opd_loss_fn`.
167171

168172
`[inference]` is required (same as rl/opd) — it starts the student inference server and auto-configures `orchestrator.student.client.base_url`. The student pool is used for online evals and policy weight sync. For externally started student inference, set `orchestrator.student.client.base_url` explicitly instead.
169173

174+
Teacher logprob scoring supports both self-hosted vLLM and Prime API teacher clients: `/inference/v1/generate` for vLLM server roots, `/api/v1/generate` when the teacher client base URL ends in `/api/v1`.
175+
170176
### RL rollout client defaults
171177

172178
For text-only RL rollouts, the orchestrator defaults to renderer-backed TITO (`use_renderer = true`). VLM configs must explicitly fall back to MITO (`use_renderer = false`) so image preprocessing and chat templating stay server-side. External teacher rollouts must also set `use_renderer = false`.

src/prime_rl/orchestrator/utils.py

Lines changed: 64 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
1+
from __future__ import annotations
2+
13
import asyncio
24
import time
35
from concurrent.futures import ThreadPoolExecutor
46
from itertools import cycle
57
from pathlib import Path
6-
from typing import Any
8+
from typing import TYPE_CHECKING, Any
79

810
import pandas as pd
911
import verifiers as vf
1012
from rich.console import Console
1113
from rich.table import Table
1214
from verifiers.utils.client_utils import setup_openai_client
1315

14-
from prime_rl.transport import TrainingSample
1516
from prime_rl.utils.logger import get_logger
1617
from prime_rl.utils.utils import (
1718
format_time,
@@ -20,6 +21,9 @@
2021
get_step_path,
2122
)
2223

24+
if TYPE_CHECKING:
25+
from prime_rl.transport.types import TrainingSample
26+
2327

2428
def set_default_executor(max_workers: int = 64) -> None:
2529
"""Scale the default asyncio thread pool so asyncio.to_thread has enough capacity."""
@@ -84,51 +88,78 @@ async def compute_teacher_logprobs(
8488
) -> list[list[float]]:
8589
"""Compute teacher model logprobs for a batch of training samples via prefill."""
8690
import httpx
87-
from vllm.entrypoints.serve.disagg.protocol import GenerateResponse
91+
92+
def _teacher_generate_request(base_url: str, model_name: str, token_ids: list[int]) -> tuple[str, dict[str, Any]]:
93+
base = base_url.rstrip("/")
94+
if base.endswith("/api/v1"):
95+
return f"{base}/generate", {
96+
"model": model_name,
97+
"prompt_token_ids": token_ids,
98+
"max_tokens": 1,
99+
"temperature": 1.0,
100+
"top_p": 1.0,
101+
"prompt_logprobs": 1,
102+
}
103+
return f"{base.removesuffix('/v1')}/inference/v1/generate", {
104+
"model": model_name,
105+
"token_ids": token_ids,
106+
"sampling_params": {
107+
"max_tokens": 1,
108+
"temperature": 1.0,
109+
"top_p": 1.0,
110+
"prompt_logprobs": 1,
111+
},
112+
}
113+
114+
def _flatten_prompt_logprobs(response: dict[str, Any], token_ids: list[int]) -> list[float]:
115+
# ``prompt_logprobs[i]`` is a ``{token_id: Logprob}`` dict for tokens
116+
# the engine could score, or ``None`` for the leading token which has
117+
# no preceding context. vLLM can include both the target token and the
118+
# top-k alternatives; select the exact target token at each position.
119+
prompt_logprobs = response.get("prompt_logprobs") or []
120+
if len(prompt_logprobs) != len(token_ids):
121+
raise ValueError(
122+
f"teacher prompt_logprobs length != sample length ({len(prompt_logprobs)} != {len(token_ids)})"
123+
)
124+
flat: list[float] = []
125+
for i, (token_id, entry) in enumerate(zip(token_ids, prompt_logprobs)):
126+
if not entry:
127+
if i != 0:
128+
raise ValueError(f"teacher prompt_logprobs missing entry at position {i} for token id {token_id}")
129+
flat.append(0.0)
130+
continue
131+
target = entry.get(str(token_id))
132+
if target is None:
133+
target = entry.get(token_id)
134+
if target is None:
135+
raise ValueError(f"teacher prompt_logprobs missing token id {token_id}")
136+
lp = target.get("logprob")
137+
flat.append(float(lp) if lp is not None else 0.0)
138+
return flat
88139

89140
async def _compute_single(client_config: vf.ClientConfig, sample: TrainingSample) -> list[float]:
90141
client = setup_openai_client(client_config)
142+
token_ids = list(sample.prompt_ids) + list(sample.completion_ids)
91143

92144
# Two escape hatches from ``AsyncOpenAI.post``:
93-
# 1. URL — ``/inference/v1/generate`` is mounted at server root, not
94-
# under ``/v1``. Pass an absolute URL so the SDK's
95-
# ``_prepare_url`` skips the base-url merge (it short-circuits
96-
# when the path passes ``httpx.URL.is_relative_url`` as False).
145+
# 1. URL — vLLM mounts ``/inference/v1/generate`` at server root,
146+
# while Prime Inference exposes ``/api/v1/generate``. Pass an
147+
# absolute URL so the SDK's ``_prepare_url`` skips base-url merge.
97148
# 2. Parse — vLLM's ``GenerateResponse`` is a plain
98149
# ``pydantic.BaseModel`` and the SDK's parse layer rejects any
99150
# ``cast_to`` that doesn't subclass ``openai.BaseModel``. Use
100151
# ``cast_to=httpx.Response`` so the SDK still builds the request
101152
# (preserving ``auth_headers``, retries, timeouts, idempotency
102153
# keys) and just hands us the raw response to validate ourselves.
103-
base = str(client.base_url).rstrip("/").removesuffix("/v1")
154+
url, body = _teacher_generate_request(str(client.base_url), model_name, token_ids)
104155
http_response = await client.post(
105-
f"{base}/inference/v1/generate",
156+
url,
106157
cast_to=httpx.Response,
107-
body={
108-
"model": model_name,
109-
"token_ids": list(sample.prompt_ids) + list(sample.completion_ids),
110-
"sampling_params": {
111-
"max_tokens": 1,
112-
"temperature": 1.0,
113-
"top_p": 1.0,
114-
"prompt_logprobs": 1,
115-
},
116-
},
158+
body=body,
117159
)
118-
response = GenerateResponse.model_validate_json(http_response.content)
119-
# ``prompt_logprobs[i]`` is a ``{token_id: Logprob}`` dict for tokens
120-
# the engine could score, or ``None`` for the leading token which has
121-
# no preceding context. Flatten to ``list[float]`` with 0.0 in the
122-
# unscored slot.
123-
flat: list[float] = []
124-
for entry in response.prompt_logprobs or []:
125-
if not entry:
126-
flat.append(0.0)
127-
continue
128-
first = next(iter(entry.values()))
129-
lp = first.logprob if hasattr(first, "logprob") else first.get("logprob")
130-
flat.append(float(lp) if lp is not None else 0.0)
131-
return flat
160+
http_response.raise_for_status()
161+
response = http_response.json()
162+
return _flatten_prompt_logprobs(response, token_ids)
132163

133164
return await asyncio.gather(*[_compute_single(client, sample) for client, sample in zip(cycle(clients), samples)])
134165

0 commit comments

Comments
 (0)