Skip to content

Commit baef63e

Browse files
committed
add Global Hessian with heavy debug code
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
1 parent a56e251 commit baef63e

4 files changed

Lines changed: 347 additions & 48 deletions

File tree

examples/llm_ptq/hf_ptq.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
import modelopt.torch.quantization as mtq
4949
import modelopt.torch.sparsity as mts
5050
from modelopt.torch.export import (
51-
export_hf_checkpoint,
51+
export_hf_vllm_fq_checkpoint,
5252
export_tensorrt_llm_checkpoint,
5353
get_model_type,
5454
)
@@ -77,6 +77,9 @@
7777
"int4_awq": mtq.INT4_AWQ_CFG,
7878
"w4a8_awq": mtq.W4A8_AWQ_BETA_CFG,
7979
"nvfp4": mtq.NVFP4_DEFAULT_CFG,
80+
"nvfp4_mse": mtq.NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG,
81+
"nvfp4_lo_he": mtq.NVFP4_LOCAL_HESSIAN_CFG,
82+
"nvfp4_gl_he": mtq.NVFP4_GLOBAL_HESSIAN_CFG,
8083
"nvfp4_awq": mtq.NVFP4_AWQ_LITE_CFG,
8184
"fp8_pb_wo": mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG,
8285
"fp8_pc_pt": mtq.FP8_PER_CHANNEL_PER_TOKEN_CFG,
@@ -139,10 +142,10 @@ def make_calib_dataloader(
139142
assert tokenizer is not None and isinstance(
140143
tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)
141144
), "The PreTrainedTokenizer must be set"
142-
# Labels are only needed for gradient-based auto_quantize
145+
# Labels are needed for gradient-based auto_quantize or global hessian calibration
143146
include_labels = (
144147
args.auto_quantize_bits is not None and args.auto_quantize_method == "gradient"
145-
)
148+
) or args.qformat == "nvfp4_gl_he" # Global hessian needs labels for backward pass
146149
calib_dataloader = get_dataset_dataloader(
147150
dataset_name=args.dataset,
148151
tokenizer=tokenizer,
@@ -432,8 +435,18 @@ def mono_quantize(
432435

433436
if not use_calibration:
434437
warnings.warn("Dynamic quantization. Calibration skipped.")
438+
439+
# Check if we need backward pass for global hessian calibration
440+
algorithm_cfg = quant_cfg.get("algorithm", {})
441+
use_global_hessian = (
442+
algorithm_cfg.get("method") == "local_hessian"
443+
and algorithm_cfg.get("hessian_type") == "global"
444+
)
445+
435446
calibrate_loop = (
436-
create_forward_loop(dataloader=calib_dataloader) if use_calibration else None
447+
create_forward_loop(dataloader=calib_dataloader, enable_backward=use_global_hessian)
448+
if use_calibration
449+
else None
437450
)
438451

439452
if calibration_only:
@@ -535,7 +548,7 @@ def export_quantized(
535548
"They will be set at deployment time."
536549
)
537550

538-
export_hf_checkpoint(
551+
export_hf_vllm_fq_checkpoint(
539552
full_model,
540553
export_dir=export_path,
541554
)

modelopt/torch/quantization/config.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,27 @@
446446
},
447447
"algorithm": {
448448
"method": "local_hessian",
449+
"hessian_type": "local",
450+
"fp8_scale_sweep": True,
451+
},
452+
}
453+
454+
NVFP4_GLOBAL_HESSIAN_CFG = {
455+
"quant_cfg": {
456+
"*weight_quantizer": {
457+
"num_bits": (2, 1),
458+
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
459+
"axis": None,
460+
"enable": True,
461+
},
462+
"*input_quantizer": {
463+
"enable": False,
464+
},
465+
**_default_disabled_quantizer_cfg,
466+
},
467+
"algorithm": {
468+
"method": "local_hessian",
469+
"hessian_type": "global",
449470
"fp8_scale_sweep": True,
450471
},
451472
}
@@ -1125,23 +1146,42 @@ class MseCalibConfig(QuantizeAlgorithmConfig):
11251146

11261147

11271148
class LocalHessianCalibConfig(QuantizeAlgorithmConfig):
1128-
"""Configuration for local Hessian-weighted MSE calibration.
1149+
"""Configuration for Hessian-weighted MSE calibration.
11291150
11301151
This algorithm uses activation information to optimize per-block scales for weight
11311152
quantization. It minimizes the output reconstruction error by weighting the loss
1132-
with the local Hessian matrix computed from input activations.
1153+
with the Hessian matrix computed from input activations (and optionally output gradients).
11331154
1134-
The local Hessian loss for each block is: ``(dw @ H @ dw.T)`` where:
1155+
The Hessian loss for each block is: ``(dw @ H @ dw.T)`` where:
11351156
- ``dw = weight - quantized_weight`` (weight reconstruction error per block)
1136-
- ``H = X @ X.T`` is the local Hessian computed from input activations X
1157+
- ``H`` is the Hessian matrix (local or global, depending on ``hessian_type``)
1158+
1159+
Two Hessian types are supported:
1160+
1161+
- **local**: ``H = X @ X.T`` - uses only input activations. Faster, no backward pass needed.
1162+
- **global**: ``H = (X * grad²) @ X.T`` - weights by output gradient squared.
1163+
More accurate as it accounts for output importance, but requires backward pass.
11371164
11381165
This method is particularly effective for NVFP4 weight-only quantization where
11391166
activation information helps select better per-block scales.
1140-
11411167
"""
11421168

11431169
method: Literal["local_hessian"] = ModeloptField("local_hessian")
11441170

1171+
hessian_type: Literal["local", "global"] = ModeloptField(
1172+
default="local",
1173+
title="Type of Hessian to compute.",
1174+
description="""Type of Hessian matrix to use for weighting quantization errors:
1175+
1176+
- ``"local"``: H = X @ X.T - Only uses input activations. Fast, forward-pass only.
1177+
- ``"global"``: H = (X * grad²) @ X.T - Weights by output gradient squared.
1178+
More accurate as it captures output importance, but requires backward pass
1179+
during calibration.
1180+
1181+
The global Hessian is closer to the true Fisher Information and typically
1182+
gives better results, but at the cost of running backward passes.""",
1183+
)
1184+
11451185
step_size: float | None = ModeloptField(
11461186
default=0.1,
11471187
gt=0.0,
@@ -1175,8 +1215,8 @@ class LocalHessianCalibConfig(QuantizeAlgorithmConfig):
11751215
block_size: int | None = ModeloptField(
11761216
default=16,
11771217
gt=0,
1178-
title="Block size for local Hessian computation.",
1179-
description="The block size used for computing the local Hessian matrix. "
1218+
title="Block size for Hessian computation.",
1219+
description="The block size used for computing the Hessian matrix. "
11801220
"This should match the block size used in the quantization config. "
11811221
"Default is 16 for NVFP4.",
11821222
)
@@ -1190,7 +1230,7 @@ class LocalHessianCalibConfig(QuantizeAlgorithmConfig):
11901230
debug: bool | None = ModeloptField(
11911231
default=False,
11921232
title="Debug mode.",
1193-
description="If True, module's local Hessian metadata will be kept as a module attribute.",
1233+
description="If True, module's Hessian metadata will be kept as a module attribute.",
11941234
)
11951235

11961236

0 commit comments

Comments
 (0)