Skip to content

Commit 5b9bf5e

Browse files
committed
Update on "Use caching allocator for runner (#15730)"
Summary: We observed that on iOS it improves perf by 6% because SDPA op does temp allocations. No significant difference on android though. ghstack-source-id: 328001114 exported-using-ghexport Reviewed By: navsud, derekdixu Differential Revision: D86120038 [ghstack-poisoned]
2 parents 7359cf2 + d8b32c6 commit 5b9bf5e

39 files changed

Lines changed: 1366 additions & 640 deletions

.ci/scripts/test_cortex_m_e2e.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/usr/bin/env bash
22
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
34
# All rights reserved.
45
#
56
# This source code is licensed under the BSD-style license found in the
@@ -18,7 +19,7 @@ mkdir -p "./cortex_m_e2e/${MODEL}"
1819
WORK_DIR=$(realpath "./cortex_m_e2e/${MODEL}")
1920

2021
echo "=== Exporting ${MODEL} with cortex-m55+int8 ==="
21-
python -m examples.arm.aot_arm_compiler \
22+
python -m backends.arm.scripts.aot_arm_compiler \
2223
-m "${MODEL}" \
2324
--target=cortex-m55+int8 \
2425
--quantize \

backends/arm/README.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,13 @@ Setup:
106106
./examples/arm/setup.sh --disable-ethos-u-deps --enable-mlsdk-deps
107107
```
108108

109+
This is the default setup path and installs the MLSDK components from pip.
110+
Developers who need local source builds can use:
111+
112+
```
113+
./backends/arm/scripts/setup-mlsdk-from-source.sh
114+
```
115+
109116
The current flow lowers to TOSA and converts to VGF for use in external projects,
110117
so the `executor_runner` is not typically used here.
111118

@@ -155,7 +162,7 @@ scp -P 2222 arm_test/cmake-out/executor_runner root@127.0.0.1:/tmp/
155162
Create a PTE file:
156163

157164
```
158-
python3 -m examples.arm.aot_arm_compiler \
165+
python3 -m backends.arm.scripts.aot_arm_compiler \
159166
--model_name examples/arm/example_modules/add.py \
160167
--delegate \
161168
--quantize \

backends/arm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@
143143
from .replace_scalar_with_tensor_pass import ( # noqa
144144
ReplaceScalarWithTensorByProfilePass,
145145
)
146+
from .rewrite_avg_pool2d_pass import RewriteAvgPool2dPass # noqa
146147
from .rewrite_bool_bitwise_to_logical_pass import ( # noqa
147148
RewriteBoolBitwiseToLogicalPass,
148149
)
@@ -157,6 +158,7 @@
157158
from .rewrite_inplace_arithmetic_pass import RewriteInplaceArithmeticPass # noqa
158159
from .rewrite_le_lt_to_ge_gt_pass import RewriteLeLtToGeGtPass # noqa
159160
from .rewrite_matmul import RewriteMatmulPass # noqa
161+
from .rewrite_max_pool2d_pass import RewriteMaxPool2dPass # noqa
160162
from .rewrite_pad import RewritePadPass # noqa
161163
from .rewrite_slice import RewriteSlicePass # noqa
162164
from .rewrite_upsample import RewriteUpsamplePass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from collections.abc import Sequence
1111
from dataclasses import dataclass, field
1212

13-
import executorch.backends.arm.tosa.dialect # noqa: unused
1413
from executorch.backends.arm._passes import (
1514
AccumulateIndexPutPass,
1615
AnnotateOutputDimOrderPass,
@@ -126,6 +125,7 @@
126125
RemoveNoopPass,
127126
ReplaceInfAndLimitValuesPass,
128127
ReplaceScalarWithTensorByProfilePass,
128+
RewriteAvgPool2dPass,
129129
RewriteBoolBitwiseToLogicalPass,
130130
RewriteBoolToFp32CastViaInt8Pass,
131131
RewriteConvPass,
@@ -134,6 +134,7 @@
134134
RewriteInplaceArithmeticPass,
135135
RewriteLeLtToGeGtPass,
136136
RewriteMatmulPass,
137+
RewriteMaxPool2dPass,
137138
RewritePadPass,
138139
RewriteSlicePass,
139140
RewriteUpsamplePass,
@@ -143,7 +144,6 @@
143144
UnsqueezeBeforeRepeatPass,
144145
UnsqueezeScalarPlaceholdersPass,
145146
)
146-
147147
from executorch.backends.arm._passes.arm_pass import ArmPass
148148
from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec
149149
from executorch.backends.arm.common.pipeline_config import (
@@ -175,6 +175,32 @@ class PassInsertions:
175175
after_passes: list = field(default_factory=list)
176176

177177

178+
_registered_pass_insertions: dict[type, PassInsertions] = {}
179+
180+
181+
def register_pass_insertions_before(
182+
target_pass_type: type, passes: list[ExportPass]
183+
) -> None:
184+
"""Register passes to be inserted before a target pass for all pipelines."""
185+
if target_pass_type not in _registered_pass_insertions:
186+
_registered_pass_insertions[target_pass_type] = PassInsertions()
187+
_registered_pass_insertions[target_pass_type].before_passes.extend(passes)
188+
189+
190+
def register_pass_insertions_after(
191+
target_pass_type: type, passes: list[ExportPass]
192+
) -> None:
193+
"""Register passes to be inserted after a target pass for all pipelines."""
194+
if target_pass_type not in _registered_pass_insertions:
195+
_registered_pass_insertions[target_pass_type] = PassInsertions()
196+
_registered_pass_insertions[target_pass_type].after_passes.extend(passes)
197+
198+
199+
def clear_registered_pass_insertions() -> None:
200+
"""Clear all globally registered pass insertions."""
201+
_registered_pass_insertions.clear()
202+
203+
178204
class ArmPassManager(PassManager):
179205
def __init__(self, compile_spec: ArmCompileSpec) -> None:
180206
self.compile_spec = compile_spec
@@ -323,13 +349,17 @@ def _configure_pass_insertions(self, exported_program: ExportedProgram) -> None:
323349
"""Hook for subclasses to configure pass insertions. Called at the START
324350
of pipeline construction, before any passes are added.
325351
326-
Subclasses should override this to call insert_passes_before/after.
352+
Subclasses can override this to call insert_passes_before/after.
327353
328354
Args:
329355
exported_program: The exported program being transformed
330356
331357
"""
332-
pass
358+
for pass_type, insertions in _registered_pass_insertions.items():
359+
if insertions.before_passes:
360+
self.insert_passes_before(pass_type, list(insertions.before_passes))
361+
if insertions.after_passes:
362+
self.insert_passes_after(pass_type, list(insertions.after_passes))
333363

334364
def add_passes(self, passes: Sequence[ExportPass | None]):
335365
for p in passes:
@@ -432,6 +462,8 @@ def _tosa_pipeline(
432462
DecomposeSliceScatterPass(),
433463
AccumulateIndexPutPass(),
434464
DecomposeIndexTensorToGatherPass(),
465+
DecomposeAdaptiveAvgPool2dPass(),
466+
DecomposeAvgPool2dPass(),
435467
Conv1dUnsqueezePass(),
436468
]
437469
)
@@ -468,17 +500,16 @@ def _tosa_pipeline(
468500
DecomposeSoftmaxPass(),
469501
ConvertMinMaxPass(),
470502
DecomposeAnyPass(),
471-
DecomposeAdaptiveAvgPool2dPass(),
472-
DecomposeAvgPool2dPass(),
473503
DecorateFp32toInt32CastingPass(),
474-
ComputeConstantOpsAOTPass(exported_program),
475-
FuseConstantArgsPass(exported_program),
476504
ConvertExpandCopyToRepeatPass(),
477505
UnsqueezeBeforeRepeatPass(),
478506
DecomposeCumsumPass(exported_program),
479507
DecomposeAsStridedCopyPass(),
480508
DecomposeMaxPool2dPass(),
481509
SizeAdjustInputPass(),
510+
RewriteAvgPool2dPass(),
511+
ComputeConstantOpsAOTPass(exported_program),
512+
FuseConstantArgsPass(exported_program),
482513
DecomposeSelectPass(),
483514
ConvertSqueezesToViewPass(),
484515
CastToInt32Pass(),
@@ -496,6 +527,7 @@ def _tosa_pipeline(
496527
self.add_passes(
497528
[
498529
RewriteUpsamplePass(),
530+
RewriteMaxPool2dPass(),
499531
RewriteConvPass(exported_program),
500532
RewriteMatmulPass(),
501533
RewritePadPass(),
@@ -573,6 +605,8 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
573605
DecomposeLayerNormPass(tfa_pass=True),
574606
DecomposeVarPass(tfa_pass=True),
575607
DecomposeMeanDimPass(graph_module, self.tosa_spec, tfa_pass=True),
608+
DecomposeAdaptiveAvgPool2dPass(tfa_pass=True),
609+
DecomposeAvgPool2dPass(tfa_pass=True),
576610
]
577611
)
578612

@@ -598,8 +632,6 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
598632
DecomposeDivPass(tfa_pass=True),
599633
DecomposeLinalgVectorNormPass(tfa_pass=True),
600634
DecomposeSqrtPass(tfa_pass=True),
601-
DecomposeAdaptiveAvgPool2dPass(tfa_pass=True),
602-
DecomposeAvgPool2dPass(tfa_pass=True),
603635
DecomposeSoftmaxPass(
604636
tfa_pass=True,
605637
),

0 commit comments

Comments
 (0)