Skip to content

Commit d9c2c6c

Browse files
author
ssjia
committed
Update base for Update on "[ET-VK] Add symint infrastructure to VulkanBackend and ComputeGraph"
Extend the Vulkan backend runtime infrastructure to better support symbolic integer (symint) arguments. This is a prerequisite for operators that need to handle dynamic shapes via symint values. Changes: - VulkanBackend.cpp: Compute output offset from end of args instead of assuming outputs follow inputs directly. Add scalar-to-tensor input handling so that Int/Bool EValues can populate tensor inputs. Support symint inputs provided as raw Int EValues (not just scalar tensors). Add symint output handling to write values back as tensor or Int EValue. - ComputeGraph.h: Add SymInt case to extract_scalar<T>() so operators can transparently read symint values as scalars. - ComputeGraph.cpp: Add Int fallback in read_symint() so values stored as plain Int (rather than SymInt objects) can be read uniformly. Differential Revision: [D95970167](https://our.internmc.facebook.com/intern/diff/D95970167/) cc manuelcandales digantdesai cbilgin [ghstack-poisoned]
2 parents 8725a18 + 22174fa commit d9c2c6c

43 files changed

Lines changed: 1911 additions & 529 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

backends/arm/_passes/TARGETS

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,55 @@
11
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
22

33
runtime.python_library(
4-
name = "passes",
5-
srcs = glob(["*.py"]),
4+
name = "core",
5+
srcs = [
6+
"arm_pass.py",
7+
"arm_pass_utils.py",
8+
"quant_args.py",
9+
],
610
deps = [
711
"//executorch/backends/arm:common",
812
"//executorch/backends/arm:constants",
913
"//executorch/backends/arm/tosa:utils",
1014
"//executorch/backends/arm/tosa/dialect:lib",
15+
"//executorch/backends/transforms:utils",
16+
"//executorch/exir:lib",
17+
],
18+
)
19+
20+
runtime.python_library(
21+
name = "arm_pass_manager_base",
22+
srcs = ["arm_pass_manager.py"],
23+
deps = [
24+
":core",
25+
"//executorch/backends/arm:common",
26+
"//executorch/backends/arm/tosa:utils",
27+
"//executorch/backends/arm/tosa/dialect:lib",
28+
"//executorch/exir:lib",
29+
],
30+
)
31+
32+
runtime.python_library(
33+
name = "arm_pass_manager_fb",
34+
srcs = [],
35+
deps = [
36+
":arm_pass_manager_base",
37+
# @oss-disable[end= ]: "//executorch/backends/arm/_passes/fb:fb",
38+
],
39+
)
40+
41+
runtime.python_library(
42+
name = "passes",
43+
srcs = glob(["*.py"], exclude = [
44+
"arm_pass.py",
45+
"arm_pass_utils.py",
46+
"quant_args.py",
47+
]),
48+
deps = [
49+
":core",
50+
":arm_pass_manager_base" if runtime.is_oss else ":arm_pass_manager_fb",
51+
"//executorch/backends/arm/tosa:utils",
52+
"//executorch/backends/arm/tosa/dialect:lib",
1153
"//executorch/backends/transforms:fuse_view_copy",
1254
"//executorch/backends/transforms:remove_getitem_op",
1355
"//executorch/backends/transforms:replace_scalar_with_tensor",

backends/arm/_passes/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@
138138
RewriteBoolToFp32CastViaInt8Pass,
139139
)
140140
from .rewrite_conv_pass import RewriteConvPass # noqa
141+
from .rewrite_high_rank_singleton_permute_pass import ( # noqa
142+
RewriteHighRankSingletonPermutePass,
143+
)
141144
from .rewrite_index_put_pass import RewriteIndexPutPass # noqa
142145
from .rewrite_le_lt_to_ge_gt_pass import RewriteLeLtToGeGtPass # noqa
143146
from .rewrite_matmul import RewriteMatmulPass # noqa
@@ -155,4 +158,13 @@
155158
from .control_flow_const_inline import ( # noqa # usort: skip
156159
ControlFlowConstInlinePass,
157160
)
158-
from .arm_pass_manager import ArmPassManager # noqa # usort: skip
161+
162+
# Import all subpackages to allow extensions to patch classes
163+
import importlib # noqa: E402
164+
import pkgutil # noqa: E402
165+
166+
for _, _modname, _ispkg in pkgutil.iter_modules(__path__, __name__ + "."):
167+
if _ispkg:
168+
importlib.import_module(_modname)
169+
170+
from .arm_pass_manager import ArmPassManager # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import logging
99
from collections import defaultdict
1010
from collections.abc import Sequence
11+
from dataclasses import dataclass, field
1112

1213
import executorch.backends.arm.tosa.dialect # noqa: unused
1314
from executorch.backends.arm._passes import (
@@ -121,6 +122,7 @@
121122
RewriteBoolBitwiseToLogicalPass,
122123
RewriteBoolToFp32CastViaInt8Pass,
123124
RewriteConvPass,
125+
RewriteHighRankSingletonPermutePass,
124126
RewriteIndexPutPass,
125127
RewriteLeLtToGeGtPass,
126128
RewriteMatmulPass,
@@ -156,11 +158,21 @@
156158
logger = logging.getLogger(__name__)
157159

158160

161+
@dataclass
162+
class PassInsertions:
163+
"""Holds lists of passes to be inserted before and after a target pass."""
164+
165+
before_passes: list = field(default_factory=list)
166+
after_passes: list = field(default_factory=list)
167+
168+
159169
class ArmPassManager(PassManager):
160170
def __init__(self, compile_spec: ArmCompileSpec) -> None:
161171
self.compile_spec = compile_spec
162172
self.tosa_spec = compile_spec.tosa_spec
163173
self._skip_pass_types: tuple[type, ...] = ()
174+
self._pass_insertions: dict[type, PassInsertions] = {}
175+
self._insertions_applied = False
164176
super().__init__()
165177
self.configure_skip_passes()
166178

@@ -222,6 +234,98 @@ def validate_constraints_mandatory(self):
222234

223235
raise RuntimeError(error_msg)
224236

237+
def insert_passes_before(
238+
self, target_pass_type: type, passes: list[ExportPass]
239+
) -> None:
240+
"""Register passes to be inserted before instances of target_pass_type.
241+
Insertions are deferred and applied via _apply_pass_insertions().
242+
243+
Args:
244+
target_pass_type: The pass class to insert before (e.g., InsertTableOpsPass)
245+
passes: List of pass instances to insert
246+
247+
"""
248+
self._pass_insertions.setdefault(
249+
target_pass_type, PassInsertions()
250+
).before_passes.extend(passes)
251+
252+
def insert_passes_after(
253+
self, target_pass_type: type, passes: list[ExportPass]
254+
) -> None:
255+
"""Register passes to be inserted after instances of target_pass_type.
256+
Insertions are deferred and applied via _apply_pass_insertions().
257+
258+
Args:
259+
target_pass_type: The pass class to insert after
260+
passes: List of pass instances to insert
261+
262+
"""
263+
self._pass_insertions.setdefault(
264+
target_pass_type, PassInsertions()
265+
).after_passes.extend(passes)
266+
267+
def _apply_pass_insertions(self) -> None:
268+
"""Apply all registered pass insertions to the collected passes.
269+
270+
Called ONCE after all add_passes() calls are complete, before execution.
271+
272+
Raises:
273+
ValueError: If any registered target pass type is not found in the pipeline.
274+
275+
"""
276+
if self._insertions_applied or not self._pass_insertions:
277+
return
278+
279+
# Fail fast if any target pass type is missing from the pipeline
280+
existing_pass_types = {type(p) for p in self.passes}
281+
for target_type in self._pass_insertions:
282+
if target_type not in existing_pass_types:
283+
available = [type(p).__name__ for p in self.passes]
284+
raise ValueError(
285+
f"Target pass {target_type.__name__} not found in the pass "
286+
f"pipeline. Available passes: {available}"
287+
)
288+
289+
# Build new pass list with insertions applied
290+
new_passes = []
291+
for pass_obj in self.passes:
292+
pass_type = type(pass_obj)
293+
294+
# Insert passes BEFORE this pass
295+
if pass_type in self._pass_insertions:
296+
insertions = self._pass_insertions[pass_type]
297+
for before_pass in insertions.before_passes:
298+
# Check if we should skip this inserted pass
299+
if type(before_pass) not in self._skip_pass_types:
300+
new_passes.append(before_pass)
301+
302+
# Add the original pass
303+
new_passes.append(pass_obj)
304+
305+
# Insert passes AFTER this pass
306+
if pass_type in self._pass_insertions:
307+
insertions = self._pass_insertions[pass_type]
308+
for after_pass in insertions.after_passes:
309+
# Check if we should skip this inserted pass
310+
if type(after_pass) not in self._skip_pass_types:
311+
new_passes.append(after_pass)
312+
313+
# Replace the passes list
314+
self.passes = new_passes
315+
self._insertions_applied = True
316+
317+
def _configure_pass_insertions(self, exported_program: ExportedProgram) -> None:
318+
"""Hook for subclasses to configure pass insertions. Called at the START
319+
of pipeline construction, before any passes are added.
320+
321+
Subclasses should override this to call insert_passes_before/after.
322+
323+
Args:
324+
exported_program: The exported program being transformed
325+
326+
"""
327+
pass
328+
225329
def add_passes(self, passes: Sequence[ExportPass | None]):
226330
for p in passes:
227331
if p is not None:
@@ -240,6 +344,9 @@ def add_pass(self, pipeline_pass):
240344
def _tosa_pipeline(
241345
self, exported_program: ExportedProgram, graph_module: GraphModule
242346
) -> GraphModule:
347+
# Allow subclasses to configure pass insertions before building pipeline
348+
self._configure_pass_insertions(exported_program)
349+
243350
# Preprocessing passes
244351
self.add_pass(AnnotateOutputDimOrderPass())
245352

@@ -366,6 +473,7 @@ def _tosa_pipeline(
366473
CastToInt32Pass(),
367474
BroadcastArgsPass(),
368475
ConvertPermuteSingletonToViewPass(),
476+
RewriteHighRankSingletonPermutePass(),
369477
FuseViewCopyTransformPass(),
370478
DecomposeConvWithInt16ActivationPass(),
371479
DecomposeSumPass(),
@@ -396,6 +504,9 @@ def _tosa_pipeline(
396504
]
397505
)
398506

507+
# Apply all pass insertions once after all passes are collected
508+
self._apply_pass_insertions()
509+
399510
self.validate_constraints_mandatory()
400511
return self._transform(graph_module)
401512

@@ -468,6 +579,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
468579
DecomposeLeakyReLUPass(tfa_pass=True),
469580
DecomposeLinalgVectorNormPass(tfa_pass=True),
470581
DecomposeSqrtPass(tfa_pass=True),
582+
DecomposeAdaptiveAvgPool2dPass(tfa_pass=True),
471583
DecomposeAvgPool2dPass(tfa_pass=True),
472584
DecomposeSoftmaxUnstablePass(tfa_pass=True),
473585
DecomposeSoftmaxPass(

backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class DecomposeAdaptiveAvgPool2dPass(ArmPass):
4949
_passes_required_after: Set[Type[ExportPass]] = {DecomposeAvgPool2dPass}
5050

5151
def call_operator(self, op, args, kwargs, meta, updated=False):
52-
if op not in (edge_ops + aten_ops):
52+
if op not in (edge_ops + aten_ops) or not self.allowed_to_transform(meta):
5353
return super().call_operator(op, args, kwargs, meta, updated)
5454

5555
avg_pool2d_op, slice_op, cat_op = _get_decomposition(op)

backends/arm/_passes/decompose_slice_scatter_pass.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,16 @@ def _get_slice_scatter_decomposition(op) -> tuple:
4141

4242
def _fixup_start(start, dim_size: int) -> int:
4343
s = 0 if start is None else int(start)
44-
return max(0, min(s % dim_size if s < 0 else s, dim_size))
44+
if s < 0:
45+
s += dim_size
46+
return max(0, min(s, dim_size))
4547

4648

4749
def _fixup_end(end, dim_size: int) -> int:
4850
e = dim_size if end is None else int(end)
49-
return max(0, min(e % dim_size if e < 0 else e, dim_size))
51+
if e < 0:
52+
e += dim_size
53+
return max(0, min(e, dim_size))
5054

5155

5256
class DecomposeSliceScatterPass(ArmPass):
@@ -131,31 +135,34 @@ def call_operator(self, op, args, kwargs, meta):
131135

132136
# ---- fast path: contiguous update (step == 1) ----
133137
if step == 1:
134-
# prefix = input[..., :start_i, ...] along dim_norm
135-
prefix = super().call_operator(
136-
slice_copy_op,
137-
(input, dim_norm, 0, start_i, 1),
138-
{},
139-
meta,
140-
updated=True,
138+
prefix = None
139+
suffix = None
140+
141+
# prefix = input[..., :start_i, ...] along dim_norm (only if non-empty)
142+
if start_i > 0:
143+
prefix = super().call_operator(
144+
slice_copy_op,
145+
(input, dim_norm, 0, start_i, 1),
146+
{},
147+
meta,
148+
updated=True,
149+
)
150+
151+
# suffix = input[..., end_i:, ...] along dim_norm (only if non-empty)
152+
if end_i < dim_size:
153+
suffix = super().call_operator(
154+
slice_copy_op,
155+
(input, dim_norm, end_i, dim_size, 1),
156+
{},
157+
meta,
158+
updated=True,
159+
)
160+
161+
parts = [x for x in (prefix, src, suffix) if x is not None]
162+
163+
return super().call_operator(
164+
cat_op, (parts, dim_norm), {}, meta, updated=True
141165
)
142-
# suffix = input[..., end_i:, ...] along dim_norm
143-
suffix = super().call_operator(
144-
slice_copy_op,
145-
(input, dim_norm, end_i, dim_size, 1),
146-
{},
147-
meta,
148-
updated=True,
149-
)
150-
# concat(prefix, src, suffix) along dim_norm
151-
updated = super().call_operator(
152-
cat_op,
153-
([prefix, src, suffix], dim_norm),
154-
{},
155-
meta,
156-
updated=True,
157-
)
158-
return updated
159166

160167
# ---- general path: strided update (step > 1) ----
161168
# Move updated dim to front to use a single index tensor.

0 commit comments

Comments
 (0)