Skip to content

Commit deb0e3a

Browse files
committed
Address review: gate fix on is_input_quantized; expand tests
@meenchen: the unconditional input_quantizer.enable() in the uncalibrated branch wrongly turned on input quantization for weight-only AWQ configs (e.g. INT4_AWQ_CFG, where the user's config sets *input_quantizer enable=False and setup() therefore never disabled it). Gate the entire postprocess block — per-channel-amax collapse and enable() — behind module.awq_lite.is_input_quantized so weight-only configs are untouched. @coderabbitai: strengthen the existing regression test to also assert the export-critical scalar amax invariant (axis=None, numel==1) when amax exists, and add a companion test on INT4_AWQ_CFG asserting the uncalibrated linear's input_quantizer stays disabled. The NVFP4 test now requires CUDA (dynamic block quantization is CUDA-only), guarded with pytest.mark.skipif. Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
1 parent af1bfd6 commit deb0e3a

2 files changed

Lines changed: 68 additions & 29 deletions

File tree

modelopt/torch/quantization/model_calib.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,19 +1316,21 @@ def postprocess(module, name):
13161316
dtype=w_dtype,
13171317
device=w_device,
13181318
)
1319-
# Mirror the calibrated postprocess path: collapse any
1320-
# per-channel _amax left over from cache_mode max_calibrate
1321-
# into a per-tensor scalar (so preprocess_linear_fusion's
1322-
# numel==1 assertion passes), and re-enable the quantizer
1323-
# (awq_lite.setup disabled it). Without these, export drops
1324-
# input_scale and per-expert MoE loaders fail.
1325-
if module.input_quantizer.amax is not None:
1326-
act_amax = module.input_quantizer.amax
1327-
module.input_quantizer._amax_for_smoothing = act_amax.cpu()
1328-
module.input_quantizer.reset_amax()
1329-
module.input_quantizer.axis = None
1330-
module.input_quantizer.amax = act_amax.amax()
1331-
module.input_quantizer.enable()
1319+
# Mirror the calibrated postprocess path, gated on
1320+
# is_input_quantized so weight-only AWQ configs (where
1321+
# setup() never disabled input_quantizer) stay untouched.
1322+
# Collapse any per-channel _amax left over from cache_mode
1323+
# max_calibrate into a per-tensor scalar so
1324+
# preprocess_linear_fusion's numel==1 assertion passes, and
1325+
# re-enable the quantizer (awq_lite.setup disabled it).
1326+
if module.awq_lite.is_input_quantized:
1327+
if module.input_quantizer.amax is not None:
1328+
act_amax = module.input_quantizer.amax
1329+
module.input_quantizer._amax_for_smoothing = act_amax.cpu()
1330+
module.input_quantizer.reset_amax()
1331+
module.input_quantizer.axis = None
1332+
module.input_quantizer.amax = act_amax.amax()
1333+
module.input_quantizer.enable()
13321334
else:
13331335
with enable_weight_access_and_writeback(module, model, name_to_module):
13341336
postprocess(module, name)

tests/unit/torch/quantization/test_calib.py

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,21 @@ def test_padded_awq():
312312
model(torch.randn(2, 16, 16))
313313

314314

315+
class _TwoBranchModel(nn.Module):
316+
"""Two parallel linears; only the first is exercised by forward_loop."""
317+
318+
def __init__(self):
319+
super().__init__()
320+
self.calibrated = nn.Linear(16, 16, bias=False)
321+
self.uncalibrated = nn.Linear(16, 16, bias=False)
322+
323+
def forward(self, x, branch="calibrated"):
324+
if branch == "calibrated":
325+
return self.calibrated(x)
326+
return self.uncalibrated(x)
327+
328+
329+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="NVFP4 dynamic block quant is CUDA-only")
315330
def test_awq_lite_uncalibrated_linear_keeps_input_quantizer_enabled():
316331
"""Regression test for NVBug 6143871.
317332
@@ -322,27 +337,16 @@ def test_awq_lite_uncalibrated_linear_keeps_input_quantizer_enabled():
322337
+ _export_quantized_weight) drops the input_scale buffer and inference
323338
runtimes that read per-expert input_scale (e.g. TRT-LLM CutlassFusedMoE)
324339
crash with KeyError on '<idx>.w1.input_scale'.
325-
"""
326-
327-
class _TwoBranchModel(nn.Module):
328-
"""Two parallel linears; only the first is exercised by forward_loop."""
329-
330-
def __init__(self):
331-
super().__init__()
332-
self.calibrated = nn.Linear(16, 16, bias=False)
333-
self.uncalibrated = nn.Linear(16, 16, bias=False)
334-
335-
def forward(self, x, branch="calibrated"):
336-
if branch == "calibrated":
337-
return self.calibrated(x)
338-
return self.uncalibrated(x)
339340
341+
Also asserts the export-critical scalar amax invariant (axis=None,
342+
numel==1) — preprocess_linear_fusion enforces it for fused-expert groups.
343+
"""
340344
torch.manual_seed(0)
341-
model = _TwoBranchModel()
345+
model = _TwoBranchModel().cuda()
342346

343347
def _forward_loop(m):
344348
for _ in range(2):
345-
m(torch.randn(2, 16, 16), branch="calibrated")
349+
m(torch.randn(2, 16, 16, device="cuda"), branch="calibrated")
346350

347351
mtq.quantize(model, mtq.NVFP4_AWQ_LITE_CFG, _forward_loop)
348352

@@ -351,6 +355,39 @@ def _forward_loop(m):
351355
"Uncalibrated linear's input_quantizer must remain enabled after "
352356
"awq_lite postprocess so export emits input_scale (NVBug 6143871)."
353357
)
358+
uncal_q = model.uncalibrated.input_quantizer
359+
# When amax exists (cache-hit but search-miss path), it must be the
360+
# scalar form export expects — preprocess_linear_fusion asserts numel==1.
361+
# When it's None (truly never routed), set_expert_quantizer_amax will
362+
# populate it during export.
363+
if uncal_q.amax is not None:
364+
assert uncal_q.axis is None
365+
assert uncal_q.amax.numel() == 1
366+
367+
368+
def test_awq_lite_uncalibrated_weight_only_keeps_input_quantizer_disabled():
369+
"""Weight-only AWQ companion to NVBug 6143871.
370+
371+
For weight-only AWQ configs (input_quantizer disabled), awq_lite.setup()
372+
never touches the input_quantizer, so the postprocess uncalibrated branch
373+
must NOT enable it — doing so turns on quantization the user's config had
374+
explicitly opted out of.
375+
"""
376+
torch.manual_seed(0)
377+
model = _TwoBranchModel()
378+
379+
def _forward_loop(m):
380+
for _ in range(2):
381+
m(torch.randn(2, 16, 16), branch="calibrated")
382+
383+
mtq.quantize(model, mtq.INT4_AWQ_CFG, _forward_loop)
384+
385+
assert not model.calibrated.input_quantizer.is_enabled
386+
assert not model.uncalibrated.input_quantizer.is_enabled, (
387+
"Weight-only AWQ must not flip on the input_quantizer for "
388+
"uncalibrated layers — that would silently quantize activations "
389+
"the user's config left in full precision."
390+
)
354391

355392

356393
def test_smoothquant_enable_disable():

0 commit comments

Comments
 (0)