Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 7 additions & 70 deletions backends/arm/_passes/decompose_meandim_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
)
from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
from executorch.exir.backend.utils import WhyNoPartitionReporter
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

Expand Down Expand Up @@ -51,14 +50,6 @@ def get_dynamic_meandim_decomposition(op) -> tuple:
raise RuntimeError(f"Can't get meandim decomposition for op {op}")


def get_avgpool(op):
if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default):
return exir_ops.edge.aten.avg_pool2d.default
if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default):
return torch.ops.aten.avg_pool2d.default
raise RuntimeError(f"Can't get meandim decomposition for op {op}")


def get_view(op):
if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default):
return exir_ops.edge.aten.view_copy.default
Expand All @@ -79,23 +70,21 @@ def get_quantization(op):


class DecomposeMeanDimPass(ArmPass):
"""Decomposes a meandim into avg_pool and/or sum + mul (1/N).

::
"""Decomposes a meandim into sum + mul (1/N).

h, w -> avg_pool
n, c -> sum + mul(1/N)
Each reduction dimension is handled via REDUCE_SUM followed by
multiplication by 1/N, which works on any axis without layout
constraints (unlike AVG_POOL2D which only pools over spatial H×W).

For rank < 4, the input is reshaped to 4D by padding with dim=1 from the
left.

Example:
x = mean_dim(x, (0,2), keepdim=False) # x = (c,h,w)
Becomes:
x = view_copy.default(x, new_shape=(1,c,h,w)) # Reshape to work with avg_pool
x = avg_pool2d.default(x, kernel=(1,w), stride=(1,1)) # Reduce w with avg_pool
x = sum.dim_IntList(x, dim=1, keepdims=True) # Reduce c with sum
x = mul.Tensor(x, 1/c) # Divide by number of channels to get mean
x = view_copy.default(x, new_shape=(1,c,h,w)) # Reshape to 4D
x = sum.dim_IntList(x, dim=(1,3), keepdims=True) # Reduce c,w with sum
x = mul.Tensor(x, 1/(c*w)) # Divide by number of elements to get mean
x = view_copy.default(x, new_shape=(h)) # Squeeze dims since keepdims = False

"""
Expand All @@ -110,14 +99,6 @@ def __init__(self, graph_module, tosa_spec, *args, **kwargs):
super().__init__(*args, **kwargs)
self._graph_module = graph_module
self._tosa_spec = tosa_spec
# Lazy import to avoid circular dependency with operator_support
from executorch.backends.arm.operator_support.pool_2d_support import (
AvgPool2dSupported,
)

self._avg_pool_checker = AvgPool2dSupported(
self._tosa_spec, WhyNoPartitionReporter()
)

def call_operator(self, op, args, kwargs, meta, updated=False):
if op not in (
Expand Down Expand Up @@ -168,12 +149,6 @@ def call_operator(self, op, args, kwargs, meta, updated=False):
x = super().call_operator(view_op, (x, new_shape), {}, meta, True)
x = self._maybe_insert_q_dq_after(x, meta)

# Reduce (h,w) dims by avg pool if possible
if not has_symbolic_reduce_dim:
x, dims_to_reduce = self._reduce_by_average_pool(
op, x, dims_to_reduce, meta
)

# Reshape back to 5D if necessary
if len(input_shape) > 4:
original_dims = input_shape[:-3]
Expand Down Expand Up @@ -259,44 +234,6 @@ def _reduce_by_sum(self, op, input_node, dims, meta):

return super().call_operator(mul_op, (sum, divisor), {}, meta, True)

def _reduce_by_average_pool(self, op, input_node, dims, meta):
dims_to_reduce_by_avgpool = [dim for dim in dims if dim >= 2]
if len(dims_to_reduce_by_avgpool) == 0:
return input_node, dims

dims_to_reduce_by_sum = [dim for dim in dims if dim < 2]

avgpool_op = get_avgpool(op)
input_shape = input_node.data.size()

stride = [1, 1]
if dims_to_reduce_by_avgpool in ([2, 3], [3, 2]):
kernel_size = [input_shape[2], input_shape[3]]
elif dims_to_reduce_by_avgpool == [3]:
kernel_size = [1, input_shape[3]]
elif dims_to_reduce_by_avgpool == [2]:
kernel_size = [input_shape[2], 1]
else:
raise RuntimeError(
f"Bad dims {dims_to_reduce_by_avgpool} for {op} decomposition of mean_dim."
)

args = (input_node, kernel_size, stride)

avg_pool_node = self._graph_module.graph.create_node(
"call_function", avgpool_op, args
)
is_supported = self._avg_pool_checker.is_node_tosa_supported(
avg_pool_node, self._tosa_spec
)

if is_supported:
out = super().call_operator(avgpool_op, args, {}, meta, True)
out = self._maybe_insert_q_dq_after(out, meta)
return out, dims_to_reduce_by_sum

return input_node, dims

def _maybe_insert_q_dq_after(self, op, meta):
"""If the input node of op is a dequant node, insert a q-dq pair after
op with identical quantization parameters.
Expand Down
10 changes: 5 additions & 5 deletions backends/arm/test/misc/test_transpose_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def forward(self, x):
"groupnorm": TransposeCountCase(
GroupNormModule(),
(torch.randn(1, 4, 4, 4),),
1,
0,
),
"multihead_attention_rank2": TransposeCountCase(
MultiheadAttentionModule(),
Expand All @@ -430,16 +430,16 @@ def forward(self, x):
Model1ConvMaxPoolResidualLinear(), (torch.randn(2, 8, 64),), 5
),
"model_2_conv_mha_linear_layernorm": TransposeCountCase(
Model2ConvMhaLinearLayerNorm(), (torch.randn(2, 8, 32),), 11
Model2ConvMhaLinearLayerNorm(), (torch.randn(2, 8, 32),), 9
),
"model_3_lstm_linear": TransposeCountCase(
Model3LstmLinear(), (torch.randn(2, 16, 8),), 2
),
"model_4_conv_lstm_linear_layernorm": TransposeCountCase(
Model4ConvLstmLinearLayerNorm(), (torch.randn(2, 8, 32),), 5
Model4ConvLstmLinearLayerNorm(), (torch.randn(2, 8, 32),), 3
),
"model_5_dwconv_gelu_layernorm_avgpool": TransposeCountCase(
Model5DwConvGeluLayerNormAvgPool(), (torch.randn(1, 8, 16, 16),), 6
Model5DwConvGeluLayerNormAvgPool(), (torch.randn(1, 8, 16, 16),), 4
),
"model_6_gru_linear": TransposeCountCase(
Model6GruLinear(), (torch.randn(2, 16, 8),), 2
Expand Down Expand Up @@ -521,7 +521,7 @@ def forward(self, x):
"groupnorm_channels_last": TransposeCountCase(
GroupNormModule(),
(torch.randn(1, 4, 4, 4).to(memory_format=torch.channels_last),),
3,
2,
),
"cumsum_rank4_dim3_channels_last": TransposeCountCase(
CumsumModule(),
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/ops/test_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def true_branch(arg: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return arg + torch.sin(arg), arg - torch.sin(arg)

def false_branch(arg: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return arg - arg.mean(), arg + arg.mean()
return arg - torch.cos(arg), arg + torch.cos(arg)

predicate = x.flatten().sum() > 0
return torch.cond(predicate, true_branch, false_branch, [x])
Expand Down
2 changes: 0 additions & 2 deletions backends/arm/test/ops/test_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,6 @@ def test_native_layer_norm_16a8w_u55_INT(test_data):

u85_xfails_16a8w = {
"randn_last_dim": "MLETORCH-1834 - 16A8W native_layer_norm output diff for certain configurations.",
"randn_last_three_dims": "MLETORCH-1834 - 16A8W native_layer_norm output diff for certain configurations.",
"randn_last_three_dims_no_bias": "MLETORCH-1834 - 16A8W native_layer_norm output diff for certain configurations.",
}


Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/passes/test_decompose_meandim_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ class MeanDimTensor(torch.nn.Module):
ops_after_pass = {
"torch.ops.aten.sum.dim_IntList": 2,
"torch.ops.aten.mul.Tensor": 1,
"torch.ops.aten.avg_pool2d.default": 1,
"torch.ops.aten.reshape.default": 1,
}

ops_not_after_pass = [
"torch.ops.aten.mean.dim",
"torch.ops.aten.avg_pool2d.default",
]

u55_ops_after_pass = {
Expand Down
Loading