-
Notifications
You must be signed in to change notification settings - Fork 992
Expand file tree
/
Copy pathpatterns.py
More file actions
1258 lines (993 loc) · 38.7 KB
/
patterns.py
File metadata and controls
1258 lines (993 loc) · 38.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2025-2026 NXP
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
import torch
from executorch.backends.nxp.quantizer.utils import get_bias_qparams
from torch import fx
from torch._ops import OpOverload
from torch.fx import Node
from torchao.quantization.pt2e import (
FakeQuantize,
MovingAveragePerChannelMinMaxObserver,
PerChannelMinMaxObserver,
)
from torchao.quantization.pt2e.quantizer import (
DerivedQuantizationSpec,
FixedQParamsQuantizationSpec,
QuantizationSpec,
SharedQuantizationSpec,
)
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
@dataclass
class NodeArgsIdx:
"""
Specifies indexes to args paramater of Node in node input annotation.
Attributes:
idx (int): Index to Node's args paramater (list). Selects an input Node or a list of Nodes at the index.
inner_idx (int): If specified, index to a list pointed by 'idx' attribute. Selects an input Node at the index.
Default: None.
"""
idx: int
inner_idx: int = None
@dataclass
class PartitionAnchors:
"""
All fields except output are lists of (node, node_args_idx) or (node, node_args_idx, quantization_spec) tuples,
where node is from the given partition and node.args[node_args_idx] is an input to the partition. Assumes
a single output.
Quantizer uses inputs, weights and biases for quantization annotation. The others
field contains tensor inputs that aren't quantized, and the literals fields contains
is used for other types of input values as well as handling default parameters.
"""
# Inputs can share quantization parameters
inputs: list[
tuple[fx.Node, NodeArgsIdx]
| tuple[fx.Node, NodeArgsIdx, SharedQuantizationSpec],
] = field(default_factory=list)
weights: list[
tuple[fx.Node, NodeArgsIdx]
| tuple[fx.Node, NodeArgsIdx, QuantizationSpec | FakeQuantize],
] = field(default_factory=list)
biases: list[
tuple[fx.Node, NodeArgsIdx]
| tuple[fx.Node, NodeArgsIdx, DerivedQuantizationSpec],
] = field(default_factory=list)
others: list[tuple[fx.Node, NodeArgsIdx]] = field(default_factory=list)
literals: list[tuple[fx.Node, NodeArgsIdx]] = field(default_factory=list)
output: list[
tuple[fx.Node]
| tuple[
fx.Node,
FixedQParamsQuantizationSpec | SharedQuantizationSpec,
],
] = field(default_factory=list)
empty: bool = False
class QuantizationPattern(ABC):
def __init__(self, is_qat: bool = False):
self.is_qat = is_qat
@abstractmethod
def partition_types(self) -> list[OpOverload]:
"""
List of types to be passed to find_sequential_partitions_aten.
"""
pass
@abstractmethod
def get_anchors(
self, gm: torch.fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors | None:
pass
class SharedSpecPattern(QuantizationPattern):
"""
Quantization pattern for shared quantization.
The quantization is derived from the previous node quantization and the input and output shares the same
quantization parameters (scale and zero-point).
"""
@abstractmethod
def partition_types(self) -> list[torch.nn.Module]:
pass
def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors | None:
node = fused_partition[0].nodes[-1]
assert len(fused_partition[0].input_nodes) == 1
prev_node = fused_partition[0].input_nodes[0]
# Previous node was not quantized => we are not able to share q-params
if Q_ANNOTATION_KEY not in prev_node.meta:
return None
qspec = SharedQuantizationSpec(prev_node)
return PartitionAnchors(
inputs=[(node, NodeArgsIdx(0))],
weights=[],
biases=[],
output=[
(node, qspec),
],
)
class SingleInputBasicPattern(QuantizationPattern):
@abstractmethod
def partition_types(self) -> list[OpOverload]:
pass
def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors | None:
node = fused_partition[0].nodes[-1]
return PartitionAnchors(
inputs=[(node, NodeArgsIdx(0))],
weights=[],
biases=[],
output=[(node,)],
)
class BatchNormPattern(QuantizationPattern):
def __init__(self, is_qat: bool):
super().__init__(is_qat=is_qat)
def partition_types(self) -> list[OpOverload]:
# BatchNorm quantization is needed only when in QAT mode
return [torch.ops.aten.batch_norm.default] if self.is_qat else []
def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors | None:
node = fused_partition[0].nodes[-1]
return PartitionAnchors(
inputs=[],
weights=[],
biases=[],
output=[(node,)],
)
def get_anchors_for_fixed_quant_specs(
fused_partition: list[fx.GraphModule],
scale: float,
zero_point: int,
quant_min: int = -128,
quant_max: int = 127,
is_qat: bool = False,
) -> PartitionAnchors:
node = fused_partition[0].nodes[-1]
assert len(fused_partition[0].input_nodes) == 1
qspec_or_fake_quantize = FixedQParamsQuantizationSpec(
dtype=torch.int8,
scale=scale,
zero_point=zero_point,
quant_min=quant_min,
quant_max=quant_max,
qscheme=torch.per_tensor_affine,
)
return PartitionAnchors(
inputs=[(node, NodeArgsIdx(0))],
weights=[],
biases=[],
output=[
(node, qspec_or_fake_quantize),
],
)
class AbsPattern(SharedSpecPattern):
"""
Quantizer for Abs operator.
"""
def partition_types(self):
return [torch.ops.aten.abs.default]
class AdaptiveAvgPoolPattern(SharedSpecPattern):
"""
Quantizer for AdaptiveAvgPool2D operator.
"""
def partition_types(self):
return [torch.ops.aten.adaptive_avg_pool2d.default]
class AddmmPattern(QuantizationPattern):
def __init__(self, neutron_quantizer, is_qat: bool):
super().__init__(is_qat=is_qat)
self.neutron_quantizer = neutron_quantizer
self.neutron_target_info = (
self.neutron_quantizer.neutron_target_spec.neutron_target_info
)
def partition_types(self) -> list[OpOverload]:
return [torch.ops.aten.addmm.default]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors:
addmm_node = fused_partition[0].nodes[-1]
bias_qspec = DerivedQuantizationSpec(
derived_from=[
(addmm_node.args[1], addmm_node),
(addmm_node.args[2], addmm_node),
],
derive_qparams_fn=get_bias_qparams,
dtype=torch.int32,
quant_min=-(2**31),
quant_max=2**31 - 1,
qscheme=torch.per_tensor_affine,
)
# If the following node is a fusable activation, quantize together with activation
output = [(addmm_node,)]
if len(
addmm_node.users
) == 1 and self.neutron_target_info.is_supported_fused_activation__aten(
activation := next(iter(addmm_node.users))
):
activation_quantizer = self.neutron_quantizer.op_to_quantizer[
activation.target
]
activation_quantizer.annotate(gm)
output = []
activation.meta["quantization_annotation"].input_qspec_map = {}
return PartitionAnchors(
inputs=[(addmm_node, NodeArgsIdx(1))],
weights=[(addmm_node, NodeArgsIdx(2))],
biases=[(addmm_node, NodeArgsIdx(0), bias_qspec)],
output=output,
)
class AddTensorPattern(QuantizationPattern):
"""
Quantization pattern for Add Tensor quantization. Accepts 1 or 2 input nodes.
Basic quantization for all inputs and output.
"""
def partition_types(self) -> list[torch.nn.Module]:
return [torch.ops.aten.add.Tensor]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors | None:
node = fused_partition[0].nodes[-1]
inputs = [(node, NodeArgsIdx(0))]
if len(fused_partition[0].input_nodes) == 2:
inputs = [(node, NodeArgsIdx(0)), (node, NodeArgsIdx(1))]
return PartitionAnchors(
inputs=inputs,
weights=[],
biases=[],
output=[(node,)],
)
class BMMPattern(QuantizationPattern):
"""
Quantizer for BatchMatMul operator.
"""
def partition_types(self) -> list[torch.nn.Module]:
return [torch.ops.aten.bmm.default]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors | None:
bmm_node = fused_partition[0].nodes[-1]
return PartitionAnchors(
inputs=[
(bmm_node, NodeArgsIdx(0)),
(bmm_node, NodeArgsIdx(1)),
],
biases=[],
output=[(bmm_node,)],
)
class SubTensorPattern(QuantizationPattern):
"""
Quantization pattern for Sub Tensor quantization. Accepts 1 or 2 input nodes.
Basic quantization for all inputs and output.
"""
def partition_types(self) -> list[torch.nn.Module]:
return [torch.ops.aten.sub.Tensor]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors | None:
node = fused_partition[0].nodes[-1]
inputs = [(node, NodeArgsIdx(0))]
if len(fused_partition[0].input_nodes) == 2:
inputs = [(node, NodeArgsIdx(0)), (node, NodeArgsIdx(1))]
return PartitionAnchors(
inputs=inputs,
weights=[],
biases=[],
output=[(node,)],
)
class AvgPool1DPattern(SharedSpecPattern):
"""
Quantizer for AvgPool1D operator.
"""
def partition_types(self):
return [torch.ops.aten.avg_pool1d.default]
class AvgPool2DPattern(SharedSpecPattern):
"""
Quantizer for AvgPool2D operator.
"""
def partition_types(self):
return [torch.ops.aten.avg_pool2d.default]
class CatPattern(QuantizationPattern):
"""
Quantizer for the Cat operator. The pattern is designed for the `NeutronAtenQuantizer`.
The node can have an arbitrary number of inputs, which are all quantized.
"""
def partition_types(self) -> list[OpOverload]:
return [torch.ops.aten.cat.default]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors | None:
node = fused_partition[0].nodes[-1]
quantized_input = None
for prev_node in node.args[0]:
if "quantization_annotation" in prev_node.meta:
quantized_input = prev_node
break
if quantized_input is not None:
inputs = []
for idx, _ in enumerate(node.args[0]):
inputs.append(
(node, NodeArgsIdx(0, idx), SharedQuantizationSpec(quantized_input))
)
outputs = [(node, SharedQuantizationSpec(quantized_input))]
else:
# No previous node was quantized => we are not able to share q-params. The conversion to IR will have to
# re-quantize the inputs if necessary.
inputs = [(node, NodeArgsIdx(0, idx)) for idx in range(len(node.args[0]))]
outputs = [(node,)]
return PartitionAnchors(
inputs=inputs,
weights=[],
biases=[],
output=outputs,
)
class ClampPattern(SingleInputBasicPattern):
"""Quantizer for the `aten.clamp.default` operator."""
def partition_types(self):
return [torch.ops.aten.clamp.default]
def _is_batch_norm(node_: Node) -> bool:
return node_.op == "call_function" and node_.target in [
torch.ops.aten.batch_norm.default,
torch.ops.aten.native_batch_norm.default,
torch.ops.aten._native_batch_norm_legit_no_training.default,
]
class ConvPattern(QuantizationPattern):
@abstractmethod
def partition_types(self) -> list[OpOverload]:
pass
def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors:
conv_node = fused_partition[0].nodes[-1]
bias_quantization_qspec = DerivedQuantizationSpec(
derived_from=[
(conv_node.args[0], conv_node),
(conv_node.args[1], conv_node),
],
derive_qparams_fn=get_bias_qparams,
dtype=torch.int32,
quant_min=-(2**31) + 1,
quant_max=2**31 - 1,
qscheme=torch.per_channel_symmetric,
ch_axis=0,
)
weight_observer_or_fake_quant_ctr = (
FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver)
if self.is_qat
else PerChannelMinMaxObserver
)
weight_quantization_spec = QuantizationSpec(
dtype=torch.int8,
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr,
quant_min=-127,
quant_max=127,
qscheme=torch.per_channel_symmetric,
ch_axis=0,
)
# Keep bias empty if not supplied
bias = []
if len(conv_node.args) > 2 and conv_node.args[2] is not None:
bias = [(conv_node, NodeArgsIdx(2), bias_quantization_qspec)]
output_specs = [(conv_node,)]
# In order for QAT to be numerically correct, there should be no quantization between
# convolution node and batch norm node.
if self.is_qat:
conv_users = conv_node.users
possibly_bn = list(conv_users.keys())[0] if len(conv_users) == 1 else None
if possibly_bn and _is_batch_norm(possibly_bn):
output_specs = []
return PartitionAnchors(
inputs=[(conv_node, NodeArgsIdx(0))],
weights=[(conv_node, NodeArgsIdx(1), weight_quantization_spec)],
biases=bias,
output=output_specs,
)
class Conv1dPattern(ConvPattern):
def partition_types(self) -> list[OpOverload]:
return [torch.ops.aten.conv1d.default]
class ConvTranspose1dPattern(ConvPattern):
def partition_types(self) -> list[OpOverload]:
return [torch.ops.aten.conv_transpose1d.default]
class Conv2dPattern(ConvPattern):
def __init__(self, neutron_quantizer, is_qat: bool = False):
super().__init__(is_qat=is_qat)
self.neutron_quantizer = neutron_quantizer
self.neutron_target_info = (
self.neutron_quantizer.neutron_target_spec.neutron_target_info
)
def partition_types(self) -> list[OpOverload]:
return [torch.ops.aten.conv2d.default]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors:
conv_node = fused_partition[0].nodes[-1]
bias_quantization_qspec = DerivedQuantizationSpec(
derived_from=[
(conv_node.args[0], conv_node),
(conv_node.args[1], conv_node),
],
derive_qparams_fn=get_bias_qparams,
dtype=torch.int32,
quant_min=-(2**31) + 1,
quant_max=2**31 - 1,
qscheme=torch.per_channel_symmetric,
ch_axis=0,
)
weight_observer_or_fake_quant_ctr = (
FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver)
if self.is_qat
else PerChannelMinMaxObserver
)
weight_quantization_spec = QuantizationSpec(
dtype=torch.int8,
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr,
quant_min=-127,
quant_max=127,
qscheme=torch.per_channel_symmetric,
ch_axis=0,
)
# Keep bias empty if not supplied
bias = []
if len(conv_node.args) > 2 and conv_node.args[2] is not None:
bias = [(conv_node, NodeArgsIdx(2), bias_quantization_qspec)]
# If the following node is a fusable activation, quantize together with activation
output = [(conv_node,)]
if len(conv_node.users) == 1 and (
self.neutron_target_info.is_supported_fused_activation__aten(
activation := next(iter(conv_node.users))
)
or (
self.is_qat
and _is_batch_norm(activation)
and self.neutron_target_info.is_supported_fused_activation__aten(
activation := next(iter(activation.users))
)
)
):
activation_quantizer = self.neutron_quantizer.op_to_quantizer[
activation.target
]
activation_quantizer.annotate(gm)
output = []
activation.meta["quantization_annotation"].input_qspec_map = {}
if isinstance(bn := next(iter(conv_node.users)), Node) and _is_batch_norm(
bn
):
bn_quantizer = self.neutron_quantizer.op_to_quantizer[bn.target]
bn_quantizer.annotate(gm)
bn.meta["quantization_annotation"].input_qspec_map = {}
bn.meta["quantization_annotation"].output_qspec = None
# In order for QAT to be numerically correct, there should be no quantization between
# convolution node and batch norm node.
if self.is_qat:
conv_users = conv_node.users
possibly_bn = list(conv_users.keys())[0] if len(conv_users) == 1 else None
if possibly_bn and _is_batch_norm(possibly_bn):
output = []
return PartitionAnchors(
inputs=[(conv_node, NodeArgsIdx(0))],
weights=[(conv_node, NodeArgsIdx(1), weight_quantization_spec)],
biases=bias,
output=output,
)
class ConvTranspose2dPattern(QuantizationPattern):
def partition_types(self) -> list[OpOverload]:
return [torch.ops.aten.conv_transpose2d.input]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors:
conv_node = fused_partition[0].nodes[-1]
bias_quantization_qspec = DerivedQuantizationSpec(
derived_from=[
(conv_node.args[0], conv_node),
(conv_node.args[1], conv_node),
],
derive_qparams_fn=get_bias_qparams,
dtype=torch.int32,
quant_min=-(2**31) + 1,
quant_max=2**31 - 1,
qscheme=torch.per_channel_symmetric,
ch_axis=0,
)
weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
weight_quantization_spec = QuantizationSpec(
dtype=torch.int8,
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr,
quant_min=-127,
quant_max=127,
qscheme=torch.per_channel_symmetric,
ch_axis=1,
)
# Keep bias empty if not supplied
bias = []
if len(conv_node.args) > 2 and conv_node.args[2] is not None:
bias = [(conv_node, NodeArgsIdx(2), bias_quantization_qspec)]
output_specs = [(conv_node,)]
# In order for QAT to be numerically correct, there should be no quantization between
# convolution node and batch norm node.
if self.is_qat:
conv_users = conv_node.users
possibly_bn = list(conv_users.keys())[0] if len(conv_users) == 1 else None
if possibly_bn and _is_batch_norm(possibly_bn):
output_specs = []
return PartitionAnchors(
inputs=[(conv_node, NodeArgsIdx(0))],
weights=[(conv_node, NodeArgsIdx(1), weight_quantization_spec)],
biases=bias,
output=output_specs,
)
class DropoutPattern(SharedSpecPattern):
"""
Quantizer for Dropout operator.
"""
def partition_types(self):
return [torch.ops.aten.dropout.default]
class FlattenPattern(SharedSpecPattern):
"""
Quantizer for Flatten operator.
"""
def partition_types(self):
return [torch.ops.aten.flatten.using_ints]
class HardTanhPattern(SingleInputBasicPattern):
"""
Quantizer for HardTanh operator.
"""
def partition_types(self):
return [torch.ops.aten.hardtanh.default]
def replacement_op(self):
raise AssertionError()
class HardTanhInPlacePattern(SingleInputBasicPattern):
"""
Quantizer for HardTanh operator with param inplace=True.
"""
def partition_types(self):
return [torch.ops.aten.hardtanh_.default]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors | None:
node = fused_partition[0].nodes[-1]
return PartitionAnchors(
inputs=[(node, NodeArgsIdx(0))],
weights=[],
biases=[],
output=[(node,)],
)
def replacement_op(self):
raise AssertionError()
class LeakyReluPattern(SingleInputBasicPattern):
"""Quantizer for the `aten.leaky_relu.default` operator."""
def partition_types(self):
return [torch.ops.aten.leaky_relu.default]
class LeakyReluInPlacePattern(SingleInputBasicPattern):
"""Quantizer for the `aten.leaky_relu.default` operator, with the parameter `inplace=True`."""
def partition_types(self):
return [torch.ops.aten.leaky_relu_.default]
class LinearPattern(QuantizationPattern):
def __init__(self, neutron_quantizer, is_qat: bool = False):
super().__init__(is_qat=is_qat)
self.neutron_quantizer = neutron_quantizer
self.neutron_target_info = (
self.neutron_quantizer.neutron_target_spec.neutron_target_info
)
def partition_types(self) -> list[OpOverload]:
return [torch.ops.aten.linear.default]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors:
linear_node = fused_partition[0].nodes[-1]
bias_qspec = DerivedQuantizationSpec(
derived_from=[
(linear_node.args[0], linear_node),
(linear_node.args[1], linear_node),
],
derive_qparams_fn=get_bias_qparams,
dtype=torch.int32,
quant_min=-(2**31),
quant_max=2**31 - 1,
qscheme=torch.per_tensor_affine,
)
# Keep bias empty if not supplied
bias = []
if len(linear_node.args) > 2:
bias = [(linear_node, NodeArgsIdx(2), bias_qspec)]
# If the following node is a fusable activation, quantize together with activation
output = [(linear_node,)]
if (
len(linear_node.users) == 1
and len(linear_node.meta["val"].shape) <= 2
and self.neutron_target_info.is_supported_fused_activation__aten(
activation := next(iter(linear_node.users))
)
):
activation_quantizer = self.neutron_quantizer.op_to_quantizer[
activation.target
]
activation_quantizer.annotate(gm)
output = []
activation.meta["quantization_annotation"].input_qspec_map = {}
# In order for QAT to be numerically correct, there should be no quantization between
# linear node and batch norm node.
if self.is_qat:
linear_users = linear_node.users
possibly_bn = (
list(linear_users.keys())[0] if len(linear_users) == 1 else None
)
if possibly_bn and _is_batch_norm(possibly_bn):
output = []
return PartitionAnchors(
inputs=[(linear_node, NodeArgsIdx(0))],
weights=[(linear_node, NodeArgsIdx(1))],
biases=bias,
output=output,
)
class MaxPool1DPattern(SharedSpecPattern):
"""Quantizer for the MaxPool1D operator."""
def partition_types(self):
return [torch.ops.aten.max_pool1d.default]
class MaxPool2DPattern(SharedSpecPattern):
"""Quantizer for the MaxPool2D operator."""
def partition_types(self):
return [torch.ops.aten.max_pool2d.default]
class MeanDimPattern(SharedSpecPattern):
"""
Quantizer for Mean Dim operator.
"""
def partition_types(self):
return [torch.ops.aten.mean.dim]
class MmPattern(QuantizationPattern):
def __init__(self, neutron_quantizer, is_qat: bool = False):
super().__init__(is_qat=is_qat)
self.neutron_quantizer = neutron_quantizer
self.neutron_target_info = (
self.neutron_quantizer.neutron_target_spec.neutron_target_info
)
def partition_types(self) -> list[OpOverload]:
return [torch.ops.aten.mm.default]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors:
mm_node = fused_partition[0].nodes[-1]
# If the following node is a fusable activation, quantize together with activation
output = [(mm_node,)]
if len(
mm_node.users
) == 1 and self.neutron_target_info.is_supported_fused_activation__aten(
activation := next(iter(mm_node.users))
):
activation_quantizer = self.neutron_quantizer.op_to_quantizer[
activation.target
]
activation_quantizer.annotate(gm)
output = []
activation.meta["quantization_annotation"].input_qspec_map = {}
return PartitionAnchors(
inputs=[(mm_node, NodeArgsIdx(0))],
weights=[(mm_node, NodeArgsIdx(1))],
biases=[],
output=output,
)
class MulTensorPattern(QuantizationPattern):
"""
Quantization pattern for Mul Tensor quantization. Accepts 1 or 2 input nodes.
Basic quantization for all inputs and output.
"""
def partition_types(self) -> list[torch.nn.Module]:
return [torch.ops.aten.mul.Tensor]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors | None:
node = fused_partition[0].nodes[-1]
input_nodes = node.all_input_nodes
qspec = FixedQParamsQuantizationSpec(
dtype=torch.int8,
scale=1.0 / 256.0,
zero_point=0,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_affine,
)
# The "Mul" operator in Neutron IR requires a specific scale and zero_point
# (defined above) for its inputs.
# Since these input nodes have already been annotated by their own patterns
# which didn't take the requirements of "Mul" into account, we need to overwrite
# the existing "quantization_annotation".
for input_node in input_nodes:
if "quantization_annotation" in input_node.meta:
input_node.meta["quantization_annotation"].output_qspec = qspec
return PartitionAnchors(
inputs=[(node, NodeArgsIdx(0), qspec), (node, NodeArgsIdx(1), qspec)],
weights=[],
biases=[],
output=[
(node,),
],
)
class NegPattern(SharedSpecPattern):
"""
Quantizer for the `aten.neg.default` operator.
"""
def partition_types(self):
return [torch.ops.aten.neg.default]
class PadPattern(SharedSpecPattern):
"""
Quantizer for Pad operator.
"""
def partition_types(self):
return [torch.ops.aten.pad.default]
class PermutePattern(SharedSpecPattern):
"""
Quantizer for Permute operator.
"""
def partition_types(self):
return [torch.ops.aten.permute.default]
class PReLUPattern(QuantizationPattern):
"""
Quantizer for PReLU operator.
"""
def partition_types(self) -> list[OpOverload]:
return [torch.ops.aten.prelu.default]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors:
node = fused_partition[0].nodes[-1]
inputs = [(node, NodeArgsIdx(0))]
weights = [(node, NodeArgsIdx(1))]
output = [(node,)]
return PartitionAnchors(
inputs=inputs,
weights=weights,
biases=[],
output=output,
)
class TransposeIntPattern(SharedSpecPattern):
"""
Quantizer for Transpose Int operator.
"""
def partition_types(self) -> list[OpOverload]:
return [torch.ops.aten.transpose.int]
class ReluPattern(SingleInputBasicPattern):
"""
Quantizer for Relu operator.
"""
def partition_types(self):
return [torch.ops.aten.relu.default]
class ReluInPlacePattern(SingleInputBasicPattern):
"""
Quantizer for Relu operator with param inplace=True.
"""
def partition_types(self):
return [torch.ops.aten.relu_.default]
class ReshapePattern(SharedSpecPattern):
"""
Quantizer for Reshape operator.
"""
def partition_types(self):
return [torch.ops.aten.reshape.default]
class ViewPattern(SharedSpecPattern):
"""
Quantizer for View operator.
"""
def partition_types(self):
return [torch.ops.aten.view.default]
class SliceTensorPattern(SharedSpecPattern):
"""
Quantizer for Slice operator.
"""
def partition_types(self):
return [torch.ops.aten.slice.Tensor]
class SoftMaxPattern(QuantizationPattern):
"""
Quantizer for Softmax operator.
The quantization of Softmax output is fixed to scale 1/256, zero point -128, dtype int8.
"""
def partition_types(self) -> list[OpOverload]:
return [torch.ops.aten.softmax.int]