@@ -95,19 +95,33 @@ def get_experts_list(module: torch.nn.Module, model_type: str):
9595 "qwen2moeforcausallm" ,
9696 "qwen3moeforcausallm" ,
9797 "qwen3nextforcausallm" ,
98+ "qwen3vlmoe" ,
9899 ]
99100 ):
100101 linear_names = ["gate_proj" , "down_proj" , "up_proj" ]
101102 else :
102103 raise NotImplementedError (f" { model_type } not supported" )
103104
104- # Common logic for all supported model types
105- experts_list .extend (
106- [
107- [_get_expert_attr (module .experts , i , linear_name ) for i in range (len (module .experts ))]
108- for linear_name in linear_names
109- ]
110- )
105+ # Check if experts use per-linear ModuleList structure (e.g., Qwen3VLMoeTextExperts)
106+ # where experts.gate_proj is a ModuleList, instead of experts[i].gate_proj
107+ first_linear = linear_names [0 ]
108+ if hasattr (module .experts , first_linear ) and isinstance (
109+ getattr (module .experts , first_linear ), nn .ModuleList
110+ ):
111+ experts_list .extend (
112+ [list (getattr (module .experts , linear_name )) for linear_name in linear_names ]
113+ )
114+ else :
115+ # Standard per-expert structure: experts[i].linear_name
116+ experts_list .extend (
117+ [
118+ [
119+ _get_expert_attr (module .experts , i , linear_name )
120+ for i in range (len (module .experts ))
121+ ]
122+ for linear_name in linear_names
123+ ]
124+ )
111125
112126 return experts_list
113127
@@ -1150,6 +1164,24 @@ def set_expert_quantizer_amax(
11501164_GATE_UP_PAIRS = [("gate_proj" , "up_proj" ), ("w1" , "w3" )]
11511165
11521166
1167+ def _sync_gate_up_pair (gate_linear , up_linear ) -> bool :
1168+ """Sync weight quantizer amaxes for a single gate/up pair. Returns True if synced."""
1169+ gate_wq = getattr (gate_linear , "weight_quantizer" , None )
1170+ up_wq = getattr (up_linear , "weight_quantizer" , None )
1171+ if gate_wq is None or up_wq is None :
1172+ return False
1173+ gate_amax = getattr (gate_wq , "amax" , None )
1174+ up_amax = getattr (up_wq , "amax" , None )
1175+ if gate_amax is None or up_amax is None :
1176+ return False
1177+ if not torch .equal (gate_amax , up_amax ):
1178+ shared_amax = torch .max (gate_amax , up_amax )
1179+ gate_wq .amax = shared_amax
1180+ up_wq .amax = shared_amax .clone ()
1181+ return True
1182+ return False
1183+
1184+
11531185def sync_moe_gate_up_amax (model : nn .Module ) -> int :
11541186 """Take element-wise max of gate and up weight quantizer amaxes per expert.
11551187
@@ -1162,35 +1194,43 @@ def sync_moe_gate_up_amax(model: nn.Module) -> int:
11621194 (e.g. Qwen MoE, DeepSeek). Models with already-fused gate_up_proj
11631195 (e.g. Llama4, GptOss) are unaffected.
11641196
1197+ Supports both standard per-expert structure (experts[i].gate_proj) and
1198+ per-linear ModuleList structure (experts.gate_proj[i], e.g. Qwen3VLMoeTextExperts).
1199+
11651200 Returns:
11661201 Number of expert gate/up pairs whose amaxes were synced.
11671202 """
11681203 synced = 0
11691204 for _ , sub_module in model .named_modules ():
11701205 if not (is_moe (sub_module ) and hasattr (sub_module , "experts" )):
11711206 continue
1172- if not hasattr (sub_module .experts , "__iter__" ):
1173- continue
1174- for expert in sub_module .experts :
1175- for gate_name , up_name in _GATE_UP_PAIRS :
1176- gate_linear = getattr (expert , gate_name , None )
1177- up_linear = getattr (expert , up_name , None )
1178- if gate_linear is None or up_linear is None :
1179- continue
1180- gate_wq = getattr (gate_linear , "weight_quantizer" , None )
1181- up_wq = getattr (up_linear , "weight_quantizer" , None )
1182- if gate_wq is None or up_wq is None :
1183- break
1184- gate_amax = getattr (gate_wq , "amax" , None )
1185- up_amax = getattr (up_wq , "amax" , None )
1186- if gate_amax is None or up_amax is None :
1207+
1208+ experts = sub_module .experts
1209+
1210+ # Check for per-linear ModuleList structure (e.g., Qwen3VLMoeTextExperts)
1211+ # where experts.gate_proj is a ModuleList instead of experts[i].gate_proj
1212+ is_modulelist_pattern = False
1213+ for gate_name , up_name in _GATE_UP_PAIRS :
1214+ gate_list = getattr (experts , gate_name , None )
1215+ up_list = getattr (experts , up_name , None )
1216+ if isinstance (gate_list , nn .ModuleList ) and isinstance (up_list , nn .ModuleList ):
1217+ for gate_linear , up_linear in zip (gate_list , up_list ):
1218+ if _sync_gate_up_pair (gate_linear , up_linear ):
1219+ synced += 1
1220+ is_modulelist_pattern = True
1221+ break # Found matching pair pattern, no need to check others
1222+
1223+ # Standard per-expert structure: experts[i].gate_proj
1224+ if not is_modulelist_pattern and hasattr (experts , "__iter__" ):
1225+ for expert in experts :
1226+ for gate_name , up_name in _GATE_UP_PAIRS :
1227+ gate_linear = getattr (expert , gate_name , None )
1228+ up_linear = getattr (expert , up_name , None )
1229+ if gate_linear is None or up_linear is None :
1230+ continue
1231+ if _sync_gate_up_pair (gate_linear , up_linear ):
1232+ synced += 1
11871233 break
1188- if not torch .equal (gate_amax , up_amax ):
1189- shared_amax = torch .max (gate_amax , up_amax )
1190- gate_wq .amax = shared_amax
1191- up_wq .amax = shared_amax .clone ()
1192- synced += 1
1193- break
11941234 return synced
11951235
11961236
0 commit comments