Skip to content

Commit 97d9777

Browse files
committed
Update base for Update on "[ET Device Support] TensorImpl carries device info"
This diff extends `TensorImpl` to carry device information, enabling the runtime tensor to track which device its data resides on (CPU, CUDA, etc.). This is a prerequisite for parsing device info from the schema and allocating device memory. Differential Revision: [D93635655](https://our.internmc.facebook.com/intern/diff/D93635655/) [ghstack-poisoned]
2 parents 7ddf345 + 22174fa commit 97d9777

43 files changed

Lines changed: 1911 additions & 529 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

backends/arm/_passes/TARGETS

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,55 @@
11
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
22

33
runtime.python_library(
4-
name = "passes",
5-
srcs = glob(["*.py"]),
4+
name = "core",
5+
srcs = [
6+
"arm_pass.py",
7+
"arm_pass_utils.py",
8+
"quant_args.py",
9+
],
610
deps = [
711
"//executorch/backends/arm:common",
812
"//executorch/backends/arm:constants",
913
"//executorch/backends/arm/tosa:utils",
1014
"//executorch/backends/arm/tosa/dialect:lib",
15+
"//executorch/backends/transforms:utils",
16+
"//executorch/exir:lib",
17+
],
18+
)
19+
20+
runtime.python_library(
21+
name = "arm_pass_manager_base",
22+
srcs = ["arm_pass_manager.py"],
23+
deps = [
24+
":core",
25+
"//executorch/backends/arm:common",
26+
"//executorch/backends/arm/tosa:utils",
27+
"//executorch/backends/arm/tosa/dialect:lib",
28+
"//executorch/exir:lib",
29+
],
30+
)
31+
32+
runtime.python_library(
33+
name = "arm_pass_manager_fb",
34+
srcs = [],
35+
deps = [
36+
":arm_pass_manager_base",
37+
# @oss-disable[end= ]: "//executorch/backends/arm/_passes/fb:fb",
38+
],
39+
)
40+
41+
runtime.python_library(
42+
name = "passes",
43+
srcs = glob(["*.py"], exclude = [
44+
"arm_pass.py",
45+
"arm_pass_utils.py",
46+
"quant_args.py",
47+
]),
48+
deps = [
49+
":core",
50+
":arm_pass_manager_base" if runtime.is_oss else ":arm_pass_manager_fb",
51+
"//executorch/backends/arm/tosa:utils",
52+
"//executorch/backends/arm/tosa/dialect:lib",
1153
"//executorch/backends/transforms:fuse_view_copy",
1254
"//executorch/backends/transforms:remove_getitem_op",
1355
"//executorch/backends/transforms:replace_scalar_with_tensor",

backends/arm/_passes/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@
138138
RewriteBoolToFp32CastViaInt8Pass,
139139
)
140140
from .rewrite_conv_pass import RewriteConvPass # noqa
141+
from .rewrite_high_rank_singleton_permute_pass import ( # noqa
142+
RewriteHighRankSingletonPermutePass,
143+
)
141144
from .rewrite_index_put_pass import RewriteIndexPutPass # noqa
142145
from .rewrite_le_lt_to_ge_gt_pass import RewriteLeLtToGeGtPass # noqa
143146
from .rewrite_matmul import RewriteMatmulPass # noqa
@@ -155,4 +158,13 @@
155158
from .control_flow_const_inline import ( # noqa # usort: skip
156159
ControlFlowConstInlinePass,
157160
)
158-
from .arm_pass_manager import ArmPassManager # noqa # usort: skip
161+
162+
# Import all subpackages to allow extensions to patch classes
163+
import importlib # noqa: E402
164+
import pkgutil # noqa: E402
165+
166+
for _, _modname, _ispkg in pkgutil.iter_modules(__path__, __name__ + "."):
167+
if _ispkg:
168+
importlib.import_module(_modname)
169+
170+
from .arm_pass_manager import ArmPassManager # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import logging
99
from collections import defaultdict
1010
from collections.abc import Sequence
11+
from dataclasses import dataclass, field
1112

1213
import executorch.backends.arm.tosa.dialect # noqa: unused
1314
from executorch.backends.arm._passes import (
@@ -121,6 +122,7 @@
121122
RewriteBoolBitwiseToLogicalPass,
122123
RewriteBoolToFp32CastViaInt8Pass,
123124
RewriteConvPass,
125+
RewriteHighRankSingletonPermutePass,
124126
RewriteIndexPutPass,
125127
RewriteLeLtToGeGtPass,
126128
RewriteMatmulPass,
@@ -156,11 +158,21 @@
156158
logger = logging.getLogger(__name__)
157159

158160

161+
@dataclass
162+
class PassInsertions:
163+
"""Holds lists of passes to be inserted before and after a target pass."""
164+
165+
before_passes: list = field(default_factory=list)
166+
after_passes: list = field(default_factory=list)
167+
168+
159169
class ArmPassManager(PassManager):
160170
def __init__(self, compile_spec: ArmCompileSpec) -> None:
161171
self.compile_spec = compile_spec
162172
self.tosa_spec = compile_spec.tosa_spec
163173
self._skip_pass_types: tuple[type, ...] = ()
174+
self._pass_insertions: dict[type, PassInsertions] = {}
175+
self._insertions_applied = False
164176
super().__init__()
165177
self.configure_skip_passes()
166178

@@ -222,6 +234,98 @@ def validate_constraints_mandatory(self):
222234

223235
raise RuntimeError(error_msg)
224236

237+
def insert_passes_before(
238+
self, target_pass_type: type, passes: list[ExportPass]
239+
) -> None:
240+
"""Register passes to be inserted before instances of target_pass_type.
241+
Insertions are deferred and applied via _apply_pass_insertions().
242+
243+
Args:
244+
target_pass_type: The pass class to insert before (e.g., InsertTableOpsPass)
245+
passes: List of pass instances to insert
246+
247+
"""
248+
self._pass_insertions.setdefault(
249+
target_pass_type, PassInsertions()
250+
).before_passes.extend(passes)
251+
252+
def insert_passes_after(
253+
self, target_pass_type: type, passes: list[ExportPass]
254+
) -> None:
255+
"""Register passes to be inserted after instances of target_pass_type.
256+
Insertions are deferred and applied via _apply_pass_insertions().
257+
258+
Args:
259+
target_pass_type: The pass class to insert after
260+
passes: List of pass instances to insert
261+
262+
"""
263+
self._pass_insertions.setdefault(
264+
target_pass_type, PassInsertions()
265+
).after_passes.extend(passes)
266+
267+
def _apply_pass_insertions(self) -> None:
268+
"""Apply all registered pass insertions to the collected passes.
269+
270+
Called ONCE after all add_passes() calls are complete, before execution.
271+
272+
Raises:
273+
ValueError: If any registered target pass type is not found in the pipeline.
274+
275+
"""
276+
if self._insertions_applied or not self._pass_insertions:
277+
return
278+
279+
# Fail fast if any target pass type is missing from the pipeline
280+
existing_pass_types = {type(p) for p in self.passes}
281+
for target_type in self._pass_insertions:
282+
if target_type not in existing_pass_types:
283+
available = [type(p).__name__ for p in self.passes]
284+
raise ValueError(
285+
f"Target pass {target_type.__name__} not found in the pass "
286+
f"pipeline. Available passes: {available}"
287+
)
288+
289+
# Build new pass list with insertions applied
290+
new_passes = []
291+
for pass_obj in self.passes:
292+
pass_type = type(pass_obj)
293+
294+
# Insert passes BEFORE this pass
295+
if pass_type in self._pass_insertions:
296+
insertions = self._pass_insertions[pass_type]
297+
for before_pass in insertions.before_passes:
298+
# Check if we should skip this inserted pass
299+
if type(before_pass) not in self._skip_pass_types:
300+
new_passes.append(before_pass)
301+
302+
# Add the original pass
303+
new_passes.append(pass_obj)
304+
305+
# Insert passes AFTER this pass
306+
if pass_type in self._pass_insertions:
307+
insertions = self._pass_insertions[pass_type]
308+
for after_pass in insertions.after_passes:
309+
# Check if we should skip this inserted pass
310+
if type(after_pass) not in self._skip_pass_types:
311+
new_passes.append(after_pass)
312+
313+
# Replace the passes list
314+
self.passes = new_passes
315+
self._insertions_applied = True
316+
317+
def _configure_pass_insertions(self, exported_program: ExportedProgram) -> None:
318+
"""Hook for subclasses to configure pass insertions. Called at the START
319+
of pipeline construction, before any passes are added.
320+
321+
Subclasses should override this to call insert_passes_before/after.
322+
323+
Args:
324+
exported_program: The exported program being transformed
325+
326+
"""
327+
pass
328+
225329
def add_passes(self, passes: Sequence[ExportPass | None]):
226330
for p in passes:
227331
if p is not None:
@@ -240,6 +344,9 @@ def add_pass(self, pipeline_pass):
240344
def _tosa_pipeline(
241345
self, exported_program: ExportedProgram, graph_module: GraphModule
242346
) -> GraphModule:
347+
# Allow subclasses to configure pass insertions before building pipeline
348+
self._configure_pass_insertions(exported_program)
349+
243350
# Preprocessing passes
244351
self.add_pass(AnnotateOutputDimOrderPass())
245352

@@ -366,6 +473,7 @@ def _tosa_pipeline(
366473
CastToInt32Pass(),
367474
BroadcastArgsPass(),
368475
ConvertPermuteSingletonToViewPass(),
476+
RewriteHighRankSingletonPermutePass(),
369477
FuseViewCopyTransformPass(),
370478
DecomposeConvWithInt16ActivationPass(),
371479
DecomposeSumPass(),
@@ -396,6 +504,9 @@ def _tosa_pipeline(
396504
]
397505
)
398506

507+
# Apply all pass insertions once after all passes are collected
508+
self._apply_pass_insertions()
509+
399510
self.validate_constraints_mandatory()
400511
return self._transform(graph_module)
401512

@@ -468,6 +579,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
468579
DecomposeLeakyReLUPass(tfa_pass=True),
469580
DecomposeLinalgVectorNormPass(tfa_pass=True),
470581
DecomposeSqrtPass(tfa_pass=True),
582+
DecomposeAdaptiveAvgPool2dPass(tfa_pass=True),
471583
DecomposeAvgPool2dPass(tfa_pass=True),
472584
DecomposeSoftmaxUnstablePass(tfa_pass=True),
473585
DecomposeSoftmaxPass(

backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class DecomposeAdaptiveAvgPool2dPass(ArmPass):
4949
_passes_required_after: Set[Type[ExportPass]] = {DecomposeAvgPool2dPass}
5050

5151
def call_operator(self, op, args, kwargs, meta, updated=False):
52-
if op not in (edge_ops + aten_ops):
52+
if op not in (edge_ops + aten_ops) or not self.allowed_to_transform(meta):
5353
return super().call_operator(op, args, kwargs, meta, updated)
5454

5555
avg_pool2d_op, slice_op, cat_op = _get_decomposition(op)

backends/arm/_passes/decompose_slice_scatter_pass.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,16 @@ def _get_slice_scatter_decomposition(op) -> tuple:
4141

4242
def _fixup_start(start, dim_size: int) -> int:
4343
s = 0 if start is None else int(start)
44-
return max(0, min(s % dim_size if s < 0 else s, dim_size))
44+
if s < 0:
45+
s += dim_size
46+
return max(0, min(s, dim_size))
4547

4648

4749
def _fixup_end(end, dim_size: int) -> int:
4850
e = dim_size if end is None else int(end)
49-
return max(0, min(e % dim_size if e < 0 else e, dim_size))
51+
if e < 0:
52+
e += dim_size
53+
return max(0, min(e, dim_size))
5054

5155

5256
class DecomposeSliceScatterPass(ArmPass):
@@ -131,31 +135,34 @@ def call_operator(self, op, args, kwargs, meta):
131135

132136
# ---- fast path: contiguous update (step == 1) ----
133137
if step == 1:
134-
# prefix = input[..., :start_i, ...] along dim_norm
135-
prefix = super().call_operator(
136-
slice_copy_op,
137-
(input, dim_norm, 0, start_i, 1),
138-
{},
139-
meta,
140-
updated=True,
138+
prefix = None
139+
suffix = None
140+
141+
# prefix = input[..., :start_i, ...] along dim_norm (only if non-empty)
142+
if start_i > 0:
143+
prefix = super().call_operator(
144+
slice_copy_op,
145+
(input, dim_norm, 0, start_i, 1),
146+
{},
147+
meta,
148+
updated=True,
149+
)
150+
151+
# suffix = input[..., end_i:, ...] along dim_norm (only if non-empty)
152+
if end_i < dim_size:
153+
suffix = super().call_operator(
154+
slice_copy_op,
155+
(input, dim_norm, end_i, dim_size, 1),
156+
{},
157+
meta,
158+
updated=True,
159+
)
160+
161+
parts = [x for x in (prefix, src, suffix) if x is not None]
162+
163+
return super().call_operator(
164+
cat_op, (parts, dim_norm), {}, meta, updated=True
141165
)
142-
# suffix = input[..., end_i:, ...] along dim_norm
143-
suffix = super().call_operator(
144-
slice_copy_op,
145-
(input, dim_norm, end_i, dim_size, 1),
146-
{},
147-
meta,
148-
updated=True,
149-
)
150-
# concat(prefix, src, suffix) along dim_norm
151-
updated = super().call_operator(
152-
cat_op,
153-
([prefix, src, suffix], dim_norm),
154-
{},
155-
meta,
156-
updated=True,
157-
)
158-
return updated
159166

160167
# ---- general path: strided update (step > 1) ----
161168
# Move updated dim to front to use a single index tensor.

0 commit comments

Comments
 (0)