-
Notifications
You must be signed in to change notification settings - Fork 20
Expand file tree
/
Copy pathquantizers.py
More file actions
5601 lines (4949 loc) · 198 KB
/
quantizers.py
File metadata and controls
5601 lines (4949 loc) · 198 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright The FMS Model Optimizer Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Util function transformers_prepare_input() is borrowed from huggingface transformers/trainer.py
Trainer class method _prepare_input().
see https://github.com/huggingface/transformers/blob/v4.46.3/src/transformers/trainer.py#L3497
Class/function MSEObserver, ObserverBase, fake_quantize_per_channel_affine,
fake_quantize_per_tensor_affine, _transform_to_ch_axis, CyclicTempDecay, LinearTempDecay,
AdaRoundSTE, AdaRoundQuantizerare are modified from BRECQ's repo: https://github.com/yhhhli/BRECQ
"""
# pylint: disable=too-many-return-statements
# mypy: disable-error-code="assignment"
# Standard
from collections.abc import Mapping
from typing import Any, Union
import logging
import math
import os
import random
# Third Party
from packaging.version import Version
import numpy as np
import torch
import torch.fx
import torch.nn as nn # pylint: disable=consider-using-from-import
import torch.nn.functional as F
logger = logging.getLogger(__name__)
def get_activation_quantizer(
qa_mode="PACT",
nbits=32,
clip_val=None,
clip_valn=None,
non_neg=False,
align_zero=True,
extend_act_range=False,
use_swcap=False,
use_PT_native_Qfunc=False,
use_subnormal=False,
):
"""Return a quantizer for activation quantization
Regular quantizers:
- pact, pact2 (non_neg, cgpact, pact+)
- pactsym/pactsym+
- max, minmax, maxsym
- lsq+, lsq (inactive), qil, qsilu, dorefa, fix
- brecq (PTQ)
SWCAP quantizers (do not dequantize):
- pact/pact+/pactsym
- sawb/sawb+
- max
"""
if not use_swcap:
QPACTLUT = {
"pact_uni": PACT,
"pact_bi": PACT2,
"cgpact_uni": PACT,
"cgpact_bi": PACT2,
"pact+_uni": PACT,
"pact+_bi": PACT2,
}
if "pact" in qa_mode and "sym" not in qa_mode:
keyQact = qa_mode + "_uni" if non_neg else qa_mode + "_bi"
cggrad = "cgpact" in qa_mode
pact_plus = "pact+" in qa_mode
act_quantizer = (
QPACTLUT[keyQact](
nbits,
init_clip_val=clip_val,
init_clip_valn=clip_valn,
dequantize=True,
inplace=False,
cggrad=cggrad,
pact_plus=pact_plus,
)
if non_neg
else QPACTLUT[keyQact](
nbits,
init_clip_val=clip_val,
init_clip_valn=clip_valn,
dequantize=True,
inplace=False,
cggrad=cggrad,
pact_plus=pact_plus,
align_zero=align_zero,
# only implemented in pact2ste and pactplus2ste
use_PT_native_Qfunc=use_PT_native_Qfunc,
)
)
elif qa_mode == "lsq+":
act_quantizer = LSQPlus(
nbits,
init_clip_vals=clip_val,
init_clip_valb=clip_valn,
dequantize=True,
inplace=False,
)
elif qa_mode == "qsilu":
act_quantizer = QSILU(
nbits,
init_clip_val=clip_val,
init_clip_valn=-0.28746,
dequantize=True,
inplace=False,
)
elif qa_mode == "dorefa":
act_quantizer = dorefa_quantize_activation
elif (
qa_mode == "max"
): # NOTE Need to be careful using this for activation, particular to 1 sided.
act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=False)
elif qa_mode == "minmax":
act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=True)
elif qa_mode == "fix":
act_quantizer = QFixSymmetric(
nbits, init_clip_val=clip_val, align_zero=align_zero
)
elif qa_mode == "maxsym":
act_quantizer = Qmax(
nbits,
align_zero=True,
minmax=False,
extend_act_range=extend_act_range,
)
elif qa_mode == "pactsym":
act_quantizer = PACT2Sym(
nbits,
init_clip_val=clip_val,
dequantize=True,
inplace=False,
)
elif qa_mode == "pactsym+":
act_quantizer = PACTplusSym(
nbits,
init_clip_val=clip_val,
dequantize=True,
inplace=False,
intg_zp=align_zero,
OORgradnoclip=False,
extend_act_range=extend_act_range,
)
elif qa_mode == "brecq":
act_quantizer = UniformAffineQuantizer(nbits, inited=True)
elif "fp8" in qa_mode:
if "custom" in qa_mode:
act_quantizer = to_custom_fp8(
bits=nbits,
q_mode=qa_mode,
use_subnormal=use_subnormal,
scale_to_max="scale" in qa_mode,
)
else:
# qa_mode should be one of:
# [fp8_e4m3_sat, fp8_e5m2_sat, fp8_e4m3_scale, fp8_e5m2_scale]
# by default, emulate = True, unless using a GPU that support FP8 computation
# NOTE: emulate will be similar to dequantize.
perToken = "perToken" in qa_mode
act_quantizer = to_fp8(
nbits,
q_mode=qa_mode,
perToken=perToken,
emulate=True,
)
elif qa_mode == "pertokenmax":
act_quantizer = PerTokenMax(nbits)
else:
raise ValueError(f"unrecognized activation quantization mode {qa_mode}")
else: # swcap-compatible activation quantizers
if qa_mode in ("pact", "pact+"):
if non_neg:
assert qa_mode == "pact", "pact+ not yet supported on single side PACT"
act_quantizer = PACT_sw(
nbits,
init_clip_val=clip_val,
dequantize=False,
align_zero=align_zero,
)
else:
pact_plus = qa_mode == "pact+"
act_quantizer = PACT2_sw(
nbits,
init_clip_val=clip_val,
init_clip_valn=clip_valn,
dequantize=False,
align_zero=align_zero,
pact_plus=pact_plus,
)
elif qa_mode == "pactsym":
act_quantizer = PACT2sym_sw(nbits, init_clip_val=clip_val, dequantize=False)
elif qa_mode == "sawb":
act_quantizer = SAWB_sw(
nbits, dequantize=False, clipSTE=False, recompute=False
)
elif qa_mode == "sawb+":
act_quantizer = SAWB_sw(
nbits, dequantize=False, clipSTE=True, recompute=False
)
elif qa_mode == "max":
act_quantizer = Qmax_sw(nbits, dequantize=False)
else:
raise ValueError(
f"activation quantization mode {qa_mode} is incompatible with swcap"
)
return act_quantizer
def get_weight_quantizer(
qw_mode="SAWB+",
nbits=32,
clip_val=None,
clip_valn=None,
align_zero=True,
w_shape=None,
use_swcap=False,
recompute=False,
perGp=None,
use_subnormal=False,
):
"""Return a quantizer for weight quantization
Regular quantizers:
- sawb (16, perCh, +, interp)
- max, minmax
- pact, cgpact, pact+
- lsq+, fix, dorefa
- brecq, adaround
SWCAP quantizers:
- sawb/sawb+
- max
"""
weight_quantizer = None
if not use_swcap:
cggrad = "cgpact" in qw_mode
if "sawb" in qw_mode:
Nch = w_shape[0] if w_shape is not None and "perCh" in qw_mode else False
clipSTE = "+" in qw_mode
intp = "interp" in qw_mode
weight_quantizer = SAWB(
nbits,
dequantize=True,
inplace=False,
align_zero=True,
clipSTE=clipSTE,
perCh=Nch,
interp=intp,
)
elif "max" in qw_mode:
Nch = w_shape[0] if w_shape is not None and "perCh" in qw_mode else False
Ngp = (
[w_shape[0] * w_shape[1] // perGp, perGp]
if "perGp" in qw_mode
else False
) # store clip_val size and group size
weight_quantizer = Qmax(
nbits,
align_zero=align_zero,
minmax="min" in qw_mode,
perCh=Nch,
perGp=Ngp,
)
elif qw_mode == "pact":
weight_quantizer = PACT2(
nbits,
init_clip_val=clip_val,
init_clip_valn=clip_valn,
cggrad=cggrad,
dequantize=True,
inplace=False,
)
elif qw_mode == "cgpact":
...
# TODO check implementation
elif qw_mode == "pact+":
weight_quantizer = PACTplusSym(
nbits,
init_clip_val=clip_val,
dequantize=True,
inplace=False,
intg_zp=align_zero,
OORgradnoclip=False,
)
elif qw_mode == "lsq+":
weight_quantizer = LSQPlus(
nbits,
init_clip_vals=clip_val,
init_clip_valb=clip_valn,
dequantize=True,
inplace=False,
)
elif qw_mode == "fix":
weight_quantizer = QFixSymmetric(
nbits, init_clip_val=clip_val, align_zero=align_zero
)
elif qw_mode == "brecq":
weight_quantizer = UniformAffineQuantizer(nbits, inited=True)
elif "adaround" in qw_mode:
useSAWB = (
"SAWB" in qw_mode
) # use SAWB to determine delta, also allow grad/update for weights
weight_quantizer = AdaRoundQuantizer(
nbits,
round_mode="learned_hard_sigmoid" if not useSAWB else "weight_STE",
useSAWB=useSAWB,
perCh="perCh" in qw_mode,
multimodal="multimodal" in qw_mode,
scalebyoptim="optim" in qw_mode,
)
elif "fp8" in qw_mode:
if "custom" in qw_mode:
weight_quantizer = to_custom_fp8(
bits=nbits,
q_mode=qw_mode,
use_subnormal=use_subnormal,
scale_to_max="scale" in qw_mode,
)
else:
# qw_mode should be one of:
# [fp8_e4m3_sat, fp8_e5m2_sat, fp8_e4m3_scale, fp8_e5m2_scale] + 'perCh'
# by default, emulate = True, unless using a GPU that support FP8 computation
# NOTE: emulate will be similar to dequantize.
Nch = (
w_shape[0] if w_shape is not None and "perCh" in qw_mode else False
)
weight_quantizer = to_fp8(
nbits,
q_mode=qw_mode,
emulate=True,
perCh=Nch,
)
else:
raise ValueError(f"unrecognized weight quantized mode {qw_mode}")
else: # swcap-compatible weight quantizers
assert (
align_zero
), "Error during weight quantizer selection: swcap requires zero alignment"
if qw_mode == "sawb":
weight_quantizer = SAWB_sw(
nbits, dequantize=False, clipSTE=False, recompute=recompute
)
elif qw_mode == "sawb+":
weight_quantizer = SAWB_sw(
nbits, dequantize=False, clipSTE=True, recompute=recompute
)
elif qw_mode == "max":
weight_quantizer = Qmax_sw(nbits, dequantize=False, recompute=recompute)
else:
raise ValueError(
f"activation quantized mode {qw_mode} is incompatible with swcap"
)
return weight_quantizer
######SAWB Quantizers#######
class SAWB(nn.Module):
"""SAWB with custom backward (gradient pass through for clip function)
if align_zero: quantizer = SAWBSTE() for coded sawb such as 103, 403, 803
if not align_zero: quantizer = SAWBZeroSTE() for normal precision setting such as 2, 4, 8
SAWB is only used to quantize weights
"""
def __init__(
self,
num_bits,
dequantize=True,
inplace=False,
align_zero=False,
clipSTE=True,
perCh=False,
interp=False,
):
super().__init__()
if num_bits in [2, 4, 8]:
self.num_bits = num_bits
else:
raise ValueError("FMS: SAWB supports 2, 4, and 8-bit quantization only.")
self.dequantize = dequantize
self.inplace = inplace
self.align_zero = align_zero
self.clipSTE = clipSTE
self.perCh = perCh # if perCh, this will be the number of ch_out
self.interp = interp
self.set_quantizer()
# self.register_buffer(
# "sawb_clip", torch.zeros(perCh) if perCh else torch.Tensor([0.0])
# ) # will obsolete soon
self.register_buffer(
"clip_val", torch.zeros(perCh) if perCh else torch.Tensor([0.0])
) # make it consistent with other quantizers
def set_quantizer(self):
if self.clipSTE:
if self.align_zero:
self.quantizer = (
SAWBPlusZeroPerChSTE
if self.perCh and self.num_bits in [2, 4, 8]
else SAWBPlusZeroSTE
)
else:
self.quantizer = SAWBPlusSTE
else:
# if perCh but no sawb+ (e.g. `sawb_perCh`) will use a per-tensor clip
# copied over each channel
if self.align_zero:
self.quantizer = SAWBZeroSTE
else:
self.quantizer = SAWBSTE
def forward(self, input_tensor):
input_tensor = self.quantizer.apply(
input_tensor,
self.num_bits,
self.dequantize,
self.inplace,
self.clip_val,
self.training,
)
# NOTE: in the past, SAWB didn't check eval/training and recalc clipvals no matter what,
# now we should pass self.training to avoid confusion
return input_tensor
def __repr__(self):
inplace_str = ", inplace" if self.inplace else ""
return (
f"{self.__class__.__name__}(num_bits={self.num_bits}, "
f"quantizer={self.quantizer}{inplace_str})"
)
class SAWBPlusZeroSTE(torch.autograd.Function):
"""SAWB+ with zero alignment (symmetric) and no gradient clipping
Supported bits: 2, 4, 7, 8
Other bits requests: runs x.abs().max(), not SAWB
"dequantize=False" option is functional
"""
@staticmethod
def forward(
ctx, input_tensor, num_bits, dequantize, inplace, objSAWB_clip_val, istraining
):
if inplace:
ctx.mark_dirty(input_tensor)
scale = 2**num_bits - 2
zero_point = 0.0
if istraining:
bits2code = {2: 103, 4: 403, 7: 703, 8: 803}
if num_bits in bits2code:
clip_val, _ = sawb_params_code(
num_bits, bits2code[num_bits], input_tensor
)
else:
clip_val = input_tensor.abs().max()
else:
# do not recalc clipval when under eval mode
clip_val = objSAWB_clip_val
# Sometimes sawb returns negative clipvals, add a safety check
if clip_val <= 0:
clip_val = input_tensor.abs().max()
if len(clip_val.shape) == 0:
clip_val = clip_val.unsqueeze(dim=0)
objSAWB_clip_val.copy_(clip_val)
output = input_tensor.mul(1 / clip_val).clamp(-1, 1).mul(0.5).add(0.5)
output = linear_quantize(output, scale, zero_point, inplace)
if dequantize:
output = linear_dequantize(output, scale, zero_point, inplace)
output = (2 * output - 1) * clip_val
else:
output -= scale / 2
return output
@staticmethod
def backward(ctx, grad_output):
return grad_output, None, None, None, None, None
class SAWBPlusZeroPerChSTE(torch.autograd.Function):
"""per-channel SAWB with zero alignment, can use 15 bins, i.e. [-7,7]"""
@staticmethod
def forward(
ctx, input_tensor, num_bits, dequantize, inplace, objSAWB_clip_val, istraining
):
# assert num_bits in [4, 8], "only implemented for 4bit and 8bit"
if inplace:
ctx.mark_dirty(input_tensor)
if istraining:
# only recalc clipvals under training mode
SAWBcode_mapping = {8: 803, 4: 403, 2: 103}
if num_bits in [2, 4, 8]:
sawb_code = SAWBcode_mapping[num_bits]
clip_val, _ = sawb_params_code(
num_bits, sawb_code, input_tensor, perCh=True
)
else:
# use min/max for 8bit sawb for now.
clip_val = torch.max(
input_tensor.abs().reshape([input_tensor.shape[0], -1]), dim=1
).values
assert (
len(clip_val) == input_tensor.shape[0]
), f"dimension error, input_tensor{input_tensor.shape}, clipval{clip_val.shape}"
else:
# do not recalc clipval when under eval mode
clip_val = objSAWB_clip_val
objSAWB_clip_val.copy_(clip_val)
int_l = -(2 ** (num_bits - 1)) + 1
int_u = -int_l
scale = clip_val * 2 / (2**num_bits - 2)
# original SAWB assumes odd number of bins when calc clip_val
zero_point = torch.zeros_like(scale) # SAWB always centers around 0 and align 0
if dequantize:
output = torch.fake_quantize_per_channel_affine(
input_tensor.float(),
scale.float(),
zero_point.float(),
axis=0,
quant_min=int_l,
quant_max=int_u,
).to(
clip_val.dtype
) # NOTE return will be a fp32 tensor; function only support float()
else:
output = torch.quantize_per_channel(
input_tensor, scale, zero_point, 0, torch.qint8
).int_repr()
# NOTE return will be a torch.int8 tensor
return output
@staticmethod
def backward(ctx, grad_output):
grad_input_tensor = grad_output.clone()
return grad_input_tensor, None, None, None, None, None
class SAWBZeroSTE(torch.autograd.Function):
"""SAWB with zero alignment (symmetric) and gradient clipping
Supported bits: 2, 4, 7, 8
Other bits requests: runs x.abs().max(), not SAWB
"dequantize=False" option is functional
"""
@staticmethod
def forward(
ctx, input_tensor, num_bits, dequantize, inplace, objSAWB_clip_val, istraining
):
if inplace:
ctx.mark_dirty(input_tensor)
scale = 2**num_bits - 2
zero_point = 0.0
if istraining:
bits2code = {2: 103, 4: 403, 7: 703, 8: 803}
if num_bits in bits2code:
clip_val, _ = sawb_params_code(
num_bits, bits2code[num_bits], input_tensor
)
else:
clip_val = input_tensor.abs().max()
else:
# do not recalc clipval when under eval mode
clip_val = objSAWB_clip_val
if len(clip_val.shape) == 0:
clip_val = clip_val.unsqueeze(dim=0)
objSAWB_clip_val.copy_(clip_val)
output = input_tensor.mul(1 / clip_val).clamp(-1, 1).mul(0.5).add(0.5)
output = linear_quantize(output, scale, zero_point, inplace)
if dequantize:
output = linear_dequantize(output, scale, zero_point, inplace)
output = (2 * output - 1) * clip_val
else:
output -= scale / 2
ctx.save_for_backward(input_tensor, clip_val)
return output
@staticmethod
def backward(ctx, grad_output):
input_tensor, clip_val = ctx.saved_tensors
grad_input_tensor = grad_output.clone()
grad_input_tensor = torch.where(
input_tensor < -clip_val,
torch.zeros_like(grad_input_tensor),
grad_input_tensor,
)
grad_input_tensor = torch.where(
input_tensor > clip_val,
torch.zeros_like(grad_input_tensor),
grad_input_tensor,
)
return grad_input_tensor, None, None, None, None, None
def sawb_params_code(num_bits, code, out, perCh=False):
with torch.no_grad():
coeff_dict = {
102: (3.12, -2.064), # [-a, -a/3, a/3, a] equivalent to 2 bits
103: (2.6, -1.71), # [-a, 0, a]
403: (12.035, -12.03), # [-a, -6/7a, ..., 0, ..., 6/7a, a]
703: (28.24, -30.81),
803: (31.76, -35.04),
}
if not coeff_dict.get(code) is None:
coeff = coeff_dict[code]
else:
raise ValueError(f"SAWB not implemented for code={code}")
if perCh:
# per-channel
reduce_dim = list(range(1, len(out.shape)))
# conv W=[ch_o, ch_i, ki, ij], linear W=[ch_o, ch_i], reduce all dim but ch_out
mu = torch.mean(out.abs(), dim=reduce_dim)
std = torch.mean(out**2, dim=reduce_dim).sqrt()
clip_val_vec = coeff[1] * mu + coeff[0] * std
return clip_val_vec, None
# per-tensor
x = out.flatten()
mu = x.abs().mean()
std = x.mul(x).mean().sqrt()
clip_val = coeff[1] * mu + coeff[0] * std
if code in [102]:
nspace = 2**num_bits - 1
elif code in [403, 103, 703, 803]:
nspace = 2**num_bits - 2
else:
raise ValueError(f"SAWB not implemented for code={code}")
return clip_val, nspace
class SAWBPlusSTE(torch.autograd.Function):
"""
SAWB+: no zero alignment and no gradient clipping
Incorrect behavior for "dequantize=False" - do not use
"""
@staticmethod
def forward(
ctx, input_tensor, num_bits, dequantize, inplace, objSAWB_clip_val, istraining
):
if inplace:
ctx.mark_dirty(input_tensor)
scale, zero_point = asymmetric_linear_quantization_params(
num_bits, saturation_min=0, saturation_max=1, signed=False
) # returns scale = 2^bits-1, zero_point = 0
if istraining:
# only recalc clipval under training mode
if num_bits in [2, 3, 4, 5]: # 8
clip_val = sawb_params(num_bits, input_tensor)
else:
clip_val = input_tensor.abs().max()
else:
# do not recalc clipval when under eval mode
clip_val = objSAWB_clip_val
if len(clip_val.shape) == 0:
clip_val = clip_val.unsqueeze(dim=0)
objSAWB_clip_val.copy_(clip_val)
output = input_tensor.mul(1 / clip_val).clamp(-1, 1).mul(0.5).add(0.5)
output = linear_quantize(output, scale, zero_point, inplace)
if dequantize:
output = linear_dequantize(output, scale, zero_point, inplace)
output = (2 * output - 1) * clip_val
return output
@staticmethod
def backward(ctx, grad_output):
grad_input_tensor = grad_output.clone()
return grad_input_tensor, None, None, None, None, None
class SAWBSTE(torch.autograd.Function):
"""
SAWB without zero alignment
Incorrect behavior for "dequantize=False" - do not use
"""
@staticmethod
def forward(
ctx, input_tensor, num_bits, dequantize, inplace, objSAWB_clip_val, istraining
):
if inplace:
ctx.mark_dirty(input_tensor)
scale, zero_point = asymmetric_linear_quantization_params(
num_bits, saturation_min=0, saturation_max=1, signed=False
)
if istraining:
if num_bits in [2, 3, 4, 5]:
clip_val = sawb_params(num_bits, input_tensor)
else:
clip_val = input_tensor.abs().max()
else:
# do not recalc clipval when under eval mode
clip_val = objSAWB_clip_val
if len(clip_val.shape) == 0:
clip_val = clip_val.unsqueeze(dim=0)
objSAWB_clip_val.copy_(clip_val)
output = input_tensor.mul(1 / clip_val).clamp(-1, 1).mul(0.5).add(0.5)
output = linear_quantize(output, scale, zero_point, inplace)
if dequantize:
output = linear_dequantize(output, scale, zero_point, inplace)
output = (2 * output - 1) * clip_val
ctx.save_for_backward(input_tensor, clip_val)
return output
@staticmethod
def backward(ctx, grad_output):
input_tensor, clip_val = ctx.saved_tensors
grad_input_tensor = grad_output.clone()
grad_input_tensor = torch.where(
input_tensor < -clip_val,
torch.zeros_like(grad_input_tensor),
grad_input_tensor,
)
grad_input_tensor = torch.where(
input_tensor > clip_val,
torch.zeros_like(grad_input_tensor),
grad_input_tensor,
)
return grad_input_tensor, None, None, None, None, None
def sawb_params(num_bits, out):
with torch.no_grad():
x = out.flatten()
mu = x.abs().mean()
std = x.mul(x).mean().sqrt()
dic_coeff = {
2: (3.12, -2.064),
3: (7.509, -6.892),
4: (12.68, -12.80),
5: (17.74, -18.64),
8: (31.76, -35.04),
}
if num_bits > 8:
raise ValueError(f"SAWB not implemented for num_bits={num_bits}")
coeff = dic_coeff[num_bits]
clip_val = coeff[1] * mu + coeff[0] * std
return clip_val
#####################################
##############1-side PACT###############
class PACT(nn.Module):
"""1-sided original PACT
PACT is only used to quantize activations
"""
def __init__(
self,
num_bits,
init_clip_val,
init_clip_valn=0, # pylint: disable=unused-argument
dequantize=True,
inplace=False,
cggrad=False,
grad_scale=False,
pact_plus=False,
):
super().__init__()
self.num_bits = num_bits
if isinstance(init_clip_val, torch.Tensor):
self.clip_val = nn.Parameter(init_clip_val)
else:
self.clip_val = nn.Parameter(torch.Tensor([init_clip_val]))
self.dequantize = dequantize
self.inplace = inplace
self.cggrad = cggrad
self.grad_scale = grad_scale
self.quantizer = (
CGPACT_STE
if self.cggrad
else PACTplusSTE
if pact_plus
else CGPACT_gScale_STE
if self.grad_scale
else PACT_STE
)
def forward(self, input_tensor):
input_tensor = self.quantizer.apply(
input_tensor,
self.clip_val,
self.num_bits,
self.dequantize,
self.inplace,
)
return input_tensor
def __repr__(self):
inplace_str = ", inplace" if self.inplace else ""
return (
f"{self.__class__.__name__}(num_bits={self.num_bits}, clip_val={self.clip_val[0]}, "
f"cggrad={self.cggrad}, grad_scale={self.grad_scale}, quantizer={self.quantizer}, "
f"{inplace_str})"
)
class PACT_STE(torch.autograd.Function):
"""1-sided original PACT"""
@staticmethod
def forward(ctx, input_tensor, clip_val, num_bits, dequantize, inplace):
clip_val = clip_val.to(input_tensor.dtype)
ctx.save_for_backward(input_tensor, clip_val)
if inplace:
ctx.mark_dirty(input_tensor)
scale, zero_point = asymmetric_linear_quantization_params(
num_bits, saturation_min=0, saturation_max=clip_val.data, signed=False
)
if isinstance(clip_val, torch.Tensor):
if input_tensor.min() < 0:
raise ValueError(
"FMS: input_tensor to single_side PACT should be non-negative."
)
output = torch.where(
input_tensor > clip_val,
torch.ones_like(input_tensor) * clip_val,
input_tensor,
)
else:
output = clamp(input_tensor, 0, clip_val.data, inplace=inplace)
output = linear_quantize(output, scale, zero_point, inplace)
if dequantize:
output = linear_dequantize(output, scale, zero_point, inplace)
return output
@staticmethod
def backward(ctx, grad_output):
input_tensor, clip_val = ctx.saved_tensors
grad_input_tensor = grad_output.clone()
grad_input_tensor = torch.where(
input_tensor < 0, torch.zeros_like(grad_input_tensor), grad_input_tensor
)
grad_input_tensor = torch.where(
input_tensor > clip_val,
torch.zeros_like(grad_input_tensor),
grad_input_tensor,
)
grad_alpha = grad_output.clone()
grad_alpha = torch.where(
input_tensor < clip_val, torch.zeros_like(grad_alpha), grad_alpha
)
grad_alpha = grad_alpha.sum().expand_as(clip_val)
return grad_input_tensor, grad_alpha, None, None, None, None
class CGPACT_STE(torch.autograd.Function):
"""1-sided CGPACT
use calibrated clip_val gradient to update clip_val
"""
@staticmethod
def forward(ctx, input_tensor, clip_val, num_bits, dequantize, inplace):
ctx.save_for_backward(input_tensor, clip_val)
if inplace:
ctx.mark_dirty(input_tensor)
scale, zero_point = asymmetric_linear_quantization_params(
num_bits, 0, clip_val.data, signed=False
)
if isinstance(clip_val, torch.Tensor):
if input_tensor.min() < 0:
raise ValueError(
"FMS: input_tensor to ClippedLinearQuantization should be non-negative."
)
output = torch.where(
input_tensor > clip_val,
torch.ones_like(input_tensor) * clip_val,
input_tensor,
)
else:
output = clamp(input_tensor, 0, clip_val.data, inplace=inplace)
output, ctx.residual = linear_quantize_residual(
output, scale, zero_point, inplace
)
with torch.no_grad():
n = 2**num_bits - 1
ctx.residual /= n
if dequantize:
output = linear_dequantize(output, scale, zero_point, inplace)
return output
@staticmethod
def backward(ctx, grad_output):
input_tensor, clip_val = ctx.saved_tensors
grad_input_tensor = grad_output.clone()
grad_input_tensor = torch.where(
input_tensor < 0, torch.zeros_like(grad_input_tensor), grad_input_tensor
)
grad_input_tensor = torch.where(
input_tensor > clip_val,
torch.zeros_like(grad_input_tensor),
grad_input_tensor,
)
grad_alpha = grad_output.clone()
grad_alpha = torch.where(
input_tensor < clip_val, grad_alpha * ctx.residual, grad_alpha
)
grad_alpha = grad_alpha.sum().expand_as(clip_val)
return grad_input_tensor, grad_alpha, None, None, None
class CGPACT_gScale_STE(torch.autograd.Function):
"""1-sided CGPACT
use calibrated clip_val gradient to update clip_val with scaled gradient
"""
@staticmethod
def forward(ctx, input_tensor, clip_val, num_bits, dequantize, inplace):
ctx.save_for_backward(input_tensor, clip_val)
if inplace:
ctx.mark_dirty(input_tensor)
scale, zero_point = asymmetric_linear_quantization_params(
num_bits, saturation_min=0, saturation_max=clip_val.data, signed=False
)
if isinstance(clip_val, torch.Tensor):
if input_tensor.min() < 0:
raise ValueError(
"FMS: input_tensor to ClippedLinearQuantization should be non-negative."
)
output = torch.where(
input_tensor > clip_val,
torch.ones_like(input_tensor) * clip_val,
input_tensor,
)
else:
output = clamp(input_tensor, 0, clip_val.data, inplace=inplace)
output, ctx.residual = linear_quantize_residual(
output, scale, zero_point, inplace
)
with torch.no_grad():
n = 2**num_bits - 1
ctx.residual /= n
if dequantize:
output = linear_dequantize(output, scale, zero_point, inplace)
return output
@staticmethod
def backward(ctx, grad_output):
input_tensor, clip_val = ctx.saved_tensors
grad_input_tensor = grad_output.clone()
grad_input_tensor = torch.where(
input_tensor < 0, torch.zeros_like(grad_input_tensor), grad_input_tensor
)
grad_input_tensor = torch.where(
input_tensor > clip_val,
torch.zeros_like(grad_input_tensor),
grad_input_tensor,
)
grad_alpha = grad_output.clone()
grad_alpha = torch.where(
input_tensor < clip_val, grad_alpha * ctx.residual, grad_alpha
)
grad_alpha = grad_alpha.sum().expand_as(clip_val)