Skip to content

Commit fe242d6

Browse files
committed
Replace moe_utils workarounds with a layer-skip hatch in _process_quantized_modules
Reverted the safe-CPU-amax / global_amax-sync / device-pinning patches in moe_utils.py — those were working around a symptom: touching the per-expert quantizers of layers that were never visited by the layerwise loop (their _amax is unset). When MO_DEBUG_MAX_LAYERS=N is set, simply skip _export_fused_experts for any *.layers.{>=N}.* module. Layers 0..N-1 all have _bootstrap_uncalibrated_weight_quantizers + MSE-applied amaxes so the existing main moe_utils.py code path works.
1 parent 7ac315d commit fe242d6

2 files changed

Lines changed: 17 additions & 77 deletions

File tree

modelopt/torch/export/moe_utils.py

Lines changed: 6 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -59,46 +59,9 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
5959
# 2-3. Split + export each per-expert projection.
6060
fused_dim0 = gate_up.shape[1] # 2 * expert_dim
6161

62-
def _safe_amax(quantizer_src: nn.Module) -> torch.Tensor | None:
63-
"""Return _amax as a clean tensor, surfacing any latent CUDA error first.
64-
65-
Layerwise calibration's _save_layer + full_restore can leave the per-expert
66-
``_amax`` as a CUDA tensor reconstructed from a serialized view with non-zero
67-
storage offset. Touching it directly (``torch.all`` / ``deepcopy``) then triggers
68-
``cudaErrorIllegalAddress``. We synchronize first to surface any pending error,
69-
then return the tensor on its original device. Falling back to CPU only on the
70-
error path avoids creating a device mismatch with sibling buffers
71-
(``_global_amax``) that stayed on the original device.
72-
"""
73-
amax = getattr(quantizer_src, "_amax", None)
74-
if amax is None or not isinstance(amax, torch.Tensor):
75-
return None
76-
try:
77-
if amax.is_cuda:
78-
torch.cuda.synchronize(amax.device)
79-
# Force a no-op read to trigger any latent async error.
80-
_ = amax.shape
81-
return amax.detach()
82-
except Exception:
83-
# CUDA tensor was unreadable. Try to recover a CPU copy; if that
84-
# also fails, treat as uncalibrated.
85-
try:
86-
return amax.detach().cpu().float()
87-
except Exception:
88-
return None
89-
9062
for idx in range(n):
9163
expert = nn.Module()
9264

93-
# Pre-extract both per-expert amaxes to CPU *before* the projection loop's
94-
# deepcopy. deepcopy calls .clone() on CUDA tensors — if the stored _amax
95-
# has corrupt storage (under-calibrated experts after layerwise calib), the
96-
# clone triggers an async CUDA illegal-memory-access error. Synchronizing in
97-
# _safe_amax surfaces the error here so subsequent operations work on
98-
# safe CPU float32 tensors.
99-
gu_amax = _safe_amax(module.gate_up_proj_weight_quantizers[idx])
100-
down_amax = _safe_amax(module.down_proj_weight_quantizers[idx])
101-
10265
# If the gate_up source quantizer was never calibrated (rare expert
10366
# that received no calibration tokens), derive its amax once from the
10467
# FUSED tensor so gate and up share the same weight_scale_2 below.
@@ -109,11 +72,11 @@ def _safe_amax(quantizer_src: nn.Module) -> torch.Tensor | None:
10972
# mismatched weight_scale_2 and garbled MoE output at inference.
11073
gate_up_q = module.gate_up_proj_weight_quantizers[idx]
11174
if getattr(gate_up_q, "is_enabled", False) and (
112-
gu_amax is None or bool(torch.all(gu_amax == 0))
75+
not hasattr(gate_up_q, "_amax")
76+
or gate_up_q._amax is None
77+
or torch.all(gate_up_q._amax == 0)
11378
):
11479
gate_up_q.amax = gate_up[idx].abs().amax().to(torch.float32)
115-
# Refresh the CPU amax we'll inject below.
116-
gu_amax = _safe_amax(gate_up_q)
11780
warnings.warn(
11881
f"Expert {idx} gate_up_proj weight quantizer was not calibrated "
11982
f"(amax missing or zero). Using fused-tensor amax as fallback "
@@ -137,23 +100,7 @@ def _safe_amax(quantizer_src: nn.Module) -> torch.Tensor | None:
137100
i_quantizer = gate_up_input_q if is_gate_up else down_input_q
138101

139102
# gate/up share a weight quantizer — clone so each gets independent amax.
140-
# Null _amax on source before deepcopy so the (possibly corrupt) CUDA tensor
141-
# is never cloned; restore afterwards for the sibling projection. The CPU
142-
# amax we pre-extracted gets injected in its place.
143-
if is_gate_up:
144-
_saved_amax = getattr(w_quantizer_src, "_amax", None)
145-
try:
146-
w_quantizer_src._amax = None
147-
w_quantizer = copy.deepcopy(w_quantizer_src)
148-
finally:
149-
w_quantizer_src._amax = _saved_amax
150-
if gu_amax is not None:
151-
w_quantizer._amax = gu_amax
152-
else:
153-
w_quantizer = w_quantizer_src
154-
if down_amax is not None:
155-
# Replace any CUDA-resident _amax with the safe CPU copy.
156-
w_quantizer._amax = down_amax
103+
w_quantizer = copy.deepcopy(w_quantizer_src) if is_gate_up else w_quantizer_src
157104

158105
# For per-channel amax (dim >= 1), proportionally slice dim-0
159106
# to match the split weight.
@@ -162,7 +109,7 @@ def _safe_amax(quantizer_src: nn.Module) -> torch.Tensor | None:
162109
and w_quantizer._amax is not None
163110
and w_quantizer._amax.dim() >= 1
164111
):
165-
amax = w_quantizer._amax # safe-extracted via _safe_amax (CUDA or CPU, recovered if corrupt)
112+
amax = w_quantizer._amax
166113
# Per-block _amax (NVFP4 static) collapses the row axis we want
167114
# to slice on; restore it so dim-0 slicing splits gate/up.
168115
if amax.numel() != fused_total and amax.numel() % fused_total == 0:
@@ -185,14 +132,13 @@ def _safe_amax(quantizer_src: nn.Module) -> torch.Tensor | None:
185132
)
186133

187134
# If the weight quantizer was never calibrated, compute amax from weights.
188-
# All amax tests below operate on the safe CPU tensor injected above.
189135
if (
190136
hasattr(w_quantizer, "is_enabled")
191137
and w_quantizer.is_enabled
192138
and (
193139
not hasattr(w_quantizer, "_amax")
194140
or w_quantizer._amax is None
195-
or bool(torch.all(w_quantizer._amax == 0))
141+
or torch.all(w_quantizer._amax == 0)
196142
)
197143
):
198144
w_quantizer.amax = weight_slice.abs().amax().to(torch.float32)
@@ -203,23 +149,6 @@ def _safe_amax(quantizer_src: nn.Module) -> torch.Tensor | None:
203149
stacklevel=2,
204150
)
205151

206-
# Align _amax and global_amax with the weight slice's device. The
207-
# export math ``per_block_scale * 448 / per_block_scale_max`` reads
208-
# both from the quantizer and would otherwise error if they drifted
209-
# apart (e.g., CPU-offloaded big-model layers + CUDA-resident weight
210-
# slice, or our CPU-injected _amax + the original CUDA global_amax).
211-
# No magnitude floor — that's main's policy for the uncalibrated
212-
# fallback below.
213-
if (
214-
hasattr(w_quantizer, "_amax")
215-
and w_quantizer._amax is not None
216-
):
217-
target_device = weight_slice.device
218-
if w_quantizer._amax.device != target_device:
219-
w_quantizer._amax = w_quantizer._amax.to(target_device)
220-
if hasattr(w_quantizer, "global_amax"):
221-
w_quantizer.global_amax = w_quantizer._amax.float().amax()
222-
223152
wrapper = nn.Module()
224153
wrapper.weight = nn.Parameter(weight_slice.contiguous(), requires_grad=False)
225154
wrapper.weight_quantizer = w_quantizer

modelopt/torch/export/unified_export_hf.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import collections.abc
1919
import json
20+
import os
2021
import re
2122
import tempfile
2223
import warnings
@@ -662,6 +663,16 @@ def _process_quantized_modules(
662663
# _QuantFusedExperts uses plural `gate_up_proj_weight_quantizers` (ModuleList),
663664
# which get_quantization_format's singular-weight_quantizer check misses. Handle
664665
# it explicitly before the format gate so fused-experts get split + quantized.
666+
# Debug hatch (paired with MO_DEBUG_MAX_LAYERS in model_calib.layerwise_calibrate):
667+
# skip _export_fused_experts for layers whose layerwise calibration was never run.
668+
# Those layers' per-expert quantizers have no _amax — touching them triggers the
669+
# uncalibrated-fallback warnings or, with corrupt storage, a CUDA illegal-memory
670+
# error. With the calibrated layers only, every expert has a valid _amax.
671+
_debug_max = int(os.environ.get("MO_DEBUG_MAX_LAYERS", "0") or "0")
672+
if _debug_max > 0:
673+
_m = re.search(r"\.layers\.(\d+)\.", name or "")
674+
if _m and int(_m.group(1)) >= _debug_max:
675+
continue
665676
with fsdp2_aware_weight_update(model, sub_module, reshard=False):
666677
_export_fused_experts(sub_module, dtype)
667678
elif get_quantization_format(sub_module) != QUANTIZATION_NONE:

0 commit comments

Comments
 (0)