Skip to content

Commit fa20e3b

Browse files
Fridah-nvFrida Hou
authored andcommitted
add local hessian calibration
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
1 parent c974090 commit fa20e3b

File tree

3 files changed

+417
-2
lines changed

3 files changed

+417
-2
lines changed

modelopt/torch/quantization/config.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,69 @@
388388
"algorithm": "max",
389389
}
390390

391+
NVFP4_WEIGHT_ACT_MSE_CFG = {
392+
"quant_cfg": {
393+
"*weight_quantizer": {
394+
"num_bits": (2, 1),
395+
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
396+
"axis": None,
397+
"enable": True,
398+
},
399+
"*input_quantizer": {
400+
"num_bits": (2, 1),
401+
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
402+
"axis": None,
403+
"enable": True,
404+
},
405+
**_default_disabled_quantizer_cfg,
406+
},
407+
"algorithm": {
408+
"method": "mse",
409+
"step_size": 0.25,
410+
"start_multiplier": 0.25,
411+
"stop_multiplier": 2.0,
412+
},
413+
}
414+
415+
NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG = {
416+
"quant_cfg": {
417+
"*weight_quantizer": {
418+
"num_bits": (2, 1),
419+
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
420+
"axis": None,
421+
"enable": True,
422+
},
423+
"*input_quantizer": {
424+
"enable": False,
425+
},
426+
**_default_disabled_quantizer_cfg,
427+
},
428+
"algorithm": {
429+
"method": "mse",
430+
"fp8_scale_sweep": True,
431+
},
432+
}
433+
434+
435+
NVFP4_LOCAL_HESSIAN_CFG = {
436+
"quant_cfg": {
437+
"*weight_quantizer": {
438+
"num_bits": (2, 1),
439+
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
440+
"axis": None,
441+
"enable": True,
442+
},
443+
"*input_quantizer": {
444+
"enable": False,
445+
},
446+
**_default_disabled_quantizer_cfg,
447+
},
448+
"algorithm": {
449+
"method": "local_hessian",
450+
"fp8_scale_sweep": True,
451+
},
452+
}
453+
391454
NVFP4_AWQ_LITE_CFG = {
392455
"quant_cfg": {
393456
"*weight_quantizer": {
@@ -1060,6 +1123,76 @@ class MseCalibConfig(QuantizeAlgorithmConfig):
10601123
)
10611124

10621125

1126+
class LocalHessianCalibConfig(QuantizeAlgorithmConfig):
1127+
"""Configuration for local Hessian-weighted MSE calibration.
1128+
1129+
This algorithm uses activation information to optimize per-block scales for weight
1130+
quantization. It minimizes the output reconstruction error by weighting the loss
1131+
with the local Hessian matrix computed from input activations.
1132+
1133+
The local Hessian loss for each block is: ``(dw @ H @ dw.T)`` where:
1134+
- ``dw = weight - quantized_weight`` (weight reconstruction error per block)
1135+
- ``H = X @ X.T`` is the local Hessian computed from input activations X
1136+
1137+
This method is particularly effective for NVFP4 weight-only quantization where
1138+
activation information helps select better per-block scales.
1139+
1140+
"""
1141+
1142+
method: Literal["local_hessian"] = ModeloptField("local_hessian")
1143+
1144+
step_size: float | None = ModeloptField(
1145+
default=0.1,
1146+
gt=0.0,
1147+
title="Step size for amax search.",
1148+
description="Step size between amax candidates. The number of candidates is computed as "
1149+
"ceil((stop_multiplier - start_multiplier) / step_size) + 1.",
1150+
)
1151+
1152+
start_multiplier: float | None = ModeloptField(
1153+
default=0.25,
1154+
gt=0.0,
1155+
title="Starting multiplier for amax search.",
1156+
description="Starting multiplier for amax search range (multiplies initial amax).",
1157+
)
1158+
1159+
stop_multiplier: float | None = ModeloptField(
1160+
default=4.0,
1161+
gt=0.0,
1162+
title="Ending multiplier for amax search.",
1163+
description="Ending multiplier for amax search range (multiplies initial amax).",
1164+
)
1165+
1166+
fp8_scale_sweep: bool | None = ModeloptField(
1167+
default=True,
1168+
title="Enable FP8 scale sweep for NVFP4 per-block quantization.",
1169+
description="If True, sweep over all 128 possible FP8 E4M3 scale values "
1170+
"for NVFP4 per-block quantization instead of using multipliers. "
1171+
"This is the recommended setting for NVFP4 quantization.",
1172+
)
1173+
1174+
block_size: int | None = ModeloptField(
1175+
default=16,
1176+
gt=0,
1177+
title="Block size for local Hessian computation.",
1178+
description="The block size used for computing the local Hessian matrix. "
1179+
"This should match the block size used in the quantization config. "
1180+
"Default is 16 for NVFP4.",
1181+
)
1182+
1183+
distributed_sync: bool | None = ModeloptField(
1184+
default=True,
1185+
title="Whether to sync the amax across the distributed processes.",
1186+
description="If True, the amax will be synced across the distributed processes.",
1187+
)
1188+
1189+
debug: bool | None = ModeloptField(
1190+
default=False,
1191+
title="Debug mode.",
1192+
description="If True, module's local Hessian metadata will be kept as a module attribute.",
1193+
)
1194+
1195+
10631196
class SmoothQuantCalibConfig(QuantizeAlgorithmConfig):
10641197
"""The config for ``smoothquant`` algorithm (SmoothQuant).
10651198

modelopt/torch/quantization/mode.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
AWQLiteCalibConfig,
3939
CompressConfig,
4040
GPTQLiteConfig,
41+
LocalHessianCalibConfig,
4142
MaxCalibConfig,
4243
MseCalibConfig,
4344
QuantizeAlgoCfgType,
@@ -56,7 +57,15 @@
5657
restore_svdquant_model,
5758
update_quantize_metadata,
5859
)
59-
from .model_calib import awq, gptq_lite, max_calibrate, mse_calibrate, smoothquant, svdquant
60+
from .model_calib import (
61+
awq,
62+
gptq_lite,
63+
local_hessian_calibrate,
64+
max_calibrate,
65+
mse_calibrate,
66+
smoothquant,
67+
svdquant,
68+
)
6069

6170
__all__ = ["BaseCalibrateModeDescriptor"]
6271

@@ -377,6 +386,22 @@ def config_class(self) -> type[QuantizeAlgorithmConfig]:
377386
_calib_func = mse_calibrate
378387

379388

389+
@CalibrateModeRegistry.register_mode
390+
class LocalHessianModeDescriptor(BaseCalibrateModeDescriptor):
391+
"""Mode for local Hessian-weighted MSE calibration algorithm.
392+
393+
This algorithm uses activation information to optimize per-block scales for weight
394+
quantization by minimizing output reconstruction error instead of weight reconstruction error.
395+
"""
396+
397+
@property
398+
def config_class(self) -> type[QuantizeAlgorithmConfig]:
399+
"""Specifies the config class for the mode."""
400+
return LocalHessianCalibConfig
401+
402+
_calib_func = local_hessian_calibrate
403+
404+
380405
@CalibrateModeRegistry.register_mode
381406
class SmoothQuantModeDescriptor(BaseCalibrateModeDescriptor):
382407
"""Mode for smoothquant calibration algorithm."""

0 commit comments

Comments
 (0)