Skip to content

Commit 4d1380a

Browse files
committed
add local hessian calibration
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
1 parent aafd388 commit 4d1380a

File tree

3 files changed

+416
-2
lines changed

3 files changed

+416
-2
lines changed

modelopt/torch/quantization/config.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,69 @@
387387
"algorithm": "max",
388388
}
389389

390+
NVFP4_WEIGHT_ACT_MSE_CFG = {
391+
"quant_cfg": {
392+
"*weight_quantizer": {
393+
"num_bits": (2, 1),
394+
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
395+
"axis": None,
396+
"enable": True,
397+
},
398+
"*input_quantizer": {
399+
"num_bits": (2, 1),
400+
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
401+
"axis": None,
402+
"enable": True,
403+
},
404+
**_default_disabled_quantizer_cfg,
405+
},
406+
"algorithm": {
407+
"method": "mse",
408+
"step_size": 0.25,
409+
"start_multiplier": 0.25,
410+
"stop_multiplier": 2.0,
411+
},
412+
}
413+
414+
NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG = {
415+
"quant_cfg": {
416+
"*weight_quantizer": {
417+
"num_bits": (2, 1),
418+
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
419+
"axis": None,
420+
"enable": True,
421+
},
422+
"*input_quantizer": {
423+
"enable": False,
424+
},
425+
**_default_disabled_quantizer_cfg,
426+
},
427+
"algorithm": {
428+
"method": "mse",
429+
"fp8_scale_sweep": True,
430+
},
431+
}
432+
433+
434+
NVFP4_LOCAL_HESSIAN_CFG = {
435+
"quant_cfg": {
436+
"*weight_quantizer": {
437+
"num_bits": (2, 1),
438+
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
439+
"axis": None,
440+
"enable": True,
441+
},
442+
"*input_quantizer": {
443+
"enable": False,
444+
},
445+
**_default_disabled_quantizer_cfg,
446+
},
447+
"algorithm": {
448+
"method": "local_hessian",
449+
"fp8_scale_sweep": True,
450+
},
451+
}
452+
390453
NVFP4_AWQ_LITE_CFG = {
391454
"quant_cfg": {
392455
"*weight_quantizer": {
@@ -1059,6 +1122,76 @@ class MseCalibConfig(QuantizeAlgorithmConfig):
10591122
)
10601123

10611124

1125+
class LocalHessianCalibConfig(QuantizeAlgorithmConfig):
1126+
"""Configuration for local Hessian-weighted MSE calibration.
1127+
1128+
This algorithm uses activation information to optimize per-block scales for weight
1129+
quantization. It minimizes the output reconstruction error by weighting the loss
1130+
with the local Hessian matrix computed from input activations.
1131+
1132+
The local Hessian loss for each block is: ``(dw @ H @ dw.T)`` where:
1133+
- ``dw = weight - quantized_weight`` (weight reconstruction error per block)
1134+
- ``H = X @ X.T`` is the local Hessian computed from input activations X
1135+
1136+
This method is particularly effective for NVFP4 weight-only quantization where
1137+
activation information helps select better per-block scales.
1138+
1139+
"""
1140+
1141+
method: Literal["local_hessian"] = ModeloptField("local_hessian")
1142+
1143+
step_size: float | None = ModeloptField(
1144+
default=0.1,
1145+
gt=0.0,
1146+
title="Step size for amax search.",
1147+
description="Step size between amax candidates. The number of candidates is computed as "
1148+
"ceil((stop_multiplier - start_multiplier) / step_size) + 1.",
1149+
)
1150+
1151+
start_multiplier: float | None = ModeloptField(
1152+
default=0.25,
1153+
gt=0.0,
1154+
title="Starting multiplier for amax search.",
1155+
description="Starting multiplier for amax search range (multiplies initial amax).",
1156+
)
1157+
1158+
stop_multiplier: float | None = ModeloptField(
1159+
default=4.0,
1160+
gt=0.0,
1161+
title="Ending multiplier for amax search.",
1162+
description="Ending multiplier for amax search range (multiplies initial amax).",
1163+
)
1164+
1165+
fp8_scale_sweep: bool | None = ModeloptField(
1166+
default=True,
1167+
title="Enable FP8 scale sweep for NVFP4 per-block quantization.",
1168+
description="If True, sweep over all 128 possible FP8 E4M3 scale values "
1169+
"for NVFP4 per-block quantization instead of using multipliers. "
1170+
"This is the recommended setting for NVFP4 quantization.",
1171+
)
1172+
1173+
block_size: int | None = ModeloptField(
1174+
default=16,
1175+
gt=0,
1176+
title="Block size for local Hessian computation.",
1177+
description="The block size used for computing the local Hessian matrix. "
1178+
"This should match the block size used in the quantization config. "
1179+
"Default is 16 for NVFP4.",
1180+
)
1181+
1182+
distributed_sync: bool | None = ModeloptField(
1183+
default=True,
1184+
title="Whether to sync the amax across the distributed processes.",
1185+
description="If True, the amax will be synced across the distributed processes.",
1186+
)
1187+
1188+
debug: bool | None = ModeloptField(
1189+
default=False,
1190+
title="Debug mode.",
1191+
description="If True, module's local Hessian metadata will be kept as a module attribute.",
1192+
)
1193+
1194+
10621195
class SmoothQuantCalibConfig(QuantizeAlgorithmConfig):
10631196
"""The config for ``smoothquant`` algorithm (SmoothQuant).
10641197

modelopt/torch/quantization/mode.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
AWQFullCalibConfig,
3838
AWQLiteCalibConfig,
3939
CompressConfig,
40+
LocalHessianCalibConfig,
4041
MaxCalibConfig,
4142
MseCalibConfig,
4243
QuantizeAlgoCfgType,
@@ -55,7 +56,14 @@
5556
restore_svdquant_model,
5657
update_quantize_metadata,
5758
)
58-
from .model_calib import awq, max_calibrate, mse_calibrate, smoothquant, svdquant
59+
from .model_calib import (
60+
awq,
61+
local_hessian_calibrate,
62+
max_calibrate,
63+
mse_calibrate,
64+
smoothquant,
65+
svdquant,
66+
)
5967

6068
__all__ = ["BaseCalibrateModeDescriptor"]
6169

@@ -376,6 +384,22 @@ def config_class(self) -> type[QuantizeAlgorithmConfig]:
376384
_calib_func = mse_calibrate
377385

378386

387+
@CalibrateModeRegistry.register_mode
388+
class LocalHessianModeDescriptor(BaseCalibrateModeDescriptor):
389+
"""Mode for local Hessian-weighted MSE calibration algorithm.
390+
391+
This algorithm uses activation information to optimize per-block scales for weight
392+
quantization by minimizing output reconstruction error instead of weight reconstruction error.
393+
"""
394+
395+
@property
396+
def config_class(self) -> type[QuantizeAlgorithmConfig]:
397+
"""Specifies the config class for the mode."""
398+
return LocalHessianCalibConfig
399+
400+
_calib_func = local_hessian_calibrate
401+
402+
379403
@CalibrateModeRegistry.register_mode
380404
class SmoothQuantModeDescriptor(BaseCalibrateModeDescriptor):
381405
"""Mode for smoothquant calibration algorithm."""

0 commit comments

Comments
 (0)