Skip to content

Commit 05c6d3b

Browse files
Add MoE/Nemotron fixes to support Transformers 5.5
Tested with both transformers 4.57 and 5.5. ## Root cause transformers 5.5 natively supports NemotronHForCausalLM (with `model.` prefix), but all puzzletron checkpoints use the trust_remote_code class (with `backbone.` prefix). Additionally, the native NemotronHConfig does not recognize the `-` pattern character used by NemotronH v2 for MLP layers. ## Fixes **trust_remote_code model class selection (4 places)** For trust_remote_code models, always force `AutoModelForCausalLM.from_config( trust_remote_code=True)` instead of the native concrete class, which has a different module structure (backbone. vs model. prefix). Applied in: - `sharded_checkpoint_utils.py` create_sharded_model - `init_child_from_parent.py` (fixes KeyError on backbone.layers.N.mixer.experts keys) - `checkpoint_utils_hf.py` init_model_from_config (fixes AttributeError in calc_subblock_params_and_memory) - `tests/_test_utils/torch/puzzletron/utils.py` create_and_save_small_hf_model **NemotronH embedding key name (singular vs plural)** `nemotron_h_model_descriptor.py` layer_name_predicates: make `s` optional (`backbone\.embeddings?\.weight`) to match both the on-disk singular form (`backbone.embedding.weight`) produced by transformers 5.5 revert_weight_conversion and the in-memory plural form. **Test checkpoint save format** `utils.py` create_and_save_small_hf_model: - Use `save_pretrained(save_original_format=False)` to skip transformers 5.5 revert_weight_conversion, which would rename backbone.embeddings.weight -> backbone.embedding.weight and cause load_and_shard_model key mismatches. - Handle AttributeError from _tied_weights_keys being a list (trust_remote_code) vs dict (transformers v5 expectation) by clearing it and retrying. - Add `config.moe_latent_size = None` guard for native NemotronH config access. - Download trust_remote_code .py files via snapshot_download for models with auto_map, since save_pretrained does not copy them. **NemotronH v2 tokenizer loading** `validate_model.py` prepare_dataloader: auto-detect trust_remote_code from the descriptor (args.descriptor is always set in puzzletron configs) when not explicitly configured. Fixes NemotronH v2 where native NemotronHConfig. _pattern_to_list only handles {M, E, *} but v2 uses `-` for MLP layers. **Qwen3VL / transformers 5.x expert hook** `expert_removal_hooks.py`: - Gate returns (logits, aux_loss) tuple in transformers 5.x; unpack it. - Use hidden_states.shape[-1] instead of self.moe.hidden_size (removed in v5). - Version-branch the experts call: transformers 5.x uses grouped_mm signature (hidden_flat, top_k_index, top_k_weights) vs 4.x loop-based (hidden_3d, routing_weights_full, router_indices). **GPT-OSS attention_type** `gpt_oss_model_descriptor.py`: use getattr(layer, "attention_type", None) since the attribute was removed in transformers v5.4. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent 7053c61 commit 05c6d3b

11 files changed

Lines changed: 134 additions & 30 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
lm-eval==0.4.10
22
math-verify
33
ray
4+
# Likely works for transformers v5 also, but we need to test it
45
transformers<5.0

modelopt/torch/prune/importance_hooks/expert_removal_hooks.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from typing import TYPE_CHECKING
2121

2222
import torch
23+
import transformers
24+
from packaging.version import Version
2325
from torch import nn
2426

2527
from .base_hooks import ForwardHook
@@ -359,27 +361,40 @@ def get_router_logits_and_routed_experts(
359361
Based on Qwen3VLMoeSparseMoe forward pass.
360362
"""
361363
orig_shape = hidden_states.shape
364+
# Use hidden_states.shape[-1] instead of self.moe.hidden_size for transformers v5 compatibility
365+
hidden_size = (
366+
self.moe.hidden_size if hasattr(self.moe, "hidden_size") else hidden_states.shape[-1]
367+
)
362368

363369
# Flatten to (num_tokens, hidden_size) for processing
364-
hidden_states_flat = hidden_states.reshape(-1, self.moe.hidden_size)
370+
hidden_states_flat = hidden_states.reshape(-1, hidden_size)
365371

366372
if router_logits is None:
367373
router_logits = self.moe.gate(hidden_states_flat)
374+
# In transformers vf the gate returns (logits, aux_loss) tuple
375+
if isinstance(router_logits, tuple):
376+
router_logits = router_logits[0]
368377

369378
routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float)
370-
routing_weights, router_indices = torch.topk(routing_weights, self.moe.top_k, dim=-1)
379+
routing_weights, router_indices = torch.topk(
380+
routing_weights, self.num_experts_per_tok, dim=-1
381+
)
371382
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
372383
routing_weights = routing_weights.to(hidden_states_flat.dtype)
373-
router_weights = torch.zeros_like(router_logits).scatter_(
374-
1, router_indices, routing_weights
375-
)
376-
377-
# Reshape hidden_states for moe.experts (expects 3D: batch, seq, hidden)
378-
# router_weights and router_indices remain 2D (num_tokens, num_experts)
379-
batch_size = orig_shape[0] if hidden_states.ndim == 3 else 1
380-
hidden_states_3d = hidden_states_flat.reshape(batch_size, -1, self.moe.hidden_size)
381384

382-
routed_out = self.moe.experts(hidden_states_3d, router_weights, router_indices)
385+
if Version(transformers.__version__) >= Version("5.0"):
386+
# transformers 5.x: grouped_mm_experts_forward expects
387+
# (hidden_states_flat 2D, top_k_index, top_k_weights)
388+
routed_out = self.moe.experts(hidden_states_flat, router_indices, routing_weights)
389+
else:
390+
# transformers 4.x: loop-based experts expects
391+
# (hidden_states_3d 3D, routing_weights_full, router_indices)
392+
batch_size = orig_shape[0] if hidden_states.ndim == 3 else 1
393+
hidden_states_3d = hidden_states_flat.reshape(batch_size, -1, hidden_size)
394+
router_weights = torch.zeros(
395+
router_logits.shape, dtype=routing_weights.dtype, device=router_logits.device
396+
).scatter_(1, router_indices, routing_weights)
397+
routed_out = self.moe.experts(hidden_states_3d, router_weights, router_indices)
383398

384399
# Return in same shape as input
385400
routed_out = routed_out.reshape(*orig_shape)

modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,9 @@ class GptOssModelDescriptor(ModelDescriptor):
5454
@classmethod
5555
def create_dummy_block(cls, original_layer: GptOssDecoderLayer, block_index: int) -> nn.Module:
5656
dummy_block = DummyBlock(block_index=block_index)
57-
# Required by `GptOssModel.forward`.
58-
dummy_block.attention_type = original_layer.attention_type
57+
# Required by `GptOssModel.forward` in transformers<5.4
58+
if hasattr(original_layer, "attention_type"):
59+
dummy_block.attention_type = original_layer.attention_type
5960
return dummy_block
6061

6162
@staticmethod

modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def get_weight_groups(
200200
def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]:
201201
layer_name_patterns = {
202202
"embeddings": re.compile(
203-
r"^(model\.embed_tokens\.weight|backbone\.embeddings\.weight)$"
203+
r"^(model\.embed_tokens\.weight|backbone\.embeddings?\.weight)$"
204204
),
205205
"lm_head": re.compile(r"^(lm_head\.weight|backbone\.norm_f\.weight)$"),
206206
}

modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,11 @@
3939
update_model_config,
4040
)
4141
from modelopt.torch.puzzletron.tools.checkpoint_utils import copy_tokenizer, load_state_dict
42-
from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import _save_checkpoint, load_model_config
42+
from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import (
43+
_get_auto_class_for_trust_remote_code,
44+
_save_checkpoint,
45+
load_model_config,
46+
)
4347
from modelopt.torch.puzzletron.tools.logger import mprint
4448
from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import _get_model_class_from_config
4549

@@ -126,12 +130,14 @@ def init_child_from_parent(
126130
model_descriptor=descriptor, block_configs=child_model_config.block_configs
127131
):
128132
model_class = _get_model_class_from_config(child_model_config)
129-
# AutoModelForCausalLM uses from_config(); concrete model classes use _from_config()
130-
if model_class is AutoModelForCausalLM:
131-
trust_remote_code = descriptor.requires_trust_remote_code()
132-
child_model = model_class.from_config(
133+
trust_remote_code = descriptor.requires_trust_remote_code()
134+
if trust_remote_code:
135+
auto_cls = _get_auto_class_for_trust_remote_code(child_model_config)
136+
child_model = auto_cls.from_config(
133137
child_model_config, trust_remote_code=trust_remote_code
134138
)
139+
elif model_class is AutoModelForCausalLM:
140+
child_model = AutoModelForCausalLM.from_config(child_model_config)
135141
else:
136142
child_model = model_class._from_config(child_model_config)
137143

modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,33 @@ def _get_model_class_from_config(config: PretrainedConfig) -> type:
133133
return AutoModelForCausalLM
134134

135135

136+
def _get_auto_class_for_trust_remote_code(config: PretrainedConfig) -> type:
137+
"""Pick the right Auto class for a trust_remote_code model by inspecting auto_map.
138+
139+
When a model requires trust_remote_code, the native transformers class resolved from
140+
config.architectures must NOT be used directly — it may have a different module structure
141+
than the trust_remote_code class (e.g. NemotronH: native uses ``model.`` prefix, but the
142+
trust_remote_code class uses ``backbone.`` prefix, causing key mismatches throughout the
143+
pipeline). Instead, we route through the appropriate Auto class so that from_config()
144+
resolves the class via auto_map, picking up the correct trust_remote_code implementation.
145+
146+
Models declare which Auto class they support via config.auto_map. We walk a priority list
147+
so that CausalLM models and VL models (AutoModelForConditionalGeneration or similar) are
148+
both handled correctly.
149+
"""
150+
auto_map = getattr(config, "auto_map", {})
151+
priority = [
152+
"AutoModelForCausalLM",
153+
"AutoModelForConditionalGeneration",
154+
"AutoModelForImageTextToText",
155+
"AutoModel",
156+
]
157+
for name in priority:
158+
if name in auto_map and hasattr(transformers, name):
159+
return getattr(transformers, name)
160+
return AutoModelForCausalLM
161+
162+
136163
def init_model_from_config(
137164
config: PretrainedConfig,
138165
*,
@@ -145,10 +172,13 @@ def init_model_from_config(
145172
Pass True when loading configs that rely on custom modeling code from the checkpoint.
146173
"""
147174
model_class = _get_model_class_from_config(config)
175+
if trust_remote_code:
176+
auto_cls = _get_auto_class_for_trust_remote_code(config)
177+
return auto_cls.from_config(config, trust_remote_code=trust_remote_code, **kwargs)
148178
if model_class is AutoModelForCausalLM:
149-
return model_class.from_config(config, trust_remote_code=trust_remote_code, **kwargs)
150-
# Concrete model classes (e.g. GptOssForCausalLM): _from_config forwards kwargs to __init__,
151-
# which does not accept trust_remote_code (only AutoModel uses it when loading custom code).
179+
return AutoModelForCausalLM.from_config(config, **kwargs)
180+
# Concrete model classes (e.g. GptOssForCausalLM, Qwen3VLMoeForConditionalGeneration):
181+
# _from_config forwards kwargs to __init__, which does not accept trust_remote_code.
152182
return model_class._from_config(config, **kwargs)
153183

154184

modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@
4343

4444
import modelopt.torch.utils.distributed as dist
4545
from modelopt.torch.puzzletron.tools.checkpoint_utils import load_model_config, load_state_dict
46+
from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import (
47+
_get_auto_class_for_trust_remote_code,
48+
)
4649
from modelopt.torch.puzzletron.tools.logger import mprint
4750
from modelopt.torch.puzzletron.utils.dummy_modules import (
4851
DummyBlock,
@@ -172,8 +175,6 @@ def load_and_shard_model(
172175
device=runtime.device,
173176
)
174177

175-
new_names = set(shard_state_dict.keys())
176-
mprint(f"{new_names=}")
177178
# strict=False: allows missing lm_head.weight when tie_word_embeddings=True (e.g., Llama 3.2 3B)
178179
model_shard.load_state_dict(shard_state_dict, strict=False, assign=True)
179180

@@ -239,10 +240,12 @@ def create_sharded_model(
239240
with EmptyInitOnDevice(device="meta", dtype=dtype):
240241
# Get model class from config.architectures (works for CausalLM, VL models, etc.)
241242
model_class = _get_model_class_from_config(model_config)
242-
# AutoModelForCausalLM uses from_config(); concrete model classes use _from_config()
243-
if model_class is AutoModelForCausalLM:
244-
trust_remote_code = descriptor.requires_trust_remote_code()
245-
model = model_class.from_config(model_config, trust_remote_code=trust_remote_code)
243+
trust_remote_code = descriptor.requires_trust_remote_code()
244+
if trust_remote_code:
245+
auto_cls = _get_auto_class_for_trust_remote_code(model_config)
246+
model = auto_cls.from_config(model_config, trust_remote_code=trust_remote_code)
247+
elif model_class is AutoModelForCausalLM:
248+
model = AutoModelForCausalLM.from_config(model_config)
246249
else:
247250
model = model_class._from_config(model_config)
248251
create_local_shard_(

modelopt/torch/puzzletron/tools/validate_model.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,9 +235,19 @@ def prepare_dataloader(
235235
if tokenizer is None:
236236
tokenizer_name = getattr(args, "tokenizer_name", None)
237237
assert (tokenizer_name is not None) or (args.model_name_or_path is not None)
238+
# Auto-detect trust_remote_code from the descriptor when not explicitly set.
239+
# Required for models like NemotronH v2 whose configs use characters (e.g. '-') that
240+
# the native transformers NemotronHConfig._pattern_to_list doesn't support.
241+
trust_remote_code = getattr(args, "trust_remote_code", False)
242+
if not trust_remote_code and getattr(args, "descriptor", None):
243+
try:
244+
descriptor_cls = ModelDescriptorFactory.get(args.descriptor)
245+
trust_remote_code = descriptor_cls.requires_trust_remote_code()
246+
except Exception:
247+
pass
238248
tokenizer = AutoTokenizer.from_pretrained(
239249
tokenizer_name or args.model_name_or_path,
240-
trust_remote_code=getattr(args, "trust_remote_code", False),
250+
trust_remote_code=trust_remote_code,
241251
)
242252

243253
val_dataloader = create_validation_dataloader(

tests/_test_utils/torch/puzzletron/utils.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020
from _test_utils.torch.transformers_models import get_tiny_tokenizer
2121
from datasets import Dataset, DatasetDict
22+
from huggingface_hub import snapshot_download
2223
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedTokenizerBase
2324

2425
import modelopt.torch.utils.distributed as dist
@@ -135,6 +136,11 @@ def create_and_save_small_hf_model(
135136
):
136137
config.pad_token_id = 0
137138

139+
# Ensure moe_latent_size is present: the native transformers NemotronH model (>=5.5)
140+
# accesses config.moe_latent_size but older trust_remote_code configs don't define it.
141+
if not hasattr(config, "moe_latent_size"):
142+
config.moe_latent_size = None
143+
138144
# Set seed for reproducible weight initialization
139145
torch.manual_seed(42)
140146

@@ -167,14 +173,38 @@ def create_and_save_small_hf_model(
167173
else:
168174
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
169175

170-
model.to(dtype=torch.bfloat16).save_pretrained(output_path)
176+
model.to(dtype=torch.bfloat16)
177+
# save_original_format=False: skip transformers' revert_weight_conversion so weights are saved
178+
# with in-memory key names (e.g. backbone.embeddings.weight) rather than the on-disk "original"
179+
# format (e.g. backbone.embedding.weight for NemotronH). This avoids key mismatches in
180+
# load_and_shard_model which looks up shard keys from model.named_parameters().
181+
try:
182+
model.save_pretrained(output_path, save_original_format=False)
183+
except AttributeError:
184+
# Workaround: some trust_remote_code models define _tied_weights_keys in an older
185+
# format (returning a list) that is incompatible with transformers v5, which
186+
# expects _get_tied_weight_keys to return a dict. Clear tied weight keys and retry.
187+
for submodule in model.modules():
188+
if getattr(submodule, "_tied_weights_keys", None) is not None:
189+
submodule._tied_weights_keys = None
190+
model.save_pretrained(output_path, save_original_format=False)
171191

172192
# Save tokenizer
173193
tokenizer.save_pretrained(output_path)
174194

175195
# Save config
176196
config.save_pretrained(output_path)
177197

198+
# Download trust_remote_code .py files from HF hub into the checkpoint directory so that
199+
# force_cache_dynamic_modules can resolve classes from the local path.
200+
# save_pretrained only saves weights + config, not these .py files.
201+
if hasattr(config, "auto_map") and isinstance(config.auto_map, dict):
202+
snapshot_download(
203+
repo_id=hf_model_name,
204+
local_dir=output_path,
205+
allow_patterns=["*.py"],
206+
)
207+
178208

179209
def save_dummy_dataset(dataset_path: Path | str):
180210
"""

tests/gpu/torch/puzzletron/resources/configs/openai/gpt-oss-20b/gpt-oss-20b.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ scoring:
4444

4545
eval_samples: 2
4646
micro_batch_size: 1
47+
block_size: 512 # Toy model has max_position_embeddings=512; attention is O(batch*heads*seq^2)
4748
dataset_path: ${dataset_path}/valid
4849
seed: 42
4950
shuffle_seed: 444
@@ -97,6 +98,7 @@ realize_model:
9798
skip_validation: false # To enable validation of the model solution set `skip_validation` as False
9899
eval_samples: 2
99100
micro_batch_size: 1
101+
block_size: 512 # Toy model has max_position_embeddings=512; attention is O(batch*heads*seq^2)
100102
dataset_path: ${dataset_path}/valid
101103
seed: 42
102104
shuffle_seed: 444

0 commit comments

Comments
 (0)