Skip to content

Commit 800c4f2

Browse files
lucasliebmarimuthu-nv
authored andcommitted
[None][feat] Add AD custom model for InternLM3 family (#222)
* [None][feat] Add AD custom model for InternLM3 family Add a lean prefill-only custom model implementation for the InternLM3 architecture (GQA + SwiGLU MLP + RMSNorm + dynamic NTK-scaled RoPE) using AutoDeploy canonical ops (torch_attention, torch_rmsnorm, torch_rope_with_explicit_cos_sin). Includes hierarchical equivalence tests (block, layer, full model, export) and bundles a minimal InternLM3Config since the model is not natively in transformers. Signed-off-by: Lucas Liebenwein <lliebenwein@nvidia.com> Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> * [None][feat] Address review: remove bundled config, document inline refs Remove the bundled InternLM3Config from the modeling file. The AD pipeline loads the config from the HF checkpoint via trust_remote_code=True (same pattern as DeciLM). The test file now loads InternLM3Config dynamically from the HF cache. Inline HF reference classes are kept because the HF modeling_internlm3.py cannot be imported on the installed transformers version (requires LossKwargs from transformers >=4.48). Signed-off-by: Lucas Liebenwein <lliebenwein@nvidia.com> Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> --------- Signed-off-by: Lucas Liebenwein <lliebenwein@nvidia.com> Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
1 parent 39f8f89 commit 800c4f2

3 files changed

Lines changed: 963 additions & 0 deletions

File tree

tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .modeling_granite_moe_hybrid import GraniteMoeHybridForCausalLM
1111
from .modeling_hunyuan_dense import HunYuanDenseForCausalLM
1212
from .modeling_hunyuan_moe import HunYuanMoEForCausalLM
13+
from .modeling_internlm3 import InternLM3ForCausalLM
1314
from .modeling_kimi_k2 import KimiK2ForCausalLM, KimiK25ForConditionalGeneration
1415
from .modeling_llama3 import Llama3ForCausalLM
1516
from .modeling_mistral import MistralForCausalLM
@@ -39,6 +40,7 @@
3940
"GraniteMoeHybridForCausalLM",
4041
"HunYuanDenseForCausalLM",
4142
"HunYuanMoEForCausalLM",
43+
"InternLM3ForCausalLM",
4244
"KimiK2ForCausalLM",
4345
"KimiK25ForConditionalGeneration",
4446
"Llama3ForCausalLM",
Lines changed: 359 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,359 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Slimmed down PyTorch InternLM3 model implementation for auto_deploy export.
17+
18+
Source:
19+
https://huggingface.co/internlm/internlm3-8b-instruct
20+
21+
This implementation differs from the original HuggingFace version in the following ways:
22+
* Simplified for prefill-only inference (no KV caching)
23+
* Uses auto_deploy custom ops for export compatibility
24+
* Removed flash attention variants (uses torch_attention custom op)
25+
* Removed gradient checkpointing and training code paths
26+
* Removed attention dropout (inference only)
27+
* No repeat_kv — AD attention ops handle GQA natively
28+
29+
Config is loaded from the HF checkpoint via trust_remote_code=True (not bundled here).
30+
31+
The InternLM3 model uses GQA with SwiGLU MLP, RMSNorm, and dynamic NTK-scaled RoPE.
32+
"""
33+
34+
from dataclasses import dataclass
35+
from typing import Optional, Tuple
36+
37+
import torch
38+
from torch import nn
39+
from transformers.activations import ACT2FN
40+
from transformers.generation import GenerationMixin
41+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
42+
from transformers.modeling_utils import PreTrainedModel
43+
from transformers.utils import ModelOutput
44+
45+
from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory
46+
47+
48+
class InternLM3RMSNorm(nn.Module):
49+
"""RMS Normalization using AutoDeploy torch_rmsnorm reference op."""
50+
51+
def __init__(self, hidden_size: int, eps: float = 1e-6):
52+
super().__init__()
53+
self.weight = nn.Parameter(torch.ones(hidden_size))
54+
self.variance_epsilon = eps
55+
56+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
57+
return torch.ops.auto_deploy.torch_rmsnorm(
58+
hidden_states, self.weight, self.variance_epsilon
59+
)
60+
61+
62+
class InternLM3RotaryEmbedding(nn.Module):
63+
"""Rotary Position Embedding for InternLM3.
64+
65+
Supports all rope types (default, dynamic, linear, etc.) via
66+
transformers ROPE_INIT_FUNCTIONS. Precomputes and caches cos/sin values.
67+
Slices by position_ids once and returns pre-sliced cos/sin to all layers.
68+
69+
Uses _ad_ prefix for buffer names to work with AutoDeploy's lift_to_meta.
70+
"""
71+
72+
def __init__(self, config):
73+
super().__init__()
74+
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
75+
rope_type = config.rope_scaling.get(
76+
"rope_type", config.rope_scaling.get("type", "default")
77+
)
78+
else:
79+
rope_type = "default"
80+
81+
inv_freq, self.attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, device=None)
82+
83+
max_pos = config.max_position_embeddings
84+
t = torch.arange(max_pos, dtype=inv_freq.dtype)
85+
freqs = torch.outer(t, inv_freq)
86+
emb = torch.cat((freqs, freqs), dim=-1)
87+
self.register_buffer("_ad_cos_cached", emb.cos() * self.attention_scaling, persistent=False)
88+
self.register_buffer("_ad_sin_cached", emb.sin() * self.attention_scaling, persistent=False)
89+
90+
def forward(
91+
self, x: torch.Tensor, position_ids: torch.Tensor
92+
) -> Tuple[torch.Tensor, torch.Tensor]:
93+
cos = self._ad_cos_cached.to(dtype=x.dtype, device=x.device)
94+
sin = self._ad_sin_cached.to(dtype=x.dtype, device=x.device)
95+
return cos[position_ids], sin[position_ids]
96+
97+
98+
class InternLM3MLP(nn.Module):
99+
"""MLP layer for InternLM3 (SwiGLU activation)."""
100+
101+
def __init__(self, config):
102+
super().__init__()
103+
self.hidden_size = config.hidden_size
104+
self.intermediate_size = config.intermediate_size
105+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias)
106+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias)
107+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.bias)
108+
self.act_fn = ACT2FN[config.hidden_act]
109+
110+
def forward(self, x: torch.Tensor) -> torch.Tensor:
111+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
112+
113+
114+
class InternLM3Attention(nn.Module):
115+
"""Grouped Query Attention for InternLM3.
116+
117+
Uses AD canonical ops for attention and RoPE. GQA is handled natively
118+
by torch_attention — no repeat_kv needed.
119+
"""
120+
121+
def __init__(self, config, layer_idx: Optional[int] = None):
122+
super().__init__()
123+
self.config = config
124+
self.layer_idx = layer_idx
125+
126+
self.hidden_size = config.hidden_size
127+
self.num_heads = config.num_attention_heads
128+
self.num_kv_heads = config.num_key_value_heads
129+
self.head_dim = config.head_dim
130+
self.scaling = self.head_dim ** (-0.5)
131+
132+
self.q_proj = nn.Linear(
133+
self.hidden_size, self.num_heads * self.head_dim, bias=config.qkv_bias
134+
)
135+
self.k_proj = nn.Linear(
136+
self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias
137+
)
138+
self.v_proj = nn.Linear(
139+
self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias
140+
)
141+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
142+
143+
def forward(
144+
self,
145+
hidden_states: torch.Tensor,
146+
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
147+
) -> torch.Tensor:
148+
bsz, q_len, _ = hidden_states.size()
149+
150+
# Project Q/K/V and reshape to [B, S, N, head_dim] (BSND layout)
151+
q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
152+
k = self.k_proj(hidden_states).view(bsz, q_len, self.num_kv_heads, self.head_dim)
153+
v = self.v_proj(hidden_states).view(bsz, q_len, self.num_kv_heads, self.head_dim)
154+
155+
# Get pre-sliced cos/sin from position_embeddings (already indexed by position_ids)
156+
cos, sin = position_embeddings # [B, S, head_dim]
157+
158+
# Apply RoPE using custom op (BSND layout, unsqueeze_dim=2)
159+
q, k = torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin(
160+
q,
161+
k,
162+
cos,
163+
sin,
164+
2, # unsqueeze_dim=2 for BSND layout
165+
)
166+
167+
# Attention using custom op with GQA support (BSND layout)
168+
attn_output = torch.ops.auto_deploy.torch_attention(
169+
q, # [B, S, N, head_dim]
170+
k, # [B, S, N_kv, head_dim]
171+
v, # [B, S, N_kv, head_dim]
172+
None, # attn_mask
173+
0.0, # dropout_p
174+
True, # is_causal
175+
self.scaling, # scale
176+
None, # sinks
177+
None, # sliding_window
178+
None, # logit_cap
179+
"bsnd", # layout
180+
)
181+
182+
# Reshape [B, S, N, head_dim] -> [B, S, N * head_dim] and project
183+
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
184+
attn_output = self.o_proj(attn_output)
185+
186+
return attn_output
187+
188+
189+
class InternLM3DecoderLayer(nn.Module):
190+
"""Transformer decoder layer for InternLM3."""
191+
192+
def __init__(self, config, layer_idx: int):
193+
super().__init__()
194+
self.hidden_size = config.hidden_size
195+
196+
self.self_attn = InternLM3Attention(config, layer_idx=layer_idx)
197+
self.mlp = InternLM3MLP(config)
198+
self.input_layernorm = InternLM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
199+
self.post_attention_layernorm = InternLM3RMSNorm(
200+
config.hidden_size, eps=config.rms_norm_eps
201+
)
202+
203+
def forward(
204+
self,
205+
hidden_states: torch.Tensor,
206+
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
207+
) -> torch.Tensor:
208+
# Self attention
209+
residual = hidden_states
210+
hidden_states = self.input_layernorm(hidden_states)
211+
hidden_states = self.self_attn(hidden_states, position_embeddings)
212+
hidden_states = residual + hidden_states
213+
214+
# MLP
215+
residual = hidden_states
216+
hidden_states = self.post_attention_layernorm(hidden_states)
217+
hidden_states = self.mlp(hidden_states)
218+
hidden_states = residual + hidden_states
219+
220+
return hidden_states
221+
222+
223+
@dataclass
224+
class InternLM3Output(ModelOutput):
225+
"""Output for InternLM3Model."""
226+
227+
last_hidden_state: Optional[torch.FloatTensor] = None
228+
229+
230+
@dataclass
231+
class InternLM3CausalLMOutput(ModelOutput):
232+
"""Output for InternLM3ForCausalLM."""
233+
234+
logits: Optional[torch.FloatTensor] = None
235+
236+
237+
class InternLM3PreTrainedModel(PreTrainedModel):
238+
"""Base class for InternLM3 models."""
239+
240+
base_model_prefix = "model"
241+
_no_split_modules = ["InternLM3DecoderLayer"]
242+
supports_gradient_checkpointing = False
243+
244+
def _init_weights(self, module):
245+
std = self.config.initializer_range
246+
if isinstance(module, nn.Linear):
247+
module.weight.data.normal_(mean=0.0, std=std)
248+
if module.bias is not None:
249+
module.bias.data.zero_()
250+
elif isinstance(module, nn.Embedding):
251+
module.weight.data.normal_(mean=0.0, std=std)
252+
if module.padding_idx is not None:
253+
module.weight.data[module.padding_idx].zero_()
254+
255+
256+
class InternLM3Model(InternLM3PreTrainedModel):
257+
"""InternLM3 transformer decoder model."""
258+
259+
def __init__(self, config):
260+
super().__init__(config)
261+
self.config = config
262+
self.padding_idx = config.pad_token_id
263+
self.vocab_size = config.vocab_size
264+
265+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
266+
self.layers = nn.ModuleList(
267+
[
268+
InternLM3DecoderLayer(config, layer_idx=idx)
269+
for idx in range(config.num_hidden_layers)
270+
]
271+
)
272+
self.norm = InternLM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
273+
274+
# Shared rotary embedding at model level
275+
self.rotary_emb = InternLM3RotaryEmbedding(config)
276+
277+
self.post_init()
278+
279+
def get_input_embeddings(self):
280+
return self.embed_tokens
281+
282+
def set_input_embeddings(self, value):
283+
self.embed_tokens = value
284+
285+
def forward(
286+
self,
287+
input_ids: torch.LongTensor,
288+
position_ids: torch.LongTensor,
289+
**kwargs,
290+
) -> InternLM3Output:
291+
assert position_ids is not None, "position_ids must be provided for AD export"
292+
293+
inputs_embeds = self.embed_tokens(input_ids)
294+
295+
# Cast to compute dtype for FP8 models
296+
inputs_embeds = inputs_embeds.to(self.norm.weight.dtype)
297+
298+
# Compute position embeddings once (sliced by position_ids in RoPE)
299+
position_embeddings = self.rotary_emb(inputs_embeds, position_ids)
300+
301+
hidden_states = inputs_embeds
302+
303+
for decoder_layer in self.layers:
304+
hidden_states = decoder_layer(hidden_states, position_embeddings)
305+
306+
hidden_states = self.norm(hidden_states)
307+
308+
return InternLM3Output(last_hidden_state=hidden_states)
309+
310+
311+
class InternLM3ForCausalLM(InternLM3PreTrainedModel, GenerationMixin):
312+
"""InternLM3 model with language modeling head."""
313+
314+
_tied_weights_keys = ["lm_head.weight"]
315+
316+
def __init__(self, config, **kwargs):
317+
super().__init__(config)
318+
self.model = InternLM3Model(config)
319+
self.vocab_size = config.vocab_size
320+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
321+
322+
self.post_init()
323+
324+
def get_input_embeddings(self):
325+
return self.model.embed_tokens
326+
327+
def set_input_embeddings(self, value):
328+
self.model.embed_tokens = value
329+
330+
def get_output_embeddings(self):
331+
return self.lm_head
332+
333+
def set_output_embeddings(self, new_embeddings):
334+
self.lm_head = new_embeddings
335+
336+
def get_decoder(self):
337+
return self.model
338+
339+
def forward(
340+
self,
341+
input_ids: torch.LongTensor,
342+
position_ids: torch.LongTensor,
343+
**kwargs,
344+
) -> InternLM3CausalLMOutput:
345+
assert position_ids is not None, "position_ids must be provided for AD export"
346+
outputs = self.model(
347+
input_ids=input_ids,
348+
position_ids=position_ids,
349+
**kwargs,
350+
)
351+
352+
hidden_states = outputs.last_hidden_state
353+
logits = self.lm_head(hidden_states).float()
354+
355+
return InternLM3CausalLMOutput(logits=logits)
356+
357+
358+
# Register with AutoModelForCausalLMFactory
359+
AutoModelForCausalLMFactory.register_custom_model_cls("InternLM3Config", InternLM3ForCausalLM)

0 commit comments

Comments
 (0)