-
Notifications
You must be signed in to change notification settings - Fork 20
Expand file tree
/
Copy pathptq.py
More file actions
2687 lines (2407 loc) · 103 KB
/
ptq.py
File metadata and controls
2687 lines (2407 loc) · 103 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright The FMS Model Optimizer Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Post-Training Quantization (PTQ) functions
Class StraightThrough, function _fold_bn, fold_bn_into_conv, reset_bn, and
search_fold_and_remove_bn are modified from QDROP repo https://github.com/wimh966/QDrop
"""
# Standard
from functools import partial
from typing import Optional, Union
import logging
import math
import random
import sys
# Third Party
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import numpy as np
import pandas as pd
# from numpy.lib.function_base import iterable
import torch
import torch.nn as nn
import torch.nn.functional as F
# Local
from fms_mo.modules import QBmm, QLinear
from fms_mo.modules.conv import QConv2dPTQv2
from fms_mo.quant.quantizers import (
AdaRoundQuantizer,
Qdynamic,
get_activation_quantizer,
lp_loss,
)
from fms_mo.utils.import_utils import available_packages
from fms_mo.utils.utils import move_to, patch_torch_bmm
logger = logging.getLogger(__name__)
try:
# Third Party
from piqa.piqa import SSIM # ,MS_SSIM unused-import
piqa_installed = True
except:
piqa_installed = False
# TODO: this function is not used. Can be removed.
def set_seed(seed):
"""
Set random seed
Not all the reproducibility items are implemented,
See https://pytorch.org/docs/stable/notes/randomness.html
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
# --------------------------------------------------------
# --------- PTQ-related util functions -------------------
# --------------------------------------------------------
class LinearTempDecay:
def __init__(self, t_max=20000, warm_up=0.2, hold=1.0, b_range=(20, 2)):
self.t_max = hold * t_max
self.start_decay = warm_up * t_max
# NOTE from warm_up to warm_up2, round_loss starts to work but no decay in beta and lambda
# see PTQLossFunc for more details
self.start_b = b_range[0]
self.end_b = b_range[1]
self.curr_beta = 0.0
def __call__(self, t):
# NOTE from warm_up to warm_up2, round_loss starts to work but no decay in beta and lambda
if t < self.start_decay:
self.curr_beta = self.start_b
elif t > self.t_max:
self.curr_beta = self.end_b
else:
rel_t = (t - self.start_decay) / (self.t_max - self.start_decay)
self.curr_beta = self.end_b + (self.start_b - self.end_b) * max(
0.0, (1 - rel_t)
)
return self.curr_beta
class CyclicTempDecay:
def __init__(self, t_max=20000, warm_up=0.2, hold=1.0, b_maxmin=(20, 2), style="V"):
self.t_max = hold * t_max
self.start_decay = warm_up * t_max
# Annealing only happens between [warmup, hold]*t_max, from max_b -> min_b -> max_b
# e.g. usually no round_loss before 0.2*tmax (controlled by PTQloss_func),
# here we still return b_max, and hold at b_max after 0.8*tmax
self.max_b = b_maxmin[0]
self.min_b = b_maxmin[1]
assert self.max_b > self.min_b, "max_b is smaller than min_b, please check!"
self.curr_beta = 0.0
# style can be 'V', 'V*2', 'V*3'... 'cos*0.5', 'cos*2', 'cos' ... shape*N where
# N means number of cycles, N could be 0.5
if "*" not in style:
style = style + "*1"
self.style_cycles = style.split("*")
self.period = (self.t_max - self.start_decay) // float(self.style_cycles[1])
def __call__(self, t):
# NOTE from warm_up to warm_up2, round_loss starts to work but no decay in beta and lambda
if t < self.start_decay:
self.curr_beta = self.max_b
elif self.start_decay <= t < self.t_max:
rel_t = ((t - self.start_decay) % self.period) / self.period
if self.style_cycles[0] == "cos":
self.curr_beta = (
self.min_b
+ (self.max_b - self.min_b)
* (math.cos(math.pi * 2 * rel_t) + 1.0)
* 0.5
)
else: # V-shape
self.curr_beta = self.min_b + (self.max_b - self.min_b) * abs(
1.0 - 2 * rel_t
)
# else: # case t > t_max
# # -> beta unchanged (hold), i.e. use last saved curr_beta
return self.curr_beta
class PTQLossFunc(nn.Module):
"""
Loss functions for PTQ block-sequential optimization.
"""
def __init__(
self,
method="mse",
Ntotal_iters=20000,
layers=None,
p=2.0,
isOptimConv=False,
adaR_anneal={
"warmup": 0.1,
"warmup2": 0.0,
"hold": 0.9,
"beta": {"range": (20, 2), "style": "linear"},
"lambda": {"range": (0.01, 0.01), "style": "constant"},
},
):
super().__init__()
self.method = method
self.p = p
self.count = 0
self.Ntotal = Ntotal_iters
self.layers = layers
self.isOptimConv = isOptimConv
self.warmup = adaR_anneal["warmup"]
self.warmup2 = adaR_anneal["warmup2"]
self.hold = adaR_anneal["hold"]
self.loss_start = int(Ntotal_iters * self.warmup)
self.reset_ReSig = None
if self.warmup2 >= self.warmup:
self.reset_ReSig = int(Ntotal_iters * self.warmup2)
# NOTE, round_loss starts from warmup but decay could start from warmup2, controlled
# by LinearTempDecay() and CyclicTempDecay() when using delayed-decay (warmup2 !=0),
# we may further switch the formula, e.g. from 1ReSig to 3ReSig at decay_start
self.beta = adaR_anneal["beta"] # brecq's settings was (20, 2)
self.lambda_eq21 = adaR_anneal["lambda"]
if self.beta["style"] == "constant":
self.beta_decay = lambda x: self.beta["range"][0]
elif self.beta["style"] == "linear":
self.beta_decay = LinearTempDecay(
Ntotal_iters,
warm_up=max(self.warmup, self.warmup2),
hold=self.hold,
b_range=self.beta["range"],
)
else:
self.beta_decay = CyclicTempDecay(
Ntotal_iters,
warm_up=max(self.warmup, self.warmup2),
hold=self.hold,
b_maxmin=self.beta["range"],
style=self.beta["style"],
)
# style can be 'cos','V','cos*0.5', 'cos*2','V*2'...
if self.lambda_eq21["style"] == "constant":
self.lambda_decay = lambda x: self.lambda_eq21["range"][0]
elif self.lambda_eq21["style"] == "linear":
self.lambda_decay = LinearTempDecay(
Ntotal_iters,
warm_up=max(self.warmup, self.warmup2),
hold=self.hold,
b_range=self.lambda_eq21["range"],
)
else:
self.lambda_decay = CyclicTempDecay(
Ntotal_iters,
warm_up=max(self.warmup, self.warmup2),
hold=self.hold,
b_maxmin=self.lambda_eq21["range"],
style=self.lambda_eq21["style"],
)
# if method not in ['mse','normalized_change','ssim','ssimlog','ssimp0.2',
# 'ssimp0.5','ssimp2','fisher_diag','fisher_full', 'adaround']:
# logger.info('!! PTQ Loss method not defined. Use "MSE" instead !!')
# self.method = 'mse'
def __call__(
self, im1, im2, grad=None, gt=None
): # input im1, im2 are supposed to be q_out, fp_out
self.count += 1
if self.method == "mse":
return F.mse_loss(im1, im2)
if self.method in ["mae", "l1"]:
return F.l1_loss(im1, im2)
if self.method == "normalized_change":
return torch.norm(im1 - im2) / torch.norm(im2)
if "ssim" in self.method and piqa_installed:
# inputs can have very different numerical range, one is the original fp tensor,
# the other is clipvaln to clipval rescale to [0, 1] based on the larger range input,
# clamp the smaller range tensor using the larger range tensor's min/max in
# case of range inconsistency
im_min = [im1.min(), im2.min()]
im_max = [im1.max(), im2.max()]
im_range = [im_max[0] - im_min[0], im_max[1] - im_min[1]]
# base_idx = 0 if im_range[0]>im_range[1] else 1
loss_func = SSIM(n_channels=im1.shape[1], value_range=1).to(im1.device)
# im_scaled = [(im1-im_min[base_idx])/im_range[base_idx],
# (im2-im_min[base_idx])/im_range[base_idx] ]
im_scaled = [
(im1 - im_min[0]) / im_range[0],
(im2 - im_min[1]) / im_range[1],
]
# if im_min[base_idx]>im_min[1-base_idx] or im_max[base_idx]<im_max[1-base_idx]:
# im_scaled[1-base_idx] = torch.clamp(im_scaled[1-base_idx], 0.0, 1.0)
ssimloss = 1.0 - loss_func(*im_scaled)
loss = (
torch.log(ssimloss)
if self.method == "ssimlog"
else torch.pow(ssimloss, 0.2)
if self.method == "ssimp0.2"
else torch.pow(ssimloss, 0.5)
if self.method == "ssimp0.5"
else torch.pow(ssimloss, 2)
if self.method == "ssimp2"
else {"mse": F.mse_loss(im1, im2), "0.01ssim": 0.01 * ssimloss}
if self.method == "0.01ssim+mse"
else {"mse": F.mse_loss(im1, im2), "0.1ssim": 0.1 * ssimloss}
if self.method == "0.1ssim+mse"
else ssimloss
) # 'ssim' or other 'ssimxxx' all default to simple form
return loss
if self.method == "fisher_diag":
return ((im1 - im2).pow(2) * grad.pow(2)).sum(1).mean()
if self.method == "fisher_full":
a = (im1 - im2).abs()
grad = grad.abs()
batch_dotprod = torch.sum(a * grad, (1, 2, 3)).view(-1, 1, 1, 1)
return (batch_dotprod * a * grad).mean() / 100
if "adaround" in self.method:
# default is mse + rounding loss as in the original paper
round_loss = torch.tensor(0.0, device=im1.device)
ssimloss = torch.tensor(0.0, device=im1.device)
losses = {}
if self.count > self.loss_start:
# we can choose to anneal beta and lambda separately, or together
b = self.beta_decay(self.count)
lambda_eq21 = self.lambda_decay(
self.count
) # eq21 in adaround paper, use brecq's settings
for l in self.layers:
if hasattr(l, "quantize_weight") and isinstance(
l.quantize_weight, AdaRoundQuantizer
):
if self.count == self.reset_ReSig:
l.quantize_weight.reset_ReSig_param(3)
round_vals = (
l.quantize_weight.get_soft_targets()
) # calc h from eq23, now support multimodal
if l.quantize_weight.multimodal:
round_vals = (
round_vals
+ (round_vals < 0.0).to(torch.float32)
- (round_vals > 1.0).to(torch.float32)
)
# sine multi-modal f_reg
# round_loss += lambda_eq21 * ( torch.sin(round_vals * math.pi
# ).abs().pow(b)).sum()
round_loss += (
lambda_eq21
* (1 - ((round_vals - 0.5).abs() * 2).pow(b)).sum()
) # eq24
if self.method == "adaroundKL":
rec_loss = F.kl_div(
torch.log(im1 + 1e-6), im2 + 1e-6, reduction="batchmean"
)
elif self.method == "adaroundCos":
rec_loss = torch.mean(
F.cosine_similarity(im1, im2) # pylint: disable=not-callable
)
elif self.method == "adaroundL1":
rec_loss = F.l1_loss(im1, im2)
elif self.method.startswith("adaroundMonAll"):
# monitor all losses, still use mse+round for total
with torch.no_grad():
losses["l1"] = F.l1_loss(im1, im2)
losses["mse"] = F.mse_loss(
im1, im2
) # backward will use "lp loss" instead of mse
with torch.set_grad_enabled(self.method.endswith("_norm")):
losses["norm"] = torch.norm(im1 - im2) / torch.norm(im2)
with torch.set_grad_enabled(self.method.endswith("_cos")):
losses["cos"] = 1.0 - torch.mean(
F.cosine_similarity(im1, im2) # pylint: disable=not-callable
)
with torch.set_grad_enabled(self.method.endswith("_ce")):
# ce loss only works for last layer, "im1" should be q_out, im2 won't be used,
# unless we want to check fp_out's ce loss
losses["qce"] = (
0.0 * F.cross_entropy(im1, gt) if gt is not None else None
)
# may need to adjust the weighting factor
if self.method.endswith("_norm"):
rec_loss = losses["norm"]
elif self.method.endswith("_cos"):
rec_loss = losses["cos"]
elif self.method.endswith("_ce") and gt is not None:
rec_loss = losses[
"qce"
] # only last layer will have gt in input, others will default to lp_loss
else:
rec_loss = lp_loss(im1, im2, p=self.p)
else:
# use brecq and qdrop's implementation
rec_loss = lp_loss(im1, im2, p=self.p)
losses["total"] = rec_loss + round_loss
losses["reconstruct"] = rec_loss.detach() # for tensorboard plot only
losses["rounding"] = round_loss.detach() # for tensorboard plot only
return losses
# method not defined!
logger.info(f'PTQ Loss method {self.method} not defined. Use "MSE" instead.')
return F.mse_loss(im1, im2)
class PTQHookRecInOut(nn.Module):
"""
Post-hook to cache input (could be FP or Q) and output
(FP only, set PTQ_mode to 'fp32_out' before running the hooks)
"""
def __init__(self, qcfg, name=None, cls2rec=(nn.Conv2d), recInOnly=False):
super().__init__()
self.name = name
self.qcfg = qcfg
self.cls2rec = cls2rec
self.rec_input_only = recInOnly
def __call__(self, mod, x, output):
# make sure this module/block's ptqmode is not 'q_out'
submods = [m for m in mod.modules() if isinstance(m, self.cls2rec)]
if any(sm.ptqmode == "q_out" for sm in submods):
# don't record input/output if any of the submods has ptqmode =='q_out'
return
if len(x) > 1:
# transformers has more than one input, e.g. masks, etc...
self.qcfg["cached_input0"].append(x[0].detach().cpu())
self.qcfg["cached_input1"].append(x[1].detach().cpu())
else:
self.qcfg["cached_input"].append(x[0].detach().cpu())
if not self.rec_input_only:
if isinstance(output, tuple):
self.qcfg["cached_output"].append(output[0].detach().cpu())
else:
self.qcfg["cached_output"].append(output.detach().cpu())
class PTQHookRecInOutLMv2(nn.Module):
"""simplified version of recording hook for PTQ
just record the entire input tuple, no matter how many inputs are there
leave the special handling, e.g. reshape/cat/shuffling...etc, for later
"""
def __init__(self, qcfg, name=None, cls2rec=(nn.Conv2d,), recInOnly=False):
super().__init__()
self.name = name
self.qcfg = qcfg
self.cls2rec = cls2rec
self.rec_input_only = recInOnly
self.num_valid_input = -1
def __call__(self, mod, inputs, output):
# make sure this module/block's ptqmode is not 'q_out'
submods = [m for m in mod.modules() if isinstance(m, self.cls2rec)]
if any(sm.ptqmode == "q_out" for sm in submods):
# don't record input/output if any of the submods has ptqmode =='q_out'
return
# input should always be a tuple of tensors, but some could be None
# check how many valid inputs are there
if self.num_valid_input == -1: # only check once then stick to it
valid_inps = [inp is not None for inp in inputs]
if False in valid_inps:
self.num_valid_input = valid_inps.index(False)
else:
self.num_valid_input = len(valid_inps) # if all True => all valid
assert all(
isinstance(inp_i, torch.Tensor) for inp_i in inputs[: self.num_valid_input]
)
# check available GPU memory, cache on GPU if possible:
GPUmem_available, _GPUmem_total = torch.cuda.mem_get_info()
# 1 block for SQUAD/BERT 500 batches*12/batch = ~10G
if GPUmem_available / 1e9 > 20:
cache_device = "cuda"
else:
cache_device = "cpu"
self.qcfg["cached_input"].append(
tuple(
inp_i.detach().to(cache_device)
for inp_i in inputs[: self.num_valid_input]
)
)
# output could be a tuple of a single tensor or simply a tensor ?
assert isinstance(output, (torch.Tensor, tuple))
if not self.rec_input_only:
self.qcfg["cached_output"].append(
output[0].detach().to(cache_device)
if isinstance(output, tuple)
else output.detach().to(cache_device)
)
# this hook is meant for ptq_loss_func == 'fisher_diag' and to temp hold the "Q_out" of the module
class PTQHookRecQOut(nn.Module):
def __init__(self, qcfg):
super().__init__()
self.qcfg = qcfg
def __call__(self, mod, x, output):
self.qcfg["cached_qout"] = (
output # hold only 1 and no detach() or cpu(), we need to do backward on this
)
class PTQHookRecGrad(nn.Module):
def __init__(self, qcfg):
super().__init__()
self.qcfg = qcfg
def __call__(self, mod, grad_input, grad_output):
self.qcfg["cached_grad_out"] = grad_output[
0
] # hold only 1 and no detach() or cpu(), we need to do backward on this
def update_train_or_ptq_mode(
mod,
ptqmode=None,
set_mod_state="train",
submod_names=["dummy"],
class2change=(nn.Conv2d,),
):
# ptqmode is either fp32_out or q_out
# layers = [mod] if isinstance(mod, class2change) else
# [c for c in mod.children() if isinstance(c, class2change)]
layers = [
c for c in mod.modules() if isinstance(c, class2change)
] # m.modules() will include itself, too. should work for both layer-wise and block-wise
for l in layers:
# --- PTQ state --- if submod is specified (using str), like 'quantize_calib_feature'
if ptqmode in ["fp32_out", "q_out"]:
l.ptqmode = ptqmode
for smname in submod_names:
sm = getattr(l, smname, l) # default to layer itself
# --- module state --- regardless its original state,
# just set it to train() or eval() when asked.
if set_mod_state == "train":
sm.train()
elif set_mod_state == "eval":
sm.eval()
else:
# keep PTQ.mode unchanged
continue
return layers # in case u need to know which layers are changed
def ptq_mod_optim(m, layers, qcfg, optim_mode="both", **kwargs):
# from detectron2.utils.events import EventStorage
"""
Block-wise PTQ optimization mainly for vision models.
"""
curr_dev = kwargs["curr_dev"]
loss_func = kwargs["loss_func"]
mod_name = kwargs["mod_name"]
Nsteps2acc = kwargs["Nsteps2acc"]
isTransformers = hasattr(
qcfg, "cached_input1"
) # will need special handling for transformers
isLastFC = (
isinstance(m, nn.Linear) and mod_name == "fc"
) # we may want to special handle the loss of last FC
param_names = [[], [], [], []] # 0=weights, 1=Qa, 2=Qw, 3=bias
# if qcfg['PTQ_freezeweight']: qcfg['ptq_lrw']=0.0
ws, cva, cvw, biases = [], [], [], []
for idx, l in enumerate(layers):
if optim_mode != "Wonly":
if hasattr(l, "quantize_feature"):
if hasattr(l.quantize_feature, "clip_val"):
cva.append(l.quantize_feature.clip_val)
param_names[1].append(f"cv{idx}")
if hasattr(l.quantize_feature, "clip_valn"):
cva.append(l.quantize_feature.clip_valn)
param_names[1].append(f"cvn{idx}")
if hasattr(l.quantize_feature, "delta"):
cva.append(l.quantize_feature.delta)
param_names[1].append(f"delta{idx}")
elif hasattr(l, "quantize_m1"):
if hasattr(l.quantize_m1, "clip_val"):
cva.append(l.quantize_m1.clip_val)
param_names[1].append(f"cv{idx}")
if hasattr(l.quantize_m2, "clip_val"):
cva.append(l.quantize_m2.clip_val)
param_names[1].append(f"cv{idx}")
if hasattr(l.quantize_m1, "clip_valn"):
cva.append(l.quantize_m1.clip_valn)
param_names[1].append(f"cvn{idx}")
if hasattr(l.quantize_m2, "clip_valn"):
cva.append(l.quantize_m2.clip_valn)
param_names[1].append(f"cvn{idx}")
else:
logger.info(f"Layer {l} has no trainable parameter for quantization")
if optim_mode != "Aonly" and (not hasattr(l, "quantize_m1")):
# sym BRECQ or PACT+ for weight
if "brecq" in l.qw_mode:
cvw.append(l.quantize_weight.delta)
param_names[2].append(f"deltaW{idx}")
if "adaround" in l.qw_mode:
cvw.append(l.quantize_weight.alpha)
param_names[2].append(f"alphaW{idx}")
if "pact+" in l.qw_mode:
cvw.append(l.quantize_weight.clip_val)
param_names[2].append(f"cvW{idx}")
if not qcfg["PTQ_freezeweight"]:
ws.append(l.weight)
param_names[0].append(f"W{idx}")
if not hasattr(l, "quantize_m1"):
if l.bias is not None and not qcfg["PTQ_freezebias"]:
biases.append(l.bias)
param_names[3].append(f"bias{idx}")
optim_w = torch.optim.Adam(
[
{
"params": ws,
"lr": qcfg["ptq_lrw"],
}, # default lr was 1e-5 for AdaQuant, BRECQ didn't optimize weights
{"params": cvw, "lr": qcfg["ptq_lrcv_w"]}, # 1e-3 for BRECQ
{"params": biases, "lr": 1e-3},
]
) # default is 1e-3 from AdaQuant
# separate w and a optimizer as in QDROP
optim_a = torch.optim.Adam(
[
{
"params": cva,
"lr": qcfg["ptq_lrcv_a"],
"weight_decay": qcfg["pact_a_decay"],
}
]
) # lr was 1e-1 or 1e-3 in AdaQuant, 4e-5 for BRECQ
scheduler = []
if "W" in qcfg["ptq_coslr"]:
scheduler.append(
torch.optim.lr_scheduler.CosineAnnealingLR(
optim_w, T_max=qcfg["ptq_nouterloop"], eta_min=0.0
)
)
if "A" in qcfg["ptq_coslr"]:
scheduler.append(
torch.optim.lr_scheduler.CosineAnnealingLR(
optim_a, T_max=qcfg["ptq_nouterloop"], eta_min=0.0
)
)
# NOTE typical shuffle is like
# data_seq = torch.randperm(qcfg['ptq_nbatch']).repeat(
# qcfg['ptq_nouterloop']//qcfg['ptq_nbatch'] +1 )
# fine-grained shuffling
data_seq = [
torch.randperm(qcfg["ptq_nbatch"] * qcfg["ptq_batchsize"])
for _ in range(qcfg["ptq_nouterloop"] // qcfg["ptq_nbatch"] + 1)
]
data_seq = torch.cat(data_seq).reshape([-1, qcfg["ptq_batchsize"]])
pbar_desc = f"Phase 2.2: PTQ optimizing module {mod_name}. loss="
pbar2 = tqdm(
data_seq[: qcfg["ptq_nouterloop"]],
desc=pbar_desc + " ",
leave=False,
mininterval=5,
)
# prep for grad accum
optim_w.zero_grad()
optim_a.zero_grad()
for i_outer, data_idx in enumerate(pbar2):
# fetch the cached data
if isTransformers: # special handle for transformers
inp = (
qcfg["cached_input0"][data_idx].to(curr_dev),
qcfg["cached_input1"][data_idx].to(curr_dev),
)
else:
inp = qcfg["cached_input"][data_idx].to(curr_dev)
fp_out = qcfg["cached_output"][data_idx].to(curr_dev)
# gt will be only used for last FC layer, e.g. calc ce loss
gt = qcfg["cached_lbls"][data_idx].to(curr_dev) if isLastFC else None
# --- mask is for Qdrop only
if qcfg["ptq_qdrop"]:
dropout_mask_in = torch.bernoulli(
torch.full_like(inp, 0.5)
).bool() # FIXME use variable dropout rate as in NWQ?
for j_inner in range(qcfg["ptq_ninnerloop"]):
grad = None
qcfg["cached_qout"] = []
qcfg["cached_grad_out"] = []
Niter = i_outer * qcfg["ptq_ninnerloop"] + j_inner
if isTransformers:
with patch_torch_bmm(qcfg):
q_out = m(*inp)
else:
q_out = m(inp)
# if qcfg['cached_qout']==[] else qcfg['cached_qout']
# # run module(input) if not cached already
# --- Qdrop implemented here ---
if qcfg["ptq_qdrop"]:
# "inp" in the case of Qdrop is actually "all fp32" input
# (i.e., all prev mods are set to fp32_out, not "sequential")
qinp = qcfg["cached_qinput"][data_idx].to(
curr_dev
) # this is the real "qinput", where all previous modules are quantized
mixed_inp = torch.where(dropout_mask_in, qinp, inp)
q_out = m(mixed_inp)
if isTransformers:
PTQloss = loss_func(q_out[0], fp_out, grad) # *loss_scaling_acc
logger.info(f"Loss is {PTQloss}")
else:
PTQloss = loss_func(
q_out, fp_out, grad=grad, gt=gt
) # *loss_scaling_acc
# only "fisher_diag" needs "grad", only "ce loss" needs gt
loss4plot = {}
# some loss func has more than 1 term (like mse+ssim),
# will return a dict then we can plot each term in TB
if isinstance(PTQloss, (dict)):
loss4plot = {
k: v.item()
for k, v in PTQloss.items()
if isinstance(v, torch.Tensor)
}
PTQloss = (
PTQloss["total"]
if "total" in PTQloss
else torch.sum(PTQloss.values())
)
else: # if only one term, plot it with the name of the loss, e.g. mse, ssim
loss4plot[qcfg["ptq_loss_func"]] = PTQloss.item()
PTQloss.backward() # retain_graph=True if qcfg['ptq_ninnerloop']>1 else False)
# accumulate grads over Nimgs2acc, usually 2 imgs per GPU, prefer to accum 16 imgs
if (Niter + 1) % Nsteps2acc == 0 or (Niter + 1 == qcfg["ptq_nouterloop"]):
optim_w.step()
optim_w.zero_grad()
optim_a.step()
optim_a.zero_grad()
# --- tensorboard output
if qcfg["tb_writer"] and (
(qcfg["ptq_ninnerloop"] == 1 and Niter % 100 == 0)
or (qcfg["ptq_ninnerloop"] > 1 and i_outer % 10 == 0)
):
# show loss on pbar
pbar2.set_description(pbar_desc + f"{PTQloss:.3f}")
# plot loss
for k, v in loss4plot.items():
qcfg["tb_writer"].add_scalar(f"{mod_name}/PTQloss_{k}", v, Niter)
# plot cv, delta, zp, alpha, and lr
for k, v in m.named_buffers():
if any(kb in k for kb in ["delta", "zero_point", "clip_val"]):
if len(v.shape) > 0 and v.shape[0] > 1: # perCh
qcfg["tb_writer"].add_histogram(f"{mod_name}/{k}", v, Niter)
else:
qcfg["tb_writer"].add_scalar(f"{mod_name}/{k}", v, Niter)
for p, pname in zip(
optim_a.param_groups[0]["params"], param_names[1]
): # cva
qcfg["tb_writer"].add_scalar(f"{mod_name}/{pname}", p, Niter)
qcfg["tb_writer"].add_scalar(
f"{mod_name}/LR_cv_a", optim_a.param_groups[0]["lr"], Niter
)
for p, pname in zip(
optim_w.param_groups[0]["params"], param_names[0]
): # weights
qcfg["tb_writer"].add_histogram(f"{mod_name}/{pname}", p, Niter)
qcfg["tb_writer"].add_scalar(
f"{mod_name}/LR_w", optim_w.param_groups[0]["lr"], Niter
)
for p, pname in zip(
optim_w.param_groups[1]["params"], param_names[2]
): # cvw
if "alpha" in pname:
qcfg["tb_writer"].add_histogram(f"{mod_name}/{pname}", p, Niter)
else:
qcfg["tb_writer"].add_scalar(f"{mod_name}/{pname}", p, Niter)
qcfg["tb_writer"].add_scalar(
f"{mod_name}/LR_cvw", optim_w.param_groups[1]["lr"], Niter
)
if "adaround" in qcfg["qw_mode"]:
curr_beta = loss_func.beta_decay(loss_func.count)
qcfg["tb_writer"].add_scalar(
f"{mod_name}/AdaR_beta", curr_beta, Niter
)
for lidx, l in enumerate(layers):
if not hasattr(l, "quantize_m1"):
qcfg["tb_writer"].add_histogram(
f"{mod_name}/W{lidx}", l.weight, Niter
)
if hasattr(l.quantize_weight, "get_hard_targets"):
nzs = torch.count_nonzero(
l.quantize_weight.get_soft_targets()
- l.quantize_weight.get_hard_targets()
)
qcfg["tb_writer"].add_scalar(
f"{mod_name}/W{lidx}_AdaR_nonzeros(soft-hard)%",
nzs / l.weight.numel(),
Niter,
)
# almost never we will use bias in optimizer,
# unless we do bn folding and optimize w and b both
for s in scheduler:
s.step() # we set up scheduler based on Nouterloop, not inner
# if profiler: profiler.step() # for debug only
# Once finish optimizing this module,
# set all AdaR (if any) to real quantizer (soft_target = False)
if "adaround" in qcfg["qw_mode"]:
for l in layers:
if not hasattr(l, "quantize_m1") and hasattr(
l.quantize_weight, "soft_targets"
):
l.quantize_weight.soft_targets = False
return PTQloss
def calib_ptq_bn_tune(
qcfg, model, loader, PTQmod_candidates, batch_size, pre_cache_func=None
):
# from detectron2.utils.events import EventStorage
# from detectron2.layers import FrozenBatchNorm2d
# --- Prep --- set up calib, PTQ post-fwd-hooks, can set up block-wise optimization as well
if qcfg["PTQ_fold_BN"]:
mods_folded = []
search_fold_and_remove_bn(model, mods_folded)
logger.info(f"--- Quantized model after BN folding--- \n {model}\n")
else:
BNmods = [
m
for k, m in model.named_modules()
if isinstance(m, nn.BatchNorm2d) or "norm" in k
]
# re-init alpha and delta for all adaR in case any changes in weights
# weight changes could be due to a) bn folding or b) load pre-trained after qmodel_prep
logger.info(
" --- check and re-initialize AdaRound delta and alpha for all layers in PTQmod_candidates"
)
for m in PTQmod_candidates:
# all the sub-modules, including quantizers, and m itself, will be included in m.modules()
for sm in m.modules():
if isinstance(sm, (QConv2dPTQv2, QLinear)) and "adaround" in sm.qw_mode:
sm.quantize_weight.init_delta(sm.weight, sm.qw_mode)
sm.quantize_weight.init_alpha(sm.weight)
curr_dev = next(model.parameters()).device
if qcfg["PTQ_keepBNfrozenDuringOptim"]:
model.eval()
torch.set_grad_enabled(False)
# --- Phase 0 --- cache images
stratified_loader = False
loader_len = (
len(loader.dataset.dataset.dataset)
if "detectron2.modeling" in sys.modules
else len(loader)
) # NOTE detectron2 needs special handling
if (
qcfg["ptq_nbatch"] > 0 and loader_len < batch_size * qcfg["ptq_nbatch"]
): # if we need more than what loader has -> cache all
# NOTE original training set has ~117000 images, stratified subset
# will be a little larger than ptq_nbatch
stratified_loader = True
qcfg["ptq_nbatch"] = loader_len # cache all images in the loader
Nbatch_to_cache = qcfg["ptq_nbatch"]
else:
Nbatch_to_cache = max(
qcfg["ptq_nbatch"], qcfg["qmodel_calibration_new"] + qcfg["BN_tune"]
)
# cache images (to be placed on CPU mem)
qcfg["cached_imgs"] = []
qcfg["cached_lbls"] = []
pbar = tqdm(
loader, desc="Phase 0: PTQ caching images from loader", total=Nbatch_to_cache
)
for data_mb, _ in zip(pbar, range(Nbatch_to_cache)):
if pre_cache_func is not None:
imgs = pre_cache_func(data_mb)
qcfg["cached_imgs"].append(imgs)
else:
imgs, lbls = data_mb # unpack (imgs, lbls)
qcfg["cached_imgs"].append(imgs)
qcfg["cached_lbls"].append(lbls)
Nimgs_per_batch = len(qcfg["cached_imgs"][0])
# --- prep for fine-grained shuffling, cat [tensor(NCHW), tensor(NCHW), ...]
# into a single tensor(Nbatch,NCHW)only works for same size imgs, e.g. ImgNet !!
qcfg["cached_imgs"] = torch.stack(qcfg["cached_imgs"])
# using torch.stack, final shape for cached_imgs = [num_batch, batchsize, C, H, W]
# but in PTQoptim, "cached_input" will be 1) torch.cat into
# [num_batch*batchsize, C,H,W] then 2) shuffled
qcfg["cached_lbls"] = torch.cat(
qcfg["cached_lbls"]
) # easier if we just torch.cat lables into shape of [num_batch*batchsize]
# --- Phase 1 --- calibration of clip vals
if qcfg["qmodel_calibration_new"] > 0:
if not qcfg["PTQ_fold_BN"]:
BNmean = [m.running_mean.mean() for m in BNmods]
logger.info(
"Before calibration, (BN running mean).abs().mean() =",
torch.stack(BNmean).abs().mean(),
)
# set all QConv2d.Qdynamic under model to training mode, so that they
# will calc and update clip_vals
update_train_or_ptq_mode(model, set_mod_state="train", class2change=Qdynamic)
# this func does the following things:
# 1) make a list of m and its children if they are instances of class2change
# 2) set those layers' ptqmode to the given mode, if 'PTQmod=xxx' is specified and is
# in ['fp32_out', 'q_out']
# 3) for each layer, set layer.train() or layer.eval() if set_mod_state is specified.
pbar = tqdm(
qcfg["cached_imgs"],
desc="Phase 1: calibration",
total=qcfg["qmodel_calibration_new"],
)
for data_mb, Niters in zip(pbar, range(qcfg["qmodel_calibration_new"])):
if isinstance(data_mb, torch.Tensor):
data_mb = data_mb.to(
curr_dev
) # usually detectron2 will move data to device for us
model(
data_mb
) # just fwd(), no need to save outputs. kwargs for yolo test is just augment=True
# record clipvals
cv_sum_table = {}
Qmods = {
k: m
for k, m in model.named_modules()
if isinstance(m, (QConv2dPTQv2, QLinear))
}
for modname, m in Qmods.items():
cv_sum_table[modname] = [
None,
None,
None,
None,
] # will store "cv_a, cvn_a, cv_w, cvn_w"
Qparams = {k: v for k, v in m.named_parameters() if "quantize_" in k}
for k, v in Qparams.items():
if "alpha" not in k:
var_name = k.split("quantize_")[1]
var_idx = ("weight" in var_name) * 2 + (
"clip_valn" in var_name or "zero_point" in var_name
)
cv_sum_table[modname][var_idx] = v.item()
qcfg["tb_writer"].add_scalar(f"{modname}/{var_name}", v, Niters)
else:
# special handle for adaround, delta and zp are buffers, not parameters,
# use mean() in case perCh
cv_sum_table[modname][2] = m.quantize_weight.delta.mean().item()
cv_sum_table[modname][3] = (
m.quantize_weight.zero_point.mean().item()
)
qcfg["tb_writer"].add_scalar(
f"{modname}/delta", m.quantize_weight.delta.mean(), Niters
)
pd.options.display.float_format = "{:.4f}".format
dfCV = pd.DataFrame(cv_sum_table).T
dfCV.columns = (
["cv_a", "cvn_a", "cv_w", "cvn_w"]
if qcfg["qw_mode"] != "adaround"
else ["cv_a", "cvn_a", "w_delta", "w_zp"]
)
logger.info(dfCV)
if not qcfg["PTQ_fold_BN"]:
BNmeanAfterCalib = {
k: m.running_mean.mean()
for k, m in model.named_modules()
if isinstance(m, (nn.BatchNorm2d,)) # FrozenBatchNorm2d,
}
logger.info(
f"After calibration {qcfg['qmodel_calibration_new']},"
"(BN running mean).abs().mean() =",
torch.stack(list(BNmeanAfterCalib.values())).abs().mean(),
)
# --- Phase 2 --- PTQ
if (
qcfg["ptq_nbatch"] > 0 and qcfg["ptq_nouterloop"] > 0
): # default Ninner = 1 if not specified
Nsteps2acc = max(
qcfg.get("PTQ_Nimgs2acc", Nimgs_per_batch) // Nimgs_per_batch, 1
)
loss_scaling_acc = 1.0 / Nsteps2acc
Ntotal_iters = qcfg["ptq_nouterloop"] * qcfg["ptq_ninnerloop"]
Nouter_new = math.ceil(
qcfg["ptq_nouterloop"] / np.lcm(Nsteps2acc, qcfg["ptq_nbatch"])
) * np.lcm(Nsteps2acc, qcfg["ptq_nbatch"])
if stratified_loader:
logger.info(
f"Using stratified dataloader, Nouterloop is adjusted from"
f"{qcfg['ptq_nouterloop']} to {Nouter_new}"
)
qcfg["ptq_nouterloop"] = Nouter_new
# in detectron2, only model.train() will output losses,
# otherwise only output predictions (instances)
if "fisher" in qcfg["ptq_loss_func"]: