Skip to content

Commit 044a4ab

Browse files
committed
minor
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent 39b8b32 commit 044a4ab

File tree

1 file changed

+18
-15
lines changed

1 file changed

+18
-15
lines changed

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,20 @@ def is_static_block_quant(self):
525525
and self._fake_quant
526526
)
527527

528+
@property
529+
def rotate_is_enabled(self):
530+
"""Check if rotate is enabled in quant config."""
531+
return self._rotate.get("enable", False) if isinstance(self._rotate, dict) else self._rotate
532+
533+
@property
534+
def rotate_is_fp32(self):
535+
"""Check if rotation needs to be computed in float32."""
536+
return (
537+
self._rotate.get("rotate_fp32", False)
538+
if isinstance(self._rotate, dict) and self.rotate_is_enabled
539+
else False
540+
)
541+
528542
def disable_calib(self):
529543
"""Disable calibration."""
530544
self._if_calib = False
@@ -992,14 +1006,8 @@ def forward(self, inputs):
9921006
inputs = inputs * self.pre_quant_scale
9931007

9941008
# Rotating the input
995-
rotate_fp32 = (
996-
self._rotate.get("rotate_fp32", False) if isinstance(self._rotate, dict) else False
997-
)
998-
rotate_enable = (
999-
self._rotate.get("enable", False) if isinstance(self._rotate, dict) else self._rotate
1000-
)
1001-
if rotate_enable:
1002-
inputs = normalized_hadamard_transform(inputs, rotate_fp32=rotate_fp32)
1009+
if self.rotate_is_enabled:
1010+
inputs = normalized_hadamard_transform(inputs, rotate_fp32=self.rotate_is_fp32)
10031011

10041012
if self._disabled:
10051013
# if quantizer is disabled, we still need to track the input dtype for saving the model
@@ -1111,13 +1119,8 @@ def extra_repr(self):
11111119
if self.pre_quant_scale is not None
11121120
else ""
11131121
)
1114-
if isinstance(self._rotate, dict):
1115-
if self._rotate.get("enable", False):
1116-
s += " rotated"
1117-
if self._rotate.get("rotate_fp32", False):
1118-
s += " (fp32)"
1119-
elif self._rotate:
1120-
s += " rotated"
1122+
s += " rotated" if self.rotate_is_enabled else ""
1123+
s += " (fp32)" if self.rotate_is_fp32 else ""
11211124
s += (
11221125
f" calibrator={self._calibrator.__class__.__name__}"
11231126
if (self._calibrator is not None)

0 commit comments

Comments
 (0)