2828- Q/K normalization before RoPE
2929- Shared input/output embeddings (no separate lm_head)
3030
31- Config is loaded from the HF checkpoint via trust_remote_code=True.
31+ The config is bundled locally as ``OpenELMConfig`` (registered with ``AutoConfig``)
32+ rather than loaded from Apple's hub remote code. Apple's ``configuration_openelm.py``
33+ was written for transformers 4.x and its private ``__post_init__(self)`` collides
34+ with the transformers 5.x strict-dataclass ``PreTrainedConfig`` (which forwards
35+ unrecognized kwargs such as ``use_cache`` to ``self.__post_init__(**kwargs)``),
36+ raising ``TypeError`` before any model code runs. Vendoring the config locally
37+ mirrors the pattern used by the other AutoDeploy custom models (e.g. EXAONE) and
38+ removes Apple's frozen remote code from the execution path entirely.
3239Uses AutoDeploy canonical IR ops for export compatibility.
3340"""
3441
3542from dataclasses import dataclass
36- from typing import Optional , Tuple
43+ from numbers import Number
44+ from typing import List , Optional , Tuple , Union
3745
46+ import numpy as np
3847import torch
3948import torch .nn .functional as F
4049from torch import nn
50+ from transformers import AutoConfig , PretrainedConfig
4151from transformers .activations import ACT2FN
4252from transformers .generation import GenerationMixin
4353from transformers .modeling_utils import PreTrainedModel
@@ -60,6 +70,148 @@ def _make_divisible(v, divisor=8, min_value=None):
6070 return new_v
6171
6272
73+ def _compute_heads (model_dim : int , head_dim : int ) -> int :
74+ """Number of heads given model/head dim (from HF OpenELM config)."""
75+ if model_dim % head_dim != 0 :
76+ raise ValueError (
77+ f"Model dimension should be divisible by head dimension. "
78+ f"Got: { model_dim } and { head_dim } ."
79+ )
80+ return model_dim // head_dim
81+
82+
83+ # =============================================================================
84+ # Config (vendored locally; see module docstring)
85+ # =============================================================================
86+
87+
88+ class OpenELMConfig (PretrainedConfig ):
89+ """OpenELM configuration, vendored from Apple's ``configuration_openelm.py``.
90+
91+ Bundled with the custom model implementation so AutoDeploy never executes
92+ Apple's hub remote code (which is incompatible with transformers 5.x). The
93+ per-layer derivation that Apple performed in ``__post_init__`` is inlined
94+ into ``__init__`` here to avoid the transformers 5.x dataclass collision on
95+ the reserved ``__post_init__`` name.
96+
97+ Derivation logic is adapted from Apple's OpenELM config
98+ (Copyright (C) 2024 Apple Inc.; https://huggingface.co/apple/OpenELM-270M-Instruct).
99+ """
100+
101+ model_type = "openelm"
102+
103+ def __init__ (
104+ self ,
105+ vocab_size : int = 32000 ,
106+ max_context_length : int = 2048 ,
107+ num_transformer_layers : int = 12 ,
108+ model_dim : int = 2048 ,
109+ head_dim : int = 128 ,
110+ qkv_multipliers : Union [Number , List [Number ]] = 1.0 ,
111+ num_query_heads : Union [int , None ] = None ,
112+ num_gqa_groups : int = 1 ,
113+ ffn_multipliers : Union [Number , List [Number ]] = 4.0 ,
114+ ffn_with_glu : bool = True ,
115+ ffn_dim_divisor : int = 256 ,
116+ activation_fn_name : str = "swish" ,
117+ normalization_layer_name : str = "rms_norm" ,
118+ normalize_qk_projections : bool = False ,
119+ share_input_output_layers : bool = False ,
120+ rope_freq_constant : int = 10000 ,
121+ rope_max_length : int = 4096 ,
122+ initializer_range : float = 0.02 ,
123+ use_cache : bool = True ,
124+ bos_token_id : int = 1 ,
125+ eos_token_id : int = 2 ,
126+ ** kwargs ,
127+ ) -> None :
128+ self .vocab_size = vocab_size
129+ self .max_context_length = max_context_length
130+ self .num_transformer_layers = num_transformer_layers
131+ self .model_dim = model_dim
132+ self .head_dim = head_dim
133+ self .qkv_multipliers = qkv_multipliers
134+ self .num_gqa_groups = num_gqa_groups
135+ self .ffn_multipliers = ffn_multipliers
136+ self .ffn_with_glu = ffn_with_glu
137+ self .ffn_dim_divisor = ffn_dim_divisor
138+ self .activation_fn_name = activation_fn_name
139+ self .normalization_layer_name = normalization_layer_name
140+ self .normalize_qk_projections = normalize_qk_projections
141+ self .share_input_output_layers = share_input_output_layers
142+ self .rope_freq_constant = rope_freq_constant
143+ self .rope_max_length = rope_max_length
144+ self .initializer_range = initializer_range
145+ # NOTE: the `num_query_heads` parameter stays accepted for config.json schema
146+ # fidelity (published checkpoints carry the precomputed list), but the
147+ # per-layer derivation below is the source of truth, same as Apple's original.
148+
149+ # --- per-layer derivation (inlined from Apple's __post_init__) ---
150+ head_multiple_of = self .num_gqa_groups if self .num_gqa_groups is not None else 2
151+
152+ if isinstance (self .qkv_multipliers , Number ):
153+ qkv_dim = _make_divisible (
154+ self .model_dim * self .qkv_multipliers ,
155+ divisor = self .head_dim * head_multiple_of ,
156+ )
157+ query_dims = [int (qkv_dim )] * self .num_transformer_layers
158+ elif isinstance (self .qkv_multipliers , (tuple , list )) and len (self .qkv_multipliers ) == 2 :
159+ qkv_multipliers = [
160+ round (v , 2 )
161+ for v in np .linspace (
162+ self .qkv_multipliers [0 ],
163+ self .qkv_multipliers [1 ],
164+ num = self .num_transformer_layers ,
165+ dtype = float ,
166+ )
167+ ]
168+ query_dims = [
169+ int (_make_divisible (self .model_dim * m , divisor = self .head_dim * head_multiple_of ))
170+ for m in qkv_multipliers
171+ ]
172+ else :
173+ raise NotImplementedError (
174+ f"QKV multipliers should be a single number or a list of exactly two numbers. "
175+ f"Got: { self .qkv_multipliers } ."
176+ )
177+
178+ self .num_query_heads = [int (_compute_heads (q_dim , self .head_dim )) for q_dim in query_dims ]
179+ self .num_kv_heads = [q // self .num_gqa_groups for q in self .num_query_heads ]
180+
181+ if isinstance (self .ffn_multipliers , Number ):
182+ self .ffn_multipliers = [self .ffn_multipliers ] * self .num_transformer_layers
183+ elif isinstance (self .ffn_multipliers , (tuple , list )):
184+ if len (self .ffn_multipliers ) == 2 :
185+ self .ffn_multipliers = [
186+ round (v , 2 )
187+ for v in np .linspace (
188+ self .ffn_multipliers [0 ],
189+ self .ffn_multipliers [1 ],
190+ num = self .num_transformer_layers ,
191+ dtype = float ,
192+ )
193+ ]
194+ else :
195+ assert len (self .ffn_multipliers ) == self .num_transformer_layers , (
196+ f"{ len (self .ffn_multipliers )= } !={ self .num_transformer_layers = } "
197+ )
198+ else :
199+ raise NotImplementedError (
200+ f"FFN multipliers should be a single number or a list of exactly two numbers. "
201+ f"Got: { self .ffn_multipliers } ."
202+ )
203+
204+ for layer_idx in range (len (query_dims )):
205+ assert self .num_query_heads [layer_idx ] % self .num_kv_heads [layer_idx ] == 0
206+
207+ super ().__init__ (
208+ use_cache = use_cache ,
209+ bos_token_id = bos_token_id ,
210+ eos_token_id = eos_token_id ,
211+ ** kwargs ,
212+ )
213+
214+
63215# =============================================================================
64216# RMSNorm (canonical AD op)
65217# =============================================================================
@@ -395,5 +547,12 @@ def forward(
395547 return OpenELMCausalLMOutput (logits = logits )
396548
397549
398- # Register with AutoModelForCausalLMFactory
550+ # Register the vendored config with AutoConfig so AutoConfig.from_pretrained resolves
551+ # `model_type: openelm` to this local class instead of executing Apple's hub remote
552+ # code. Because this class's module is not under `transformers.`, transformers treats
553+ # it as `explicit_local_code` and uses it even when trust_remote_code=True and the
554+ # checkpoint's config.json declares an auto_map.
555+ AutoConfig .register ("openelm" , OpenELMConfig , exist_ok = True )
556+
557+ # Register with AutoModelForCausalLMFactory (keyed by config class name "OpenELMConfig")
399558AutoModelForCausalLMFactory .register_custom_model_cls ("OpenELMConfig" , OpenELMForCausalLM )
0 commit comments