Skip to content

Commit c28d8ad

Browse files
authored
Merge branch 'main' into python_314
2 parents 0dd706e + e3d5de2 commit c28d8ad

190 files changed

Lines changed: 10135 additions & 1021 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/mlx.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,11 @@ jobs:
6666
echo "::endgroup::"
6767
6868
echo "::group::Build test runners"
69-
${CONDA_RUN} cmake --build cmake-out --target op_test_runner multi_thread_test_runner -j$(( $(sysctl -n hw.ncpu) - 1 ))
69+
${CONDA_RUN} cmake --build cmake-out --target op_test_runner multi_thread_test_runner mlx_mutable_state_test -j$(( $(sysctl -n hw.ncpu) - 1 ))
70+
echo "::endgroup::"
71+
72+
echo "::group::Run mutable-state (multi-session) unit test"
73+
./cmake-out/backends/mlx/test/mlx_mutable_state_test
7074
echo "::endgroup::"
7175
7276
echo "::group::Run op unit tests"

.github/workflows/pull.yml

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,33 @@ jobs:
816816
# Test test_arm_backend.sh with test
817817
backends/arm/test/test_arm_backend.sh "${ARM_TEST}"
818818
819+
test-arm-backend-public-api-backward-compatibility:
820+
name: test-arm-backend-public-api-backward-compatibility
821+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
822+
permissions:
823+
id-token: write
824+
contents: read
825+
with:
826+
runner: linux.2xlarge.memory
827+
docker-image: ci-image:executorch-ubuntu-22.04-arm-sdk
828+
submodules: 'recursive'
829+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
830+
timeout: 120
831+
script: |
832+
# The generic Linux job chooses to use base env, not the one setup by the image
833+
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
834+
conda activate "${CONDA_ENV}"
835+
836+
source .ci/scripts/utils.sh
837+
install_executorch "--use-pt-pinned-commit"
838+
839+
.ci/scripts/setup-arm-baremetal-tools.sh --enable-mlsdk-deps --install-mlsdk-deps-with-pip
840+
source examples/arm/arm-scratch/setup_path.sh
841+
842+
backends/arm/scripts/public_api_manifest/validate_all_public_api_manifests.sh
843+
844+
python backends/arm/test/public_api_bc/run_public_api_bc_scenarios.py
845+
819846
test-llama-runner-qnn-linux:
820847
name: test-llama-runner-qnn-linux
821848
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main

backends/arm/TARGETS

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,15 @@ runtime.python_library(
119119
"//executorch/exir:lib",
120120
],
121121
)
122+
runtime.python_library(
123+
name = "public_api",
124+
srcs = ["__init__.py"],
125+
deps = [
126+
":ethosu",
127+
":vgf",
128+
"//executorch/backends/arm/quantizer:lib",
129+
],
130+
)
122131

123132
runtime.python_library(
124133
name = "process_node",

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@
171171
from .rewrite_le_lt_to_ge_gt_pass import RewriteLeLtToGeGtPass # noqa
172172
from .rewrite_matmul import RewriteMatmulPass # noqa
173173
from .rewrite_max_pool2d_pass import RewriteMaxPool2dPass # noqa
174+
from .rewrite_mxfp_conv2d import RewriteMXFPConv2dPass # noqa
174175
from .rewrite_mxfp_linear import RewriteMXFPLinearPass # noqa
175176
from .rewrite_pad import RewritePadPass # noqa
176177
from .rewrite_slice import RewriteSlicePass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@
117117
InsertConstShapesPass,
118118
InsertControlFlowRescalesPass,
119119
InsertDataLayoutCastsPass,
120+
InsertDynamicPaddingPass,
120121
InsertInt32CastsAfterInt64PlaceholdersPass,
121122
InsertRescaleInt32Pass,
122123
InsertRescalePass,
@@ -146,6 +147,7 @@
146147
RewriteLeLtToGeGtPass,
147148
RewriteMatmulPass,
148149
RewriteMaxPool2dPass,
150+
RewriteMXFPConv2dPass,
149151
RewriteMXFPLinearPass,
150152
RewritePadPass,
151153
RewriteSlicePass,
@@ -611,6 +613,7 @@ def _tosa_pipeline(
611613
RewriteMaxPool2dPass(),
612614
DecomposeAdaptiveMaxPool2dPass(),
613615
RewriteConvPass(exported_program),
616+
RewriteMXFPConv2dPass(exported_program),
614617
RewriteMXFPLinearPass(exported_program),
615618
RewriteMatmulPass(),
616619
RewritePadPass(),
@@ -632,6 +635,7 @@ def _tosa_pipeline(
632635
CastInt64BuffersToInt32Pass(exported_program),
633636
FuseEqualPlaceholdersPass(exported_program),
634637
SymbolicToTosaShapesPass(),
638+
InsertDynamicPaddingPass(),
635639
FuseConsecutiveConcatShapesPass(),
636640
EnsureUniqueOutputNodesPass(),
637641
RemoveNoopPass(),

backends/arm/_passes/aten_to_tosa_activation_functions.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,21 @@ def rewrite_clamp(node: Node, pass_: AtenToDialectPass) -> DialectNodeSpec | Non
128128
exir_ops.backend.tosa.CLAMP.default,
129129
(node.args[0], *min_max_args),
130130
)
131+
132+
133+
def get_activation_replacement(
134+
node: Node, pass_: AtenToDialectPass
135+
) -> DialectNodeSpec | None:
136+
# Dispatch activation rewrites from their ATen target to the matching TOSA
137+
# dialect node builder.
138+
match node.target:
139+
case exir_ops.edge.aten.clamp.default:
140+
return rewrite_clamp(node, pass_)
141+
case exir_ops.edge.aten.erf.default:
142+
return rewrite_erf(node, pass_)
143+
case exir_ops.edge.aten.sigmoid.default:
144+
return rewrite_sigmoid(node, pass_)
145+
case exir_ops.edge.aten.tanh.default:
146+
return rewrite_tanh(node, pass_)
147+
case _:
148+
return None
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import cast
7+
8+
from executorch.backends.transforms.aten_to_dialect_pass import (
9+
AtenToDialectPass,
10+
DialectNodeSpec,
11+
)
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from torch.fx import Node
14+
15+
16+
def rewrite_argmax(node: Node, pass_: AtenToDialectPass) -> DialectNodeSpec:
17+
input_node = cast(Node, node.args[0])
18+
dim = cast(int, node.kwargs["dim"] if "dim" in node.kwargs else node.args[1])
19+
if dim < 0:
20+
dim += len(input_node.meta["val"].shape)
21+
22+
return DialectNodeSpec(
23+
exir_ops.backend.tosa.ARGMAX.default,
24+
(input_node, dim),
25+
{},
26+
)

backends/arm/_passes/exir_to_tosa_pass.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,37 +5,38 @@
55

66
import executorch.backends.arm.tosa.dialect # noqa: F401
77
from executorch.backends.arm._passes.aten_to_tosa_activation_functions import (
8-
rewrite_clamp,
9-
rewrite_erf,
10-
rewrite_sigmoid,
11-
rewrite_tanh,
8+
get_activation_replacement,
9+
)
10+
from executorch.backends.arm._passes.aten_to_tosa_tensor_operators import rewrite_argmax
11+
from executorch.backends.transforms.aten_to_dialect_pass import (
12+
AtenToDialectPass,
13+
DialectNodeSpec,
1214
)
13-
from executorch.backends.transforms.aten_to_dialect_pass import AtenToDialectPass
1415
from executorch.exir.dialects._ops import ops as exir_ops
16+
from torch.fx import Node
1517

1618

1719
class ExirToTosaPass(AtenToDialectPass):
1820
"""Rewrite simple EXIR ops to equivalent backend TOSA dialect ops.
1921
20-
Rewrite functions are grouped by op category and registered with the shared
21-
ATen-to-dialect pass infrastructure.
22+
Rewrite functions are registered with the shared ATen-to-dialect pass
23+
infrastructure.
2224
2325
"""
2426

2527

26-
_ACTIVATION_FUNCTION_REWRITES = {
27-
exir_ops.edge.aten.clamp.default: rewrite_clamp,
28-
exir_ops.edge.aten.erf.default: rewrite_erf,
29-
exir_ops.edge.aten.sigmoid.default: rewrite_sigmoid,
30-
exir_ops.edge.aten.tanh.default: rewrite_tanh,
31-
}
28+
@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.argmax.default)
29+
def _get_tensor_operators_replacement(
30+
node: Node, pass_: AtenToDialectPass
31+
) -> DialectNodeSpec:
32+
return rewrite_argmax(node, pass_)
3233

33-
_DIRECT_REWRITE_CATEGORIES = {
34-
"activation_functions": _ACTIVATION_FUNCTION_REWRITES,
35-
}
3634

37-
# Register each category's ATen targets with the function that builds the
38-
# corresponding TOSA dialect node spec.
39-
for _rewrite_category in _DIRECT_REWRITE_CATEGORIES.values():
40-
for _edge_target, _rewrite_fn in _rewrite_category.items():
41-
ExirToTosaPass.register_dialect_substitution(_edge_target)(_rewrite_fn)
35+
@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.clamp.default)
36+
@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.erf.default)
37+
@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.sigmoid.default)
38+
@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.tanh.default)
39+
def _get_activation_replacement(
40+
node: Node, pass_: AtenToDialectPass
41+
) -> DialectNodeSpec | None:
42+
return get_activation_replacement(node, pass_)

backends/arm/_passes/insert_dynamic_padding.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class InsertDynamicPaddingPass(ArmOpTargetedPass):
2929
_passes_required_after: Set[Type[ExportPass]] = set()
3030
target_ops = (
3131
exir_ops.backend.tosa.CONV2D.default,
32+
exir_ops.backend.tosa.CONV3D.default,
3233
exir_ops.backend.tosa.DEPTHWISE_CONV2D.default,
3334
exir_ops.backend.tosa.MAX_POOL2D.default,
3435
exir_ops.backend.tosa.AVG_POOL2D.default,
@@ -57,11 +58,12 @@ def call_operator(self, op, args, kwargs, meta, updated=False) -> ProxyValue:
5758
if not self._is_dynamic_padding(padding):
5859
return super().call_operator(op, args, kwargs, meta, updated)
5960

60-
# Create a pad op before conv2d
61+
# Create a pad op before the convolution/pool op.
6162
input_tensor = args[0]
6263

6364
zero_padding_pair = [0, 0]
64-
zero_spatial_padding = [0, 0, 0, 0]
65+
spatial_rank = 3 if op == exir_ops.backend.tosa.CONV3D.default else 2
66+
zero_spatial_padding = [0] * (spatial_rank * 2)
6567
N_padding = super().call_shape_operator(
6668
exir_ops.backend.tosa.CONST_SHAPE.default,
6769
(zero_padding_pair,),
@@ -93,7 +95,7 @@ def call_operator(self, op, args, kwargs, meta, updated=False) -> ProxyValue:
9395
meta,
9496
True,
9597
)
96-
new_conv2d_args = list(args)
97-
new_conv2d_args[0] = pad_res
98-
new_conv2d_args[padding_index] = zero_spatial_padding
99-
return super().call_operator(op, tuple(new_conv2d_args), kwargs, meta, updated)
98+
new_args = list(args)
99+
new_args[0] = pad_res
100+
new_args[padding_index] = zero_spatial_padding
101+
return super().call_operator(op, tuple(new_args), kwargs, meta, updated)

backends/arm/_passes/rewrite_conv_pass.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -97,23 +97,25 @@ def _adjust_pad_if_needed(
9797

9898
if isinstance(mod_remainder, torch.SymInt):
9999
shape_env = get_context_shape_env()
100-
exact_values = evaluate_symbolic_expr_values(
101-
mod_remainder.node.expr, shape_env
102-
)
100+
exact_values = evaluate_symbolic_expr_values(mod_remainder, shape_env)
103101
if exact_values is not None:
104102
mod_remainder_upper = max(exact_values)
103+
if len(exact_values) == 1:
104+
mod_remainder = int(next(iter(exact_values)))
105+
elif mod_remainder_upper == 0:
106+
mod_remainder = 0
107+
else:
108+
return pad - mod_remainder
105109
else:
106-
value_ranges = shape_env.bound_sympy(mod_remainder.node.expr)
107-
mod_remainder_upper = int(value_ranges.upper)
108-
if mod_remainder_upper == 0:
109-
mod_remainder = 0
110-
else:
111-
mod_remainder_upper = mod_remainder
112-
113-
if mod_remainder_upper > pad:
110+
# SizeAdjustInputPass already trims symbolic remainder classes
111+
# that would force negative padding. Keep the symbolic
112+
# expression here instead of asking ShapeEnv to normalize it.
113+
return pad - mod_remainder
114+
if mod_remainder > pad:
114115
raise RuntimeError(
115-
"This case should be handled by the SizeAdjustInputPass, is it enabled?\n"
116+
"This case should be handled by SizeAdjustInputPass, is it enabled?\n"
116117
)
118+
117119
return pad - mod_remainder
118120

119121
def _is_depthwise_conv2d(self, node: torch.fx.Node) -> bool:

0 commit comments

Comments
 (0)