Skip to content

Commit 5fc1eb9

Browse files
[None][perf] offload chat template rendering in serving
Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
1 parent 416bdb2 commit 5fc1eb9

5 files changed

Lines changed: 91 additions & 16 deletions

File tree

tensorrt_llm/inputs/utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,37 @@ def apply_chat_template(
705705
return result
706706

707707

708+
async def async_apply_chat_template(
709+
*,
710+
model_type: str,
711+
tokenizer: Union[TransformersTokenizer, TokenizerBase],
712+
processor: ProcessorMixin,
713+
conversation: list[ConversationMessage],
714+
add_generation_prompt: bool,
715+
mm_placeholder_counts: list[dict[str, int]],
716+
tools: Optional[list[dict[str, Any]]] = None,
717+
documents: Optional[list[dict[str, str]]] = None,
718+
chat_template: Optional[str] = None,
719+
chat_template_kwargs: Optional[dict[str, Any]] = None,
720+
enable_tokenize: bool = False,
721+
) -> (str | List[str]):
722+
"""Apply chat template without blocking the event loop."""
723+
return await asyncio.to_thread(
724+
apply_chat_template,
725+
model_type=model_type,
726+
tokenizer=tokenizer,
727+
processor=processor,
728+
conversation=conversation,
729+
add_generation_prompt=add_generation_prompt,
730+
mm_placeholder_counts=mm_placeholder_counts,
731+
tools=tools,
732+
documents=documents,
733+
chat_template=chat_template,
734+
chat_template_kwargs=chat_template_kwargs,
735+
enable_tokenize=enable_tokenize,
736+
)
737+
738+
708739
def default_multimodal_input_loader(
709740
*,
710741
tokenizer: Optional[Union[TransformersTokenizer, TokenizerBase]],

tensorrt_llm/serve/openai_server.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@
3535
from tensorrt_llm.inputs import prompt_inputs
3636
from tensorrt_llm.inputs.data import TokensPrompt
3737
from tensorrt_llm.inputs.multimodal import MultimodalServerConfig
38-
from tensorrt_llm.inputs.utils import ConversationMessage, apply_chat_template
38+
from tensorrt_llm.inputs.utils import (ConversationMessage,
39+
async_apply_chat_template)
3940
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
4041
from tensorrt_llm.llmapi import MultimodalEncoder, SchedulingParams, tracing
4142
from tensorrt_llm.llmapi.disagg_utils import (DisaggClusterConfig,
@@ -1241,7 +1242,7 @@ async def chat_stream_generator(
12411242
if request.prompt_token_ids is not None:
12421243
prompt = request.prompt_token_ids
12431244
else:
1244-
prompt: str = apply_chat_template(
1245+
prompt_task = async_apply_chat_template(
12451246
model_type=resolve_top_level_model_type(self.model_config),
12461247
tokenizer=self.tokenizer,
12471248
processor=self.processor,
@@ -1253,9 +1254,12 @@ async def chat_stream_generator(
12531254
chat_template=request.chat_template or self.chat_template,
12541255
chat_template_kwargs=request.chat_template_kwargs or {},
12551256
)
1257+
prompt, (mm_data, mm_embeddings) = await asyncio.gather(
1258+
prompt_task, mm_coroutines)
12561259
prompt = prompt_inputs(prompt)
12571260

1258-
mm_data, mm_embeddings = await mm_coroutines
1261+
if request.prompt_token_ids is not None:
1262+
mm_data, mm_embeddings = await mm_coroutines
12591263
if mm_data:
12601264
prompt["multi_modal_data"] = mm_data
12611265
if mm_embeddings:
@@ -1394,7 +1398,7 @@ async def create_mm_embedding_response(promise: RequestOutput):
13941398
if request.prompt_token_ids is not None:
13951399
prompt = request.prompt_token_ids
13961400
else:
1397-
prompt: str = apply_chat_template(
1401+
prompt_task = async_apply_chat_template(
13981402
model_type=resolve_top_level_model_type(self.model_config),
13991403
tokenizer=self.tokenizer,
14001404
processor=self.processor,
@@ -1406,9 +1410,12 @@ async def create_mm_embedding_response(promise: RequestOutput):
14061410
chat_template=request.chat_template,
14071411
chat_template_kwargs=request.chat_template_kwargs or {},
14081412
)
1413+
prompt, (mm_data, mm_embeddings) = await asyncio.gather(
1414+
prompt_task, mm_coroutines)
14091415
prompt = prompt_inputs(prompt)
14101416

1411-
mm_data, mm_embeddings = await mm_coroutines
1417+
if request.prompt_token_ids is not None:
1418+
mm_data, mm_embeddings = await mm_coroutines
14121419
if mm_embeddings:
14131420
raise ValueError("Cannot use multimodal embeddings as input")
14141421
if mm_data is not None:

tensorrt_llm/serve/resource_governor.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
LLM/Proxy/Worker chain.
2020
"""
2121

22+
import asyncio
2223
import traceback
2324
from http import HTTPStatus
2425
from typing import Callable, List, Optional
@@ -27,7 +28,7 @@
2728
from starlette.responses import JSONResponse, Response
2829

2930
from tensorrt_llm.executor.request import TruncateKVCacheRequest
30-
from tensorrt_llm.inputs.utils import ConversationMessage, apply_chat_template
31+
from tensorrt_llm.inputs.utils import ConversationMessage, async_apply_chat_template
3132
from tensorrt_llm.logger import logger
3233
from tensorrt_llm.serve.chat_utils import parse_chat_messages_coroutines
3334
from tensorrt_llm.serve.openai_protocol import KVCacheTruncateRequest
@@ -86,7 +87,7 @@ def _put_or_unavailable(self, request: TruncateKVCacheRequest) -> Optional[Respo
8687
queue.put(request)
8788
return None
8889

89-
def _convert_messages(
90+
async def _convert_messages(
9091
self,
9192
messages,
9293
tool_dicts,
@@ -97,20 +98,24 @@ def _convert_messages(
9798
) -> List[int]:
9899
"""Convert chat messages to token IDs via chat template + tokenization."""
99100
conversation: List[ConversationMessage] = []
100-
conversation, _, __ = parse_chat_messages_coroutines(messages, self.model_config, None)
101-
return apply_chat_template(
101+
conversation, mm_coroutines, mm_placeholder_counts = parse_chat_messages_coroutines(
102+
messages, self.model_config, None
103+
)
104+
token_task = async_apply_chat_template(
102105
model_type=self.model_config.model_type,
103106
tokenizer=self.tokenizer,
104107
processor=self.processor,
105108
conversation=conversation,
106109
add_generation_prompt=add_generation_prompt,
107-
mm_placeholder_counts=[],
110+
mm_placeholder_counts=mm_placeholder_counts,
108111
tools=tool_dicts,
109112
documents=documents,
110113
chat_template=chat_template,
111114
chat_template_kwargs=chat_template_kwargs or {},
112115
enable_tokenize=True,
113116
)
117+
token_ids, _ = await asyncio.gather(token_task, mm_coroutines)
118+
return token_ids
114119

115120
async def _truncate_kv_cache(self, request: KVCacheTruncateRequest) -> Response:
116121
try:
@@ -120,7 +125,7 @@ async def _truncate_kv_cache(self, request: KVCacheTruncateRequest) -> Response:
120125
chat_template_kwargs = request.chat_template_kwargs or {}
121126

122127
messages_to_retain = (
123-
self._convert_messages(
128+
await self._convert_messages(
124129
request.messages_to_retain,
125130
tool_dicts,
126131
request.add_generation_prompt,
@@ -133,7 +138,7 @@ async def _truncate_kv_cache(self, request: KVCacheTruncateRequest) -> Response:
133138
)
134139

135140
messages = (
136-
self._convert_messages(
141+
await self._convert_messages(
137142
request.messages,
138143
tool_dicts,
139144
request.add_generation_prompt,

tensorrt_llm/serve/responses_utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242

4343
from tensorrt_llm.bindings import steady_clock_now
4444
from tensorrt_llm.executor import GenerationResult
45-
from tensorrt_llm.inputs.utils import apply_chat_template
45+
from tensorrt_llm.inputs.utils import async_apply_chat_template
4646
from tensorrt_llm.llmapi import SamplingParams
4747
from tensorrt_llm.llmapi.llm import RequestOutput
4848
from tensorrt_llm.llmapi.reasoning_parser import (BaseReasoningParser,
@@ -821,13 +821,11 @@ async def _create_input_tokens(
821821

822822
conversation, mm_coroutines, mm_placeholder_counts = parse_chat_messages_coroutines(
823823
messages, model_config)
824-
mm_data = await mm_coroutines
825-
826824
tools_dict = [
827825
tool.model_dump()
828826
for tool in _get_chat_completion_function_tools(request.tools)
829827
]
830-
token_ids = apply_chat_template(
828+
token_task = async_apply_chat_template(
831829
model_type=resolve_top_level_model_type(model_config),
832830
tokenizer=tokenizer,
833831
processor=processor,
@@ -837,6 +835,7 @@ async def _create_input_tokens(
837835
mm_placeholder_counts=mm_placeholder_counts,
838836
enable_tokenize=True,
839837
)
838+
token_ids, mm_data = await asyncio.gather(token_task, mm_coroutines)
840839

841840
return token_ids, mm_data
842841

tests/unittest/inputs/test_chat_template_dispatch.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
"""Tests for content-format-driven chat template dispatch and placeholder handling."""
44

5+
import threading
6+
57
import pytest
68

79
from tensorrt_llm.inputs.content_format import ContentFormat
@@ -15,6 +17,7 @@
1517
_build_openai_content,
1618
_resolve_content_format,
1719
add_multimodal_placeholders,
20+
async_apply_chat_template,
1821
interleave_mm_placeholders,
1922
)
2023

@@ -324,3 +327,33 @@ def test_excess_existing_placeholders_preserved(self):
324327
)
325328
assert result == text
326329
assert result.count("<image>") == 3
330+
331+
332+
class TestAsyncApplyChatTemplate:
333+
@pytest.mark.asyncio
334+
async def test_runs_in_worker_thread(self):
335+
event_loop_thread_id = threading.current_thread().ident
336+
337+
class TrackingTokenizer:
338+
def __init__(self):
339+
self.worker_thread_id = None
340+
341+
def apply_chat_template(self, **_):
342+
self.worker_thread_id = threading.current_thread().ident
343+
return "rendered"
344+
345+
tokenizer = TrackingTokenizer()
346+
347+
result = await async_apply_chat_template(
348+
model_type="test_string_model",
349+
tokenizer=tokenizer,
350+
processor=None,
351+
conversation=[ConversationMessage(role="user", content="hello", media=[])],
352+
add_generation_prompt=True,
353+
mm_placeholder_counts=[{}],
354+
chat_template="{{ messages }}",
355+
)
356+
357+
assert result == "rendered"
358+
assert tokenizer.worker_thread_id is not None
359+
assert tokenizer.worker_thread_id != event_loop_thread_id

0 commit comments

Comments
 (0)