Skip to content

Commit adcce61

Browse files
authored
add local hessian calibration (#788)
## What does this PR do? **Type of change:** new feature <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** Add a new calibration method for weight scale search. It considers activation information by weighing scale candidates with local hessian matrix. Initial experiments with Qwen3 8B NVFP4 shows improvements. ## Usage <!-- You can potentially add a usage example below. --> Use `NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG` quantization config for quantization and evaluation. ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added local Hessian-weighted MSE calibration pathway for NVFP4 per-block quantization with configurable amax search parameters and FP8 scale sweep support. * **Tests** * Added test coverage for the new local Hessian weight-only quantization configuration. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
1 parent ac7c985 commit adcce61

File tree

4 files changed

+436
-8
lines changed

4 files changed

+436
-8
lines changed

modelopt/torch/quantization/config.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,27 @@
419419
"algorithm": "max",
420420
}
421421

422+
NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG = {
423+
"quant_cfg": {
424+
"*weight_quantizer": {
425+
"num_bits": (2, 1),
426+
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
427+
"axis": None,
428+
"enable": True,
429+
},
430+
"*input_quantizer": {
431+
"num_bits": (2, 1),
432+
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
433+
"axis": None,
434+
"enable": True,
435+
},
436+
**_default_disabled_quantizer_cfg,
437+
},
438+
"algorithm": {
439+
"method": "local_hessian",
440+
"fp8_scale_sweep": True,
441+
},
442+
}
422443

423444
MAMBA_MOE_NVFP4_AGGRESSIVE_CFG = {
424445
"quant_cfg": {
@@ -1138,6 +1159,73 @@ class MseCalibConfig(QuantizeAlgorithmConfig):
11381159
)
11391160

11401161

1162+
class LocalHessianCalibConfig(QuantizeAlgorithmConfig):
1163+
"""Configuration for local Hessian-weighted MSE calibration.
1164+
1165+
This algorithm uses activation information to optimize per-block scales for weight
1166+
quantization. It minimizes the output reconstruction error by weighting the loss
1167+
with the local Hessian matrix computed from input activations.
1168+
1169+
The local Hessian loss for each block is: ``(dw @ H @ dw.T)`` where:
1170+
- ``dw = weight - quantized_weight`` (weight reconstruction error per block)
1171+
- ``H = X @ X.T`` is the local Hessian computed from input activations X
1172+
1173+
"""
1174+
1175+
method: Literal["local_hessian"] = ModeloptField("local_hessian")
1176+
1177+
step_size: float | None = ModeloptField(
1178+
default=0.1,
1179+
gt=0.0,
1180+
title="Step size for amax search.",
1181+
description="Step size between amax candidates. The number of candidates is computed as "
1182+
"ceil((stop_multiplier - start_multiplier) / step_size) + 1.",
1183+
)
1184+
1185+
start_multiplier: float | None = ModeloptField(
1186+
default=0.25,
1187+
gt=0.0,
1188+
title="Starting multiplier for amax search.",
1189+
description="Starting multiplier for amax search range (multiplies initial amax).",
1190+
)
1191+
1192+
stop_multiplier: float | None = ModeloptField(
1193+
default=4.0,
1194+
gt=0.0,
1195+
title="Ending multiplier for amax search.",
1196+
description="Ending multiplier for amax search range (multiplies initial amax).",
1197+
)
1198+
1199+
fp8_scale_sweep: bool | None = ModeloptField(
1200+
default=True,
1201+
title="Enable FP8 scale sweep for NVFP4 per-block quantization.",
1202+
description="If True, sweep over all 128 possible FP8 E4M3 scale values "
1203+
"for NVFP4 per-block quantization instead of using multipliers. "
1204+
"This is the recommended setting for NVFP4 quantization.",
1205+
)
1206+
1207+
block_size: int | None = ModeloptField(
1208+
default=16,
1209+
gt=0,
1210+
title="Block size for local Hessian computation.",
1211+
description="The block size used for computing the local Hessian matrix. "
1212+
"This should match the block size used in the quantization config. "
1213+
"Default is 16 for NVFP4.",
1214+
)
1215+
1216+
distributed_sync: bool | None = ModeloptField(
1217+
default=True,
1218+
title="Whether to sync the amax across the distributed processes.",
1219+
description="If True, the amax will be synced across the distributed processes.",
1220+
)
1221+
1222+
debug: bool | None = ModeloptField(
1223+
default=False,
1224+
title="Debug mode.",
1225+
description="If True, module's local Hessian metadata will be kept as a module attribute.",
1226+
)
1227+
1228+
11411229
class SmoothQuantCalibConfig(QuantizeAlgorithmConfig):
11421230
"""The config for ``smoothquant`` algorithm (SmoothQuant).
11431231

modelopt/torch/quantization/mode.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
AWQLiteCalibConfig,
3939
CompressConfig,
4040
GPTQLiteConfig,
41+
LocalHessianCalibConfig,
4142
MaxCalibConfig,
4243
MseCalibConfig,
4344
QuantizeAlgoCfgType,
@@ -56,7 +57,15 @@
5657
restore_svdquant_model,
5758
update_quantize_metadata,
5859
)
59-
from .model_calib import awq, gptq_lite, max_calibrate, mse_calibrate, smoothquant, svdquant
60+
from .model_calib import (
61+
awq,
62+
gptq_lite,
63+
local_hessian_calibrate,
64+
max_calibrate,
65+
mse_calibrate,
66+
smoothquant,
67+
svdquant,
68+
)
6069

6170
__all__ = ["BaseCalibrateModeDescriptor"]
6271

@@ -377,6 +386,22 @@ def config_class(self) -> type[QuantizeAlgorithmConfig]:
377386
_calib_func = mse_calibrate
378387

379388

389+
@CalibrateModeRegistry.register_mode
390+
class LocalHessianModeDescriptor(BaseCalibrateModeDescriptor):
391+
"""Mode for local Hessian-weighted MSE calibration algorithm.
392+
393+
This algorithm uses activation information to optimize per-block scales for weight
394+
quantization by minimizing output reconstruction error instead of weight reconstruction error.
395+
"""
396+
397+
@property
398+
def config_class(self) -> type[QuantizeAlgorithmConfig]:
399+
"""Specifies the config class for the mode."""
400+
return LocalHessianCalibConfig
401+
402+
_calib_func = local_hessian_calibrate
403+
404+
380405
@CalibrateModeRegistry.register_mode
381406
class SmoothQuantModeDescriptor(BaseCalibrateModeDescriptor):
382407
"""Mode for smoothquant calibration algorithm."""

0 commit comments

Comments
 (0)