4545from .nn import QuantLinearConvBase , QuantModule , SequentialQuantizer , TensorQuantizer
4646from .utils import is_quantized_linear
4747
48+ _ROUTED_MOE_EXPERT_NAME_RE = re .compile (r"(^|\.)experts(\.|$)" )
49+
50+
51+ def _is_routed_moe_module_name (name : str ) -> bool :
52+ """Return True for routed MoE expert modules, excluding shared experts."""
53+ return "shared_expert" not in name and _ROUTED_MOE_EXPERT_NAME_RE .search (name ) is not None
54+
55+
56+ def _get_active_moe_cost_weight (
57+ module_names : Sequence [str ], active_moe_expert_ratio : float | None
58+ ) -> float :
59+ """Return cost multiplier for the active-MoE cost model."""
60+ if active_moe_expert_ratio is None :
61+ return 1.0
62+ if any (_is_routed_moe_module_name (n ) for n in module_names ):
63+ return active_moe_expert_ratio
64+ return 1.0
65+
4866
4967def estimate_quant_compression (quant_cfg : QuantizeConfig ) -> float :
5068 """Estimate the compression ratio of a quantization configuration.
@@ -204,13 +222,16 @@ def __init__(
204222 score_modules : list [nn .Module ] | None = None ,
205223 name : str | None = None ,
206224 quant_module_names : list [str ] | None = None ,
225+ cost_weight : float = 1.0 ,
207226 ) -> None :
208227 """Initializes Hparam with original value and choices."""
209228 choices = sorted ({* (choices if choices else []), QuantRecipe (quant_cfg = None )})
210229 super ().__init__ (choices , original = choices [0 ])
211230
212231 self .name = name
213232 self .quant_module_names = quant_module_names or []
233+ assert cost_weight > 0.0 , "cost_weight must be positive."
234+ self .cost_weight = cost_weight
214235
215236 self .quant_modules = list (set (quant_modules or []))
216237 self .score_modules = list (set (score_modules or self .quant_modules ))
@@ -303,15 +324,18 @@ def get_score(self, recipe: QuantRecipe) -> float:
303324 total_score += importance .item ()
304325 return total_score
305326
306- def get_cost (self , recipe : QuantRecipe ) -> float :
327+ def get_cost (self , recipe : QuantRecipe , cost_weight : float | None = None ) -> float :
307328 """Get the cost for a given recipe.
308329
309330 The cost is the total weight size of the quantizable modules multiplied by
310331 the compression ratio of the recipe.
311332 """
333+ cost_weight = self .cost_weight if cost_weight is None else cost_weight
312334 cost = 0
313335 for quant_module in self .quant_modules :
314- weight_size = _AutoQuantizeBaseSearcher ._get_total_weight_size ([quant_module ])
336+ weight_size = (
337+ _AutoQuantizeBaseSearcher ._get_total_weight_size ([quant_module ]) * cost_weight
338+ )
315339 parallel_state = getattr (quant_module , "parallel_state" , None )
316340
317341 if parallel_state is None :
@@ -341,7 +365,7 @@ def get_cost(self, recipe: QuantRecipe) -> float:
341365 @property
342366 def attrs (self ) -> list [str ]:
343367 """Return the attributes of the hparam for repr."""
344- return ["name" , * super ().attrs ]
368+ return ["name" , "cost_weight" , * super ().attrs ]
345369
346370
347371class _AutoQuantizeBaseSearcher (BaseSearcher , ABC ):
@@ -381,13 +405,18 @@ def default_search_config(self):
381405 "disabled_layers" : None ,
382406 "verbose" : is_master (),
383407 "checkpoint" : None ,
408+ "cost_model" : "weight" ,
409+ "active_moe_expert_ratio" : None ,
384410 }
385411
386412 @property
387413 def default_state_dict (self ) -> SearchStateDict :
388414 """Get the default state dict for AutoQuantize."""
389415 return {
390416 "method" : self .method_name ,
417+ "cost_model" : "weight" ,
418+ "active_moe_expert_ratio" : None ,
419+ "cost_denominator" : None ,
391420 "candidate_stats" : defaultdict (dict ),
392421 "quantizer_states" : {},
393422 "best" : {"recipe" : {}, "constraints" : {}, "score" : float ("inf" ), "is_satisfied" : False },
@@ -403,6 +432,18 @@ def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig:
403432 assert config ["forward_step" ] is not None , (
404433 "`forward_step` must be provided for `auto_quantize`."
405434 )
435+ if config ["cost_model" ] not in ("weight" , "active_moe" ):
436+ raise ValueError (
437+ f"Invalid cost_model: { config ['cost_model' ]} . "
438+ "Valid options are 'weight' and 'active_moe'."
439+ )
440+ active_moe_expert_ratio = config ["active_moe_expert_ratio" ]
441+ if active_moe_expert_ratio is not None and not (0.0 < active_moe_expert_ratio <= 1.0 ):
442+ raise ValueError ("active_moe_expert_ratio must be in the range (0.0, 1.0]." )
443+ if config ["cost_model" ] == "active_moe" and active_moe_expert_ratio is None :
444+ raise ValueError (
445+ "active_moe_expert_ratio must be set when using active_moe cost accounting."
446+ )
406447 return config
407448
408449 def load_search_checkpoint (self ) -> bool :
@@ -490,7 +531,9 @@ def _get_score_module_from_name(
490531 )
491532 return quant_module
492533
493- def insert_hparams_after_merge_rules (self , model , quant_recipes , disabled_layers = None ):
534+ def insert_hparams_after_merge_rules (
535+ self , model , quant_recipes , disabled_layers = None , active_moe_expert_ratio = None
536+ ):
494537 """Restrict the search space using the merge rules and insert the hparams for the model."""
495538 # TRTLLM fuses linear layers such as q_proj, k_proj, v_proj into same layer
496539 # Hence we need to restrict the search space so that all these layers share the same recipe
@@ -545,14 +588,17 @@ def insert_hparams_after_merge_rules(self, model, quant_recipes, disabled_layers
545588 quant_modules = [module for module , _ , _ , _ in module_info_list ]
546589 disabled = any (disabled for _ , _ , disabled , _ in module_info_list )
547590 score_modules = [score_module for _ , _ , _ , score_module in module_info_list ]
591+ quant_module_names = [name for _ , name , _ , _ in module_info_list ]
592+ cost_weight = _get_active_moe_cost_weight (quant_module_names , active_moe_expert_ratio )
548593
549594 _quant_recipes = None if disabled else quant_recipes
550595 hparam = QuantRecipeHparam (
551596 _quant_recipes ,
552597 quant_modules = quant_modules ,
553598 score_modules = score_modules ,
554599 name = str (group_key ),
555- quant_module_names = [name for _ , name , _ , _ in module_info_list ],
600+ quant_module_names = quant_module_names ,
601+ cost_weight = cost_weight ,
556602 )
557603
558604 for module in quant_modules :
@@ -584,23 +630,30 @@ def initialize_candidate_stats(self):
584630 if not isinstance (hparam , QuantRecipeHparam ):
585631 continue
586632
587- formats , scores , costs = [], [], []
633+ formats , scores , costs , active_costs = [], [], [], []
588634 prev_score = float ("inf" )
635+ constraint_cost_weight = (
636+ hparam .cost_weight if self .config ["cost_model" ] == "active_moe" else 1.0
637+ )
589638 for recipe in hparam .choices :
590639 formats .append (recipe )
591640
592641 score = hparam .get_score (recipe ) # type: ignore [arg-type]
593- cost = hparam .get_cost (recipe ) # type: ignore [arg-type]
642+ cost = hparam .get_cost (recipe , cost_weight = constraint_cost_weight ) # type: ignore [arg-type]
643+ active_cost = hparam .get_cost (recipe , cost_weight = hparam .cost_weight ) # type: ignore [arg-type]
594644
595645 score = min (score , prev_score ) # TODO: Should we get rid of this?
596646 scores .append (score )
597647 costs .append (cost )
648+ active_costs .append (active_cost )
598649 prev_score = score
599650
600651 self .candidate_stats [name ]["formats" ] = formats
601652 self .candidate_stats [name ]["scores" ] = scores
602653 self .candidate_stats [name ]["costs" ] = costs
654+ self .candidate_stats [name ]["active_costs" ] = active_costs
603655 self .candidate_stats [name ]["module_names" ] = hparam .quant_module_names
656+ self .candidate_stats [name ]["cost_weight" ] = hparam .cost_weight
604657
605658 def _run_func (self , func , num_iters = 1 , desc = "" ):
606659 for i , data in tqdm (
@@ -625,12 +678,30 @@ def before_search(self):
625678 f"Checkpoint method '{ restored_method } ' does not match current method "
626679 f"'{ self .method_name } '. Use a different checkpoint path."
627680 )
681+ restored_cost_model = getattr (self , "cost_model" , "weight" )
682+ restored_active_moe_expert_ratio = getattr (self , "active_moe_expert_ratio" , None )
683+ if self .candidate_stats and (
684+ restored_cost_model != self .config ["cost_model" ]
685+ or restored_active_moe_expert_ratio != self .config ["active_moe_expert_ratio" ]
686+ ):
687+ raise ValueError (
688+ "Checkpoint AutoQuantize cost model does not match current search config: "
689+ f"checkpoint=({ restored_cost_model } , { restored_active_moe_expert_ratio } ), "
690+ f"current=({ self .config ['cost_model' ]} , { self .config ['active_moe_expert_ratio' ]} ). "
691+ "Use a different checkpoint path."
692+ )
628693 self .method = self .method_name
694+ self .cost_model = self .config ["cost_model" ]
695+ self .active_moe_expert_ratio = self .config ["active_moe_expert_ratio" ]
696+ self .cost_denominator = getattr (self , "cost_denominator" , None )
629697
630698 search_recipes = self ._get_search_recipes (self .config ["quantization_formats" ])
631699 self ._verify_constraint (search_recipes )
632700 self .insert_hparams_after_merge_rules (
633- self .model , search_recipes , self .config ["disabled_layers" ]
701+ self .model ,
702+ search_recipes ,
703+ self .config ["disabled_layers" ],
704+ self .config ["active_moe_expert_ratio" ],
634705 )
635706
636707 QuantRecipe .disable_folding_pqs_to_weights ()
@@ -720,6 +791,17 @@ def _get_total_weight_size(modules):
720791 for module in modules
721792 )
722793
794+ @staticmethod
795+ def _get_total_weight_size_from_named_modules (named_modules , active_moe_expert_ratio = None ):
796+ total_weight_size = 0.0
797+ for name , module in named_modules :
798+ if not _AutoQuantizeBaseSearcher ._is_auto_quantize_module (module ):
799+ continue
800+ total_weight_size += module .weight .numel () * _get_active_moe_cost_weight (
801+ [name ], active_moe_expert_ratio
802+ )
803+ return total_weight_size
804+
723805 def _get_constraints_for_search (self , max_weight_size , lower_bound = None ):
724806 constraints = {
725807 "weight_size_after_compression" : (
@@ -729,6 +811,12 @@ def _get_constraints_for_search(self, max_weight_size, lower_bound=None):
729811 }
730812 return constraints , "weight_size_after_compression"
731813
814+ def _get_search_lower_bounds (self ):
815+ cost_model = getattr (self , "cost_model" , getattr (self , "config" , {}).get ("cost_model" ))
816+ if cost_model == "active_moe" :
817+ return [0.99 , 0.90 , None ]
818+ return [None , 0.99 , 0.90 ]
819+
732820 @abstractmethod
733821 def run_search_with_stats (self , max_weight_size , verbose = False ):
734822 """Run the search with stats to get the best recipe and whether the constraints are satisfied."""
@@ -742,8 +830,24 @@ def run_search(self):
742830 )
743831
744832 compression = self ._get_formatted_weight_compression_constraint ()
745- total_weight_size = self ._get_total_weight_size (self .model .modules ())
833+ if self .config ["cost_model" ] == "active_moe" :
834+ total_weight_size = self ._get_total_weight_size_from_named_modules (
835+ self .model .named_modules (), self .config ["active_moe_expert_ratio" ]
836+ )
837+ else :
838+ total_weight_size = self ._get_total_weight_size (self .model .modules ())
839+ self .cost_denominator = total_weight_size
746840 max_weight_size = total_weight_size * compression
841+ if verbose :
842+ print_rank_0 (
843+ "AutoQuantize cost model: "
844+ f"{ self .config ['cost_model' ]} "
845+ + (
846+ f" (active_moe_expert_ratio={ self .config ['active_moe_expert_ratio' ]} )"
847+ if self .config ["cost_model" ] == "active_moe"
848+ else ""
849+ )
850+ )
747851
748852 # Run the search with stats to get the best recipe and whether the constraints are satisfied
749853 best_recipe_info , is_satisfied = self .run_search_with_stats (max_weight_size , verbose )
@@ -1048,7 +1152,7 @@ def run_search_with_stats(self, max_weight_size, verbose=False):
10481152 """
10491153 # TODO: Do this only for rank 0 in the respective pipeline group
10501154
1051- for lower_bound in [ None , 0.99 , 0.90 ] :
1155+ for lower_bound in self . _get_search_lower_bounds () :
10521156 # The LP solver for auto_quantize sometimes fails to find a solution if a lower bound is not
10531157 # specified. I dont know why this happens.
10541158 # As a workaround, lets specify a lower bound for the weight compression if previous
@@ -1377,7 +1481,9 @@ def _resolve_best_recipe(search_state, constraints, verbose=False):
13771481 effective_bits = constraints ["effective_bits" ]
13781482 compression = effective_bits / 16.0
13791483 candidate_stats = search_state ["candidate_stats" ]
1380- total_weight_size = sum (s ["costs" ][- 1 ] for s in candidate_stats .values ())
1484+ total_weight_size = search_state .get ("cost_denominator" ) or sum (
1485+ s ["costs" ][- 1 ] for s in candidate_stats .values ()
1486+ )
13811487 max_weight_size = total_weight_size * compression
13821488 method = search_state ["method" ]
13831489
@@ -1391,6 +1497,8 @@ def _resolve_best_recipe(search_state, constraints, verbose=False):
13911497 )
13921498
13931499 searcher .candidate_stats = candidate_stats
1500+ searcher .cost_model = search_state .get ("cost_model" , "weight" )
1501+ searcher .config = {"cost_model" : searcher .cost_model }
13941502 best_recipe_info , _ = searcher .run_search_with_stats (max_weight_size , verbose = verbose )
13951503
13961504 best_recipe = {name : info ["format" ] for name , info in best_recipe_info .items ()}
0 commit comments