Skip to content

Commit 05b8602

Browse files
committed
Update base for Update on "Fix SLEEF preprocessor macro name to match ATen vec headers"
The ATen NEON vectorized math headers (vec128_float_neon.h) check for AT_BUILD_ARM_VEC256_WITH_SLEEF to enable SLEEF intrinsics for exp(), log(), etc. ExecuTorch's get_vec_preprocessor_flags() was defining ET_BUILD_ARM_VEC256_WITH_SLEEF (wrong prefix), so the USE_SLEEF macro always took the fallback path: map(std::exp) — scalar exp called per-element with full vector load/store overhead wrapping it. With this fix, Vectorized<float>::exp() correctly dispatches to Sleef_expf4_u10 on ARM, which is the intended behavior. Differential Revision: [D96044314](https://our.internmc.facebook.com/intern/diff/D96044314/) [ghstack-poisoned]
2 parents 90c2ca5 + bf64fa1 commit 05b8602

66 files changed

Lines changed: 3040 additions & 647 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.ci/docker/build.sh

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/bin/bash
22
# Copyright (c) Meta Platforms, Inc. and affiliates.
33
# All rights reserved.
4+
# Copyright 2026 Arm Limited and/or its affiliates.
45
#
56
# This source code is licensed under the BSD-style license found in the
67
# LICENSE file in the root directory of this source tree.
@@ -94,11 +95,6 @@ BUILD_DOCS=1
9495
# Copy requirements-lintrunner.txt from root to here
9596
cp ../../requirements-lintrunner.txt ./
9697

97-
# Copy arm setup script from root to here
98-
# TODO(huydhn): Figure out a way to rebuild the Docker image automatically
99-
# with a new image hash when the content here is updated
100-
cp -r ../../examples/arm/ ./arm
101-
10298
docker build \
10399
--no-cache \
104100
--progress=plain \

backends/apple/metal/runtime/metal_backend.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <c10/util/safe_numerics.h>
910
#include <dlfcn.h>
1011
#include <executorch/runtime/backend/interface.h>
1112
#include <executorch/runtime/core/error.h>
@@ -459,8 +460,10 @@ class ET_EXPERIMENTAL MetalBackend final
459460

460461
ET_LOG(Debug, "MetalBackend n_outputs %zd generated", n_outputs);
461462

463+
size_t n_io_sum = 0;
462464
ET_CHECK_OR_RETURN_ERROR(
463-
n_inputs + n_outputs == args.size(),
465+
!c10::add_overflows(n_inputs, n_outputs, &n_io_sum) &&
466+
n_io_sum == args.size(),
464467
InvalidArgument,
465468
"number of user input %zd and output %zd generated from AOT Inductor does not match ET runner's %zd. Exit.",
466469
n_inputs,

backends/arm/_passes/fuse_constant_ops_pass.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import logging
7-
from typing import Set, Type
7+
from collections.abc import Mapping
8+
from typing import Sequence, Set, Type
89

910
import torch._export.utils
1011
import torch.fx
@@ -18,6 +19,7 @@
1819
from executorch.backends.arm._passes.fuse_equal_placeholders_pass import (
1920
FuseEqualPlaceholdersPass,
2021
)
22+
from executorch.backends.arm.tosa.dialect.shape import meta_has_shape_mark
2123
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
2224
from executorch.backends.transforms.utils import (
2325
create_constant_placeholder,
@@ -53,6 +55,36 @@ def __init__(self, exported_program: ExportedProgram, *args, **kwargs) -> None:
5355
super().__init__(*args, **kwargs)
5456
self.exported_program = exported_program
5557

58+
@staticmethod
59+
def _is_tosa_dialect_op(target) -> bool:
60+
target_str = str(target)
61+
return (
62+
"executorch.exir.dialects.backend._ops.tosa." in target_str
63+
or "<EdgeOpOverload: tosa." in target_str
64+
)
65+
66+
@staticmethod
67+
def _arg_contains_symbolic_shape(arg) -> bool:
68+
if isinstance(arg, torch.fx.Node):
69+
if meta_has_shape_mark(arg.meta):
70+
return True
71+
return FuseConstantArgsPass._arg_contains_symbolic_shape(
72+
arg.meta.get("val")
73+
)
74+
if isinstance(arg, torch.SymInt):
75+
return True
76+
if isinstance(arg, Mapping):
77+
return any(
78+
FuseConstantArgsPass._arg_contains_symbolic_shape(k)
79+
or FuseConstantArgsPass._arg_contains_symbolic_shape(v)
80+
for k, v in arg.items()
81+
)
82+
if isinstance(arg, Sequence) and not isinstance(arg, (str, bytes)):
83+
return any(
84+
FuseConstantArgsPass._arg_contains_symbolic_shape(v) for v in arg
85+
)
86+
return False
87+
5688
def _propagate_special_dtype(self, from_nodes, to_node, data):
5789
"""Propagate special dtype meta if it exists."""
5890
special_dtypes = set()
@@ -142,13 +174,13 @@ def call(self, graph_module):
142174
for node in graph_module.graph.nodes:
143175
if node.op != "call_function":
144176
continue
145-
if node.target in [
146-
exir_ops.backend.tosa.MATMUL.default,
147-
exir_ops.backend.tosa.RESCALE.default,
148-
exir_ops.backend.tosa.RESIZE.default,
149-
exir_ops.backend.tosa.TABLE.default,
150-
exir_ops.backend.tosa.TRANSPOSE.default,
151-
]:
177+
# Don't fuse TOSA dialect ops as they do not have eager forward functions.
178+
# Also don't fuse ops whose explicit args/kwargs include symbolic shape values.
179+
if (
180+
self._is_tosa_dialect_op(node.target)
181+
or self._arg_contains_symbolic_shape(node.args)
182+
or self._arg_contains_symbolic_shape(node.kwargs)
183+
):
152184
continue
153185

154186
input_nodes = node.all_input_nodes
@@ -164,7 +196,6 @@ def call(self, graph_module):
164196
)
165197
if not all(input_nodes_constant):
166198
continue
167-
168199
try:
169200
did_fuse = self._fuse_nodes(node)
170201
if did_fuse:

backends/arm/_passes/rewrite_conv_pass.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
get_input_qparams,
2222
get_output_qparams,
2323
)
24+
from executorch.backends.arm._passes.symbolic_value_range import (
25+
evaluate_symbolic_expr_values,
26+
)
2427
from executorch.backends.arm.constants import HWCM_ORDER, NHWC_INVERSE_ORDER
2528
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
2629
from executorch.backends.arm.tosa.specification import get_context_shape_env
@@ -83,16 +86,22 @@ def _adjust_pad_if_needed(
8386

8487
if isinstance(mod_remainder, torch.SymInt):
8588
shape_env = get_context_shape_env()
86-
value_ranges = shape_env.bound_sympy(mod_remainder.node.expr)
87-
mod_remainder_upper = int(value_ranges.upper)
89+
exact_values = evaluate_symbolic_expr_values(
90+
mod_remainder.node.expr, shape_env
91+
)
92+
if exact_values is not None:
93+
mod_remainder_upper = max(exact_values)
94+
else:
95+
value_ranges = shape_env.bound_sympy(mod_remainder.node.expr)
96+
mod_remainder_upper = int(value_ranges.upper)
8897
if mod_remainder_upper == 0:
8998
mod_remainder = 0
9099
else:
91100
mod_remainder_upper = mod_remainder
92101

93102
if mod_remainder_upper > pad:
94103
raise RuntimeError(
95-
"This case should be handled by the SizeAdjustInputPass, is it enabled?"
104+
"This case should be handled by the SizeAdjustInputPass, is it enabled?\n"
96105
)
97106
return pad - mod_remainder
98107

backends/arm/_passes/size_adjust_input_pass.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55
from typing import cast, Sequence, Set, Type, TypeAlias
66

7+
import torch
78
import torch.fx
89
from executorch.backends.arm._passes import ArmPass
910
from executorch.backends.arm._passes.arm_pass_utils import (
@@ -12,6 +13,9 @@
1213
)
1314
from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass
1415
from executorch.backends.arm._passes.rewrite_max_pool2d_pass import RewriteMaxPool2dPass
16+
from executorch.backends.arm._passes.symbolic_value_range import (
17+
evaluate_symbolic_expr_values,
18+
)
1519
from executorch.backends.arm.tosa.specification import get_context_shape_env
1620
from executorch.exir.dialects._ops import ops as exir_ops
1721
from executorch.exir.pass_base import ExportPass, PassResult
@@ -49,6 +53,9 @@ def _greater_than(input: SymIntLike, other: int) -> bool | torch.SymBool:
4953
"""Returns whether an int or SymInt is greater than another value."""
5054
if isinstance(input, torch.SymInt):
5155
shape_env = get_context_shape_env()
56+
exact_values = evaluate_symbolic_expr_values(input.node.expr, shape_env)
57+
if exact_values is not None:
58+
return max(exact_values) > other
5259
value_ranges = shape_env.bound_sympy(input.node.expr)
5360
return value_ranges.upper > other
5461
else:
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Optional
7+
8+
import sympy # type: ignore[import-untyped]
9+
import torch
10+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
11+
from torch.utils._sympy.interp import sympy_interp
12+
13+
_MAX_SET_SIZE = 256
14+
_ExactValues = Optional[frozenset[sympy.Basic]]
15+
16+
17+
def _expr_to_int(sym_expr: sympy.Basic) -> Optional[int]:
18+
if isinstance(sym_expr, int):
19+
return sym_expr
20+
if isinstance(sym_expr, sympy.Integer):
21+
return int(sym_expr)
22+
if getattr(sym_expr, "is_integer", False) and sym_expr.is_number:
23+
return int(sym_expr)
24+
return None
25+
26+
27+
def _symbol_values(symbol: sympy.Symbol, shape_env: ShapeEnv) -> _ExactValues:
28+
value_range = shape_env.var_to_range.get(symbol)
29+
if value_range is None or not value_range.is_int:
30+
return None
31+
32+
lower = _expr_to_int(value_range.lower)
33+
upper = _expr_to_int(value_range.upper)
34+
if lower is None or upper is None or upper < lower:
35+
return None
36+
if upper - lower + 1 > _MAX_SET_SIZE:
37+
return None
38+
39+
return frozenset(sympy.Integer(value) for value in range(lower, upper + 1))
40+
41+
42+
def _map_values(values: _ExactValues, fn) -> _ExactValues:
43+
if values is None:
44+
return None
45+
46+
result = {sympy.simplify(fn(value)) for value in values}
47+
if len(result) > _MAX_SET_SIZE:
48+
return None
49+
return frozenset(result)
50+
51+
52+
def _combine_values(lhs: _ExactValues, rhs: _ExactValues, fn) -> _ExactValues:
53+
if lhs is None or rhs is None:
54+
return None
55+
if len(lhs) * len(rhs) > _MAX_SET_SIZE * _MAX_SET_SIZE:
56+
return None
57+
58+
result = {sympy.simplify(fn(a, b)) for a in lhs for b in rhs}
59+
if len(result) > _MAX_SET_SIZE:
60+
return None
61+
return frozenset(result)
62+
63+
64+
class _ExactValueAnalysis:
65+
@staticmethod
66+
def constant(value, dtype) -> frozenset[sympy.Basic]:
67+
return frozenset({sympy.sympify(value)})
68+
69+
@staticmethod
70+
def add(lhs: _ExactValues, rhs: _ExactValues) -> _ExactValues:
71+
return _combine_values(lhs, rhs, lambda a, b: a + b)
72+
73+
@staticmethod
74+
def mul(lhs: _ExactValues, rhs: _ExactValues) -> _ExactValues:
75+
return _combine_values(lhs, rhs, lambda a, b: a * b)
76+
77+
@staticmethod
78+
def mod(lhs: _ExactValues, rhs: _ExactValues) -> _ExactValues:
79+
if rhs is None or any(value == 0 for value in rhs):
80+
return None
81+
return _combine_values(lhs, rhs, lambda a, b: sympy.Mod(a, b))
82+
83+
@staticmethod
84+
def pow(lhs: _ExactValues, rhs: _ExactValues) -> _ExactValues:
85+
return _combine_values(lhs, rhs, lambda a, b: a**b)
86+
87+
@staticmethod
88+
def floor_to_int(values: _ExactValues, dtype) -> _ExactValues:
89+
return _map_values(values, sympy.floor)
90+
91+
@staticmethod
92+
def sym_sum(args: list[_ExactValues]) -> _ExactValues:
93+
acc: _ExactValues = frozenset({sympy.Integer(0)})
94+
for arg in args:
95+
acc = _ExactValueAnalysis.add(acc, arg)
96+
if acc is None:
97+
return None
98+
return acc
99+
100+
101+
def evaluate_symbolic_expr_values(
102+
expr: sympy.Basic | torch.SymInt,
103+
shape_env: ShapeEnv,
104+
) -> Optional[set[int]]:
105+
"""Return a best-effort finite set of possible integer values.
106+
107+
The helper first relies on ``bound_sympy`` for cheap singleton detection.
108+
When interval bounds are not precise enough, it falls back to a small
109+
exact-set analysis over bounded symbols using ``sympy_interp``.
110+
111+
"""
112+
root_expr = sympy.simplify(
113+
expr.node.expr if isinstance(expr, torch.SymInt) else expr
114+
)
115+
value_range = shape_env.bound_sympy(root_expr)
116+
if value_range.is_int and value_range.is_singleton():
117+
singleton = _expr_to_int(value_range.lower)
118+
return {singleton} if singleton is not None else None
119+
120+
exact_values = sympy_interp(
121+
_ExactValueAnalysis,
122+
{
123+
symbol: _symbol_values(symbol, shape_env)
124+
for symbol in root_expr.free_symbols
125+
},
126+
root_expr,
127+
missing_handler=lambda symbol: _symbol_values(symbol, shape_env),
128+
)
129+
if exact_values is None:
130+
return None
131+
132+
result: set[int] = set()
133+
for value in exact_values:
134+
integer_value = _expr_to_int(sympy.simplify(value))
135+
if integer_value is None:
136+
return None
137+
result.add(integer_value)
138+
return result

0 commit comments

Comments
 (0)