4545 AUTO_QUANTIZE_CONSTRAINT_KEYS ,
4646 COST_MODEL_ACTIVE_MOE ,
4747 COST_MODEL_WEIGHT ,
48+ _get_module_weight_numel ,
4849 get_auto_quantize_cost_model ,
4950 normalize_auto_quantize_constraints ,
5051)
5455from .utils import is_quantized_linear
5556
5657
58+ def _is_fused_experts_module (module : nn .Module ) -> bool :
59+ """Return True if ``module`` is a quantized fused-MoE-experts container.
60+
61+ These modules expose plural ``*_input_quantizer`` and ``*_weight_quantizers``
62+ (an ``nn.ModuleList`` of per-expert quantizers) instead of the singular
63+ ``input_quantizer`` / ``weight_quantizer`` attrs found on standard
64+ ``nn.Linear``-derived QuantModules. AutoQuantize hparam discovery and cost
65+ accounting need to recognize this layout to enumerate fused experts as
66+ search dimensions.
67+ """
68+ # Late import to avoid a circular import at module load time.
69+ try :
70+ from .plugins .huggingface import _QuantFusedExperts
71+ except ImportError :
72+ return False
73+ return isinstance (module , _QuantFusedExperts )
74+
75+
76+ # Quantizer attribute names that participate in AutoQuantize snapshot/restore.
77+ _STD_QUANTIZER_ATTRS = ("input_quantizer" , "weight_quantizer" , "output_quantizer" )
78+ _FUSED_EXPERTS_QUANTIZER_ATTRS = (
79+ "gate_up_proj_input_quantizer" ,
80+ "gate_up_proj_weight_quantizers" ,
81+ "down_proj_input_quantizer" ,
82+ "down_proj_weight_quantizers" ,
83+ )
84+
85+
86+ def _get_quantizer_attrs (module : nn .Module ) -> tuple [str , ...]:
87+ """Return the quantizer attribute names that AutoQuantize must snapshot/restore.
88+
89+ For fused MoE experts, this returns the four plural quantizer attrs (two
90+ shared input quantizers + two ``ModuleList`` of per-expert weight quantizers).
91+ For standard Linear-derived QuantModules, returns the canonical trio.
92+ """
93+ if _is_fused_experts_module (module ):
94+ return _FUSED_EXPERTS_QUANTIZER_ATTRS
95+ return _STD_QUANTIZER_ATTRS
96+
97+
98+ def _make_fresh_quantizer_for_attr (module : nn .Module , attr_name : str ) -> nn .Module :
99+ """Return a fresh, default quantizer object suitable to overwrite ``module.<attr_name>``.
100+
101+ For ModuleList attrs (per-expert quantizers on fused-experts modules), the
102+ returned ModuleList preserves the original list length so per-expert
103+ enumeration stays consistent across recipes.
104+ """
105+ current = getattr (module , attr_name , None )
106+ if isinstance (current , nn .ModuleList ):
107+ return nn .ModuleList (TensorQuantizer () for _ in range (len (current )))
108+ return TensorQuantizer ()
109+
110+
57111def estimate_quant_compression (quant_cfg : QuantizeConfig ) -> float :
58112 """Estimate the compression ratio of a quantization configuration.
59113
@@ -231,26 +285,26 @@ def __init__(
231285 # This is a hack; We dont want to make the input_quantizer, weight_quantizer, output_quantizer
232286 # a dynamic attribute for backward compatibility with the model_calib.py
233287 # TODO: Make input_quantizer, weight_quantizer, output_quantizer a dynamic attribute and get rid of this hack
288+ # NOTE: For fused-experts modules, the relevant attrs are plural
289+ # (``*_input_quantizer`` + ``*_weight_quantizers`` ModuleList) — see
290+ # ``_get_quantizer_attrs``. Both layouts share the same snapshot dict
291+ # shape so ``active.setter`` swaps the right child modules.
234292 self ._all_quantizer_choices = {quant_recipe : {} for quant_recipe in self .choices }
235293
236294 quant_recipe : QuantRecipe
237295 for quant_recipe in self .choices :
238296 for quant_module in self .quant_modules :
239- for quantizer_attr_name in [
240- "input_quantizer" ,
241- "weight_quantizer" ,
242- "output_quantizer" ,
243- ]:
244- setattr (quant_module , quantizer_attr_name , TensorQuantizer ())
297+ attr_names = _get_quantizer_attrs (quant_module )
298+ for attr_name in attr_names :
299+ setattr (
300+ quant_module ,
301+ attr_name ,
302+ _make_fresh_quantizer_for_attr (quant_module , attr_name ),
303+ )
245304
246305 set_quantizer_by_cfg (quant_module , quant_recipe .config .quant_cfg )
247306 self ._all_quantizer_choices [quant_recipe ][quant_module ] = {
248- quantizer_attr_name : getattr (quant_module , quantizer_attr_name )
249- for quantizer_attr_name in [
250- "input_quantizer" ,
251- "weight_quantizer" ,
252- "output_quantizer" ,
253- ]
307+ attr_name : getattr (quant_module , attr_name ) for attr_name in attr_names
254308 }
255309
256310 self .active = self .original
@@ -360,6 +414,20 @@ def attrs(self) -> list[str]:
360414 return ["name" , "cost_weight" , * super ().attrs ]
361415
362416
417+ _LINEAR_ATTN_QKVZ_RE = re .compile (r"^(.*?\.linear_attn)\.(?:in_proj_qkv|in_proj_z)$" )
418+ _LINEAR_ATTN_BA_RE = re .compile (r"^(.*?\.linear_attn)\.(?:in_proj_a|in_proj_b)$" )
419+
420+
421+ def _linear_attn_qkvz_group_key (_model , name : str ) -> str | None :
422+ m = _LINEAR_ATTN_QKVZ_RE .match (name )
423+ return f"{ m .group (1 )} /qkvz" if m else None
424+
425+
426+ def _linear_attn_ba_group_key (_model , name : str ) -> str | None :
427+ m = _LINEAR_ATTN_BA_RE .match (name )
428+ return f"{ m .group (1 )} /ba" if m else None
429+
430+
363431class _AutoQuantizeBaseSearcher (BaseSearcher , ABC ):
364432 """Base searcher for AutoQuantize algorithm."""
365433
@@ -381,6 +449,13 @@ class _AutoQuantizeBaseSearcher(BaseSearcher, ABC):
381449 r"^(.*?)\.(gate_proj|up_proj)$" , # gate_proj, up_proj for llama like models
382450 r"^(.*?)\.(\d+\.(w1|w2|w3))$" , # mixtral experts
383451 r"^(.*?)\.((w1_linear|w2_linear|w3_linear)\.\d+)$" , # dbrx experts
452+ # Qwen3.5/3.6 hybrid linear_attn: vLLM fuses (in_proj_qkv, in_proj_z)
453+ # into ``in_proj_qkvz`` and (in_proj_a, in_proj_b) into ``in_proj_ba`` and
454+ # requires fused shards to share quant_algo. Two callables (not one
455+ # regex) so qkv+z and a+b produce DIFFERENT group keys; each pair
456+ # stays with its own fusion partner.
457+ _linear_attn_qkvz_group_key ,
458+ _linear_attn_ba_group_key ,
384459 ]
385460
386461 score_module_rules = []
@@ -411,6 +486,7 @@ def default_state_dict(self) -> SearchStateDict:
411486 "cost" : {},
412487 "active_moe_expert_ratio" : None ,
413488 "cost_denominator" : None ,
489+ "disabled_layers" : None ,
414490 "candidate_stats" : defaultdict (dict ),
415491 "quantizer_states" : {},
416492 "best" : {"recipe" : {}, "constraints" : {}, "score" : float ("inf" ), "is_satisfied" : False },
@@ -433,9 +509,15 @@ def load_search_checkpoint(self) -> bool:
433509
434510 @staticmethod
435511 def _is_auto_quantize_module (module ):
436- return (
437- is_quantized_linear (module ) or isinstance (module , QuantLinearConvBase )
438- ) and isinstance (module , QuantModule )
512+ if (is_quantized_linear (module ) or isinstance (module , QuantLinearConvBase )) and isinstance (
513+ module , QuantModule
514+ ):
515+ return True
516+ # Fused MoE experts: a single ``QuantModule`` that owns N per-expert
517+ # weight quantizers in an ``nn.ModuleList`` plus shared input quantizers.
518+ # All N experts in a layer share one search dimension (one recipe per
519+ # fused module).
520+ return _is_fused_experts_module (module ) and isinstance (module , QuantModule )
439521
440522 @staticmethod
441523 def _get_search_recipes (quantization_formats ):
@@ -677,6 +759,7 @@ def before_search(self):
677759 self .cost_model = self .config ["cost_model" ]
678760 self .cost = self .config ["cost" ]
679761 self .active_moe_expert_ratio = self .config ["active_moe_expert_ratio" ]
762+ self .disabled_layers = self .config ["disabled_layers" ]
680763 self .cost_denominator = getattr (self , "cost_denominator" , None )
681764
682765 search_recipes = self ._get_search_recipes (self .config ["quantization_formats" ])
@@ -765,11 +848,9 @@ def _print_recipe_summary(best_recipe, total_cost, total_weight_size, prefix="Au
765848 @staticmethod
766849 def _get_total_weight_size (modules ):
767850 return sum (
768- (
769- module .weight .numel ()
770- if _AutoQuantizeBaseSearcher ._is_auto_quantize_module (module )
771- else 0
772- )
851+ _get_module_weight_numel (module )
852+ if _AutoQuantizeBaseSearcher ._is_auto_quantize_module (module )
853+ else 0
773854 for module in modules
774855 )
775856
@@ -1372,6 +1453,16 @@ def run_search_with_stats(self, max_weight_size, verbose=False):
13721453AutoQuantizeSearcher = AutoQuantizeGradientSearcher
13731454
13741455
1456+ def _as_list (value ) -> list :
1457+ if value is None :
1458+ return []
1459+ if isinstance (value , list ):
1460+ return value
1461+ if isinstance (value , tuple ):
1462+ return list (value )
1463+ return [value ]
1464+
1465+
13751466def get_auto_quantize_config (search_state , constraints = None , verbose = False ):
13761467 """Build a flat quant config dict from auto_quantize search_state.
13771468
@@ -1401,6 +1492,11 @@ def _cfg_to_dict(v):
14011492 return v
14021493
14031494 quant_cfg : list [dict ] = [{"quantizer_name" : "*" , "enable" : False }]
1495+ quant_cfg .extend (
1496+ {"quantizer_name" : pattern , "enable" : False }
1497+ for pattern in _as_list (search_state .get ("disabled_layers" ))
1498+ )
1499+ per_module_entries : list [dict ] = []
14041500 _per_module_attrs = ("input_quantizer" , "weight_quantizer" , "output_quantizer" )
14051501 # Track global (non per-module) recipe entries. Last recipe wins for each pattern.
14061502 global_entries : dict [str , dict ] = {}
@@ -1421,7 +1517,7 @@ def _cfg_to_dict(v):
14211517 }
14221518 if matched_cfg is not None :
14231519 entry ["cfg" ] = _cfg_to_dict (matched_cfg )
1424- quant_cfg .append (entry )
1520+ per_module_entries .append (entry )
14251521
14261522 # Collect non-per-module entries (e.g. *[kv]_bmm_quantizer) from winning recipes.
14271523 for recipe_entry in recipe .config .quant_cfg :
@@ -1438,7 +1534,10 @@ def _cfg_to_dict(v):
14381534 ge ["cfg" ] = _cfg_to_dict (cfg )
14391535 global_entries [pattern ] = ge
14401536
1537+ # Keep path-scoped recipe entries before explicit module entries so selected
1538+ # modules override default disables such as ``*lm_head*``.
14411539 quant_cfg .extend (global_entries .values ())
1540+ quant_cfg .extend (per_module_entries )
14421541 warnings .warn (
14431542 "get_auto_quantize_config: returned config uses algorithm='max'. "
14441543 "Per-recipe calibration algorithms (e.g. smoothquant, awq) are not preserved. "
@@ -1502,6 +1601,9 @@ def _match_quantizer_cfg(quant_cfg, quantizer_attr):
15021601 matched = None
15031602 matched_enable = None
15041603 for entry in quant_cfg :
1604+ parent_class = entry .get ("parent_class" ) if hasattr (entry , "get" ) else entry .parent_class
1605+ if parent_class is not None :
1606+ continue
15051607 pattern = entry ["quantizer_name" ]
15061608 cfg = entry .get ("cfg" )
15071609 enable = entry .get ("enable" , True )
0 commit comments