Skip to content

Commit b655321

Browse files
authored
[Issue 543] [Bug fix] Fix dynamic input quant for AWQ (#726)
## What does this PR do? **Type of change:** Bug fix <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** Dynamic input quantizers, e.g., MXFP4, are not restored after AWQ. This PR fix the issue. ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> Tested with MXFP4, NVFP4, int4 ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
1 parent 883c873 commit b655321

1 file changed

Lines changed: 12 additions & 8 deletions

File tree

modelopt/torch/quantization/model_calib.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -751,14 +751,18 @@ def postprocess(module, name):
751751
delattr(module.weight_quantizer, "_pre_quant_scale")
752752
if hasattr(module.input_quantizer, "_pre_quant_scale"):
753753
delattr(module.input_quantizer, "_pre_quant_scale")
754-
if module.awq_lite.is_input_quantized and module.input_quantizer.amax is not None:
755-
act_amax = module.input_quantizer.amax
756-
# TODO: make this a buffer after we support only heterogeneous checkpointing for MCore
757-
module.input_quantizer._amax_for_smoothing = act_amax.cpu()
758-
module.input_quantizer.reset_amax()
759-
module.input_quantizer.axis = None
760-
module.input_quantizer.amax = act_amax.amax()
761-
module.input_quantizer.enable()
754+
if module.awq_lite.is_input_quantized:
755+
if module.input_quantizer.amax is not None:
756+
act_amax = module.input_quantizer.amax
757+
# TODO: make this a buffer after we support only heterogeneous checkpointing for MCore
758+
module.input_quantizer._amax_for_smoothing = act_amax.cpu()
759+
module.input_quantizer.reset_amax()
760+
module.input_quantizer.axis = None
761+
module.input_quantizer.amax = act_amax.amax()
762+
module.input_quantizer.enable()
763+
# for dynamic quantization, there is no amax, so we just enable the quantizer
764+
else:
765+
module.input_quantizer.enable()
762766

763767
if module.awq_lite.is_enabled:
764768
apply_pre_quant_scale_and_smooth(module, 1.0 / module.awq_lite.best_scale)

0 commit comments

Comments
 (0)