Skip to content

Commit 34e160d

Browse files
authored
Merge branch 'main' into android-combined-v2
2 parents 04727f0 + 48a8d58 commit 34e160d

42 files changed

Lines changed: 1417 additions & 401 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

backends/arm/_passes/decompose_meandim_pass.py

Lines changed: 7 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
)
1717
from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass
1818
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
19-
from executorch.exir.backend.utils import WhyNoPartitionReporter
2019
from executorch.exir.dialects._ops import ops as exir_ops
2120
from executorch.exir.pass_base import ExportPass
2221

@@ -51,14 +50,6 @@ def get_dynamic_meandim_decomposition(op) -> tuple:
5150
raise RuntimeError(f"Can't get meandim decomposition for op {op}")
5251

5352

54-
def get_avgpool(op):
55-
if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default):
56-
return exir_ops.edge.aten.avg_pool2d.default
57-
if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default):
58-
return torch.ops.aten.avg_pool2d.default
59-
raise RuntimeError(f"Can't get meandim decomposition for op {op}")
60-
61-
6253
def get_view(op):
6354
if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default):
6455
return exir_ops.edge.aten.view_copy.default
@@ -79,23 +70,21 @@ def get_quantization(op):
7970

8071

8172
class DecomposeMeanDimPass(ArmPass):
82-
"""Decomposes a meandim into avg_pool and/or sum + mul (1/N).
83-
84-
::
73+
"""Decomposes a meandim into sum + mul (1/N).
8574
86-
h, w -> avg_pool
87-
n, c -> sum + mul(1/N)
75+
Each reduction dimension is handled via REDUCE_SUM followed by
76+
multiplication by 1/N, which works on any axis without layout
77+
constraints (unlike AVG_POOL2D which only pools over spatial H×W).
8878
8979
For rank < 4, the input is reshaped to 4D by padding with dim=1 from the
9080
left.
9181
9282
Example:
9383
x = mean_dim(x, (0,2), keepdim=False) # x = (c,h,w)
9484
Becomes:
95-
x = view_copy.default(x, new_shape=(1,c,h,w)) # Reshape to work with avg_pool
96-
x = avg_pool2d.default(x, kernel=(1,w), stride=(1,1)) # Reduce w with avg_pool
97-
x = sum.dim_IntList(x, dim=1, keepdims=True) # Reduce c with sum
98-
x = mul.Tensor(x, 1/c) # Divide by number of channels to get mean
85+
x = view_copy.default(x, new_shape=(1,c,h,w)) # Reshape to 4D
86+
x = sum.dim_IntList(x, dim=(1,3), keepdims=True) # Reduce c,w with sum
87+
x = mul.Tensor(x, 1/(c*w)) # Divide by number of elements to get mean
9988
x = view_copy.default(x, new_shape=(h)) # Squeeze dims since keepdims = False
10089
10190
"""
@@ -110,14 +99,6 @@ def __init__(self, graph_module, tosa_spec, *args, **kwargs):
11099
super().__init__(*args, **kwargs)
111100
self._graph_module = graph_module
112101
self._tosa_spec = tosa_spec
113-
# Lazy import to avoid circular dependency with operator_support
114-
from executorch.backends.arm.operator_support.pool_2d_support import (
115-
AvgPool2dSupported,
116-
)
117-
118-
self._avg_pool_checker = AvgPool2dSupported(
119-
self._tosa_spec, WhyNoPartitionReporter()
120-
)
121102

122103
def call_operator(self, op, args, kwargs, meta, updated=False):
123104
if op not in (
@@ -168,12 +149,6 @@ def call_operator(self, op, args, kwargs, meta, updated=False):
168149
x = super().call_operator(view_op, (x, new_shape), {}, meta, True)
169150
x = self._maybe_insert_q_dq_after(x, meta)
170151

171-
# Reduce (h,w) dims by avg pool if possible
172-
if not has_symbolic_reduce_dim:
173-
x, dims_to_reduce = self._reduce_by_average_pool(
174-
op, x, dims_to_reduce, meta
175-
)
176-
177152
# Reshape back to 5D if necessary
178153
if len(input_shape) > 4:
179154
original_dims = input_shape[:-3]
@@ -259,44 +234,6 @@ def _reduce_by_sum(self, op, input_node, dims, meta):
259234

260235
return super().call_operator(mul_op, (sum, divisor), {}, meta, True)
261236

262-
def _reduce_by_average_pool(self, op, input_node, dims, meta):
263-
dims_to_reduce_by_avgpool = [dim for dim in dims if dim >= 2]
264-
if len(dims_to_reduce_by_avgpool) == 0:
265-
return input_node, dims
266-
267-
dims_to_reduce_by_sum = [dim for dim in dims if dim < 2]
268-
269-
avgpool_op = get_avgpool(op)
270-
input_shape = input_node.data.size()
271-
272-
stride = [1, 1]
273-
if dims_to_reduce_by_avgpool in ([2, 3], [3, 2]):
274-
kernel_size = [input_shape[2], input_shape[3]]
275-
elif dims_to_reduce_by_avgpool == [3]:
276-
kernel_size = [1, input_shape[3]]
277-
elif dims_to_reduce_by_avgpool == [2]:
278-
kernel_size = [input_shape[2], 1]
279-
else:
280-
raise RuntimeError(
281-
f"Bad dims {dims_to_reduce_by_avgpool} for {op} decomposition of mean_dim."
282-
)
283-
284-
args = (input_node, kernel_size, stride)
285-
286-
avg_pool_node = self._graph_module.graph.create_node(
287-
"call_function", avgpool_op, args
288-
)
289-
is_supported = self._avg_pool_checker.is_node_tosa_supported(
290-
avg_pool_node, self._tosa_spec
291-
)
292-
293-
if is_supported:
294-
out = super().call_operator(avgpool_op, args, {}, meta, True)
295-
out = self._maybe_insert_q_dq_after(out, meta)
296-
return out, dims_to_reduce_by_sum
297-
298-
return input_node, dims
299-
300237
def _maybe_insert_q_dq_after(self, op, meta):
301238
"""If the input node of op is a dequant node, insert a q-dq pair after
302239
op with identical quantization parameters.

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 63 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,19 @@ def _get_special_dtype(qspec: QuantArgs) -> TosaSpecialDtype | None:
4040
return None
4141

4242

43+
def _merge_qparams(qspec_1: QuantArgs, qspec_2: QuantArgs) -> QuantArgs:
44+
"""Merge two QuantArgs when inputs are quantized differently.
45+
46+
Requires same dtype; picks the first's parameters by default.
47+
48+
"""
49+
if qspec_1.dtype != qspec_2.dtype:
50+
raise RuntimeError(
51+
f"Cannot merge qparams of different dtypes: {qspec_1.dtype} vs {qspec_2.dtype}"
52+
)
53+
return qspec_1
54+
55+
4356
def get_input_qparams(node: Node) -> dict[int, QuantArgs]:
4457
"""Get the input quantization parameters from a node, set by the
4558
'FoldAndAnnotateQParamsPass'.
@@ -121,57 +134,72 @@ def __init__(
121134
super().__init__(*args, **kwargs)
122135
self.exported_program = exported_program
123136

124-
def fold_and_annotate_arg(
125-
self, graph_module: GraphModule, node: Node, arg_list: list[Node], i: int
126-
) -> None:
127-
input_qparams = None
128-
nodes_to_remove = set()
137+
def _extract_input_params(
138+
self, arg_list: list[Node]
139+
) -> tuple[Optional[QuantArgs], set[Node]]:
140+
input_qparams: Optional[QuantArgs] = None
141+
nodes_to_remove: set[Node] = set()
129142
for arg in arg_list:
130143
if not isinstance(arg, Node):
131-
return
132-
133-
arg_quant_params = None
144+
return None, set()
145+
arg_quant: Optional[QuantArgs] = None
134146
if arg.target in DQ_OPS:
135147
args = arg.args
136148
scales = args[1]
137149
if (
138-
isinstance(args[1], Node)
150+
isinstance(scales, Node)
139151
and self.exported_program is not None
140-
and is_param_node(self.exported_program, args[1])
152+
and is_param_node(self.exported_program, scales)
141153
):
142-
scales = get_param_tensor(self.exported_program, args[1])
154+
scales = get_param_tensor(self.exported_program, scales)
143155
zps = args[2]
144156
if (
145-
isinstance(args[2], Node)
157+
isinstance(zps, Node)
146158
and self.exported_program is not None
147-
and is_param_node(self.exported_program, args[2])
159+
and is_param_node(self.exported_program, zps)
148160
):
149-
zps = get_param_tensor(self.exported_program, args[2])
150-
arg_quant_params = QuantArgs.from_operator(
161+
zps = get_param_tensor(self.exported_program, zps)
162+
arg_quant = QuantArgs.from_operator(
151163
arg.target, (args[0], scales, zps, *args[3:])
152164
)
153-
# add arg to nodes_to_remove to fold the dq-node
154165
nodes_to_remove.add(arg)
155-
if input_qparams is not None and input_qparams != arg_quant_params:
156-
# Two args are quantized differently
157-
raise RuntimeError("Input qparams do not match")
158-
input_qparams = arg_quant_params
159-
if input_qparams is not None:
160-
node.meta["input_qparams"][i] = input_qparams
161-
for n in nodes_to_remove:
162-
if n.target not in DQ_OPS:
163-
raise RuntimeError(
164-
f"Expected one of {DQ_OPS} dq_op, got {n.target}"
165-
)
166+
if arg_quant is not None:
167+
if input_qparams is None:
168+
input_qparams = arg_quant
169+
elif input_qparams != arg_quant:
170+
input_qparams = _merge_qparams(input_qparams, arg_quant)
171+
return input_qparams, nodes_to_remove
172+
173+
def _annotate_input_params(
174+
self,
175+
graph_module: GraphModule,
176+
node: Node,
177+
index: int,
178+
input_qparams: QuantArgs,
179+
nodes_to_remove: set[Node],
180+
) -> None:
181+
node.meta["input_qparams"][index] = input_qparams
182+
183+
for dq in nodes_to_remove:
184+
if dq.target not in DQ_OPS:
185+
raise RuntimeError(f"Expected one of {DQ_OPS} dq_op, got {dq.target}")
186+
node.replace_input_with(dq, cast(Node, dq.args[0]))
187+
if not dq.users:
188+
graph_module.graph.erase_node(dq)
189+
190+
special = _get_special_dtype(input_qparams)
191+
if special:
192+
node.all_input_nodes[index].meta[TosaSpecialDtype.meta_key()] = special
166193

167-
node.replace_input_with(n, cast(Node, n.args[0]))
168-
if len(n.users) == 0:
169-
graph_module.graph.erase_node(n)
170-
special_dtype = _get_special_dtype(input_qparams)
171-
if special_dtype:
172-
node.all_input_nodes[i].meta[
173-
TosaSpecialDtype.meta_key()
174-
] = special_dtype
194+
def fold_and_annotate_arg(
195+
self, graph_module: GraphModule, node: Node, arg_list: list[Node], i: int
196+
) -> None:
197+
input_qparams, nodes_to_remove = self._extract_input_params(arg_list)
198+
if input_qparams is None:
199+
return
200+
self._annotate_input_params(
201+
graph_module, node, i, input_qparams, nodes_to_remove
202+
)
175203

176204
def _handle_control_flow_node(self, node: Node, graph_module: GraphModule):
177205
"""Fold outmost quant nodes inside submodule.

backends/arm/_passes/normalize_while_initial_args_pass.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -82,6 +82,8 @@ def _normalize_node(self, graph_module: GraphModule, node: Node) -> bool:
8282
new_carried = tuple(carried_inputs + additional_inputs)
8383
node.update_arg(2, new_carried)
8484
node.update_arg(3, ())
85+
# annotate node so later keying of captured vs loop‐carried args is possible
86+
node.meta["additional_inputs"] = additional_inputs
8587

8688
body_module_name = str(cast(Node, node.args[1]).target)
8789
body_module = cast(GraphModule, graph_module.get_submodule(body_module_name)) # type: ignore

backends/arm/quantizer/quantization_annotator.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -890,29 +890,33 @@ def any_or_hardtanh_min_zero(n: Node):
890890
submodule_args_pos = -1 if node.target == torch.ops.higher_order.cond else -2
891891
submodule_args = node.args[submodule_args_pos]
892892
output_qspec = output_act_qspec
893-
if len(submodule_args) > 0: # type: ignore[arg-type]
894-
# The way the TOSA backend handles quantized inputs, arrays of input tensors (such as the input to a
895-
# conditional graph) need shared quantization.
896-
shared_qspec = SharedQuantizationSpec(
897-
(cast(list[Node], submodule_args)[0], node)
898-
)
899-
quant_properties.quant_inputs = [
900-
_QuantProperty(
901-
submodule_args_pos,
902-
[
903-
input_act_qspec,
904-
*([shared_qspec] * (len(submodule_args) - 1)), # type: ignore[arg-type]
905-
],
893+
# Annotate each control-flow tensor independently using the default input qspec
894+
if submodule_args:
895+
if node.meta.get("additional_inputs", None):
896+
qspecs = [input_act_qspec] * len(cast(Sequence[Node], submodule_args)) # type: ignore[arg-type]
897+
quant_properties.quant_inputs = [
898+
_QuantProperty(submodule_args_pos, qspecs)
899+
]
900+
else:
901+
shared_qspec = SharedQuantizationSpec(
902+
(cast(list[Node], submodule_args)[0], node)
906903
)
907-
]
908-
if node.target == torch.ops.higher_order.while_loop:
909-
# The output of the while loop body can either re-enter the body, or exit the while loop.
910-
# Therefore, A and B in the diagram below need to share the same quantization parameters.
911-
# A -> while ( RESCALE -> ... RESCALE -> ) -> B
912-
output_qspec = shared_qspec
904+
quant_properties.quant_inputs = [
905+
_QuantProperty(
906+
submodule_args_pos,
907+
[
908+
input_act_qspec,
909+
*([shared_qspec] * (len(submodule_args) - 1)), # type: ignore[arg-type]
910+
],
911+
)
912+
]
913+
if node.target == torch.ops.higher_order.while_loop:
914+
# The output of the while loop body can either re-enter the body, or exit the while loop.
915+
# Therefore, A and B in the diagram below need to share the same quantization parameters.
916+
# A -> while ( RESCALE -> ... RESCALE -> ) -> B
917+
output_qspec = shared_qspec
913918

914919
quant_properties.quant_output = _QuantProperty(0, output_qspec)
915-
916920
else:
917921
return None
918922

backends/arm/scripts/aot_arm_compiler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -847,8 +847,8 @@ def _to_edge_TOSA_delegate(
847847
)
848848

849849
# Replace quantized_decomposed::{quantize,dequantize}_per_tensor nodes
850-
# with cortex_m:: equivalents for int8 QDQ ops remaining outside the
851-
# delegated subgraph.
850+
# with cortex_m:: equivalents for int8/int16 QDQ ops remaining outside
851+
# the delegated subgraph.
852852
edge = _apply_replace_quant_nodes(edge, target, direct_drive)
853853

854854
return model_quant, edge
@@ -955,8 +955,8 @@ def _to_edge_no_delegate(
955955
)
956956

957957
# Replace quantized_decomposed::{quantize,dequantize}_per_tensor nodes
958-
# with cortex_m:: equivalents for int8 QDQ ops remaining outside the
959-
# delegated subgraph.
958+
# with cortex_m:: equivalents for int8/int16 QDQ ops remaining outside
959+
# the delegated subgraph.
960960
edge = _apply_replace_quant_nodes(edge, args.target, args.direct_drive)
961961

962962
return model_quant, edge

backends/arm/test/misc/test_transpose_counts.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ def forward(self, x):
404404
"groupnorm": TransposeCountCase(
405405
GroupNormModule(),
406406
(torch.randn(1, 4, 4, 4),),
407-
1,
407+
0,
408408
),
409409
"multihead_attention_rank2": TransposeCountCase(
410410
MultiheadAttentionModule(),
@@ -430,16 +430,16 @@ def forward(self, x):
430430
Model1ConvMaxPoolResidualLinear(), (torch.randn(2, 8, 64),), 5
431431
),
432432
"model_2_conv_mha_linear_layernorm": TransposeCountCase(
433-
Model2ConvMhaLinearLayerNorm(), (torch.randn(2, 8, 32),), 11
433+
Model2ConvMhaLinearLayerNorm(), (torch.randn(2, 8, 32),), 9
434434
),
435435
"model_3_lstm_linear": TransposeCountCase(
436436
Model3LstmLinear(), (torch.randn(2, 16, 8),), 2
437437
),
438438
"model_4_conv_lstm_linear_layernorm": TransposeCountCase(
439-
Model4ConvLstmLinearLayerNorm(), (torch.randn(2, 8, 32),), 5
439+
Model4ConvLstmLinearLayerNorm(), (torch.randn(2, 8, 32),), 3
440440
),
441441
"model_5_dwconv_gelu_layernorm_avgpool": TransposeCountCase(
442-
Model5DwConvGeluLayerNormAvgPool(), (torch.randn(1, 8, 16, 16),), 6
442+
Model5DwConvGeluLayerNormAvgPool(), (torch.randn(1, 8, 16, 16),), 4
443443
),
444444
"model_6_gru_linear": TransposeCountCase(
445445
Model6GruLinear(), (torch.randn(2, 16, 8),), 2
@@ -521,7 +521,7 @@ def forward(self, x):
521521
"groupnorm_channels_last": TransposeCountCase(
522522
GroupNormModule(),
523523
(torch.randn(1, 4, 4, 4).to(memory_format=torch.channels_last),),
524-
3,
524+
2,
525525
),
526526
"cumsum_rank4_dim3_channels_last": TransposeCountCase(
527527
CumsumModule(),

backends/arm/test/ops/test_cond.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def true_branch(arg: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
8282
return arg + torch.sin(arg), arg - torch.sin(arg)
8383

8484
def false_branch(arg: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
85-
return arg - arg.mean(), arg + arg.mean()
85+
return arg - torch.cos(arg), arg + torch.cos(arg)
8686

8787
predicate = x.flatten().sum() > 0
8888
return torch.cond(predicate, true_branch, false_branch, [x])

0 commit comments

Comments
 (0)