1616)
1717from executorch .backends .arm ._passes .size_adjust_input_pass import SizeAdjustInputPass
1818from executorch .backends .arm .constants import DQ_OPS , Q_OPS
19- from executorch .exir .backend .utils import WhyNoPartitionReporter
2019from executorch .exir .dialects ._ops import ops as exir_ops
2120from executorch .exir .pass_base import ExportPass
2221
@@ -51,14 +50,6 @@ def get_dynamic_meandim_decomposition(op) -> tuple:
5150 raise RuntimeError (f"Can't get meandim decomposition for op { op } " )
5251
5352
54- def get_avgpool (op ):
55- if op in (exir_ops .edge .aten .mean .dim , exir_ops .edge .aten .mean .default ):
56- return exir_ops .edge .aten .avg_pool2d .default
57- if op in (torch .ops .aten .mean .dim , torch .ops .aten .mean .default ):
58- return torch .ops .aten .avg_pool2d .default
59- raise RuntimeError (f"Can't get meandim decomposition for op { op } " )
60-
61-
6253def get_view (op ):
6354 if op in (exir_ops .edge .aten .mean .dim , exir_ops .edge .aten .mean .default ):
6455 return exir_ops .edge .aten .view_copy .default
@@ -79,23 +70,21 @@ def get_quantization(op):
7970
8071
8172class DecomposeMeanDimPass (ArmPass ):
82- """Decomposes a meandim into avg_pool and/or sum + mul (1/N).
83-
84- ::
73+ """Decomposes a meandim into sum + mul (1/N).
8574
86- h, w -> avg_pool
87- n, c -> sum + mul(1/N)
75+ Each reduction dimension is handled via REDUCE_SUM followed by
76+ multiplication by 1/N, which works on any axis without layout
77+ constraints (unlike AVG_POOL2D which only pools over spatial H×W).
8878
8979 For rank < 4, the input is reshaped to 4D by padding with dim=1 from the
9080 left.
9181
9282 Example:
9383 x = mean_dim(x, (0,2), keepdim=False) # x = (c,h,w)
9484 Becomes:
95- x = view_copy.default(x, new_shape=(1,c,h,w)) # Reshape to work with avg_pool
96- x = avg_pool2d.default(x, kernel=(1,w), stride=(1,1)) # Reduce w with avg_pool
97- x = sum.dim_IntList(x, dim=1, keepdims=True) # Reduce c with sum
98- x = mul.Tensor(x, 1/c) # Divide by number of channels to get mean
85+ x = view_copy.default(x, new_shape=(1,c,h,w)) # Reshape to 4D
86+ x = sum.dim_IntList(x, dim=(1,3), keepdims=True) # Reduce c,w with sum
87+ x = mul.Tensor(x, 1/(c*w)) # Divide by number of elements to get mean
9988 x = view_copy.default(x, new_shape=(h)) # Squeeze dims since keepdims = False
10089
10190 """
@@ -110,14 +99,6 @@ def __init__(self, graph_module, tosa_spec, *args, **kwargs):
11099 super ().__init__ (* args , ** kwargs )
111100 self ._graph_module = graph_module
112101 self ._tosa_spec = tosa_spec
113- # Lazy import to avoid circular dependency with operator_support
114- from executorch .backends .arm .operator_support .pool_2d_support import (
115- AvgPool2dSupported ,
116- )
117-
118- self ._avg_pool_checker = AvgPool2dSupported (
119- self ._tosa_spec , WhyNoPartitionReporter ()
120- )
121102
122103 def call_operator (self , op , args , kwargs , meta , updated = False ):
123104 if op not in (
@@ -168,12 +149,6 @@ def call_operator(self, op, args, kwargs, meta, updated=False):
168149 x = super ().call_operator (view_op , (x , new_shape ), {}, meta , True )
169150 x = self ._maybe_insert_q_dq_after (x , meta )
170151
171- # Reduce (h,w) dims by avg pool if possible
172- if not has_symbolic_reduce_dim :
173- x , dims_to_reduce = self ._reduce_by_average_pool (
174- op , x , dims_to_reduce , meta
175- )
176-
177152 # Reshape back to 5D if necessary
178153 if len (input_shape ) > 4 :
179154 original_dims = input_shape [:- 3 ]
@@ -259,44 +234,6 @@ def _reduce_by_sum(self, op, input_node, dims, meta):
259234
260235 return super ().call_operator (mul_op , (sum , divisor ), {}, meta , True )
261236
262- def _reduce_by_average_pool (self , op , input_node , dims , meta ):
263- dims_to_reduce_by_avgpool = [dim for dim in dims if dim >= 2 ]
264- if len (dims_to_reduce_by_avgpool ) == 0 :
265- return input_node , dims
266-
267- dims_to_reduce_by_sum = [dim for dim in dims if dim < 2 ]
268-
269- avgpool_op = get_avgpool (op )
270- input_shape = input_node .data .size ()
271-
272- stride = [1 , 1 ]
273- if dims_to_reduce_by_avgpool in ([2 , 3 ], [3 , 2 ]):
274- kernel_size = [input_shape [2 ], input_shape [3 ]]
275- elif dims_to_reduce_by_avgpool == [3 ]:
276- kernel_size = [1 , input_shape [3 ]]
277- elif dims_to_reduce_by_avgpool == [2 ]:
278- kernel_size = [input_shape [2 ], 1 ]
279- else :
280- raise RuntimeError (
281- f"Bad dims { dims_to_reduce_by_avgpool } for { op } decomposition of mean_dim."
282- )
283-
284- args = (input_node , kernel_size , stride )
285-
286- avg_pool_node = self ._graph_module .graph .create_node (
287- "call_function" , avgpool_op , args
288- )
289- is_supported = self ._avg_pool_checker .is_node_tosa_supported (
290- avg_pool_node , self ._tosa_spec
291- )
292-
293- if is_supported :
294- out = super ().call_operator (avgpool_op , args , {}, meta , True )
295- out = self ._maybe_insert_q_dq_after (out , meta )
296- return out , dims_to_reduce_by_sum
297-
298- return input_node , dims
299-
300237 def _maybe_insert_q_dq_after (self , op , meta ):
301238 """If the input node of op is a dequant node, insert a q-dq pair after
302239 op with identical quantization parameters.
0 commit comments