Skip to content

Commit b868c10

Browse files
lucasliebmarimuthu-nv
authored andcommitted
[None][feat] Add AD custom model for Llama 4 family (#237)
Add a lean, prefill-only custom model for the Llama 4 family (Scout-17B-16E, Maverick-17B-128E) using AutoDeploy canonical ops. Replaces the previous MoE/vision patches with a self-contained implementation. Key features: - GQA with complex-frequency RoPE (torch_rope_with_complex_freqs) - NoPE layers with attention temperature tuning - L2 QK normalization on RoPE layers (mean-based, plain PyTorch) - MoE with stacked expert weights (bmm) matching HF checkpoint format; AD MatchBmmMoePattern transform handles conversion at deployment - Multimodal wrapper (ForConditionalGeneration) for weight compat - Fix multimodal processor chat template for text-only prompts Includes hierarchical unit tests (block, layer, full model, export) covering RoPE/NoPE layers, MoE/dense layers, and dynamic shapes. Signed-off-by: Lucas Liebenwein <lliebenwein@nvidia.com> Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
1 parent ff33927 commit b868c10

5 files changed

Lines changed: 1501 additions & 240 deletions

File tree

tensorrt_llm/_torch/auto_deploy/llm.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,25 @@ def __call__(
4646
# multi_modal_data should not be present in the messages field
4747
assert "multi_modal_data" not in inputs, f"unexpected multi_modal_data key in {inputs=}"
4848

49+
# Normalize message content to list-of-dicts format for multimodal
50+
# processors (e.g., Llama4) that expect {"type": "text", "text": "..."}
51+
# instead of plain strings when tokenize=True.
52+
messages = inputs["messages"]
53+
for msg in messages:
54+
if isinstance(msg.get("content"), str):
55+
msg["content"] = [{"type": "text", "text": msg["content"]}]
56+
4957
# TODO: we don't really need this but it makes for a good sanity check. Consider
5058
# removing this in the future if we need to speed things up.
5159
prompt = self.processor.apply_chat_template(
52-
inputs["messages"],
60+
messages,
5361
add_generation_prompt=True,
5462
tokenize=False,
5563
)
5664
inputs["prompt"] = prompt
5765

5866
all_args = self.processor.apply_chat_template(
59-
inputs["messages"],
67+
messages,
6068
add_generation_prompt=True,
6169
tokenize=True,
6270
return_dict=True,

tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .modeling_internlm3 import InternLM3ForCausalLM
1616
from .modeling_kimi_k2 import KimiK2ForCausalLM, KimiK25ForConditionalGeneration
1717
from .modeling_llama3 import Llama3ForCausalLM
18+
from .modeling_llama4 import Llama4ForCausalLM, Llama4ForConditionalGeneration
1819
from .modeling_minimax_m2 import MiniMaxM2ForCausalLM
1920
from .modeling_mistral import MistralForCausalLM
2021
from .modeling_mistral3 import Mistral3ForConditionalGeneration, Mistral3TextForCausalLM
@@ -49,6 +50,8 @@
4950
"KimiK2ForCausalLM",
5051
"KimiK25ForConditionalGeneration",
5152
"Llama3ForCausalLM",
53+
"Llama4ForCausalLM",
54+
"Llama4ForConditionalGeneration",
5255
"MiniMaxM2ForCausalLM",
5356
"MistralForCausalLM",
5457
"Mistral3ForConditionalGeneration",

0 commit comments

Comments
 (0)