5252 promote_nvfp4_static_quantizers ,
5353 quantizer_attr_names ,
5454 reduce_amax ,
55- weight_attr_names ,
5655)
5756from .utils .calib_utils import _GPTQ_HELPER_REGISTRY , GPTQHelper
5857
8483
8584
8685def _is_calibrated_nvfp4_static_weight_quantizer (q ) -> bool :
87- """True for an NVFP4-static weight quantizer that ``max_calibrate`` already
88- populated with a per-block ``_amax`` and that is currently enabled.
86+ """Check whether ``q`` is an enabled, calibrated NVFP4-static weight quantizer.
87+
88+ True when ``max_calibrate`` already populated a per-block ``_amax``.
8989 """
9090 return (
9191 isinstance (q , TensorQuantizer )
@@ -97,9 +97,9 @@ def _is_calibrated_nvfp4_static_weight_quantizer(q) -> bool:
9797
9898
9999def _collect_grouped_linears (model : nn .Module ) -> list [list [nn .Module ]]:
100- """Find groups of Linear-like submodules whose NVFP4-static weight quantizers
101- should share ``global_amax`` (Q/K/V under one attention parent; gate/up under
102- one MLP parent) .
100+ """Find Linear-like submodule groups whose NVFP4-static weight quantizers should share global_amax.
101+
102+ Groups are Q/K/V under one attention parent and gate/up under one MLP parent.
103103 """
104104 groups : list [list [nn .Module ]] = []
105105 wq_attr = quantizer_attr_names ("weight" ).weight_quantizer
@@ -118,6 +118,50 @@ def _collect_grouped_linears(model: nn.Module) -> list[list[nn.Module]]:
118118 return groups
119119
120120
121+ @torch .no_grad ()
122+ def _bootstrap_uncalibrated_weight_quantizers (model : nn .Module ) -> int :
123+ """Run a max-style amax collection on weight quantizers whose ``_amax`` is missing.
124+
125+ Forward-pass max calibration only populates per-expert weight quantizers in
126+ fused-experts containers when tokens are routed to that expert. "Dead"
127+ experts that received no tokens end up with no ``_amax``, which causes
128+ ``mse_calibrate``'s subsequent walk to skip them and forces the export-time
129+ fallback to derive separate per-half amax for gate/up. This helper walks
130+ every QuantModule's :meth:`iter_weights_for_calibration` pairs and, for any
131+ quantizer that lacks ``_amax``, runs the existing calibrator (typically
132+ :class:`MaxCalibrator`) on the corresponding weight slice — populating
133+ ``_amax`` from the weight rather than from runtime activations.
134+
135+ Returns the number of quantizers bootstrapped (mostly for diagnostics).
136+ """
137+ n = 0
138+ for module in model .modules ():
139+ if not isinstance (module , QuantModule ):
140+ continue
141+ try :
142+ pairs = list (module .iter_weights_for_calibration ())
143+ except Exception :
144+ continue
145+ for weight , q in pairs :
146+ if not isinstance (q , TensorQuantizer ) or q ._disabled or q ._dynamic :
147+ continue
148+ if q ._calibrator is None :
149+ continue
150+ if hasattr (q , "_amax" ) and q ._amax is not None and not torch .all (q ._amax == 0 ):
151+ continue
152+ q .disable_quant ()
153+ q .enable_calib ()
154+ q (weight )
155+ if q ._calibrator .compute_amax () is not None :
156+ q .load_calib_amax ()
157+ q .enable_quant ()
158+ q .disable_calib ()
159+ if hasattr (q ._calibrator , "reset" ):
160+ q ._calibrator .reset ()
161+ n += 1
162+ return n
163+
164+
121165@torch .no_grad ()
122166def sync_grouped_weight_global_amax (model : nn .Module ) -> int :
123167 """Sync ``global_amax`` across sibling NVFP4-static weight quantizers.
@@ -439,17 +483,24 @@ def mse_calibrate(
439483 # Step 1: First get initial amax using max calibration
440484 max_calibrate (model , forward_loop , distributed_sync )
441485
486+ # Step 1a: Bootstrap any weight quantizer that didn't receive an _amax from
487+ # the forward-pass max calibration (typical of dead MoE experts in fused-
488+ # experts containers). Without this, the dead-expert per-expert quantizers
489+ # would be silently skipped by step 2's `hasattr(_amax)` gate, leaving the
490+ # export-time fallback to derive separate gate/up amaxes from each half of
491+ # the fused weight (breaking the gate==up weight_scale_2 invariant).
492+ _bootstrap_uncalibrated_weight_quantizers (model )
493+
442494 # Step 1b: Sync global_amax across sibling NVFP4-static weight quantizers
443495 # (q/k/v_proj under one attention block; gate/up — a.k.a. w1/w3 — under one
444496 # MLP block) so their FP8 scale-of-scales matches and the per-block FP8
445497 # round uses a consistent grid. No-op when there are no sibling groups
446498 # (e.g. fused QKV / fused gate_up_proj).
447499 sync_grouped_weight_global_amax (model )
448500
449- # Step 2: Replace calibrators with MseCalibrator for enabled quantizers
450- # and identify weight quantizers
451- weight_quantizers = []
452- seen_modules = set ()
501+ # Step 2: Replace calibrators with MseCalibrator for enabled quantizers.
502+ # (Weight-quantizer discovery + calibration happens in step 3 below using
503+ # iter_weights_for_calibration.)
453504
454505 for name , module in list (model .named_modules ()):
455506 if isinstance (module , TensorQuantizer ) and not module ._disabled :
@@ -506,52 +557,80 @@ def mse_calibrate(
506557 quant_func = partial (_mse_quant_func , quantizer = module ),
507558 )
508559
509- # Identify weight quantizers by checking if they have corresponding weight parameters
560+ # Step 3+4: discover and calibrate weight quantizers via
561+ # iter_weights_for_calibration, which yields (weight_or_slice, quantizer)
562+ # pairs. For non-fused QuantModules, this is one pair per weight (same as
563+ # the previous singular-only walk). For fused-experts containers
564+ # (transformers 5.x: gate_up_proj / down_proj as 3-D Parameters with per-
565+ # expert quantizer ModuleLists) it yields one pair per expert per
566+ # projection — so every per-expert weight quantizer gets MSE-calibrated,
567+ # not just the ones that received forward-pass tokens.
510568 name_to_module = dict (model .named_modules ())
569+ weight_calib_seen : set [int ] = set ()
570+
571+ # Pre-count for an accurate tqdm total (the same iter is cheap to repeat;
572+ # actually run-time work happens in the second pass).
573+ total_to_calib = 0
511574 for parent_module in name_to_module .values ():
512- if parent_module in seen_modules :
575+ if id ( parent_module ) in weight_calib_seen or not isinstance ( parent_module , QuantModule ) :
513576 continue
514- for weight_name in weight_attr_names (parent_module ):
515- weight_quantizer_name = quantizer_attr_names (weight_name ).weight_quantizer
516- weight_quantizer = getattr (parent_module , weight_quantizer_name , None )
517- if isinstance (weight_quantizer , TensorQuantizer ) and weight_quantizer .is_enabled :
518- if getattr (weight_quantizer , "_calibrator" , None ) is not None :
519- weight_quantizers .append ((parent_module , weight_name , weight_quantizer ))
520- seen_modules .add (parent_module )
521-
522- # Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation
523- # This prevents massive memory accumulation seen in large models
524- for idx , (parent_module , weight_name , weight_quantizer ) in enumerate (
525- tqdm (weight_quantizers , desc = "MSE weight calibration" )
526- ):
527- # Enable calibration mode for the weight quantizer
528- weight_quantizer .disable_quant ()
529- weight_quantizer .enable_calib ()
530- with enable_weight_access_and_writeback (parent_module , model , name_to_module ):
531- weight = getattr (parent_module , weight_name )
532- weight_quantizer (weight )
533-
534- # IMMEDIATELY compute amax and reset calibrator to free memory
535- cal = getattr (weight_quantizer , "_calibrator" , None )
536- if cal is not None and cal .compute_amax () is not None :
537- weight_quantizer .load_calib_amax ()
538-
539- weight_quantizer .enable_quant ()
540- weight_quantizer .disable_calib ()
541-
542- # Synchronize ALL CUDA devices before resetting to ensure all async operations complete
543- # This is critical for multi-GPU setups where tensors may be on different devices
544- if torch .cuda .is_available ():
545- for dev_id in range (torch .cuda .device_count ()):
546- torch .cuda .synchronize (torch .device (f"cuda:{ dev_id } " ))
547-
548- if cal is not None and hasattr (cal , "reset" ):
549- cal .reset ()
577+ try :
578+ pairs = list (parent_module .iter_weights_for_calibration ())
579+ except Exception :
580+ continue
581+ for _ , q in pairs :
582+ if (
583+ isinstance (q , TensorQuantizer )
584+ and q .is_enabled
585+ and getattr (q , "_calibrator" , None ) is not None
586+ ):
587+ total_to_calib += 1
550588
551- if (idx + 1 ) % 10 == 0 and torch .cuda .is_available ():
552- for dev_id in range (torch .cuda .device_count ()):
553- torch .cuda .synchronize (torch .device (f"cuda:{ dev_id } " ))
554- torch .cuda .empty_cache ()
589+ pbar = tqdm (total = total_to_calib , desc = "MSE weight calibration" )
590+ n_calibrated = 0
591+ for parent_module in name_to_module .values ():
592+ if id (parent_module ) in weight_calib_seen :
593+ continue
594+ weight_calib_seen .add (id (parent_module ))
595+ if not isinstance (parent_module , QuantModule ):
596+ continue
597+ with enable_weight_access_and_writeback (parent_module , model , name_to_module ):
598+ try :
599+ pairs = list (parent_module .iter_weights_for_calibration ())
600+ except Exception :
601+ pairs = []
602+ for weight , weight_quantizer in pairs :
603+ if not (
604+ isinstance (weight_quantizer , TensorQuantizer )
605+ and weight_quantizer .is_enabled
606+ and getattr (weight_quantizer , "_calibrator" , None ) is not None
607+ ):
608+ continue
609+ weight_quantizer .disable_quant ()
610+ weight_quantizer .enable_calib ()
611+ weight_quantizer (weight )
612+
613+ cal = weight_quantizer ._calibrator
614+ if cal .compute_amax () is not None :
615+ weight_quantizer .load_calib_amax ()
616+
617+ weight_quantizer .enable_quant ()
618+ weight_quantizer .disable_calib ()
619+
620+ if torch .cuda .is_available ():
621+ for dev_id in range (torch .cuda .device_count ()):
622+ torch .cuda .synchronize (torch .device (f"cuda:{ dev_id } " ))
623+
624+ if hasattr (cal , "reset" ):
625+ cal .reset ()
626+
627+ pbar .update (1 )
628+ n_calibrated += 1
629+ if n_calibrated % 10 == 0 and torch .cuda .is_available ():
630+ for dev_id in range (torch .cuda .device_count ()):
631+ torch .cuda .synchronize (torch .device (f"cuda:{ dev_id } " ))
632+ torch .cuda .empty_cache ()
633+ pbar .close ()
555634
556635 if torch .cuda .is_available ():
557636 for dev_id in range (torch .cuda .device_count ()):
0 commit comments