Skip to content

Commit 5508c32

Browse files
authored
[Quantization] MSE-calibrate every per-expert weight in fused-experts MoE (#1421)
### What does this PR do? Type of change: Bug fix Two-part fix for transformers 5.x fused-experts containers (Qwen3-MoE / Qwen3.5-MoE / Mixtral / DeepSeek / Kimi-K2.x ...) where weight quantizers live in `nn.ModuleList`s (`gate_up_proj_weight_quantizers`, `down_proj_weight_quantizers`): 1. **Per-expert weight iteration for calibration.** Add `_QuantFusedExperts.iter_weights_for_calibration` that yields per-expert `(weight_slice, quantizer)` pairs for both projections. The base impl uses singular `*_weight_quantizer` and silently skips fused-experts modules, so weight-only calibration paths never reached per-expert quantizers. 2. **`mse_calibrate` refactor.** - Add `_bootstrap_uncalibrated_weight_quantizers` after `max_calibrate` to populate `_amax` on quantizers the forward pass didn't reach (dead MoE experts that received no calibration tokens). Runs the existing calibrator on the weight slice surfaced by `iter_weights_for_calibration`. - Replace the singular-only `weight_attr_names` discovery + `getattr`-by-name walk with an `iter_weights_for_calibration` walk done inside each parent module's `enable_weight_access_and_writeback` context, so MSE processes every per-expert quantizer (active and dead) and remains FSDP-safe. Without this, the export-time fallback in `_export_fused_experts` derived separate gate/up amaxes from each half of the fused weight, breaking the `gate==up` `weight_scale_2` invariant on dead experts. Also includes: - `_sanitize_generation_config_for_save` in `unified_export_hf` — coerces `do_sample=True` when an upstream `generation_config.json` has `top_k`/`top_p` set, so newer transformers' strict validate doesn't block `save_pretrained`. - Small companion plumbing in `moe_utils.py`, `tensor_quantizer.py`, and `core_utils.py` to support the per-expert iteration and bootstrap path. ### Usage ```python import modelopt.torch.quantization as mtq from modelopt.recipe import load_config # Recipe `nvfp4_experts_only_mse-kv_fp8_cast` (already on main) now correctly # MSE-calibrates every per-expert weight quantizer in fused-experts MoE models. cfg = load_config("general/ptq/nvfp4_experts_only_mse-kv_fp8_cast") mtq.quantize(model, cfg, forward_loop=calibration_forward_loop) ``` ### Testing **Original validation — Qwen3.5-122B-A10B with `nvfp4_experts_only_mse-fp8_cast_kv`:** - **Before:** 1/12288 (layer 38 expert 69) `gate \!= up`; 0 weights MSE-calibrated. - **After:** 0/12288 mismatches; 24576 weights MSE-calibrated; ~4.2 min. **End-to-end pipeline validation — Qwen3.5-35B-A3B (40 layers × 256 experts × 2 projections = 20,480 per-expert weight quantizers), TRT-LLM 1.3.0rc13 + transformers 5.6 docker, single B200:** | | Path A (4-sample calib, deliberately undercalibrated) | Path B (zero forward-pass tokens) | |---|---|---| | Per-expert weight quantizers calibrated | 20,480 / 20,480 | 20,480 / 20,480 | | Missing `_amax` | 0 | 0 | | All-zero `_amax` | 0 | 0 | | `mtq.quantize` time | 25–34 s | 23 s | - **Cross-path diff:** every per-expert weight amax matches **bit-for-bit** between the two paths (`n=20480 exact=20480 diff=0 max_rel=0`). With 8/256 experts routed per token and 4 calib samples, almost all experts are "dead" in Path A. Bootstrap fills them from `max(|weight|)`, MSE searches deterministically from there → identical to Path B which bootstraps everything. - **Export to HF NVFP4 checkpoint** succeeded (~95 s, 22 GB checkpoint). Resulting `generation_config.json` has `do_sample: true` (upstream had `top_k=20` + `top_p=0.95` which would have failed strict validate). - **TRT-LLM inference loaded the checkpoint and generated text:** `"Born in north-east France, Soyer trained as a"` → `" tailor. Demonstrating his craft at a young age, at 20 he moved to Paris at the requests of the noble people of Picardy."` (coherent grammar; factually wrong as expected with 4-sample calib, but no NaN/Inf in logits, no scale-mismatch crash). 92 GB GPU memory used. ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A - Did you write any new necessary tests?: ❌ <\!-- relies on existing recipe-level integration coverage; verified end-to-end on Qwen3.5-122B-A10B and Qwen3.5-35B-A3B + TRT-LLM 1.3.0rc13 --> - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: N/A - Did you get Claude approval on this PR?: ❌ <\!-- will run \`/claude review\` --> ### Additional Information Follow-up to PR #1407 (MSE+FP8-cast-KV recipes). The recipe YAML files landed there; this PR fixes the calibration codepath so the MSE recipes actually exercise per-expert weight quantizers in fused-experts MoE containers. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Fixed generation configuration validation for HuggingFace model exports. * Improved handling of quantization shape mismatches during expert weight export. * **New Features** * Enhanced calibration process with automatic population of missing expert quantizers. * Added grouped quantizer synchronization for improved multi-expert quantization. * **Tests** * Added regression tests for fused expert export and calibration correctness. [![Review Change Stack](https://storage.googleapis.com/coderabbit_public_assets/review-stack-in-coderabbit-ui.svg)](https://app.coderabbit.ai/change-stack/NVIDIA/Model-Optimizer/pull/1421) <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
1 parent 2ce745a commit 5508c32

7 files changed

Lines changed: 399 additions & 71 deletions

File tree

modelopt/torch/export/moe_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,19 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
110110
and w_quantizer._amax.dim() >= 1
111111
):
112112
amax = w_quantizer._amax
113+
# Per-block _amax (NVFP4 static) collapses the row axis we want
114+
# to slice on; restore it so dim-0 slicing splits gate/up.
115+
if amax.numel() != fused_total and amax.numel() % fused_total == 0:
116+
amax = amax.contiguous().view(fused_total, amax.numel() // fused_total)
113117
amax_dim0 = amax.shape[0]
114118
if fused_total % amax_dim0 == 0:
115119
slice_start = fused_start * amax_dim0 // fused_total
116120
slice_end = (fused_start + weight_slice.shape[0]) * amax_dim0 // fused_total
117-
w_quantizer.amax = amax[slice_start:slice_end].contiguous()
121+
sliced = amax[slice_start:slice_end].contiguous()
122+
# The amax setter refuses shape changes; drop _amax first.
123+
if hasattr(w_quantizer, "_amax"):
124+
delattr(w_quantizer, "_amax")
125+
w_quantizer.amax = sliced
118126
else:
119127
warnings.warn(
120128
f"Expert {idx} {proj_name}: fused amax dim0 ({amax_dim0}) does not "

modelopt/torch/export/unified_export_hf.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,6 +1134,19 @@ def _unpatch_revert_weight_conversion(patches: list[tuple[Any, Any]]) -> None:
11341134
mod.revert_weight_conversion = original
11351135

11361136

1137+
def _sanitize_generation_config_for_save(model: torch.nn.Module) -> None:
1138+
"""Force ``do_sample=True`` when generation_config has ``top_k``/``top_p`` set.
1139+
1140+
Newer transformers reject ``do_sample=False`` mixed with sampling attrs in
1141+
``save_pretrained``'s strict validate.
1142+
"""
1143+
gc = getattr(model, "generation_config", None)
1144+
if gc is None:
1145+
return
1146+
if getattr(gc, "top_k", None) is not None or getattr(gc, "top_p", None) is not None:
1147+
gc.do_sample = True
1148+
1149+
11371150
def export_speculative_decoding(
11381151
model: torch.nn.Module,
11391152
dtype: torch.dtype | None = None,
@@ -1228,6 +1241,8 @@ def export_hf_checkpoint(
12281241
# modeling_utils does `from core_model_loading import revert_weight_conversion`.
12291242
_patches = _patch_revert_weight_conversion()
12301243

1244+
_sanitize_generation_config_for_save(model)
1245+
12311246
try:
12321247
model.save_pretrained(
12331248
export_dir,

modelopt/torch/quantization/model_calib.py

Lines changed: 147 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252
promote_nvfp4_static_quantizers,
5353
quantizer_attr_names,
5454
reduce_amax,
55-
weight_attr_names,
5655
)
5756
from .utils.calib_utils import _GPTQ_HELPER_REGISTRY, GPTQHelper
5857

@@ -66,6 +65,107 @@
6665
"svdquant",
6766
]
6867

68+
69+
def _is_calibrated_nvfp4_static(q) -> bool:
70+
"""True iff ``q`` is an enabled NVFP4-static weight quantizer with ``_amax`` set."""
71+
return (
72+
isinstance(q, TensorQuantizer)
73+
and not q._disabled
74+
and q.is_nvfp4_static
75+
and getattr(q, "_amax", None) is not None
76+
)
77+
78+
79+
def _collect_grouped_linears(model: nn.Module) -> list[list[nn.Module]]:
80+
"""Collect sibling groups (Q/K/V, gate/up) with calibrated NVFP4-static weight quantizers."""
81+
# Inline: layer_utils → quant_utils → model_calib cycle.
82+
from modelopt.torch.export.layer_utils import _GATE_UP_PAIRS
83+
84+
# Reuses the existing gate/up pairs and adds Q/K/V (no equivalent constant
85+
# in export). Single source for the gate/up half avoids parallel lists.
86+
patterns: tuple[tuple[str, ...], ...] = (("q_proj", "k_proj", "v_proj"), *_GATE_UP_PAIRS)
87+
groups: list[list[nn.Module]] = []
88+
wq_attr = quantizer_attr_names("weight").weight_quantizer
89+
for parent in model.modules():
90+
for sibling_names in patterns:
91+
members = [
92+
child
93+
for child in (getattr(parent, n, None) for n in sibling_names)
94+
if child is not None and _is_calibrated_nvfp4_static(getattr(child, wq_attr, None))
95+
]
96+
if len(members) >= 2:
97+
groups.append(members)
98+
return groups
99+
100+
101+
@torch.no_grad()
102+
def _bootstrap_uncalibrated_weight_quantizers(model: nn.Module) -> int:
103+
"""Re-run weight calibration on the weight tensor for quantizers missing ``_amax``.
104+
105+
Covers MoE experts that ``max_calibrate`` skipped (no routed tokens) so MSE
106+
doesn't drop them and break the gate==up ``weight_scale_2`` export invariant.
107+
Activation quantizers on those modules remain uncalibrated; emits a warning.
108+
"""
109+
name_to_module = dict(model.named_modules())
110+
n = 0
111+
for module in name_to_module.values():
112+
if not isinstance(module, QuantModule):
113+
continue
114+
with enable_weight_access_and_writeback(module, model, name_to_module):
115+
for weight, q in module.iter_weights_for_calibration():
116+
if not isinstance(q, TensorQuantizer) or q._disabled or q._dynamic:
117+
continue
118+
if q._calibrator is None:
119+
continue
120+
if getattr(q, "_amax", None) is not None and not torch.all(q._amax == 0):
121+
continue
122+
q.disable_quant()
123+
q.enable_calib()
124+
q(weight)
125+
if q._calibrator.compute_amax() is not None:
126+
q.load_calib_amax()
127+
q.enable_quant()
128+
q.disable_calib()
129+
if hasattr(q._calibrator, "reset"):
130+
q._calibrator.reset()
131+
n += 1
132+
if n > 0:
133+
warnings.warn(
134+
f"Bootstrapped {n} weight quantizer(s) with no routed calibration tokens; "
135+
f"their activation quantizers (if any) remain uncalibrated. "
136+
f"Increase calib size/seq len to activate all experts.",
137+
stacklevel=2,
138+
)
139+
return n
140+
141+
142+
@torch.no_grad()
143+
def _sync_grouped_weight_global_amax(model: nn.Module) -> int:
144+
"""Unify NVFP4 ``global_amax`` across Q/K/V and gate/up sibling weight quantizers.
145+
146+
Run after ``max_calibrate``. Sibling discovery is name-based via
147+
``_collect_grouped_linears``; non-matching architectures (wqkv, fused
148+
qkv_proj, DeepSeek variants, single-Linear fused gate_up_proj) silently
149+
fall back to per-module global_amax. Fused-experts containers already
150+
share a single quantizer across gate/up halves and need no sync.
151+
"""
152+
# quant_utils imports back from this module; top-level would cycle.
153+
from modelopt.torch.export.quant_utils import preprocess_linear_fusion
154+
155+
wq_attr = quantizer_attr_names("weight").weight_quantizer
156+
n_groups = 0
157+
for group in _collect_grouped_linears(model):
158+
for child in group:
159+
wq = getattr(child, wq_attr)
160+
if not isinstance(wq, NVFP4StaticQuantizer):
161+
NVFP4StaticQuantizer.from_tensor_quantizer(
162+
wq, global_amax=reduce_amax(wq._amax, axis=None)
163+
)
164+
preprocess_linear_fusion(group)
165+
n_groups += 1
166+
return n_groups
167+
168+
69169
CalibratorFactory: TypeAlias = Callable[
70170
[torch.Tensor, int | tuple | list | None, Callable[..., torch.Tensor]], _Calibrator
71171
]
@@ -346,32 +446,23 @@ def mse_calibrate(
346446
See :class:`MseCalibConfig <modelopt.torch.quantization.config.MseCalibConfig>` for
347447
details on the remaining arguments.
348448
"""
349-
# Step 1: First get initial amax using max calibration
449+
# Step 1: max calibrate, bootstrap dead-expert weight quantizers,
450+
# unify grouped NVFP4 global_amax so MSE sees a consistent FP8 grid.
350451
max_calibrate(model, forward_loop, distributed_sync)
452+
_bootstrap_uncalibrated_weight_quantizers(model)
453+
_sync_grouped_weight_global_amax(model)
351454

352-
# Step 2: Replace calibrators with MseCalibrator for enabled quantizers
353-
# and identify weight quantizers
354-
weight_quantizers = []
355-
seen_modules = set()
356-
455+
# Step 2: replace calibrators with MseCalibrator for enabled quantizers.
357456
for name, module in list(model.named_modules()):
358457
if isinstance(module, TensorQuantizer) and not module._disabled:
359458
if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"):
360-
# Get the initial amax from max calibration
361459
initial_amax = module._amax.clone().detach()
460+
is_nvfp4_static = module.is_nvfp4_static
362461

363-
is_nvfp4_static = (
364-
module.is_static_block_quant
365-
and module._num_bits == (2, 1)
366-
and module._block_sizes is not None
367-
and module._block_sizes.get("scale_bits") == (4, 3)
368-
)
369-
370-
if is_nvfp4_static:
371-
# Compute and set global_amax
462+
# Promote standalone NVFP4-static quantizers; grouped siblings
463+
# already promoted by _sync_grouped_weight_global_amax above.
464+
if is_nvfp4_static and not isinstance(module, NVFP4StaticQuantizer):
372465
global_amax = reduce_amax(initial_amax, axis=None)
373-
374-
# Convert to NVFP4StaticQuantizer in-place
375466
NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax)
376467

377468
if fp8_scale_sweep:
@@ -412,52 +503,48 @@ def mse_calibrate(
412503
quant_func=partial(_mse_quant_func, quantizer=module),
413504
)
414505

415-
# Identify weight quantizers by checking if they have corresponding weight parameters
506+
# Step 3: calibrate weight quantizers via iter_weights_for_calibration.
416507
name_to_module = dict(model.named_modules())
508+
seen_modules: set[int] = set()
509+
pbar = tqdm(desc="MSE weight calibration")
510+
n_calibrated = 0
417511
for parent_module in name_to_module.values():
418-
if parent_module in seen_modules:
512+
if id(parent_module) in seen_modules or not isinstance(parent_module, QuantModule):
419513
continue
420-
for weight_name in weight_attr_names(parent_module):
421-
weight_quantizer_name = quantizer_attr_names(weight_name).weight_quantizer
422-
weight_quantizer = getattr(parent_module, weight_quantizer_name, None)
423-
if isinstance(weight_quantizer, TensorQuantizer) and weight_quantizer.is_enabled:
424-
if getattr(weight_quantizer, "_calibrator", None) is not None:
425-
weight_quantizers.append((parent_module, weight_name, weight_quantizer))
426-
seen_modules.add(parent_module)
427-
428-
# Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation
429-
# This prevents massive memory accumulation seen in large models
430-
for idx, (parent_module, weight_name, weight_quantizer) in enumerate(
431-
tqdm(weight_quantizers, desc="MSE weight calibration")
432-
):
433-
# Enable calibration mode for the weight quantizer
434-
weight_quantizer.disable_quant()
435-
weight_quantizer.enable_calib()
514+
seen_modules.add(id(parent_module))
436515
with enable_weight_access_and_writeback(parent_module, model, name_to_module):
437-
weight = getattr(parent_module, weight_name)
438-
weight_quantizer(weight)
516+
for weight, weight_quantizer in parent_module.iter_weights_for_calibration():
517+
if not (
518+
isinstance(weight_quantizer, TensorQuantizer)
519+
and weight_quantizer.is_enabled
520+
and getattr(weight_quantizer, "_calibrator", None) is not None
521+
):
522+
continue
523+
weight_quantizer.disable_quant()
524+
weight_quantizer.enable_calib()
525+
weight_quantizer(weight)
439526

440-
# IMMEDIATELY compute amax and reset calibrator to free memory
441-
cal = getattr(weight_quantizer, "_calibrator", None)
442-
if cal is not None and cal.compute_amax() is not None:
443-
weight_quantizer.load_calib_amax()
527+
cal = weight_quantizer._calibrator
528+
if cal.compute_amax() is not None:
529+
weight_quantizer.load_calib_amax()
444530

445-
weight_quantizer.enable_quant()
446-
weight_quantizer.disable_calib()
531+
weight_quantizer.enable_quant()
532+
weight_quantizer.disable_calib()
447533

448-
# Synchronize ALL CUDA devices before resetting to ensure all async operations complete
449-
# This is critical for multi-GPU setups where tensors may be on different devices
450-
if torch.cuda.is_available():
451-
for dev_id in range(torch.cuda.device_count()):
452-
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
534+
if torch.cuda.is_available():
535+
for dev_id in range(torch.cuda.device_count()):
536+
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
453537

454-
if cal is not None and hasattr(cal, "reset"):
455-
cal.reset()
538+
if hasattr(cal, "reset"):
539+
cal.reset()
456540

457-
if (idx + 1) % 10 == 0 and torch.cuda.is_available():
458-
for dev_id in range(torch.cuda.device_count()):
459-
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
460-
torch.cuda.empty_cache()
541+
pbar.update(1)
542+
n_calibrated += 1
543+
if n_calibrated % 10 == 0 and torch.cuda.is_available():
544+
for dev_id in range(torch.cuda.device_count()):
545+
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
546+
torch.cuda.empty_cache()
547+
pbar.close()
461548

462549
if torch.cuda.is_available():
463550
for dev_id in range(torch.cuda.device_count()):
@@ -612,6 +699,8 @@ def forward(self, input, *args, **kwargs):
612699
print_rank_0("local_hessian: Running max calibration for all quantizers...")
613700
max_calibrate(model, forward_loop, distributed_sync)
614701

702+
_sync_grouped_weight_global_amax(model)
703+
615704
# Setup helpers for all quantized linear modules
616705
name_to_module = dict(model.named_modules())
617706
weight_quantizers_info = []
@@ -666,14 +755,9 @@ def quant_func(x, amax, quantizer=weight_quantizer):
666755

667756
return xq
668757

669-
is_nvfp4_static = (
670-
weight_quantizer.is_static_block_quant
671-
and weight_quantizer._num_bits == (2, 1)
672-
and weight_quantizer._block_sizes is not None
673-
and weight_quantizer._block_sizes.get("scale_bits") == (4, 3)
674-
)
758+
is_nvfp4_static = weight_quantizer.is_nvfp4_static
675759

676-
if is_nvfp4_static:
760+
if is_nvfp4_static and not isinstance(weight_quantizer, NVFP4StaticQuantizer):
677761
global_amax = reduce_amax(initial_amax, axis=None)
678762
NVFP4StaticQuantizer.from_tensor_quantizer(weight_quantizer, global_amax=global_amax)
679763

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,16 @@ def is_mx_format(self):
514514
and self.block_sizes.get("scale_bits", None) == (8, 0)
515515
)
516516

517+
@property
518+
def is_nvfp4_static(self):
519+
"""True for E2M1 weights + E4M3 per-block scales in static layout (format-only check)."""
520+
return (
521+
self.is_static_block_quant
522+
and self._num_bits == (2, 1)
523+
and self._block_sizes is not None
524+
and self._block_sizes.get("scale_bits") == (4, 3)
525+
)
526+
517527
def is_mxfp(self, bits):
518528
"""Check if is MXFP4/MXFP6/MXFP8."""
519529
if bits == 4:

modelopt/torch/quantization/plugins/huggingface.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,24 @@ def forward(self, *args, **kwargs):
900900
self._down_proj_linear = False
901901
return super().forward(*args, **kwargs)
902902

903+
def iter_weights_for_calibration(self):
904+
"""Yield ``(weight_slice, quantizer)`` per-expert pairs.
905+
906+
The base impl uses singular ``*_weight_quantizer`` and skips fused-
907+
experts modules, so weight-only calibration never reaches per-expert
908+
quantizers without this override.
909+
"""
910+
for weight_name, quantizers_name in (
911+
("gate_up_proj", "gate_up_proj_weight_quantizers"),
912+
("down_proj", "down_proj_weight_quantizers"),
913+
):
914+
weight = getattr(self, weight_name, None)
915+
quantizers = getattr(self, quantizers_name, None)
916+
if weight is None or quantizers is None:
917+
continue
918+
for idx, q in enumerate(quantizers):
919+
yield weight[idx], q
920+
903921
def fold_weight(self, keep_attrs: bool = False):
904922
"""Fold per-expert weight quantizers into the fused 3-D weights.
905923

modelopt/torch/quantization/utils/core_utils.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -957,13 +957,7 @@ def promote_nvfp4_static_quantizers(model: nn.Module) -> int:
957957
for _name, module in list(model.named_modules()):
958958
if isinstance(module, TensorQuantizer) and not module._disabled:
959959
if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"):
960-
is_nvfp4_static = (
961-
module.is_static_block_quant
962-
and module._num_bits == (2, 1)
963-
and module._block_sizes is not None
964-
and module._block_sizes.get("scale_bits") == (4, 3)
965-
)
966-
if is_nvfp4_static:
960+
if module.is_nvfp4_static:
967961
initial_amax = module._amax.clone().detach()
968962
global_amax = reduce_amax(initial_amax, axis=None)
969963
NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax)

0 commit comments

Comments
 (0)