Skip to content

Commit b721f1d

Browse files
committed
Add active-MoE AutoQuant cost accounting
Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
1 parent 4e34480 commit b721f1d

4 files changed

Lines changed: 329 additions & 11 deletions

File tree

examples/llm_ptq/hf_ptq.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,8 @@ def forward_step(model, batch):
404404
],
405405
method=auto_quantize_method,
406406
checkpoint=auto_quantize_checkpoint,
407+
cost_model=args.auto_quantize_cost_model,
408+
active_moe_expert_ratio=args.auto_quantize_active_moe_expert_ratio,
407409
)
408410

409411
calibrate_loop = create_forward_loop(dataloader=calib_dataloader)
@@ -1401,6 +1403,27 @@ def parse_args() -> argparse.Namespace:
14011403
"(sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified."
14021404
),
14031405
)
1406+
parser.add_argument(
1407+
"--auto_quantize_cost_model",
1408+
type=str,
1409+
default="weight",
1410+
choices=["weight", "active_moe"],
1411+
help=(
1412+
"Cost model for auto_quantize effective-bits accounting. 'weight' counts all "
1413+
"quantizable weights equally. 'active_moe' scales routed MoE expert weights by "
1414+
"--auto_quantize_active_moe_expert_ratio, or infers top_k/num_experts from model config."
1415+
),
1416+
)
1417+
parser.add_argument(
1418+
"--auto_quantize_active_moe_expert_ratio",
1419+
type=float,
1420+
default=None,
1421+
help=(
1422+
"Routed MoE expert active ratio for --auto_quantize_cost_model active_moe. "
1423+
"For top-k MoE this is top_k / num_experts. If omitted, common model config "
1424+
"fields such as num_experts_per_tok and num_experts are used when available."
1425+
),
1426+
)
14041427
parser.add_argument(
14051428
"--moe_calib_experts_ratio",
14061429
type=float,
@@ -1434,6 +1457,18 @@ def parse_args() -> argparse.Namespace:
14341457
args = parser.parse_args()
14351458
if args.moe_calib_experts_ratio is not None and not (0.0 < args.moe_calib_experts_ratio <= 1.0):
14361459
parser.error("--moe_calib_experts_ratio must be in the range (0.0, 1.0].")
1460+
if args.auto_quantize_active_moe_expert_ratio is not None and not (
1461+
0.0 < args.auto_quantize_active_moe_expert_ratio <= 1.0
1462+
):
1463+
parser.error("--auto_quantize_active_moe_expert_ratio must be in the range (0.0, 1.0].")
1464+
if (
1465+
args.auto_quantize_cost_model == "weight"
1466+
and args.auto_quantize_active_moe_expert_ratio is not None
1467+
):
1468+
parser.error(
1469+
"--auto_quantize_active_moe_expert_ratio requires "
1470+
"--auto_quantize_cost_model active_moe."
1471+
)
14371472

14381473
if args.specdec_offline_dataset is not None and args.sparsity_fmt != "dense":
14391474
parser.error("--specdec_offline_dataset is only supported with --sparsity_fmt dense (PTQ).")

modelopt/torch/quantization/algorithms.py

Lines changed: 119 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,24 @@
4545
from .nn import QuantLinearConvBase, QuantModule, SequentialQuantizer, TensorQuantizer
4646
from .utils import is_quantized_linear
4747

48+
_ROUTED_MOE_EXPERT_NAME_RE = re.compile(r"(^|\.)experts(\.|$)")
49+
50+
51+
def _is_routed_moe_module_name(name: str) -> bool:
52+
"""Return True for routed MoE expert modules, excluding shared experts."""
53+
return "shared_expert" not in name and _ROUTED_MOE_EXPERT_NAME_RE.search(name) is not None
54+
55+
56+
def _get_active_moe_cost_weight(
57+
module_names: Sequence[str], active_moe_expert_ratio: float | None
58+
) -> float:
59+
"""Return cost multiplier for the active-MoE cost model."""
60+
if active_moe_expert_ratio is None:
61+
return 1.0
62+
if any(_is_routed_moe_module_name(n) for n in module_names):
63+
return active_moe_expert_ratio
64+
return 1.0
65+
4866

4967
def estimate_quant_compression(quant_cfg: QuantizeConfig) -> float:
5068
"""Estimate the compression ratio of a quantization configuration.
@@ -204,13 +222,16 @@ def __init__(
204222
score_modules: list[nn.Module] | None = None,
205223
name: str | None = None,
206224
quant_module_names: list[str] | None = None,
225+
cost_weight: float = 1.0,
207226
) -> None:
208227
"""Initializes Hparam with original value and choices."""
209228
choices = sorted({*(choices if choices else []), QuantRecipe(quant_cfg=None)})
210229
super().__init__(choices, original=choices[0])
211230

212231
self.name = name
213232
self.quant_module_names = quant_module_names or []
233+
assert cost_weight > 0.0, "cost_weight must be positive."
234+
self.cost_weight = cost_weight
214235

215236
self.quant_modules = list(set(quant_modules or []))
216237
self.score_modules = list(set(score_modules or self.quant_modules))
@@ -303,15 +324,18 @@ def get_score(self, recipe: QuantRecipe) -> float:
303324
total_score += importance.item()
304325
return total_score
305326

306-
def get_cost(self, recipe: QuantRecipe) -> float:
327+
def get_cost(self, recipe: QuantRecipe, cost_weight: float | None = None) -> float:
307328
"""Get the cost for a given recipe.
308329
309330
The cost is the total weight size of the quantizable modules multiplied by
310331
the compression ratio of the recipe.
311332
"""
333+
cost_weight = self.cost_weight if cost_weight is None else cost_weight
312334
cost = 0
313335
for quant_module in self.quant_modules:
314-
weight_size = _AutoQuantizeBaseSearcher._get_total_weight_size([quant_module])
336+
weight_size = (
337+
_AutoQuantizeBaseSearcher._get_total_weight_size([quant_module]) * cost_weight
338+
)
315339
parallel_state = getattr(quant_module, "parallel_state", None)
316340

317341
if parallel_state is None:
@@ -341,7 +365,7 @@ def get_cost(self, recipe: QuantRecipe) -> float:
341365
@property
342366
def attrs(self) -> list[str]:
343367
"""Return the attributes of the hparam for repr."""
344-
return ["name", *super().attrs]
368+
return ["name", "cost_weight", *super().attrs]
345369

346370

347371
class _AutoQuantizeBaseSearcher(BaseSearcher, ABC):
@@ -381,13 +405,18 @@ def default_search_config(self):
381405
"disabled_layers": None,
382406
"verbose": is_master(),
383407
"checkpoint": None,
408+
"cost_model": "weight",
409+
"active_moe_expert_ratio": None,
384410
}
385411

386412
@property
387413
def default_state_dict(self) -> SearchStateDict:
388414
"""Get the default state dict for AutoQuantize."""
389415
return {
390416
"method": self.method_name,
417+
"cost_model": "weight",
418+
"active_moe_expert_ratio": None,
419+
"cost_denominator": None,
391420
"candidate_stats": defaultdict(dict),
392421
"quantizer_states": {},
393422
"best": {"recipe": {}, "constraints": {}, "score": float("inf"), "is_satisfied": False},
@@ -403,6 +432,18 @@ def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig:
403432
assert config["forward_step"] is not None, (
404433
"`forward_step` must be provided for `auto_quantize`."
405434
)
435+
if config["cost_model"] not in ("weight", "active_moe"):
436+
raise ValueError(
437+
f"Invalid cost_model: {config['cost_model']}. "
438+
"Valid options are 'weight' and 'active_moe'."
439+
)
440+
active_moe_expert_ratio = config["active_moe_expert_ratio"]
441+
if active_moe_expert_ratio is not None and not (0.0 < active_moe_expert_ratio <= 1.0):
442+
raise ValueError("active_moe_expert_ratio must be in the range (0.0, 1.0].")
443+
if config["cost_model"] == "active_moe" and active_moe_expert_ratio is None:
444+
raise ValueError(
445+
"active_moe_expert_ratio must be set when using active_moe cost accounting."
446+
)
406447
return config
407448

408449
def load_search_checkpoint(self) -> bool:
@@ -490,7 +531,9 @@ def _get_score_module_from_name(
490531
)
491532
return quant_module
492533

493-
def insert_hparams_after_merge_rules(self, model, quant_recipes, disabled_layers=None):
534+
def insert_hparams_after_merge_rules(
535+
self, model, quant_recipes, disabled_layers=None, active_moe_expert_ratio=None
536+
):
494537
"""Restrict the search space using the merge rules and insert the hparams for the model."""
495538
# TRTLLM fuses linear layers such as q_proj, k_proj, v_proj into same layer
496539
# Hence we need to restrict the search space so that all these layers share the same recipe
@@ -545,14 +588,17 @@ def insert_hparams_after_merge_rules(self, model, quant_recipes, disabled_layers
545588
quant_modules = [module for module, _, _, _ in module_info_list]
546589
disabled = any(disabled for _, _, disabled, _ in module_info_list)
547590
score_modules = [score_module for _, _, _, score_module in module_info_list]
591+
quant_module_names = [name for _, name, _, _ in module_info_list]
592+
cost_weight = _get_active_moe_cost_weight(quant_module_names, active_moe_expert_ratio)
548593

549594
_quant_recipes = None if disabled else quant_recipes
550595
hparam = QuantRecipeHparam(
551596
_quant_recipes,
552597
quant_modules=quant_modules,
553598
score_modules=score_modules,
554599
name=str(group_key),
555-
quant_module_names=[name for _, name, _, _ in module_info_list],
600+
quant_module_names=quant_module_names,
601+
cost_weight=cost_weight,
556602
)
557603

558604
for module in quant_modules:
@@ -584,23 +630,30 @@ def initialize_candidate_stats(self):
584630
if not isinstance(hparam, QuantRecipeHparam):
585631
continue
586632

587-
formats, scores, costs = [], [], []
633+
formats, scores, costs, active_costs = [], [], [], []
588634
prev_score = float("inf")
635+
constraint_cost_weight = (
636+
hparam.cost_weight if self.config["cost_model"] == "active_moe" else 1.0
637+
)
589638
for recipe in hparam.choices:
590639
formats.append(recipe)
591640

592641
score = hparam.get_score(recipe) # type: ignore [arg-type]
593-
cost = hparam.get_cost(recipe) # type: ignore [arg-type]
642+
cost = hparam.get_cost(recipe, cost_weight=constraint_cost_weight) # type: ignore [arg-type]
643+
active_cost = hparam.get_cost(recipe, cost_weight=hparam.cost_weight) # type: ignore [arg-type]
594644

595645
score = min(score, prev_score) # TODO: Should we get rid of this?
596646
scores.append(score)
597647
costs.append(cost)
648+
active_costs.append(active_cost)
598649
prev_score = score
599650

600651
self.candidate_stats[name]["formats"] = formats
601652
self.candidate_stats[name]["scores"] = scores
602653
self.candidate_stats[name]["costs"] = costs
654+
self.candidate_stats[name]["active_costs"] = active_costs
603655
self.candidate_stats[name]["module_names"] = hparam.quant_module_names
656+
self.candidate_stats[name]["cost_weight"] = hparam.cost_weight
604657

605658
def _run_func(self, func, num_iters=1, desc=""):
606659
for i, data in tqdm(
@@ -625,12 +678,30 @@ def before_search(self):
625678
f"Checkpoint method '{restored_method}' does not match current method "
626679
f"'{self.method_name}'. Use a different checkpoint path."
627680
)
681+
restored_cost_model = getattr(self, "cost_model", "weight")
682+
restored_active_moe_expert_ratio = getattr(self, "active_moe_expert_ratio", None)
683+
if self.candidate_stats and (
684+
restored_cost_model != self.config["cost_model"]
685+
or restored_active_moe_expert_ratio != self.config["active_moe_expert_ratio"]
686+
):
687+
raise ValueError(
688+
"Checkpoint AutoQuantize cost model does not match current search config: "
689+
f"checkpoint=({restored_cost_model}, {restored_active_moe_expert_ratio}), "
690+
f"current=({self.config['cost_model']}, {self.config['active_moe_expert_ratio']}). "
691+
"Use a different checkpoint path."
692+
)
628693
self.method = self.method_name
694+
self.cost_model = self.config["cost_model"]
695+
self.active_moe_expert_ratio = self.config["active_moe_expert_ratio"]
696+
self.cost_denominator = getattr(self, "cost_denominator", None)
629697

630698
search_recipes = self._get_search_recipes(self.config["quantization_formats"])
631699
self._verify_constraint(search_recipes)
632700
self.insert_hparams_after_merge_rules(
633-
self.model, search_recipes, self.config["disabled_layers"]
701+
self.model,
702+
search_recipes,
703+
self.config["disabled_layers"],
704+
self.config["active_moe_expert_ratio"],
634705
)
635706

636707
QuantRecipe.disable_folding_pqs_to_weights()
@@ -720,6 +791,17 @@ def _get_total_weight_size(modules):
720791
for module in modules
721792
)
722793

794+
@staticmethod
795+
def _get_total_weight_size_from_named_modules(named_modules, active_moe_expert_ratio=None):
796+
total_weight_size = 0.0
797+
for name, module in named_modules:
798+
if not _AutoQuantizeBaseSearcher._is_auto_quantize_module(module):
799+
continue
800+
total_weight_size += module.weight.numel() * _get_active_moe_cost_weight(
801+
[name], active_moe_expert_ratio
802+
)
803+
return total_weight_size
804+
723805
def _get_constraints_for_search(self, max_weight_size, lower_bound=None):
724806
constraints = {
725807
"weight_size_after_compression": (
@@ -729,6 +811,12 @@ def _get_constraints_for_search(self, max_weight_size, lower_bound=None):
729811
}
730812
return constraints, "weight_size_after_compression"
731813

814+
def _get_search_lower_bounds(self):
815+
cost_model = getattr(self, "cost_model", getattr(self, "config", {}).get("cost_model"))
816+
if cost_model == "active_moe":
817+
return [0.99, 0.90, None]
818+
return [None, 0.99, 0.90]
819+
732820
@abstractmethod
733821
def run_search_with_stats(self, max_weight_size, verbose=False):
734822
"""Run the search with stats to get the best recipe and whether the constraints are satisfied."""
@@ -742,8 +830,24 @@ def run_search(self):
742830
)
743831

744832
compression = self._get_formatted_weight_compression_constraint()
745-
total_weight_size = self._get_total_weight_size(self.model.modules())
833+
if self.config["cost_model"] == "active_moe":
834+
total_weight_size = self._get_total_weight_size_from_named_modules(
835+
self.model.named_modules(), self.config["active_moe_expert_ratio"]
836+
)
837+
else:
838+
total_weight_size = self._get_total_weight_size(self.model.modules())
839+
self.cost_denominator = total_weight_size
746840
max_weight_size = total_weight_size * compression
841+
if verbose:
842+
print_rank_0(
843+
"AutoQuantize cost model: "
844+
f"{self.config['cost_model']}"
845+
+ (
846+
f" (active_moe_expert_ratio={self.config['active_moe_expert_ratio']})"
847+
if self.config["cost_model"] == "active_moe"
848+
else ""
849+
)
850+
)
747851

748852
# Run the search with stats to get the best recipe and whether the constraints are satisfied
749853
best_recipe_info, is_satisfied = self.run_search_with_stats(max_weight_size, verbose)
@@ -1048,7 +1152,7 @@ def run_search_with_stats(self, max_weight_size, verbose=False):
10481152
"""
10491153
# TODO: Do this only for rank 0 in the respective pipeline group
10501154

1051-
for lower_bound in [None, 0.99, 0.90]:
1155+
for lower_bound in self._get_search_lower_bounds():
10521156
# The LP solver for auto_quantize sometimes fails to find a solution if a lower bound is not
10531157
# specified. I dont know why this happens.
10541158
# As a workaround, lets specify a lower bound for the weight compression if previous
@@ -1377,7 +1481,9 @@ def _resolve_best_recipe(search_state, constraints, verbose=False):
13771481
effective_bits = constraints["effective_bits"]
13781482
compression = effective_bits / 16.0
13791483
candidate_stats = search_state["candidate_stats"]
1380-
total_weight_size = sum(s["costs"][-1] for s in candidate_stats.values())
1484+
total_weight_size = search_state.get("cost_denominator") or sum(
1485+
s["costs"][-1] for s in candidate_stats.values()
1486+
)
13811487
max_weight_size = total_weight_size * compression
13821488
method = search_state["method"]
13831489

@@ -1391,6 +1497,8 @@ def _resolve_best_recipe(search_state, constraints, verbose=False):
13911497
)
13921498

13931499
searcher.candidate_stats = candidate_stats
1500+
searcher.cost_model = search_state.get("cost_model", "weight")
1501+
searcher.config = {"cost_model": searcher.cost_model}
13941502
best_recipe_info, _ = searcher.run_search_with_stats(max_weight_size, verbose=verbose)
13951503

13961504
best_recipe = {name: info["format"] for name, info in best_recipe_info.items()}

0 commit comments

Comments
 (0)