Skip to content

Commit 5ba141e

Browse files
committed
[TRTLLM-13024][perf] Make chat template application non-blocking
Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
1 parent 01d8ccb commit 5ba141e

1 file changed

Lines changed: 97 additions & 50 deletions

File tree

tensorrt_llm/serve/openai_server.py

Lines changed: 97 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from datetime import datetime
1515
from http import HTTPStatus
1616
from pathlib import Path
17-
from typing import (Annotated, Any, AsyncGenerator, AsyncIterator, List,
18-
Optional, Union)
17+
from typing import (Annotated, Any, AsyncGenerator, AsyncIterator, Coroutine,
18+
List, Optional, Union)
1919

2020
import uvicorn
2121
from fastapi import Body, FastAPI, Request
@@ -1238,32 +1238,21 @@ async def chat_stream_generator(
12381238
self.multimodal_server_config,
12391239
request_media_io_kwargs=request.media_io_kwargs)
12401240

1241-
if request.prompt_token_ids is not None:
1242-
prompt = request.prompt_token_ids
1243-
else:
1244-
prompt: str = apply_chat_template(
1245-
model_type=resolve_top_level_model_type(self.model_config),
1246-
tokenizer=self.tokenizer,
1247-
processor=self.processor,
1248-
conversation=conversation,
1249-
add_generation_prompt=request.add_generation_prompt,
1250-
mm_placeholder_counts=mm_placeholder_counts,
1251-
tools=tool_dicts,
1252-
documents=request.documents,
1253-
chat_template=request.chat_template or self.chat_template,
1254-
chat_template_kwargs=request.chat_template_kwargs or {},
1255-
)
1256-
prompt = prompt_inputs(prompt)
1257-
1258-
mm_data, mm_embeddings = await mm_coroutines
1259-
if mm_data:
1260-
prompt["multi_modal_data"] = mm_data
1261-
if mm_embeddings:
1262-
prompt["multi_modal_embeddings"] = mm_embeddings
1263-
if mm_data and mm_embeddings:
1264-
raise ValueError(
1265-
"Passing 'multi_modal_data' and 'multi_modal_embeddings' at the same time is not supported."
1266-
)
1241+
prompt = await _prepare_chat_prompt_inputs_nonblocking(
1242+
model_type=resolve_top_level_model_type(self.model_config),
1243+
tokenizer=self.tokenizer,
1244+
processor=self.processor,
1245+
conversation=conversation,
1246+
add_generation_prompt=request.add_generation_prompt,
1247+
mm_coroutines=mm_coroutines,
1248+
mm_placeholder_counts=mm_placeholder_counts,
1249+
tools=tool_dicts,
1250+
documents=request.documents,
1251+
chat_template=request.chat_template or self.chat_template,
1252+
chat_template_kwargs=request.chat_template_kwargs,
1253+
prompt_token_ids=request.prompt_token_ids,
1254+
allow_mm_embeddings=True,
1255+
)
12671256

12681257
postproc_args.reasoning_parser = self.generator.args.reasoning_parser
12691258
postproc_args.tool_parser = self.tool_parser
@@ -1388,28 +1377,21 @@ async def create_mm_embedding_response(promise: RequestOutput):
13881377
self.multimodal_server_config,
13891378
request_media_io_kwargs=request.media_io_kwargs)
13901379

1391-
if request.prompt_token_ids is not None:
1392-
prompt = request.prompt_token_ids
1393-
else:
1394-
prompt: str = apply_chat_template(
1395-
model_type=resolve_top_level_model_type(self.model_config),
1396-
tokenizer=self.tokenizer,
1397-
processor=self.processor,
1398-
conversation=conversation,
1399-
add_generation_prompt=request.add_generation_prompt,
1400-
mm_placeholder_counts=mm_placeholder_counts,
1401-
tools=tool_dicts,
1402-
documents=request.documents,
1403-
chat_template=request.chat_template,
1404-
chat_template_kwargs=request.chat_template_kwargs or {},
1405-
)
1406-
prompt = prompt_inputs(prompt)
1407-
1408-
mm_data, mm_embeddings = await mm_coroutines
1409-
if mm_embeddings:
1410-
raise ValueError("Cannot use multimodal embeddings as input")
1411-
if mm_data is not None:
1412-
prompt["multi_modal_data"] = mm_data
1380+
prompt = await _prepare_chat_prompt_inputs_nonblocking(
1381+
model_type=resolve_top_level_model_type(self.model_config),
1382+
tokenizer=self.tokenizer,
1383+
processor=self.processor,
1384+
conversation=conversation,
1385+
add_generation_prompt=request.add_generation_prompt,
1386+
mm_coroutines=mm_coroutines,
1387+
mm_placeholder_counts=mm_placeholder_counts,
1388+
tools=tool_dicts,
1389+
documents=request.documents,
1390+
chat_template=request.chat_template or self.chat_template,
1391+
chat_template_kwargs=request.chat_template_kwargs,
1392+
prompt_token_ids=request.prompt_token_ids,
1393+
allow_mm_embeddings=False,
1394+
)
14131395

14141396
promise = self.generator.generate_async(inputs=prompt, )
14151397
asyncio.create_task(self.await_disconnected(raw_request, promise))
@@ -2191,3 +2173,68 @@ async def _register_after_serving():
21912173

21922174
asyncio.create_task(_register_after_serving())
21932175
await server.serve(sockets=sockets)
2176+
2177+
2178+
async def _apply_chat_template_nonblocking(**kwargs: Any
2179+
) -> Union[str, List[str]]:
2180+
"""Apply a chat template without blocking the server event loop."""
2181+
# Thread-safety note: this offloads work that shares the server tokenizer and processor, so
2182+
# hidden mutable state in a custom tokenizer is the main risk.
2183+
# For this openai_server.py use, the risk is low: apply_chat_template reads tokenizer/processor
2184+
# configuration and mutates only per-request conversation dictionaries.
2185+
return await asyncio.to_thread(apply_chat_template, **kwargs)
2186+
2187+
2188+
async def _prepare_chat_prompt_inputs_nonblocking(
2189+
*,
2190+
model_type: str,
2191+
tokenizer: Any,
2192+
processor: Any,
2193+
conversation: List[ConversationMessage],
2194+
add_generation_prompt: bool,
2195+
mm_coroutines: Coroutine[Any, Any, tuple[Optional[dict[str, List[Any]]],
2196+
Optional[dict[str, List[Any]]]]],
2197+
mm_placeholder_counts: List[dict[str, int]],
2198+
tools: Optional[List[dict[str, Any]]],
2199+
documents: Optional[List[dict[str, str]]],
2200+
chat_template: Optional[str],
2201+
chat_template_kwargs: Optional[dict[str, Any]],
2202+
prompt_token_ids: Optional[List[int]],
2203+
allow_mm_embeddings: bool,
2204+
) -> dict[str, Any]:
2205+
"""Prepare prompt inputs while overlapping media loading and template rendering."""
2206+
mm_task = asyncio.create_task(mm_coroutines)
2207+
try:
2208+
if prompt_token_ids is not None:
2209+
prompt = prompt_token_ids
2210+
else:
2211+
prompt = await _apply_chat_template_nonblocking(
2212+
model_type=model_type,
2213+
tokenizer=tokenizer,
2214+
processor=processor,
2215+
conversation=conversation,
2216+
add_generation_prompt=add_generation_prompt,
2217+
mm_placeholder_counts=mm_placeholder_counts,
2218+
tools=tools,
2219+
documents=documents,
2220+
chat_template=chat_template,
2221+
chat_template_kwargs=chat_template_kwargs or {},
2222+
)
2223+
prepared_prompt = prompt_inputs(prompt)
2224+
2225+
mm_data, mm_embeddings = await mm_task
2226+
if mm_data:
2227+
prepared_prompt["multi_modal_data"] = mm_data
2228+
if mm_embeddings:
2229+
if not allow_mm_embeddings:
2230+
raise ValueError("Cannot use multimodal embeddings as input")
2231+
prepared_prompt["multi_modal_embeddings"] = mm_embeddings
2232+
if mm_data and mm_embeddings:
2233+
raise ValueError(
2234+
"Passing 'multi_modal_data' and 'multi_modal_embeddings' at the same time is not supported."
2235+
)
2236+
return prepared_prompt
2237+
finally:
2238+
if not mm_task.done():
2239+
mm_task.cancel()
2240+
await asyncio.gather(mm_task, return_exceptions=True)

0 commit comments

Comments
 (0)