|
12 | 12 | from typing import List, Optional, Tuple, Union |
13 | 13 |
|
14 | 14 | import torch |
| 15 | +from executorch.backends.cadence.aot.compiler_utils import get_shape |
15 | 16 | from executorch.backends.cadence.aot.pass_utils import get_arg, replace_with_op |
16 | 17 | from executorch.backends.cadence.aot.quantizer.pattern_utils import ( |
17 | 18 | DQ_PER_TENSOR, |
|
24 | 25 | from executorch.backends.cadence.aot.quantizer.utils import ( |
25 | 26 | check_out_zero_point_is_min_range, |
26 | 27 | get_bias_qparams, |
| 28 | + quantize_tensor_multiplier, |
27 | 29 | ) |
28 | 30 | from torch import fx |
29 | 31 | from torch._ops import OpOverload |
@@ -806,6 +808,40 @@ def get_anchors( |
806 | 808 | def replacement_op(self) -> OpOverload: |
807 | 809 | return torch.ops.cadence.quantized_max_pool2d_nchw.default |
808 | 810 |
|
| 811 | + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None: |
| 812 | + return _fuse_max_pool2d(gm, anchor_node) |
| 813 | + |
| 814 | + |
| 815 | +def _fuse_max_pool2d(gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None: |
| 816 | + """Shared fuse logic for both MaxPool2d variants.""" |
| 817 | + dq_input = anchor_node.args[0] |
| 818 | + if not isinstance(dq_input, fx.Node) or dq_input.target != DQ_PER_TENSOR: |
| 819 | + return None |
| 820 | + quant_node = find_quant_user(anchor_node) |
| 821 | + if quant_node is None: |
| 822 | + return None |
| 823 | + kernel_size = get_arg(anchor_node, "kernel_size", list[int]) |
| 824 | + stride = get_arg(anchor_node, "stride", list[int]) |
| 825 | + padding = get_arg(anchor_node, "padding", list[int]) |
| 826 | + dilation = get_arg(anchor_node, "dilation", list[int]) |
| 827 | + ceil_mode = get_arg(anchor_node, "ceil_mode", bool) |
| 828 | + args = (get_arg(dq_input, "input", fx.Node),) |
| 829 | + kwargs = { |
| 830 | + "kernel_size": kernel_size, |
| 831 | + "stride": stride, |
| 832 | + "padding": padding, |
| 833 | + "dilation": dilation, |
| 834 | + "ceil_mode": ceil_mode, |
| 835 | + } |
| 836 | + return replace_with_op( |
| 837 | + gm, |
| 838 | + anchor_node, |
| 839 | + torch.ops.cadence.quantized_max_pool2d_nchw.default, |
| 840 | + args, |
| 841 | + kwargs, |
| 842 | + quant_node, |
| 843 | + ) |
| 844 | + |
809 | 845 |
|
810 | 846 | class MaxPool2dWithoutIndicesPattern(QuantizationPattern): |
811 | 847 | """ |
@@ -845,8 +881,8 @@ def get_anchors( |
845 | 881 | def replacement_op(self) -> OpOverload: |
846 | 882 | return torch.ops.cadence.quantized_max_pool2d_nchw.default |
847 | 883 |
|
848 | | - |
849 | | -# This is a base class for ReLU |
| 884 | + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None: |
| 885 | + return _fuse_max_pool2d(gm, anchor_node) |
850 | 886 |
|
851 | 887 |
|
852 | 888 | # This is a base class for ReLU, since it can be used with two different aten ops |
@@ -874,6 +910,28 @@ def get_anchors( |
874 | 910 | def replacement_op(self) -> OpOverload: |
875 | 911 | return torch.ops.cadence.quantized_relu.per_tensor |
876 | 912 |
|
| 913 | + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None: |
| 914 | + dq_input = anchor_node.args[0] |
| 915 | + if not isinstance(dq_input, fx.Node) or dq_input.target != DQ_PER_TENSOR: |
| 916 | + return None |
| 917 | + quant_node = find_quant_user(anchor_node) |
| 918 | + if quant_node is None: |
| 919 | + return None |
| 920 | + input_scale = get_arg(dq_input, "scale", float) |
| 921 | + requantize_scale = input_scale / get_arg(quant_node, "scale", float) |
| 922 | + requantize_scale_t = torch.tensor([requantize_scale]) |
| 923 | + out_multiplier, out_shift = quantize_tensor_multiplier(requantize_scale_t) |
| 924 | + args = (get_arg(dq_input, "input", fx.Node),) |
| 925 | + kwargs = { |
| 926 | + "X_zero_point": get_arg(dq_input, "zero_point", int), |
| 927 | + "out_zero_point": get_arg(quant_node, "zero_point", int), |
| 928 | + "out_multiplier": out_multiplier[0].item(), |
| 929 | + "out_shift": out_shift[0].item(), |
| 930 | + } |
| 931 | + return replace_with_op( |
| 932 | + gm, anchor_node, self.replacement_op(), args, kwargs, quant_node |
| 933 | + ) |
| 934 | + |
877 | 935 |
|
878 | 936 | # Regular relu op |
879 | 937 | class ReluPattern0(ReluBasePattern): |
@@ -933,6 +991,39 @@ def get_anchors( |
933 | 991 | def replacement_op(self) -> OpOverload: |
934 | 992 | return torch.ops.cadence.quantized_conv2d_nchw.per_tensor |
935 | 993 |
|
| 994 | + def anchor_ops(self) -> tuple[OpOverload, ...]: |
| 995 | + return (self.partition_types()[0],) |
| 996 | + |
| 997 | + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None: |
| 998 | + conv_users = list(anchor_node.users) |
| 999 | + if len(conv_users) != 1: |
| 1000 | + return None |
| 1001 | + relu_node = conv_users[0] |
| 1002 | + if relu_node.target != self.partition_types()[1]: |
| 1003 | + return None |
| 1004 | + _arg0 = anchor_node.args[0] |
| 1005 | + dq_input = ( |
| 1006 | + _arg0 |
| 1007 | + if isinstance(_arg0, fx.Node) and _arg0.target == DQ_PER_TENSOR |
| 1008 | + else None |
| 1009 | + ) |
| 1010 | + _arg1 = anchor_node.args[1] |
| 1011 | + dq_weight = ( |
| 1012 | + _arg1 |
| 1013 | + if isinstance(_arg1, fx.Node) and _arg1.target == DQ_PER_TENSOR |
| 1014 | + else None |
| 1015 | + ) |
| 1016 | + if dq_input is None or dq_weight is None: |
| 1017 | + return None |
| 1018 | + quant_node = find_quant_user(relu_node) |
| 1019 | + if quant_node is None: |
| 1020 | + return None |
| 1021 | + check_out_zero_point_is_min_range( |
| 1022 | + get_arg(quant_node, "zero_point", int), |
| 1023 | + get_arg(quant_node, "dtype", torch.dtype), |
| 1024 | + ) |
| 1025 | + return fuse_conv(self, gm, anchor_node, dq_input, dq_weight, quant_node) |
| 1026 | + |
936 | 1027 |
|
937 | 1028 | # Conv1d + regular relu op fusion |
938 | 1029 | class Conv1dReluPattern0(ConvReluBasePattern): |
@@ -987,6 +1078,56 @@ def get_anchors( |
987 | 1078 | def replacement_op(self) -> OpOverload: |
988 | 1079 | return torch.ops.cadence.quantized_softmax.per_tensor |
989 | 1080 |
|
| 1081 | + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None: |
| 1082 | + dq_input = anchor_node.args[0] |
| 1083 | + if not isinstance(dq_input, fx.Node) or dq_input.target != DQ_PER_TENSOR: |
| 1084 | + return None |
| 1085 | + quant_node = find_quant_user(anchor_node) |
| 1086 | + if quant_node is None: |
| 1087 | + return None |
| 1088 | + input_q = get_arg(dq_input, "input", fx.Node) |
| 1089 | + quant_input = get_arg(quant_node, "input", fx.Node) |
| 1090 | + mask_shape = get_shape(gm, quant_input) |
| 1091 | + if not mask_shape: |
| 1092 | + return None |
| 1093 | + mask_shape = list(mask_shape) |
| 1094 | + # Softmax mask is packed 16 elements per int32 word. |
| 1095 | + assert ( |
| 1096 | + mask_shape[-1] % 16 == 0 |
| 1097 | + ), f"Softmax mask dimension must be divisible by 16, got {mask_shape[-1]}" |
| 1098 | + mask_shape[-1] = mask_shape[-1] // 16 |
| 1099 | + mask_tensor = insert_node_with_meta( |
| 1100 | + gm, |
| 1101 | + torch.ops.aten.full.default, |
| 1102 | + (mask_shape, 0.0), |
| 1103 | + {"dtype": torch.int32}, |
| 1104 | + anchor_node, |
| 1105 | + input_q, |
| 1106 | + ) |
| 1107 | + # Initial position for streaming softmax (unused, set to 0). |
| 1108 | + pos_tensor = insert_node_with_meta( |
| 1109 | + gm, |
| 1110 | + torch.ops.aten.full.default, |
| 1111 | + ([1], 0), |
| 1112 | + {"dtype": torch.int64}, |
| 1113 | + anchor_node, |
| 1114 | + input_q, |
| 1115 | + ) |
| 1116 | + args = ( |
| 1117 | + input_q, |
| 1118 | + mask_tensor, |
| 1119 | + get_arg(anchor_node, "dim", int), |
| 1120 | + 0, |
| 1121 | + pos_tensor, |
| 1122 | + get_arg(dq_input, "scale", float), |
| 1123 | + get_arg(dq_input, "zero_point", int), |
| 1124 | + get_arg(quant_node, "scale", float), |
| 1125 | + get_arg(quant_node, "zero_point", int), |
| 1126 | + ) |
| 1127 | + return replace_with_op( |
| 1128 | + gm, anchor_node, self.replacement_op(), args, {}, quant_node |
| 1129 | + ) |
| 1130 | + |
990 | 1131 |
|
991 | 1132 | class MixedW8A32LinearPattern(QuantizationPattern): |
992 | 1133 | def partition_types(self) -> List[OpOverload]: |
@@ -1041,6 +1182,36 @@ def get_anchors( |
1041 | 1182 | def replacement_op(self) -> OpOverload: |
1042 | 1183 | return torch.ops.cadence.quantized_w8a32_linear.default |
1043 | 1184 |
|
| 1185 | + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None: |
| 1186 | + if len(anchor_node.args) != 3 or len(anchor_node.kwargs) > 0: |
| 1187 | + return None |
| 1188 | + _arg1 = anchor_node.args[1] |
| 1189 | + dq_weight = ( |
| 1190 | + _arg1 |
| 1191 | + if isinstance(_arg1, fx.Node) and _arg1.target == DQ_PER_TENSOR |
| 1192 | + else None |
| 1193 | + ) |
| 1194 | + _arg2 = anchor_node.args[2] |
| 1195 | + dq_bias = ( |
| 1196 | + _arg2 |
| 1197 | + if isinstance(_arg2, fx.Node) and _arg2.target == DQ_PER_TENSOR |
| 1198 | + else None |
| 1199 | + ) |
| 1200 | + if dq_weight is None or dq_bias is None: |
| 1201 | + return None |
| 1202 | + input_node = anchor_node.args[0] |
| 1203 | + assert isinstance(input_node, fx.Node) |
| 1204 | + args = ( |
| 1205 | + input_node, |
| 1206 | + get_arg(dq_weight, "input", fx.Node), |
| 1207 | + get_arg(dq_weight, "scale", float), |
| 1208 | + get_arg(dq_bias, "input", fx.Node), |
| 1209 | + get_arg(dq_bias, "scale", float), |
| 1210 | + ) |
| 1211 | + return replace_with_op( |
| 1212 | + gm, anchor_node, self.replacement_op(), args, {}, anchor_node |
| 1213 | + ) |
| 1214 | + |
1044 | 1215 |
|
1045 | 1216 | class MixedW8A32ConvPattern(QuantizationPattern): |
1046 | 1217 | def partition_types(self) -> List[OpOverload]: |
@@ -1115,6 +1286,57 @@ def get_anchors( |
1115 | 1286 | def replacement_op(self) -> OpOverload: |
1116 | 1287 | return torch.ops.cadence.quantized_w8a32_conv.default |
1117 | 1288 |
|
| 1289 | + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None: |
| 1290 | + if len(anchor_node.args) != 3 or len(anchor_node.kwargs) > 0: |
| 1291 | + return None |
| 1292 | + _arg1 = anchor_node.args[1] |
| 1293 | + dq_weight = ( |
| 1294 | + _arg1 |
| 1295 | + if isinstance(_arg1, fx.Node) and _arg1.target == DQ_PER_TENSOR |
| 1296 | + else None |
| 1297 | + ) |
| 1298 | + _arg2 = anchor_node.args[2] |
| 1299 | + dq_bias = ( |
| 1300 | + _arg2 |
| 1301 | + if isinstance(_arg2, fx.Node) and _arg2.target == DQ_PER_TENSOR |
| 1302 | + else None |
| 1303 | + ) |
| 1304 | + if dq_weight is None or dq_bias is None: |
| 1305 | + return None |
| 1306 | + input_node = anchor_node.args[0] |
| 1307 | + assert isinstance(input_node, fx.Node) |
| 1308 | + assert get_arg(anchor_node, "stride", list[int]) == [1] |
| 1309 | + assert get_arg(anchor_node, "padding", list[int]) == [0] |
| 1310 | + assert get_arg(anchor_node, "dilation", list[int]) == [1] |
| 1311 | + assert get_arg(anchor_node, "groups", int) == 1 |
| 1312 | + weight_q = get_arg(dq_weight, "input", fx.Node) |
| 1313 | + transposed_inputs = insert_node_with_meta( |
| 1314 | + gm, |
| 1315 | + torch.ops.aten.permute.default, |
| 1316 | + (input_node, [0, 2, 1]), |
| 1317 | + None, |
| 1318 | + anchor_node, |
| 1319 | + input_node, |
| 1320 | + ) |
| 1321 | + transposed_weights = insert_node_with_meta( |
| 1322 | + gm, |
| 1323 | + torch.ops.aten.permute.default, |
| 1324 | + (weight_q, [2, 0, 1]), |
| 1325 | + None, |
| 1326 | + anchor_node, |
| 1327 | + weight_q, |
| 1328 | + ) |
| 1329 | + args = ( |
| 1330 | + transposed_inputs, |
| 1331 | + transposed_weights, |
| 1332 | + get_arg(dq_weight, "scale", float), |
| 1333 | + get_arg(dq_bias, "input", fx.Node), |
| 1334 | + get_arg(dq_bias, "scale", float), |
| 1335 | + ) |
| 1336 | + return replace_with_op( |
| 1337 | + gm, anchor_node, self.replacement_op(), args, {}, anchor_node |
| 1338 | + ) |
| 1339 | + |
1118 | 1340 |
|
1119 | 1341 | class MixedW8A32GruPattern(QuantizationPattern): |
1120 | 1342 | def partition_types(self) -> List[OpOverload]: |
@@ -1187,6 +1409,42 @@ def __init__(self, args, meta): |
1187 | 1409 | def replacement_op(self) -> OpOverload: |
1188 | 1410 | return torch.ops.cadence.quantized_w8a32_gru.default |
1189 | 1411 |
|
| 1412 | + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None: |
| 1413 | + if len(anchor_node.kwargs) > 0: |
| 1414 | + return None |
| 1415 | + params = anchor_node.args[2] |
| 1416 | + # GRU requires 4 weight/bias params: w_ih, w_hh, b_ih, b_hh |
| 1417 | + if not isinstance(params, (list, tuple)) or len(params) < 4: |
| 1418 | + return None |
| 1419 | + dq_w_ih = params[0] |
| 1420 | + if not isinstance(dq_w_ih, fx.Node) or dq_w_ih.target != DQ_PER_TENSOR: |
| 1421 | + return None |
| 1422 | + dq_w_hh = params[1] |
| 1423 | + if not isinstance(dq_w_hh, fx.Node) or dq_w_hh.target != DQ_PER_TENSOR: |
| 1424 | + return None |
| 1425 | + dq_b_ih = params[2] |
| 1426 | + if not isinstance(dq_b_ih, fx.Node) or dq_b_ih.target != DQ_PER_TENSOR: |
| 1427 | + return None |
| 1428 | + dq_b_hh = params[3] |
| 1429 | + if not isinstance(dq_b_hh, fx.Node) or dq_b_hh.target != DQ_PER_TENSOR: |
| 1430 | + return None |
| 1431 | + input_node = anchor_node.args[0] |
| 1432 | + hidden_node = anchor_node.args[1] |
| 1433 | + args = ( |
| 1434 | + input_node, |
| 1435 | + hidden_node, |
| 1436 | + get_arg(dq_w_ih, "input", fx.Node), |
| 1437 | + get_arg(dq_w_ih, "scale", float), |
| 1438 | + get_arg(dq_w_hh, "input", fx.Node), |
| 1439 | + get_arg(dq_w_hh, "scale", float), |
| 1440 | + get_arg(dq_b_ih, "input", fx.Node), |
| 1441 | + get_arg(dq_b_ih, "scale", float), |
| 1442 | + get_arg(dq_b_hh, "input", fx.Node), |
| 1443 | + ) |
| 1444 | + return replace_with_op( |
| 1445 | + gm, anchor_node, self.replacement_op(), args, {}, anchor_node |
| 1446 | + ) |
| 1447 | + |
1190 | 1448 |
|
1191 | 1449 | class RmsNormPattern(QuantizationPattern): |
1192 | 1450 | """Pattern that preserves rms_norm from decomposition without matching anything.""" |
|
0 commit comments