@@ -525,6 +525,20 @@ def is_static_block_quant(self):
525525 and self ._fake_quant
526526 )
527527
528+ @property
529+ def rotate_is_enabled (self ):
530+ """Check if rotate is enabled in quant config."""
531+ return self ._rotate .get ("enable" , False ) if isinstance (self ._rotate , dict ) else self ._rotate
532+
533+ @property
534+ def rotate_is_fp32 (self ):
535+ """Check if rotation needs to be computed in float32."""
536+ return (
537+ self ._rotate .get ("rotate_fp32" , False )
538+ if isinstance (self ._rotate , dict ) and self .rotate_is_enabled
539+ else False
540+ )
541+
528542 def disable_calib (self ):
529543 """Disable calibration."""
530544 self ._if_calib = False
@@ -992,14 +1006,8 @@ def forward(self, inputs):
9921006 inputs = inputs * self .pre_quant_scale
9931007
9941008 # Rotating the input
995- rotate_fp32 = (
996- self ._rotate .get ("rotate_fp32" , False ) if isinstance (self ._rotate , dict ) else False
997- )
998- rotate_enable = (
999- self ._rotate .get ("enable" , False ) if isinstance (self ._rotate , dict ) else self ._rotate
1000- )
1001- if rotate_enable :
1002- inputs = normalized_hadamard_transform (inputs , rotate_fp32 = rotate_fp32 )
1009+ if self .rotate_is_enabled :
1010+ inputs = normalized_hadamard_transform (inputs , rotate_fp32 = self .rotate_is_fp32 )
10031011
10041012 if self ._disabled :
10051013 # if quantizer is disabled, we still need to track the input dtype for saving the model
@@ -1111,13 +1119,8 @@ def extra_repr(self):
11111119 if self .pre_quant_scale is not None
11121120 else ""
11131121 )
1114- if isinstance (self ._rotate , dict ):
1115- if self ._rotate .get ("enable" , False ):
1116- s += " rotated"
1117- if self ._rotate .get ("rotate_fp32" , False ):
1118- s += " (fp32)"
1119- elif self ._rotate :
1120- s += " rotated"
1122+ s += " rotated" if self .rotate_is_enabled else ""
1123+ s += " (fp32)" if self .rotate_is_fp32 else ""
11211124 s += (
11221125 f" calibrator={ self ._calibrator .__class__ .__name__ } "
11231126 if (self ._calibrator is not None )
0 commit comments