Skip to content

Commit bd24e79

Browse files
authored
Add fuse() to remaining QuantizationPatterns (pytorch#19727)
Summary: Add `fuse()` implementations to the remaining Cadence `QuantizationPattern` subclasses: - `MaxPool2dPattern`, `MaxPool2dWithoutIndicesPattern` — order-preserving pool on quantized values - `ReluBasePattern` (inherited by `ReluPattern0`/`1`) — relu with requantization - `ConvReluBasePattern` (inherited by `Conv1d`/`2dReluPattern0`/`1`) — conv+relu fusion with `anchor_ops()` override to match only the conv op - `SoftmaxPattern` — softmax with dummy mask/pos tensors and fake_mode metadata - `MixedW8A32LinearPattern` — weight-only quantized linear (no input/output quant) - `MixedW8A32ConvPattern` — weight-only quantized conv1d with NCL→NLC permutation - `MixedW8A32GruPattern` — weight-only quantized GRU with 4 dequantized params Reviewed By: DrJessop Differential Revision: D105728177
1 parent 5395f20 commit bd24e79

1 file changed

Lines changed: 260 additions & 2 deletions

File tree

backends/cadence/aot/quantizer/patterns.py

Lines changed: 260 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from typing import List, Optional, Tuple, Union
1313

1414
import torch
15+
from executorch.backends.cadence.aot.compiler_utils import get_shape
1516
from executorch.backends.cadence.aot.pass_utils import get_arg, replace_with_op
1617
from executorch.backends.cadence.aot.quantizer.pattern_utils import (
1718
DQ_PER_TENSOR,
@@ -24,6 +25,7 @@
2425
from executorch.backends.cadence.aot.quantizer.utils import (
2526
check_out_zero_point_is_min_range,
2627
get_bias_qparams,
28+
quantize_tensor_multiplier,
2729
)
2830
from torch import fx
2931
from torch._ops import OpOverload
@@ -806,6 +808,40 @@ def get_anchors(
806808
def replacement_op(self) -> OpOverload:
807809
return torch.ops.cadence.quantized_max_pool2d_nchw.default
808810

811+
def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None:
812+
return _fuse_max_pool2d(gm, anchor_node)
813+
814+
815+
def _fuse_max_pool2d(gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None:
816+
"""Shared fuse logic for both MaxPool2d variants."""
817+
dq_input = anchor_node.args[0]
818+
if not isinstance(dq_input, fx.Node) or dq_input.target != DQ_PER_TENSOR:
819+
return None
820+
quant_node = find_quant_user(anchor_node)
821+
if quant_node is None:
822+
return None
823+
kernel_size = get_arg(anchor_node, "kernel_size", list[int])
824+
stride = get_arg(anchor_node, "stride", list[int])
825+
padding = get_arg(anchor_node, "padding", list[int])
826+
dilation = get_arg(anchor_node, "dilation", list[int])
827+
ceil_mode = get_arg(anchor_node, "ceil_mode", bool)
828+
args = (get_arg(dq_input, "input", fx.Node),)
829+
kwargs = {
830+
"kernel_size": kernel_size,
831+
"stride": stride,
832+
"padding": padding,
833+
"dilation": dilation,
834+
"ceil_mode": ceil_mode,
835+
}
836+
return replace_with_op(
837+
gm,
838+
anchor_node,
839+
torch.ops.cadence.quantized_max_pool2d_nchw.default,
840+
args,
841+
kwargs,
842+
quant_node,
843+
)
844+
809845

810846
class MaxPool2dWithoutIndicesPattern(QuantizationPattern):
811847
"""
@@ -845,8 +881,8 @@ def get_anchors(
845881
def replacement_op(self) -> OpOverload:
846882
return torch.ops.cadence.quantized_max_pool2d_nchw.default
847883

848-
849-
# This is a base class for ReLU
884+
def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None:
885+
return _fuse_max_pool2d(gm, anchor_node)
850886

851887

852888
# This is a base class for ReLU, since it can be used with two different aten ops
@@ -874,6 +910,28 @@ def get_anchors(
874910
def replacement_op(self) -> OpOverload:
875911
return torch.ops.cadence.quantized_relu.per_tensor
876912

913+
def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None:
914+
dq_input = anchor_node.args[0]
915+
if not isinstance(dq_input, fx.Node) or dq_input.target != DQ_PER_TENSOR:
916+
return None
917+
quant_node = find_quant_user(anchor_node)
918+
if quant_node is None:
919+
return None
920+
input_scale = get_arg(dq_input, "scale", float)
921+
requantize_scale = input_scale / get_arg(quant_node, "scale", float)
922+
requantize_scale_t = torch.tensor([requantize_scale])
923+
out_multiplier, out_shift = quantize_tensor_multiplier(requantize_scale_t)
924+
args = (get_arg(dq_input, "input", fx.Node),)
925+
kwargs = {
926+
"X_zero_point": get_arg(dq_input, "zero_point", int),
927+
"out_zero_point": get_arg(quant_node, "zero_point", int),
928+
"out_multiplier": out_multiplier[0].item(),
929+
"out_shift": out_shift[0].item(),
930+
}
931+
return replace_with_op(
932+
gm, anchor_node, self.replacement_op(), args, kwargs, quant_node
933+
)
934+
877935

878936
# Regular relu op
879937
class ReluPattern0(ReluBasePattern):
@@ -933,6 +991,39 @@ def get_anchors(
933991
def replacement_op(self) -> OpOverload:
934992
return torch.ops.cadence.quantized_conv2d_nchw.per_tensor
935993

994+
def anchor_ops(self) -> tuple[OpOverload, ...]:
995+
return (self.partition_types()[0],)
996+
997+
def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None:
998+
conv_users = list(anchor_node.users)
999+
if len(conv_users) != 1:
1000+
return None
1001+
relu_node = conv_users[0]
1002+
if relu_node.target != self.partition_types()[1]:
1003+
return None
1004+
_arg0 = anchor_node.args[0]
1005+
dq_input = (
1006+
_arg0
1007+
if isinstance(_arg0, fx.Node) and _arg0.target == DQ_PER_TENSOR
1008+
else None
1009+
)
1010+
_arg1 = anchor_node.args[1]
1011+
dq_weight = (
1012+
_arg1
1013+
if isinstance(_arg1, fx.Node) and _arg1.target == DQ_PER_TENSOR
1014+
else None
1015+
)
1016+
if dq_input is None or dq_weight is None:
1017+
return None
1018+
quant_node = find_quant_user(relu_node)
1019+
if quant_node is None:
1020+
return None
1021+
check_out_zero_point_is_min_range(
1022+
get_arg(quant_node, "zero_point", int),
1023+
get_arg(quant_node, "dtype", torch.dtype),
1024+
)
1025+
return fuse_conv(self, gm, anchor_node, dq_input, dq_weight, quant_node)
1026+
9361027

9371028
# Conv1d + regular relu op fusion
9381029
class Conv1dReluPattern0(ConvReluBasePattern):
@@ -987,6 +1078,56 @@ def get_anchors(
9871078
def replacement_op(self) -> OpOverload:
9881079
return torch.ops.cadence.quantized_softmax.per_tensor
9891080

1081+
def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None:
1082+
dq_input = anchor_node.args[0]
1083+
if not isinstance(dq_input, fx.Node) or dq_input.target != DQ_PER_TENSOR:
1084+
return None
1085+
quant_node = find_quant_user(anchor_node)
1086+
if quant_node is None:
1087+
return None
1088+
input_q = get_arg(dq_input, "input", fx.Node)
1089+
quant_input = get_arg(quant_node, "input", fx.Node)
1090+
mask_shape = get_shape(gm, quant_input)
1091+
if not mask_shape:
1092+
return None
1093+
mask_shape = list(mask_shape)
1094+
# Softmax mask is packed 16 elements per int32 word.
1095+
assert (
1096+
mask_shape[-1] % 16 == 0
1097+
), f"Softmax mask dimension must be divisible by 16, got {mask_shape[-1]}"
1098+
mask_shape[-1] = mask_shape[-1] // 16
1099+
mask_tensor = insert_node_with_meta(
1100+
gm,
1101+
torch.ops.aten.full.default,
1102+
(mask_shape, 0.0),
1103+
{"dtype": torch.int32},
1104+
anchor_node,
1105+
input_q,
1106+
)
1107+
# Initial position for streaming softmax (unused, set to 0).
1108+
pos_tensor = insert_node_with_meta(
1109+
gm,
1110+
torch.ops.aten.full.default,
1111+
([1], 0),
1112+
{"dtype": torch.int64},
1113+
anchor_node,
1114+
input_q,
1115+
)
1116+
args = (
1117+
input_q,
1118+
mask_tensor,
1119+
get_arg(anchor_node, "dim", int),
1120+
0,
1121+
pos_tensor,
1122+
get_arg(dq_input, "scale", float),
1123+
get_arg(dq_input, "zero_point", int),
1124+
get_arg(quant_node, "scale", float),
1125+
get_arg(quant_node, "zero_point", int),
1126+
)
1127+
return replace_with_op(
1128+
gm, anchor_node, self.replacement_op(), args, {}, quant_node
1129+
)
1130+
9901131

9911132
class MixedW8A32LinearPattern(QuantizationPattern):
9921133
def partition_types(self) -> List[OpOverload]:
@@ -1041,6 +1182,36 @@ def get_anchors(
10411182
def replacement_op(self) -> OpOverload:
10421183
return torch.ops.cadence.quantized_w8a32_linear.default
10431184

1185+
def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None:
1186+
if len(anchor_node.args) != 3 or len(anchor_node.kwargs) > 0:
1187+
return None
1188+
_arg1 = anchor_node.args[1]
1189+
dq_weight = (
1190+
_arg1
1191+
if isinstance(_arg1, fx.Node) and _arg1.target == DQ_PER_TENSOR
1192+
else None
1193+
)
1194+
_arg2 = anchor_node.args[2]
1195+
dq_bias = (
1196+
_arg2
1197+
if isinstance(_arg2, fx.Node) and _arg2.target == DQ_PER_TENSOR
1198+
else None
1199+
)
1200+
if dq_weight is None or dq_bias is None:
1201+
return None
1202+
input_node = anchor_node.args[0]
1203+
assert isinstance(input_node, fx.Node)
1204+
args = (
1205+
input_node,
1206+
get_arg(dq_weight, "input", fx.Node),
1207+
get_arg(dq_weight, "scale", float),
1208+
get_arg(dq_bias, "input", fx.Node),
1209+
get_arg(dq_bias, "scale", float),
1210+
)
1211+
return replace_with_op(
1212+
gm, anchor_node, self.replacement_op(), args, {}, anchor_node
1213+
)
1214+
10441215

10451216
class MixedW8A32ConvPattern(QuantizationPattern):
10461217
def partition_types(self) -> List[OpOverload]:
@@ -1115,6 +1286,57 @@ def get_anchors(
11151286
def replacement_op(self) -> OpOverload:
11161287
return torch.ops.cadence.quantized_w8a32_conv.default
11171288

1289+
def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None:
1290+
if len(anchor_node.args) != 3 or len(anchor_node.kwargs) > 0:
1291+
return None
1292+
_arg1 = anchor_node.args[1]
1293+
dq_weight = (
1294+
_arg1
1295+
if isinstance(_arg1, fx.Node) and _arg1.target == DQ_PER_TENSOR
1296+
else None
1297+
)
1298+
_arg2 = anchor_node.args[2]
1299+
dq_bias = (
1300+
_arg2
1301+
if isinstance(_arg2, fx.Node) and _arg2.target == DQ_PER_TENSOR
1302+
else None
1303+
)
1304+
if dq_weight is None or dq_bias is None:
1305+
return None
1306+
input_node = anchor_node.args[0]
1307+
assert isinstance(input_node, fx.Node)
1308+
assert get_arg(anchor_node, "stride", list[int]) == [1]
1309+
assert get_arg(anchor_node, "padding", list[int]) == [0]
1310+
assert get_arg(anchor_node, "dilation", list[int]) == [1]
1311+
assert get_arg(anchor_node, "groups", int) == 1
1312+
weight_q = get_arg(dq_weight, "input", fx.Node)
1313+
transposed_inputs = insert_node_with_meta(
1314+
gm,
1315+
torch.ops.aten.permute.default,
1316+
(input_node, [0, 2, 1]),
1317+
None,
1318+
anchor_node,
1319+
input_node,
1320+
)
1321+
transposed_weights = insert_node_with_meta(
1322+
gm,
1323+
torch.ops.aten.permute.default,
1324+
(weight_q, [2, 0, 1]),
1325+
None,
1326+
anchor_node,
1327+
weight_q,
1328+
)
1329+
args = (
1330+
transposed_inputs,
1331+
transposed_weights,
1332+
get_arg(dq_weight, "scale", float),
1333+
get_arg(dq_bias, "input", fx.Node),
1334+
get_arg(dq_bias, "scale", float),
1335+
)
1336+
return replace_with_op(
1337+
gm, anchor_node, self.replacement_op(), args, {}, anchor_node
1338+
)
1339+
11181340

11191341
class MixedW8A32GruPattern(QuantizationPattern):
11201342
def partition_types(self) -> List[OpOverload]:
@@ -1187,6 +1409,42 @@ def __init__(self, args, meta):
11871409
def replacement_op(self) -> OpOverload:
11881410
return torch.ops.cadence.quantized_w8a32_gru.default
11891411

1412+
def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None:
1413+
if len(anchor_node.kwargs) > 0:
1414+
return None
1415+
params = anchor_node.args[2]
1416+
# GRU requires 4 weight/bias params: w_ih, w_hh, b_ih, b_hh
1417+
if not isinstance(params, (list, tuple)) or len(params) < 4:
1418+
return None
1419+
dq_w_ih = params[0]
1420+
if not isinstance(dq_w_ih, fx.Node) or dq_w_ih.target != DQ_PER_TENSOR:
1421+
return None
1422+
dq_w_hh = params[1]
1423+
if not isinstance(dq_w_hh, fx.Node) or dq_w_hh.target != DQ_PER_TENSOR:
1424+
return None
1425+
dq_b_ih = params[2]
1426+
if not isinstance(dq_b_ih, fx.Node) or dq_b_ih.target != DQ_PER_TENSOR:
1427+
return None
1428+
dq_b_hh = params[3]
1429+
if not isinstance(dq_b_hh, fx.Node) or dq_b_hh.target != DQ_PER_TENSOR:
1430+
return None
1431+
input_node = anchor_node.args[0]
1432+
hidden_node = anchor_node.args[1]
1433+
args = (
1434+
input_node,
1435+
hidden_node,
1436+
get_arg(dq_w_ih, "input", fx.Node),
1437+
get_arg(dq_w_ih, "scale", float),
1438+
get_arg(dq_w_hh, "input", fx.Node),
1439+
get_arg(dq_w_hh, "scale", float),
1440+
get_arg(dq_b_ih, "input", fx.Node),
1441+
get_arg(dq_b_ih, "scale", float),
1442+
get_arg(dq_b_hh, "input", fx.Node),
1443+
)
1444+
return replace_with_op(
1445+
gm, anchor_node, self.replacement_op(), args, {}, anchor_node
1446+
)
1447+
11901448

11911449
class RmsNormPattern(QuantizationPattern):
11921450
"""Pattern that preserves rms_norm from decomposition without matching anything."""

0 commit comments

Comments
 (0)