|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +# pyre-strict |
| 8 | + |
| 9 | +import operator |
| 10 | +from typing import Any |
| 11 | + |
| 12 | +import torch |
| 13 | +from executorch.backends.cadence.aot.pass_utils import get_arg, replace_with_op |
| 14 | +from executorch.backends.cadence.aot.quantizer.utils import ( |
| 15 | + copy_node_metadata, |
| 16 | + create_zero_bias_int32, |
| 17 | + quantize_tensor_multiplier, |
| 18 | +) |
| 19 | +from executorch.backends.cadence.aot.utils import is_depthwise_conv |
| 20 | +from torch import fx |
| 21 | +from torch._ops import OpOverload |
| 22 | + |
| 23 | +DQ_PER_TENSOR: OpOverload = torch.ops.quantized_decomposed.dequantize_per_tensor.default |
| 24 | +Q_PER_TENSOR: OpOverload = torch.ops.quantized_decomposed.quantize_per_tensor.default |
| 25 | + |
| 26 | + |
| 27 | +def insert_node_with_meta( |
| 28 | + gm: fx.GraphModule, |
| 29 | + op: OpOverload, |
| 30 | + args: tuple[Any, ...], |
| 31 | + kwargs: dict[str, Any] | None, |
| 32 | + insert_before: fx.Node, |
| 33 | + like_node: fx.Node, |
| 34 | +) -> fx.Node: |
| 35 | + """Create a new node and populate its FakeTensor metadata. |
| 36 | +
|
| 37 | + Inserts ``op(*args, **kwargs)`` before ``insert_before``, runs the op |
| 38 | + under ``like_node``'s fake_mode to compute ``meta["val"]``, and copies |
| 39 | + remaining metadata from ``like_node``. |
| 40 | + """ |
| 41 | + with gm.graph.inserting_before(insert_before): |
| 42 | + node = gm.graph.call_function(op, args, kwargs or {}) |
| 43 | + assert "val" in like_node.meta |
| 44 | + fake_mode = like_node.meta["val"].fake_mode |
| 45 | + assert fake_mode is not None |
| 46 | + |
| 47 | + def _resolve(x: Any) -> Any: |
| 48 | + return x.meta["val"] if isinstance(x, fx.Node) else x |
| 49 | + |
| 50 | + fake_args = tuple(_resolve(a) for a in args) |
| 51 | + fake_kwargs = {k: _resolve(v) for k, v in (kwargs or {}).items()} |
| 52 | + with fake_mode: |
| 53 | + node.meta["val"] = op(*fake_args, **fake_kwargs) |
| 54 | + copy_node_metadata(node, like_node) |
| 55 | + return node |
| 56 | + |
| 57 | + |
| 58 | +def find_quant_user(node: fx.Node) -> fx.Node | None: |
| 59 | + """Find the first quantize_per_tensor user of ``node``, traversing through getitem.""" |
| 60 | + users = list(node.users) |
| 61 | + if not users: |
| 62 | + return None |
| 63 | + user = users[0] |
| 64 | + if user.target is operator.getitem: |
| 65 | + if user.args[1] == 0: |
| 66 | + users = list(user.users) |
| 67 | + if not users: |
| 68 | + return None |
| 69 | + user = users[0] |
| 70 | + else: |
| 71 | + return None |
| 72 | + if user.target == Q_PER_TENSOR: |
| 73 | + return user |
| 74 | + return None |
| 75 | + |
| 76 | + |
| 77 | +def fuse_conv( |
| 78 | + pattern: object, |
| 79 | + gm: fx.GraphModule, |
| 80 | + conv_node: fx.Node, |
| 81 | + dq_input: fx.Node, |
| 82 | + dq_weight: fx.Node, |
| 83 | + quant_node: fx.Node, |
| 84 | +) -> fx.Node: |
| 85 | + """Fuse a dq->conv->q chain into a single quantized conv op.""" |
| 86 | + dq_bias = None |
| 87 | + if len(conv_node.args) > 2 and conv_node.args[2] is not None: |
| 88 | + bias_arg = conv_node.args[2] |
| 89 | + assert isinstance(bias_arg, fx.Node) |
| 90 | + dq_bias = bias_arg if bias_arg.target == DQ_PER_TENSOR else None |
| 91 | + weight_scale = get_arg(dq_weight, "scale", float) |
| 92 | + input_scale = get_arg(dq_input, "scale", float) |
| 93 | + bias_scale = input_scale * weight_scale |
| 94 | + if dq_bias is not None: |
| 95 | + bias_q = get_arg(dq_bias, "input", fx.Node) |
| 96 | + else: |
| 97 | + # Cadence quantized conv ops require a non-optional bias argument. |
| 98 | + weight_node = get_arg(dq_weight, "input", fx.Node) |
| 99 | + with gm.graph.inserting_before(conv_node): |
| 100 | + bias_q = create_zero_bias_int32(gm, weight_node, bias_scale) |
| 101 | + requantize_scale = bias_scale / get_arg(quant_node, "scale", float) |
| 102 | + requantize_scale_t = torch.tensor([requantize_scale]) |
| 103 | + out_multiplier, out_shift = quantize_tensor_multiplier(requantize_scale_t) |
| 104 | + args = ( |
| 105 | + get_arg(dq_input, "input", fx.Node), |
| 106 | + get_arg(dq_weight, "input", fx.Node), |
| 107 | + bias_q, |
| 108 | + ) |
| 109 | + groups = get_arg(conv_node, "groups", int) |
| 110 | + kwargs = { |
| 111 | + "stride": get_arg(conv_node, "stride", list[int]), |
| 112 | + "padding": get_arg(conv_node, "padding", list[int]), |
| 113 | + "dilation": get_arg(conv_node, "dilation", list[int]), |
| 114 | + "groups": groups, |
| 115 | + "input_zero_point": get_arg(dq_input, "zero_point", int), |
| 116 | + "weight_zero_point": get_arg(dq_weight, "zero_point", int), |
| 117 | + "bias_scale": bias_scale, |
| 118 | + "out_scale": get_arg(quant_node, "scale", float), |
| 119 | + "out_zero_point": get_arg(quant_node, "zero_point", int), |
| 120 | + "out_multiplier": out_multiplier[0].item(), |
| 121 | + "out_shift": out_shift[0].item(), |
| 122 | + } |
| 123 | + replacement_op = pattern.replacement_op() # pyre-ignore[16] |
| 124 | + if replacement_op == torch.ops.cadence.quantized_conv1d_ncl.per_tensor: |
| 125 | + input_node = get_arg(dq_input, "input", fx.Node) |
| 126 | + assert len(input_node.meta["val"].shape) >= 2 |
| 127 | + in_channels = input_node.meta["val"].shape[1] |
| 128 | + if is_depthwise_conv(groups, in_channels): |
| 129 | + replacement_op = torch.ops.cadence.quantized_depthwise_conv1d_ncl.per_tensor |
| 130 | + return replace_with_op(gm, conv_node, replacement_op, args, kwargs, quant_node) |
| 131 | + |
| 132 | + |
| 133 | +def fuse_linear( |
| 134 | + gm: fx.GraphModule, |
| 135 | + dq_input: fx.Node, |
| 136 | + dq_weight: fx.Node, |
| 137 | + dq_bias: fx.Node | None, |
| 138 | + quant_node: fx.Node, |
| 139 | + op_node: fx.Node, |
| 140 | + replacement_op: OpOverload, |
| 141 | + weight_q: fx.Node | None = None, |
| 142 | +) -> fx.Node: |
| 143 | + """Fuse a dq->linear->q chain into a single quantized linear op.""" |
| 144 | + assert op_node.target in ( |
| 145 | + torch.ops.aten.linear.default, |
| 146 | + torch.ops.aten.addmm.default, |
| 147 | + ), f"Expected linear/addmm, got {op_node.target}" |
| 148 | + weight_scale = get_arg(dq_weight, "scale", float) |
| 149 | + input_scale = get_arg(dq_input, "scale", float) |
| 150 | + bias_scale = input_scale * weight_scale |
| 151 | + requantize_scale = bias_scale / get_arg(quant_node, "scale", float) |
| 152 | + requantize_scale_t = torch.tensor([requantize_scale]) |
| 153 | + out_multiplier, out_shift = quantize_tensor_multiplier(requantize_scale_t) |
| 154 | + if dq_bias is not None: |
| 155 | + bias_q = get_arg(dq_bias, "input", fx.Node) |
| 156 | + else: |
| 157 | + # Cadence quantized linear ops require a non-optional bias argument. |
| 158 | + weight_node = get_arg(dq_weight, "input", fx.Node) |
| 159 | + with gm.graph.inserting_before(op_node): |
| 160 | + bias_q = create_zero_bias_int32(gm, weight_node, bias_scale) |
| 161 | + final_weight = ( |
| 162 | + weight_q if weight_q is not None else get_arg(dq_weight, "input", fx.Node) |
| 163 | + ) |
| 164 | + args = (get_arg(dq_input, "input", fx.Node), final_weight, bias_q) |
| 165 | + kwargs = { |
| 166 | + "src_zero_point": get_arg(dq_input, "zero_point", int), |
| 167 | + "weight_zero_point": get_arg(dq_weight, "zero_point", int), |
| 168 | + "out_multiplier": out_multiplier[0].item(), |
| 169 | + "out_shift": out_shift[0].item(), |
| 170 | + "out_zero_point": get_arg(quant_node, "zero_point", int), |
| 171 | + "offset": None, |
| 172 | + } |
| 173 | + return replace_with_op(gm, op_node, replacement_op, args, kwargs, quant_node) |
| 174 | + |
| 175 | + |
| 176 | +def fuse_matmul( |
| 177 | + gm: fx.GraphModule, |
| 178 | + anchor_node: fx.Node, |
| 179 | + dq0: fx.Node, |
| 180 | + dq1: fx.Node, |
| 181 | + quant_node: fx.Node, |
| 182 | + replacement_op: OpOverload, |
| 183 | +) -> fx.Node: |
| 184 | + """Fuse a dq->matmul->q chain into a single quantized matmul op.""" |
| 185 | + assert anchor_node.target in ( |
| 186 | + torch.ops.aten.bmm.default, |
| 187 | + torch.ops.aten.matmul.default, |
| 188 | + ), f"Expected bmm/matmul, got {anchor_node.target}" |
| 189 | + scale0 = get_arg(dq0, "scale", float) |
| 190 | + scale1 = get_arg(dq1, "scale", float) |
| 191 | + requantize_scale = (scale0 * scale1) / get_arg(quant_node, "scale", float) |
| 192 | + requantize_scale_t = torch.tensor([requantize_scale]) |
| 193 | + out_multiplier, out_shift = quantize_tensor_multiplier(requantize_scale_t) |
| 194 | + args = ( |
| 195 | + get_arg(dq0, "input", fx.Node), |
| 196 | + get_arg(dq0, "zero_point", int), |
| 197 | + get_arg(dq1, "input", fx.Node), |
| 198 | + get_arg(dq1, "zero_point", int), |
| 199 | + None, |
| 200 | + ) |
| 201 | + kwargs = { |
| 202 | + "out_multiplier": out_multiplier[0].item(), |
| 203 | + "out_shift": out_shift[0].item(), |
| 204 | + "out_zero_point": get_arg(quant_node, "zero_point", int), |
| 205 | + "transposed": False, |
| 206 | + } |
| 207 | + return replace_with_op(gm, anchor_node, replacement_op, args, kwargs, quant_node) |
0 commit comments