From 8d100cfbb576342f14ff9b0682342d0907a85329 Mon Sep 17 00:00:00 2001 From: Timothy Kostolansky <39891386+tim0120@users.noreply.github.com> Date: Mon, 1 Jun 2026 14:04:34 -0700 Subject: [PATCH 1/2] fix: support hosted OPD teacher logprobs --- skills/configs/SKILL.md | 12 ++ src/prime_rl/orchestrator/utils.py | 92 +++++--- .../orchestrator/test_teacher_logprobs.py | 204 +++++++++++++++++- 3 files changed, 271 insertions(+), 37 deletions(-) diff --git a/skills/configs/SKILL.md b/skills/configs/SKILL.md index 83f7dd8d47..c3270926ed 100644 --- a/skills/configs/SKILL.md +++ b/skills/configs/SKILL.md @@ -60,6 +60,18 @@ CLI: `--env.0.id reverse-text --env.1.id math-env`. In TOML, an empty section header (`[ckpt]`) does the same. +## Distillation training modes + +Set `orchestrator.training_mode = "sft"` or `orchestrator.training_mode = "opd"` (or top-level `training_mode`, which auto-propagates) and configure `orchestrator.teacher` with the teacher endpoint. + +In `sft`, rollouts are generated by the teacher. The shared `training_mode` validator sets `trainer.loss.type = "sft"`, and trainer loss dispatch uses `sft_loss_fn`. + +In `opd`, rollouts are generated by the student. The orchestrator scores the student token sequence with the teacher, writes those values to `TrainingSample.teacher_logprobs`, and trainer loss dispatch uses `opd_loss_fn`. + +`[inference]` is required for the usual online path because it starts the student inference server and auto-configures `orchestrator.student.client.base_url`. The student pool is used for online evals and policy weight sync. For externally started student inference, set `orchestrator.student.client.base_url` explicitly instead. + +Teacher logprob scoring supports both self-hosted vLLM and Prime API teacher clients: `/inference/v1/generate` for vLLM server roots, `/api/v1/generate` when the teacher client base URL ends in `/api/v1`. + ## RL trainer token exports For rollout debugging, enable trainer-side token export under `trainer.experimental.token_export` (or `experimental.token_export` when running the trainer entrypoint directly). It writes one JSONL record per exported sequence under `output_dir/token_exports/step_/rank_.jsonl`. Each record stores aligned per-token arrays for token ids, loss mask, advantage, reward, entropy, mismatch KL, inference/trainer logprobs, importance ratios, probability deltas, and masking diagnostics. It does not decode token text in the trainer. diff --git a/src/prime_rl/orchestrator/utils.py b/src/prime_rl/orchestrator/utils.py index 5675ba3f34..ff7d98bd38 100644 --- a/src/prime_rl/orchestrator/utils.py +++ b/src/prime_rl/orchestrator/utils.py @@ -1,9 +1,12 @@ +from __future__ import annotations + import asyncio import logging import time from concurrent.futures import ThreadPoolExecutor from itertools import cycle from pathlib import Path +from typing import Any import orjson import verifiers as vf @@ -103,51 +106,78 @@ async def compute_teacher_logprobs( ) -> list[list[float]]: """Compute teacher model logprobs for a batch of training samples via prefill.""" import httpx - from vllm.entrypoints.serve.disagg.protocol import GenerateResponse + + def _teacher_generate_request(base_url: str, model_name: str, token_ids: list[int]) -> tuple[str, dict[str, Any]]: + base = base_url.rstrip("/") + if base.endswith("/api/v1"): + return f"{base}/generate", { + "model": model_name, + "prompt_token_ids": token_ids, + "max_tokens": 1, + "temperature": 1.0, + "top_p": 1.0, + "prompt_logprobs": 1, + } + return f"{base.removesuffix('/v1')}/inference/v1/generate", { + "model": model_name, + "token_ids": token_ids, + "sampling_params": { + "max_tokens": 1, + "temperature": 1.0, + "top_p": 1.0, + "prompt_logprobs": 1, + }, + } + + def _flatten_prompt_logprobs(response: dict[str, Any], token_ids: list[int]) -> list[float]: + # ``prompt_logprobs[i]`` is a ``{token_id: Logprob}`` dict for tokens + # the engine could score, or ``None`` for the leading token which has + # no preceding context. vLLM can include both the target token and the + # top-k alternatives; select the exact target token at each position. + prompt_logprobs = response.get("prompt_logprobs") or [] + if len(prompt_logprobs) != len(token_ids): + raise ValueError( + f"teacher prompt_logprobs length != sample length ({len(prompt_logprobs)} != {len(token_ids)})" + ) + flat: list[float] = [] + for i, (token_id, entry) in enumerate(zip(token_ids, prompt_logprobs)): + if not entry: + if i != 0: + raise ValueError(f"teacher prompt_logprobs missing entry at position {i} for token id {token_id}") + flat.append(0.0) + continue + target = entry.get(str(token_id)) + if target is None: + target = entry.get(token_id) + if target is None: + raise ValueError(f"teacher prompt_logprobs missing token id {token_id}") + lp = target.get("logprob") + flat.append(float(lp) if lp is not None else 0.0) + return flat async def _compute_single(client_config: vf.ClientConfig, sample: TrainingSample) -> list[float]: client = setup_openai_client(client_config) + token_ids = list(sample.prompt_ids) + list(sample.completion_ids) # Two escape hatches from ``AsyncOpenAI.post``: - # 1. URL — ``/inference/v1/generate`` is mounted at server root, not - # under ``/v1``. Pass an absolute URL so the SDK's - # ``_prepare_url`` skips the base-url merge (it short-circuits - # when the path passes ``httpx.URL.is_relative_url`` as False). + # 1. URL — vLLM mounts ``/inference/v1/generate`` at server root, + # while Prime Inference exposes ``/api/v1/generate``. Pass an + # absolute URL so the SDK's ``_prepare_url`` skips base-url merge. # 2. Parse — vLLM's ``GenerateResponse`` is a plain # ``pydantic.BaseModel`` and the SDK's parse layer rejects any # ``cast_to`` that doesn't subclass ``openai.BaseModel``. Use # ``cast_to=httpx.Response`` so the SDK still builds the request # (preserving ``auth_headers``, retries, timeouts, idempotency # keys) and just hands us the raw response to validate ourselves. - base = str(client.base_url).rstrip("/").removesuffix("/v1") + url, body = _teacher_generate_request(str(client.base_url), model_name, token_ids) http_response = await client.post( - f"{base}/inference/v1/generate", + url, cast_to=httpx.Response, - body={ - "model": model_name, - "token_ids": list(sample.prompt_ids) + list(sample.completion_ids), - "sampling_params": { - "max_tokens": 1, - "temperature": 1.0, - "top_p": 1.0, - "prompt_logprobs": 1, - }, - }, + body=body, ) - response = GenerateResponse.model_validate_json(http_response.content) - # ``prompt_logprobs[i]`` is a ``{token_id: Logprob}`` dict for tokens - # the engine could score, or ``None`` for the leading token which has - # no preceding context. Flatten to ``list[float]`` with 0.0 in the - # unscored slot. - flat: list[float] = [] - for entry in response.prompt_logprobs or []: - if not entry: - flat.append(0.0) - continue - first = next(iter(entry.values())) - lp = first.logprob if hasattr(first, "logprob") else first.get("logprob") - flat.append(float(lp) if lp is not None else 0.0) - return flat + http_response.raise_for_status() + response = http_response.json() + return _flatten_prompt_logprobs(response, token_ids) return await asyncio.gather(*[_compute_single(client, sample) for client, sample in zip(cycle(clients), samples)]) diff --git a/tests/unit/orchestrator/test_teacher_logprobs.py b/tests/unit/orchestrator/test_teacher_logprobs.py index d63fdce792..918724ac58 100644 --- a/tests/unit/orchestrator/test_teacher_logprobs.py +++ b/tests/unit/orchestrator/test_teacher_logprobs.py @@ -1,11 +1,11 @@ import asyncio import json +from types import SimpleNamespace import httpx import verifiers as vf from prime_rl.orchestrator import utils as orchestrator_utils -from prime_rl.transport import TrainingSample class _FakeOpenAIClient: @@ -14,17 +14,18 @@ class _FakeOpenAIClient: handed back verbatim, mirroring the real SDK's short-circuit at ``AsyncAPIClient._process_response``.""" - def __init__(self, payload: dict): + def __init__(self, payload: dict, base_url: str = "http://fake-host:8000/v1", status_code: int = 200): # Match what AsyncOpenAI exposes — utils.py reads ``str(client.base_url)``. - self.base_url = "http://fake-host:8000/v1" + self.base_url = base_url self._payload = payload + self._status_code = status_code self.calls: list[dict] = [] async def post(self, url, *, cast_to, body): self.calls.append({"url": url, "cast_to": cast_to, "body": body}) request = httpx.Request("POST", url, json=body) return httpx.Response( - status_code=200, + status_code=self._status_code, content=json.dumps(self._payload).encode(), request=request, ) @@ -37,13 +38,17 @@ async def _run(): "request_id": "gen-test", "choices": [], # Upstream wire shape: list[dict[token_id, Logprob] | None] - "prompt_logprobs": [None, {"11": {"logprob": -0.7}}, {"12": {"logprob": -0.3}}], + "prompt_logprobs": [ + None, + {"13": {"logprob": -0.1}, "2": {"logprob": -0.7}}, + {"198": {"logprob": -0.2}, "3": {"logprob": -0.3}}, + ], "kv_transfer_params": None, } ) monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client) - sample = TrainingSample( + sample = SimpleNamespace( prompt_ids=[1], prompt_mask=[True], completion_ids=[2, 3], @@ -78,3 +83,190 @@ async def _run(): ] asyncio.run(_run()) + + +def test_compute_teacher_logprobs_uses_prime_generate_for_api_base_url(monkeypatch): + async def _run(): + fake_client = _FakeOpenAIClient( + { + "request_id": "gen-test", + "choices": [], + "prompt_logprobs": [ + None, + {"13": {"logprob": -0.1}, "2": {"logprob": -0.7}}, + {"198": {"logprob": -0.2}, "3": {"logprob": -0.3}}, + ], + "kv_transfer_params": None, + }, + base_url="https://api.primeintellect.ai/api/v1", + ) + monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client) + + sample = SimpleNamespace( + prompt_ids=[1], + prompt_mask=[True], + completion_ids=[2, 3], + completion_mask=[True, True], + completion_logprobs=[-0.1, -0.2], + completion_temperatures=[1.0, 1.0], + env_name="test-env", + ) + + result = await orchestrator_utils.compute_teacher_logprobs( + clients=[vf.ClientConfig()], + model_name="teacher-model", + samples=[sample], + ) + + assert result == [[0.0, -0.7, -0.3]] + assert fake_client.calls[0]["url"] == "https://api.primeintellect.ai/api/v1/generate" + assert fake_client.calls[0]["body"] == { + "model": "teacher-model", + "prompt_token_ids": [1, 2, 3], + "max_tokens": 1, + "temperature": 1.0, + "top_p": 1.0, + "prompt_logprobs": 1, + } + + asyncio.run(_run()) + + +def test_compute_teacher_logprobs_rejects_wrong_length(monkeypatch): + async def _run(): + fake_client = _FakeOpenAIClient( + { + "request_id": "gen-test", + "choices": [], + "prompt_logprobs": [None, {"2": {"logprob": -0.7}}], + "kv_transfer_params": None, + } + ) + monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client) + + sample = SimpleNamespace( + prompt_ids=[1], + prompt_mask=[True], + completion_ids=[2, 3], + completion_mask=[True, True], + completion_logprobs=[-0.1, -0.2], + completion_temperatures=[1.0, 1.0], + env_name="test-env", + ) + + try: + await orchestrator_utils.compute_teacher_logprobs( + clients=[vf.ClientConfig()], + model_name="teacher-model", + samples=[sample], + ) + except ValueError as exc: + assert "teacher prompt_logprobs length != sample length" in str(exc) + else: + raise AssertionError("Expected ValueError") + + asyncio.run(_run()) + + +def test_compute_teacher_logprobs_rejects_missing_token_id(monkeypatch): + async def _run(): + fake_client = _FakeOpenAIClient( + { + "request_id": "gen-test", + "choices": [], + "prompt_logprobs": [None, {"13": {"logprob": -0.1}}, {"3": {"logprob": -0.3}}], + "kv_transfer_params": None, + } + ) + monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client) + + sample = SimpleNamespace( + prompt_ids=[1], + prompt_mask=[True], + completion_ids=[2, 3], + completion_mask=[True, True], + completion_logprobs=[-0.1, -0.2], + completion_temperatures=[1.0, 1.0], + env_name="test-env", + ) + + try: + await orchestrator_utils.compute_teacher_logprobs( + clients=[vf.ClientConfig()], + model_name="teacher-model", + samples=[sample], + ) + except ValueError as exc: + assert "teacher prompt_logprobs missing token id 2" in str(exc) + else: + raise AssertionError("Expected ValueError") + + asyncio.run(_run()) + + +def test_compute_teacher_logprobs_rejects_missing_non_leading_entry(monkeypatch): + async def _run(): + fake_client = _FakeOpenAIClient( + { + "request_id": "gen-test", + "choices": [], + "prompt_logprobs": [None, None, {"3": {"logprob": -0.3}}], + "kv_transfer_params": None, + } + ) + monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client) + + sample = SimpleNamespace( + prompt_ids=[1], + prompt_mask=[True], + completion_ids=[2, 3], + completion_mask=[True, True], + completion_logprobs=[-0.1, -0.2], + completion_temperatures=[1.0, 1.0], + env_name="test-env", + ) + + try: + await orchestrator_utils.compute_teacher_logprobs( + clients=[vf.ClientConfig()], + model_name="teacher-model", + samples=[sample], + ) + except ValueError as exc: + assert "teacher prompt_logprobs missing entry at position 1 for token id 2" in str(exc) + else: + raise AssertionError("Expected ValueError") + + asyncio.run(_run()) + + +def test_compute_teacher_logprobs_raises_for_teacher_http_error(monkeypatch): + async def _run(): + fake_client = _FakeOpenAIClient( + {"error": {"message": "invalid teacher api key"}}, + status_code=401, + ) + monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client) + + sample = SimpleNamespace( + prompt_ids=[1], + prompt_mask=[True], + completion_ids=[2, 3], + completion_mask=[True, True], + completion_logprobs=[-0.1, -0.2], + completion_temperatures=[1.0, 1.0], + env_name="test-env", + ) + + try: + await orchestrator_utils.compute_teacher_logprobs( + clients=[vf.ClientConfig()], + model_name="teacher-model", + samples=[sample], + ) + except httpx.HTTPStatusError as exc: + assert exc.response.status_code == 401 + else: + raise AssertionError("Expected HTTPStatusError") + + asyncio.run(_run()) From 31b5fd1324bf0931a75e85fc49a84dc4446c67cd Mon Sep 17 00:00:00 2001 From: Timothy Kostolansky <39891386+tim0120@users.noreply.github.com> Date: Mon, 1 Jun 2026 20:34:21 -0700 Subject: [PATCH 2/2] fix: keep OPD teacher scoring on native vLLM route --- skills/configs/SKILL.md | 2 +- src/prime_rl/orchestrator/utils.py | 26 ++++------ .../orchestrator/test_teacher_logprobs.py | 47 ------------------- 3 files changed, 11 insertions(+), 64 deletions(-) diff --git a/skills/configs/SKILL.md b/skills/configs/SKILL.md index c3270926ed..ae1e037197 100644 --- a/skills/configs/SKILL.md +++ b/skills/configs/SKILL.md @@ -70,7 +70,7 @@ In `opd`, rollouts are generated by the student. The orchestrator scores the stu `[inference]` is required for the usual online path because it starts the student inference server and auto-configures `orchestrator.student.client.base_url`. The student pool is used for online evals and policy weight sync. For externally started student inference, set `orchestrator.student.client.base_url` explicitly instead. -Teacher logprob scoring supports both self-hosted vLLM and Prime API teacher clients: `/inference/v1/generate` for vLLM server roots, `/api/v1/generate` when the teacher client base URL ends in `/api/v1`. +Teacher logprob scoring uses PrimeRL's vLLM-native `/inference/v1/generate` route. The request field is `token_ids`, meaning the prompt plus completion tokens to score; response `choices[].token_ids` remains generated completion tokens and is not used for OPD scoring. ## RL trainer token exports diff --git a/src/prime_rl/orchestrator/utils.py b/src/prime_rl/orchestrator/utils.py index ff7d98bd38..e26b6dbcfb 100644 --- a/src/prime_rl/orchestrator/utils.py +++ b/src/prime_rl/orchestrator/utils.py @@ -107,20 +107,15 @@ async def compute_teacher_logprobs( """Compute teacher model logprobs for a batch of training samples via prefill.""" import httpx - def _teacher_generate_request(base_url: str, model_name: str, token_ids: list[int]) -> tuple[str, dict[str, Any]]: + def _teacher_generate_request( + base_url: str, + model_name: str, + scored_token_ids: list[int], + ) -> tuple[str, dict[str, Any]]: base = base_url.rstrip("/") - if base.endswith("/api/v1"): - return f"{base}/generate", { - "model": model_name, - "prompt_token_ids": token_ids, - "max_tokens": 1, - "temperature": 1.0, - "top_p": 1.0, - "prompt_logprobs": 1, - } return f"{base.removesuffix('/v1')}/inference/v1/generate", { "model": model_name, - "token_ids": token_ids, + "token_ids": scored_token_ids, "sampling_params": { "max_tokens": 1, "temperature": 1.0, @@ -157,19 +152,18 @@ def _flatten_prompt_logprobs(response: dict[str, Any], token_ids: list[int]) -> async def _compute_single(client_config: vf.ClientConfig, sample: TrainingSample) -> list[float]: client = setup_openai_client(client_config) - token_ids = list(sample.prompt_ids) + list(sample.completion_ids) + scored_token_ids = list(sample.prompt_ids) + list(sample.completion_ids) # Two escape hatches from ``AsyncOpenAI.post``: # 1. URL — vLLM mounts ``/inference/v1/generate`` at server root, - # while Prime Inference exposes ``/api/v1/generate``. Pass an - # absolute URL so the SDK's ``_prepare_url`` skips base-url merge. + # so pass an absolute URL and skip the SDK's base-url merge. # 2. Parse — vLLM's ``GenerateResponse`` is a plain # ``pydantic.BaseModel`` and the SDK's parse layer rejects any # ``cast_to`` that doesn't subclass ``openai.BaseModel``. Use # ``cast_to=httpx.Response`` so the SDK still builds the request # (preserving ``auth_headers``, retries, timeouts, idempotency # keys) and just hands us the raw response to validate ourselves. - url, body = _teacher_generate_request(str(client.base_url), model_name, token_ids) + url, body = _teacher_generate_request(str(client.base_url), model_name, scored_token_ids) http_response = await client.post( url, cast_to=httpx.Response, @@ -177,7 +171,7 @@ async def _compute_single(client_config: vf.ClientConfig, sample: TrainingSample ) http_response.raise_for_status() response = http_response.json() - return _flatten_prompt_logprobs(response, token_ids) + return _flatten_prompt_logprobs(response, scored_token_ids) return await asyncio.gather(*[_compute_single(client, sample) for client, sample in zip(cycle(clients), samples)]) diff --git a/tests/unit/orchestrator/test_teacher_logprobs.py b/tests/unit/orchestrator/test_teacher_logprobs.py index 918724ac58..b7e9b39cb1 100644 --- a/tests/unit/orchestrator/test_teacher_logprobs.py +++ b/tests/unit/orchestrator/test_teacher_logprobs.py @@ -85,53 +85,6 @@ async def _run(): asyncio.run(_run()) -def test_compute_teacher_logprobs_uses_prime_generate_for_api_base_url(monkeypatch): - async def _run(): - fake_client = _FakeOpenAIClient( - { - "request_id": "gen-test", - "choices": [], - "prompt_logprobs": [ - None, - {"13": {"logprob": -0.1}, "2": {"logprob": -0.7}}, - {"198": {"logprob": -0.2}, "3": {"logprob": -0.3}}, - ], - "kv_transfer_params": None, - }, - base_url="https://api.primeintellect.ai/api/v1", - ) - monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client) - - sample = SimpleNamespace( - prompt_ids=[1], - prompt_mask=[True], - completion_ids=[2, 3], - completion_mask=[True, True], - completion_logprobs=[-0.1, -0.2], - completion_temperatures=[1.0, 1.0], - env_name="test-env", - ) - - result = await orchestrator_utils.compute_teacher_logprobs( - clients=[vf.ClientConfig()], - model_name="teacher-model", - samples=[sample], - ) - - assert result == [[0.0, -0.7, -0.3]] - assert fake_client.calls[0]["url"] == "https://api.primeintellect.ai/api/v1/generate" - assert fake_client.calls[0]["body"] == { - "model": "teacher-model", - "prompt_token_ids": [1, 2, 3], - "max_tokens": 1, - "temperature": 1.0, - "top_p": 1.0, - "prompt_logprobs": 1, - } - - asyncio.run(_run()) - - def test_compute_teacher_logprobs_rejects_wrong_length(monkeypatch): async def _run(): fake_client = _FakeOpenAIClient(