Skip to content

Commit 360b53e

Browse files
committed
[Quantization] MSE-calibrate every per-expert weight in fused-experts MoE
Two-part fix for transformers 5.x fused-experts containers (Qwen3-MoE / Qwen3.5-MoE / Mixtral / DeepSeek / Kimi-K2.x ...) where weight quantizers live in `nn.ModuleList`s (`gate_up_proj_weight_quantizers`, `down_proj_weight_quantizers`): 1. Add `_QuantFusedExperts.iter_weights_for_calibration` that yields per-expert (weight_slice, quantizer) pairs for both projections. The base impl uses singular `*_weight_quantizer` and silently skips fused-experts modules, so weight-only calibration paths never reach per-expert quantizers. 2. Refactor `mse_calibrate`: - Add `_bootstrap_uncalibrated_weight_quantizers` after `max_calibrate` to populate `_amax` on quantizers the forward pass didn't reach (dead MoE experts that received no calibration tokens). Runs the existing calibrator on the weight slice surfaced by `iter_weights_for_calibration`. - Replace the singular-only `weight_attr_names` discovery + `getattr`-by- name walk with an `iter_weights_for_calibration` walk done inside each parent module's `enable_weight_access_and_writeback` context, so MSE processes every per-expert quantizer (active and dead) and remains FSDP-safe. Without this, the export-time fallback in `_export_fused_experts` derived separate gate/up amaxes from each half of the fused weight, breaking the gate==up `weight_scale_2` invariant on dead experts. End-to-end check on Qwen3.5-122B-A10B with `nvfp4_experts_only_mse-fp8_cast_kv`: - Before: 1/12288 (layer 38 expert 69) gate \!= up; 0 weights MSE-calibrated - After: 0/12288 mismatches; 24576 weights MSE-calibrated; ~4.2 min Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
1 parent 3587238 commit 360b53e

2 files changed

Lines changed: 152 additions & 52 deletions

File tree

modelopt/torch/quantization/model_calib.py

Lines changed: 131 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252
promote_nvfp4_static_quantizers,
5353
quantizer_attr_names,
5454
reduce_amax,
55-
weight_attr_names,
5655
)
5756
from .utils.calib_utils import _GPTQ_HELPER_REGISTRY, GPTQHelper
5857

@@ -84,8 +83,9 @@
8483

8584

8685
def _is_calibrated_nvfp4_static_weight_quantizer(q) -> bool:
87-
"""True for an NVFP4-static weight quantizer that ``max_calibrate`` already
88-
populated with a per-block ``_amax`` and that is currently enabled.
86+
"""Check whether ``q`` is an enabled, calibrated NVFP4-static weight quantizer.
87+
88+
True when ``max_calibrate`` already populated a per-block ``_amax``.
8989
"""
9090
return (
9191
isinstance(q, TensorQuantizer)
@@ -97,9 +97,9 @@ def _is_calibrated_nvfp4_static_weight_quantizer(q) -> bool:
9797

9898

9999
def _collect_grouped_linears(model: nn.Module) -> list[list[nn.Module]]:
100-
"""Find groups of Linear-like submodules whose NVFP4-static weight quantizers
101-
should share ``global_amax`` (Q/K/V under one attention parent; gate/up under
102-
one MLP parent).
100+
"""Find Linear-like submodule groups whose NVFP4-static weight quantizers should share global_amax.
101+
102+
Groups are Q/K/V under one attention parent and gate/up under one MLP parent.
103103
"""
104104
groups: list[list[nn.Module]] = []
105105
wq_attr = quantizer_attr_names("weight").weight_quantizer
@@ -118,6 +118,50 @@ def _collect_grouped_linears(model: nn.Module) -> list[list[nn.Module]]:
118118
return groups
119119

120120

121+
@torch.no_grad()
122+
def _bootstrap_uncalibrated_weight_quantizers(model: nn.Module) -> int:
123+
"""Run a max-style amax collection on weight quantizers whose ``_amax`` is missing.
124+
125+
Forward-pass max calibration only populates per-expert weight quantizers in
126+
fused-experts containers when tokens are routed to that expert. "Dead"
127+
experts that received no tokens end up with no ``_amax``, which causes
128+
``mse_calibrate``'s subsequent walk to skip them and forces the export-time
129+
fallback to derive separate per-half amax for gate/up. This helper walks
130+
every QuantModule's :meth:`iter_weights_for_calibration` pairs and, for any
131+
quantizer that lacks ``_amax``, runs the existing calibrator (typically
132+
:class:`MaxCalibrator`) on the corresponding weight slice — populating
133+
``_amax`` from the weight rather than from runtime activations.
134+
135+
Returns the number of quantizers bootstrapped (mostly for diagnostics).
136+
"""
137+
n = 0
138+
for module in model.modules():
139+
if not isinstance(module, QuantModule):
140+
continue
141+
try:
142+
pairs = list(module.iter_weights_for_calibration())
143+
except Exception:
144+
continue
145+
for weight, q in pairs:
146+
if not isinstance(q, TensorQuantizer) or q._disabled or q._dynamic:
147+
continue
148+
if q._calibrator is None:
149+
continue
150+
if hasattr(q, "_amax") and q._amax is not None and not torch.all(q._amax == 0):
151+
continue
152+
q.disable_quant()
153+
q.enable_calib()
154+
q(weight)
155+
if q._calibrator.compute_amax() is not None:
156+
q.load_calib_amax()
157+
q.enable_quant()
158+
q.disable_calib()
159+
if hasattr(q._calibrator, "reset"):
160+
q._calibrator.reset()
161+
n += 1
162+
return n
163+
164+
121165
@torch.no_grad()
122166
def sync_grouped_weight_global_amax(model: nn.Module) -> int:
123167
"""Sync ``global_amax`` across sibling NVFP4-static weight quantizers.
@@ -439,17 +483,24 @@ def mse_calibrate(
439483
# Step 1: First get initial amax using max calibration
440484
max_calibrate(model, forward_loop, distributed_sync)
441485

486+
# Step 1a: Bootstrap any weight quantizer that didn't receive an _amax from
487+
# the forward-pass max calibration (typical of dead MoE experts in fused-
488+
# experts containers). Without this, the dead-expert per-expert quantizers
489+
# would be silently skipped by step 2's `hasattr(_amax)` gate, leaving the
490+
# export-time fallback to derive separate gate/up amaxes from each half of
491+
# the fused weight (breaking the gate==up weight_scale_2 invariant).
492+
_bootstrap_uncalibrated_weight_quantizers(model)
493+
442494
# Step 1b: Sync global_amax across sibling NVFP4-static weight quantizers
443495
# (q/k/v_proj under one attention block; gate/up — a.k.a. w1/w3 — under one
444496
# MLP block) so their FP8 scale-of-scales matches and the per-block FP8
445497
# round uses a consistent grid. No-op when there are no sibling groups
446498
# (e.g. fused QKV / fused gate_up_proj).
447499
sync_grouped_weight_global_amax(model)
448500

449-
# Step 2: Replace calibrators with MseCalibrator for enabled quantizers
450-
# and identify weight quantizers
451-
weight_quantizers = []
452-
seen_modules = set()
501+
# Step 2: Replace calibrators with MseCalibrator for enabled quantizers.
502+
# (Weight-quantizer discovery + calibration happens in step 3 below using
503+
# iter_weights_for_calibration.)
453504

454505
for name, module in list(model.named_modules()):
455506
if isinstance(module, TensorQuantizer) and not module._disabled:
@@ -506,52 +557,80 @@ def mse_calibrate(
506557
quant_func=partial(_mse_quant_func, quantizer=module),
507558
)
508559

509-
# Identify weight quantizers by checking if they have corresponding weight parameters
560+
# Step 3+4: discover and calibrate weight quantizers via
561+
# iter_weights_for_calibration, which yields (weight_or_slice, quantizer)
562+
# pairs. For non-fused QuantModules, this is one pair per weight (same as
563+
# the previous singular-only walk). For fused-experts containers
564+
# (transformers 5.x: gate_up_proj / down_proj as 3-D Parameters with per-
565+
# expert quantizer ModuleLists) it yields one pair per expert per
566+
# projection — so every per-expert weight quantizer gets MSE-calibrated,
567+
# not just the ones that received forward-pass tokens.
510568
name_to_module = dict(model.named_modules())
569+
weight_calib_seen: set[int] = set()
570+
571+
# Pre-count for an accurate tqdm total (the same iter is cheap to repeat;
572+
# actually run-time work happens in the second pass).
573+
total_to_calib = 0
511574
for parent_module in name_to_module.values():
512-
if parent_module in seen_modules:
575+
if id(parent_module) in weight_calib_seen or not isinstance(parent_module, QuantModule):
513576
continue
514-
for weight_name in weight_attr_names(parent_module):
515-
weight_quantizer_name = quantizer_attr_names(weight_name).weight_quantizer
516-
weight_quantizer = getattr(parent_module, weight_quantizer_name, None)
517-
if isinstance(weight_quantizer, TensorQuantizer) and weight_quantizer.is_enabled:
518-
if getattr(weight_quantizer, "_calibrator", None) is not None:
519-
weight_quantizers.append((parent_module, weight_name, weight_quantizer))
520-
seen_modules.add(parent_module)
521-
522-
# Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation
523-
# This prevents massive memory accumulation seen in large models
524-
for idx, (parent_module, weight_name, weight_quantizer) in enumerate(
525-
tqdm(weight_quantizers, desc="MSE weight calibration")
526-
):
527-
# Enable calibration mode for the weight quantizer
528-
weight_quantizer.disable_quant()
529-
weight_quantizer.enable_calib()
530-
with enable_weight_access_and_writeback(parent_module, model, name_to_module):
531-
weight = getattr(parent_module, weight_name)
532-
weight_quantizer(weight)
533-
534-
# IMMEDIATELY compute amax and reset calibrator to free memory
535-
cal = getattr(weight_quantizer, "_calibrator", None)
536-
if cal is not None and cal.compute_amax() is not None:
537-
weight_quantizer.load_calib_amax()
538-
539-
weight_quantizer.enable_quant()
540-
weight_quantizer.disable_calib()
541-
542-
# Synchronize ALL CUDA devices before resetting to ensure all async operations complete
543-
# This is critical for multi-GPU setups where tensors may be on different devices
544-
if torch.cuda.is_available():
545-
for dev_id in range(torch.cuda.device_count()):
546-
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
547-
548-
if cal is not None and hasattr(cal, "reset"):
549-
cal.reset()
577+
try:
578+
pairs = list(parent_module.iter_weights_for_calibration())
579+
except Exception:
580+
continue
581+
for _, q in pairs:
582+
if (
583+
isinstance(q, TensorQuantizer)
584+
and q.is_enabled
585+
and getattr(q, "_calibrator", None) is not None
586+
):
587+
total_to_calib += 1
550588

551-
if (idx + 1) % 10 == 0 and torch.cuda.is_available():
552-
for dev_id in range(torch.cuda.device_count()):
553-
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
554-
torch.cuda.empty_cache()
589+
pbar = tqdm(total=total_to_calib, desc="MSE weight calibration")
590+
n_calibrated = 0
591+
for parent_module in name_to_module.values():
592+
if id(parent_module) in weight_calib_seen:
593+
continue
594+
weight_calib_seen.add(id(parent_module))
595+
if not isinstance(parent_module, QuantModule):
596+
continue
597+
with enable_weight_access_and_writeback(parent_module, model, name_to_module):
598+
try:
599+
pairs = list(parent_module.iter_weights_for_calibration())
600+
except Exception:
601+
pairs = []
602+
for weight, weight_quantizer in pairs:
603+
if not (
604+
isinstance(weight_quantizer, TensorQuantizer)
605+
and weight_quantizer.is_enabled
606+
and getattr(weight_quantizer, "_calibrator", None) is not None
607+
):
608+
continue
609+
weight_quantizer.disable_quant()
610+
weight_quantizer.enable_calib()
611+
weight_quantizer(weight)
612+
613+
cal = weight_quantizer._calibrator
614+
if cal.compute_amax() is not None:
615+
weight_quantizer.load_calib_amax()
616+
617+
weight_quantizer.enable_quant()
618+
weight_quantizer.disable_calib()
619+
620+
if torch.cuda.is_available():
621+
for dev_id in range(torch.cuda.device_count()):
622+
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
623+
624+
if hasattr(cal, "reset"):
625+
cal.reset()
626+
627+
pbar.update(1)
628+
n_calibrated += 1
629+
if n_calibrated % 10 == 0 and torch.cuda.is_available():
630+
for dev_id in range(torch.cuda.device_count()):
631+
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
632+
torch.cuda.empty_cache()
633+
pbar.close()
555634

556635
if torch.cuda.is_available():
557636
for dev_id in range(torch.cuda.device_count()):

modelopt/torch/quantization/plugins/huggingface.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,27 @@ def forward(self, *args, **kwargs):
900900
self._down_proj_linear = False
901901
return super().forward(*args, **kwargs)
902902

903+
def iter_weights_for_calibration(self):
904+
"""Yield ``(weight_slice, quantizer)`` pairs for every per-expert weight quantizer.
905+
906+
Overrides the default :meth:`QuantModule.iter_weights_for_calibration`,
907+
which uses ``weight_attr_names`` + singular ``*_weight_quantizer`` and
908+
therefore silently skips fused-experts modules. Without this override,
909+
weight-only calibration paths (``mse_calibrate``, ``weight_only_quantize``)
910+
never reach per-expert weight quantizers — leaving any expert that the
911+
forward-pass max-calibration didn't route to with no ``_amax``.
912+
"""
913+
for weight_name, quantizers_name in (
914+
("gate_up_proj", "gate_up_proj_weight_quantizers"),
915+
("down_proj", "down_proj_weight_quantizers"),
916+
):
917+
weight = getattr(self, weight_name, None)
918+
quantizers = getattr(self, quantizers_name, None)
919+
if weight is None or quantizers is None:
920+
continue
921+
for idx, q in enumerate(quantizers):
922+
yield weight[idx], q
923+
903924
def fold_weight(self, keep_attrs: bool = False):
904925
"""Fold per-expert weight quantizers into the fused 3-D weights.
905926

0 commit comments

Comments
 (0)