|
7 | 7 |
|
8 | 8 | import torch |
9 | 9 | from executorch.backends.arm._passes import ArmPass |
| 10 | +from executorch.backends.arm._passes.arm_pass_utils import ( |
| 11 | + create_node, |
| 12 | + get_first_fake_tensor, |
| 13 | +) |
10 | 14 | from executorch.backends.arm.operators.operator_validation_utils import ( |
11 | 15 | adjust_pooling_pad_if_needed, |
12 | 16 | ) |
13 | 17 | from executorch.exir.dialects._ops import ops as exir_ops |
14 | | -from executorch.exir.pass_base import ExportPass |
| 18 | +from executorch.exir.pass_base import ExportPass, PassResult |
15 | 19 |
|
16 | 20 | from .fuse_constant_ops_pass import ComputeConstantOpsAOTPass |
17 | 21 |
|
| 22 | +_NCHW_TO_NHWC = [0, 2, 3, 1] |
| 23 | +_NHWC_TO_NCHW = [0, 3, 1, 2] |
| 24 | + |
18 | 25 |
|
19 | 26 | class RewriteAvgPool2dPass(ArmPass): |
20 | | - """Rewrite aten.avg_pool2d calls to TOSA AVG_POOL2D op.""" |
| 27 | + """Rewrite aten.avg_pool2d calls to TOSA AVG_POOL2D op with NHWC layout.""" |
21 | 28 |
|
22 | | - # Target the original avg_pool2d operator |
23 | 29 | targeted_ops = {exir_ops.edge.aten.avg_pool2d.default} |
24 | 30 | _passes_required_after: Set[Type[ExportPass]] = { |
25 | 31 | ComputeConstantOpsAOTPass, |
26 | 32 | } |
27 | 33 |
|
28 | | - def call_operator(self, op, args, kwargs, meta, updated=False): |
29 | | - |
30 | | - # Only rewrite avg_pool2d |
31 | | - if op not in self.targeted_ops: |
32 | | - return super().call_operator(op, args, kwargs, meta, updated) |
33 | | - |
34 | | - x = args[0] |
35 | | - pad_h, pad_w = args[3] |
36 | | - # Make sure pad corresponds to TOSA |
37 | | - pad = [pad_h, pad_w, pad_h, pad_w] |
38 | | - |
39 | | - _, _, h, w = x.data.shape |
40 | | - kernel_h, kernel_w = args[1] |
41 | | - stride_h, stride_w = args[2] |
42 | | - |
43 | | - ceil_mode = args[4] if len(args) > 4 else False |
44 | | - |
45 | | - # Adjust padding if necessary |
46 | | - pad[1] = adjust_pooling_pad_if_needed(h, kernel_h, stride_h, pad[1], ceil_mode) |
47 | | - pad[3] = adjust_pooling_pad_if_needed(w, kernel_w, stride_w, pad[3], ceil_mode) |
48 | | - |
49 | | - # Materialize zero-point constants |
50 | | - in_qparams = meta.data.get("input_qparams", {}) |
51 | | - in_zp_val = in_qparams[0].get_zp_per_tensor() if 0 in in_qparams else 0 |
52 | | - # Materialize input zero-point as a scalar tensor |
53 | | - input_zp = super().call_scalar(in_zp_val, meta) |
54 | | - |
55 | | - out_qparams = meta.data.get("output_qparams", {}) |
56 | | - out_zp_val = out_qparams[0].get_zp_per_tensor() if 0 in out_qparams else 0 |
57 | | - # Materialize output zero-point as a scalar tensor |
58 | | - output_zp = super().call_scalar(out_zp_val, meta) |
59 | | - |
60 | | - # Determine accumulator dtype for AVG_POOL2D: INT32 for integer inputs, FP32 otherwise |
61 | | - if x.data.dtype in (torch.int8, torch.int16): |
62 | | - acc_type = torch.int32 |
63 | | - else: |
64 | | - acc_type = torch.float32 |
65 | | - |
66 | | - tosa_args = (args[0], input_zp, output_zp, *args[1:3], pad, acc_type) |
67 | | - |
68 | | - # Emit TOSA AVG_POOL2D with normalized args |
69 | | - return super().call_operator( |
70 | | - exir_ops.backend.tosa.AVG_POOL2D.default, |
71 | | - tosa_args, |
72 | | - {}, |
73 | | - meta, |
74 | | - True, |
| 34 | + @staticmethod |
| 35 | + def _insert_permute(graph_module, anchor_node, input_node, perm, before=True): |
| 36 | + ctx = ( |
| 37 | + graph_module.graph.inserting_before(anchor_node) |
| 38 | + if before |
| 39 | + else graph_module.graph.inserting_after(anchor_node) |
75 | 40 | ) |
| 41 | + with ctx: |
| 42 | + return create_node( |
| 43 | + graph=graph_module.graph, |
| 44 | + op_target=exir_ops.edge.aten.permute_copy.default, |
| 45 | + args=(input_node, perm), |
| 46 | + from_node=input_node, |
| 47 | + ) |
| 48 | + |
| 49 | + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: |
| 50 | + modified = False |
| 51 | + |
| 52 | + for node in list(graph_module.graph.nodes): |
| 53 | + if node.op != "call_function" or node.target not in self.targeted_ops: |
| 54 | + continue |
| 55 | + |
| 56 | + modified = True |
| 57 | + x = node.args[0] |
| 58 | + |
| 59 | + pad_h, pad_w = node.args[3] |
| 60 | + pad = [pad_h, pad_w, pad_h, pad_w] |
| 61 | + |
| 62 | + input_fake = get_first_fake_tensor(x) |
| 63 | + _, _, h, w = input_fake.shape |
| 64 | + kernel_h, kernel_w = node.args[1] |
| 65 | + stride_h, stride_w = node.args[2] |
| 66 | + |
| 67 | + ceil_mode = node.args[4] if len(node.args) > 4 else False |
| 68 | + |
| 69 | + pad[1] = adjust_pooling_pad_if_needed(h, kernel_h, stride_h, pad[1], ceil_mode) |
| 70 | + pad[3] = adjust_pooling_pad_if_needed(w, kernel_w, stride_w, pad[3], ceil_mode) |
| 71 | + |
| 72 | + # Determine zero-points and accumulator type |
| 73 | + in_qparams = node.meta.get("input_qparams", {}) |
| 74 | + in_zp_val = in_qparams[0].get_zp_per_tensor() if 0 in in_qparams else 0 |
| 75 | + |
| 76 | + out_qparams = node.meta.get("output_qparams", {}) |
| 77 | + out_zp_val = out_qparams[0].get_zp_per_tensor() if 0 in out_qparams else 0 |
| 78 | + |
| 79 | + if input_fake.dtype in (torch.int8, torch.int16): |
| 80 | + acc_type = torch.int32 |
| 81 | + else: |
| 82 | + acc_type = torch.float32 |
| 83 | + |
| 84 | + # Insert NCHW → NHWC permute on input |
| 85 | + x_permuted = self._insert_permute( |
| 86 | + graph_module, node, x, _NCHW_TO_NHWC, before=True |
| 87 | + ) |
| 88 | + |
| 89 | + # Materialize zp scalars as graph constants using aten.full with |
| 90 | + # explicit dtype matching the input tensor. This ensures the |
| 91 | + # pre-computed buffer placeholders carry the correct type for |
| 92 | + # INT-only TOSA profiles (avoids defaulting to float32). |
| 93 | + zp_kwargs = {"dtype": input_fake.dtype, "device": input_fake.device} |
| 94 | + with graph_module.graph.inserting_before(node): |
| 95 | + input_zp_node = create_node( |
| 96 | + graph=graph_module.graph, |
| 97 | + op_target=exir_ops.edge.aten.full.default, |
| 98 | + args=((1,), in_zp_val), |
| 99 | + kwargs=zp_kwargs, |
| 100 | + from_node=node, |
| 101 | + ) |
| 102 | + output_zp_node = create_node( |
| 103 | + graph=graph_module.graph, |
| 104 | + op_target=exir_ops.edge.aten.full.default, |
| 105 | + args=((1,), out_zp_val), |
| 106 | + kwargs=zp_kwargs, |
| 107 | + from_node=node, |
| 108 | + ) |
| 109 | + |
| 110 | + kernel = list(node.args[1]) |
| 111 | + stride = list(node.args[2]) |
| 112 | + |
| 113 | + tosa_args = (x_permuted, input_zp_node, output_zp_node, kernel, stride, pad, acc_type) |
| 114 | + |
| 115 | + # Create TOSA AVG_POOL2D node |
| 116 | + with graph_module.graph.inserting_after(node): |
| 117 | + tosa_op = create_node( |
| 118 | + graph=graph_module.graph, |
| 119 | + op_target=exir_ops.backend.tosa.AVG_POOL2D.default, |
| 120 | + args=tosa_args, |
| 121 | + from_node=node, |
| 122 | + inherit_qparams=True, |
| 123 | + ) |
| 124 | + |
| 125 | + # Compute correct NHWC FakeTensor |
| 126 | + input_fake_nhwc = input_fake.permute(_NCHW_TO_NHWC) |
| 127 | + input_zp_fake = torch.tensor(in_zp_val, dtype=input_fake.dtype) |
| 128 | + output_zp_fake = torch.tensor(out_zp_val, dtype=input_fake.dtype) |
| 129 | + tosa_node_fake = exir_ops.backend.tosa.AVG_POOL2D.default( |
| 130 | + input_fake_nhwc, input_zp_fake, output_zp_fake, kernel, stride, pad, acc_type |
| 131 | + ) |
| 132 | + tosa_op.meta["val"] = tosa_node_fake |
| 133 | + |
| 134 | + # Insert NHWC → NCHW permute on output |
| 135 | + output_permute = self._insert_permute( |
| 136 | + graph_module, tosa_op, tosa_op, _NHWC_TO_NCHW, before=False |
| 137 | + ) |
| 138 | + |
| 139 | + node.replace_all_uses_with(output_permute) |
| 140 | + graph_module.graph.erase_node(node) |
| 141 | + |
| 142 | + if modified: |
| 143 | + graph_module.recompile() |
| 144 | + graph_module = super().call(graph_module).graph_module |
| 145 | + return PassResult(graph_module, modified) |
0 commit comments