1- from typing import cast
1+ from typing import Optional , cast
22
3+ from openai import AsyncOpenAI , BaseModel
34from openai .types .chat import ChatCompletion
45
56from verifiers .clients .openai_chat_completions_client import (
67 OpenAIChatCompletionsClient ,
78 OpenAIChatMessages ,
89 OpenAIChatResponse ,
910 OpenAITool ,
10- )
11- from verifiers .clients .openai_chat_completions_client import (
1211 handle_openai_overlong_prompt ,
1312)
1413from verifiers .types import SamplingArgs , State
15- from verifiers .utils .token_utils import get_prompt_ids
14+ from verifiers .utils .message_utils import concat_messages
15+
16+
17+ # copy from vllm/entrypoints/openai/protocol.py
18+ class TokenizeResponse (BaseModel ):
19+ count : int
20+ max_model_len : int
21+ tokens : list [int ]
22+ token_strs : Optional [list [str ]] = None
1623
1724
1825class OpenAIChatCompletionsTokenClient (OpenAIChatCompletionsClient ):
1926 """Wrapper for custom vLLM route /v1/chat/completions/tokens via AsyncOpenAI client."""
2027
28+ @property
29+ def token_client (self ) -> AsyncOpenAI :
30+ """Strips trailing /v1 from the OpenAI client."""
31+ base_url = str (self .client .base_url ).rstrip ("/" )
32+ if base_url .endswith ("/v1" ):
33+ base_url = base_url [:- 3 ]
34+ return self .client .with_options (base_url = base_url )
35+
2136 @handle_openai_overlong_prompt
2237 async def get_native_response (
2338 self ,
@@ -49,7 +64,7 @@ def normalize_sampling_args(sampling_args: SamplingArgs):
4964 return await super ().get_native_response (
5065 prompt , model , sampling_args , tools
5166 )
52- prompt_ids = await get_prompt_ids (state , prompt , tools , self . client )
67+ prompt_ids = await self . get_prompt_ids (state , prompt , tools )
5368 extra_body = sampling_args .pop ("extra_body" , {})
5469 body = dict (
5570 model = model ,
@@ -65,3 +80,138 @@ def normalize_sampling_args(sampling_args: SamplingArgs):
6580 body = body ,
6681 cast_to = ChatCompletion ,
6782 )
83+
84+ async def get_prompt_ids (
85+ self ,
86+ state : State ,
87+ prompt_messages : OpenAIChatMessages ,
88+ oai_tools : list [OpenAITool ] | None ,
89+ ) -> list [int ]:
90+ """
91+ Build prompt_ids (token prompt) corresponding to prompt_messages. We assume
92+ that this method is called *before* making the model response from
93+ prompt_messages, i.e. the previous turn's prompt and completion do not yet
94+ include the environment response and next turn's model response.
95+ """
96+ prev_turn_prompt = state ["trajectory" ][- 1 ]["prompt" ]
97+ prev_turn_completion = state ["trajectory" ][- 1 ]["completion" ]
98+ prev_turn_tokens = state ["trajectory" ][- 1 ]["tokens" ]
99+ assert prev_turn_tokens is not None
100+ prev_turn_prompt_ids = prev_turn_tokens ["prompt_ids" ]
101+ prev_turn_completion_ids = prev_turn_tokens ["completion_ids" ]
102+ prev_turn_ids = prev_turn_prompt_ids + prev_turn_completion_ids
103+
104+ # the env response is all messages after the previous turn
105+ messages = concat_messages ([prev_turn_prompt , prev_turn_completion ])
106+ env_response = prompt_messages [len (messages ) :]
107+
108+ def compute_suffix_ids (lst : list [int ], value : int ) -> list [int ]:
109+ """Returns all tokens after the last occurrence of `value` in `lst`, if any."""
110+
111+ def find_last_index (lst : list [int ], target : int ) -> int :
112+ for i in range (len (lst ) - 1 , - 1 , - 1 ):
113+ if lst [i ] == target :
114+ return i
115+ raise ValueError
116+
117+ try :
118+ i = find_last_index (lst , value )
119+ suffix_ids = lst [i + 1 :]
120+ return suffix_ids
121+ except ValueError :
122+ # end of message token not found, so we don't need to add any suffix tokens
123+ return []
124+
125+ def find_largest_overlap (a : list [int ], b : list [int ]) -> int :
126+ """Find the largest overlapping sequence between the end of a and beginning of b."""
127+ if not a or not b :
128+ return 0
129+
130+ max_possible = min (len (a ), len (b ))
131+ for overlap_len in reversed (range (1 , max_possible + 1 )):
132+ a_suffix = a [- overlap_len :]
133+ b_prefix = b [:overlap_len ]
134+
135+ if a_suffix == b_prefix :
136+ return overlap_len
137+
138+ return 0
139+
140+ # we build the env_response_ids using simple tokenization
141+ env_response_ids = await self .tokenize (
142+ messages = env_response ,
143+ tools = None ,
144+ model = state ["model" ],
145+ )
146+
147+ # we add suffix_ids to prev_turn_ids. suffix_ids are tokens that are added
148+ # by the chat template after messages, but not generated by the model, i.e.
149+ # they will be part of messages_ids (from the chat template) but not of
150+ # prev_turn_ids (from the engine). to not train OOD w.r.t. the chat
151+ # template, we add these suffix tokens to prev_turn_ids. we compute the
152+ # suffix_ids once, and cache them for future use. then, for each turn, we
153+ # find the largest overlap between the end of prev_turn_ids and the
154+ # beginning of the suffix_ids. this is to correctly handle truncated turns
155+ # that did not produce message delimiting tokens.
156+ if state .get ("_cached_suffix_ids" ) is None :
157+ dummy_content = "World!"
158+ dummy_messages = cast (
159+ OpenAIChatMessages ,
160+ [
161+ {"role" : "user" , "content" : "Hello" },
162+ {"role" : "assistant" , "content" : dummy_content },
163+ ],
164+ )
165+ dummy_content_ids = await self .tokenize (
166+ messages = dummy_content ,
167+ tools = oai_tools ,
168+ model = state ["model" ],
169+ )
170+ dummy_messages_ids = await self .tokenize (
171+ messages = dummy_messages ,
172+ tools = oai_tools ,
173+ model = state ["model" ],
174+ extra_kwargs = dict (add_generation_prompt = False ),
175+ )
176+ # these are typically chat template specific tokens, such as
177+ # eom tokens, newlines, etc.
178+ suffix_ids = compute_suffix_ids (dummy_messages_ids , dummy_content_ids [- 1 ])
179+ state ["_cached_suffix_ids" ] = suffix_ids
180+ else :
181+ suffix_ids = state ["_cached_suffix_ids" ]
182+ overlap_len = find_largest_overlap (prev_turn_ids , suffix_ids )
183+ prev_turn_ids += suffix_ids [overlap_len :]
184+
185+ prompt_ids = prev_turn_ids + env_response_ids
186+
187+ return prompt_ids
188+
189+ async def tokenize (
190+ self ,
191+ messages : str | OpenAIChatMessages ,
192+ tools : list [OpenAITool ] | None ,
193+ model : str ,
194+ extra_kwargs : dict = {},
195+ ** kwargs ,
196+ ) -> list [int ]:
197+ """Tokenize messages using the vLLM /tokenize API."""
198+ if isinstance (messages , str ):
199+ body = dict (
200+ model = model ,
201+ prompt = messages ,
202+ ** extra_kwargs ,
203+ )
204+ tokenize_response = await self .token_client .post (
205+ "/tokenize" , body = body , cast_to = TokenizeResponse
206+ )
207+ else :
208+ body = dict (
209+ model = model ,
210+ messages = messages ,
211+ tools = tools ,
212+ ** extra_kwargs ,
213+ )
214+ tokenize_response = await self .token_client .post (
215+ "/tokenize" , body = body , cast_to = TokenizeResponse
216+ )
217+ return tokenize_response .tokens
0 commit comments