Skip to content

Commit 0e6b67e

Browse files
authored
Add fuse() to QuantizationPatterns (pytorch#19726)
Differential Revision: D105728156 Pull Request resolved: pytorch#19726
1 parent 29c18de commit 0e6b67e

2 files changed

Lines changed: 264 additions & 2 deletions

File tree

backends/cadence/aot/quantizer/BUCK

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,10 @@ fbcode_target(_kind = runtime.python_library,
3636
],
3737
typing = True,
3838
deps = [
39+
":pattern_utils",
3940
":utils",
4041
"//caffe2:torch",
42+
"//executorch/backends/cadence/aot:pass_utils",
4143
],
4244
)
4345

backends/cadence/aot/quantizer/patterns.py

Lines changed: 262 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,19 @@
1212
from typing import List, Optional, Tuple, Union
1313

1414
import torch
15-
from executorch.backends.cadence.aot.quantizer.utils import get_bias_qparams
16-
15+
from executorch.backends.cadence.aot.pass_utils import get_arg, replace_with_op
16+
from executorch.backends.cadence.aot.quantizer.pattern_utils import (
17+
DQ_PER_TENSOR,
18+
find_quant_user,
19+
fuse_conv,
20+
fuse_linear,
21+
fuse_matmul,
22+
insert_node_with_meta,
23+
)
24+
from executorch.backends.cadence.aot.quantizer.utils import (
25+
check_out_zero_point_is_min_range,
26+
get_bias_qparams,
27+
)
1728
from torch import fx
1829
from torch._ops import OpOverload
1930
from torchao.quantization.pt2e.quantizer import (
@@ -131,6 +142,41 @@ def get_anchors(
131142
def replacement_op(self) -> OpOverload:
132143
return torch.ops.cadence.quantized_linear.per_tensor
133144

145+
def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None:
146+
assert anchor_node.target == torch.ops.aten.addmm.default
147+
# addmm(bias, input, weight)
148+
bias_node = anchor_node.args[0]
149+
assert isinstance(bias_node, fx.Node)
150+
dq_input = get_arg(anchor_node, "mat1", fx.Node)
151+
if dq_input.target != DQ_PER_TENSOR:
152+
return None
153+
dq_weight = get_arg(anchor_node, "mat2", fx.Node)
154+
if dq_weight.target != DQ_PER_TENSOR:
155+
return None
156+
quant_node = find_quant_user(anchor_node)
157+
if quant_node is None:
158+
return None
159+
dq_bias = bias_node if bias_node.target == DQ_PER_TENSOR else None
160+
weight_q = get_arg(dq_weight, "input", fx.Node)
161+
transposed = insert_node_with_meta(
162+
gm,
163+
torch.ops.aten.transpose.int,
164+
(weight_q, 0, 1),
165+
None,
166+
anchor_node,
167+
weight_q,
168+
)
169+
return fuse_linear(
170+
gm,
171+
dq_input,
172+
dq_weight,
173+
dq_bias,
174+
quant_node,
175+
anchor_node,
176+
self.replacement_op(),
177+
weight_q=transposed,
178+
)
179+
134180

135181
class AddPattern(QuantizationPattern):
136182
def partition_types(self) -> List[OpOverload]:
@@ -169,6 +215,33 @@ def get_anchors(
169215
def replacement_op(self) -> OpOverload:
170216
return torch.ops.cadence.quantized_add.per_tensor
171217

218+
def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None:
219+
# Skip if alpha kwarg is present — changes add semantics.
220+
if anchor_node.kwargs:
221+
return None
222+
dq0 = anchor_node.args[0]
223+
if not isinstance(dq0, fx.Node) or dq0.target != DQ_PER_TENSOR:
224+
return None
225+
dq1 = anchor_node.args[1]
226+
if not isinstance(dq1, fx.Node) or dq1.target != DQ_PER_TENSOR:
227+
return None
228+
quant_node = find_quant_user(anchor_node)
229+
if quant_node is None:
230+
return None
231+
args = (
232+
get_arg(dq0, "input", fx.Node),
233+
get_arg(dq0, "scale", float),
234+
get_arg(dq0, "zero_point", int),
235+
get_arg(dq1, "input", fx.Node),
236+
get_arg(dq1, "scale", float),
237+
get_arg(dq1, "zero_point", int),
238+
get_arg(quant_node, "scale", float),
239+
get_arg(quant_node, "zero_point", int),
240+
)
241+
return replace_with_op(
242+
gm, anchor_node, self.replacement_op(), args, {}, quant_node
243+
)
244+
172245

173246
# This is a base class for Add+ReLU fusion, since it can be used with two different relu aten ops
174247
class AddReluBasePattern(QuantizationPattern):
@@ -212,6 +285,46 @@ def get_anchors(
212285
def replacement_op(self) -> OpOverload:
213286
return torch.ops.cadence.quantized_add.per_tensor
214287

288+
def anchor_ops(self) -> tuple[OpOverload, ...]:
289+
return (torch.ops.aten.add.Tensor,)
290+
291+
def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None:
292+
add_users = list(anchor_node.users)
293+
if len(add_users) != 1:
294+
return None
295+
relu_node = add_users[0]
296+
if relu_node.target != self.partition_types()[1]:
297+
return None
298+
if len(anchor_node.kwargs) > 0:
299+
return None
300+
dq0 = anchor_node.args[0]
301+
if not isinstance(dq0, fx.Node) or dq0.target != DQ_PER_TENSOR:
302+
return None
303+
dq1 = anchor_node.args[1]
304+
if not isinstance(dq1, fx.Node) or dq1.target != DQ_PER_TENSOR:
305+
return None
306+
quant_node = find_quant_user(relu_node)
307+
if quant_node is None:
308+
return None
309+
if not check_out_zero_point_is_min_range(
310+
get_arg(quant_node, "zero_point", int),
311+
get_arg(quant_node, "dtype", torch.dtype),
312+
):
313+
return None
314+
args = (
315+
get_arg(dq0, "input", fx.Node),
316+
get_arg(dq0, "scale", float),
317+
get_arg(dq0, "zero_point", int),
318+
get_arg(dq1, "input", fx.Node),
319+
get_arg(dq1, "scale", float),
320+
get_arg(dq1, "zero_point", int),
321+
get_arg(quant_node, "scale", float),
322+
get_arg(quant_node, "zero_point", int),
323+
)
324+
return replace_with_op(
325+
gm, anchor_node, self.replacement_op(), args, {}, quant_node
326+
)
327+
215328

216329
# Add + regular relu op fusion
217330
class AddReluPattern0(AddReluBasePattern):
@@ -250,6 +363,18 @@ def replacement_op(self) -> OpOverload:
250363
# we just need to change the name of the op
251364
return torch.ops.cadence.quantized_matmul.default
252365

366+
def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None:
367+
dq0 = anchor_node.args[0]
368+
if not isinstance(dq0, fx.Node) or dq0.target != DQ_PER_TENSOR:
369+
return None
370+
dq1 = anchor_node.args[1]
371+
if not isinstance(dq1, fx.Node) or dq1.target != DQ_PER_TENSOR:
372+
return None
373+
quant_node = find_quant_user(anchor_node)
374+
if quant_node is None:
375+
return None
376+
return fuse_matmul(gm, anchor_node, dq0, dq1, quant_node, self.replacement_op())
377+
253378

254379
class CatPattern(QuantizationPattern):
255380
def partition_types(self) -> List[OpOverload]:
@@ -299,6 +424,25 @@ def get_anchors(
299424
def replacement_op(self) -> OpOverload:
300425
return torch.ops.aten.cat.default
301426

427+
def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None:
428+
cat_inputs = anchor_node.args[0]
429+
if not isinstance(cat_inputs, (list, tuple)) or not cat_inputs:
430+
return None
431+
inputs_q = []
432+
for inp in cat_inputs:
433+
if not isinstance(inp, fx.Node) or inp.target != DQ_PER_TENSOR:
434+
return None
435+
inputs_q.append(get_arg(inp, "input", fx.Node))
436+
quant_node = find_quant_user(anchor_node)
437+
if quant_node is None:
438+
return None
439+
dim = get_arg(anchor_node, "dim", int)
440+
args = (inputs_q,)
441+
kwargs = {"dim": dim}
442+
return replace_with_op(
443+
gm, anchor_node, self.replacement_op(), args, kwargs, quant_node
444+
)
445+
302446

303447
class Conv1dPattern(QuantizationPattern):
304448
def partition_types(self) -> List[OpOverload]:
@@ -341,6 +485,18 @@ def get_anchors(
341485
def replacement_op(self) -> OpOverload:
342486
return torch.ops.cadence.quantized_conv1d_ncl.per_tensor
343487

488+
def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None:
489+
dq_input = anchor_node.args[0]
490+
if not isinstance(dq_input, fx.Node) or dq_input.target != DQ_PER_TENSOR:
491+
return None
492+
dq_weight = anchor_node.args[1]
493+
if not isinstance(dq_weight, fx.Node) or dq_weight.target != DQ_PER_TENSOR:
494+
return None
495+
quant_node = find_quant_user(anchor_node)
496+
if quant_node is None:
497+
return None
498+
return fuse_conv(self, gm, anchor_node, dq_input, dq_weight, quant_node)
499+
344500

345501
class Conv2dPattern(QuantizationPattern):
346502
def partition_types(self) -> List[OpOverload]:
@@ -383,6 +539,18 @@ def get_anchors(
383539
def replacement_op(self) -> OpOverload:
384540
return torch.ops.cadence.quantized_conv2d_nchw.per_tensor
385541

542+
def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None:
543+
dq_input = anchor_node.args[0]
544+
if not isinstance(dq_input, fx.Node) or dq_input.target != DQ_PER_TENSOR:
545+
return None
546+
dq_weight = anchor_node.args[1]
547+
if not isinstance(dq_weight, fx.Node) or dq_weight.target != DQ_PER_TENSOR:
548+
return None
549+
quant_node = find_quant_user(anchor_node)
550+
if quant_node is None:
551+
return None
552+
return fuse_conv(self, gm, anchor_node, dq_input, dq_weight, quant_node)
553+
386554

387555
class LayerNormPattern(QuantizationPattern):
388556
def partition_types(self) -> List[OpOverload]:
@@ -421,6 +589,61 @@ def get_anchors(
421589
def replacement_op(self) -> OpOverload:
422590
return torch.ops.cadence.quantized_layer_norm.per_tensor
423591

592+
def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None:
593+
dq_input = anchor_node.args[0]
594+
if not isinstance(dq_input, fx.Node) or dq_input.target != DQ_PER_TENSOR:
595+
return None
596+
quant_node = find_quant_user(anchor_node)
597+
if quant_node is None:
598+
return None
599+
scale = get_arg(dq_input, "scale", float)
600+
zero_point = get_arg(dq_input, "zero_point", int)
601+
normalized_shape = anchor_node.args[1]
602+
assert isinstance(normalized_shape, list)
603+
weight = (
604+
anchor_node.args[2]
605+
if len(anchor_node.args) > 2 and anchor_node.args[2]
606+
else None
607+
)
608+
bias = (
609+
anchor_node.args[3]
610+
if len(anchor_node.args) > 3 and anchor_node.args[3]
611+
else None
612+
)
613+
input_q = get_arg(dq_input, "input", fx.Node)
614+
# Default weight=1 and bias=0 must be float32 — cadence::quantized_layer_norm
615+
# expects float affine parameters, not quantized values.
616+
if not weight:
617+
weight = insert_node_with_meta(
618+
gm,
619+
torch.ops.aten.full.default,
620+
(normalized_shape, 1),
621+
{"dtype": torch.float32},
622+
anchor_node,
623+
input_q,
624+
)
625+
if not bias:
626+
bias = insert_node_with_meta(
627+
gm,
628+
torch.ops.aten.full.default,
629+
(normalized_shape, 0),
630+
{"dtype": torch.float32},
631+
anchor_node,
632+
input_q,
633+
)
634+
args = (input_q, scale, zero_point)
635+
kwargs = {
636+
"normalized_shape": normalized_shape,
637+
"weight": weight,
638+
"bias": bias,
639+
"eps": get_arg(anchor_node, "eps", float),
640+
"output_scale": get_arg(quant_node, "scale", float),
641+
"output_zero_point": get_arg(quant_node, "zero_point", int),
642+
}
643+
return replace_with_op(
644+
gm, anchor_node, self.replacement_op(), args, kwargs, quant_node
645+
)
646+
424647

425648
class LinearPattern(QuantizationPattern):
426649
def partition_types(self) -> List[OpOverload]:
@@ -463,6 +686,31 @@ def get_anchors(
463686
def replacement_op(self) -> OpOverload:
464687
return torch.ops.cadence.quantized_linear.per_tensor
465688

689+
def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None:
690+
dq_input = anchor_node.args[0]
691+
if not isinstance(dq_input, fx.Node) or dq_input.target != DQ_PER_TENSOR:
692+
return None
693+
dq_weight = anchor_node.args[1]
694+
if not isinstance(dq_weight, fx.Node) or dq_weight.target != DQ_PER_TENSOR:
695+
return None
696+
quant_node = find_quant_user(anchor_node)
697+
if quant_node is None:
698+
return None
699+
dq_bias: fx.Node | None = None
700+
if len(anchor_node.args) > 2:
701+
bias_arg = anchor_node.args[2]
702+
if isinstance(bias_arg, fx.Node) and bias_arg.target == DQ_PER_TENSOR:
703+
dq_bias = bias_arg
704+
return fuse_linear(
705+
gm,
706+
dq_input,
707+
dq_weight,
708+
dq_bias,
709+
quant_node,
710+
anchor_node,
711+
self.replacement_op(),
712+
)
713+
466714

467715
class MatmulPattern(QuantizationPattern):
468716
def partition_types(self) -> List[OpOverload]:
@@ -488,6 +736,18 @@ def replacement_op(self) -> OpOverload:
488736
# TODO: T240804887 This is actually a per-tensor variant, we just need to change the name of the op
489737
return torch.ops.cadence.quantized_matmul.default
490738

739+
def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None:
740+
dq0 = anchor_node.args[0]
741+
if not isinstance(dq0, fx.Node) or dq0.target != DQ_PER_TENSOR:
742+
return None
743+
dq1 = anchor_node.args[1]
744+
if not isinstance(dq1, fx.Node) or dq1.target != DQ_PER_TENSOR:
745+
return None
746+
quant_node = find_quant_user(anchor_node)
747+
if quant_node is None:
748+
return None
749+
return fuse_matmul(gm, anchor_node, dq0, dq1, quant_node, self.replacement_op())
750+
491751

492752
class MaxPool2dPattern(QuantizationPattern):
493753
"""

0 commit comments

Comments
 (0)