Skip to content

Commit 45efcca

Browse files
[None][fix] Address review comments on async chat-template offload
- resource_governor: resolve the top-level model type (resolve_top_level_ model_type) in _convert_messages, matching the serving call sites, instead of the raw model_config.model_type. - responses_utils: unpack the (mm_data, mm_embeddings) tuple from the asyncio.gather result so _create_input_tokens returns mm_data (not the whole tuple) as its contract states. - tests: add async regression coverage for both gather paths (ResourceGovernor._convert_messages and _create_input_tokens). Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
1 parent 5fc1eb9 commit 45efcca

3 files changed

Lines changed: 106 additions & 3 deletions

File tree

tensorrt_llm/serve/resource_governor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@
3030
from tensorrt_llm.executor.request import TruncateKVCacheRequest
3131
from tensorrt_llm.inputs.utils import ConversationMessage, async_apply_chat_template
3232
from tensorrt_llm.logger import logger
33-
from tensorrt_llm.serve.chat_utils import parse_chat_messages_coroutines
33+
from tensorrt_llm.serve.chat_utils import (
34+
parse_chat_messages_coroutines,
35+
resolve_top_level_model_type,
36+
)
3437
from tensorrt_llm.serve.openai_protocol import KVCacheTruncateRequest
3538

3639

@@ -102,7 +105,7 @@ async def _convert_messages(
102105
messages, self.model_config, None
103106
)
104107
token_task = async_apply_chat_template(
105-
model_type=self.model_config.model_type,
108+
model_type=resolve_top_level_model_type(self.model_config),
106109
tokenizer=self.tokenizer,
107110
processor=self.processor,
108111
conversation=conversation,

tensorrt_llm/serve/responses_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,9 @@ async def _create_input_tokens(
835835
mm_placeholder_counts=mm_placeholder_counts,
836836
enable_tokenize=True,
837837
)
838-
token_ids, mm_data = await asyncio.gather(token_task, mm_coroutines)
838+
token_ids, (mm_data,
839+
_mm_embeddings) = await asyncio.gather(token_task,
840+
mm_coroutines)
839841

840842
return token_ids, mm_data
841843

tests/unittest/inputs/test_chat_template_dispatch.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,3 +357,101 @@ def apply_chat_template(self, **_):
357357
assert result == "rendered"
358358
assert tokenizer.worker_thread_id is not None
359359
assert tokenizer.worker_thread_id != event_loop_thread_id
360+
361+
362+
class TestServingChatTemplateGather:
363+
"""Cover the asyncio.gather integration in the serving chat-template paths."""
364+
365+
@pytest.mark.asyncio
366+
async def test_resource_governor_convert_messages(self, monkeypatch):
367+
from unittest.mock import Mock
368+
369+
import tensorrt_llm.serve.resource_governor as rg
370+
371+
governor = object.__new__(rg.ResourceGovernor)
372+
governor.model_config = Mock()
373+
governor.tokenizer = Mock()
374+
governor.processor = None
375+
376+
async def fake_mm_coroutine():
377+
# parse_chat_messages_coroutines' coroutine yields
378+
# (mm_data, mm_embeddings).
379+
return ({"image": ["data"]}, None)
380+
381+
monkeypatch.setattr(
382+
rg,
383+
"parse_chat_messages_coroutines",
384+
lambda messages, model_config, _: ([], fake_mm_coroutine(), [{}]),
385+
)
386+
# Must resolve the top-level model type, matching the serving call
387+
# sites (not the raw model_config.model_type).
388+
monkeypatch.setattr(rg, "resolve_top_level_model_type", lambda cfg: "resolved-model-type")
389+
390+
captured = {}
391+
392+
async def fake_async_apply(**kwargs):
393+
captured.update(kwargs)
394+
return [1, 2, 3]
395+
396+
monkeypatch.setattr(rg, "async_apply_chat_template", fake_async_apply)
397+
398+
token_ids = await governor._convert_messages(
399+
messages=[{"role": "user", "content": "hi"}],
400+
tool_dicts=None,
401+
add_generation_prompt=True,
402+
documents=None,
403+
chat_template=None,
404+
chat_template_kwargs=None,
405+
)
406+
407+
# Returns only token_ids, not the (mm_data, mm_embeddings) tuple.
408+
assert token_ids == [1, 2, 3]
409+
# Uses the top-level resolver and forwards the real placeholder counts.
410+
assert captured["model_type"] == "resolved-model-type"
411+
assert captured["mm_placeholder_counts"] == [{}]
412+
413+
@pytest.mark.asyncio
414+
async def test_responses_create_input_tokens_unpacks_mm_tuple(self, monkeypatch):
415+
"""_create_input_tokens must return mm_data, not the whole gather tuple."""
416+
from unittest.mock import Mock
417+
418+
import tensorrt_llm.serve.responses_utils as ru
419+
420+
async def fake_create_input_messages(request, prev_msgs):
421+
return [{"role": "user", "content": "hi"}]
422+
423+
async def fake_mm_coroutine():
424+
return ({"image": ["data"]}, {"image": ["embed"]})
425+
426+
monkeypatch.setattr(ru, "_create_input_messages", fake_create_input_messages)
427+
monkeypatch.setattr(
428+
ru,
429+
"parse_chat_messages_coroutines",
430+
lambda messages, model_config: ([], fake_mm_coroutine(), [{}]),
431+
)
432+
monkeypatch.setattr(ru, "resolve_top_level_model_type", lambda cfg: "resolved-model-type")
433+
monkeypatch.setattr(ru, "_get_chat_completion_function_tools", lambda tools: [])
434+
435+
async def fake_async_apply(**kwargs):
436+
return [1, 2, 3]
437+
438+
monkeypatch.setattr(ru, "async_apply_chat_template", fake_async_apply)
439+
440+
request = Mock()
441+
request.tools = None
442+
request.store = False
443+
444+
token_ids, mm_data = await ru._create_input_tokens(
445+
request=request,
446+
prev_response=None,
447+
prev_msgs=None,
448+
conversation_store=None,
449+
enable_store=False,
450+
tokenizer=Mock(),
451+
model_config=Mock(),
452+
processor=None,
453+
)
454+
455+
assert token_ids == [1, 2, 3]
456+
# mm_data is the data dict, not the (mm_data, mm_embeddings) tuple.
457+
assert mm_data == {"image": ["data"]}

0 commit comments

Comments
 (0)