Skip to content

Commit 5c5024f

Browse files
[None][perf] offload chat template rendering in serving
Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
1 parent 9a7f76f commit 5c5024f

5 files changed

Lines changed: 91 additions & 16 deletions

File tree

tensorrt_llm/inputs/utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,37 @@ def apply_chat_template(
706706
return result
707707

708708

709+
async def async_apply_chat_template(
710+
*,
711+
model_type: str,
712+
tokenizer: Union[TransformersTokenizer, TokenizerBase],
713+
processor: ProcessorMixin,
714+
conversation: list[ConversationMessage],
715+
add_generation_prompt: bool,
716+
mm_placeholder_counts: list[dict[str, int]],
717+
tools: Optional[list[dict[str, Any]]] = None,
718+
documents: Optional[list[dict[str, str]]] = None,
719+
chat_template: Optional[str] = None,
720+
chat_template_kwargs: Optional[dict[str, Any]] = None,
721+
enable_tokenize: bool = False,
722+
) -> (str | List[str]):
723+
"""Apply chat template without blocking the event loop."""
724+
return await asyncio.to_thread(
725+
apply_chat_template,
726+
model_type=model_type,
727+
tokenizer=tokenizer,
728+
processor=processor,
729+
conversation=conversation,
730+
add_generation_prompt=add_generation_prompt,
731+
mm_placeholder_counts=mm_placeholder_counts,
732+
tools=tools,
733+
documents=documents,
734+
chat_template=chat_template,
735+
chat_template_kwargs=chat_template_kwargs,
736+
enable_tokenize=enable_tokenize,
737+
)
738+
739+
709740
def default_multimodal_input_loader(
710741
*,
711742
tokenizer: Optional[Union[TransformersTokenizer, TokenizerBase]],

tensorrt_llm/serve/openai_server.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
from tensorrt_llm.inputs import prompt_inputs
3535
from tensorrt_llm.inputs.data import TokensPrompt
3636
from tensorrt_llm.inputs.multimodal import MultimodalServerConfig
37-
from tensorrt_llm.inputs.utils import ConversationMessage, apply_chat_template
37+
from tensorrt_llm.inputs.utils import (ConversationMessage,
38+
async_apply_chat_template)
3839
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
3940
from tensorrt_llm.llmapi import MultimodalEncoder, SchedulingParams, tracing
4041
from tensorrt_llm.llmapi.disagg_utils import (DisaggClusterConfig,
@@ -1200,7 +1201,7 @@ async def chat_stream_generator(
12001201
if request.prompt_token_ids is not None:
12011202
prompt = request.prompt_token_ids
12021203
else:
1203-
prompt: str = apply_chat_template(
1204+
prompt_task = async_apply_chat_template(
12041205
model_type=resolve_top_level_model_type(self.model_config),
12051206
tokenizer=self.tokenizer,
12061207
processor=self.processor,
@@ -1212,9 +1213,12 @@ async def chat_stream_generator(
12121213
chat_template=request.chat_template or self.chat_template,
12131214
chat_template_kwargs=request.chat_template_kwargs or {},
12141215
)
1216+
prompt, (mm_data, mm_embeddings) = await asyncio.gather(
1217+
prompt_task, mm_coroutines)
12151218
prompt = prompt_inputs(prompt)
12161219

1217-
mm_data, mm_embeddings = await mm_coroutines
1220+
if request.prompt_token_ids is not None:
1221+
mm_data, mm_embeddings = await mm_coroutines
12181222
if mm_data:
12191223
prompt["multi_modal_data"] = mm_data
12201224
if mm_embeddings:
@@ -1350,7 +1354,7 @@ async def create_mm_embedding_response(promise: RequestOutput):
13501354
if request.prompt_token_ids is not None:
13511355
prompt = request.prompt_token_ids
13521356
else:
1353-
prompt: str = apply_chat_template(
1357+
prompt_task = async_apply_chat_template(
13541358
model_type=resolve_top_level_model_type(self.model_config),
13551359
tokenizer=self.tokenizer,
13561360
processor=self.processor,
@@ -1362,9 +1366,12 @@ async def create_mm_embedding_response(promise: RequestOutput):
13621366
chat_template=request.chat_template,
13631367
chat_template_kwargs=request.chat_template_kwargs or {},
13641368
)
1369+
prompt, (mm_data, mm_embeddings) = await asyncio.gather(
1370+
prompt_task, mm_coroutines)
13651371
prompt = prompt_inputs(prompt)
13661372

1367-
mm_data, mm_embeddings = await mm_coroutines
1373+
if request.prompt_token_ids is not None:
1374+
mm_data, mm_embeddings = await mm_coroutines
13681375
if mm_embeddings:
13691376
raise ValueError("Cannot use multimodal embeddings as input")
13701377
if mm_data is not None:

tensorrt_llm/serve/resource_governor.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
LLM/Proxy/Worker chain.
2020
"""
2121

22+
import asyncio
2223
import traceback
2324
from http import HTTPStatus
2425
from typing import Callable, List, Optional
@@ -27,7 +28,7 @@
2728
from starlette.responses import JSONResponse, Response
2829

2930
from tensorrt_llm.executor.request import TruncateKVCacheRequest
30-
from tensorrt_llm.inputs.utils import ConversationMessage, apply_chat_template
31+
from tensorrt_llm.inputs.utils import ConversationMessage, async_apply_chat_template
3132
from tensorrt_llm.logger import logger
3233
from tensorrt_llm.serve.chat_utils import parse_chat_messages_coroutines
3334
from tensorrt_llm.serve.openai_protocol import KVCacheTruncateRequest
@@ -86,7 +87,7 @@ def _put_or_unavailable(self, request: TruncateKVCacheRequest) -> Optional[Respo
8687
queue.put(request)
8788
return None
8889

89-
def _convert_messages(
90+
async def _convert_messages(
9091
self,
9192
messages,
9293
tool_dicts,
@@ -97,20 +98,24 @@ def _convert_messages(
9798
) -> List[int]:
9899
"""Convert chat messages to token IDs via chat template + tokenization."""
99100
conversation: List[ConversationMessage] = []
100-
conversation, _, __ = parse_chat_messages_coroutines(messages, self.model_config, None)
101-
return apply_chat_template(
101+
conversation, mm_coroutines, mm_placeholder_counts = parse_chat_messages_coroutines(
102+
messages, self.model_config, None
103+
)
104+
token_task = async_apply_chat_template(
102105
model_type=self.model_config.model_type,
103106
tokenizer=self.tokenizer,
104107
processor=self.processor,
105108
conversation=conversation,
106109
add_generation_prompt=add_generation_prompt,
107-
mm_placeholder_counts=[],
110+
mm_placeholder_counts=mm_placeholder_counts,
108111
tools=tool_dicts,
109112
documents=documents,
110113
chat_template=chat_template,
111114
chat_template_kwargs=chat_template_kwargs or {},
112115
enable_tokenize=True,
113116
)
117+
token_ids, _ = await asyncio.gather(token_task, mm_coroutines)
118+
return token_ids
114119

115120
async def _truncate_kv_cache(self, request: KVCacheTruncateRequest) -> Response:
116121
try:
@@ -120,7 +125,7 @@ async def _truncate_kv_cache(self, request: KVCacheTruncateRequest) -> Response:
120125
chat_template_kwargs = request.chat_template_kwargs or {}
121126

122127
messages_to_retain = (
123-
self._convert_messages(
128+
await self._convert_messages(
124129
request.messages_to_retain,
125130
tool_dicts,
126131
request.add_generation_prompt,
@@ -133,7 +138,7 @@ async def _truncate_kv_cache(self, request: KVCacheTruncateRequest) -> Response:
133138
)
134139

135140
messages = (
136-
self._convert_messages(
141+
await self._convert_messages(
137142
request.messages,
138143
tool_dicts,
139144
request.add_generation_prompt,

tensorrt_llm/serve/responses_utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242

4343
from tensorrt_llm.bindings import steady_clock_now
4444
from tensorrt_llm.executor import GenerationResult
45-
from tensorrt_llm.inputs.utils import apply_chat_template
45+
from tensorrt_llm.inputs.utils import async_apply_chat_template
4646
from tensorrt_llm.llmapi import SamplingParams
4747
from tensorrt_llm.llmapi.llm import RequestOutput
4848
from tensorrt_llm.llmapi.reasoning_parser import (BaseReasoningParser,
@@ -821,13 +821,11 @@ async def _create_input_tokens(
821821

822822
conversation, mm_coroutines, mm_placeholder_counts = parse_chat_messages_coroutines(
823823
messages, model_config)
824-
mm_data = await mm_coroutines
825-
826824
tools_dict = [
827825
tool.model_dump()
828826
for tool in _get_chat_completion_function_tools(request.tools)
829827
]
830-
token_ids = apply_chat_template(
828+
token_task = async_apply_chat_template(
831829
model_type=resolve_top_level_model_type(model_config),
832830
tokenizer=tokenizer,
833831
processor=processor,
@@ -837,6 +835,7 @@ async def _create_input_tokens(
837835
mm_placeholder_counts=mm_placeholder_counts,
838836
enable_tokenize=True,
839837
)
838+
token_ids, mm_data = await asyncio.gather(token_task, mm_coroutines)
840839

841840
return token_ids, mm_data
842841

tests/unittest/inputs/test_chat_template_dispatch.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
"""Tests for content-format-driven chat template dispatch and placeholder handling."""
44

5+
import threading
6+
57
import pytest
68

79
from tensorrt_llm.inputs.content_format import ContentFormat
@@ -15,6 +17,7 @@
1517
_build_openai_content,
1618
_resolve_content_format,
1719
add_multimodal_placeholders,
20+
async_apply_chat_template,
1821
interleave_mm_placeholders,
1922
)
2023

@@ -324,3 +327,33 @@ def test_excess_existing_placeholders_preserved(self):
324327
)
325328
assert result == text
326329
assert result.count("<image>") == 3
330+
331+
332+
class TestAsyncApplyChatTemplate:
333+
@pytest.mark.asyncio
334+
async def test_runs_in_worker_thread(self):
335+
event_loop_thread_id = threading.current_thread().ident
336+
337+
class TrackingTokenizer:
338+
def __init__(self):
339+
self.worker_thread_id = None
340+
341+
def apply_chat_template(self, **_):
342+
self.worker_thread_id = threading.current_thread().ident
343+
return "rendered"
344+
345+
tokenizer = TrackingTokenizer()
346+
347+
result = await async_apply_chat_template(
348+
model_type="test_string_model",
349+
tokenizer=tokenizer,
350+
processor=None,
351+
conversation=[ConversationMessage(role="user", content="hello", media=[])],
352+
add_generation_prompt=True,
353+
mm_placeholder_counts=[{}],
354+
chat_template="{{ messages }}",
355+
)
356+
357+
assert result == "rendered"
358+
assert tokenizer.worker_thread_id is not None
359+
assert tokenizer.worker_thread_id != event_loop_thread_id

0 commit comments

Comments
 (0)