Skip to content

Commit 8e2b7b2

Browse files
authored
[#14672][fix] AutoDeploy: Vendor OpenELMConfig locally to fix OpenELM config loading (#15175)
Signed-off-by: Pedro Lapagesse <pedrolap@umich.edu>
1 parent fb7a1d0 commit 8e2b7b2

3 files changed

Lines changed: 332 additions & 33 deletions

File tree

examples/auto_deploy/model_registry/models.yaml

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,15 @@ models:
1313
- name: Qwen/Qwen3-0.6B
1414
config_id: default_ws_1
1515
yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml', 'enable_sharder_ir.yaml']
16-
# TypeError: OpenELMConfig.__post_init__() got an unexpected keyword argument 'use_cache'. See https://github.com/NVIDIA/TensorRT-LLM/issues/14672
17-
# - name: apple/OpenELM-270M-Instruct
18-
# config_id: openelm
19-
# yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml', 'openelm.yaml']
20-
# TypeError: OpenELMConfig.__post_init__() got an unexpected keyword argument 'use_cache'. See https://github.com/NVIDIA/TensorRT-LLM/issues/14672
21-
# - name: apple/OpenELM-1_1B-Instruct
22-
# config_id: openelm
23-
# yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml', 'openelm.yaml']
24-
# TypeError: OpenELMConfig.__post_init__() got an unexpected keyword argument 'use_cache'. See https://github.com/NVIDIA/TensorRT-LLM/issues/14672
25-
# - name: apple/OpenELM-3B-Instruct
26-
# config_id: openelm
27-
# yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml', 'openelm.yaml']
16+
- name: apple/OpenELM-270M-Instruct
17+
config_id: openelm
18+
yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml', 'openelm.yaml']
19+
- name: apple/OpenELM-1_1B-Instruct
20+
config_id: openelm
21+
yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml', 'openelm.yaml']
22+
- name: apple/OpenELM-3B-Instruct
23+
config_id: openelm
24+
yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml', 'openelm.yaml']
2825
# ImportError: cannot import name 'LossKwargs' from 'transformers.utils' (/usr/local/lib/python3.12/dist-packages/transformers/utils/__init__.py)
2926
# - name: microsoft/Phi-4-mini-instruct
3027
# config_id: default_ws_1

tensorrt_llm/_torch/auto_deploy/models/custom/modeling_openelm.py

Lines changed: 162 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,26 @@
2828
- Q/K normalization before RoPE
2929
- Shared input/output embeddings (no separate lm_head)
3030
31-
Config is loaded from the HF checkpoint via trust_remote_code=True.
31+
The config is bundled locally as ``OpenELMConfig`` (registered with ``AutoConfig``)
32+
rather than loaded from Apple's hub remote code. Apple's ``configuration_openelm.py``
33+
was written for transformers 4.x and its private ``__post_init__(self)`` collides
34+
with the transformers 5.x strict-dataclass ``PreTrainedConfig`` (which forwards
35+
unrecognized kwargs such as ``use_cache`` to ``self.__post_init__(**kwargs)``),
36+
raising ``TypeError`` before any model code runs. Vendoring the config locally
37+
mirrors the pattern used by the other AutoDeploy custom models (e.g. EXAONE) and
38+
removes Apple's frozen remote code from the execution path entirely.
3239
Uses AutoDeploy canonical IR ops for export compatibility.
3340
"""
3441

3542
from dataclasses import dataclass
36-
from typing import Optional, Tuple
43+
from numbers import Number
44+
from typing import List, Optional, Tuple, Union
3745

46+
import numpy as np
3847
import torch
3948
import torch.nn.functional as F
4049
from torch import nn
50+
from transformers import AutoConfig, PretrainedConfig
4151
from transformers.activations import ACT2FN
4252
from transformers.generation import GenerationMixin
4353
from transformers.modeling_utils import PreTrainedModel
@@ -60,6 +70,148 @@ def _make_divisible(v, divisor=8, min_value=None):
6070
return new_v
6171

6272

73+
def _compute_heads(model_dim: int, head_dim: int) -> int:
74+
"""Number of heads given model/head dim (from HF OpenELM config)."""
75+
if model_dim % head_dim != 0:
76+
raise ValueError(
77+
f"Model dimension should be divisible by head dimension. "
78+
f"Got: {model_dim} and {head_dim}."
79+
)
80+
return model_dim // head_dim
81+
82+
83+
# =============================================================================
84+
# Config (vendored locally; see module docstring)
85+
# =============================================================================
86+
87+
88+
class OpenELMConfig(PretrainedConfig):
89+
"""OpenELM configuration, vendored from Apple's ``configuration_openelm.py``.
90+
91+
Bundled with the custom model implementation so AutoDeploy never executes
92+
Apple's hub remote code (which is incompatible with transformers 5.x). The
93+
per-layer derivation that Apple performed in ``__post_init__`` is inlined
94+
into ``__init__`` here to avoid the transformers 5.x dataclass collision on
95+
the reserved ``__post_init__`` name.
96+
97+
Derivation logic is adapted from Apple's OpenELM config
98+
(Copyright (C) 2024 Apple Inc.; https://huggingface.co/apple/OpenELM-270M-Instruct).
99+
"""
100+
101+
model_type = "openelm"
102+
103+
def __init__(
104+
self,
105+
vocab_size: int = 32000,
106+
max_context_length: int = 2048,
107+
num_transformer_layers: int = 12,
108+
model_dim: int = 2048,
109+
head_dim: int = 128,
110+
qkv_multipliers: Union[Number, List[Number]] = 1.0,
111+
num_query_heads: Union[int, None] = None,
112+
num_gqa_groups: int = 1,
113+
ffn_multipliers: Union[Number, List[Number]] = 4.0,
114+
ffn_with_glu: bool = True,
115+
ffn_dim_divisor: int = 256,
116+
activation_fn_name: str = "swish",
117+
normalization_layer_name: str = "rms_norm",
118+
normalize_qk_projections: bool = False,
119+
share_input_output_layers: bool = False,
120+
rope_freq_constant: int = 10000,
121+
rope_max_length: int = 4096,
122+
initializer_range: float = 0.02,
123+
use_cache: bool = True,
124+
bos_token_id: int = 1,
125+
eos_token_id: int = 2,
126+
**kwargs,
127+
) -> None:
128+
self.vocab_size = vocab_size
129+
self.max_context_length = max_context_length
130+
self.num_transformer_layers = num_transformer_layers
131+
self.model_dim = model_dim
132+
self.head_dim = head_dim
133+
self.qkv_multipliers = qkv_multipliers
134+
self.num_gqa_groups = num_gqa_groups
135+
self.ffn_multipliers = ffn_multipliers
136+
self.ffn_with_glu = ffn_with_glu
137+
self.ffn_dim_divisor = ffn_dim_divisor
138+
self.activation_fn_name = activation_fn_name
139+
self.normalization_layer_name = normalization_layer_name
140+
self.normalize_qk_projections = normalize_qk_projections
141+
self.share_input_output_layers = share_input_output_layers
142+
self.rope_freq_constant = rope_freq_constant
143+
self.rope_max_length = rope_max_length
144+
self.initializer_range = initializer_range
145+
# NOTE: the `num_query_heads` parameter stays accepted for config.json schema
146+
# fidelity (published checkpoints carry the precomputed list), but the
147+
# per-layer derivation below is the source of truth, same as Apple's original.
148+
149+
# --- per-layer derivation (inlined from Apple's __post_init__) ---
150+
head_multiple_of = self.num_gqa_groups if self.num_gqa_groups is not None else 2
151+
152+
if isinstance(self.qkv_multipliers, Number):
153+
qkv_dim = _make_divisible(
154+
self.model_dim * self.qkv_multipliers,
155+
divisor=self.head_dim * head_multiple_of,
156+
)
157+
query_dims = [int(qkv_dim)] * self.num_transformer_layers
158+
elif isinstance(self.qkv_multipliers, (tuple, list)) and len(self.qkv_multipliers) == 2:
159+
qkv_multipliers = [
160+
round(v, 2)
161+
for v in np.linspace(
162+
self.qkv_multipliers[0],
163+
self.qkv_multipliers[1],
164+
num=self.num_transformer_layers,
165+
dtype=float,
166+
)
167+
]
168+
query_dims = [
169+
int(_make_divisible(self.model_dim * m, divisor=self.head_dim * head_multiple_of))
170+
for m in qkv_multipliers
171+
]
172+
else:
173+
raise NotImplementedError(
174+
f"QKV multipliers should be a single number or a list of exactly two numbers. "
175+
f"Got: {self.qkv_multipliers}."
176+
)
177+
178+
self.num_query_heads = [int(_compute_heads(q_dim, self.head_dim)) for q_dim in query_dims]
179+
self.num_kv_heads = [q // self.num_gqa_groups for q in self.num_query_heads]
180+
181+
if isinstance(self.ffn_multipliers, Number):
182+
self.ffn_multipliers = [self.ffn_multipliers] * self.num_transformer_layers
183+
elif isinstance(self.ffn_multipliers, (tuple, list)):
184+
if len(self.ffn_multipliers) == 2:
185+
self.ffn_multipliers = [
186+
round(v, 2)
187+
for v in np.linspace(
188+
self.ffn_multipliers[0],
189+
self.ffn_multipliers[1],
190+
num=self.num_transformer_layers,
191+
dtype=float,
192+
)
193+
]
194+
else:
195+
assert len(self.ffn_multipliers) == self.num_transformer_layers, (
196+
f"{len(self.ffn_multipliers)=}!={self.num_transformer_layers=}"
197+
)
198+
else:
199+
raise NotImplementedError(
200+
f"FFN multipliers should be a single number or a list of exactly two numbers. "
201+
f"Got: {self.ffn_multipliers}."
202+
)
203+
204+
for layer_idx in range(len(query_dims)):
205+
assert self.num_query_heads[layer_idx] % self.num_kv_heads[layer_idx] == 0
206+
207+
super().__init__(
208+
use_cache=use_cache,
209+
bos_token_id=bos_token_id,
210+
eos_token_id=eos_token_id,
211+
**kwargs,
212+
)
213+
214+
63215
# =============================================================================
64216
# RMSNorm (canonical AD op)
65217
# =============================================================================
@@ -395,5 +547,12 @@ def forward(
395547
return OpenELMCausalLMOutput(logits=logits)
396548

397549

398-
# Register with AutoModelForCausalLMFactory
550+
# Register the vendored config with AutoConfig so AutoConfig.from_pretrained resolves
551+
# `model_type: openelm` to this local class instead of executing Apple's hub remote
552+
# code. Because this class's module is not under `transformers.`, transformers treats
553+
# it as `explicit_local_code` and uses it even when trust_remote_code=True and the
554+
# checkpoint's config.json declares an auto_map.
555+
AutoConfig.register("openelm", OpenELMConfig, exist_ok=True)
556+
557+
# Register with AutoModelForCausalLMFactory (keyed by config class name "OpenELMConfig")
399558
AutoModelForCausalLMFactory.register_custom_model_cls("OpenELMConfig", OpenELMForCausalLM)

0 commit comments

Comments
 (0)