|
14 | 14 | from datetime import datetime |
15 | 15 | from http import HTTPStatus |
16 | 16 | from pathlib import Path |
17 | | -from typing import (Annotated, Any, AsyncGenerator, AsyncIterator, List, |
18 | | - Optional, Union) |
| 17 | +from typing import (Annotated, Any, AsyncGenerator, AsyncIterator, Coroutine, |
| 18 | + List, Optional, Union) |
19 | 19 |
|
20 | 20 | import uvicorn |
21 | 21 | from fastapi import Body, FastAPI, Request |
@@ -1238,32 +1238,21 @@ async def chat_stream_generator( |
1238 | 1238 | self.multimodal_server_config, |
1239 | 1239 | request_media_io_kwargs=request.media_io_kwargs) |
1240 | 1240 |
|
1241 | | - if request.prompt_token_ids is not None: |
1242 | | - prompt = request.prompt_token_ids |
1243 | | - else: |
1244 | | - prompt: str = apply_chat_template( |
1245 | | - model_type=resolve_top_level_model_type(self.model_config), |
1246 | | - tokenizer=self.tokenizer, |
1247 | | - processor=self.processor, |
1248 | | - conversation=conversation, |
1249 | | - add_generation_prompt=request.add_generation_prompt, |
1250 | | - mm_placeholder_counts=mm_placeholder_counts, |
1251 | | - tools=tool_dicts, |
1252 | | - documents=request.documents, |
1253 | | - chat_template=request.chat_template or self.chat_template, |
1254 | | - chat_template_kwargs=request.chat_template_kwargs or {}, |
1255 | | - ) |
1256 | | - prompt = prompt_inputs(prompt) |
1257 | | - |
1258 | | - mm_data, mm_embeddings = await mm_coroutines |
1259 | | - if mm_data: |
1260 | | - prompt["multi_modal_data"] = mm_data |
1261 | | - if mm_embeddings: |
1262 | | - prompt["multi_modal_embeddings"] = mm_embeddings |
1263 | | - if mm_data and mm_embeddings: |
1264 | | - raise ValueError( |
1265 | | - "Passing 'multi_modal_data' and 'multi_modal_embeddings' at the same time is not supported." |
1266 | | - ) |
| 1241 | + prompt = await _prepare_chat_prompt_inputs_nonblocking( |
| 1242 | + model_type=resolve_top_level_model_type(self.model_config), |
| 1243 | + tokenizer=self.tokenizer, |
| 1244 | + processor=self.processor, |
| 1245 | + conversation=conversation, |
| 1246 | + add_generation_prompt=request.add_generation_prompt, |
| 1247 | + mm_coroutines=mm_coroutines, |
| 1248 | + mm_placeholder_counts=mm_placeholder_counts, |
| 1249 | + tools=tool_dicts, |
| 1250 | + documents=request.documents, |
| 1251 | + chat_template=request.chat_template or self.chat_template, |
| 1252 | + chat_template_kwargs=request.chat_template_kwargs, |
| 1253 | + prompt_token_ids=request.prompt_token_ids, |
| 1254 | + allow_mm_embeddings=True, |
| 1255 | + ) |
1267 | 1256 |
|
1268 | 1257 | postproc_args.reasoning_parser = self.generator.args.reasoning_parser |
1269 | 1258 | postproc_args.tool_parser = self.tool_parser |
@@ -1388,28 +1377,21 @@ async def create_mm_embedding_response(promise: RequestOutput): |
1388 | 1377 | self.multimodal_server_config, |
1389 | 1378 | request_media_io_kwargs=request.media_io_kwargs) |
1390 | 1379 |
|
1391 | | - if request.prompt_token_ids is not None: |
1392 | | - prompt = request.prompt_token_ids |
1393 | | - else: |
1394 | | - prompt: str = apply_chat_template( |
1395 | | - model_type=resolve_top_level_model_type(self.model_config), |
1396 | | - tokenizer=self.tokenizer, |
1397 | | - processor=self.processor, |
1398 | | - conversation=conversation, |
1399 | | - add_generation_prompt=request.add_generation_prompt, |
1400 | | - mm_placeholder_counts=mm_placeholder_counts, |
1401 | | - tools=tool_dicts, |
1402 | | - documents=request.documents, |
1403 | | - chat_template=request.chat_template, |
1404 | | - chat_template_kwargs=request.chat_template_kwargs or {}, |
1405 | | - ) |
1406 | | - prompt = prompt_inputs(prompt) |
1407 | | - |
1408 | | - mm_data, mm_embeddings = await mm_coroutines |
1409 | | - if mm_embeddings: |
1410 | | - raise ValueError("Cannot use multimodal embeddings as input") |
1411 | | - if mm_data is not None: |
1412 | | - prompt["multi_modal_data"] = mm_data |
| 1380 | + prompt = await _prepare_chat_prompt_inputs_nonblocking( |
| 1381 | + model_type=resolve_top_level_model_type(self.model_config), |
| 1382 | + tokenizer=self.tokenizer, |
| 1383 | + processor=self.processor, |
| 1384 | + conversation=conversation, |
| 1385 | + add_generation_prompt=request.add_generation_prompt, |
| 1386 | + mm_coroutines=mm_coroutines, |
| 1387 | + mm_placeholder_counts=mm_placeholder_counts, |
| 1388 | + tools=tool_dicts, |
| 1389 | + documents=request.documents, |
| 1390 | + chat_template=request.chat_template or self.chat_template, |
| 1391 | + chat_template_kwargs=request.chat_template_kwargs, |
| 1392 | + prompt_token_ids=request.prompt_token_ids, |
| 1393 | + allow_mm_embeddings=False, |
| 1394 | + ) |
1413 | 1395 |
|
1414 | 1396 | promise = self.generator.generate_async(inputs=prompt, ) |
1415 | 1397 | asyncio.create_task(self.await_disconnected(raw_request, promise)) |
@@ -2191,3 +2173,68 @@ async def _register_after_serving(): |
2191 | 2173 |
|
2192 | 2174 | asyncio.create_task(_register_after_serving()) |
2193 | 2175 | await server.serve(sockets=sockets) |
| 2176 | + |
| 2177 | + |
| 2178 | +async def _apply_chat_template_nonblocking(**kwargs: Any |
| 2179 | + ) -> Union[str, List[str]]: |
| 2180 | + """Apply a chat template without blocking the server event loop.""" |
| 2181 | + # Thread-safety note: this offloads work that shares the server tokenizer and processor, so |
| 2182 | + # hidden mutable state in a custom tokenizer is the main risk. |
| 2183 | + # For this openai_server.py use, the risk is low: apply_chat_template reads tokenizer/processor |
| 2184 | + # configuration and mutates only per-request conversation dictionaries. |
| 2185 | + return await asyncio.to_thread(apply_chat_template, **kwargs) |
| 2186 | + |
| 2187 | + |
| 2188 | +async def _prepare_chat_prompt_inputs_nonblocking( |
| 2189 | + *, |
| 2190 | + model_type: str, |
| 2191 | + tokenizer: Any, |
| 2192 | + processor: Any, |
| 2193 | + conversation: List[ConversationMessage], |
| 2194 | + add_generation_prompt: bool, |
| 2195 | + mm_coroutines: Coroutine[Any, Any, tuple[Optional[dict[str, List[Any]]], |
| 2196 | + Optional[dict[str, List[Any]]]]], |
| 2197 | + mm_placeholder_counts: List[dict[str, int]], |
| 2198 | + tools: Optional[List[dict[str, Any]]], |
| 2199 | + documents: Optional[List[dict[str, str]]], |
| 2200 | + chat_template: Optional[str], |
| 2201 | + chat_template_kwargs: Optional[dict[str, Any]], |
| 2202 | + prompt_token_ids: Optional[List[int]], |
| 2203 | + allow_mm_embeddings: bool, |
| 2204 | +) -> dict[str, Any]: |
| 2205 | + """Prepare prompt inputs while overlapping media loading and template rendering.""" |
| 2206 | + mm_task = asyncio.create_task(mm_coroutines) |
| 2207 | + try: |
| 2208 | + if prompt_token_ids is not None: |
| 2209 | + prompt = prompt_token_ids |
| 2210 | + else: |
| 2211 | + prompt = await _apply_chat_template_nonblocking( |
| 2212 | + model_type=model_type, |
| 2213 | + tokenizer=tokenizer, |
| 2214 | + processor=processor, |
| 2215 | + conversation=conversation, |
| 2216 | + add_generation_prompt=add_generation_prompt, |
| 2217 | + mm_placeholder_counts=mm_placeholder_counts, |
| 2218 | + tools=tools, |
| 2219 | + documents=documents, |
| 2220 | + chat_template=chat_template, |
| 2221 | + chat_template_kwargs=chat_template_kwargs or {}, |
| 2222 | + ) |
| 2223 | + prepared_prompt = prompt_inputs(prompt) |
| 2224 | + |
| 2225 | + mm_data, mm_embeddings = await mm_task |
| 2226 | + if mm_data: |
| 2227 | + prepared_prompt["multi_modal_data"] = mm_data |
| 2228 | + if mm_embeddings: |
| 2229 | + if not allow_mm_embeddings: |
| 2230 | + raise ValueError("Cannot use multimodal embeddings as input") |
| 2231 | + prepared_prompt["multi_modal_embeddings"] = mm_embeddings |
| 2232 | + if mm_data and mm_embeddings: |
| 2233 | + raise ValueError( |
| 2234 | + "Passing 'multi_modal_data' and 'multi_modal_embeddings' at the same time is not supported." |
| 2235 | + ) |
| 2236 | + return prepared_prompt |
| 2237 | + finally: |
| 2238 | + if not mm_task.done(): |
| 2239 | + mm_task.cancel() |
| 2240 | + await asyncio.gather(mm_task, return_exceptions=True) |
0 commit comments