Skip to content

Commit b036594

Browse files
yechank-nvidia2ez4bz
authored andcommitted
[None][refactor] Add explicit Qwen VL LLM compile hook
Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
1 parent f0a7053 commit b036594

3 files changed

Lines changed: 23 additions & 13 deletions

File tree

tensorrt_llm/_torch/models/modeling_qwen2vl.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,6 +1373,14 @@ def vocab_size_padded(self) -> int:
13731373
def infer_max_seq_len(self) -> int:
13741374
return self.llm.infer_max_seq_len()
13751375

1376+
def apply_llm_torch_compile(self, *, backend: Any, fullgraph: bool) -> None:
1377+
# TODO: Move this hook to MultimodalModelMixin once multimodal models
1378+
# consistently expose an LLM compile contract.
1379+
"""Compile only the LLM decoder; the vision encoder stays eager."""
1380+
self.llm.model = torch.compile(self.llm.model,
1381+
backend=backend,
1382+
fullgraph=fullgraph)
1383+
13761384
@nvtx_range("Qwen2.5-VL prepare_mrope_config")
13771385
def prepare_mrope_config(
13781386
self,

tensorrt_llm/_torch/models/modeling_qwen3vl.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,6 +1121,12 @@ def vocab_size_padded(self) -> int:
11211121
def infer_max_seq_len(self) -> int:
11221122
return self.llm.infer_max_seq_len()
11231123

1124+
def apply_llm_torch_compile(self, *, backend: Any, fullgraph: bool) -> None:
1125+
# TODO: Move this hook to MultimodalModelMixin once multimodal models
1126+
# consistently expose an LLM compile contract.
1127+
"""Compile only the LLM decoder; the vision encoder stays eager."""
1128+
self.llm.model = torch.compile(self.llm.model, backend=backend, fullgraph=fullgraph)
1129+
11241130
def init_mrope_embedding(self, model_config: ModelConfig[PretrainedConfig]):
11251131
config = model_config.pretrained_config.text_config
11261132
pos_embd_params = PositionalEmbeddingParams(

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -458,24 +458,20 @@ def __init__(
458458
capture_num_tokens=self._piecewise_cuda_graph_num_tokens,
459459
max_num_streams=torch_compile_max_num_streams,
460460
mapping=self.mapping)
461+
apply_llm_torch_compile = getattr(self.model,
462+
"apply_llm_torch_compile",
463+
None)
461464
if isinstance(self.model, DecoderModelForCausalLM):
462465
self.model.model = torch.compile(
463466
self.model.model,
464467
backend=self._torch_compile_backend,
465468
fullgraph=torch_compile_fullgraph)
466-
elif hasattr(self.model, "llm") and isinstance(
467-
getattr(self.model.llm, "model", None),
468-
torch.nn.Module):
469-
# Multi-modal wrapper (e.g. Qwen2/3-VL): compile only the
470-
# text decoder. Tracing the outer wrapper pulls the
471-
# vision-tower output path + `fuse_input_embeds` into
472-
# the same graph, which lets the vision hidden_dim
473-
# propagate into the LM o_proj fake-tensor trace and
474-
# blows up the piecewise CUDA graph warmup.
475-
self.model.llm.model = torch.compile(
476-
self.model.llm.model,
477-
backend=self._torch_compile_backend,
478-
fullgraph=torch_compile_fullgraph)
469+
elif callable(apply_llm_torch_compile):
470+
# TODO: Move this contract to MultimodalModelMixin once
471+
# multimodal models consistently expose their LLM compile
472+
# scope through the mixin.
473+
apply_llm_torch_compile(backend=self._torch_compile_backend,
474+
fullgraph=torch_compile_fullgraph)
479475
else:
480476
self.model = torch.compile(
481477
self.model,

0 commit comments

Comments
 (0)