fix: harden OPD teacher logprob scoring#2683
Conversation
There was a problem hiding this comment.
Pull request overview
This PR updates OPD teacher logprob scoring so it works with both hosted Prime API teachers and self-hosted vLLM teachers, and tightens logprob parsing to align per-token logprobs to the exact target token id with explicit validation on response mismatches.
Changes:
- Route teacher scoring to
/api/v1/generatewhen the teacher client base URL ends with/api/v1, otherwise use vLLM’s/inference/v1/generate. - Parse
prompt_logprobsfrom raw JSON and select the logprob corresponding to the actual token id at each position, failing on length/token-id mismatches. - Add unit tests for endpoint routing and mismatch handling, and document SFT vs OPD teacher scoring behavior in the config skill doc.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
tests/unit/orchestrator/test_teacher_logprobs.py |
Extends unit coverage for hosted Prime API teacher routing and stricter logprob response validation. |
src/prime_rl/orchestrator/utils.py |
Implements endpoint selection and token-id-aligned prompt logprob flattening for teacher scoring. |
skills/config/SKILL.md |
Documents SFT vs OPD distillation modes and which teacher endpoints are supported for scoring. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @@ -115,20 +145,8 @@ async def _compute_single(client_config: vf.ClientConfig, sample: TrainingSample | |||
| }, | |||
| }, | |||
| ) | |||
| response = GenerateResponse.model_validate_json(http_response.content) | |||
| # ``prompt_logprobs[i]`` is a ``{token_id: Logprob}`` dict for tokens | |||
| # the engine could score, or ``None`` for the leading token which has | |||
| # no preceding context. Flatten to ``list[float]`` with 0.0 in the | |||
| # unscored slot. | |||
| flat: list[float] = [] | |||
| for entry in response.prompt_logprobs or []: | |||
| if not entry: | |||
| flat.append(0.0) | |||
| continue | |||
| first = next(iter(entry.values())) | |||
| lp = first.logprob if hasattr(first, "logprob") else first.get("logprob") | |||
| flat.append(float(lp) if lp is not None else 0.0) | |||
| return flat | |||
| response = http_response.json() | |||
| return _flatten_prompt_logprobs(response, token_ids) | |||
3638f41 to
7a55a28
Compare
|
@codex review |
|
Codex Review: Didn't find any major issues. Breezy! ℹ️ About Codex in GitHubYour team has set up Codex to review pull requests in this repo. Reviews are triggered when you
If Codex has suggestions, it will comment; otherwise it will react with 👍. Codex can also answer questions or update the PR. Try commenting "@codex address that feedback". |
7a55a28 to
8d100cf
Compare
|
@codex review |
9d34387 to
31b5fd1
Compare
|
Codex Review: Didn't find any major issues. Breezy! ℹ️ About Codex in GitHubYour team has set up Codex to review pull requests in this repo. Reviews are triggered when you
If Codex has suggestions, it will comment; otherwise it will react with 👍. Codex can also answer questions or update the PR. Try commenting "@codex address that feedback". |
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit 31b5fd1. Configure here.
| scored_token_ids: list[int], | ||
| ) -> tuple[str, dict[str, Any]]: | ||
| base = base_url.rstrip("/") | ||
| return f"{base.removesuffix('/v1')}/inference/v1/generate", { |
There was a problem hiding this comment.
Missing Prime API URL routing for hosted teacher
High Severity
_teacher_generate_request always constructs a /inference/v1/generate URL, but the PR claims to also route to Prime API /api/v1/generate when a hosted teacher base URL is used. With a Prime API base URL like https://api.pinference.ai/api/v1 (as seen in configs/debug/training_modes/sft_external.toml), the function strips /v1 and appends /inference/v1/generate, producing https://api.pinference.ai/api/inference/v1/generate — an incorrect path that would 404. The expected Prime API URL is https://api.pinference.ai/api/v1/generate. No conditional routing exists and no unit test covers a Prime API base URL, despite the PR caveat stating the path is "unit/dry-run covered."
Reviewed by Cursor Bugbot for commit 31b5fd1. Configure here.
There was a problem hiding this comment.
I think this is stale after the latest scope change. PrimeRL no longer branches to /api/v1/generate; this PR intentionally keeps OPD teacher scoring on the native vLLM /inference/v1/generate route. The Prime Inference /api/v1/generate gateway is platform/pi-inference-owned via PrimeIntellect-ai/platform#2342.
|
@codex review |
|
Codex Review: Didn't find any major issues. Keep them coming! ℹ️ About Codex in GitHubYour team has set up Codex to review pull requests in this repo. Reviews are triggered when you
If Codex has suggestions, it will comment; otherwise it will react with 👍. Codex can also answer questions or update the PR. Try commenting "@codex address that feedback". |
| def test_compute_teacher_logprobs_rejects_wrong_length(monkeypatch): | ||
| async def _run(): | ||
| fake_client = _FakeOpenAIClient( | ||
| { | ||
| "request_id": "gen-test", | ||
| "choices": [], | ||
| "prompt_logprobs": [None, {"2": {"logprob": -0.7}}], | ||
| "kv_transfer_params": None, | ||
| } | ||
| ) | ||
| monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client) | ||
|
|
||
| sample = SimpleNamespace( | ||
| prompt_ids=[1], | ||
| prompt_mask=[True], | ||
| completion_ids=[2, 3], | ||
| completion_mask=[True, True], | ||
| completion_logprobs=[-0.1, -0.2], | ||
| completion_temperatures=[1.0, 1.0], | ||
| env_name="test-env", | ||
| ) | ||
|
|
||
| try: | ||
| await orchestrator_utils.compute_teacher_logprobs( | ||
| clients=[vf.ClientConfig()], | ||
| model_name="teacher-model", | ||
| samples=[sample], | ||
| ) | ||
| except ValueError as exc: | ||
| assert "teacher prompt_logprobs length != sample length" in str(exc) | ||
| else: | ||
| raise AssertionError("Expected ValueError") | ||
|
|
||
| asyncio.run(_run()) | ||
|
|
||
|
|
||
| def test_compute_teacher_logprobs_rejects_missing_token_id(monkeypatch): | ||
| async def _run(): | ||
| fake_client = _FakeOpenAIClient( | ||
| { | ||
| "request_id": "gen-test", | ||
| "choices": [], | ||
| "prompt_logprobs": [None, {"13": {"logprob": -0.1}}, {"3": {"logprob": -0.3}}], | ||
| "kv_transfer_params": None, | ||
| } | ||
| ) | ||
| monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client) | ||
|
|
||
| sample = SimpleNamespace( | ||
| prompt_ids=[1], | ||
| prompt_mask=[True], | ||
| completion_ids=[2, 3], | ||
| completion_mask=[True, True], | ||
| completion_logprobs=[-0.1, -0.2], | ||
| completion_temperatures=[1.0, 1.0], | ||
| env_name="test-env", | ||
| ) | ||
|
|
||
| try: | ||
| await orchestrator_utils.compute_teacher_logprobs( | ||
| clients=[vf.ClientConfig()], | ||
| model_name="teacher-model", | ||
| samples=[sample], | ||
| ) | ||
| except ValueError as exc: | ||
| assert "teacher prompt_logprobs missing token id 2" in str(exc) | ||
| else: | ||
| raise AssertionError("Expected ValueError") | ||
|
|
||
| asyncio.run(_run()) | ||
|
|
||
|
|
||
| def test_compute_teacher_logprobs_rejects_missing_non_leading_entry(monkeypatch): | ||
| async def _run(): | ||
| fake_client = _FakeOpenAIClient( | ||
| { | ||
| "request_id": "gen-test", | ||
| "choices": [], | ||
| "prompt_logprobs": [None, None, {"3": {"logprob": -0.3}}], | ||
| "kv_transfer_params": None, | ||
| } | ||
| ) | ||
| monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client) | ||
|
|
||
| sample = SimpleNamespace( | ||
| prompt_ids=[1], | ||
| prompt_mask=[True], | ||
| completion_ids=[2, 3], | ||
| completion_mask=[True, True], | ||
| completion_logprobs=[-0.1, -0.2], | ||
| completion_temperatures=[1.0, 1.0], | ||
| env_name="test-env", | ||
| ) | ||
|
|
||
| try: | ||
| await orchestrator_utils.compute_teacher_logprobs( | ||
| clients=[vf.ClientConfig()], | ||
| model_name="teacher-model", | ||
| samples=[sample], | ||
| ) | ||
| except ValueError as exc: | ||
| assert "teacher prompt_logprobs missing entry at position 1 for token id 2" in str(exc) | ||
| else: | ||
| raise AssertionError("Expected ValueError") | ||
|
|
||
| asyncio.run(_run()) | ||
|
|
||
|
|
||
| def test_compute_teacher_logprobs_raises_for_teacher_http_error(monkeypatch): | ||
| async def _run(): | ||
| fake_client = _FakeOpenAIClient( | ||
| {"error": {"message": "invalid teacher api key"}}, | ||
| status_code=401, | ||
| ) | ||
| monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client) | ||
|
|
||
| sample = SimpleNamespace( | ||
| prompt_ids=[1], | ||
| prompt_mask=[True], | ||
| completion_ids=[2, 3], | ||
| completion_mask=[True, True], | ||
| completion_logprobs=[-0.1, -0.2], | ||
| completion_temperatures=[1.0, 1.0], | ||
| env_name="test-env", | ||
| ) | ||
|
|
||
| try: | ||
| await orchestrator_utils.compute_teacher_logprobs( | ||
| clients=[vf.ClientConfig()], | ||
| model_name="teacher-model", | ||
| samples=[sample], | ||
| ) | ||
| except httpx.HTTPStatusError as exc: | ||
| assert exc.response.status_code == 401 | ||
| else: | ||
| raise AssertionError("Expected HTTPStatusError") | ||
|
|
||
| asyncio.run(_run()) |
| ## Distillation training modes | ||
|
|
||
| Set `orchestrator.training_mode = "sft"` or `orchestrator.training_mode = "opd"` (or top-level `training_mode`, which auto-propagates) and configure `orchestrator.teacher` with the teacher endpoint. | ||
|
|
||
| In `sft`, rollouts are generated by the teacher. The shared `training_mode` validator sets `trainer.loss.type = "sft"`, and trainer loss dispatch uses `sft_loss_fn`. | ||
|
|
||
| In `opd`, rollouts are generated by the student. The orchestrator scores the student token sequence with the teacher, writes those values to `TrainingSample.teacher_logprobs`, and trainer loss dispatch uses `opd_loss_fn`. | ||
|
|
||
| `[inference]` is required for the usual online path because it starts the student inference server and auto-configures `orchestrator.student.client.base_url`. The student pool is used for online evals and policy weight sync. For externally started student inference, set `orchestrator.student.client.base_url` explicitly instead. | ||
|
|
||
| Teacher logprob scoring uses PrimeRL's vLLM-native `/inference/v1/generate` route. The request field is `token_ids`, meaning the prompt plus completion tokens to score; response `choices[].token_ids` remains generated completion tokens and is not used for OPD scoring. |
There was a problem hiding this comment.
lets remove also ? not sure what is the link with the pr
| scored_token_ids: list[int], | ||
| ) -> tuple[str, dict[str, Any]]: | ||
| base = base_url.rstrip("/") | ||
| return f"{base.removesuffix('/v1')}/inference/v1/generate", { |
There was a problem hiding this comment.
this break opd in prime rl , we can't do this
|
Closing again to keep the active workstream focused on replay-backed SFT. OPD follow-up is not needed here. |


Summary
/inference/v1/generateroute/api/v1/generatebranch from PrimeRL now that the gateway path is platform/pi-inference-ownedtoken_idsscoring contract in the local config skillScope
This PR is PrimeRL-native scoring only. It intentionally does not route hosted Prime API base URLs to
/api/v1/generate; that gateway path is owned by platform/pi-inference in PrimeIntellect-ai/platform#2342.For this PR, teacher scoring expects a vLLM-compatible endpoint that serves
/inference/v1/generatewithtoken_ids+prompt_logprobs.Verification
reverse_text_rl_opduv run pytest tests/unit/orchestrator/test_teacher_logprobs.pyis blocked by the workspace dependency conflict (prime-pydantic-config:devpinsruff==0.5.0;prime-rl:devrequiresruff>=0.12.1)UV_NO_SYNC=1 uv run pytest tests/unit/orchestrator/test_teacher_logprobs.pyis blocked becausepytestis not installed in the unsynced envgit diff --checkpassed for the staged PrimeRL changes before commitReview Notes
deps/rendererssubmodule rewind, which was intentionally not staged or pushed@codex reviewtrigger comment addedLinear: APR-57
Note
Medium Risk
Changes OPD distillation’s teacher scoring contract and error handling; wrong logprobs would skew training, but scope is isolated to
compute_teacher_logprobswith added validation and tests.Overview
Hardens OPD teacher logprob scoring so the orchestrator only uses PrimeRL’s vLLM-native
/inference/v1/generatepath withtoken_ids(prompt + completion) andprompt_logprobs, instead of relying on vLLM’sGenerateResponsepydantic parsing.compute_teacher_logprobsnow builds the request via helpers, parses the raw JSON response, and picks the logprob for each scored token id (not the first alternative in the dict). It errors on length mismatches, missing non-leading entries, missing target token ids, and HTTP failures before parsing. The config skill documents SFT vs OPD flows and thetoken_idsscoring contract; unit tests cover happy path, validation failures, and 401 handling.Reviewed by Cursor Bugbot for commit 31b5fd1. Bugbot is set up for automated code reviews on this repo. Configure here.