diff --git a/skills/configs/SKILL.md b/skills/configs/SKILL.md index 83f7dd8d47..ae1e037197 100644 --- a/skills/configs/SKILL.md +++ b/skills/configs/SKILL.md @@ -60,6 +60,18 @@ CLI: `--env.0.id reverse-text --env.1.id math-env`. In TOML, an empty section header (`[ckpt]`) does the same. +## Distillation training modes + +Set `orchestrator.training_mode = "sft"` or `orchestrator.training_mode = "opd"` (or top-level `training_mode`, which auto-propagates) and configure `orchestrator.teacher` with the teacher endpoint. + +In `sft`, rollouts are generated by the teacher. The shared `training_mode` validator sets `trainer.loss.type = "sft"`, and trainer loss dispatch uses `sft_loss_fn`. + +In `opd`, rollouts are generated by the student. The orchestrator scores the student token sequence with the teacher, writes those values to `TrainingSample.teacher_logprobs`, and trainer loss dispatch uses `opd_loss_fn`. + +`[inference]` is required for the usual online path because it starts the student inference server and auto-configures `orchestrator.student.client.base_url`. The student pool is used for online evals and policy weight sync. For externally started student inference, set `orchestrator.student.client.base_url` explicitly instead. + +Teacher logprob scoring uses PrimeRL's vLLM-native `/inference/v1/generate` route. The request field is `token_ids`, meaning the prompt plus completion tokens to score; response `choices[].token_ids` remains generated completion tokens and is not used for OPD scoring. + ## RL trainer token exports For rollout debugging, enable trainer-side token export under `trainer.experimental.token_export` (or `experimental.token_export` when running the trainer entrypoint directly). It writes one JSONL record per exported sequence under `output_dir/token_exports/step_/rank_.jsonl`. Each record stores aligned per-token arrays for token ids, loss mask, advantage, reward, entropy, mismatch KL, inference/trainer logprobs, importance ratios, probability deltas, and masking diagnostics. It does not decode token text in the trainer. diff --git a/src/prime_rl/orchestrator/utils.py b/src/prime_rl/orchestrator/utils.py index 5675ba3f34..e26b6dbcfb 100644 --- a/src/prime_rl/orchestrator/utils.py +++ b/src/prime_rl/orchestrator/utils.py @@ -1,9 +1,12 @@ +from __future__ import annotations + import asyncio import logging import time from concurrent.futures import ThreadPoolExecutor from itertools import cycle from pathlib import Path +from typing import Any import orjson import verifiers as vf @@ -103,51 +106,72 @@ async def compute_teacher_logprobs( ) -> list[list[float]]: """Compute teacher model logprobs for a batch of training samples via prefill.""" import httpx - from vllm.entrypoints.serve.disagg.protocol import GenerateResponse + + def _teacher_generate_request( + base_url: str, + model_name: str, + scored_token_ids: list[int], + ) -> tuple[str, dict[str, Any]]: + base = base_url.rstrip("/") + return f"{base.removesuffix('/v1')}/inference/v1/generate", { + "model": model_name, + "token_ids": scored_token_ids, + "sampling_params": { + "max_tokens": 1, + "temperature": 1.0, + "top_p": 1.0, + "prompt_logprobs": 1, + }, + } + + def _flatten_prompt_logprobs(response: dict[str, Any], token_ids: list[int]) -> list[float]: + # ``prompt_logprobs[i]`` is a ``{token_id: Logprob}`` dict for tokens + # the engine could score, or ``None`` for the leading token which has + # no preceding context. vLLM can include both the target token and the + # top-k alternatives; select the exact target token at each position. + prompt_logprobs = response.get("prompt_logprobs") or [] + if len(prompt_logprobs) != len(token_ids): + raise ValueError( + f"teacher prompt_logprobs length != sample length ({len(prompt_logprobs)} != {len(token_ids)})" + ) + flat: list[float] = [] + for i, (token_id, entry) in enumerate(zip(token_ids, prompt_logprobs)): + if not entry: + if i != 0: + raise ValueError(f"teacher prompt_logprobs missing entry at position {i} for token id {token_id}") + flat.append(0.0) + continue + target = entry.get(str(token_id)) + if target is None: + target = entry.get(token_id) + if target is None: + raise ValueError(f"teacher prompt_logprobs missing token id {token_id}") + lp = target.get("logprob") + flat.append(float(lp) if lp is not None else 0.0) + return flat async def _compute_single(client_config: vf.ClientConfig, sample: TrainingSample) -> list[float]: client = setup_openai_client(client_config) + scored_token_ids = list(sample.prompt_ids) + list(sample.completion_ids) # Two escape hatches from ``AsyncOpenAI.post``: - # 1. URL — ``/inference/v1/generate`` is mounted at server root, not - # under ``/v1``. Pass an absolute URL so the SDK's - # ``_prepare_url`` skips the base-url merge (it short-circuits - # when the path passes ``httpx.URL.is_relative_url`` as False). + # 1. URL — vLLM mounts ``/inference/v1/generate`` at server root, + # so pass an absolute URL and skip the SDK's base-url merge. # 2. Parse — vLLM's ``GenerateResponse`` is a plain # ``pydantic.BaseModel`` and the SDK's parse layer rejects any # ``cast_to`` that doesn't subclass ``openai.BaseModel``. Use # ``cast_to=httpx.Response`` so the SDK still builds the request # (preserving ``auth_headers``, retries, timeouts, idempotency # keys) and just hands us the raw response to validate ourselves. - base = str(client.base_url).rstrip("/").removesuffix("/v1") + url, body = _teacher_generate_request(str(client.base_url), model_name, scored_token_ids) http_response = await client.post( - f"{base}/inference/v1/generate", + url, cast_to=httpx.Response, - body={ - "model": model_name, - "token_ids": list(sample.prompt_ids) + list(sample.completion_ids), - "sampling_params": { - "max_tokens": 1, - "temperature": 1.0, - "top_p": 1.0, - "prompt_logprobs": 1, - }, - }, + body=body, ) - response = GenerateResponse.model_validate_json(http_response.content) - # ``prompt_logprobs[i]`` is a ``{token_id: Logprob}`` dict for tokens - # the engine could score, or ``None`` for the leading token which has - # no preceding context. Flatten to ``list[float]`` with 0.0 in the - # unscored slot. - flat: list[float] = [] - for entry in response.prompt_logprobs or []: - if not entry: - flat.append(0.0) - continue - first = next(iter(entry.values())) - lp = first.logprob if hasattr(first, "logprob") else first.get("logprob") - flat.append(float(lp) if lp is not None else 0.0) - return flat + http_response.raise_for_status() + response = http_response.json() + return _flatten_prompt_logprobs(response, scored_token_ids) return await asyncio.gather(*[_compute_single(client, sample) for client, sample in zip(cycle(clients), samples)]) diff --git a/tests/unit/orchestrator/test_teacher_logprobs.py b/tests/unit/orchestrator/test_teacher_logprobs.py index d63fdce792..b7e9b39cb1 100644 --- a/tests/unit/orchestrator/test_teacher_logprobs.py +++ b/tests/unit/orchestrator/test_teacher_logprobs.py @@ -1,11 +1,11 @@ import asyncio import json +from types import SimpleNamespace import httpx import verifiers as vf from prime_rl.orchestrator import utils as orchestrator_utils -from prime_rl.transport import TrainingSample class _FakeOpenAIClient: @@ -14,17 +14,18 @@ class _FakeOpenAIClient: handed back verbatim, mirroring the real SDK's short-circuit at ``AsyncAPIClient._process_response``.""" - def __init__(self, payload: dict): + def __init__(self, payload: dict, base_url: str = "http://fake-host:8000/v1", status_code: int = 200): # Match what AsyncOpenAI exposes — utils.py reads ``str(client.base_url)``. - self.base_url = "http://fake-host:8000/v1" + self.base_url = base_url self._payload = payload + self._status_code = status_code self.calls: list[dict] = [] async def post(self, url, *, cast_to, body): self.calls.append({"url": url, "cast_to": cast_to, "body": body}) request = httpx.Request("POST", url, json=body) return httpx.Response( - status_code=200, + status_code=self._status_code, content=json.dumps(self._payload).encode(), request=request, ) @@ -37,13 +38,17 @@ async def _run(): "request_id": "gen-test", "choices": [], # Upstream wire shape: list[dict[token_id, Logprob] | None] - "prompt_logprobs": [None, {"11": {"logprob": -0.7}}, {"12": {"logprob": -0.3}}], + "prompt_logprobs": [ + None, + {"13": {"logprob": -0.1}, "2": {"logprob": -0.7}}, + {"198": {"logprob": -0.2}, "3": {"logprob": -0.3}}, + ], "kv_transfer_params": None, } ) monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client) - sample = TrainingSample( + sample = SimpleNamespace( prompt_ids=[1], prompt_mask=[True], completion_ids=[2, 3], @@ -78,3 +83,143 @@ async def _run(): ] asyncio.run(_run()) + + +def test_compute_teacher_logprobs_rejects_wrong_length(monkeypatch): + async def _run(): + fake_client = _FakeOpenAIClient( + { + "request_id": "gen-test", + "choices": [], + "prompt_logprobs": [None, {"2": {"logprob": -0.7}}], + "kv_transfer_params": None, + } + ) + monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client) + + sample = SimpleNamespace( + prompt_ids=[1], + prompt_mask=[True], + completion_ids=[2, 3], + completion_mask=[True, True], + completion_logprobs=[-0.1, -0.2], + completion_temperatures=[1.0, 1.0], + env_name="test-env", + ) + + try: + await orchestrator_utils.compute_teacher_logprobs( + clients=[vf.ClientConfig()], + model_name="teacher-model", + samples=[sample], + ) + except ValueError as exc: + assert "teacher prompt_logprobs length != sample length" in str(exc) + else: + raise AssertionError("Expected ValueError") + + asyncio.run(_run()) + + +def test_compute_teacher_logprobs_rejects_missing_token_id(monkeypatch): + async def _run(): + fake_client = _FakeOpenAIClient( + { + "request_id": "gen-test", + "choices": [], + "prompt_logprobs": [None, {"13": {"logprob": -0.1}}, {"3": {"logprob": -0.3}}], + "kv_transfer_params": None, + } + ) + monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client) + + sample = SimpleNamespace( + prompt_ids=[1], + prompt_mask=[True], + completion_ids=[2, 3], + completion_mask=[True, True], + completion_logprobs=[-0.1, -0.2], + completion_temperatures=[1.0, 1.0], + env_name="test-env", + ) + + try: + await orchestrator_utils.compute_teacher_logprobs( + clients=[vf.ClientConfig()], + model_name="teacher-model", + samples=[sample], + ) + except ValueError as exc: + assert "teacher prompt_logprobs missing token id 2" in str(exc) + else: + raise AssertionError("Expected ValueError") + + asyncio.run(_run()) + + +def test_compute_teacher_logprobs_rejects_missing_non_leading_entry(monkeypatch): + async def _run(): + fake_client = _FakeOpenAIClient( + { + "request_id": "gen-test", + "choices": [], + "prompt_logprobs": [None, None, {"3": {"logprob": -0.3}}], + "kv_transfer_params": None, + } + ) + monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client) + + sample = SimpleNamespace( + prompt_ids=[1], + prompt_mask=[True], + completion_ids=[2, 3], + completion_mask=[True, True], + completion_logprobs=[-0.1, -0.2], + completion_temperatures=[1.0, 1.0], + env_name="test-env", + ) + + try: + await orchestrator_utils.compute_teacher_logprobs( + clients=[vf.ClientConfig()], + model_name="teacher-model", + samples=[sample], + ) + except ValueError as exc: + assert "teacher prompt_logprobs missing entry at position 1 for token id 2" in str(exc) + else: + raise AssertionError("Expected ValueError") + + asyncio.run(_run()) + + +def test_compute_teacher_logprobs_raises_for_teacher_http_error(monkeypatch): + async def _run(): + fake_client = _FakeOpenAIClient( + {"error": {"message": "invalid teacher api key"}}, + status_code=401, + ) + monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client) + + sample = SimpleNamespace( + prompt_ids=[1], + prompt_mask=[True], + completion_ids=[2, 3], + completion_mask=[True, True], + completion_logprobs=[-0.1, -0.2], + completion_temperatures=[1.0, 1.0], + env_name="test-env", + ) + + try: + await orchestrator_utils.compute_teacher_logprobs( + clients=[vf.ClientConfig()], + model_name="teacher-model", + samples=[sample], + ) + except httpx.HTTPStatusError as exc: + assert exc.response.status_code == 401 + else: + raise AssertionError("Expected HTTPStatusError") + + asyncio.run(_run())