Skip to content

Commit 5de4432

Browse files
committed
Add active-MoE AutoQuant cost accounting
Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
1 parent 6bca67e commit 5de4432

8 files changed

Lines changed: 661 additions & 13 deletions

File tree

examples/llm_ptq/hf_ptq.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None:
104104
"int8_sq": mtq.INT8_SMOOTHQUANT_CFG,
105105
"int8_wo": mtq.INT8_WEIGHT_ONLY_CFG,
106106
"fp8": mtq.FP8_DEFAULT_CFG,
107+
"fp8_w8a8": mtq.FP8_DEFAULT_CFG,
107108
"int4_awq": mtq.INT4_AWQ_CFG,
108109
"w4a8_awq": mtq.W4A8_AWQ_BETA_CFG,
109110
"nvfp4": mtq.NVFP4_DEFAULT_CFG,
@@ -350,6 +351,7 @@ def auto_quantize(
350351
qformat
351352
in [
352353
"fp8",
354+
"fp8_w8a8",
353355
"int8_sq",
354356
"int8_wo",
355357
"int4_awq",
@@ -396,9 +398,15 @@ def forward_step(model, batch):
396398
if "parent_class" not in entry and entry["quantizer_name"] != "*lm_head*"
397399
]
398400
enable_linear_attn_big3 = os.environ.get("MODELOPT_AUTOQ_ENABLE_LINEAR_ATTN_BIG3") == "1"
401+
enable_linear_attn_all = os.environ.get("MODELOPT_AUTOQ_ENABLE_LINEAR_ATTN_ALL") == "1"
399402
enable_shared_expert = os.environ.get("MODELOPT_AUTOQ_ENABLE_SHARED_EXPERT") == "1"
403+
if enable_linear_attn_all:
404+
enable_linear_attn_big3 = True
400405
autoq_extra_disabled = [
401406
"*shared_expert_gate*",
407+
# Keep the GDN a/b projections in BF16 even for "all linear_attn"
408+
# searches. Prior healthy NVFP4 controls excluded these small
409+
# projections, while low-end full-search checkpoints quantized them.
402410
"*linear_attn.in_proj_a*",
403411
"*linear_attn.in_proj_b*",
404412
]
@@ -437,6 +445,10 @@ def forward_step(model, batch):
437445
disabled_layers=disabled_layers,
438446
method=auto_quantize_method,
439447
checkpoint=auto_quantize_checkpoint,
448+
cost_model=args.auto_quantize_cost_model,
449+
active_moe_expert_ratio=args.auto_quantize_active_moe_expert_ratio,
450+
cost_lower_bound=args.auto_quantize_cost_lower_bound,
451+
cost_objective=args.auto_quantize_cost_objective,
440452
)
441453

442454
calibrate_loop = create_forward_loop(dataloader=calib_dataloader)
@@ -1454,6 +1466,48 @@ def parse_args() -> argparse.Namespace:
14541466
"(sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified."
14551467
),
14561468
)
1469+
parser.add_argument(
1470+
"--auto_quantize_cost_model",
1471+
type=str,
1472+
default="weight",
1473+
choices=["weight", "active_moe"],
1474+
help=(
1475+
"Cost model for auto_quantize effective-bits accounting. 'weight' counts all "
1476+
"quantizable weights equally. 'active_moe' scales routed MoE expert weights by "
1477+
"--auto_quantize_active_moe_expert_ratio, or infers top_k/num_experts from model config."
1478+
),
1479+
)
1480+
parser.add_argument(
1481+
"--auto_quantize_active_moe_expert_ratio",
1482+
type=float,
1483+
default=None,
1484+
help=(
1485+
"Routed MoE expert active ratio for --auto_quantize_cost_model active_moe. "
1486+
"For top-k MoE this is top_k / num_experts. If omitted, common model config "
1487+
"fields such as num_experts_per_tok and num_experts are used when available."
1488+
),
1489+
)
1490+
parser.add_argument(
1491+
"--auto_quantize_cost_lower_bound",
1492+
type=float,
1493+
default=None,
1494+
help=(
1495+
"Optional lower bound, as a fraction of the requested effective-bits budget, "
1496+
"for the auto_quantize LP. Active-MoE cost mode uses a best-effort lower bound "
1497+
"by default when this is omitted."
1498+
),
1499+
)
1500+
parser.add_argument(
1501+
"--auto_quantize_cost_objective",
1502+
type=str,
1503+
default="sensitivity",
1504+
choices=["sensitivity", "active_moe"],
1505+
help=(
1506+
"Objective for auto_quantize LP. 'sensitivity' minimizes quantization sensitivity. "
1507+
"'active_moe' minimizes active routed-MoE cost while the cost model constraint "
1508+
"still controls the requested budget."
1509+
),
1510+
)
14571511
parser.add_argument(
14581512
"--moe_calib_experts_ratio",
14591513
type=float,
@@ -1475,6 +1529,23 @@ def parse_args() -> argparse.Namespace:
14751529
args = parser.parse_args()
14761530
if args.moe_calib_experts_ratio is not None and not (0.0 < args.moe_calib_experts_ratio <= 1.0):
14771531
parser.error("--moe_calib_experts_ratio must be in the range (0.0, 1.0].")
1532+
if args.auto_quantize_active_moe_expert_ratio is not None and not (
1533+
0.0 < args.auto_quantize_active_moe_expert_ratio <= 1.0
1534+
):
1535+
parser.error("--auto_quantize_active_moe_expert_ratio must be in the range (0.0, 1.0].")
1536+
if (
1537+
args.auto_quantize_cost_model == "weight"
1538+
and args.auto_quantize_cost_objective != "active_moe"
1539+
and args.auto_quantize_active_moe_expert_ratio is not None
1540+
):
1541+
parser.error(
1542+
"--auto_quantize_active_moe_expert_ratio requires "
1543+
"--auto_quantize_cost_model active_moe or --auto_quantize_cost_objective active_moe."
1544+
)
1545+
if args.auto_quantize_cost_lower_bound is not None and not (
1546+
0.0 < args.auto_quantize_cost_lower_bound <= 1.0
1547+
):
1548+
parser.error("--auto_quantize_cost_lower_bound must be in the range (0.0, 1.0].")
14781549

14791550
if args.specdec_offline_dataset is not None and args.sparsity_fmt != "dense":
14801551
parser.error("--specdec_offline_dataset is only supported with --sparsity_fmt dense (PTQ).")

modelopt/torch/export/layer_utils.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,6 +1173,97 @@ def set_expert_quantizer_amax(
11731173
_GATE_UP_PAIRS = [("gate_proj", "up_proj"), ("w1", "w3")]
11741174

11751175

1176+
_LINEAR_ATTN_FUSED_PAIRS = [
1177+
("in_proj_qkv", "in_proj_z"),
1178+
("in_proj_b", "in_proj_a"),
1179+
]
1180+
1181+
1182+
def _tensor_values_equal(left: torch.Tensor | None, right: torch.Tensor | None) -> bool:
1183+
if left is None or right is None:
1184+
return left is right
1185+
if left.is_meta or right.is_meta:
1186+
return False
1187+
return torch.equal(left, right)
1188+
1189+
1190+
def _safe_quantizer_amax(quantizer) -> torch.Tensor | None:
1191+
try:
1192+
return getattr(quantizer, "amax", None)
1193+
except AssertionError:
1194+
return None
1195+
1196+
1197+
def _linear_fusion_scales_match(left: nn.Module, right: nn.Module) -> bool:
1198+
left_iq = getattr(left, "input_quantizer", None)
1199+
right_iq = getattr(right, "input_quantizer", None)
1200+
if (
1201+
left_iq is not None
1202+
and right_iq is not None
1203+
and getattr(left_iq, "is_enabled", False)
1204+
and getattr(right_iq, "is_enabled", False)
1205+
and not _tensor_values_equal(_safe_quantizer_amax(left_iq), _safe_quantizer_amax(right_iq))
1206+
):
1207+
return False
1208+
1209+
left_wq = getattr(left, "weight_quantizer", None)
1210+
right_wq = getattr(right, "weight_quantizer", None)
1211+
if left_wq is None or right_wq is None:
1212+
return True
1213+
1214+
if isinstance(left_wq, SequentialQuantizer) and isinstance(right_wq, SequentialQuantizer):
1215+
if (
1216+
len(left_wq) > 0
1217+
and len(right_wq) > 0
1218+
and getattr(left_wq[-1], "is_enabled", False)
1219+
and getattr(right_wq[-1], "is_enabled", False)
1220+
):
1221+
return _tensor_values_equal(
1222+
_safe_quantizer_amax(left_wq[-1]), _safe_quantizer_amax(right_wq[-1])
1223+
)
1224+
return True
1225+
1226+
if hasattr(left_wq, "global_amax") and hasattr(right_wq, "global_amax"):
1227+
return _tensor_values_equal(left_wq.global_amax, right_wq.global_amax)
1228+
1229+
if getattr(left_wq, "is_enabled", False) and getattr(right_wq, "is_enabled", False):
1230+
return _tensor_values_equal(_safe_quantizer_amax(left_wq), _safe_quantizer_amax(right_wq))
1231+
1232+
return True
1233+
1234+
1235+
def sync_linear_attn_fused_projection_amax(model: nn.Module) -> int:
1236+
"""Sync quantizer amaxes for GDN projections that serving engines fuse.
1237+
1238+
Qwen3.5/Qwen3-Next GDN exports keep ``in_proj_qkv`` and ``in_proj_z`` as
1239+
separate HF tensors, but vLLM fuses them into ``in_proj_qkvz`` at load time.
1240+
Likewise ``in_proj_b`` and ``in_proj_a`` may be fused as ``in_proj_ba``.
1241+
Sharing the quantizer scale domains before export avoids serving-time fused
1242+
loaders having to reconcile different scalar/global scales.
1243+
1244+
Returns:
1245+
Number of projection pairs whose scale state changed.
1246+
"""
1247+
changed = 0
1248+
for _, sub_module in model.named_modules():
1249+
for left_name, right_name in _LINEAR_ATTN_FUSED_PAIRS:
1250+
left = getattr(sub_module, left_name, None)
1251+
right = getattr(sub_module, right_name, None)
1252+
if left is None or right is None:
1253+
continue
1254+
left_format = get_quantization_format(left)
1255+
right_format = get_quantization_format(right)
1256+
if left_format != right_format or left_format is None:
1257+
continue
1258+
if left_format == QUANTIZATION_NONE:
1259+
continue
1260+
matched_before = _linear_fusion_scales_match(left, right)
1261+
preprocess_linear_fusion([left, right])
1262+
if not matched_before:
1263+
changed += 1
1264+
return changed
1265+
1266+
11761267
def sync_moe_gate_up_amax(model: nn.Module) -> int:
11771268
"""Take element-wise max of gate and up weight quantizer amaxes per expert.
11781269

modelopt/torch/export/unified_export_hf.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
is_moe,
7474
is_quantlinear,
7575
set_expert_quantizer_amax,
76+
sync_linear_attn_fused_projection_amax,
7677
sync_moe_gate_up_amax,
7778
)
7879
from .model_config import (
@@ -810,6 +811,15 @@ def _export_transformers_checkpoint(
810811
f"Taking element-wise max of amaxes for serving-engine fusion."
811812
)
812813

814+
# Safety net for Qwen3.5/Qwen3-Next GDN projections. These remain separate
815+
# HF tensors, but vLLM fuses qkv+z and b+a at load time.
816+
synced = sync_linear_attn_fused_projection_amax(model)
817+
if synced:
818+
warnings.warn(
819+
f"Synced quantizer amax/global_amax for {synced} linear-attention "
820+
f"projection pair(s) that are fused by serving engines."
821+
)
822+
813823
# Process all quantized modules and export weights
814824
_process_quantized_modules(model, dtype, is_modelopt_qlora)
815825

modelopt/torch/opt/searcher.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,16 @@
5757
__all__ = ["BaseSearcher"]
5858

5959

60+
def _get_optional_env_float(name: str) -> float | None:
61+
value = os.environ.get(name)
62+
if not value:
63+
return None
64+
parsed_value = float(value)
65+
if parsed_value <= 0.0:
66+
raise ValueError(f"{name} must be positive, got {parsed_value}.")
67+
return parsed_value
68+
69+
6070
class BaseSearcher(ABC):
6171
"""A basic search interface that can be used to search/optimize a model.
6272
@@ -336,7 +346,14 @@ def __init__(
336346
self.constraints_to_candidate_costs = constraints_to_candidate_costs
337347
self.candidate_scores = candidate_scores
338348
self.objective_type = pulp.LpMinimize if objective_type == "minimize" else pulp.LpMaximize
339-
self.solver = pulp.PULP_CBC_CMD(msg=verbose)
349+
solver_kwargs = {}
350+
cbc_time_limit = _get_optional_env_float("MODELOPT_LPS_CBC_TIME_LIMIT")
351+
cbc_gap_rel = _get_optional_env_float("MODELOPT_LPS_CBC_GAP_REL")
352+
if cbc_time_limit is not None:
353+
solver_kwargs["timeLimit"] = cbc_time_limit
354+
if cbc_gap_rel is not None:
355+
solver_kwargs["gapRel"] = cbc_gap_rel
356+
self.solver = pulp.PULP_CBC_CMD(msg=verbose, **solver_kwargs)
340357

341358
self.num_layers = len(self.candidate_scores)
342359
self.num_candidates_per_layer = list(map(len, self.candidate_scores))

0 commit comments

Comments
 (0)