|
4 | 4 | # LICENSE file in the root directory of this source tree. |
5 | 5 |
|
6 | 6 |
|
7 | | -from typing import Set, Type |
| 7 | +from typing import Any, Set, Type |
8 | 8 |
|
9 | 9 | import torch |
10 | 10 | from executorch.backends.arm._passes.arm_pass import ArmPass |
|
17 | 17 | from executorch.exir.dialects._ops import ops as exir_ops |
18 | 18 | from executorch.exir.pass_base import ExportPass |
19 | 19 |
|
20 | | -edge_div_ops = (exir_ops.edge.aten.avg_pool2d.default,) |
21 | | -aten_div_ops = (torch.ops.aten.avg_pool2d.default,) |
| 20 | +edge_avg_pool2d = (exir_ops.edge.aten.avg_pool2d.default,) |
| 21 | +aten_avg_pool2d = (torch.ops.aten.avg_pool2d.default,) |
22 | 22 |
|
23 | 23 |
|
24 | 24 | def get_decomposition(op) -> tuple: |
25 | | - if op in edge_div_ops: |
| 25 | + if op in edge_avg_pool2d: |
26 | 26 | return ( |
27 | | - exir_ops.edge.aten.full.default, |
28 | | - exir_ops.edge.aten.cat.default, |
| 27 | + exir_ops.edge.aten.constant_pad_nd.default, |
29 | 28 | exir_ops.edge.aten.avg_pool2d.default, |
30 | 29 | exir_ops.edge.aten.mul.Tensor, |
31 | 30 | ) |
32 | | - if op in aten_div_ops: |
| 31 | + if op in aten_avg_pool2d: |
33 | 32 | return ( |
34 | | - torch.ops.aten.full.default, |
35 | | - torch.ops.aten.cat.default, |
| 33 | + torch.ops.aten.pad.default, |
36 | 34 | torch.ops.aten.avg_pool2d.default, |
37 | 35 | torch.ops.aten.mul.Tensor, |
38 | 36 | ) |
39 | 37 | raise RuntimeError(f"Can't get avg_pool2d decomposition for op {op}") |
40 | 38 |
|
41 | 39 |
|
| 40 | +def _compute_post_pad( |
| 41 | + size: int, |
| 42 | + kernel: int, |
| 43 | + stride: int, |
| 44 | + pad: int, |
| 45 | + ceil_mode: bool, |
| 46 | + divisor_override, |
| 47 | +) -> int: |
| 48 | + |
| 49 | + if pad == 0: |
| 50 | + return pad |
| 51 | + if ceil_mode and divisor_override is None: |
| 52 | + return pad |
| 53 | + |
| 54 | + pad_adjust = adjust_pooling_pad_if_needed(size, kernel, stride, pad, ceil_mode) |
| 55 | + |
| 56 | + # Padding must always be above 0, the above adjustment may return -1 |
| 57 | + if pad_adjust > 0: |
| 58 | + return pad_adjust |
| 59 | + return pad |
| 60 | + |
| 61 | + |
| 62 | +def _get_avgpool_post_pad( |
| 63 | + h, |
| 64 | + w, |
| 65 | + kernel: tuple, |
| 66 | + stride_h, |
| 67 | + stride_w, |
| 68 | + pad_h, |
| 69 | + pad_w, |
| 70 | + ceil_mode, |
| 71 | + count_include_pad, |
| 72 | + divisor_override, |
| 73 | +) -> tuple[list[Any], list[int]]: |
| 74 | + """Compute the post-padding configuration for avg_pool2d when pre- |
| 75 | + materializing explicit zero padding ahead of the pooling operation. |
| 76 | +
|
| 77 | + Given the original spatial dimensions (h, w), pooling kernel size, stride, |
| 78 | + and explicit pre-padding amounts (pad_h, pad_w), this function returns the |
| 79 | + additional padding to apply on the right and bottom edges so that avg_pool2d |
| 80 | + with count_include_pad and/or divisor_override produces the equivalent |
| 81 | + result without built-in padding. |
| 82 | +
|
| 83 | + """ |
| 84 | + |
| 85 | + k_h, k_w = kernel |
| 86 | + post_h, post_w = (0, 0) |
| 87 | + new_pad_h, new_pad_w = pad_h, pad_w |
| 88 | + |
| 89 | + if not count_include_pad: |
| 90 | + return [new_pad_h, new_pad_w], [new_pad_h, new_pad_w] |
| 91 | + |
| 92 | + post_h = _compute_post_pad(h, k_h, stride_h, pad_h, ceil_mode, divisor_override) |
| 93 | + post_w = _compute_post_pad(w, k_w, stride_w, pad_w, ceil_mode, divisor_override) |
| 94 | + |
| 95 | + # Return our pre-padding calculation. Turn off built-in padding. |
| 96 | + return [pad_w, post_w, pad_h, post_h], [0, 0] |
| 97 | + |
| 98 | + |
42 | 99 | class DecomposeAvgPool2dPass(ArmPass): |
43 | 100 | _passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOTPass} |
44 | 101 |
|
45 | 102 | def call_operator(self, op, args, kwargs, meta): |
46 | | - if op not in (edge_div_ops + aten_div_ops) or not self.allowed_to_transform( |
47 | | - meta |
48 | | - ): |
| 103 | + if op not in ( |
| 104 | + edge_avg_pool2d + aten_avg_pool2d |
| 105 | + ) or not self.allowed_to_transform(meta): |
49 | 106 | return super().call_operator(op, args, kwargs, meta) |
50 | 107 |
|
51 | | - full_op, cat_op, avgpool_op, mul_op = get_decomposition(op) |
| 108 | + pad_op, avgpool_op, mul_op = get_decomposition(op) |
52 | 109 |
|
53 | 110 | x = args[0] |
54 | | - full_kwargs = {"device": x.data.device, "dtype": x.data.dtype} |
55 | 111 | kernel_h, kernel_w = args[1] |
56 | 112 | kernel_size = kernel_h * kernel_w |
| 113 | + |
57 | 114 | if len(args) > 2 and args[2] is not None: |
58 | 115 | stride_h, stride_w = args[2] |
59 | 116 | else: |
60 | 117 | stride_h, stride_w = kernel_h, kernel_w |
61 | | - pad_h, pad_w = new_pad_h, new_pad_w = args[3] if len(args) > 3 else (0, 0) |
| 118 | + pad_h, pad_w = args[3] if len(args) > 3 else (0, 0) |
62 | 119 | ceil_mode = args[4] if len(args) > 4 else False |
63 | 120 | count_include_pad = args[5] if len(args) > 5 else True |
64 | 121 | divisor_override = args[6] if len(args) > 6 else None |
65 | 122 |
|
66 | 123 | n, c, h, w = x.data.shape |
67 | | - post_pad_w, post_pad_h = (0, 0) |
68 | 124 |
|
69 | 125 | # Count_include_pad == False means that we use a different divisor for edge elements |
70 | 126 | # When divisor_override is set, this will be overriden anyways. |
71 | 127 | # It is easier to replace a constant divisor, so set count_include_pad == True |
72 | 128 | if divisor_override is not None: |
73 | 129 | count_include_pad = True |
74 | 130 |
|
75 | | - # Add width padding manually if count_include_pad |
76 | | - if count_include_pad and pad_w > 0: |
77 | | - pre_pad_shape = [n, c, h, pad_w] |
78 | | - pre_pad = super().call_operator( |
79 | | - full_op, (pre_pad_shape, 0.0), full_kwargs, meta, updated=True |
80 | | - ) |
81 | | - |
82 | | - if ceil_mode and divisor_override is None: |
83 | | - post_pad_w = pad_w |
84 | | - else: |
85 | | - post_pad_w = adjust_pooling_pad_if_needed( |
86 | | - w, kernel_w, stride_w, pad_w, ceil_mode |
87 | | - ) |
88 | | - |
89 | | - if post_pad_w > 0: |
90 | | - post_pad_shape = [n, c, h, post_pad_w] |
91 | | - post_pad = super().call_operator( |
92 | | - full_op, (post_pad_shape, 0.0), full_kwargs, meta, updated=True |
93 | | - ) |
94 | | - cat_nodes = [pre_pad, x, post_pad] |
95 | | - else: |
96 | | - cat_nodes = [pre_pad, x] |
97 | | - |
98 | | - x = super().call_operator( |
99 | | - cat_op, (cat_nodes, 3), kwargs, meta, updated=True |
100 | | - ) |
101 | | - new_pad_w = 0 |
102 | | - |
103 | | - # Add height padding manually if count_include_pad |
104 | | - if count_include_pad and pad_h > 0: |
105 | | - pre_pad_shape = [n, c, pad_h, w + pad_w + post_pad_w] |
106 | | - pre_pad = super().call_operator( |
107 | | - full_op, (pre_pad_shape, 0.0), full_kwargs, meta, updated=True |
108 | | - ) |
| 131 | + pad, new_pad = _get_avgpool_post_pad( |
| 132 | + h, |
| 133 | + w, |
| 134 | + args[1], |
| 135 | + stride_h, |
| 136 | + stride_w, |
| 137 | + pad_h, |
| 138 | + pad_w, |
| 139 | + ceil_mode, |
| 140 | + count_include_pad, |
| 141 | + divisor_override, |
| 142 | + ) |
109 | 143 |
|
110 | | - if ceil_mode and divisor_override is None: |
111 | | - post_pad_h = pad_h |
112 | | - else: |
113 | | - post_pad_h = adjust_pooling_pad_if_needed( |
114 | | - h, kernel_h, stride_h, pad_h, ceil_mode |
115 | | - ) |
116 | | - |
117 | | - if post_pad_h > 0: |
118 | | - post_pad_shape = [n, c, post_pad_h, w + pad_w + post_pad_w] |
119 | | - post_pad = super().call_operator( |
120 | | - full_op, (post_pad_shape, 0.0), full_kwargs, meta, updated=True |
121 | | - ) |
122 | | - cat_nodes = [pre_pad, x, post_pad] |
| 144 | + if count_include_pad and (pad_h > 0 or pad_w > 0): |
| 145 | + if op in aten_avg_pool2d: |
| 146 | + pad_args = (x, pad, "constant", 0.0) |
123 | 147 | else: |
124 | | - cat_nodes = [pre_pad, x] |
| 148 | + pad_args = (x, pad, 0.0) |
125 | 149 |
|
126 | 150 | x = super().call_operator( |
127 | | - cat_op, (cat_nodes, 2), kwargs, meta, updated=True |
| 151 | + pad_op, |
| 152 | + pad_args, |
| 153 | + {}, |
| 154 | + meta, |
| 155 | + updated=True, |
128 | 156 | ) |
129 | | - new_pad_h = 0 |
130 | 157 |
|
131 | 158 | avgpool_args = ( |
132 | 159 | x, |
133 | 160 | args[1], |
134 | 161 | [stride_h, stride_w], |
135 | | - [new_pad_h, new_pad_w], |
| 162 | + new_pad, |
136 | 163 | ceil_mode, |
137 | 164 | False, |
138 | 165 | ) |
| 166 | + |
139 | 167 | x = super().call_operator(avgpool_op, avgpool_args, kwargs, meta, updated=True) |
140 | 168 |
|
141 | | - # Multiply by factor (kernel_size / divisor_override) if divisor_override |
142 | 169 | if divisor_override is not None and divisor_override != kernel_size: |
143 | | - override_multiplier = super().call_operator( |
144 | | - full_op, |
145 | | - ([1, 1, 1, 1], kernel_size / divisor_override), |
146 | | - full_kwargs, |
| 170 | + x = super().call_operator( |
| 171 | + mul_op, |
| 172 | + (x, super().call_scalar(kernel_size / divisor_override, meta)), |
| 173 | + {}, |
147 | 174 | meta, |
148 | 175 | updated=True, |
149 | 176 | ) |
150 | | - x = super().call_operator( |
151 | | - mul_op, (x, override_multiplier), kwargs, meta, updated=True |
152 | | - ) |
153 | 177 |
|
154 | 178 | return x |
0 commit comments