Skip to content

Commit 36e3724

Browse files
committed
[None][feat] Add AutoDeploy custom model for OpenELM family
Onboard the OpenELM architecture (apple/OpenELM-270M/1_1B/3B-Instruct) as a custom AutoDeploy model. This is a heterogeneous transformer with: - Per-layer varying query/KV head counts (GQA) - Per-layer varying FFN intermediate sizes - Fused QKV projection with Q/K normalization - Shared input/output embeddings (no separate lm_head) - GLU-style FFN (proj_1 = fused gate+up, proj_2 = down) Uses canonical AD IR ops: torch_rmsnorm, torch_rope_with_explicit_cos_sin, torch_attention. Config loaded from checkpoint via trust_remote_code=True. Updated openelm.yaml with attn_backend=flashinfer (trtllm backend produces degenerate output for OpenELM). Works with torch-cudagraph, default batch settings from dashboard_default.yaml. All 3 variants produce coherent generation via build_and_run_ad.py. Signed-off-by: Lucas Liebenwein <lliebenwein@nvidia.com> Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
1 parent fcdea57 commit 36e3724

4 files changed

Lines changed: 874 additions & 2 deletions

File tree

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1-
# Configuration for Apple OpenELM models
2-
# These models require Llama-2 tokenizer
1+
# Configuration for Apple OpenELM models (270M, 1.1B, 3B)
2+
# These models use the Llama-2 tokenizer (confirmed by Apple's CoreNet docs).
33
tokenizer: meta-llama/Llama-2-7b-hf
4+
5+
# Override dashboard_default's attn_backend=trtllm which produces degenerate
6+
# output for OpenELM. Use flashinfer which works with torch-cudagraph.
7+
attn_backend: flashinfer

tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .modeling_nemotron_flash import NemotronFlashForCausalLM, NemotronFlashPreTrainedTokenizerFast
2525
from .modeling_nemotron_h import NemotronHForCausalLM
2626
from .modeling_olmo3 import Olmo3ForCausalLM
27+
from .modeling_openelm import OpenELMForCausalLM
2728
from .modeling_phi4 import Phi4ForCausalLM
2829
from .modeling_phi4_visionr import Phi4VisionRForConditionalGeneration
2930
from .modeling_phi4flash import Phi4FlashForCausalLM
@@ -71,6 +72,7 @@
7172
"NemotronFlashPreTrainedTokenizerFast",
7273
"NemotronHForCausalLM",
7374
"Olmo3ForCausalLM",
75+
"OpenELMForCausalLM",
7476
"Phi4ForCausalLM",
7577
"Phi4FlashForCausalLM",
7678
"Phi4MMForCausalLM",

0 commit comments

Comments
 (0)