Skip to content

Commit e35fe3c

Browse files
authored
move token utils into oai token client (#913)
* move utils where they belong * same stripping as before
1 parent 75a3702 commit e35fe3c

2 files changed

Lines changed: 155 additions & 179 deletions

File tree

verifiers/clients/openai_chat_completions_token_client.py

Lines changed: 155 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,38 @@
1-
from typing import cast
1+
from typing import Optional, cast
22

3+
from openai import AsyncOpenAI, BaseModel
34
from openai.types.chat import ChatCompletion
45

56
from verifiers.clients.openai_chat_completions_client import (
67
OpenAIChatCompletionsClient,
78
OpenAIChatMessages,
89
OpenAIChatResponse,
910
OpenAITool,
10-
)
11-
from verifiers.clients.openai_chat_completions_client import (
1211
handle_openai_overlong_prompt,
1312
)
1413
from verifiers.types import SamplingArgs, State
15-
from verifiers.utils.token_utils import get_prompt_ids
14+
from verifiers.utils.message_utils import concat_messages
15+
16+
17+
# copy from vllm/entrypoints/openai/protocol.py
18+
class TokenizeResponse(BaseModel):
19+
count: int
20+
max_model_len: int
21+
tokens: list[int]
22+
token_strs: Optional[list[str]] = None
1623

1724

1825
class OpenAIChatCompletionsTokenClient(OpenAIChatCompletionsClient):
1926
"""Wrapper for custom vLLM route /v1/chat/completions/tokens via AsyncOpenAI client."""
2027

28+
@property
29+
def token_client(self) -> AsyncOpenAI:
30+
"""Strips trailing /v1 from the OpenAI client."""
31+
base_url = str(self.client.base_url).rstrip("/")
32+
if base_url.endswith("/v1"):
33+
base_url = base_url[:-3]
34+
return self.client.with_options(base_url=base_url)
35+
2136
@handle_openai_overlong_prompt
2237
async def get_native_response(
2338
self,
@@ -49,7 +64,7 @@ def normalize_sampling_args(sampling_args: SamplingArgs):
4964
return await super().get_native_response(
5065
prompt, model, sampling_args, tools
5166
)
52-
prompt_ids = await get_prompt_ids(state, prompt, tools, self.client)
67+
prompt_ids = await self.get_prompt_ids(state, prompt, tools)
5368
extra_body = sampling_args.pop("extra_body", {})
5469
body = dict(
5570
model=model,
@@ -65,3 +80,138 @@ def normalize_sampling_args(sampling_args: SamplingArgs):
6580
body=body,
6681
cast_to=ChatCompletion,
6782
)
83+
84+
async def get_prompt_ids(
85+
self,
86+
state: State,
87+
prompt_messages: OpenAIChatMessages,
88+
oai_tools: list[OpenAITool] | None,
89+
) -> list[int]:
90+
"""
91+
Build prompt_ids (token prompt) corresponding to prompt_messages. We assume
92+
that this method is called *before* making the model response from
93+
prompt_messages, i.e. the previous turn's prompt and completion do not yet
94+
include the environment response and next turn's model response.
95+
"""
96+
prev_turn_prompt = state["trajectory"][-1]["prompt"]
97+
prev_turn_completion = state["trajectory"][-1]["completion"]
98+
prev_turn_tokens = state["trajectory"][-1]["tokens"]
99+
assert prev_turn_tokens is not None
100+
prev_turn_prompt_ids = prev_turn_tokens["prompt_ids"]
101+
prev_turn_completion_ids = prev_turn_tokens["completion_ids"]
102+
prev_turn_ids = prev_turn_prompt_ids + prev_turn_completion_ids
103+
104+
# the env response is all messages after the previous turn
105+
messages = concat_messages([prev_turn_prompt, prev_turn_completion])
106+
env_response = prompt_messages[len(messages) :]
107+
108+
def compute_suffix_ids(lst: list[int], value: int) -> list[int]:
109+
"""Returns all tokens after the last occurrence of `value` in `lst`, if any."""
110+
111+
def find_last_index(lst: list[int], target: int) -> int:
112+
for i in range(len(lst) - 1, -1, -1):
113+
if lst[i] == target:
114+
return i
115+
raise ValueError
116+
117+
try:
118+
i = find_last_index(lst, value)
119+
suffix_ids = lst[i + 1 :]
120+
return suffix_ids
121+
except ValueError:
122+
# end of message token not found, so we don't need to add any suffix tokens
123+
return []
124+
125+
def find_largest_overlap(a: list[int], b: list[int]) -> int:
126+
"""Find the largest overlapping sequence between the end of a and beginning of b."""
127+
if not a or not b:
128+
return 0
129+
130+
max_possible = min(len(a), len(b))
131+
for overlap_len in reversed(range(1, max_possible + 1)):
132+
a_suffix = a[-overlap_len:]
133+
b_prefix = b[:overlap_len]
134+
135+
if a_suffix == b_prefix:
136+
return overlap_len
137+
138+
return 0
139+
140+
# we build the env_response_ids using simple tokenization
141+
env_response_ids = await self.tokenize(
142+
messages=env_response,
143+
tools=None,
144+
model=state["model"],
145+
)
146+
147+
# we add suffix_ids to prev_turn_ids. suffix_ids are tokens that are added
148+
# by the chat template after messages, but not generated by the model, i.e.
149+
# they will be part of messages_ids (from the chat template) but not of
150+
# prev_turn_ids (from the engine). to not train OOD w.r.t. the chat
151+
# template, we add these suffix tokens to prev_turn_ids. we compute the
152+
# suffix_ids once, and cache them for future use. then, for each turn, we
153+
# find the largest overlap between the end of prev_turn_ids and the
154+
# beginning of the suffix_ids. this is to correctly handle truncated turns
155+
# that did not produce message delimiting tokens.
156+
if state.get("_cached_suffix_ids") is None:
157+
dummy_content = "World!"
158+
dummy_messages = cast(
159+
OpenAIChatMessages,
160+
[
161+
{"role": "user", "content": "Hello"},
162+
{"role": "assistant", "content": dummy_content},
163+
],
164+
)
165+
dummy_content_ids = await self.tokenize(
166+
messages=dummy_content,
167+
tools=oai_tools,
168+
model=state["model"],
169+
)
170+
dummy_messages_ids = await self.tokenize(
171+
messages=dummy_messages,
172+
tools=oai_tools,
173+
model=state["model"],
174+
extra_kwargs=dict(add_generation_prompt=False),
175+
)
176+
# these are typically chat template specific tokens, such as
177+
# eom tokens, newlines, etc.
178+
suffix_ids = compute_suffix_ids(dummy_messages_ids, dummy_content_ids[-1])
179+
state["_cached_suffix_ids"] = suffix_ids
180+
else:
181+
suffix_ids = state["_cached_suffix_ids"]
182+
overlap_len = find_largest_overlap(prev_turn_ids, suffix_ids)
183+
prev_turn_ids += suffix_ids[overlap_len:]
184+
185+
prompt_ids = prev_turn_ids + env_response_ids
186+
187+
return prompt_ids
188+
189+
async def tokenize(
190+
self,
191+
messages: str | OpenAIChatMessages,
192+
tools: list[OpenAITool] | None,
193+
model: str,
194+
extra_kwargs: dict = {},
195+
**kwargs,
196+
) -> list[int]:
197+
"""Tokenize messages using the vLLM /tokenize API."""
198+
if isinstance(messages, str):
199+
body = dict(
200+
model=model,
201+
prompt=messages,
202+
**extra_kwargs,
203+
)
204+
tokenize_response = await self.token_client.post(
205+
"/tokenize", body=body, cast_to=TokenizeResponse
206+
)
207+
else:
208+
body = dict(
209+
model=model,
210+
messages=messages,
211+
tools=tools,
212+
**extra_kwargs,
213+
)
214+
tokenize_response = await self.token_client.post(
215+
"/tokenize", body=body, cast_to=TokenizeResponse
216+
)
217+
return tokenize_response.tokens

verifiers/utils/token_utils.py

Lines changed: 0 additions & 174 deletions
This file was deleted.

0 commit comments

Comments
 (0)