5454from .utils import is_quantized_linear
5555
5656
57+ def _is_fused_experts_module (module : nn .Module ) -> bool :
58+ """Return True if ``module`` is a quantized fused-MoE-experts container.
59+
60+ These modules expose plural ``*_input_quantizer`` and ``*_weight_quantizers``
61+ (an ``nn.ModuleList`` of per-expert quantizers) instead of the singular
62+ ``input_quantizer`` / ``weight_quantizer`` attrs found on standard
63+ ``nn.Linear``-derived QuantModules. AutoQuantize hparam discovery and cost
64+ accounting need to recognize this layout to enumerate fused experts as
65+ search dimensions.
66+ """
67+ # Late import to avoid a circular import at module load time.
68+ try :
69+ from .plugins .huggingface import _QuantFusedExperts
70+ except ImportError :
71+ return False
72+ return isinstance (module , _QuantFusedExperts )
73+
74+
75+ # Quantizer attribute names that participate in AutoQuantize snapshot/restore.
76+ _STD_QUANTIZER_ATTRS = ("input_quantizer" , "weight_quantizer" , "output_quantizer" )
77+ _FUSED_EXPERTS_QUANTIZER_ATTRS = (
78+ "gate_up_proj_input_quantizer" ,
79+ "gate_up_proj_weight_quantizers" ,
80+ "down_proj_input_quantizer" ,
81+ "down_proj_weight_quantizers" ,
82+ )
83+
84+
85+ def _get_quantizer_attrs (module : nn .Module ) -> tuple [str , ...]:
86+ """Return the quantizer attribute names that AutoQuantize must snapshot/restore.
87+
88+ For fused MoE experts, this returns the four plural quantizer attrs (two
89+ shared input quantizers + two ``ModuleList`` of per-expert weight quantizers).
90+ For standard Linear-derived QuantModules, returns the canonical trio.
91+ """
92+ if _is_fused_experts_module (module ):
93+ return _FUSED_EXPERTS_QUANTIZER_ATTRS
94+ return _STD_QUANTIZER_ATTRS
95+
96+
97+ def _make_fresh_quantizer_for_attr (module : nn .Module , attr_name : str ) -> nn .Module :
98+ """Return a fresh, default quantizer object suitable to overwrite ``module.<attr_name>``.
99+
100+ For ModuleList attrs (per-expert quantizers on fused-experts modules), the
101+ returned ModuleList preserves the original list length so per-expert
102+ enumeration stays consistent across recipes.
103+ """
104+ current = getattr (module , attr_name , None )
105+ if isinstance (current , nn .ModuleList ):
106+ return nn .ModuleList (TensorQuantizer () for _ in range (len (current )))
107+ return TensorQuantizer ()
108+
109+
110+ def _get_module_weight_numel (module : nn .Module ) -> int :
111+ """Return the total parameter count of a module's quantizable weights.
112+
113+ Standard QuantLinear modules have a single ``weight`` parameter. Fused
114+ experts modules have two 3-D fused parameters (``gate_up_proj`` and
115+ ``down_proj``) instead — both contribute to the cost accounting.
116+ """
117+ if _is_fused_experts_module (module ):
118+ total = 0
119+ for attr in ("gate_up_proj" , "down_proj" ):
120+ param = getattr (module , attr , None )
121+ if param is not None :
122+ total += param .numel ()
123+ return total
124+ weight = getattr (module , "weight" , None )
125+ return weight .numel () if weight is not None else 0
126+
127+
57128def estimate_quant_compression (quant_cfg : QuantizeConfig ) -> float :
58129 """Estimate the compression ratio of a quantization configuration.
59130
@@ -231,26 +302,26 @@ def __init__(
231302 # This is a hack; We dont want to make the input_quantizer, weight_quantizer, output_quantizer
232303 # a dynamic attribute for backward compatibility with the model_calib.py
233304 # TODO: Make input_quantizer, weight_quantizer, output_quantizer a dynamic attribute and get rid of this hack
305+ # NOTE: For fused-experts modules, the relevant attrs are plural
306+ # (``*_input_quantizer`` + ``*_weight_quantizers`` ModuleList) — see
307+ # ``_get_quantizer_attrs``. Both layouts share the same snapshot dict
308+ # shape so ``active.setter`` swaps the right child modules.
234309 self ._all_quantizer_choices = {quant_recipe : {} for quant_recipe in self .choices }
235310
236311 quant_recipe : QuantRecipe
237312 for quant_recipe in self .choices :
238313 for quant_module in self .quant_modules :
239- for quantizer_attr_name in [
240- "input_quantizer" ,
241- "weight_quantizer" ,
242- "output_quantizer" ,
243- ]:
244- setattr (quant_module , quantizer_attr_name , TensorQuantizer ())
314+ attr_names = _get_quantizer_attrs (quant_module )
315+ for attr_name in attr_names :
316+ setattr (
317+ quant_module ,
318+ attr_name ,
319+ _make_fresh_quantizer_for_attr (quant_module , attr_name ),
320+ )
245321
246322 set_quantizer_by_cfg (quant_module , quant_recipe .config .quant_cfg )
247323 self ._all_quantizer_choices [quant_recipe ][quant_module ] = {
248- quantizer_attr_name : getattr (quant_module , quantizer_attr_name )
249- for quantizer_attr_name in [
250- "input_quantizer" ,
251- "weight_quantizer" ,
252- "output_quantizer" ,
253- ]
324+ attr_name : getattr (quant_module , attr_name ) for attr_name in attr_names
254325 }
255326
256327 self .active = self .original
@@ -360,6 +431,20 @@ def attrs(self) -> list[str]:
360431 return ["name" , "cost_weight" , * super ().attrs ]
361432
362433
434+ _LINEAR_ATTN_QKVZ_RE = re .compile (r"^(.*?\.linear_attn)\.(?:in_proj_qkv|in_proj_z)$" )
435+ _LINEAR_ATTN_BA_RE = re .compile (r"^(.*?\.linear_attn)\.(?:in_proj_a|in_proj_b)$" )
436+
437+
438+ def _linear_attn_qkvz_group_key (_model , name : str ) -> str | None :
439+ m = _LINEAR_ATTN_QKVZ_RE .match (name )
440+ return f"{ m .group (1 )} /qkvz" if m else None
441+
442+
443+ def _linear_attn_ba_group_key (_model , name : str ) -> str | None :
444+ m = _LINEAR_ATTN_BA_RE .match (name )
445+ return f"{ m .group (1 )} /ba" if m else None
446+
447+
363448class _AutoQuantizeBaseSearcher (BaseSearcher , ABC ):
364449 """Base searcher for AutoQuantize algorithm."""
365450
@@ -381,6 +466,13 @@ class _AutoQuantizeBaseSearcher(BaseSearcher, ABC):
381466 r"^(.*?)\.(gate_proj|up_proj)$" , # gate_proj, up_proj for llama like models
382467 r"^(.*?)\.(\d+\.(w1|w2|w3))$" , # mixtral experts
383468 r"^(.*?)\.((w1_linear|w2_linear|w3_linear)\.\d+)$" , # dbrx experts
469+ # Qwen3.5/3.6 hybrid linear_attn: vLLM fuses (in_proj_qkv, in_proj_z)
470+ # into ``in_proj_qkvz`` and (in_proj_a, in_proj_b) into ``in_proj_ba`` and
471+ # requires fused shards to share quant_algo. Two callables (not one
472+ # regex) so qkv+z and a+b produce DIFFERENT group keys; each pair
473+ # stays with its own fusion partner.
474+ _linear_attn_qkvz_group_key ,
475+ _linear_attn_ba_group_key ,
384476 ]
385477
386478 score_module_rules = []
@@ -411,6 +503,7 @@ def default_state_dict(self) -> SearchStateDict:
411503 "cost" : {},
412504 "active_moe_expert_ratio" : None ,
413505 "cost_denominator" : None ,
506+ "disabled_layers" : None ,
414507 "candidate_stats" : defaultdict (dict ),
415508 "quantizer_states" : {},
416509 "best" : {"recipe" : {}, "constraints" : {}, "score" : float ("inf" ), "is_satisfied" : False },
@@ -433,9 +526,15 @@ def load_search_checkpoint(self) -> bool:
433526
434527 @staticmethod
435528 def _is_auto_quantize_module (module ):
436- return (
437- is_quantized_linear (module ) or isinstance (module , QuantLinearConvBase )
438- ) and isinstance (module , QuantModule )
529+ if (is_quantized_linear (module ) or isinstance (module , QuantLinearConvBase )) and isinstance (
530+ module , QuantModule
531+ ):
532+ return True
533+ # Fused MoE experts: a single ``QuantModule`` that owns N per-expert
534+ # weight quantizers in an ``nn.ModuleList`` plus shared input quantizers.
535+ # All N experts in a layer share one search dimension (one recipe per
536+ # fused module).
537+ return _is_fused_experts_module (module ) and isinstance (module , QuantModule )
439538
440539 @staticmethod
441540 def _get_search_recipes (quantization_formats ):
@@ -677,6 +776,7 @@ def before_search(self):
677776 self .cost_model = self .config ["cost_model" ]
678777 self .cost = self .config ["cost" ]
679778 self .active_moe_expert_ratio = self .config ["active_moe_expert_ratio" ]
779+ self .disabled_layers = self .config ["disabled_layers" ]
680780 self .cost_denominator = getattr (self , "cost_denominator" , None )
681781
682782 search_recipes = self ._get_search_recipes (self .config ["quantization_formats" ])
@@ -765,11 +865,9 @@ def _print_recipe_summary(best_recipe, total_cost, total_weight_size, prefix="Au
765865 @staticmethod
766866 def _get_total_weight_size (modules ):
767867 return sum (
768- (
769- module .weight .numel ()
770- if _AutoQuantizeBaseSearcher ._is_auto_quantize_module (module )
771- else 0
772- )
868+ _get_module_weight_numel (module )
869+ if _AutoQuantizeBaseSearcher ._is_auto_quantize_module (module )
870+ else 0
773871 for module in modules
774872 )
775873
@@ -1372,6 +1470,16 @@ def run_search_with_stats(self, max_weight_size, verbose=False):
13721470AutoQuantizeSearcher = AutoQuantizeGradientSearcher
13731471
13741472
1473+ def _as_list (value ) -> list :
1474+ if value is None :
1475+ return []
1476+ if isinstance (value , list ):
1477+ return value
1478+ if isinstance (value , tuple ):
1479+ return list (value )
1480+ return [value ]
1481+
1482+
13751483def get_auto_quantize_config (search_state , constraints = None , verbose = False ):
13761484 """Build a flat quant config dict from auto_quantize search_state.
13771485
@@ -1401,6 +1509,11 @@ def _cfg_to_dict(v):
14011509 return v
14021510
14031511 quant_cfg : list [dict ] = [{"quantizer_name" : "*" , "enable" : False }]
1512+ quant_cfg .extend (
1513+ {"quantizer_name" : pattern , "enable" : False }
1514+ for pattern in _as_list (search_state .get ("disabled_layers" ))
1515+ )
1516+ per_module_entries : list [dict ] = []
14041517 _per_module_attrs = ("input_quantizer" , "weight_quantizer" , "output_quantizer" )
14051518 # Track global (non per-module) recipe entries. Last recipe wins for each pattern.
14061519 global_entries : dict [str , dict ] = {}
@@ -1421,7 +1534,7 @@ def _cfg_to_dict(v):
14211534 }
14221535 if matched_cfg is not None :
14231536 entry ["cfg" ] = _cfg_to_dict (matched_cfg )
1424- quant_cfg .append (entry )
1537+ per_module_entries .append (entry )
14251538
14261539 # Collect non-per-module entries (e.g. *[kv]_bmm_quantizer) from winning recipes.
14271540 for recipe_entry in recipe .config .quant_cfg :
@@ -1438,7 +1551,10 @@ def _cfg_to_dict(v):
14381551 ge ["cfg" ] = _cfg_to_dict (cfg )
14391552 global_entries [pattern ] = ge
14401553
1554+ # Keep path-scoped recipe entries before explicit module entries so selected
1555+ # modules override default disables such as ``*lm_head*``.
14411556 quant_cfg .extend (global_entries .values ())
1557+ quant_cfg .extend (per_module_entries )
14421558 warnings .warn (
14431559 "get_auto_quantize_config: returned config uses algorithm='max'. "
14441560 "Per-recipe calibration algorithms (e.g. smoothquant, awq) are not preserved. "
@@ -1502,6 +1618,9 @@ def _match_quantizer_cfg(quant_cfg, quantizer_attr):
15021618 matched = None
15031619 matched_enable = None
15041620 for entry in quant_cfg :
1621+ parent_class = entry .get ("parent_class" ) if hasattr (entry , "get" ) else entry .parent_class
1622+ if parent_class is not None :
1623+ continue
15051624 pattern = entry ["quantizer_name" ]
15061625 cfg = entry .get ("cfg" )
15071626 enable = entry .get ("enable" , True )
0 commit comments