Skip to content

Commit b353110

Browse files
kinjalpatel27danielkorzekwa
authored andcommitted
Added support to rotate in fp32 (optional) (#885)
## What does this PR do? **Type of change:** New Feature **Overview:** This MR adds support to perform rotation for RHT in float32 if enabled by quantization configuration. It also makes rotate argument in quantization configuration of type bool (for backward compatibility) or dict (added option for float32 rotation) ## Usage ``` python hf_ptq.py --pyt_ckpt_path meta-llama/Llama-3.2-3B-Instruct --qformat nvfp4 --export_fmt hf --dataset cnn_dailymail --export_path test --trust_remote_code --inference_pipeline_parallel 1 --batch_size 1 --calib_size 4 --kv_cache_qformat nvfp4_rotate ``` Updated `NVFP4_KV_ROTATE_CFG` locally with `"rotate": {"enable": True, "rotate_fp32": True}` ``` ... model.layers.27.self_attn.k_bmm_quantizer TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=8.3750 rotated (fp32) calibrator =MaxCalibrator quant) ... ``` Updated `NVFP4_KV_ROTATE_CFG` locally with `"rotate": {"enable": True, "rotate_fp32": False}` ``` model.layers.27.self_attn.k_bmm_quantizer TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=8.3750 rotated calibrator=MaxCalibrator quant) ``` ## Testing Updated unit test in `tests/gpu/torch/quantization/test_hadamard.py` ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: No (updated existing test) - **Did you add or update any necessary documentation?**: Yes - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added rotational input capability prior to quantization for RHT (Rotated Hyperplane Transform). * Introduced granular rotation configuration options enabling FP32 casting for improved numerical stability during transforms. * **Tests** * Expanded test coverage for rotation functionality with parameterized FP32 casting scenarios. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com> Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
1 parent eace1ae commit b353110

File tree

5 files changed

+49
-12
lines changed

5 files changed

+49
-12
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ NVIDIA Model Optimizer Changelog (Linux)
99
- User does not need to manually register MOE modules to cover experts calibration coverage in PTQ workflow.
1010
- ``hf_ptq.py`` now saves the quantization summary and moe expert token count table to the export directory.
1111
- Add sparse attention optimization for transformer models (``modelopt.torch.sparsity.attention_sparsity``). This reduces computational cost by skipping attention computation. Supports calibration for threshold selection on HuggingFace models. See `examples/llm_sparsity/attention_sparsity/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_sparsity/attention_sparsity>`_ for usage.
12+
- Add support for rotating the input before quantization for RHT.
1213

1314
0.42 (2026-02-xx)
1415
^^^^^^^^^^^^^^^^^

modelopt/torch/quantization/config.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,14 +1033,20 @@ def validate_calibrator(cls, v, info: ValidationInfo):
10331033
assert v in ["max", "histogram"]
10341034
return v
10351035

1036-
rotate: bool = ModeloptField(
1036+
rotate: bool | dict[str, bool] = ModeloptField(
10371037
default=False,
1038-
title="""If rotate the input before quantization.""",
1039-
description=""""If true, the input of the quantizer will be rotated with a hadamard matrix
1038+
title="""Configuration for rotating the input before quantization.""",
1039+
description="""Can be a boolean or a dictionary with the following keys:
1040+
- "enable": Boolean to enable/disable rotation (default: False)
1041+
- "rotate_fp32": Boolean to compute rotation in float32 precision (default: False)
1042+
1043+
If a boolean is provided, it is treated as the "enable" value with "rotate_fp32" defaulting to False.
1044+
1045+
When enabled, the input of the quantizer will be rotated with a hadamard matrix
10401046
given by scipy.linalg.hadamard, i.e.
10411047
``input = input @ scipy.linalg.hadamard(input.shape[-1]) / sqrt(input.shape[-1])``.
10421048
1043-
This can be used for ratation based PTQ methods, e.g. QuaRot or SpinQuant.
1049+
This can be used for rotation based PTQ methods, e.g. QuaRot or SpinQuant.
10441050
See https://arxiv.org/abs/2404.00456 for example.""",
10451051
)
10461052

modelopt/torch/quantization/nn/functional.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def backward(ctx, grad_outputs):
9393
return fast_hadamard_transform.hadamard_transform(grad_outputs) # type: ignore[name-defined]
9494

9595

96-
def normalized_hadamard_transform(inputs):
96+
def normalized_hadamard_transform(inputs, rotate_fp32=False):
9797
"""Normalized fast hadamard transform."""
9898
global fast_hadamard_transform
9999
try:
@@ -104,6 +104,10 @@ def normalized_hadamard_transform(inputs):
104104
"`pip install git+https://github.com/Dao-AILab/fast-hadamard-transform.git`"
105105
)
106106

107-
return FastHadamardTransform.apply(inputs) / torch.sqrt(
107+
dtype = inputs.dtype
108+
if rotate_fp32:
109+
inputs = inputs.to(torch.float32)
110+
outputs = FastHadamardTransform.apply(inputs) / torch.sqrt(
108111
torch.tensor(inputs.shape[-1], dtype=torch.float32)
109112
)
113+
return outputs.to(dtype) if rotate_fp32 else outputs

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,20 @@ def is_static_block_quant(self):
529529
and self._fake_quant
530530
)
531531

532+
@property
533+
def rotate_is_enabled(self):
534+
"""Check if rotate is enabled in quant config."""
535+
return self._rotate.get("enable", False) if isinstance(self._rotate, dict) else self._rotate
536+
537+
@property
538+
def rotate_is_fp32(self):
539+
"""Check if rotation needs to be computed in float32."""
540+
return (
541+
self._rotate.get("rotate_fp32", False)
542+
if isinstance(self._rotate, dict) and self.rotate_is_enabled
543+
else False
544+
)
545+
532546
def disable_calib(self):
533547
"""Disable calibration."""
534548
self._if_calib = False
@@ -996,8 +1010,8 @@ def forward(self, inputs):
9961010
inputs = inputs * self.pre_quant_scale
9971011

9981012
# Rotating the input
999-
if self._rotate:
1000-
inputs = normalized_hadamard_transform(inputs)
1013+
if self.rotate_is_enabled:
1014+
inputs = normalized_hadamard_transform(inputs, rotate_fp32=self.rotate_is_fp32)
10011015

10021016
if self._disabled:
10031017
# if quantizer is disabled, we still need to track the input dtype for saving the model
@@ -1109,7 +1123,8 @@ def extra_repr(self):
11091123
if self.pre_quant_scale is not None
11101124
else ""
11111125
)
1112-
s += " rotated" if self._rotate else ""
1126+
s += " rotated" if self.rotate_is_enabled else ""
1127+
s += " (fp32)" if self.rotate_is_fp32 else ""
11131128
s += (
11141129
f" calibrator={self._calibrator.__class__.__name__}"
11151130
if (self._calibrator is not None)

tests/gpu/torch/quantization/test_hadamard.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,21 +41,32 @@ def test_hadamard_transform(dim):
4141
xxt_h = x_h @ x_h.T
4242
# The numerical error can be large, especially for 16-bit floats.
4343
assert torch.allclose(xxt_h, xxt, atol=0.05)
44+
x_h_fp32 = normalized_hadamard_transform(x, rotate_fp32=True)
45+
xxt_h_fp32 = x_h_fp32 @ x_h_fp32.T
46+
assert torch.allclose(xxt_h_fp32, xxt, atol=0.05)
4447

4548

46-
def test_kv_rotate():
49+
@pytest.mark.parametrize(
50+
"rotate_fp32",
51+
[True, False],
52+
)
53+
def test_kv_rotate(rotate_fp32):
4754
mtq.plugins.register_attention_for_kv_quant(SDPAAttention)
4855
model = nn.Sequential(SDPAAttention())
4956
mtq.replace_quant_module(model)
5057

5158
set_quantizer_by_cfg(model, {"*": {"enable": False}})
5259
dummy_input = SDPAAttention.get_input(device="cuda")
5360
output_ref = model(dummy_input)
61+
if rotate_fp32:
62+
rotate = {"enable": True, "rotate_fp32": True}
63+
else:
64+
rotate = True
5465
with set_quantizer_by_cfg_context(
5566
model,
5667
{
5768
"*[qk]_bmm_quantizer": {
58-
"rotate": True,
69+
"rotate": rotate,
5970
},
6071
},
6172
):
@@ -67,7 +78,7 @@ def test_kv_rotate():
6778
model,
6879
{
6980
"*k_bmm_quantizer": {
70-
"rotate": True,
81+
"rotate": rotate,
7182
},
7283
},
7384
):

0 commit comments

Comments
 (0)