Skip to content

Commit 68c038b

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Replace AVG_POOL2D with REDUCE_SUM in DecomposeMeanDimPass (#19242)
Summary: Replace the avg_pool2d decomposition path in DecomposeMeanDimPass with REDUCE_SUM + MUL(1/N) for all mean.dim reductions. AVG_POOL2D can only pool over spatial (H×W) axes in TOSA/NHWC layout, which forces the compiler to insert TRANSPOSE ops when the reduction is over channels (common in LayerNorm). REDUCE_SUM works on any axis without layout constraints, avoiding those transposes entirely. Reviewed By: 3l1 Differential Revision: D101418199
1 parent a3dd0fa commit 68c038b

5 files changed

Lines changed: 14 additions & 79 deletions

File tree

backends/arm/_passes/decompose_meandim_pass.py

Lines changed: 7 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
)
1717
from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass
1818
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
19-
from executorch.exir.backend.utils import WhyNoPartitionReporter
2019
from executorch.exir.dialects._ops import ops as exir_ops
2120
from executorch.exir.pass_base import ExportPass
2221

@@ -51,14 +50,6 @@ def get_dynamic_meandim_decomposition(op) -> tuple:
5150
raise RuntimeError(f"Can't get meandim decomposition for op {op}")
5251

5352

54-
def get_avgpool(op):
55-
if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default):
56-
return exir_ops.edge.aten.avg_pool2d.default
57-
if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default):
58-
return torch.ops.aten.avg_pool2d.default
59-
raise RuntimeError(f"Can't get meandim decomposition for op {op}")
60-
61-
6253
def get_view(op):
6354
if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default):
6455
return exir_ops.edge.aten.view_copy.default
@@ -79,23 +70,21 @@ def get_quantization(op):
7970

8071

8172
class DecomposeMeanDimPass(ArmPass):
82-
"""Decomposes a meandim into avg_pool and/or sum + mul (1/N).
83-
84-
::
73+
"""Decomposes a meandim into sum + mul (1/N).
8574
86-
h, w -> avg_pool
87-
n, c -> sum + mul(1/N)
75+
Each reduction dimension is handled via REDUCE_SUM followed by
76+
multiplication by 1/N, which works on any axis without layout
77+
constraints (unlike AVG_POOL2D which only pools over spatial H×W).
8878
8979
For rank < 4, the input is reshaped to 4D by padding with dim=1 from the
9080
left.
9181
9282
Example:
9383
x = mean_dim(x, (0,2), keepdim=False) # x = (c,h,w)
9484
Becomes:
95-
x = view_copy.default(x, new_shape=(1,c,h,w)) # Reshape to work with avg_pool
96-
x = avg_pool2d.default(x, kernel=(1,w), stride=(1,1)) # Reduce w with avg_pool
97-
x = sum.dim_IntList(x, dim=1, keepdims=True) # Reduce c with sum
98-
x = mul.Tensor(x, 1/c) # Divide by number of channels to get mean
85+
x = view_copy.default(x, new_shape=(1,c,h,w)) # Reshape to 4D
86+
x = sum.dim_IntList(x, dim=(1,3), keepdims=True) # Reduce c,w with sum
87+
x = mul.Tensor(x, 1/(c*w)) # Divide by number of elements to get mean
9988
x = view_copy.default(x, new_shape=(h)) # Squeeze dims since keepdims = False
10089
10190
"""
@@ -110,14 +99,6 @@ def __init__(self, graph_module, tosa_spec, *args, **kwargs):
11099
super().__init__(*args, **kwargs)
111100
self._graph_module = graph_module
112101
self._tosa_spec = tosa_spec
113-
# Lazy import to avoid circular dependency with operator_support
114-
from executorch.backends.arm.operator_support.pool_2d_support import (
115-
AvgPool2dSupported,
116-
)
117-
118-
self._avg_pool_checker = AvgPool2dSupported(
119-
self._tosa_spec, WhyNoPartitionReporter()
120-
)
121102

122103
def call_operator(self, op, args, kwargs, meta, updated=False):
123104
if op not in (
@@ -168,12 +149,6 @@ def call_operator(self, op, args, kwargs, meta, updated=False):
168149
x = super().call_operator(view_op, (x, new_shape), {}, meta, True)
169150
x = self._maybe_insert_q_dq_after(x, meta)
170151

171-
# Reduce (h,w) dims by avg pool if possible
172-
if not has_symbolic_reduce_dim:
173-
x, dims_to_reduce = self._reduce_by_average_pool(
174-
op, x, dims_to_reduce, meta
175-
)
176-
177152
# Reshape back to 5D if necessary
178153
if len(input_shape) > 4:
179154
original_dims = input_shape[:-3]
@@ -259,44 +234,6 @@ def _reduce_by_sum(self, op, input_node, dims, meta):
259234

260235
return super().call_operator(mul_op, (sum, divisor), {}, meta, True)
261236

262-
def _reduce_by_average_pool(self, op, input_node, dims, meta):
263-
dims_to_reduce_by_avgpool = [dim for dim in dims if dim >= 2]
264-
if len(dims_to_reduce_by_avgpool) == 0:
265-
return input_node, dims
266-
267-
dims_to_reduce_by_sum = [dim for dim in dims if dim < 2]
268-
269-
avgpool_op = get_avgpool(op)
270-
input_shape = input_node.data.size()
271-
272-
stride = [1, 1]
273-
if dims_to_reduce_by_avgpool in ([2, 3], [3, 2]):
274-
kernel_size = [input_shape[2], input_shape[3]]
275-
elif dims_to_reduce_by_avgpool == [3]:
276-
kernel_size = [1, input_shape[3]]
277-
elif dims_to_reduce_by_avgpool == [2]:
278-
kernel_size = [input_shape[2], 1]
279-
else:
280-
raise RuntimeError(
281-
f"Bad dims {dims_to_reduce_by_avgpool} for {op} decomposition of mean_dim."
282-
)
283-
284-
args = (input_node, kernel_size, stride)
285-
286-
avg_pool_node = self._graph_module.graph.create_node(
287-
"call_function", avgpool_op, args
288-
)
289-
is_supported = self._avg_pool_checker.is_node_tosa_supported(
290-
avg_pool_node, self._tosa_spec
291-
)
292-
293-
if is_supported:
294-
out = super().call_operator(avgpool_op, args, {}, meta, True)
295-
out = self._maybe_insert_q_dq_after(out, meta)
296-
return out, dims_to_reduce_by_sum
297-
298-
return input_node, dims
299-
300237
def _maybe_insert_q_dq_after(self, op, meta):
301238
"""If the input node of op is a dequant node, insert a q-dq pair after
302239
op with identical quantization parameters.

backends/arm/test/misc/test_transpose_counts.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ def forward(self, x):
404404
"groupnorm": TransposeCountCase(
405405
GroupNormModule(),
406406
(torch.randn(1, 4, 4, 4),),
407-
1,
407+
0,
408408
),
409409
"multihead_attention_rank2": TransposeCountCase(
410410
MultiheadAttentionModule(),
@@ -430,16 +430,16 @@ def forward(self, x):
430430
Model1ConvMaxPoolResidualLinear(), (torch.randn(2, 8, 64),), 5
431431
),
432432
"model_2_conv_mha_linear_layernorm": TransposeCountCase(
433-
Model2ConvMhaLinearLayerNorm(), (torch.randn(2, 8, 32),), 11
433+
Model2ConvMhaLinearLayerNorm(), (torch.randn(2, 8, 32),), 9
434434
),
435435
"model_3_lstm_linear": TransposeCountCase(
436436
Model3LstmLinear(), (torch.randn(2, 16, 8),), 2
437437
),
438438
"model_4_conv_lstm_linear_layernorm": TransposeCountCase(
439-
Model4ConvLstmLinearLayerNorm(), (torch.randn(2, 8, 32),), 5
439+
Model4ConvLstmLinearLayerNorm(), (torch.randn(2, 8, 32),), 3
440440
),
441441
"model_5_dwconv_gelu_layernorm_avgpool": TransposeCountCase(
442-
Model5DwConvGeluLayerNormAvgPool(), (torch.randn(1, 8, 16, 16),), 6
442+
Model5DwConvGeluLayerNormAvgPool(), (torch.randn(1, 8, 16, 16),), 4
443443
),
444444
"model_6_gru_linear": TransposeCountCase(
445445
Model6GruLinear(), (torch.randn(2, 16, 8),), 2
@@ -521,7 +521,7 @@ def forward(self, x):
521521
"groupnorm_channels_last": TransposeCountCase(
522522
GroupNormModule(),
523523
(torch.randn(1, 4, 4, 4).to(memory_format=torch.channels_last),),
524-
3,
524+
2,
525525
),
526526
"cumsum_rank4_dim3_channels_last": TransposeCountCase(
527527
CumsumModule(),

backends/arm/test/ops/test_cond.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def true_branch(arg: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
8282
return arg + torch.sin(arg), arg - torch.sin(arg)
8383

8484
def false_branch(arg: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
85-
return arg - arg.mean(), arg + arg.mean()
85+
return arg - torch.cos(arg), arg + torch.cos(arg)
8686

8787
predicate = x.flatten().sum() > 0
8888
return torch.cond(predicate, true_branch, false_branch, [x])

backends/arm/test/ops/test_layer_norm.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,6 @@ def test_native_layer_norm_16a8w_u55_INT(test_data):
204204

205205
u85_xfails_16a8w = {
206206
"randn_last_dim": "MLETORCH-1834 - 16A8W native_layer_norm output diff for certain configurations.",
207-
"randn_last_three_dims": "MLETORCH-1834 - 16A8W native_layer_norm output diff for certain configurations.",
208-
"randn_last_three_dims_no_bias": "MLETORCH-1834 - 16A8W native_layer_norm output diff for certain configurations.",
209207
}
210208

211209

backends/arm/test/passes/test_decompose_meandim_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,12 @@ class MeanDimTensor(torch.nn.Module):
5656
ops_after_pass = {
5757
"torch.ops.aten.sum.dim_IntList": 2,
5858
"torch.ops.aten.mul.Tensor": 1,
59-
"torch.ops.aten.avg_pool2d.default": 1,
6059
"torch.ops.aten.reshape.default": 1,
6160
}
6261

6362
ops_not_after_pass = [
6463
"torch.ops.aten.mean.dim",
64+
"torch.ops.aten.avg_pool2d.default",
6565
]
6666

6767
u55_ops_after_pass = {

0 commit comments

Comments
 (0)