Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions skills/configs/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,18 @@ CLI: `--env.0.id reverse-text --env.1.id math-env`.

In TOML, an empty section header (`[ckpt]`) does the same.

## Distillation training modes

Set `orchestrator.training_mode = "sft"` or `orchestrator.training_mode = "opd"` (or top-level `training_mode`, which auto-propagates) and configure `orchestrator.teacher` with the teacher endpoint.

In `sft`, rollouts are generated by the teacher. The shared `training_mode` validator sets `trainer.loss.type = "sft"`, and trainer loss dispatch uses `sft_loss_fn`.

In `opd`, rollouts are generated by the student. The orchestrator scores the student token sequence with the teacher, writes those values to `TrainingSample.teacher_logprobs`, and trainer loss dispatch uses `opd_loss_fn`.

`[inference]` is required for the usual online path because it starts the student inference server and auto-configures `orchestrator.student.client.base_url`. The student pool is used for online evals and policy weight sync. For externally started student inference, set `orchestrator.student.client.base_url` explicitly instead.

Teacher logprob scoring uses PrimeRL's vLLM-native `/inference/v1/generate` route. The request field is `token_ids`, meaning the prompt plus completion tokens to score; response `choices[].token_ids` remains generated completion tokens and is not used for OPD scoring.
Comment on lines +63 to +73
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets remove also ? not sure what is the link with the pr


## RL trainer token exports

For rollout debugging, enable trainer-side token export under `trainer.experimental.token_export` (or `experimental.token_export` when running the trainer entrypoint directly). It writes one JSONL record per exported sequence under `output_dir/token_exports/step_<step>/rank_<rank>.jsonl`. Each record stores aligned per-token arrays for token ids, loss mask, advantage, reward, entropy, mismatch KL, inference/trainer logprobs, importance ratios, probability deltas, and masking diagnostics. It does not decode token text in the trainer.
Expand Down
86 changes: 55 additions & 31 deletions src/prime_rl/orchestrator/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

import asyncio
import logging
import time
from concurrent.futures import ThreadPoolExecutor
from itertools import cycle
from pathlib import Path
from typing import Any

import orjson
import verifiers as vf
Expand Down Expand Up @@ -103,51 +106,72 @@ async def compute_teacher_logprobs(
) -> list[list[float]]:
"""Compute teacher model logprobs for a batch of training samples via prefill."""
import httpx
from vllm.entrypoints.serve.disagg.protocol import GenerateResponse

def _teacher_generate_request(
base_url: str,
model_name: str,
scored_token_ids: list[int],
) -> tuple[str, dict[str, Any]]:
base = base_url.rstrip("/")
return f"{base.removesuffix('/v1')}/inference/v1/generate", {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing Prime API URL routing for hosted teacher

High Severity

_teacher_generate_request always constructs a /inference/v1/generate URL, but the PR claims to also route to Prime API /api/v1/generate when a hosted teacher base URL is used. With a Prime API base URL like https://api.pinference.ai/api/v1 (as seen in configs/debug/training_modes/sft_external.toml), the function strips /v1 and appends /inference/v1/generate, producing https://api.pinference.ai/api/inference/v1/generate — an incorrect path that would 404. The expected Prime API URL is https://api.pinference.ai/api/v1/generate. No conditional routing exists and no unit test covers a Prime API base URL, despite the PR caveat stating the path is "unit/dry-run covered."

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 31b5fd1. Configure here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is stale after the latest scope change. PrimeRL no longer branches to /api/v1/generate; this PR intentionally keeps OPD teacher scoring on the native vLLM /inference/v1/generate route. The Prime Inference /api/v1/generate gateway is platform/pi-inference-owned via PrimeIntellect-ai/platform#2342.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this break opd in prime rl , we can't do this

"model": model_name,
"token_ids": scored_token_ids,
"sampling_params": {
"max_tokens": 1,
"temperature": 1.0,
"top_p": 1.0,
"prompt_logprobs": 1,
},
}

def _flatten_prompt_logprobs(response: dict[str, Any], token_ids: list[int]) -> list[float]:
# ``prompt_logprobs[i]`` is a ``{token_id: Logprob}`` dict for tokens
# the engine could score, or ``None`` for the leading token which has
# no preceding context. vLLM can include both the target token and the
# top-k alternatives; select the exact target token at each position.
prompt_logprobs = response.get("prompt_logprobs") or []
if len(prompt_logprobs) != len(token_ids):
raise ValueError(
f"teacher prompt_logprobs length != sample length ({len(prompt_logprobs)} != {len(token_ids)})"
)
flat: list[float] = []
for i, (token_id, entry) in enumerate(zip(token_ids, prompt_logprobs)):
if not entry:
if i != 0:
raise ValueError(f"teacher prompt_logprobs missing entry at position {i} for token id {token_id}")
flat.append(0.0)
continue
target = entry.get(str(token_id))
if target is None:
target = entry.get(token_id)
if target is None:
raise ValueError(f"teacher prompt_logprobs missing token id {token_id}")
lp = target.get("logprob")
flat.append(float(lp) if lp is not None else 0.0)
return flat

async def _compute_single(client_config: vf.ClientConfig, sample: TrainingSample) -> list[float]:
client = setup_openai_client(client_config)
scored_token_ids = list(sample.prompt_ids) + list(sample.completion_ids)

# Two escape hatches from ``AsyncOpenAI.post``:
# 1. URL — ``/inference/v1/generate`` is mounted at server root, not
# under ``/v1``. Pass an absolute URL so the SDK's
# ``_prepare_url`` skips the base-url merge (it short-circuits
# when the path passes ``httpx.URL.is_relative_url`` as False).
# 1. URL — vLLM mounts ``/inference/v1/generate`` at server root,
# so pass an absolute URL and skip the SDK's base-url merge.
# 2. Parse — vLLM's ``GenerateResponse`` is a plain
# ``pydantic.BaseModel`` and the SDK's parse layer rejects any
# ``cast_to`` that doesn't subclass ``openai.BaseModel``. Use
# ``cast_to=httpx.Response`` so the SDK still builds the request
# (preserving ``auth_headers``, retries, timeouts, idempotency
# keys) and just hands us the raw response to validate ourselves.
base = str(client.base_url).rstrip("/").removesuffix("/v1")
url, body = _teacher_generate_request(str(client.base_url), model_name, scored_token_ids)
http_response = await client.post(
f"{base}/inference/v1/generate",
url,
cast_to=httpx.Response,
body={
"model": model_name,
"token_ids": list(sample.prompt_ids) + list(sample.completion_ids),
"sampling_params": {
"max_tokens": 1,
"temperature": 1.0,
"top_p": 1.0,
"prompt_logprobs": 1,
},
},
body=body,
)
response = GenerateResponse.model_validate_json(http_response.content)
# ``prompt_logprobs[i]`` is a ``{token_id: Logprob}`` dict for tokens
# the engine could score, or ``None`` for the leading token which has
# no preceding context. Flatten to ``list[float]`` with 0.0 in the
# unscored slot.
flat: list[float] = []
for entry in response.prompt_logprobs or []:
if not entry:
flat.append(0.0)
continue
first = next(iter(entry.values()))
lp = first.logprob if hasattr(first, "logprob") else first.get("logprob")
flat.append(float(lp) if lp is not None else 0.0)
return flat
http_response.raise_for_status()
response = http_response.json()
return _flatten_prompt_logprobs(response, scored_token_ids)

return await asyncio.gather(*[_compute_single(client, sample) for client, sample in zip(cycle(clients), samples)])

Expand Down
157 changes: 151 additions & 6 deletions tests/unit/orchestrator/test_teacher_logprobs.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import asyncio
import json
from types import SimpleNamespace

import httpx
import verifiers as vf

from prime_rl.orchestrator import utils as orchestrator_utils
from prime_rl.transport import TrainingSample


class _FakeOpenAIClient:
Expand All @@ -14,17 +14,18 @@ class _FakeOpenAIClient:
handed back verbatim, mirroring the real SDK's short-circuit at
``AsyncAPIClient._process_response``."""

def __init__(self, payload: dict):
def __init__(self, payload: dict, base_url: str = "http://fake-host:8000/v1", status_code: int = 200):
# Match what AsyncOpenAI exposes — utils.py reads ``str(client.base_url)``.
self.base_url = "http://fake-host:8000/v1"
self.base_url = base_url
self._payload = payload
self._status_code = status_code
self.calls: list[dict] = []

async def post(self, url, *, cast_to, body):
self.calls.append({"url": url, "cast_to": cast_to, "body": body})
request = httpx.Request("POST", url, json=body)
return httpx.Response(
status_code=200,
status_code=self._status_code,
content=json.dumps(self._payload).encode(),
request=request,
)
Expand All @@ -37,13 +38,17 @@ async def _run():
"request_id": "gen-test",
"choices": [],
# Upstream wire shape: list[dict[token_id, Logprob] | None]
"prompt_logprobs": [None, {"11": {"logprob": -0.7}}, {"12": {"logprob": -0.3}}],
"prompt_logprobs": [
None,
{"13": {"logprob": -0.1}, "2": {"logprob": -0.7}},
{"198": {"logprob": -0.2}, "3": {"logprob": -0.3}},
],
"kv_transfer_params": None,
}
)
monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client)

sample = TrainingSample(
sample = SimpleNamespace(
prompt_ids=[1],
prompt_mask=[True],
completion_ids=[2, 3],
Expand Down Expand Up @@ -78,3 +83,143 @@ async def _run():
]

asyncio.run(_run())


def test_compute_teacher_logprobs_rejects_wrong_length(monkeypatch):
async def _run():
fake_client = _FakeOpenAIClient(
{
"request_id": "gen-test",
"choices": [],
"prompt_logprobs": [None, {"2": {"logprob": -0.7}}],
"kv_transfer_params": None,
}
)
monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client)

sample = SimpleNamespace(
prompt_ids=[1],
prompt_mask=[True],
completion_ids=[2, 3],
completion_mask=[True, True],
completion_logprobs=[-0.1, -0.2],
completion_temperatures=[1.0, 1.0],
env_name="test-env",
)

try:
await orchestrator_utils.compute_teacher_logprobs(
clients=[vf.ClientConfig()],
model_name="teacher-model",
samples=[sample],
)
except ValueError as exc:
assert "teacher prompt_logprobs length != sample length" in str(exc)
else:
raise AssertionError("Expected ValueError")

asyncio.run(_run())


def test_compute_teacher_logprobs_rejects_missing_token_id(monkeypatch):
async def _run():
fake_client = _FakeOpenAIClient(
{
"request_id": "gen-test",
"choices": [],
"prompt_logprobs": [None, {"13": {"logprob": -0.1}}, {"3": {"logprob": -0.3}}],
"kv_transfer_params": None,
}
)
monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client)

sample = SimpleNamespace(
prompt_ids=[1],
prompt_mask=[True],
completion_ids=[2, 3],
completion_mask=[True, True],
completion_logprobs=[-0.1, -0.2],
completion_temperatures=[1.0, 1.0],
env_name="test-env",
)

try:
await orchestrator_utils.compute_teacher_logprobs(
clients=[vf.ClientConfig()],
model_name="teacher-model",
samples=[sample],
)
except ValueError as exc:
assert "teacher prompt_logprobs missing token id 2" in str(exc)
else:
raise AssertionError("Expected ValueError")

asyncio.run(_run())


def test_compute_teacher_logprobs_rejects_missing_non_leading_entry(monkeypatch):
async def _run():
fake_client = _FakeOpenAIClient(
{
"request_id": "gen-test",
"choices": [],
"prompt_logprobs": [None, None, {"3": {"logprob": -0.3}}],
"kv_transfer_params": None,
}
)
monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client)

sample = SimpleNamespace(
prompt_ids=[1],
prompt_mask=[True],
completion_ids=[2, 3],
completion_mask=[True, True],
completion_logprobs=[-0.1, -0.2],
completion_temperatures=[1.0, 1.0],
env_name="test-env",
)

try:
await orchestrator_utils.compute_teacher_logprobs(
clients=[vf.ClientConfig()],
model_name="teacher-model",
samples=[sample],
)
except ValueError as exc:
assert "teacher prompt_logprobs missing entry at position 1 for token id 2" in str(exc)
else:
raise AssertionError("Expected ValueError")

asyncio.run(_run())


def test_compute_teacher_logprobs_raises_for_teacher_http_error(monkeypatch):
async def _run():
fake_client = _FakeOpenAIClient(
{"error": {"message": "invalid teacher api key"}},
status_code=401,
)
monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client)

sample = SimpleNamespace(
prompt_ids=[1],
prompt_mask=[True],
completion_ids=[2, 3],
completion_mask=[True, True],
completion_logprobs=[-0.1, -0.2],
completion_temperatures=[1.0, 1.0],
env_name="test-env",
)

try:
await orchestrator_utils.compute_teacher_logprobs(
clients=[vf.ClientConfig()],
model_name="teacher-model",
samples=[sample],
)
except httpx.HTTPStatusError as exc:
assert exc.response.status_code == 401
else:
raise AssertionError("Expected HTTPStatusError")

asyncio.run(_run())
Comment on lines +88 to +225
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets remove tests here

Loading