@@ -382,9 +382,18 @@ def forward_step(model, batch):
382382 f"Invalid auto_quantize_method: { auto_quantize_method } . Must be 'gradient' or 'kl_div'"
383383 )
384384
385+ auto_quantize_constraints = {
386+ "effective_bits" : args .auto_quantize_bits ,
387+ "cost_model" : args .auto_quantize_cost_model ,
388+ }
389+ if args .auto_quantize_active_moe_expert_ratio is not None :
390+ auto_quantize_constraints ["cost" ] = {
391+ "active_moe_expert_ratio" : args .auto_quantize_active_moe_expert_ratio
392+ }
393+
385394 language_model , _ = mtq .auto_quantize (
386395 language_model ,
387- constraints = { "effective_bits" : args . auto_quantize_bits } ,
396+ constraints = auto_quantize_constraints ,
388397 data_loader = calib_dataloader ,
389398 forward_step = forward_step ,
390399 loss_func = loss_func , # Only used for gradient-based method
@@ -1401,6 +1410,27 @@ def parse_args() -> argparse.Namespace:
14011410 "(sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified."
14021411 ),
14031412 )
1413+ parser .add_argument (
1414+ "--auto_quantize_cost_model" ,
1415+ type = str ,
1416+ default = "weight" ,
1417+ choices = ["weight" , "active_moe" ],
1418+ help = (
1419+ "Cost model for auto_quantize effective-bits accounting. 'weight' counts all "
1420+ "quantizable weights equally. 'active_moe' scales routed MoE expert weights by "
1421+ "--auto_quantize_active_moe_expert_ratio, or infers top_k/num_experts from model config."
1422+ ),
1423+ )
1424+ parser .add_argument (
1425+ "--auto_quantize_active_moe_expert_ratio" ,
1426+ type = float ,
1427+ default = None ,
1428+ help = (
1429+ "Routed MoE expert active ratio for --auto_quantize_cost_model active_moe. "
1430+ "For top-k MoE this is top_k / num_experts. If omitted, common model config "
1431+ "fields such as num_experts_per_tok and num_experts are used when available."
1432+ ),
1433+ )
14041434 parser .add_argument (
14051435 "--moe_calib_experts_ratio" ,
14061436 type = float ,
@@ -1434,6 +1464,18 @@ def parse_args() -> argparse.Namespace:
14341464 args = parser .parse_args ()
14351465 if args .moe_calib_experts_ratio is not None and not (0.0 < args .moe_calib_experts_ratio <= 1.0 ):
14361466 parser .error ("--moe_calib_experts_ratio must be in the range (0.0, 1.0]." )
1467+ if args .auto_quantize_active_moe_expert_ratio is not None and not (
1468+ 0.0 < args .auto_quantize_active_moe_expert_ratio <= 1.0
1469+ ):
1470+ parser .error ("--auto_quantize_active_moe_expert_ratio must be in the range (0.0, 1.0]." )
1471+ if (
1472+ args .auto_quantize_cost_model == "weight"
1473+ and args .auto_quantize_active_moe_expert_ratio is not None
1474+ ):
1475+ parser .error (
1476+ "--auto_quantize_active_moe_expert_ratio requires "
1477+ "--auto_quantize_cost_model active_moe."
1478+ )
14371479
14381480 if args .specdec_offline_dataset is not None and args .sparsity_fmt != "dense" :
14391481 parser .error ("--specdec_offline_dataset is only supported with --sparsity_fmt dense (PTQ)." )
0 commit comments