@@ -98,6 +98,8 @@ def get_experts_list(module: torch.nn.Module, model_type: str):
9898 ]
9999 ):
100100 linear_names = ["gate_proj" , "down_proj" , "up_proj" ]
101+ elif "deepseek" in model_type :
102+ linear_names = ["gate_proj" , "down_proj" , "up_proj" ]
101103 else :
102104 raise NotImplementedError (f" { model_type } not supported" )
103105
@@ -150,6 +152,33 @@ def check_model_compatibility(module_list: list[nn.Module]) -> tuple[bool, bool,
150152
151153def get_transformer_layers (model : nn .Module ) -> list [nn .Module ]:
152154 """Returns the root module of the transformer model."""
155+ if "Megatron" in type (model ).__name__ :
156+ if hasattr (model , "model" ) and "GPTModel" in type (model .model ).__name__ :
157+ # NEMO mcore models can be handled with the following branch.
158+ model = model .model
159+
160+ # NEMO non mcore models, we need to find the language_model module first.
161+ children = [model ]
162+ language_model = None
163+ while children and not language_model :
164+ next_children = []
165+ for child in children :
166+ if type (child ).__name__ == "TransformerLanguageModel" :
167+ language_model = child
168+ break
169+ next_children .extend (list (child .children ()))
170+ children = next_children
171+ if language_model :
172+ warn ("Warning: this is an old NEMO checkpoint format and will be deprecated soon." )
173+ layers = list (language_model .embedding .children ()) + list (
174+ language_model .encoder .children ()
175+ )
176+
177+ if hasattr (language_model , "output_layer" ):
178+ layers .append (language_model .output_layer )
179+
180+ return layers
181+
153182 if "GPTModel" in type (model ).__name__ :
154183 # mcore models
155184 layers = []
@@ -298,14 +327,20 @@ def is_mlp(module: nn.Module) -> bool:
298327 return any (key in type (module ).__name__ .upper () for key in ("MLP" , "T5DENSE" ))
299328
300329
330+ def _is_deepseek_moe_name (module_name : str ) -> bool :
331+ return "deepseek" in module_name and "moe" in module_name
332+
333+
301334def is_moe (module : nn .Module ) -> bool :
302335 """Returns whether the module is an MOE layer."""
303336 name = type (module ).__name__ .lower ()
304337 # Auto-detect common MoE patterns
305338 if name .endswith ("sparsemoeblock" ) or "moelayer" in name :
306339 return True
340+ if _is_deepseek_moe_name (name ) and hasattr (module , "gate" ) and hasattr (module , "experts" ):
341+ return True
307342 # Explicit matches for non-standard naming
308- return any (key in name for key in ["arcticmoe" , "deepseekmoe " , "dbrxffn " ])
343+ return any (key in name for key in ["arcticmoe" , "dbrxffn " , "gptossmoe " ])
309344
310345
311346def is_quantlinear (module : nn .Module ) -> bool :
@@ -358,7 +393,7 @@ def build_qkv(
358393 num_kv_heads = ext_config .num_kv_heads
359394
360395 if "ColumnParallelLinear" in type (qkv_module ).__name__ :
361- # For Megatron-core model, num_kv_heads/num_attention_heads is the first dimension of QKV
396+ # For NEMO model, num_kv_heads/num_attention_heads is the first dimension of QKV
362397 model_metadata_config ["head_is_first_dim" ] = True
363398
364399 qkv_weight = qkv_module .weight
@@ -965,14 +1000,17 @@ def module_match_name_list(module, name_list):
9651000 """
9661001 return any (name .lower () in type (module ).__name__ .lower () for name in name_list )
9671002
968- if module_match_name_list (
1003+ module_name = type (module ).__name__ .lower ()
1004+
1005+ if _is_deepseek_moe_name (module_name ):
1006+ return ["gate_proj" , "down_proj" , "up_proj" ]
1007+ elif module_match_name_list (
9691008 module ,
9701009 [
9711010 "Qwen2MoeSparseMoeBlock" ,
9721011 "Qwen3MoeSparseMoeBlock" ,
9731012 "Qwen3NextSparseMoeBlock" ,
9741013 "Qwen3_5MoeSparseMoeBlock" ,
975- "DeepseekMoE" ,
9761014 ],
9771015 ):
9781016 return ["gate_proj" , "down_proj" , "up_proj" ]
@@ -1455,7 +1493,7 @@ def _set_layer_config_from_metaconfig(layer_config, metaconfig):
14551493 if k in metaconfig :
14561494 setattr (layer_config , name , metaconfig [k ])
14571495
1458- # MCore use "rope" as an alias for "rope_gpt_neox"
1496+ # MCore / NeMo use "rope" as an alias for "rope_gpt_neox"
14591497 if layer_config .position_embedding_type == "rope" :
14601498 layer_config .position_embedding_type = "rope_gpt_neox"
14611499
0 commit comments