Skip to content

Commit e140a9b

Browse files
committed
Update base for Update on "[ET Device Support] DeviceAllocator interface and DeviceAllocatorRegistry"
This diff introduces the `DeviceAllocator` abstract interface and `DeviceAllocatorRegistry` for device-specific memory allocation. This is a foundational abstraction that enables the runtime to dispatch memory operations to the appropriate device backend other than CPU (CUDA, etc.). **DeviceAllocator interface provides:** - `init_buffer()` - Initialize memory buffer pools for memory-planned tensors - `get_offset_address()` - Get pointer to offset within pre-allocated buffer - `allocate()` / `deallocate()` - Dynamic device memory allocation - `copy_host_to_device()` / `copy_device_to_host()` - Data transfer between host and device - `device_type()` - Returns the device type this allocator handles **DeviceAllocatorRegistry provides:** - Singleton registry mapping DeviceType → DeviceAllocator - `register_allocator()` / `get_allocator()` methods - Fixed-size array indexed by device type (no dynamic allocation, embedded-friendly) **Design notes:** - Registry stores raw pointers (non-owning) - allocators are expected to be singletons with static lifetime - Follows ExecuTorch's embedded-first philosophy (no std::unique_ptr, no heap allocation in registry) - Convenience free functions `register_device_allocator()` and `get_device_allocator()` for ease of use Differential Revision: [D93635656](https://our.internmc.facebook.com/intern/diff/D93635656/) [ghstack-poisoned]
2 parents 5ef857f + 38b40bc commit e140a9b

82 files changed

Lines changed: 3241 additions & 941 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.ci/scripts/wheel/envvar_macos.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,8 @@
99
# any variables so that subprocesses will see them.
1010

1111
source "${GITHUB_WORKSPACE}/${REPOSITORY}/.ci/scripts/wheel/envvar_base.sh"
12+
13+
# Force Apple Clang to avoid Homebrew LLVM, which doesn't properly handle
14+
# Apple SDK Objective-C framework headers (e.g. NSIntegerMax in NSObjCRuntime.h).
15+
export CC=/usr/bin/clang
16+
export CXX=/usr/bin/clang++

.github/workflows/cuda.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ jobs:
132132
# Build executor_runner (needed by CUDA backend e2e tests)
133133
cmake --build cmake-out --target executor_runner
134134
135-
# Run all CUDA backend Python tests (including chunk_gated_delta e2e)
135+
# Run CUDA backend Python tests
136136
python -m pytest backends/cuda/tests backends/cuda/passes/tests -v -o "addopts="
137137
138138
export-model-cuda-artifact:

.github/workflows/pull.yml

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,36 @@ jobs:
607607
exit 1
608608
fi
609609
610+
test-mcu-cortex-m-backend:
611+
name: test-mcu-cortex-m-backend
612+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
613+
permissions:
614+
id-token: write
615+
contents: read
616+
with:
617+
runner: linux.2xlarge.memory
618+
docker-image: ci-image:executorch-ubuntu-22.04-arm-sdk
619+
submodules: 'recursive'
620+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
621+
timeout: 120
622+
script: |
623+
# The generic Linux job chooses to use base env, not the one setup by the image
624+
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
625+
conda activate "${CONDA_ENV}"
626+
627+
source .ci/scripts/utils.sh
628+
install_executorch "--use-pt-pinned-commit"
629+
630+
# Install arm dependencies
631+
.ci/scripts/setup-arm-baremetal-tools.sh
632+
source examples/arm/arm-scratch/setup_path.sh
633+
634+
# To build cortex-m test runner
635+
backends/cortex_m/test/build_test_runner.sh
636+
637+
# To run cortex_m tests
638+
pytest --config-file=backends/arm/test/pytest.ini backends/cortex_m/test
639+
610640
android:
611641
uses: ./.github/workflows/_android.yml
612642
permissions:

.github/workflows/trunk.yml

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,33 +1054,3 @@ jobs:
10541054
10551055
.ci/scripts/test_model.ps1 -modelName ${{ matrix.model }} -backend ${{ matrix.backend }}
10561056
}"
1057-
1058-
test-mcu-cortex-m-backend:
1059-
name: test-mcu-cortex-m-backend
1060-
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
1061-
permissions:
1062-
id-token: write
1063-
contents: read
1064-
with:
1065-
runner: linux.2xlarge.memory
1066-
docker-image: ci-image:executorch-ubuntu-22.04-arm-sdk
1067-
submodules: 'recursive'
1068-
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
1069-
timeout: 120
1070-
script: |
1071-
# The generic Linux job chooses to use base env, not the one setup by the image
1072-
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
1073-
conda activate "${CONDA_ENV}"
1074-
1075-
source .ci/scripts/utils.sh
1076-
install_executorch "--use-pt-pinned-commit"
1077-
1078-
# Install arm dependencies
1079-
.ci/scripts/setup-arm-baremetal-tools.sh
1080-
source examples/arm/arm-scratch/setup_path.sh
1081-
1082-
# To build cortex-m test runner
1083-
backends/cortex_m/test/build_test_runner.sh
1084-
1085-
# To run cortex_m tests
1086-
pytest --config-file=backends/arm/test/pytest.ini backends/cortex_m/test

backends/arm/MODELS.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
<!-- Copyright 2025-2026 Arm Limited and/or its affiliates. -->
22
# The following file contains all models that have been confirmed to be functional and tested for the Arm backend:
3+
# Note: Deep AutoEncoder requires manual Linear+BatchNorm1d fusion as the quantizer does not yet support this pattern.
4+
# Note: DS CNN requires AvgPool2d workaround for Ethos-U55 due to stride > 3 limitation.
35
- Conformer
6+
- Deep AutoEncoder
47
- Deit Tiny
58
- DeepLab v3 (DL3)
9+
- DS CNN
610
- Inception v3 (IC3)
711
- Llama
812
- Gemma3n
913
- Long Short-Term Memory (LSTM)
14+
- MobileNet V1 0.25
1015
- MobileNet v2 (MV2)
1116
- MobileNet v3 (MV3)
1217
- Some popular torch.nn.functional models (NN functional)
@@ -16,6 +21,7 @@
1621
- Neural Super Sampler (NSS)
1722
- Phi-3
1823
- ResNet 18
24+
- ResNet-8
1925
- Wav2Letter (W2L)
2026
- Stable Diffusion:
2127
* CLIP Text Encoder (CLIP Text with Projection)

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@
102102
QuantizeClampArgumentsPass,
103103
)
104104
from .fuse_batch_norm2d_pass import FuseBatchNorm2dPass # noqa
105+
from .fuse_consecutive_rescales_pass import FuseConsecutiveRescalesPass # noqa
105106
from .fuse_constant_ops_pass import ( # noqa
106107
ComputeConstantOpsAOTPass,
107108
FuseConstantArgsPass,

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898
DecorateFp32toInt32CastingPass,
9999
FoldAndAnnotateQParamsPass,
100100
FuseBatchNorm2dPass,
101+
FuseConsecutiveRescalesPass,
101102
FuseConstantArgsPass,
102103
FuseDuplicateUsersPass,
103104
FuseEqualPlaceholdersPass,
@@ -380,6 +381,7 @@ def _tosa_pipeline(
380381
# Ticket: MLETORCH-1539
381382
DecomposeLinearPass(),
382383
InsertRescaleInt32Pass(),
384+
FuseConsecutiveRescalesPass(),
383385
InsertControlFlowRescalesPass(),
384386
DecomposeQuantNodesPass(),
385387
]
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import logging
7+
from typing import cast, Set, Type
8+
9+
import torch
10+
from executorch.backends.arm._passes.arm_pass import ArmPass
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import ExportPass, PassResult
13+
from torch.fx import GraphModule, Node
14+
15+
logger: logging.Logger = logging.getLogger(__name__)
16+
17+
# TOSA RESCALE argument positions:
18+
# args[0] = input tensor (Node)
19+
# args[1] = output dtype (e.g., torch.int8, torch.int32)
20+
# args[2] = scale list (List[float]; per-tensor when len == 1)
21+
# args[3] = input zero point (int)
22+
# args[4] = output zero point (int)
23+
_ARG_INPUT = 0
24+
_ARG_OUTPUT_DTYPE = 1
25+
_ARG_SCALE = 2
26+
_ARG_INPUT_ZP = 3
27+
_ARG_OUTPUT_ZP = 4
28+
29+
30+
class FuseConsecutiveRescalesPass(ArmPass):
31+
"""Fuse consecutive RESCALE(INT32->INT8/INT16) -> RESCALE(INT8/INT16->INT32)
32+
pairs.
33+
34+
InsertRescaleInt32Pass wraps each quantized arithmetic and comparison
35+
operator (add, sub, mul, abs, eq, ge, gt, le, lt, max, min, sum) with
36+
input rescales (INT8/INT16->INT32) and an output rescale
37+
(INT32->INT8/INT16). When two such ops are chained (e.g., add1 -> add2),
38+
the output rescale of add1 feeds directly into an input rescale of add2,
39+
creating a redundant INT32->INT8/INT16->INT32 round-trip that loses
40+
precision.
41+
42+
This pass detects such pairs and handles two cases:
43+
44+
- **Identity** (composed scale ~1.0, matching zero points): Removes both
45+
RESCALEs and directly wires R1's input to R2's users. This eliminates
46+
the entire round-trip. Bypassing the intermediate INT8/INT16 clamp can
47+
in theory cause up to ~120 INT8 steps of output difference when all
48+
inputs are near the clamp boundary; in practice, observed differences
49+
are 0-1 steps for typical distributions. Tests use qtol=1.
50+
51+
- **Non-identity**: Leaves the pair unchanged. The Vela NPU compiler
52+
cannot correctly process INT32->INT32 RESCALE (produces all-zero NPU
53+
outputs), so non-identity pairs retain their INT8/INT16 intermediate.
54+
55+
Handles multi-user R1 nodes: when R1 feeds both RESCALE and
56+
non-RESCALE users, each R1->R2 RESCALE pair is fused individually
57+
while preserving R1 for its non-RESCALE users.
58+
59+
"""
60+
61+
_passes_required_after: Set[Type[ExportPass]] = set()
62+
63+
def call(self, graph_module: GraphModule) -> PassResult:
64+
graph = graph_module.graph
65+
modified = False
66+
rescale_before = sum(1 for n in graph.nodes if _is_rescale(n))
67+
identity_pairs_fused = 0
68+
69+
for node in list(graph.nodes):
70+
node = cast(Node, node)
71+
if not _is_fuseable_r1(node):
72+
continue
73+
74+
r1_input = node.args[_ARG_INPUT]
75+
r1_input_zp = node.args[_ARG_INPUT_ZP]
76+
r1_scale = float(node.args[_ARG_SCALE][0]) # type: ignore[arg-type]
77+
78+
node_fused = False
79+
for user in list(node.users):
80+
if _try_fuse_identity_pair(node, user, r1_input, r1_input_zp, r1_scale):
81+
node_fused = True
82+
identity_pairs_fused += 1
83+
84+
if node_fused:
85+
modified = True
86+
87+
if modified:
88+
graph.eliminate_dead_code()
89+
rescale_after = sum(1 for n in graph.nodes if _is_rescale(n))
90+
removed = rescale_before - rescale_after
91+
logger.info(
92+
"FuseConsecutiveRescalesPass: removed %d identity pairs "
93+
"(%d RESCALEs: %d -> %d)",
94+
identity_pairs_fused,
95+
removed,
96+
rescale_before,
97+
rescale_after,
98+
)
99+
graph_module.recompile()
100+
graph.lint()
101+
# Note: we deliberately skip super().call() — retracing is
102+
# unnecessary since this pass only rewires edges and removes
103+
# nodes without introducing new operations.
104+
105+
return PassResult(graph_module, modified)
106+
107+
108+
def _is_rescale(node: Node) -> bool:
109+
return (
110+
node.op == "call_function"
111+
and node.target == exir_ops.backend.tosa.RESCALE.default
112+
)
113+
114+
115+
def _is_fuseable_r1(node: Node) -> bool:
116+
"""Check if node is an R1 candidate.
117+
118+
R1 is RESCALE(INT32 -> INT8/INT16) with per-tensor scale.
119+
120+
"""
121+
if not _is_rescale(node):
122+
return False
123+
if node.args[_ARG_OUTPUT_DTYPE] not in (torch.int8, torch.int16):
124+
return False
125+
if len(node.args[_ARG_SCALE]) != 1: # type: ignore[arg-type]
126+
return False
127+
r1_input = node.args[_ARG_INPUT]
128+
if not isinstance(r1_input, Node) or "val" not in r1_input.meta:
129+
return False
130+
if r1_input.meta["val"].dtype != torch.int32:
131+
return False
132+
return True
133+
134+
135+
def _try_fuse_identity_pair(
136+
r1: Node,
137+
r2: Node,
138+
r1_input: Node,
139+
r1_input_zp: int,
140+
r1_scale: float,
141+
) -> bool:
142+
"""Try to fuse an R1->R2 identity pair.
143+
144+
Returns True if fused.
145+
146+
"""
147+
if not _is_rescale(r2):
148+
return False
149+
if r2.args[_ARG_OUTPUT_DTYPE] != torch.int32:
150+
return False
151+
if r1.args[_ARG_OUTPUT_ZP] != r2.args[_ARG_INPUT_ZP]:
152+
return False
153+
if len(r2.args[_ARG_SCALE]) != 1: # type: ignore[arg-type]
154+
return False
155+
156+
r2_scale = float(r2.args[_ARG_SCALE][0]) # type: ignore[arg-type, index]
157+
composed_scale = r1_scale * r2_scale
158+
r2_output_zp = r2.args[_ARG_OUTPUT_ZP]
159+
160+
if abs(composed_scale - 1.0) < 1e-6 and r1_input_zp == r2_output_zp:
161+
# Identity case: remove both RESCALEs and directly wire
162+
# R1's input (INT32) to R2's users. The composed scale
163+
# is ~1.0 so the round-trip is a no-op modulo the INT8
164+
# clamp. Bypassing the clamp can in theory cause up to
165+
# ~120 INT8 steps of difference near clamp boundaries;
166+
# observed differences are 0-1 steps. Tests use qtol=1.
167+
r2.replace_all_uses_with(r1_input)
168+
return True
169+
170+
# Non-identity: leave the pair unchanged. Creating a
171+
# single INT32->INT32 RESCALE with the composed scale would
172+
# be semantically correct (and the TOSA ref model handles
173+
# it), but the Vela NPU compiler produces all-zero outputs
174+
# for INT32->INT32 RESCALE operations.
175+
return False

0 commit comments

Comments
 (0)