Skip to content

Commit 8e21516

Browse files
committed
[Quantization] MSE-calibrate every per-expert weight in fused-experts MoE
Two-part fix for transformers 5.x fused-experts containers (Qwen3-MoE / Qwen3.5-MoE / Mixtral / DeepSeek / Kimi-K2.x ...) where weight quantizers live in `nn.ModuleList`s (`gate_up_proj_weight_quantizers`, `down_proj_weight_quantizers`): 1. Add `_QuantFusedExperts.iter_weights_for_calibration` that yields per-expert (weight_slice, quantizer) pairs for both projections. The base impl uses singular `*_weight_quantizer` and silently skips fused-experts modules, so weight-only calibration paths never reach per-expert quantizers. 2. Refactor `mse_calibrate`: - Add `_bootstrap_uncalibrated_weight_quantizers` after `max_calibrate` to populate `_amax` on quantizers the forward pass didn't reach (dead MoE experts that received no calibration tokens). Runs the existing calibrator on the weight slice surfaced by `iter_weights_for_calibration`. - Replace the singular-only `weight_attr_names` discovery + `getattr`-by- name walk with an `iter_weights_for_calibration` walk done inside each parent module's `enable_weight_access_and_writeback` context, so MSE processes every per-expert quantizer (active and dead) and remains FSDP-safe. Without this, the export-time fallback in `_export_fused_experts` derived separate gate/up amaxes from each half of the fused weight, breaking the gate==up `weight_scale_2` invariant on dead experts. End-to-end check on Qwen3.5-122B-A10B with `nvfp4_experts_only_mse-fp8_cast_kv`: - Before: 1/12288 (layer 38 expert 69) gate \!= up; 0 weights MSE-calibrated - After: 0/12288 mismatches; 24576 weights MSE-calibrated; ~4.2 min Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
1 parent 3587238 commit 8e21516

2 files changed

Lines changed: 155 additions & 52 deletions

File tree

modelopt/torch/quantization/model_calib.py

Lines changed: 134 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252
promote_nvfp4_static_quantizers,
5353
quantizer_attr_names,
5454
reduce_amax,
55-
weight_attr_names,
5655
)
5756
from .utils.calib_utils import _GPTQ_HELPER_REGISTRY, GPTQHelper
5857

@@ -84,8 +83,9 @@
8483

8584

8685
def _is_calibrated_nvfp4_static_weight_quantizer(q) -> bool:
87-
"""True for an NVFP4-static weight quantizer that ``max_calibrate`` already
88-
populated with a per-block ``_amax`` and that is currently enabled.
86+
"""Check whether ``q`` is an enabled, calibrated NVFP4-static weight quantizer.
87+
88+
True when ``max_calibrate`` already populated a per-block ``_amax``.
8989
"""
9090
return (
9191
isinstance(q, TensorQuantizer)
@@ -97,9 +97,9 @@ def _is_calibrated_nvfp4_static_weight_quantizer(q) -> bool:
9797

9898

9999
def _collect_grouped_linears(model: nn.Module) -> list[list[nn.Module]]:
100-
"""Find groups of Linear-like submodules whose NVFP4-static weight quantizers
101-
should share ``global_amax`` (Q/K/V under one attention parent; gate/up under
102-
one MLP parent).
100+
"""Find Linear-like submodule groups whose NVFP4-static weight quantizers should share global_amax.
101+
102+
Groups are Q/K/V under one attention parent and gate/up under one MLP parent.
103103
"""
104104
groups: list[list[nn.Module]] = []
105105
wq_attr = quantizer_attr_names("weight").weight_quantizer
@@ -118,6 +118,50 @@ def _collect_grouped_linears(model: nn.Module) -> list[list[nn.Module]]:
118118
return groups
119119

120120

121+
@torch.no_grad()
122+
def _bootstrap_uncalibrated_weight_quantizers(model: nn.Module) -> int:
123+
"""Run a max-style amax collection on weight quantizers whose ``_amax`` is missing.
124+
125+
Forward-pass max calibration only populates per-expert weight quantizers in
126+
fused-experts containers when tokens are routed to that expert. "Dead"
127+
experts that received no tokens end up with no ``_amax``, which causes
128+
``mse_calibrate``'s subsequent walk to skip them and forces the export-time
129+
fallback to derive separate per-half amax for gate/up. This helper walks
130+
every QuantModule's :meth:`iter_weights_for_calibration` pairs and, for any
131+
quantizer that lacks ``_amax``, runs the existing calibrator (typically
132+
:class:`MaxCalibrator`) on the corresponding weight slice — populating
133+
``_amax`` from the weight rather than from runtime activations.
134+
135+
Returns the number of quantizers bootstrapped (mostly for diagnostics).
136+
"""
137+
n = 0
138+
for module in model.modules():
139+
if not isinstance(module, QuantModule):
140+
continue
141+
try:
142+
pairs = list(module.iter_weights_for_calibration())
143+
except Exception:
144+
continue
145+
for weight, q in pairs:
146+
if not isinstance(q, TensorQuantizer) or q._disabled or q._dynamic:
147+
continue
148+
if q._calibrator is None:
149+
continue
150+
if hasattr(q, "_amax") and q._amax is not None and not torch.all(q._amax == 0):
151+
continue
152+
q.disable_quant()
153+
q.enable_calib()
154+
q(weight)
155+
if q._calibrator.compute_amax() is not None:
156+
q.load_calib_amax()
157+
q.enable_quant()
158+
q.disable_calib()
159+
if hasattr(q._calibrator, "reset"):
160+
q._calibrator.reset()
161+
n += 1
162+
return n
163+
164+
121165
@torch.no_grad()
122166
def sync_grouped_weight_global_amax(model: nn.Module) -> int:
123167
"""Sync ``global_amax`` across sibling NVFP4-static weight quantizers.
@@ -138,6 +182,9 @@ def sync_grouped_weight_global_amax(model: nn.Module) -> int:
138182
Must be called after ``max_calibrate`` has populated each weight
139183
quantizer's ``_amax``. Idempotent. Returns the number of groups synced.
140184
"""
185+
# Inline import: `modelopt.torch.export.quant_utils` imports
186+
# `enable_stats_collection`/`finish_stats_collection`/`svd` from this module,
187+
# so a top-level import here would deadlock the cycle at startup.
141188
from modelopt.torch.export.quant_utils import preprocess_linear_fusion
142189

143190
n_groups = 0
@@ -439,17 +486,24 @@ def mse_calibrate(
439486
# Step 1: First get initial amax using max calibration
440487
max_calibrate(model, forward_loop, distributed_sync)
441488

489+
# Step 1a: Bootstrap any weight quantizer that didn't receive an _amax from
490+
# the forward-pass max calibration (typical of dead MoE experts in fused-
491+
# experts containers). Without this, the dead-expert per-expert quantizers
492+
# would be silently skipped by step 2's `hasattr(_amax)` gate, leaving the
493+
# export-time fallback to derive separate gate/up amaxes from each half of
494+
# the fused weight (breaking the gate==up weight_scale_2 invariant).
495+
_bootstrap_uncalibrated_weight_quantizers(model)
496+
442497
# Step 1b: Sync global_amax across sibling NVFP4-static weight quantizers
443498
# (q/k/v_proj under one attention block; gate/up — a.k.a. w1/w3 — under one
444499
# MLP block) so their FP8 scale-of-scales matches and the per-block FP8
445500
# round uses a consistent grid. No-op when there are no sibling groups
446501
# (e.g. fused QKV / fused gate_up_proj).
447502
sync_grouped_weight_global_amax(model)
448503

449-
# Step 2: Replace calibrators with MseCalibrator for enabled quantizers
450-
# and identify weight quantizers
451-
weight_quantizers = []
452-
seen_modules = set()
504+
# Step 2: Replace calibrators with MseCalibrator for enabled quantizers.
505+
# (Weight-quantizer discovery + calibration happens in step 3 below using
506+
# iter_weights_for_calibration.)
453507

454508
for name, module in list(model.named_modules()):
455509
if isinstance(module, TensorQuantizer) and not module._disabled:
@@ -506,52 +560,80 @@ def mse_calibrate(
506560
quant_func=partial(_mse_quant_func, quantizer=module),
507561
)
508562

509-
# Identify weight quantizers by checking if they have corresponding weight parameters
563+
# Step 3+4: discover and calibrate weight quantizers via
564+
# iter_weights_for_calibration, which yields (weight_or_slice, quantizer)
565+
# pairs. For non-fused QuantModules, this is one pair per weight (same as
566+
# the previous singular-only walk). For fused-experts containers
567+
# (transformers 5.x: gate_up_proj / down_proj as 3-D Parameters with per-
568+
# expert quantizer ModuleLists) it yields one pair per expert per
569+
# projection — so every per-expert weight quantizer gets MSE-calibrated,
570+
# not just the ones that received forward-pass tokens.
510571
name_to_module = dict(model.named_modules())
572+
weight_calib_seen: set[int] = set()
573+
574+
# Pre-count for an accurate tqdm total (the same iter is cheap to repeat;
575+
# actually run-time work happens in the second pass).
576+
total_to_calib = 0
511577
for parent_module in name_to_module.values():
512-
if parent_module in seen_modules:
578+
if id(parent_module) in weight_calib_seen or not isinstance(parent_module, QuantModule):
513579
continue
514-
for weight_name in weight_attr_names(parent_module):
515-
weight_quantizer_name = quantizer_attr_names(weight_name).weight_quantizer
516-
weight_quantizer = getattr(parent_module, weight_quantizer_name, None)
517-
if isinstance(weight_quantizer, TensorQuantizer) and weight_quantizer.is_enabled:
518-
if getattr(weight_quantizer, "_calibrator", None) is not None:
519-
weight_quantizers.append((parent_module, weight_name, weight_quantizer))
520-
seen_modules.add(parent_module)
521-
522-
# Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation
523-
# This prevents massive memory accumulation seen in large models
524-
for idx, (parent_module, weight_name, weight_quantizer) in enumerate(
525-
tqdm(weight_quantizers, desc="MSE weight calibration")
526-
):
527-
# Enable calibration mode for the weight quantizer
528-
weight_quantizer.disable_quant()
529-
weight_quantizer.enable_calib()
530-
with enable_weight_access_and_writeback(parent_module, model, name_to_module):
531-
weight = getattr(parent_module, weight_name)
532-
weight_quantizer(weight)
533-
534-
# IMMEDIATELY compute amax and reset calibrator to free memory
535-
cal = getattr(weight_quantizer, "_calibrator", None)
536-
if cal is not None and cal.compute_amax() is not None:
537-
weight_quantizer.load_calib_amax()
538-
539-
weight_quantizer.enable_quant()
540-
weight_quantizer.disable_calib()
541-
542-
# Synchronize ALL CUDA devices before resetting to ensure all async operations complete
543-
# This is critical for multi-GPU setups where tensors may be on different devices
544-
if torch.cuda.is_available():
545-
for dev_id in range(torch.cuda.device_count()):
546-
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
547-
548-
if cal is not None and hasattr(cal, "reset"):
549-
cal.reset()
580+
try:
581+
pairs = list(parent_module.iter_weights_for_calibration())
582+
except Exception:
583+
continue
584+
for _, q in pairs:
585+
if (
586+
isinstance(q, TensorQuantizer)
587+
and q.is_enabled
588+
and getattr(q, "_calibrator", None) is not None
589+
):
590+
total_to_calib += 1
550591

551-
if (idx + 1) % 10 == 0 and torch.cuda.is_available():
552-
for dev_id in range(torch.cuda.device_count()):
553-
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
554-
torch.cuda.empty_cache()
592+
pbar = tqdm(total=total_to_calib, desc="MSE weight calibration")
593+
n_calibrated = 0
594+
for parent_module in name_to_module.values():
595+
if id(parent_module) in weight_calib_seen:
596+
continue
597+
weight_calib_seen.add(id(parent_module))
598+
if not isinstance(parent_module, QuantModule):
599+
continue
600+
with enable_weight_access_and_writeback(parent_module, model, name_to_module):
601+
try:
602+
pairs = list(parent_module.iter_weights_for_calibration())
603+
except Exception:
604+
pairs = []
605+
for weight, weight_quantizer in pairs:
606+
if not (
607+
isinstance(weight_quantizer, TensorQuantizer)
608+
and weight_quantizer.is_enabled
609+
and getattr(weight_quantizer, "_calibrator", None) is not None
610+
):
611+
continue
612+
weight_quantizer.disable_quant()
613+
weight_quantizer.enable_calib()
614+
weight_quantizer(weight)
615+
616+
cal = weight_quantizer._calibrator
617+
if cal.compute_amax() is not None:
618+
weight_quantizer.load_calib_amax()
619+
620+
weight_quantizer.enable_quant()
621+
weight_quantizer.disable_calib()
622+
623+
if torch.cuda.is_available():
624+
for dev_id in range(torch.cuda.device_count()):
625+
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
626+
627+
if hasattr(cal, "reset"):
628+
cal.reset()
629+
630+
pbar.update(1)
631+
n_calibrated += 1
632+
if n_calibrated % 10 == 0 and torch.cuda.is_available():
633+
for dev_id in range(torch.cuda.device_count()):
634+
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
635+
torch.cuda.empty_cache()
636+
pbar.close()
555637

556638
if torch.cuda.is_available():
557639
for dev_id in range(torch.cuda.device_count()):

modelopt/torch/quantization/plugins/huggingface.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,27 @@ def forward(self, *args, **kwargs):
900900
self._down_proj_linear = False
901901
return super().forward(*args, **kwargs)
902902

903+
def iter_weights_for_calibration(self):
904+
"""Yield ``(weight_slice, quantizer)`` pairs for every per-expert weight quantizer.
905+
906+
Overrides the default :meth:`QuantModule.iter_weights_for_calibration`,
907+
which uses ``weight_attr_names`` + singular ``*_weight_quantizer`` and
908+
therefore silently skips fused-experts modules. Without this override,
909+
weight-only calibration paths (``mse_calibrate``, ``weight_only_quantize``)
910+
never reach per-expert weight quantizers — leaving any expert that the
911+
forward-pass max-calibration didn't route to with no ``_amax``.
912+
"""
913+
for weight_name, quantizers_name in (
914+
("gate_up_proj", "gate_up_proj_weight_quantizers"),
915+
("down_proj", "down_proj_weight_quantizers"),
916+
):
917+
weight = getattr(self, weight_name, None)
918+
quantizers = getattr(self, quantizers_name, None)
919+
if weight is None or quantizers is None:
920+
continue
921+
for idx, q in enumerate(quantizers):
922+
yield weight[idx], q
923+
903924
def fold_weight(self, keep_attrs: bool = False):
904925
"""Fold per-expert weight quantizers into the fused 3-D weights.
905926

0 commit comments

Comments
 (0)