-
Notifications
You must be signed in to change notification settings - Fork 1k
Expand file tree
/
Copy pathpatterns.py
More file actions
913 lines (756 loc) · 30.3 KB
/
patterns.py
File metadata and controls
913 lines (756 loc) · 30.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import List, Tuple, Union
import torch
from executorch.backends.cadence.aot.quantizer.utils import get_bias_qparams
from torch import fx
from torch._ops import OpOverload
from torchao.quantization.pt2e.quantizer import (
DerivedQuantizationSpec,
SharedQuantizationSpec,
)
@dataclass
class PartitionAnchors:
"""
All fields except output are lists of (node, args_index) pair, where node is from
the given partition and node.args[args_index] is an input to the partition. Assumes
a single output.
Quantizer uses inputs, weights and biases for quantization annotation. The others
field contains tensor inputs that aren't quantized, and the literals fields contains
is used for other types of input values as well as handling default parameters.
"""
# Inputs can share quantization parameters
inputs: List[
Union[
Tuple[fx.Node, Union[int, Tuple[int, int]]],
Tuple[
fx.Node,
Union[int, Tuple[int, int]],
SharedQuantizationSpec,
],
]
] = field(default_factory=list)
weights: List[Tuple[fx.Node, int]] = field(default_factory=list)
biases: List[
Union[Tuple[fx.Node, int], Tuple[fx.Node, int, DerivedQuantizationSpec]]
] = field(default_factory=list)
others: List[Tuple[fx.Node, int]] = field(default_factory=list)
literals: List[Tuple[fx.Node, int]] = field(default_factory=list)
output: List[Union[Tuple[fx.Node], Tuple[fx.Node, SharedQuantizationSpec]]] = field(
default_factory=list
)
empty: bool = False
class QuantizationPattern(ABC):
@abstractmethod
def partition_types(self) -> list[OpOverload]:
"""
List of types to be passed to find_sequential_partitions_aten.
"""
pass
@abstractmethod
def get_anchors(
self, gm: torch.fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Tuple[PartitionAnchors, fx.Node]:
pass
@abstractmethod
def replacement_op(self) -> OpOverload:
"""
Operator (most likely a custom one) that this partition should be fused into in
the backend. Refer to the QuantFusion pass for examples.
"""
pass
class AddmmPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.addmm.default]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
addmm_node = fused_partition[0].nodes[-1]
bias_qspec = DerivedQuantizationSpec(
derived_from=[
(addmm_node.args[1], addmm_node),
(addmm_node.args[2], addmm_node),
],
derive_qparams_fn=get_bias_qparams,
dtype=torch.int32,
quant_min=-(2**31),
quant_max=2**31 - 1,
qscheme=torch.per_tensor_affine,
)
return (
PartitionAnchors(
inputs=[(addmm_node, 1)],
weights=[(addmm_node, 2)],
biases=[(addmm_node, 0, bias_qspec)],
output=[(addmm_node,)],
),
addmm_node,
)
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_linear.per_tensor
class AddPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.add.Tensor]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
add_node = fused_partition[0].nodes[-1]
# Bail if:
# - the add node is not a tensor add
# - the add node has kwargs (e.g. alpha)
is_tensor_add = isinstance(add_node.args[0], fx.Node) and isinstance(
add_node.args[1], fx.Node
)
if not is_tensor_add or len(add_node.kwargs) > 0:
return (
PartitionAnchors(
empty=True,
),
add_node,
)
return (
PartitionAnchors(
inputs=[(add_node, 0), (add_node, 1)],
weights=[],
biases=[],
output=[(add_node,)],
),
add_node,
)
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_add.per_tensor
# This is a base class for Add+ReLU fusion, since it can be used with two different relu aten ops
class AddReluBasePattern(QuantizationPattern):
@abstractmethod
def partition_types(self) -> List[OpOverload]:
pass
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Tuple[PartitionAnchors, fx.Node]:
# The first node should be add, the second should be relu
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
add_node = fused_partition[0].nodes[-1]
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
relu_node = fused_partition[1].nodes[-1]
# Bail if:
# - the add node is not a tensor add
# - the add node has kwargs (e.g. alpha)
is_tensor_add = isinstance(add_node.args[0], fx.Node) and isinstance(
add_node.args[1], fx.Node
)
if not is_tensor_add or len(add_node.kwargs) > 0:
return (
PartitionAnchors(
empty=True,
),
add_node,
)
return (
PartitionAnchors(
inputs=[(add_node, 0), (add_node, 1)],
weights=[],
biases=[],
output=[(relu_node,)], # Output is from the relu node
),
relu_node,
)
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_add.per_tensor
# Add + regular relu op fusion
class AddReluPattern0(AddReluBasePattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.add.Tensor, torch.ops.aten.relu.default]
# Add + alternate relu op fusion
class AddReluPattern1(AddReluBasePattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.add.Tensor, torch.ops.aten.relu_.default]
class BmmPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.bmm.default]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
bmm_node = fused_partition[0].nodes[-1]
return (
PartitionAnchors(
inputs=[(bmm_node, 0), (bmm_node, 1)],
weights=[],
biases=[],
output=[(bmm_node,)],
),
bmm_node,
)
def replacement_op(self) -> OpOverload:
# TODO: T240804887 This is actually a per-tensor variant,
# we just need to change the name of the op
return torch.ops.cadence.quantized_matmul.default
class CatPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.cat.default]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
cat_node = fused_partition[0].nodes[-1]
# Create args. The first argument does not have quant spec and
# will inherit from the overall quant spec. All subsequent args
# will share that spec.
# Note that outpus also share that spec.
args: List[
Union[
Tuple[fx.Node, Union[int, Tuple[int, int]]],
Tuple[
fx.Node,
Union[int, Tuple[int, int]],
SharedQuantizationSpec,
],
]
] = [(cat_node, (0, 0))]
for i in range(1, len(cat_node.args[0])):
args.append(
(
cat_node,
(0, i),
SharedQuantizationSpec((cat_node.args[0][0], cat_node)),
)
)
return (
PartitionAnchors(
inputs=args,
weights=[],
biases=[],
output=[
(cat_node, SharedQuantizationSpec((cat_node.args[0][0], cat_node)))
],
),
cat_node,
)
def replacement_op(self) -> OpOverload:
return torch.ops.aten.cat.default
class Conv1dPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.conv1d.default]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
conv1d_node = fused_partition[0].nodes[-1]
bias_qspec = DerivedQuantizationSpec(
derived_from=[
(conv1d_node.args[0], conv1d_node),
(conv1d_node.args[1], conv1d_node),
],
derive_qparams_fn=get_bias_qparams,
dtype=torch.int32,
quant_min=-(2**31),
quant_max=2**31 - 1,
qscheme=torch.per_tensor_affine,
)
# Keep bias empty if not supplied
bias = []
if len(conv1d_node.args) > 2 and conv1d_node.args[2] is not None:
bias = [(conv1d_node, 2, bias_qspec)]
return (
PartitionAnchors(
inputs=[(conv1d_node, 0)],
weights=[(conv1d_node, 1)],
# pyre-fixme[6]: Incompatible parameter type
biases=bias,
output=[(conv1d_node,)],
),
conv1d_node,
)
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_conv1d_ncl.per_tensor
class Conv2dPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.conv2d.default]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
conv2d_node = fused_partition[0].nodes[-1]
bias_qspec = DerivedQuantizationSpec(
derived_from=[
(conv2d_node.args[0], conv2d_node),
(conv2d_node.args[1], conv2d_node),
],
derive_qparams_fn=get_bias_qparams,
dtype=torch.int32,
quant_min=-(2**31),
quant_max=2**31 - 1,
qscheme=torch.per_tensor_affine,
)
# Keep bias empty if not supplied
bias = []
if len(conv2d_node.args) > 2 and conv2d_node.args[2] is not None:
bias = [(conv2d_node, 2, bias_qspec)]
return (
PartitionAnchors(
inputs=[(conv2d_node, 0)],
weights=[(conv2d_node, 1)],
# pyre-fixme[6]: Incompatible parameter type
biases=bias,
output=[(conv2d_node,)],
),
conv2d_node,
)
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_conv2d_nchw.per_tensor
class LayerNormPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.layer_norm.default]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
layer_norm_node = fused_partition[0].nodes[-1]
others = [(layer_norm_node, 1)]
# Add weights if supplied
if len(layer_norm_node.args) > 2 and layer_norm_node.args[2]:
others.append((layer_norm_node, 2))
# Add bias if supplied
if len(layer_norm_node.args) > 3 and layer_norm_node.args[3]:
others.append((layer_norm_node, 3))
# Weights are used in quantized mode by our kernel, so they are
# passed in as others here along with the normalized shape.
return (
PartitionAnchors(
inputs=[(layer_norm_node, 0)],
weights=[],
biases=[],
# Ordering: normalized_shape, weights, bias
others=others,
output=[(layer_norm_node,)],
),
layer_norm_node,
)
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_layer_norm.per_tensor
class LinearPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.linear.default]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
linear_node = fused_partition[0].nodes[-1]
bias_qspec = DerivedQuantizationSpec(
derived_from=[
(linear_node.args[0], linear_node),
(linear_node.args[1], linear_node),
],
derive_qparams_fn=get_bias_qparams,
dtype=torch.int32,
quant_min=-(2**31),
quant_max=2**31 - 1,
qscheme=torch.per_tensor_affine,
)
# Keep bias empty if not supplied
bias = []
if len(linear_node.args) > 2:
bias = [(linear_node, 2, bias_qspec)]
return (
PartitionAnchors(
inputs=[(linear_node, 0)],
weights=[(linear_node, 1)],
# pyre-fixme[6]: Incompatible parameter type
biases=bias,
output=[(linear_node,)],
),
linear_node,
)
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_linear.per_tensor
class MatmulPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.matmul.default]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
matmul_node = fused_partition[0].nodes[-1]
return (
PartitionAnchors(
inputs=[(matmul_node, 0), (matmul_node, 1)],
weights=[],
biases=[],
output=[(matmul_node,)],
),
matmul_node,
)
def replacement_op(self) -> OpOverload:
# TODO: T240804887 This is actually a per-tensor variant, we just need to change the name of the op
return torch.ops.cadence.quantized_matmul.default
class MaxPool2dPattern(QuantizationPattern):
"""
Pattern for quantized max pooling (with indices variant).
Max pooling is order-preserving, so max(a, b) in the quantized domain gives
the same result as quantizing max(dequant(a), dequant(b)) when using the same
scale/zero_point. This means we can perform max pooling directly on quantized
values without any requantization.
The input and output share quantization parameters.
"""
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.max_pool2d_with_indices.default]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
max_pool_node = fused_partition[0].nodes[-1]
# Input and output share quantization parameters since max is order-preserving
return (
PartitionAnchors(
inputs=[(max_pool_node, 0)],
weights=[],
biases=[],
# kernel_size, stride, padding, dilation, ceil_mode are literals
literals=[
(max_pool_node, i) for i in range(1, len(max_pool_node.args))
],
output=[
(
max_pool_node,
SharedQuantizationSpec((max_pool_node.args[0], max_pool_node)),
)
],
),
max_pool_node,
)
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_max_pool2d_nchw.default
class MaxPool2dWithoutIndicesPattern(QuantizationPattern):
"""
Pattern for quantized max pooling (without indices variant).
Same as MaxPool2dPattern but matches aten.max_pool2d.default which returns
a single tensor instead of a tuple (values, indices).
"""
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.max_pool2d.default]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
max_pool_node = fused_partition[0].nodes[-1]
return (
PartitionAnchors(
inputs=[(max_pool_node, 0)],
weights=[],
biases=[],
literals=[
(max_pool_node, i) for i in range(1, len(max_pool_node.args))
],
output=[
(
max_pool_node,
SharedQuantizationSpec((max_pool_node.args[0], max_pool_node)),
)
],
),
max_pool_node,
)
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_max_pool2d_nchw.default
# This is a base class for ReLU
# This is a base class for ReLU, since it can be used with two different aten ops
class ReluBasePattern(QuantizationPattern):
@abstractmethod
def partition_types(self) -> List[OpOverload]:
pass
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
relu_node = fused_partition[0].nodes[-1]
return (
PartitionAnchors(
inputs=[(relu_node, 0)],
weights=[],
biases=[],
output=[(relu_node,)],
),
relu_node,
)
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_relu.per_tensor
# Regular relu op
class ReluPattern0(ReluBasePattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.relu.default]
# Alternate relu op
class ReluPattern1(ReluBasePattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.relu_.default]
# This is a base class for Conv+ReLU fusion, since it can be used with two different relu aten ops
class ConvReluBasePattern(QuantizationPattern):
@abstractmethod
def partition_types(self) -> List[OpOverload]:
pass
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Tuple[PartitionAnchors, fx.Node]:
# The first node should be conv, the second should be relu
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
conv_node = fused_partition[0].nodes[-1] # Second to last node
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
relu_node = fused_partition[1].nodes[-1] # Last node
bias_qspec = DerivedQuantizationSpec(
derived_from=[
(conv_node.args[0], conv_node),
(conv_node.args[1], conv_node),
],
derive_qparams_fn=get_bias_qparams,
dtype=torch.int32,
quant_min=-(2**31),
quant_max=2**31 - 1,
qscheme=torch.per_tensor_affine,
)
# Keep bias empty if not supplied
bias = []
if len(conv_node.args) > 2 and conv_node.args[2] is not None:
bias = [(conv_node, 2, bias_qspec)]
return (
PartitionAnchors(
inputs=[(conv_node, 0)],
weights=[(conv_node, 1)],
# pyre-fixme[6]: Incompatible parameter type
biases=bias,
output=[(relu_node,)], # Output is from the relu node
),
relu_node,
)
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_conv2d_nchw.per_tensor
# Conv1d + regular relu op fusion
class Conv1dReluPattern0(ConvReluBasePattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.conv1d.default, torch.ops.aten.relu.default]
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_conv1d_ncl.per_tensor
# Conv1d + alternate relu op fusion
class Conv1dReluPattern1(ConvReluBasePattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.conv1d.default, torch.ops.aten.relu_.default]
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_conv1d_ncl.per_tensor
# Conv2d + regular relu op fusion
class Conv2dReluPattern0(ConvReluBasePattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.conv2d.default, torch.ops.aten.relu.default]
# Conv2d + alternate relu op fusion
class Conv2dReluPattern1(ConvReluBasePattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.conv2d.default, torch.ops.aten.relu_.default]
class SoftmaxPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten._softmax.default]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
softmax_node = fused_partition[0].nodes[-1]
return (
PartitionAnchors(
inputs=[(softmax_node, 0)],
weights=[],
biases=[],
output=[(softmax_node,)],
),
softmax_node,
)
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_softmax.per_tensor
class MixedW8A32LinearPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.linear.default]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-ignore[29]
linear_layer = fused_partition[0].nodes[-1]
# Bail if the arguments have different shapes than expected
if len(linear_layer.args) != 3 or len(linear_layer.kwargs) > 0:
return (
PartitionAnchors(
empty=True,
),
linear_layer,
)
input_node = linear_layer.args[0]
input_shape = input_node.meta["tensor_meta"].shape
# Bail if the weights are not multiple of 4 (SIMD)
if input_shape[-1] % 4 != 0:
return (
PartitionAnchors(
empty=True,
),
linear_layer,
)
# Currenly only supporting vector-matrix multiplication
if len(input_shape) > 0 and input_shape[-2] != 1:
return (
PartitionAnchors(
empty=True,
),
linear_layer,
)
return (
PartitionAnchors(
inputs=[],
weights=[(linear_layer, 1)],
biases=[(linear_layer, 2)],
output=[],
others=[(linear_layer, 0)],
),
linear_layer,
)
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_w8a32_linear.default
class MixedW8A32ConvPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.conv1d.default]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-ignore[29]
conv_layer = fused_partition[0].nodes[-1]
# Bail if the arguments have different shapes than expected
# Stride, padding, dilation and groups are not supported
if len(conv_layer.args) != 3 or len(conv_layer.kwargs) > 0:
return (
PartitionAnchors(
empty=True,
),
conv_layer,
)
cnn_weights = conv_layer.args[1]
if "tensor_meta" in cnn_weights.meta:
cnn_weights_shape = cnn_weights.meta["tensor_meta"].shape
# Bail if the channels are not multiple of 4 (SIMD)
if cnn_weights_shape[0] % 4 != 0:
return (
PartitionAnchors(
empty=True,
),
conv_layer,
)
if cnn_weights_shape[1] % 4 != 0:
return (
PartitionAnchors(
empty=True,
),
conv_layer,
)
# Bail if the kernel size is not 3
if cnn_weights_shape[2] != 3:
return (
PartitionAnchors(
empty=True,
),
conv_layer,
)
inputs = conv_layer.args[0]
if "tensor_meta" in inputs.meta:
inputs_shape = inputs.meta["tensor_meta"].shape
# Bail if length != kernel size - Not yet supported
if inputs_shape[-1] != cnn_weights_shape[2]:
return (
PartitionAnchors(
empty=True,
),
conv_layer,
)
return (
PartitionAnchors(
inputs=[],
weights=[(conv_layer, 1)],
biases=[(conv_layer, 2)],
output=[],
others=[(conv_layer, 0)],
),
conv_layer,
)
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_w8a32_conv.default
class MixedW8A32GruPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.gru.input]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
gru_layer = fused_partition[0].nodes[-1]
if len(gru_layer.kwargs) > 0:
return (
PartitionAnchors(
empty=True,
),
gru_layer,
)
# Bail if input or states are not multiple of 4 (SIMD)
tensor_meta_0 = gru_layer.args[0].meta.get("tensor_meta", None)
if tensor_meta_0 is None or tensor_meta_0.shape[-1] % 4 != 0:
return (
PartitionAnchors(
empty=True,
),
gru_layer,
)
tensor_meta_1 = gru_layer.args[1].meta.get("tensor_meta", None)
if tensor_meta_1 is None or tensor_meta_1.shape[-1] % 4 != 0:
return (
PartitionAnchors(
empty=True,
),
gru_layer,
)
class Wrapper: # noqa: B903
def __init__(self, args, meta):
self.args = args
self.meta = meta
wrapper = Wrapper(tuple(gru_layer.args[2]), gru_layer.meta)
# Using SharedQuantizationSpec so that bias_hh has the same observer as bias_ih
# Both biases get the same quantization scale to match the cpp operator
bias_ih_node = wrapper.args[2]
bias_ih_edge = (bias_ih_node, gru_layer)
shared_bias_qspec = SharedQuantizationSpec(edge_or_node=bias_ih_edge)
return (
PartitionAnchors(
inputs=[],
# pyre-fixme[6]: Expected `List[Tuple[Node, int]]` but got `List[Tuple[Wrapper, int]]`.
weights=[(wrapper, 0), (wrapper, 1)],
# pyre-fixme[6]: Expected `List[Union[Tuple[Node, int], Tuple[Node, int, DerivedQuantizationSpec]]]` but got `List[Tuple[Wrapper, int]]`.
biases=[
(wrapper, 2), # bias_ih gets normal qspec
(
wrapper,
3,
shared_bias_qspec,
), # bias_hh shares observer with bias_ih
],
output=[],
others=[(gru_layer, 0), (gru_layer, 1)],
),
gru_layer,
)
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_w8a32_gru.default
class RmsNormPattern(QuantizationPattern):
"""Pattern that preserves rms_norm from decomposition without matching anything."""
def partition_types(self) -> list[torch._ops.OpOverload]:
return [torch.ops.aten.rms_norm.default]
def get_anchors(
self, gm: torch.fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Tuple[PartitionAnchors, fx.Node]:
return PartitionAnchors(empty=True), None # pyre-ignore[7]
def replacement_op(self) -> torch._ops.OpOverload:
return torch.ops.aten.rms_norm.default