Skip to content

Commit 4e56843

Browse files
enable hadamard rotation, also fix a minor Qbmm bug
Signed-off-by: cliu-us <cliu@us.ibm.com>
1 parent 13ebd90 commit 4e56843

8 files changed

Lines changed: 332 additions & 144 deletions

File tree

fms_mo/modules/bmm.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
# Local
2222
from fms_mo.quant.quantizers import Qbypass, Qdynamic, get_activation_quantizer
23+
from fms_mo.quant.rotation import RotQuantWrapper
2324

2425

2526
class QBmm(nn.Module):
@@ -131,8 +132,10 @@ def __init__(
131132
)
132133

133134
self.calib_iterator = [] # To simplify update of clipvals in forward()
134-
self.quantize_m1 = Qbypass()
135-
self.quantize_calib_m1 = Qbypass()
135+
quant_m1_def = Qbypass() if "rot_" not in self.qm1_mode else RotQuantWrapper()
136+
quant_m2_def = Qbypass() if "rot_" not in self.qm2_mode else RotQuantWrapper()
137+
self.quantize_m1 = quant_m1_def
138+
self.quantize_calib_m1 = quant_m1_def
136139
if self.num_bits_m1 not in [32, 16]:
137140
self.quantize_m1 = get_activation_quantizer(
138141
self.qm1_mode if (not m1_bounded or "fp8" in qm1_mode) else "minmax",
@@ -155,8 +158,8 @@ def __init__(
155158
symmetric=self.symmetric,
156159
)
157160

158-
self.quantize_m2 = Qbypass()
159-
self.quantize_calib_m2 = Qbypass()
161+
self.quantize_m2 = quant_m2_def
162+
self.quantize_calib_m2 = quant_m2_def
160163
if self.num_bits_m2 not in [32, 16]:
161164
self.quantize_m2 = get_activation_quantizer(
162165
self.qm2_mode if (not m2_bounded or "fp8" in qm2_mode) else "minmax",

fms_mo/modules/linear.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
get_weight_quantizer,
3737
mask_fc_kij,
3838
)
39+
from fms_mo.quant.rotation import RotQuantWrapper
3940
from fms_mo.utils.import_utils import available_packages
4041

4142
if available_packages["triton"]:
@@ -158,8 +159,10 @@ def __init__(
158159

159160
self.calib_iterator = []
160161
# To simplify update of clipvals in forward()
161-
self.quantize_feature = Qbypass()
162-
self.quantize_calib_feature = Qbypass()
162+
quantA_default = Qbypass() if "rot_" not in self.qa_mode else RotQuantWrapper()
163+
quantW_default = Qbypass() if "rot_" not in self.qw_mode else RotQuantWrapper()
164+
self.quantize_feature = quantA_default
165+
self.quantize_calib_feature = quantA_default
163166
if self.num_bits_feature not in [32, 16]:
164167
self.quantize_feature = get_activation_quantizer(
165168
self.qa_mode,
@@ -187,8 +190,8 @@ def __init__(
187190
quantizer2sync=self.quantize_feature,
188191
)
189192

190-
self.quantize_weight = Qbypass()
191-
self.quantize_calib_weight = Qbypass()
193+
self.quantize_weight = quantW_default
194+
self.quantize_calib_weight = quantW_default
192195
if self.num_bits_weight not in [32, 16]:
193196
self.quantize_weight = get_weight_quantizer(
194197
self.qw_mode,

fms_mo/prep.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False):
215215
base_params = {}
216216
if hasattr(module, "__constants__"):
217217
base_params = {k: getattr(module, k) for k in module.__constants__}
218-
base_params["bias"] = module.bias is not None
218+
base_params["bias"] = getattr(module, "bias", None) is not None
219219
base_params["device"] = next(module.parameters()).device # usually cuda
220220

221221
module_output = module
@@ -480,6 +480,12 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False):
480480
setattr(module_output, k, v)
481481
module_output._all_weights = module._all_weights
482482

483+
# For nn.Embedding
484+
elif isinstance(module, nn.Embedding):
485+
# simplest case, only support rotation for now, no quantization
486+
Qemb = mapping.get(nn.Embedding, nn.Embedding)
487+
module_output = Qemb(module)
488+
483489
return module_output
484490

485491

fms_mo/quant/quantizers.py

Lines changed: 57 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def get_activation_quantizer(
7777
use_rot = False
7878
if "rot_" in qa_mode or "_rot" in qa_mode:
7979
use_rot = True
80-
qa_mode.replace("rot_", "").replace("_rot", "")
80+
qa_mode = qa_mode.replace("rot_", "").replace("_rot", "")
8181

8282
if not use_swcap:
8383
QPACTLUT = {
@@ -134,23 +134,27 @@ def get_activation_quantizer(
134134
)
135135
elif qa_mode == "dorefa":
136136
act_quantizer = dorefa_quantize_activation
137-
elif (
138-
qa_mode == "max"
139-
): # NOTE Need to be careful using this for activation, particular to 1 sided.
140-
act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=False)
141-
elif qa_mode == "minmax":
142-
act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=True)
137+
elif "max" in qa_mode:
138+
# NOTE Need to be careful using this for activation, particular to 1 sided.
139+
if "min" in qa_mode:
140+
act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=True)
141+
elif "pertoken" in qa_mode or "perToken" in qa_mode:
142+
act_quantizer = QMaxDynamic(nbits, dim=-1)
143+
elif "per_channel" in qa_mode or "perCh" in qa_mode:
144+
act_quantizer = QMaxDynamic(nbits, dim=-2)
145+
elif "sym" in qa_mode:
146+
act_quantizer = Qmax(
147+
nbits,
148+
align_zero=True,
149+
minmax=False,
150+
extend_act_range=extend_act_range,
151+
)
152+
else:
153+
act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=False)
143154
elif qa_mode == "fix":
144155
act_quantizer = QFixSymmetric(
145156
nbits, init_clip_val=clip_val, align_zero=align_zero
146157
)
147-
elif qa_mode == "maxsym":
148-
act_quantizer = Qmax(
149-
nbits,
150-
align_zero=True,
151-
minmax=False,
152-
extend_act_range=extend_act_range,
153-
)
154158
elif qa_mode == "pactsym":
155159
act_quantizer = PACT2Sym(
156160
nbits,
@@ -190,8 +194,6 @@ def get_activation_quantizer(
190194
perToken=perToken,
191195
emulate=True,
192196
)
193-
elif qa_mode == "pertokenmax":
194-
act_quantizer = PerTokenMax(nbits)
195197
else:
196198
raise ValueError(f"unrecognized activation quantization mode {qa_mode}")
197199
else: # swcap-compatible activation quantizers
@@ -266,7 +268,7 @@ def get_weight_quantizer(
266268
use_rot = False
267269
if "rot_" in qw_mode or "_rot" in qw_mode:
268270
use_rot = True
269-
qw_mode.replace("rot_", "").replace("_rot", "")
271+
qw_mode = qw_mode.replace("rot_", "").replace("_rot", "")
270272

271273
weight_quantizer = None
272274
if not use_swcap:
@@ -3495,7 +3497,7 @@ def __init__(self, num_bits):
34953497
"""
34963498
For per-token activation quantization using abs().max() as scale,
34973499
Zero is aligned so that the levels are symmetric around zero (lossing one level)
3498-
Since the token length is un-known before running, the quatnization is dynamic, meaning
3500+
Since the token length is un-known before running, the quantization is dynamic, meaning
34993501
no trainable quantization scales and the scales are computed at run time.
35003502
"""
35013503
super().__init__()
@@ -3512,6 +3514,42 @@ def __repr__(self):
35123514
return f"{self.__class__.__name__}(num_bits={self.num_bits}, quantizer=)"
35133515

35143516

3517+
class QMaxDynamic(nn.Module):
3518+
def __init__(self, num_bits, dim=-1):
3519+
"""
3520+
For per-token or per-channel quantization using abs().max() as scale, usually for activation
3521+
and could be used for Qbmm M2 as well.
3522+
(reduce) dim = -1 -> abs() will output a column vector (if input is 2D) => per token
3523+
dim = -2 -> per-channel
3524+
Zero is aligned so that the levels are symmetric around zero (lossing one level)
3525+
Since the token length is un-known before running, the quantizater can only calculate the
3526+
scales at the run times dynamically, meaning no trainable quantization scales is allowed.
3527+
(unless input seq length is always the same, not just padded to a fixed length.)
3528+
"""
3529+
super().__init__()
3530+
self.num_bits = num_bits
3531+
self.levels = 2 ** (self.num_bits - 1) - 1
3532+
if isinstance(dim, str):
3533+
if "perCh" in dim or "per_channel" in dim:
3534+
dim = -2
3535+
elif "perToken" in dim or "per_token" in dim or "per_Token" in dim:
3536+
dim = -1
3537+
elif dim in [-1, -2]:
3538+
self.reduce_dim = dim
3539+
else:
3540+
raise ValueError(
3541+
f"Reduce dim can only be [-1, -2] or ['perCh', 'perToken'] but found {dim}"
3542+
)
3543+
3544+
def forward(self, input_tensor):
3545+
amax_dim = input_tensor.abs().max(dim=self.reduce_dim, keepdim=True)[0]
3546+
scales = amax_dim.clamp(min=1e-5).div(self.levels)
3547+
return input_tensor.div(scales).round().mul(scales)
3548+
3549+
def __repr__(self):
3550+
return f"{self.__class__.__name__}(num_bits={self.num_bits}, quantizer=)"
3551+
3552+
35153553
class Qdynamic(nn.Module):
35163554
def __init__(
35173555
self,
@@ -4585,7 +4623,7 @@ def forward(self, x_orig):
45854623

45864624
class Qbypass(nn.Module):
45874625
"""
4588-
no quantization at all, straight-thru
4626+
No quantization at all, output the input_tensor directly.
45894627
in place of lambda function when using nbits=32 and 16.
45904628
to avoid issue when pickle (ie torch.save) of lambda
45914629
(seems to be a problem only for DDP)

0 commit comments

Comments
 (0)