Skip to content

Commit ada1e26

Browse files
cjluo-nvclaude
andauthored
[NVBug: 6000530] Fix AWQ crash for uncalibrated MoE experts (#1142)
## Summary - Fixes NVBugs 6000530: `AttributeError: 'float' object has no attribute 'pow'` when running AWQ lite with `moe_calib_experts_ratio < 1.0` on MoE models (e.g. Qwen3-30B-A3B). - **Root cause**: When `moe_calib_experts_ratio=0.5`, some MoE experts receive zero tokens during the AWQ cache phase, leaving `act_scale` as a Python float `0.0` instead of a tensor. This causes two failures: 1. **Search phase crash**: Uncalibrated experts crash in `get_scale()` because `float.pow()` doesn't exist. 2. **Export crash**: Calibrated experts have `pre_quant_scale` but uncalibrated ones don't, causing `torch.stack()` to fail on mixed `None`/tensor values in `preprocess_linear_fusion()`. - **Fix**: Handle uncalibrated experts (`num_cache_steps == 0`) in two stages: 1. **Before search**: Disable AWQ search (`is_enabled = False`) to prevent `get_scale()` crash on float `act_scale`. 2. **During postprocessing**: Max calibrate weights and apply a neutral (all-ones) `pre_quant_scale` so export can stack scaling factors consistently across all experts. The `pre_quant_scale` buffer must be registered outside `enable_weight_access_and_writeback` because HF accelerate's `post_forward` hook drops newly-registered submodule buffers. ## Test plan - [x] Reproduce with `Qwen/Qwen3-30B-A3B`, `--qformat int4_awq`, `--moe_calib_experts_ratio 0.5` — verify no crash during calibration and export 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Signed-off-by: Chenjie Luo <chenjiel@nvidia.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 74a8694 commit ada1e26

1 file changed

Lines changed: 37 additions & 9 deletions

File tree

modelopt/torch/quantization/model_calib.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,6 +1179,17 @@ def sync_act_scale_across_dp(module, data_parallel_group):
11791179
module.parallel_state.data_parallel_group,
11801180
)
11811181

1182+
# Disable AWQ search for uncalibrated experts (num_cache_steps == 0) to
1183+
# prevent get_scale() crash on float act_scale. Max calibration and neutral
1184+
# pre_quant_scale are applied in the postprocessing loop below.
1185+
for name, module in model.named_modules():
1186+
if (
1187+
is_quantized_linear(module)
1188+
and hasattr(module, "awq_lite")
1189+
and module.awq_lite.num_cache_steps == 0
1190+
):
1191+
module.awq_lite.is_enabled = False
1192+
11821193
AWQLiteHelper.cache_mode = False
11831194
print_rank_0("awq_lite: Searching parameters...")
11841195
with torch.no_grad():
@@ -1212,16 +1223,33 @@ def postprocess(module, name):
12121223
for name, module in model.named_modules():
12131224
if hasattr(module, "awq_lite"):
12141225
if module.awq_lite.num_cache_steps == 0:
1215-
module.awq_lite.is_enabled = False
1216-
elif module.awq_lite.num_search_steps == 0:
1217-
module.awq_lite.is_enabled = False
1218-
warnings.warn(
1219-
"awq_lite: Calling `forward_loop(model)` the second time did not forward data through the"
1220-
f" {name}. Please provide a valid `forward_loop` function that can be used to"
1221-
" forward data through the model many times."
1226+
# Uncalibrated expert: max calibrate weights and apply neutral
1227+
# (all-ones) pre_quant_scale for export consistency.
1228+
# NOTE: ones_scale must be registered OUTSIDE enable_weight_access_and_writeback
1229+
# because HF accelerate post_forward drops newly-registered submodule buffers.
1230+
with enable_weight_access_and_writeback(module, model, name_to_module):
1231+
max_calibrate(module, lambda module: module.weight_quantizer(module.weight))
1232+
w_shape, w_dtype, w_device = (
1233+
module.weight.shape[1],
1234+
module.weight.dtype,
1235+
module.weight.device,
1236+
)
1237+
module.input_quantizer._enable_pre_quant_scale = True
1238+
module.input_quantizer.pre_quant_scale = torch.ones(
1239+
w_shape,
1240+
dtype=w_dtype,
1241+
device=w_device,
12221242
)
1223-
with enable_weight_access_and_writeback(module, model, name_to_module):
1224-
postprocess(module, name)
1243+
else:
1244+
if module.awq_lite.num_search_steps == 0:
1245+
module.awq_lite.is_enabled = False
1246+
warnings.warn(
1247+
"awq_lite: Calling `forward_loop(model)` the second time did not forward"
1248+
f" data through the {name}. Please provide a valid `forward_loop` function"
1249+
" that can be used to forward data through the model many times."
1250+
)
1251+
with enable_weight_access_and_writeback(module, model, name_to_module):
1252+
postprocess(module, name)
12251253

12261254
module.awq_lite.cleanup()
12271255
if not debug:

0 commit comments

Comments
 (0)