Skip to content

Commit 311f3a5

Browse files
committed
[Quantization] Always populate weight _amax in max_calibrate
max_calibrate now always runs weight_only_quantize before the optional forward_loop, so every weight quantizer gets _amax regardless of MoE routing. Weight quantizers disabled by the caller (e.g. awq_lite, which runs max_calibrate with weight quantizers disabled) are skipped by weight_only_quantize, so the AWQ dynamic-amax path is unaffected. With _amax guaranteed after calibration, remove two now-redundant band-aids: - _bootstrap_uncalibrated_weight_quantizers (re-ran weight calibration for experts skipped by partial MoE routing); superseded by the always-on weight_only_quantize. - _ensure_weight_quantizer_calibrated and its helpers in export (lazy weight calibration at scale-factor extraction time), plus the GPU test that only exercised that lazy path. Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
1 parent d7e72f4 commit 311f3a5

3 files changed

Lines changed: 14 additions & 209 deletions

File tree

modelopt/torch/export/quant_utils.py

Lines changed: 0 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
from modelopt.torch.quantization.utils import (
4242
QuantizerAttrNames,
4343
quantizer_attr_names,
44-
reduce_block_amax,
4544
representative_weight_quantizer,
4645
weight_attr_names,
4746
)
@@ -241,100 +240,6 @@ def get_scaling_factor(quantizer: TensorQuantizer) -> torch.Tensor:
241240
return scaling_factor
242241

243242

244-
def _get_nvfp4_block_size(
245-
weight_quantizer: NVFP4StaticQuantizer, weight: torch.Tensor, module_name: str = ""
246-
) -> int:
247-
"""Return block size for NVFP4 from quantizer's block_sizes; raise if missing."""
248-
prefix = f"NVFP4StaticQuantizer{f' for {module_name}' if module_name else ''}"
249-
block_sizes = weight_quantizer.block_sizes
250-
if block_sizes is None:
251-
raise ValueError(f"{prefix} has no block_sizes; cannot compute per-block amax from weight.")
252-
block_size = block_sizes.get(-1) or block_sizes.get(weight.dim() - 1)
253-
if block_size is None:
254-
raise ValueError(
255-
f"{prefix} block_sizes has no -1 or last-dim key; cannot compute per-block amax."
256-
)
257-
return block_size
258-
259-
260-
def _set_amax_from_tensor(weight_quantizer: TensorQuantizer, tensor: torch.Tensor) -> None:
261-
"""Set quantizer _amax buffer from tensor; copy in-place if same shape, else replace buffer."""
262-
if (
263-
hasattr(weight_quantizer, "_amax")
264-
and weight_quantizer._amax is not None
265-
and weight_quantizer._amax.shape == tensor.shape
266-
):
267-
weight_quantizer._amax.data.copy_(tensor.to(weight_quantizer._amax.device))
268-
else:
269-
if hasattr(weight_quantizer, "_amax"):
270-
delattr(weight_quantizer, "_amax")
271-
weight_quantizer.register_buffer("_amax", tensor.clone().detach())
272-
273-
274-
def _ensure_weight_quantizer_calibrated(
275-
weight_quantizer: TensorQuantizer, weight: torch.Tensor, module_name: str = ""
276-
) -> None:
277-
"""Calibrate weight quantizer if amax is not set.
278-
279-
This is a lazy calibration pattern used during export when weight quantizers
280-
may not have been calibrated during the main calibration phase.
281-
282-
For NVFP4StaticQuantizer, _amax is per-block amax and _global_amax is the max over
283-
blocks; both are computed from the weight when missing.
284-
285-
Args:
286-
weight_quantizer: The weight quantizer to calibrate
287-
weight: The weight tensor to use for calibration
288-
module_name: Optional module name for better warning messages
289-
"""
290-
if isinstance(weight_quantizer, NVFP4StaticQuantizer):
291-
292-
def _amax_is_invalid(t: torch.Tensor | None) -> bool:
293-
# MCore distcp may register but not fill amax — treat missing/non-finite/negative as recompute.
294-
if t is None:
295-
return True
296-
t = t.detach()
297-
if not torch.is_floating_point(t):
298-
return False
299-
return bool((~torch.isfinite(t) | (t < 0)).any().item())
300-
301-
need_per_block = (
302-
not hasattr(weight_quantizer, "_amax")
303-
or weight_quantizer._amax is None
304-
or _amax_is_invalid(weight_quantizer._amax)
305-
)
306-
need_global = (
307-
not hasattr(weight_quantizer, "_global_amax")
308-
or weight_quantizer.global_amax is None
309-
or _amax_is_invalid(weight_quantizer.global_amax)
310-
)
311-
if not (need_per_block or need_global):
312-
return
313-
block_size = _get_nvfp4_block_size(weight_quantizer, weight, module_name)
314-
warn(
315-
f"NVFP4StaticQuantizer{f' for {module_name}' if module_name else ''} was not fully calibrated. "
316-
f"Computing per-block amax and global_amax from weights. This may occur if: "
317-
f"some experts were not activated during calibration (expected for MoE models), try increasing --calib_size"
318-
)
319-
per_block_amax = reduce_block_amax(weight, block_sizes={-1: block_size})
320-
if need_per_block:
321-
_set_amax_from_tensor(weight_quantizer, per_block_amax.to(weight.device))
322-
if need_global:
323-
weight_quantizer.global_amax = per_block_amax.max()
324-
return
325-
326-
if not hasattr(weight_quantizer, "_amax") or weight_quantizer._amax is None:
327-
warn(
328-
f"Weight quantizer{f' for {module_name}' if module_name else ''} was not calibrated. "
329-
f"Computing amax from weights. This may occur if: "
330-
f"some experts were not activated during calibration (expected for MoE models), try increasing --calib_size"
331-
)
332-
weight_quantizer.reset_amax()
333-
enable_stats_collection(weight_quantizer)
334-
weight_quantizer(weight)
335-
finish_stats_collection(weight_quantizer)
336-
337-
338243
def get_activation_scaling_factor(
339244
module: nn.Module, input_quantizer_name: str = "input_quantizer"
340245
) -> torch.Tensor:
@@ -379,10 +284,6 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
379284
QUANTIZATION_W4A16_NVFP4,
380285
QUANTIZATION_W4A8_NVFP4_FP8,
381286
]:
382-
# Calibrate weight quantizer if amax is not set
383-
module_name = f"{type(module).__name__}.{weight_name}"
384-
_ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name)
385-
386287
if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8:
387288
# weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
388289
# This is because the kernel dequantizes weight to fp8, which is in range 448.
@@ -424,11 +325,6 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight")
424325
QUANTIZATION_W4A16_NVFP4,
425326
QUANTIZATION_W4A8_NVFP4_FP8,
426327
]:
427-
# Calibrate weight quantizer if amax is not set
428-
weight = getattr(module, weight_name)
429-
module_name = f"{type(module).__name__}.{weight_name}"
430-
_ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name)
431-
432328
if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8:
433329
# weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
434330
# This is because the kernel dequantizes weight to fp8, which is in range 448.

modelopt/torch/quantization/model_calib.py

Lines changed: 12 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -110,47 +110,6 @@ def _collect_weight_stats(quantizer: nn.Module, weight: torch.Tensor) -> None:
110110
quantizer(weight)
111111

112112

113-
@torch.no_grad()
114-
def _bootstrap_uncalibrated_weight_quantizers(model: nn.Module) -> int:
115-
"""Re-run weight calibration on the weight tensor for quantizers missing ``_amax``.
116-
117-
Covers MoE experts that ``max_calibrate`` skipped (no routed tokens) so MSE
118-
doesn't drop them and break the gate==up ``weight_scale_2`` export invariant.
119-
Activation quantizers on those modules remain uncalibrated; emits a warning.
120-
"""
121-
name_to_module = dict(model.named_modules())
122-
n = 0
123-
for module in name_to_module.values():
124-
if not isinstance(module, QuantModule):
125-
continue
126-
with enable_weight_access_and_writeback(module, model, name_to_module):
127-
for weight, q in module.iter_weights_for_calibration():
128-
if (
129-
not isinstance(q, TensorQuantizer)
130-
or q._disabled
131-
or q._dynamic
132-
or q._calibrator is None
133-
):
134-
continue
135-
if weight.is_meta:
136-
continue
137-
amax = q.amax
138-
if amax is not None and (amax.is_meta or not torch.all(amax == 0)):
139-
continue
140-
_run_and_load_max_stats(q, partial(_collect_weight_stats, weight=weight))
141-
if hasattr(q._calibrator, "reset"):
142-
q._calibrator.reset()
143-
n += 1
144-
if n > 0:
145-
warnings.warn(
146-
f"Bootstrapped {n} weight quantizer(s) with no routed calibration tokens; "
147-
f"their activation quantizers (if any) remain uncalibrated. "
148-
f"Increase calib size/seq len to activate all experts.",
149-
stacklevel=2,
150-
)
151-
return n
152-
153-
154113
@torch.no_grad()
155114
def _sync_grouped_weight_global_amax(model: nn.Module) -> int:
156115
"""Unify NVFP4 ``global_amax`` across Q/K/V and gate/up sibling weight quantizers.
@@ -304,7 +263,18 @@ def max_calibrate(
304263
See :class:`MaxCalibConfig <modelopt.torch.quantization.config.MaxCalibConfig>` for
305264
details on the remaining arguments.
306265
"""
307-
_run_and_load_max_stats(model, forward_loop)
266+
# Always run weight calibration on the weight tensor directly so every weight
267+
# quantizer gets ``_amax``, regardless of MoE routing. Downstream algorithms
268+
# (MSE, AWQ, export) then no longer need to patch in a missing ``_amax``. Weight
269+
# quantizers also exercised by ``forward_loop`` see the same weight twice;
270+
# MaxCalibrator's reduction is idempotent (max of identical values), so the extra
271+
# pass is a no-op for their stats. Disabled weight quantizers (e.g. AWQ's, which
272+
# call this with weight quantizers disabled) are skipped by ``weight_only_quantize``.
273+
enable_stats_collection(model)
274+
weight_only_quantize(model)
275+
if forward_loop is not None:
276+
forward_loop(model)
277+
finish_stats_collection(model)
308278

309279
# Sync quantizer amax across local experts within each rank (for SequentialMLP)
310280
for name, module in model.named_modules():
@@ -314,8 +284,6 @@ def max_calibrate(
314284
# Fail fast on NVFP4 static-block with TP>1 (sharded_state_dict treats _amax as replicated).
315285
_check_nvfp4_static_tp_supported(model)
316286

317-
_bootstrap_uncalibrated_weight_quantizers(model)
318-
319287
# Promote eligible static-block NVFP4 weight quantizers to NVFP4StaticQuantizer so
320288
# the static blockwise fake-quant path is used in forward and export picks up the
321289
# two-level (per-block + global) scaling. Run before the ``distributed_sync`` early

tests/gpu/torch/export/test_export_weight_gpu.py

Lines changed: 2 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,16 @@
1717

1818
import torch
1919
import torch.nn as nn
20-
from _test_utils.torch.export.utils import ToyModel, partial_nvfp4_config, partial_w4a8_config
20+
from _test_utils.torch.export.utils import ToyModel, partial_w4a8_config
2121
from torch.nn import functional as F
2222
from torch.nn import init
2323

2424
import modelopt.torch.quantization as mtq
2525
from modelopt.torch.export.unified_export_hf import _export_quantized_weight
26-
from modelopt.torch.quantization.nn import NVFP4StaticQuantizer
2726
from modelopt.torch.quantization.nn.modules.quant_module import QuantModule, QuantModuleRegistry
2827
from modelopt.torch.quantization.nn.modules.tensor_quantizer import TensorQuantizer
2928
from modelopt.torch.quantization.tensor_quant import QUANT_DESC_8BIT_PER_TENSOR
30-
from modelopt.torch.quantization.utils import quantizer_attr_names, reduce_block_amax
29+
from modelopt.torch.quantization.utils import quantizer_attr_names
3130

3231

3332
class ToyLinear(nn.Module):
@@ -122,61 +121,3 @@ def test_export_per_block_quantized_weight():
122121
assert hasattr(model.linears[2], quantizer_attrs.output_quantizer)
123122
assert not getattr(model.linears[2], quantizer_attrs.output_quantizer).is_enabled
124123
assert not hasattr(model.linears[2], quantizer_attrs.output_scale)
125-
126-
127-
def test_export_nvfp4_static_weight_dynamic_vs_static_match():
128-
"""Dynamic vs static NVFP4 export: same weight and scales after export even when amaxs are
129-
cleared on one layer (lazy calibration via _ensure_weight_quantizer_calibrated fills them from weights).
130-
"""
131-
device = "cuda"
132-
dims = [32, 32, 32, 32]
133-
block_size = 16
134-
calib_input = torch.randn(1, 4, 32, device=device)
135-
nvfp4_layer_indices = [1, 2] # layers with NVFP4 enabled in partial_nvfp4_config
136-
137-
torch.manual_seed(42)
138-
model_dynamic = ToyModel(dims=dims).to(device)
139-
mtq.quantize(model_dynamic, partial_nvfp4_config, lambda x: x(calib_input))
140-
141-
torch.manual_seed(42)
142-
model_static = ToyModel(dims=dims).to(device)
143-
mtq.quantize(model_static, partial_nvfp4_config, lambda x: x(calib_input))
144-
145-
# Convert NVFP4 layers to NVFP4StaticQuantizer with per-block and global amax
146-
for idx in nvfp4_layer_indices:
147-
layer = model_static.linears[idx]
148-
weight = layer.weight.data
149-
per_block_amax = reduce_block_amax(weight, block_sizes={-1: block_size})
150-
tq = layer.weight_quantizer
151-
if hasattr(tq, "_amax"):
152-
delattr(tq, "_amax")
153-
tq.register_buffer("_amax", per_block_amax.to(weight.device).clone().detach())
154-
NVFP4StaticQuantizer.from_tensor_quantizer(tq, global_amax=per_block_amax.max())
155-
156-
# Clear amaxs on layer 1 to exercise lazy calibration during export
157-
for linear, is_static in [(model_dynamic.linears[1], False), (model_static.linears[1], True)]:
158-
wq = linear.weight_quantizer
159-
if hasattr(wq, "_amax"):
160-
delattr(wq, "_amax")
161-
if is_static and hasattr(wq, "_global_amax"):
162-
delattr(wq, "_global_amax")
163-
164-
quantizer_attrs = quantizer_attr_names("weight")
165-
for idx in nvfp4_layer_indices:
166-
_export_quantized_weight(model_dynamic.linears[idx], torch.float32, "weight")
167-
_export_quantized_weight(model_static.linears[idx], torch.float32, "weight")
168-
169-
for idx in nvfp4_layer_indices:
170-
dyn_linear = model_dynamic.linears[idx]
171-
sta_linear = model_static.linears[idx]
172-
assert torch.equal(dyn_linear.weight, sta_linear.weight), (
173-
f"Layer {idx}: exported NVFP4 weight should match (dynamic vs static)"
174-
)
175-
assert torch.allclose(
176-
getattr(dyn_linear, quantizer_attrs.weight_scale).float(),
177-
getattr(sta_linear, quantizer_attrs.weight_scale).float(),
178-
), f"Layer {idx}: weight_scale should match"
179-
assert torch.allclose(
180-
getattr(dyn_linear, quantizer_attrs.weight_scale_2).float(),
181-
getattr(sta_linear, quantizer_attrs.weight_scale_2).float(),
182-
), f"Layer {idx}: weight_scale_2 should match"

0 commit comments

Comments
 (0)