Skip to content

Commit 4de16d0

Browse files
authored
Add shared fusion infrastructure and QuantFusionPass (pytorch#19724)
Differential Revision: D105728137 Pull Request resolved: pytorch#19724
1 parent 9596866 commit 4de16d0

6 files changed

Lines changed: 289 additions & 2 deletions

File tree

backends/cadence/aot/compiler_funcs.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch
1515
from torch._inductor.decomposition import remove_decompositions
1616
from torch.fx import GraphModule
17+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
1718
from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e, prepare_qat_pt2e
1819
from torchao.quantization.pt2e.quantizer import Quantizer
1920

@@ -607,3 +608,32 @@ def sink_input_dequant_through_transparent_ops(
607608
graph_module.recompile()
608609

609610
return modified
611+
612+
613+
class QuantFusionPass(PassBase):
614+
"""
615+
Iterates patterns, finds anchor ops in the converted graph, and calls
616+
pattern.fuse() to replace dq-op-q subgraphs with fused ops.
617+
"""
618+
619+
def __init__(self, patterns: Sequence[object]) -> None:
620+
super().__init__()
621+
self.patterns = patterns
622+
623+
def call(self, graph_module: GraphModule) -> Optional[PassResult]:
624+
changed = False
625+
for pattern in self.patterns:
626+
pattern_changed = False
627+
for target in pattern.anchor_ops(): # pyre-ignore[16]
628+
for node in graph_module.graph.find_nodes(
629+
op="call_function", target=target
630+
):
631+
result = pattern.fuse(graph_module, node) # pyre-ignore[16]
632+
if result is not None:
633+
changed = True
634+
pattern_changed = True
635+
if pattern_changed:
636+
graph_module.graph.eliminate_dead_code()
637+
if changed:
638+
graph_module.recompile()
639+
return PassResult(graph_module, changed)

backends/cadence/aot/pass_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,3 +212,20 @@ def nodes_not_adjacent_in_gm(
212212
def none_throws(x: Optional[PassResult]) -> PassResult:
213213
assert x is not None
214214
return x
215+
216+
217+
def replace_with_op(
218+
gm: torch.fx.GraphModule,
219+
insert_after: torch.fx.Node,
220+
replacement_op: torch._ops.OpOverload,
221+
args: tuple, # pyre-ignore[2]
222+
kwargs: dict, # pyre-ignore[2]
223+
node_to_replace: torch.fx.Node,
224+
) -> torch.fx.Node:
225+
"""Insert ``replacement_op`` after ``insert_after`` and replace all uses of
226+
``node_to_replace`` with the new node."""
227+
with gm.graph.inserting_after(insert_after):
228+
new_node = gm.graph.call_function(replacement_op, args, kwargs)
229+
new_node.meta = node_to_replace.meta
230+
node_to_replace.replace_all_uses_with(new_node)
231+
return new_node

backends/cadence/aot/quantizer/BUCK

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,21 @@ fbcode_target(_kind = runtime.python_library,
1414
],
1515
)
1616

17+
fbcode_target(_kind = runtime.python_library,
18+
name = "pattern_utils",
19+
srcs = [
20+
"pattern_utils.py",
21+
],
22+
typing = True,
23+
deps = [
24+
":utils",
25+
"//caffe2:torch",
26+
"//executorch/backends/cadence/aot:compiler_utils",
27+
"//executorch/backends/cadence/aot:pass_utils",
28+
"//executorch/backends/cadence/aot:utils",
29+
],
30+
)
31+
1732
fbcode_target(_kind = runtime.python_library,
1833
name = "patterns",
1934
srcs = [
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import operator
10+
from typing import Any
11+
12+
import torch
13+
from executorch.backends.cadence.aot.pass_utils import get_arg, replace_with_op
14+
from executorch.backends.cadence.aot.quantizer.utils import (
15+
copy_node_metadata,
16+
create_zero_bias_int32,
17+
quantize_tensor_multiplier,
18+
)
19+
from executorch.backends.cadence.aot.utils import is_depthwise_conv
20+
from torch import fx
21+
from torch._ops import OpOverload
22+
23+
DQ_PER_TENSOR: OpOverload = torch.ops.quantized_decomposed.dequantize_per_tensor.default
24+
Q_PER_TENSOR: OpOverload = torch.ops.quantized_decomposed.quantize_per_tensor.default
25+
26+
27+
def insert_node_with_meta(
28+
gm: fx.GraphModule,
29+
op: OpOverload,
30+
args: tuple[Any, ...],
31+
kwargs: dict[str, Any] | None,
32+
insert_before: fx.Node,
33+
like_node: fx.Node,
34+
) -> fx.Node:
35+
"""Create a new node and populate its FakeTensor metadata.
36+
37+
Inserts ``op(*args, **kwargs)`` before ``insert_before``, runs the op
38+
under ``like_node``'s fake_mode to compute ``meta["val"]``, and copies
39+
remaining metadata from ``like_node``.
40+
"""
41+
with gm.graph.inserting_before(insert_before):
42+
node = gm.graph.call_function(op, args, kwargs or {})
43+
assert "val" in like_node.meta
44+
fake_mode = like_node.meta["val"].fake_mode
45+
assert fake_mode is not None
46+
47+
def _resolve(x: Any) -> Any:
48+
return x.meta["val"] if isinstance(x, fx.Node) else x
49+
50+
fake_args = tuple(_resolve(a) for a in args)
51+
fake_kwargs = {k: _resolve(v) for k, v in (kwargs or {}).items()}
52+
with fake_mode:
53+
node.meta["val"] = op(*fake_args, **fake_kwargs)
54+
copy_node_metadata(node, like_node)
55+
return node
56+
57+
58+
def find_quant_user(node: fx.Node) -> fx.Node | None:
59+
"""Find the first quantize_per_tensor user of ``node``, traversing through getitem."""
60+
users = list(node.users)
61+
if not users:
62+
return None
63+
user = users[0]
64+
if user.target is operator.getitem:
65+
if user.args[1] == 0:
66+
users = list(user.users)
67+
if not users:
68+
return None
69+
user = users[0]
70+
else:
71+
return None
72+
if user.target == Q_PER_TENSOR:
73+
return user
74+
return None
75+
76+
77+
def fuse_conv(
78+
pattern: object,
79+
gm: fx.GraphModule,
80+
conv_node: fx.Node,
81+
dq_input: fx.Node,
82+
dq_weight: fx.Node,
83+
quant_node: fx.Node,
84+
) -> fx.Node:
85+
"""Fuse a dq->conv->q chain into a single quantized conv op."""
86+
dq_bias = None
87+
if len(conv_node.args) > 2 and conv_node.args[2] is not None:
88+
bias_arg = conv_node.args[2]
89+
assert isinstance(bias_arg, fx.Node)
90+
dq_bias = bias_arg if bias_arg.target == DQ_PER_TENSOR else None
91+
weight_scale = get_arg(dq_weight, "scale", float)
92+
input_scale = get_arg(dq_input, "scale", float)
93+
bias_scale = input_scale * weight_scale
94+
if dq_bias is not None:
95+
bias_q = get_arg(dq_bias, "input", fx.Node)
96+
else:
97+
# Cadence quantized conv ops require a non-optional bias argument.
98+
weight_node = get_arg(dq_weight, "input", fx.Node)
99+
with gm.graph.inserting_before(conv_node):
100+
bias_q = create_zero_bias_int32(gm, weight_node, bias_scale)
101+
requantize_scale = bias_scale / get_arg(quant_node, "scale", float)
102+
requantize_scale_t = torch.tensor([requantize_scale])
103+
out_multiplier, out_shift = quantize_tensor_multiplier(requantize_scale_t)
104+
args = (
105+
get_arg(dq_input, "input", fx.Node),
106+
get_arg(dq_weight, "input", fx.Node),
107+
bias_q,
108+
)
109+
groups = get_arg(conv_node, "groups", int)
110+
kwargs = {
111+
"stride": get_arg(conv_node, "stride", list[int]),
112+
"padding": get_arg(conv_node, "padding", list[int]),
113+
"dilation": get_arg(conv_node, "dilation", list[int]),
114+
"groups": groups,
115+
"input_zero_point": get_arg(dq_input, "zero_point", int),
116+
"weight_zero_point": get_arg(dq_weight, "zero_point", int),
117+
"bias_scale": bias_scale,
118+
"out_scale": get_arg(quant_node, "scale", float),
119+
"out_zero_point": get_arg(quant_node, "zero_point", int),
120+
"out_multiplier": out_multiplier[0].item(),
121+
"out_shift": out_shift[0].item(),
122+
}
123+
replacement_op = pattern.replacement_op() # pyre-ignore[16]
124+
if replacement_op == torch.ops.cadence.quantized_conv1d_ncl.per_tensor:
125+
input_node = get_arg(dq_input, "input", fx.Node)
126+
assert len(input_node.meta["val"].shape) >= 2
127+
in_channels = input_node.meta["val"].shape[1]
128+
if is_depthwise_conv(groups, in_channels):
129+
replacement_op = torch.ops.cadence.quantized_depthwise_conv1d_ncl.per_tensor
130+
return replace_with_op(gm, conv_node, replacement_op, args, kwargs, quant_node)
131+
132+
133+
def fuse_linear(
134+
gm: fx.GraphModule,
135+
dq_input: fx.Node,
136+
dq_weight: fx.Node,
137+
dq_bias: fx.Node | None,
138+
quant_node: fx.Node,
139+
op_node: fx.Node,
140+
replacement_op: OpOverload,
141+
weight_q: fx.Node | None = None,
142+
) -> fx.Node:
143+
"""Fuse a dq->linear->q chain into a single quantized linear op."""
144+
assert op_node.target in (
145+
torch.ops.aten.linear.default,
146+
torch.ops.aten.addmm.default,
147+
), f"Expected linear/addmm, got {op_node.target}"
148+
weight_scale = get_arg(dq_weight, "scale", float)
149+
input_scale = get_arg(dq_input, "scale", float)
150+
bias_scale = input_scale * weight_scale
151+
requantize_scale = bias_scale / get_arg(quant_node, "scale", float)
152+
requantize_scale_t = torch.tensor([requantize_scale])
153+
out_multiplier, out_shift = quantize_tensor_multiplier(requantize_scale_t)
154+
if dq_bias is not None:
155+
bias_q = get_arg(dq_bias, "input", fx.Node)
156+
else:
157+
# Cadence quantized linear ops require a non-optional bias argument.
158+
weight_node = get_arg(dq_weight, "input", fx.Node)
159+
with gm.graph.inserting_before(op_node):
160+
bias_q = create_zero_bias_int32(gm, weight_node, bias_scale)
161+
final_weight = (
162+
weight_q if weight_q is not None else get_arg(dq_weight, "input", fx.Node)
163+
)
164+
args = (get_arg(dq_input, "input", fx.Node), final_weight, bias_q)
165+
kwargs = {
166+
"src_zero_point": get_arg(dq_input, "zero_point", int),
167+
"weight_zero_point": get_arg(dq_weight, "zero_point", int),
168+
"out_multiplier": out_multiplier[0].item(),
169+
"out_shift": out_shift[0].item(),
170+
"out_zero_point": get_arg(quant_node, "zero_point", int),
171+
"offset": None,
172+
}
173+
return replace_with_op(gm, op_node, replacement_op, args, kwargs, quant_node)
174+
175+
176+
def fuse_matmul(
177+
gm: fx.GraphModule,
178+
anchor_node: fx.Node,
179+
dq0: fx.Node,
180+
dq1: fx.Node,
181+
quant_node: fx.Node,
182+
replacement_op: OpOverload,
183+
) -> fx.Node:
184+
"""Fuse a dq->matmul->q chain into a single quantized matmul op."""
185+
assert anchor_node.target in (
186+
torch.ops.aten.bmm.default,
187+
torch.ops.aten.matmul.default,
188+
), f"Expected bmm/matmul, got {anchor_node.target}"
189+
scale0 = get_arg(dq0, "scale", float)
190+
scale1 = get_arg(dq1, "scale", float)
191+
requantize_scale = (scale0 * scale1) / get_arg(quant_node, "scale", float)
192+
requantize_scale_t = torch.tensor([requantize_scale])
193+
out_multiplier, out_shift = quantize_tensor_multiplier(requantize_scale_t)
194+
args = (
195+
get_arg(dq0, "input", fx.Node),
196+
get_arg(dq0, "zero_point", int),
197+
get_arg(dq1, "input", fx.Node),
198+
get_arg(dq1, "zero_point", int),
199+
None,
200+
)
201+
kwargs = {
202+
"out_multiplier": out_multiplier[0].item(),
203+
"out_shift": out_shift[0].item(),
204+
"out_zero_point": get_arg(quant_node, "zero_point", int),
205+
"transposed": False,
206+
}
207+
return replace_with_op(gm, anchor_node, replacement_op, args, kwargs, quant_node)

backends/cadence/aot/quantizer/patterns.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import operator
1010
from abc import ABC, abstractmethod
1111
from dataclasses import dataclass, field
12-
from typing import List, Tuple, Union
12+
from typing import List, Optional, Tuple, Union
1313

1414
import torch
1515
from executorch.backends.cadence.aot.quantizer.utils import get_bias_qparams
@@ -79,6 +79,22 @@ def replacement_op(self) -> OpOverload:
7979
"""
8080
pass
8181

82+
def anchor_ops(self) -> tuple[OpOverload, ...]:
83+
return tuple(self.partition_types())
84+
85+
def fuse(
86+
self,
87+
gm: fx.GraphModule,
88+
anchor_node: fx.Node,
89+
) -> Optional[fx.Node]:
90+
"""Replace the dq→op→q subgraph around ``anchor_node`` with a fused op.
91+
92+
Called by ``QuantFusionPass`` for each node matching ``anchor_ops()``.
93+
Returns the new fused node on success, or ``None`` to skip this match.
94+
Subclasses override to implement pattern-specific fusion logic.
95+
"""
96+
return None
97+
8298

8399
class AddmmPattern(QuantizationPattern):
84100
def partition_types(self) -> List[OpOverload]:

backends/cadence/aot/quantizer/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,9 @@ def create_zero_bias_int32(
118118
bias_scale: float,
119119
) -> fx.Node:
120120
"""
121-
Creates a zero bias tensor with the shape of weight[0]
121+
Creates a zero bias tensor with the shape of weight[0].
122+
Caller is responsible for setting the graph insertion point
123+
(e.g. ``with gm.graph.inserting_before(node):``).
122124
"""
123125
try:
124126
attr_node = getattr(graph_module, weight_node.target)

0 commit comments

Comments
 (0)