5252 promote_nvfp4_static_quantizers ,
5353 quantizer_attr_names ,
5454 reduce_amax ,
55- weight_attr_names ,
5655)
5756from .utils .calib_utils import _GPTQ_HELPER_REGISTRY , GPTQHelper
5857
8483
8584
8685def _is_calibrated_nvfp4_static_weight_quantizer (q ) -> bool :
87- """True for an NVFP4-static weight quantizer that ``max_calibrate`` already
88- populated with a per-block ``_amax`` and that is currently enabled.
86+ """Check whether ``q`` is an enabled, calibrated NVFP4-static weight quantizer.
87+
88+ True when ``max_calibrate`` already populated a per-block ``_amax``.
8989 """
9090 return (
9191 isinstance (q , TensorQuantizer )
@@ -97,9 +97,9 @@ def _is_calibrated_nvfp4_static_weight_quantizer(q) -> bool:
9797
9898
9999def _collect_grouped_linears (model : nn .Module ) -> list [list [nn .Module ]]:
100- """Find groups of Linear-like submodules whose NVFP4-static weight quantizers
101- should share ``global_amax`` (Q/K/V under one attention parent; gate/up under
102- one MLP parent) .
100+ """Find Linear-like submodule groups whose NVFP4-static weight quantizers should share global_amax.
101+
102+ Groups are Q/K/V under one attention parent and gate/up under one MLP parent.
103103 """
104104 groups : list [list [nn .Module ]] = []
105105 wq_attr = quantizer_attr_names ("weight" ).weight_quantizer
@@ -118,6 +118,50 @@ def _collect_grouped_linears(model: nn.Module) -> list[list[nn.Module]]:
118118 return groups
119119
120120
121+ @torch .no_grad ()
122+ def _bootstrap_uncalibrated_weight_quantizers (model : nn .Module ) -> int :
123+ """Run a max-style amax collection on weight quantizers whose ``_amax`` is missing.
124+
125+ Forward-pass max calibration only populates per-expert weight quantizers in
126+ fused-experts containers when tokens are routed to that expert. "Dead"
127+ experts that received no tokens end up with no ``_amax``, which causes
128+ ``mse_calibrate``'s subsequent walk to skip them and forces the export-time
129+ fallback to derive separate per-half amax for gate/up. This helper walks
130+ every QuantModule's :meth:`iter_weights_for_calibration` pairs and, for any
131+ quantizer that lacks ``_amax``, runs the existing calibrator (typically
132+ :class:`MaxCalibrator`) on the corresponding weight slice — populating
133+ ``_amax`` from the weight rather than from runtime activations.
134+
135+ Returns the number of quantizers bootstrapped (mostly for diagnostics).
136+ """
137+ n = 0
138+ for module in model .modules ():
139+ if not isinstance (module , QuantModule ):
140+ continue
141+ try :
142+ pairs = list (module .iter_weights_for_calibration ())
143+ except Exception :
144+ continue
145+ for weight , q in pairs :
146+ if not isinstance (q , TensorQuantizer ) or q ._disabled or q ._dynamic :
147+ continue
148+ if q ._calibrator is None :
149+ continue
150+ if hasattr (q , "_amax" ) and q ._amax is not None and not torch .all (q ._amax == 0 ):
151+ continue
152+ q .disable_quant ()
153+ q .enable_calib ()
154+ q (weight )
155+ if q ._calibrator .compute_amax () is not None :
156+ q .load_calib_amax ()
157+ q .enable_quant ()
158+ q .disable_calib ()
159+ if hasattr (q ._calibrator , "reset" ):
160+ q ._calibrator .reset ()
161+ n += 1
162+ return n
163+
164+
121165@torch .no_grad ()
122166def sync_grouped_weight_global_amax (model : nn .Module ) -> int :
123167 """Sync ``global_amax`` across sibling NVFP4-static weight quantizers.
@@ -138,6 +182,9 @@ def sync_grouped_weight_global_amax(model: nn.Module) -> int:
138182 Must be called after ``max_calibrate`` has populated each weight
139183 quantizer's ``_amax``. Idempotent. Returns the number of groups synced.
140184 """
185+ # Inline import: `modelopt.torch.export.quant_utils` imports
186+ # `enable_stats_collection`/`finish_stats_collection`/`svd` from this module,
187+ # so a top-level import here would deadlock the cycle at startup.
141188 from modelopt .torch .export .quant_utils import preprocess_linear_fusion
142189
143190 n_groups = 0
@@ -439,17 +486,24 @@ def mse_calibrate(
439486 # Step 1: First get initial amax using max calibration
440487 max_calibrate (model , forward_loop , distributed_sync )
441488
489+ # Step 1a: Bootstrap any weight quantizer that didn't receive an _amax from
490+ # the forward-pass max calibration (typical of dead MoE experts in fused-
491+ # experts containers). Without this, the dead-expert per-expert quantizers
492+ # would be silently skipped by step 2's `hasattr(_amax)` gate, leaving the
493+ # export-time fallback to derive separate gate/up amaxes from each half of
494+ # the fused weight (breaking the gate==up weight_scale_2 invariant).
495+ _bootstrap_uncalibrated_weight_quantizers (model )
496+
442497 # Step 1b: Sync global_amax across sibling NVFP4-static weight quantizers
443498 # (q/k/v_proj under one attention block; gate/up — a.k.a. w1/w3 — under one
444499 # MLP block) so their FP8 scale-of-scales matches and the per-block FP8
445500 # round uses a consistent grid. No-op when there are no sibling groups
446501 # (e.g. fused QKV / fused gate_up_proj).
447502 sync_grouped_weight_global_amax (model )
448503
449- # Step 2: Replace calibrators with MseCalibrator for enabled quantizers
450- # and identify weight quantizers
451- weight_quantizers = []
452- seen_modules = set ()
504+ # Step 2: Replace calibrators with MseCalibrator for enabled quantizers.
505+ # (Weight-quantizer discovery + calibration happens in step 3 below using
506+ # iter_weights_for_calibration.)
453507
454508 for name , module in list (model .named_modules ()):
455509 if isinstance (module , TensorQuantizer ) and not module ._disabled :
@@ -506,52 +560,80 @@ def mse_calibrate(
506560 quant_func = partial (_mse_quant_func , quantizer = module ),
507561 )
508562
509- # Identify weight quantizers by checking if they have corresponding weight parameters
563+ # Step 3+4: discover and calibrate weight quantizers via
564+ # iter_weights_for_calibration, which yields (weight_or_slice, quantizer)
565+ # pairs. For non-fused QuantModules, this is one pair per weight (same as
566+ # the previous singular-only walk). For fused-experts containers
567+ # (transformers 5.x: gate_up_proj / down_proj as 3-D Parameters with per-
568+ # expert quantizer ModuleLists) it yields one pair per expert per
569+ # projection — so every per-expert weight quantizer gets MSE-calibrated,
570+ # not just the ones that received forward-pass tokens.
510571 name_to_module = dict (model .named_modules ())
572+ weight_calib_seen : set [int ] = set ()
573+
574+ # Pre-count for an accurate tqdm total (the same iter is cheap to repeat;
575+ # actually run-time work happens in the second pass).
576+ total_to_calib = 0
511577 for parent_module in name_to_module .values ():
512- if parent_module in seen_modules :
578+ if id ( parent_module ) in weight_calib_seen or not isinstance ( parent_module , QuantModule ) :
513579 continue
514- for weight_name in weight_attr_names (parent_module ):
515- weight_quantizer_name = quantizer_attr_names (weight_name ).weight_quantizer
516- weight_quantizer = getattr (parent_module , weight_quantizer_name , None )
517- if isinstance (weight_quantizer , TensorQuantizer ) and weight_quantizer .is_enabled :
518- if getattr (weight_quantizer , "_calibrator" , None ) is not None :
519- weight_quantizers .append ((parent_module , weight_name , weight_quantizer ))
520- seen_modules .add (parent_module )
521-
522- # Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation
523- # This prevents massive memory accumulation seen in large models
524- for idx , (parent_module , weight_name , weight_quantizer ) in enumerate (
525- tqdm (weight_quantizers , desc = "MSE weight calibration" )
526- ):
527- # Enable calibration mode for the weight quantizer
528- weight_quantizer .disable_quant ()
529- weight_quantizer .enable_calib ()
530- with enable_weight_access_and_writeback (parent_module , model , name_to_module ):
531- weight = getattr (parent_module , weight_name )
532- weight_quantizer (weight )
533-
534- # IMMEDIATELY compute amax and reset calibrator to free memory
535- cal = getattr (weight_quantizer , "_calibrator" , None )
536- if cal is not None and cal .compute_amax () is not None :
537- weight_quantizer .load_calib_amax ()
538-
539- weight_quantizer .enable_quant ()
540- weight_quantizer .disable_calib ()
541-
542- # Synchronize ALL CUDA devices before resetting to ensure all async operations complete
543- # This is critical for multi-GPU setups where tensors may be on different devices
544- if torch .cuda .is_available ():
545- for dev_id in range (torch .cuda .device_count ()):
546- torch .cuda .synchronize (torch .device (f"cuda:{ dev_id } " ))
547-
548- if cal is not None and hasattr (cal , "reset" ):
549- cal .reset ()
580+ try :
581+ pairs = list (parent_module .iter_weights_for_calibration ())
582+ except Exception :
583+ continue
584+ for _ , q in pairs :
585+ if (
586+ isinstance (q , TensorQuantizer )
587+ and q .is_enabled
588+ and getattr (q , "_calibrator" , None ) is not None
589+ ):
590+ total_to_calib += 1
550591
551- if (idx + 1 ) % 10 == 0 and torch .cuda .is_available ():
552- for dev_id in range (torch .cuda .device_count ()):
553- torch .cuda .synchronize (torch .device (f"cuda:{ dev_id } " ))
554- torch .cuda .empty_cache ()
592+ pbar = tqdm (total = total_to_calib , desc = "MSE weight calibration" )
593+ n_calibrated = 0
594+ for parent_module in name_to_module .values ():
595+ if id (parent_module ) in weight_calib_seen :
596+ continue
597+ weight_calib_seen .add (id (parent_module ))
598+ if not isinstance (parent_module , QuantModule ):
599+ continue
600+ with enable_weight_access_and_writeback (parent_module , model , name_to_module ):
601+ try :
602+ pairs = list (parent_module .iter_weights_for_calibration ())
603+ except Exception :
604+ pairs = []
605+ for weight , weight_quantizer in pairs :
606+ if not (
607+ isinstance (weight_quantizer , TensorQuantizer )
608+ and weight_quantizer .is_enabled
609+ and getattr (weight_quantizer , "_calibrator" , None ) is not None
610+ ):
611+ continue
612+ weight_quantizer .disable_quant ()
613+ weight_quantizer .enable_calib ()
614+ weight_quantizer (weight )
615+
616+ cal = weight_quantizer ._calibrator
617+ if cal .compute_amax () is not None :
618+ weight_quantizer .load_calib_amax ()
619+
620+ weight_quantizer .enable_quant ()
621+ weight_quantizer .disable_calib ()
622+
623+ if torch .cuda .is_available ():
624+ for dev_id in range (torch .cuda .device_count ()):
625+ torch .cuda .synchronize (torch .device (f"cuda:{ dev_id } " ))
626+
627+ if hasattr (cal , "reset" ):
628+ cal .reset ()
629+
630+ pbar .update (1 )
631+ n_calibrated += 1
632+ if n_calibrated % 10 == 0 and torch .cuda .is_available ():
633+ for dev_id in range (torch .cuda .device_count ()):
634+ torch .cuda .synchronize (torch .device (f"cuda:{ dev_id } " ))
635+ torch .cuda .empty_cache ()
636+ pbar .close ()
555637
556638 if torch .cuda .is_available ():
557639 for dev_id in range (torch .cuda .device_count ()):
0 commit comments