@@ -77,7 +77,7 @@ def get_activation_quantizer(
7777 use_rot = False
7878 if "rot_" in qa_mode or "_rot" in qa_mode :
7979 use_rot = True
80- qa_mode .replace ("rot_" , "" ).replace ("_rot" , "" )
80+ qa_mode = qa_mode .replace ("rot_" , "" ).replace ("_rot" , "" )
8181
8282 if not use_swcap :
8383 QPACTLUT = {
@@ -134,23 +134,27 @@ def get_activation_quantizer(
134134 )
135135 elif qa_mode == "dorefa" :
136136 act_quantizer = dorefa_quantize_activation
137- elif (
138- qa_mode == "max"
139- ): # NOTE Need to be careful using this for activation, particular to 1 sided.
140- act_quantizer = Qmax (nbits , align_zero = align_zero , minmax = False )
141- elif qa_mode == "minmax" :
142- act_quantizer = Qmax (nbits , align_zero = align_zero , minmax = True )
137+ elif "max" in qa_mode :
138+ # NOTE Need to be careful using this for activation, particular to 1 sided.
139+ if "min" in qa_mode :
140+ act_quantizer = Qmax (nbits , align_zero = align_zero , minmax = True )
141+ elif "pertoken" in qa_mode or "perToken" in qa_mode :
142+ act_quantizer = QMaxDynamic (nbits , dim = - 1 )
143+ elif "per_channel" in qa_mode or "perCh" in qa_mode :
144+ act_quantizer = QMaxDynamic (nbits , dim = - 2 )
145+ elif "sym" in qa_mode :
146+ act_quantizer = Qmax (
147+ nbits ,
148+ align_zero = True ,
149+ minmax = False ,
150+ extend_act_range = extend_act_range ,
151+ )
152+ else :
153+ act_quantizer = Qmax (nbits , align_zero = align_zero , minmax = False )
143154 elif qa_mode == "fix" :
144155 act_quantizer = QFixSymmetric (
145156 nbits , init_clip_val = clip_val , align_zero = align_zero
146157 )
147- elif qa_mode == "maxsym" :
148- act_quantizer = Qmax (
149- nbits ,
150- align_zero = True ,
151- minmax = False ,
152- extend_act_range = extend_act_range ,
153- )
154158 elif qa_mode == "pactsym" :
155159 act_quantizer = PACT2Sym (
156160 nbits ,
@@ -190,8 +194,6 @@ def get_activation_quantizer(
190194 perToken = perToken ,
191195 emulate = True ,
192196 )
193- elif qa_mode == "pertokenmax" :
194- act_quantizer = PerTokenMax (nbits )
195197 else :
196198 raise ValueError (f"unrecognized activation quantization mode { qa_mode } " )
197199 else : # swcap-compatible activation quantizers
@@ -266,7 +268,7 @@ def get_weight_quantizer(
266268 use_rot = False
267269 if "rot_" in qw_mode or "_rot" in qw_mode :
268270 use_rot = True
269- qw_mode .replace ("rot_" , "" ).replace ("_rot" , "" )
271+ qw_mode = qw_mode .replace ("rot_" , "" ).replace ("_rot" , "" )
270272
271273 weight_quantizer = None
272274 if not use_swcap :
@@ -3495,7 +3497,7 @@ def __init__(self, num_bits):
34953497 """
34963498 For per-token activation quantization using abs().max() as scale,
34973499 Zero is aligned so that the levels are symmetric around zero (lossing one level)
3498- Since the token length is un-known before running, the quatnization is dynamic, meaning
3500+ Since the token length is un-known before running, the quantization is dynamic, meaning
34993501 no trainable quantization scales and the scales are computed at run time.
35003502 """
35013503 super ().__init__ ()
@@ -3512,6 +3514,42 @@ def __repr__(self):
35123514 return f"{ self .__class__ .__name__ } (num_bits={ self .num_bits } , quantizer=)"
35133515
35143516
3517+ class QMaxDynamic (nn .Module ):
3518+ def __init__ (self , num_bits , dim = - 1 ):
3519+ """
3520+ For per-token or per-channel quantization using abs().max() as scale, usually for activation
3521+ and could be used for Qbmm M2 as well.
3522+ (reduce) dim = -1 -> abs() will output a column vector (if input is 2D) => per token
3523+ dim = -2 -> per-channel
3524+ Zero is aligned so that the levels are symmetric around zero (lossing one level)
3525+ Since the token length is un-known before running, the quantizater can only calculate the
3526+ scales at the run times dynamically, meaning no trainable quantization scales is allowed.
3527+ (unless input seq length is always the same, not just padded to a fixed length.)
3528+ """
3529+ super ().__init__ ()
3530+ self .num_bits = num_bits
3531+ self .levels = 2 ** (self .num_bits - 1 ) - 1
3532+ if isinstance (dim , str ):
3533+ if "perCh" in dim or "per_channel" in dim :
3534+ dim = - 2
3535+ elif "perToken" in dim or "per_token" in dim or "per_Token" in dim :
3536+ dim = - 1
3537+ elif dim in [- 1 , - 2 ]:
3538+ self .reduce_dim = dim
3539+ else :
3540+ raise ValueError (
3541+ f"Reduce dim can only be [-1, -2] or ['perCh', 'perToken'] but found { dim } "
3542+ )
3543+
3544+ def forward (self , input_tensor ):
3545+ amax_dim = input_tensor .abs ().max (dim = self .reduce_dim , keepdim = True )[0 ]
3546+ scales = amax_dim .clamp (min = 1e-5 ).div (self .levels )
3547+ return input_tensor .div (scales ).round ().mul (scales )
3548+
3549+ def __repr__ (self ):
3550+ return f"{ self .__class__ .__name__ } (num_bits={ self .num_bits } , quantizer=)"
3551+
3552+
35153553class Qdynamic (nn .Module ):
35163554 def __init__ (
35173555 self ,
@@ -4585,7 +4623,7 @@ def forward(self, x_orig):
45854623
45864624class Qbypass (nn .Module ):
45874625 """
4588- no quantization at all, straight-thru
4626+ No quantization at all, output the input_tensor directly.
45894627 in place of lambda function when using nbits=32 and 16.
45904628 to avoid issue when pickle (ie torch.save) of lambda
45914629 (seems to be a problem only for DDP)
0 commit comments