|
| 1 | +from __future__ import annotations |
| 2 | + |
1 | 3 | import asyncio |
2 | 4 | import time |
3 | 5 | from concurrent.futures import ThreadPoolExecutor |
4 | 6 | from itertools import cycle |
5 | 7 | from pathlib import Path |
6 | | -from typing import Any |
| 8 | +from typing import TYPE_CHECKING, Any |
7 | 9 |
|
8 | 10 | import pandas as pd |
9 | 11 | import verifiers as vf |
10 | 12 | from rich.console import Console |
11 | 13 | from rich.table import Table |
12 | 14 | from verifiers.utils.client_utils import setup_openai_client |
13 | 15 |
|
14 | | -from prime_rl.transport import TrainingSample |
15 | 16 | from prime_rl.utils.logger import get_logger |
16 | 17 | from prime_rl.utils.utils import ( |
17 | 18 | format_time, |
|
20 | 21 | get_step_path, |
21 | 22 | ) |
22 | 23 |
|
| 24 | +if TYPE_CHECKING: |
| 25 | + from prime_rl.transport.types import TrainingSample |
| 26 | + |
23 | 27 |
|
24 | 28 | def set_default_executor(max_workers: int = 64) -> None: |
25 | 29 | """Scale the default asyncio thread pool so asyncio.to_thread has enough capacity.""" |
@@ -84,51 +88,78 @@ async def compute_teacher_logprobs( |
84 | 88 | ) -> list[list[float]]: |
85 | 89 | """Compute teacher model logprobs for a batch of training samples via prefill.""" |
86 | 90 | import httpx |
87 | | - from vllm.entrypoints.serve.disagg.protocol import GenerateResponse |
| 91 | + |
| 92 | + def _teacher_generate_url(base_url: str) -> str: |
| 93 | + base = base_url.rstrip("/") |
| 94 | + if base.endswith("/api/v1"): |
| 95 | + return f"{base}/generate" |
| 96 | + return f"{base.removesuffix('/v1')}/inference/v1/generate" |
| 97 | + |
| 98 | + def _teacher_generate_body(base_url: str, model_name: str, token_ids: list[int]) -> dict[str, Any]: |
| 99 | + sampling_params = { |
| 100 | + "max_tokens": 1, |
| 101 | + "temperature": 1.0, |
| 102 | + "top_p": 1.0, |
| 103 | + "prompt_logprobs": 1, |
| 104 | + } |
| 105 | + if base_url.rstrip("/").endswith("/api/v1"): |
| 106 | + return { |
| 107 | + "model": model_name, |
| 108 | + "prompt_token_ids": token_ids, |
| 109 | + "sampling_params": sampling_params, |
| 110 | + } |
| 111 | + return { |
| 112 | + "model": model_name, |
| 113 | + "token_ids": token_ids, |
| 114 | + "sampling_params": sampling_params, |
| 115 | + } |
| 116 | + |
| 117 | + def _flatten_prompt_logprobs(response: dict[str, Any], token_ids: list[int]) -> list[float]: |
| 118 | + # ``prompt_logprobs[i]`` is a ``{token_id: Logprob}`` dict for tokens |
| 119 | + # the engine could score, or ``None`` for the leading token which has |
| 120 | + # no preceding context. vLLM can include both the target token and the |
| 121 | + # top-k alternatives; select the exact target token at each position. |
| 122 | + prompt_logprobs = response.get("prompt_logprobs") or [] |
| 123 | + if len(prompt_logprobs) != len(token_ids): |
| 124 | + raise ValueError( |
| 125 | + f"teacher prompt_logprobs length != sample length ({len(prompt_logprobs)} != {len(token_ids)})" |
| 126 | + ) |
| 127 | + flat: list[float] = [] |
| 128 | + for i, (token_id, entry) in enumerate(zip(token_ids, prompt_logprobs)): |
| 129 | + if not entry: |
| 130 | + if i != 0: |
| 131 | + raise ValueError(f"teacher prompt_logprobs missing entry for token id {token_id}") |
| 132 | + flat.append(0.0) |
| 133 | + continue |
| 134 | + target = entry.get(str(token_id)) or entry.get(token_id) |
| 135 | + if target is None: |
| 136 | + raise ValueError(f"teacher prompt_logprobs missing token id {token_id}") |
| 137 | + lp = target.get("logprob") |
| 138 | + flat.append(float(lp) if lp is not None else 0.0) |
| 139 | + return flat |
88 | 140 |
|
89 | 141 | async def _compute_single(client_config: vf.ClientConfig, sample: TrainingSample) -> list[float]: |
90 | 142 | client = setup_openai_client(client_config) |
| 143 | + token_ids = list(sample.prompt_ids) + list(sample.completion_ids) |
91 | 144 |
|
92 | 145 | # Two escape hatches from ``AsyncOpenAI.post``: |
93 | | - # 1. URL — ``/inference/v1/generate`` is mounted at server root, not |
94 | | - # under ``/v1``. Pass an absolute URL so the SDK's |
95 | | - # ``_prepare_url`` skips the base-url merge (it short-circuits |
96 | | - # when the path passes ``httpx.URL.is_relative_url`` as False). |
| 146 | + # 1. URL — vLLM mounts ``/inference/v1/generate`` at server root, |
| 147 | + # while Prime Inference exposes ``/api/v1/generate``. Pass an |
| 148 | + # absolute URL so the SDK's ``_prepare_url`` skips base-url merge. |
97 | 149 | # 2. Parse — vLLM's ``GenerateResponse`` is a plain |
98 | 150 | # ``pydantic.BaseModel`` and the SDK's parse layer rejects any |
99 | 151 | # ``cast_to`` that doesn't subclass ``openai.BaseModel``. Use |
100 | 152 | # ``cast_to=httpx.Response`` so the SDK still builds the request |
101 | 153 | # (preserving ``auth_headers``, retries, timeouts, idempotency |
102 | 154 | # keys) and just hands us the raw response to validate ourselves. |
103 | | - base = str(client.base_url).rstrip("/").removesuffix("/v1") |
104 | 155 | http_response = await client.post( |
105 | | - f"{base}/inference/v1/generate", |
| 156 | + _teacher_generate_url(str(client.base_url)), |
106 | 157 | cast_to=httpx.Response, |
107 | | - body={ |
108 | | - "model": model_name, |
109 | | - "token_ids": list(sample.prompt_ids) + list(sample.completion_ids), |
110 | | - "sampling_params": { |
111 | | - "max_tokens": 1, |
112 | | - "temperature": 1.0, |
113 | | - "top_p": 1.0, |
114 | | - "prompt_logprobs": 1, |
115 | | - }, |
116 | | - }, |
| 158 | + body=_teacher_generate_body(str(client.base_url), model_name, token_ids), |
117 | 159 | ) |
118 | | - response = GenerateResponse.model_validate_json(http_response.content) |
119 | | - # ``prompt_logprobs[i]`` is a ``{token_id: Logprob}`` dict for tokens |
120 | | - # the engine could score, or ``None`` for the leading token which has |
121 | | - # no preceding context. Flatten to ``list[float]`` with 0.0 in the |
122 | | - # unscored slot. |
123 | | - flat: list[float] = [] |
124 | | - for entry in response.prompt_logprobs or []: |
125 | | - if not entry: |
126 | | - flat.append(0.0) |
127 | | - continue |
128 | | - first = next(iter(entry.values())) |
129 | | - lp = first.logprob if hasattr(first, "logprob") else first.get("logprob") |
130 | | - flat.append(float(lp) if lp is not None else 0.0) |
131 | | - return flat |
| 160 | + http_response.raise_for_status() |
| 161 | + response = http_response.json() |
| 162 | + return _flatten_prompt_logprobs(response, token_ids) |
132 | 163 |
|
133 | 164 | return await asyncio.gather(*[_compute_single(client, sample) for client, sample in zip(cycle(clients), samples)]) |
134 | 165 |
|
|
0 commit comments