Skip to content

Commit dad10d9

Browse files
lucasliebmarimuthu-nv
authored andcommitted
[None][feat] Add AutoDeploy custom model for OpenELM family (#198)
Onboard the OpenELM architecture (apple/OpenELM-270M/1_1B/3B-Instruct) as a custom AutoDeploy model. This is a heterogeneous transformer with: - Per-layer varying query/KV head counts (GQA) - Per-layer varying FFN intermediate sizes - Fused QKV projection with Q/K normalization - Shared input/output embeddings (no separate lm_head) - GLU-style FFN (proj_1 = fused gate+up, proj_2 = down) Uses canonical AD IR ops: torch_rmsnorm, torch_rope_with_explicit_cos_sin, torch_attention. Config loaded from checkpoint via trust_remote_code=True. Updated openelm.yaml with attn_backend=flashinfer (trtllm backend produces degenerate output for OpenELM). Works with torch-cudagraph, default batch settings from dashboard_default.yaml. All 3 variants produce coherent generation via build_and_run_ad.py. Signed-off-by: Lucas Liebenwein <lliebenwein@nvidia.com> Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
1 parent b868c10 commit dad10d9

4 files changed

Lines changed: 874 additions & 2 deletions

File tree

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1-
# Configuration for Apple OpenELM models
2-
# These models require Llama-2 tokenizer
1+
# Configuration for Apple OpenELM models (270M, 1.1B, 3B)
2+
# These models use the Llama-2 tokenizer (confirmed by Apple's CoreNet docs).
33
tokenizer: meta-llama/Llama-2-7b-hf
4+
5+
# Override dashboard_default's attn_backend=trtllm which produces degenerate
6+
# output for OpenELM. Use flashinfer which works with torch-cudagraph.
7+
attn_backend: flashinfer

tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .modeling_nemotron_flash import NemotronFlashForCausalLM, NemotronFlashPreTrainedTokenizerFast
2323
from .modeling_nemotron_h import NemotronHForCausalLM
2424
from .modeling_olmo3 import Olmo3ForCausalLM
25+
from .modeling_openelm import OpenELMForCausalLM
2526
from .modeling_qwen2 import Qwen2ForCausalLM
2627
from .modeling_qwen3_5_moe import Qwen3_5MoeForCausalLM, Qwen3_5MoeForConditionalGeneration
2728
from .modeling_qwen3_moe import Qwen3MoeForCausalLM
@@ -60,6 +61,7 @@
6061
"NemotronFlashPreTrainedTokenizerFast",
6162
"NemotronHForCausalLM",
6263
"Olmo3ForCausalLM",
64+
"OpenELMForCausalLM",
6365
"Phi4ForCausalLM",
6466
"Phi4FlashForCausalLM",
6567
"Phi4MMForCausalLM",

0 commit comments

Comments
 (0)