-
Notifications
You must be signed in to change notification settings - Fork 304
fix: harden OPD teacher logprob scoring #2683
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,12 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| import logging | ||
| import time | ||
| from concurrent.futures import ThreadPoolExecutor | ||
| from itertools import cycle | ||
| from pathlib import Path | ||
| from typing import Any | ||
|
|
||
| import orjson | ||
| import verifiers as vf | ||
|
|
@@ -103,51 +106,72 @@ async def compute_teacher_logprobs( | |
| ) -> list[list[float]]: | ||
| """Compute teacher model logprobs for a batch of training samples via prefill.""" | ||
| import httpx | ||
| from vllm.entrypoints.serve.disagg.protocol import GenerateResponse | ||
|
|
||
| def _teacher_generate_request( | ||
| base_url: str, | ||
| model_name: str, | ||
| scored_token_ids: list[int], | ||
| ) -> tuple[str, dict[str, Any]]: | ||
| base = base_url.rstrip("/") | ||
| return f"{base.removesuffix('/v1')}/inference/v1/generate", { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing Prime API URL routing for hosted teacherHigh Severity
Reviewed by Cursor Bugbot for commit 31b5fd1. Configure here.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is stale after the latest scope change. PrimeRL no longer branches to
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this break opd in prime rl , we can't do this |
||
| "model": model_name, | ||
| "token_ids": scored_token_ids, | ||
| "sampling_params": { | ||
| "max_tokens": 1, | ||
| "temperature": 1.0, | ||
| "top_p": 1.0, | ||
| "prompt_logprobs": 1, | ||
| }, | ||
| } | ||
|
|
||
| def _flatten_prompt_logprobs(response: dict[str, Any], token_ids: list[int]) -> list[float]: | ||
| # ``prompt_logprobs[i]`` is a ``{token_id: Logprob}`` dict for tokens | ||
| # the engine could score, or ``None`` for the leading token which has | ||
| # no preceding context. vLLM can include both the target token and the | ||
| # top-k alternatives; select the exact target token at each position. | ||
| prompt_logprobs = response.get("prompt_logprobs") or [] | ||
| if len(prompt_logprobs) != len(token_ids): | ||
| raise ValueError( | ||
| f"teacher prompt_logprobs length != sample length ({len(prompt_logprobs)} != {len(token_ids)})" | ||
| ) | ||
| flat: list[float] = [] | ||
| for i, (token_id, entry) in enumerate(zip(token_ids, prompt_logprobs)): | ||
| if not entry: | ||
| if i != 0: | ||
| raise ValueError(f"teacher prompt_logprobs missing entry at position {i} for token id {token_id}") | ||
| flat.append(0.0) | ||
| continue | ||
| target = entry.get(str(token_id)) | ||
| if target is None: | ||
| target = entry.get(token_id) | ||
| if target is None: | ||
| raise ValueError(f"teacher prompt_logprobs missing token id {token_id}") | ||
| lp = target.get("logprob") | ||
| flat.append(float(lp) if lp is not None else 0.0) | ||
| return flat | ||
|
|
||
| async def _compute_single(client_config: vf.ClientConfig, sample: TrainingSample) -> list[float]: | ||
| client = setup_openai_client(client_config) | ||
| scored_token_ids = list(sample.prompt_ids) + list(sample.completion_ids) | ||
|
|
||
| # Two escape hatches from ``AsyncOpenAI.post``: | ||
| # 1. URL — ``/inference/v1/generate`` is mounted at server root, not | ||
| # under ``/v1``. Pass an absolute URL so the SDK's | ||
| # ``_prepare_url`` skips the base-url merge (it short-circuits | ||
| # when the path passes ``httpx.URL.is_relative_url`` as False). | ||
| # 1. URL — vLLM mounts ``/inference/v1/generate`` at server root, | ||
| # so pass an absolute URL and skip the SDK's base-url merge. | ||
| # 2. Parse — vLLM's ``GenerateResponse`` is a plain | ||
| # ``pydantic.BaseModel`` and the SDK's parse layer rejects any | ||
| # ``cast_to`` that doesn't subclass ``openai.BaseModel``. Use | ||
| # ``cast_to=httpx.Response`` so the SDK still builds the request | ||
| # (preserving ``auth_headers``, retries, timeouts, idempotency | ||
| # keys) and just hands us the raw response to validate ourselves. | ||
| base = str(client.base_url).rstrip("/").removesuffix("/v1") | ||
| url, body = _teacher_generate_request(str(client.base_url), model_name, scored_token_ids) | ||
| http_response = await client.post( | ||
| f"{base}/inference/v1/generate", | ||
| url, | ||
| cast_to=httpx.Response, | ||
| body={ | ||
| "model": model_name, | ||
| "token_ids": list(sample.prompt_ids) + list(sample.completion_ids), | ||
| "sampling_params": { | ||
| "max_tokens": 1, | ||
| "temperature": 1.0, | ||
| "top_p": 1.0, | ||
| "prompt_logprobs": 1, | ||
| }, | ||
| }, | ||
| body=body, | ||
| ) | ||
| response = GenerateResponse.model_validate_json(http_response.content) | ||
| # ``prompt_logprobs[i]`` is a ``{token_id: Logprob}`` dict for tokens | ||
| # the engine could score, or ``None`` for the leading token which has | ||
| # no preceding context. Flatten to ``list[float]`` with 0.0 in the | ||
| # unscored slot. | ||
| flat: list[float] = [] | ||
| for entry in response.prompt_logprobs or []: | ||
| if not entry: | ||
| flat.append(0.0) | ||
| continue | ||
| first = next(iter(entry.values())) | ||
| lp = first.logprob if hasattr(first, "logprob") else first.get("logprob") | ||
| flat.append(float(lp) if lp is not None else 0.0) | ||
| return flat | ||
| http_response.raise_for_status() | ||
| response = http_response.json() | ||
| return _flatten_prompt_logprobs(response, scored_token_ids) | ||
|
|
||
| return await asyncio.gather(*[_compute_single(client, sample) for client, sample in zip(cycle(clients), samples)]) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,11 @@ | ||
| import asyncio | ||
| import json | ||
| from types import SimpleNamespace | ||
|
|
||
| import httpx | ||
| import verifiers as vf | ||
|
|
||
| from prime_rl.orchestrator import utils as orchestrator_utils | ||
| from prime_rl.transport import TrainingSample | ||
|
|
||
|
|
||
| class _FakeOpenAIClient: | ||
|
|
@@ -14,17 +14,18 @@ class _FakeOpenAIClient: | |
| handed back verbatim, mirroring the real SDK's short-circuit at | ||
| ``AsyncAPIClient._process_response``.""" | ||
|
|
||
| def __init__(self, payload: dict): | ||
| def __init__(self, payload: dict, base_url: str = "http://fake-host:8000/v1", status_code: int = 200): | ||
| # Match what AsyncOpenAI exposes — utils.py reads ``str(client.base_url)``. | ||
| self.base_url = "http://fake-host:8000/v1" | ||
| self.base_url = base_url | ||
| self._payload = payload | ||
| self._status_code = status_code | ||
| self.calls: list[dict] = [] | ||
|
|
||
| async def post(self, url, *, cast_to, body): | ||
| self.calls.append({"url": url, "cast_to": cast_to, "body": body}) | ||
| request = httpx.Request("POST", url, json=body) | ||
| return httpx.Response( | ||
| status_code=200, | ||
| status_code=self._status_code, | ||
| content=json.dumps(self._payload).encode(), | ||
| request=request, | ||
| ) | ||
|
|
@@ -37,13 +38,17 @@ async def _run(): | |
| "request_id": "gen-test", | ||
| "choices": [], | ||
| # Upstream wire shape: list[dict[token_id, Logprob] | None] | ||
| "prompt_logprobs": [None, {"11": {"logprob": -0.7}}, {"12": {"logprob": -0.3}}], | ||
| "prompt_logprobs": [ | ||
| None, | ||
| {"13": {"logprob": -0.1}, "2": {"logprob": -0.7}}, | ||
| {"198": {"logprob": -0.2}, "3": {"logprob": -0.3}}, | ||
| ], | ||
| "kv_transfer_params": None, | ||
| } | ||
| ) | ||
| monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client) | ||
|
|
||
| sample = TrainingSample( | ||
| sample = SimpleNamespace( | ||
| prompt_ids=[1], | ||
| prompt_mask=[True], | ||
| completion_ids=[2, 3], | ||
|
|
@@ -78,3 +83,143 @@ async def _run(): | |
| ] | ||
|
|
||
| asyncio.run(_run()) | ||
|
|
||
|
|
||
| def test_compute_teacher_logprobs_rejects_wrong_length(monkeypatch): | ||
| async def _run(): | ||
| fake_client = _FakeOpenAIClient( | ||
| { | ||
| "request_id": "gen-test", | ||
| "choices": [], | ||
| "prompt_logprobs": [None, {"2": {"logprob": -0.7}}], | ||
| "kv_transfer_params": None, | ||
| } | ||
| ) | ||
| monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client) | ||
|
|
||
| sample = SimpleNamespace( | ||
| prompt_ids=[1], | ||
| prompt_mask=[True], | ||
| completion_ids=[2, 3], | ||
| completion_mask=[True, True], | ||
| completion_logprobs=[-0.1, -0.2], | ||
| completion_temperatures=[1.0, 1.0], | ||
| env_name="test-env", | ||
| ) | ||
|
|
||
| try: | ||
| await orchestrator_utils.compute_teacher_logprobs( | ||
| clients=[vf.ClientConfig()], | ||
| model_name="teacher-model", | ||
| samples=[sample], | ||
| ) | ||
| except ValueError as exc: | ||
| assert "teacher prompt_logprobs length != sample length" in str(exc) | ||
| else: | ||
| raise AssertionError("Expected ValueError") | ||
|
|
||
| asyncio.run(_run()) | ||
|
|
||
|
|
||
| def test_compute_teacher_logprobs_rejects_missing_token_id(monkeypatch): | ||
| async def _run(): | ||
| fake_client = _FakeOpenAIClient( | ||
| { | ||
| "request_id": "gen-test", | ||
| "choices": [], | ||
| "prompt_logprobs": [None, {"13": {"logprob": -0.1}}, {"3": {"logprob": -0.3}}], | ||
| "kv_transfer_params": None, | ||
| } | ||
| ) | ||
| monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client) | ||
|
|
||
| sample = SimpleNamespace( | ||
| prompt_ids=[1], | ||
| prompt_mask=[True], | ||
| completion_ids=[2, 3], | ||
| completion_mask=[True, True], | ||
| completion_logprobs=[-0.1, -0.2], | ||
| completion_temperatures=[1.0, 1.0], | ||
| env_name="test-env", | ||
| ) | ||
|
|
||
| try: | ||
| await orchestrator_utils.compute_teacher_logprobs( | ||
| clients=[vf.ClientConfig()], | ||
| model_name="teacher-model", | ||
| samples=[sample], | ||
| ) | ||
| except ValueError as exc: | ||
| assert "teacher prompt_logprobs missing token id 2" in str(exc) | ||
| else: | ||
| raise AssertionError("Expected ValueError") | ||
|
|
||
| asyncio.run(_run()) | ||
|
|
||
|
|
||
| def test_compute_teacher_logprobs_rejects_missing_non_leading_entry(monkeypatch): | ||
| async def _run(): | ||
| fake_client = _FakeOpenAIClient( | ||
| { | ||
| "request_id": "gen-test", | ||
| "choices": [], | ||
| "prompt_logprobs": [None, None, {"3": {"logprob": -0.3}}], | ||
| "kv_transfer_params": None, | ||
| } | ||
| ) | ||
| monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client) | ||
|
|
||
| sample = SimpleNamespace( | ||
| prompt_ids=[1], | ||
| prompt_mask=[True], | ||
| completion_ids=[2, 3], | ||
| completion_mask=[True, True], | ||
| completion_logprobs=[-0.1, -0.2], | ||
| completion_temperatures=[1.0, 1.0], | ||
| env_name="test-env", | ||
| ) | ||
|
|
||
| try: | ||
| await orchestrator_utils.compute_teacher_logprobs( | ||
| clients=[vf.ClientConfig()], | ||
| model_name="teacher-model", | ||
| samples=[sample], | ||
| ) | ||
| except ValueError as exc: | ||
| assert "teacher prompt_logprobs missing entry at position 1 for token id 2" in str(exc) | ||
| else: | ||
| raise AssertionError("Expected ValueError") | ||
|
|
||
| asyncio.run(_run()) | ||
|
|
||
|
|
||
| def test_compute_teacher_logprobs_raises_for_teacher_http_error(monkeypatch): | ||
| async def _run(): | ||
| fake_client = _FakeOpenAIClient( | ||
| {"error": {"message": "invalid teacher api key"}}, | ||
| status_code=401, | ||
| ) | ||
| monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client) | ||
|
|
||
| sample = SimpleNamespace( | ||
| prompt_ids=[1], | ||
| prompt_mask=[True], | ||
| completion_ids=[2, 3], | ||
| completion_mask=[True, True], | ||
| completion_logprobs=[-0.1, -0.2], | ||
| completion_temperatures=[1.0, 1.0], | ||
| env_name="test-env", | ||
| ) | ||
|
|
||
| try: | ||
| await orchestrator_utils.compute_teacher_logprobs( | ||
| clients=[vf.ClientConfig()], | ||
| model_name="teacher-model", | ||
| samples=[sample], | ||
| ) | ||
| except httpx.HTTPStatusError as exc: | ||
| assert exc.response.status_code == 401 | ||
| else: | ||
| raise AssertionError("Expected HTTPStatusError") | ||
|
|
||
| asyncio.run(_run()) | ||
|
Comment on lines
+88
to
+225
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lets remove tests here |
||


There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets remove also ? not sure what is the link with the pr