Skip to content

Commit daf0144

Browse files
committed
Support Qwen3VLMoeTextExperts ModuleList pattern in export resmooth/fusion
- Add "qwen3vlmoe" to get_experts_list() model type recognition - Handle per-linear ModuleList expert structure (experts.gate_proj[i]) in addition to standard per-expert structure (experts[i].gate_proj) - Extend expert naming regex in requantize_resmooth_fused_llm_layers to match "experts.gate_proj.0" pattern for uncalibrated expert fusion - Update sync_moe_gate_up_amax to sync gate/up weight quantizer amaxes for ModuleList-pattern experts Signed-off-by: Shengliang Xu <shengliangx@nvidia.com>
1 parent 6671b5f commit daf0144

2 files changed

Lines changed: 95 additions & 45 deletions

File tree

modelopt/torch/export/layer_utils.py

Lines changed: 68 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -95,19 +95,33 @@ def get_experts_list(module: torch.nn.Module, model_type: str):
9595
"qwen2moeforcausallm",
9696
"qwen3moeforcausallm",
9797
"qwen3nextforcausallm",
98+
"qwen3vlmoe",
9899
]
99100
):
100101
linear_names = ["gate_proj", "down_proj", "up_proj"]
101102
else:
102103
raise NotImplementedError(f" {model_type} not supported")
103104

104-
# Common logic for all supported model types
105-
experts_list.extend(
106-
[
107-
[_get_expert_attr(module.experts, i, linear_name) for i in range(len(module.experts))]
108-
for linear_name in linear_names
109-
]
110-
)
105+
# Check if experts use per-linear ModuleList structure (e.g., Qwen3VLMoeTextExperts)
106+
# where experts.gate_proj is a ModuleList, instead of experts[i].gate_proj
107+
first_linear = linear_names[0]
108+
if hasattr(module.experts, first_linear) and isinstance(
109+
getattr(module.experts, first_linear), nn.ModuleList
110+
):
111+
experts_list.extend(
112+
[list(getattr(module.experts, linear_name)) for linear_name in linear_names]
113+
)
114+
else:
115+
# Standard per-expert structure: experts[i].linear_name
116+
experts_list.extend(
117+
[
118+
[
119+
_get_expert_attr(module.experts, i, linear_name)
120+
for i in range(len(module.experts))
121+
]
122+
for linear_name in linear_names
123+
]
124+
)
111125

112126
return experts_list
113127

@@ -1150,6 +1164,24 @@ def set_expert_quantizer_amax(
11501164
_GATE_UP_PAIRS = [("gate_proj", "up_proj"), ("w1", "w3")]
11511165

11521166

1167+
def _sync_gate_up_pair(gate_linear, up_linear) -> bool:
1168+
"""Sync weight quantizer amaxes for a single gate/up pair. Returns True if synced."""
1169+
gate_wq = getattr(gate_linear, "weight_quantizer", None)
1170+
up_wq = getattr(up_linear, "weight_quantizer", None)
1171+
if gate_wq is None or up_wq is None:
1172+
return False
1173+
gate_amax = getattr(gate_wq, "amax", None)
1174+
up_amax = getattr(up_wq, "amax", None)
1175+
if gate_amax is None or up_amax is None:
1176+
return False
1177+
if not torch.equal(gate_amax, up_amax):
1178+
shared_amax = torch.max(gate_amax, up_amax)
1179+
gate_wq.amax = shared_amax
1180+
up_wq.amax = shared_amax.clone()
1181+
return True
1182+
return False
1183+
1184+
11531185
def sync_moe_gate_up_amax(model: nn.Module) -> int:
11541186
"""Take element-wise max of gate and up weight quantizer amaxes per expert.
11551187
@@ -1162,35 +1194,43 @@ def sync_moe_gate_up_amax(model: nn.Module) -> int:
11621194
(e.g. Qwen MoE, DeepSeek). Models with already-fused gate_up_proj
11631195
(e.g. Llama4, GptOss) are unaffected.
11641196
1197+
Supports both standard per-expert structure (experts[i].gate_proj) and
1198+
per-linear ModuleList structure (experts.gate_proj[i], e.g. Qwen3VLMoeTextExperts).
1199+
11651200
Returns:
11661201
Number of expert gate/up pairs whose amaxes were synced.
11671202
"""
11681203
synced = 0
11691204
for _, sub_module in model.named_modules():
11701205
if not (is_moe(sub_module) and hasattr(sub_module, "experts")):
11711206
continue
1172-
if not hasattr(sub_module.experts, "__iter__"):
1173-
continue
1174-
for expert in sub_module.experts:
1175-
for gate_name, up_name in _GATE_UP_PAIRS:
1176-
gate_linear = getattr(expert, gate_name, None)
1177-
up_linear = getattr(expert, up_name, None)
1178-
if gate_linear is None or up_linear is None:
1179-
continue
1180-
gate_wq = getattr(gate_linear, "weight_quantizer", None)
1181-
up_wq = getattr(up_linear, "weight_quantizer", None)
1182-
if gate_wq is None or up_wq is None:
1183-
break
1184-
gate_amax = getattr(gate_wq, "amax", None)
1185-
up_amax = getattr(up_wq, "amax", None)
1186-
if gate_amax is None or up_amax is None:
1207+
1208+
experts = sub_module.experts
1209+
1210+
# Check for per-linear ModuleList structure (e.g., Qwen3VLMoeTextExperts)
1211+
# where experts.gate_proj is a ModuleList instead of experts[i].gate_proj
1212+
is_modulelist_pattern = False
1213+
for gate_name, up_name in _GATE_UP_PAIRS:
1214+
gate_list = getattr(experts, gate_name, None)
1215+
up_list = getattr(experts, up_name, None)
1216+
if isinstance(gate_list, nn.ModuleList) and isinstance(up_list, nn.ModuleList):
1217+
for gate_linear, up_linear in zip(gate_list, up_list):
1218+
if _sync_gate_up_pair(gate_linear, up_linear):
1219+
synced += 1
1220+
is_modulelist_pattern = True
1221+
break # Found matching pair pattern, no need to check others
1222+
1223+
# Standard per-expert structure: experts[i].gate_proj
1224+
if not is_modulelist_pattern and hasattr(experts, "__iter__"):
1225+
for expert in experts:
1226+
for gate_name, up_name in _GATE_UP_PAIRS:
1227+
gate_linear = getattr(expert, gate_name, None)
1228+
up_linear = getattr(expert, up_name, None)
1229+
if gate_linear is None or up_linear is None:
1230+
continue
1231+
if _sync_gate_up_pair(gate_linear, up_linear):
1232+
synced += 1
11871233
break
1188-
if not torch.equal(gate_amax, up_amax):
1189-
shared_amax = torch.max(gate_amax, up_amax)
1190-
gate_wq.amax = shared_amax
1191-
up_wq.amax = shared_amax.clone()
1192-
synced += 1
1193-
break
11941234
return synced
11951235

11961236

modelopt/torch/export/unified_export_hf.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -396,27 +396,37 @@ def llm_dummy_forward():
396396

397397
# The dummy forward may not be able to activate all the experts.
398398
# Process experts by naming rules like experts.0, experts.1, etc.
399+
# Also handle ModuleList pattern: experts.gate_proj.0, experts.up_proj.0, etc.
399400
for name, modules_fused in fused_linears.items():
401+
# Determine expert naming pattern:
402+
# Standard: "experts.0.gate_proj" → expert index right after "experts."
403+
# ModuleList: "experts.gate_proj.0" → expert index after linear name
400404
if re.search(r"experts?\.\d+", name):
401-
expert_id = 0
402-
while True:
403-
new_expert_name = re.sub(r"(experts?\.)\d+", rf"\g<1>{expert_id}", name, count=1)
404-
if new_expert_name in fused_linears:
405-
expert_id += 1
406-
continue
407-
if new_expert_name not in module_names:
408-
break
409-
410-
new_expert_modules = []
411-
for name_fused in modules_fused:
412-
new_expert_name = re.sub(r"(experts?\.)\d+", rf"\g<1>{expert_id}", name_fused)
413-
assert new_expert_name in module_names
414-
new_expert_modules.append(model.get_submodule(new_expert_name))
415-
416-
with fsdp2_aware_weight_update(model, new_expert_modules):
417-
preprocess_linear_fusion(new_expert_modules)
405+
pattern = r"(experts?\.)\d+"
406+
elif re.search(r"experts?\.[a-zA-Z_]\w*\.\d+", name):
407+
pattern = r"(experts?\.[a-zA-Z_]\w*\.)\d+"
408+
else:
409+
continue
418410

411+
expert_id = 0
412+
while True:
413+
new_expert_name = re.sub(pattern, rf"\g<1>{expert_id}", name, count=1)
414+
if new_expert_name in fused_linears:
419415
expert_id += 1
416+
continue
417+
if new_expert_name not in module_names:
418+
break
419+
420+
new_expert_modules = []
421+
for name_fused in modules_fused:
422+
new_expert_name = re.sub(pattern, rf"\g<1>{expert_id}", name_fused)
423+
assert new_expert_name in module_names
424+
new_expert_modules.append(model.get_submodule(new_expert_name))
425+
426+
with fsdp2_aware_weight_update(model, new_expert_modules):
427+
preprocess_linear_fusion(new_expert_modules)
428+
429+
expert_id += 1
420430

421431

422432
def _export_quantized_weight(

0 commit comments

Comments
 (0)