Skip to content

Commit c5c5b3a

Browse files
Arm backend: Add util for symbolic range eval (#19108)
Adds util for computing a value range from a symbolic expression. cc @digantdesai @freddan80 @per @zingo @mansnils @Sebastian-Larsson @robell Signed-off-by: Oscar Andersson <oscar.andersson@arm.com>
1 parent b6a47aa commit c5c5b3a

4 files changed

Lines changed: 226 additions & 3 deletions

File tree

backends/arm/_passes/rewrite_conv_pass.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
get_input_qparams,
2222
get_output_qparams,
2323
)
24+
from executorch.backends.arm._passes.symbolic_value_range import (
25+
evaluate_symbolic_expr_values,
26+
)
2427
from executorch.backends.arm.constants import HWCM_ORDER, NHWC_INVERSE_ORDER
2528
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
2629
from executorch.backends.arm.tosa.specification import get_context_shape_env
@@ -83,16 +86,22 @@ def _adjust_pad_if_needed(
8386

8487
if isinstance(mod_remainder, torch.SymInt):
8588
shape_env = get_context_shape_env()
86-
value_ranges = shape_env.bound_sympy(mod_remainder.node.expr)
87-
mod_remainder_upper = int(value_ranges.upper)
89+
exact_values = evaluate_symbolic_expr_values(
90+
mod_remainder.node.expr, shape_env
91+
)
92+
if exact_values is not None:
93+
mod_remainder_upper = max(exact_values)
94+
else:
95+
value_ranges = shape_env.bound_sympy(mod_remainder.node.expr)
96+
mod_remainder_upper = int(value_ranges.upper)
8897
if mod_remainder_upper == 0:
8998
mod_remainder = 0
9099
else:
91100
mod_remainder_upper = mod_remainder
92101

93102
if mod_remainder_upper > pad:
94103
raise RuntimeError(
95-
"This case should be handled by the SizeAdjustInputPass, is it enabled?"
104+
"This case should be handled by the SizeAdjustInputPass, is it enabled?\n"
96105
)
97106
return pad - mod_remainder
98107

backends/arm/_passes/size_adjust_input_pass.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55
from typing import cast, Sequence, Set, Type, TypeAlias
66

7+
import torch
78
import torch.fx
89
from executorch.backends.arm._passes import ArmPass
910
from executorch.backends.arm._passes.arm_pass_utils import (
@@ -12,6 +13,9 @@
1213
)
1314
from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass
1415
from executorch.backends.arm._passes.rewrite_max_pool2d_pass import RewriteMaxPool2dPass
16+
from executorch.backends.arm._passes.symbolic_value_range import (
17+
evaluate_symbolic_expr_values,
18+
)
1519
from executorch.backends.arm.tosa.specification import get_context_shape_env
1620
from executorch.exir.dialects._ops import ops as exir_ops
1721
from executorch.exir.pass_base import ExportPass, PassResult
@@ -49,6 +53,9 @@ def _greater_than(input: SymIntLike, other: int) -> bool | torch.SymBool:
4953
"""Returns whether an int or SymInt is greater than another value."""
5054
if isinstance(input, torch.SymInt):
5155
shape_env = get_context_shape_env()
56+
exact_values = evaluate_symbolic_expr_values(input.node.expr, shape_env)
57+
if exact_values is not None:
58+
return max(exact_values) > other
5259
value_ranges = shape_env.bound_sympy(input.node.expr)
5360
return value_ranges.upper > other
5461
else:
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Optional
7+
8+
import sympy # type: ignore[import-untyped]
9+
import torch
10+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
11+
from torch.utils._sympy.interp import sympy_interp
12+
13+
_MAX_SET_SIZE = 256
14+
_ExactValues = Optional[frozenset[sympy.Basic]]
15+
16+
17+
def _expr_to_int(sym_expr: sympy.Basic) -> Optional[int]:
18+
if isinstance(sym_expr, int):
19+
return sym_expr
20+
if isinstance(sym_expr, sympy.Integer):
21+
return int(sym_expr)
22+
if getattr(sym_expr, "is_integer", False) and sym_expr.is_number:
23+
return int(sym_expr)
24+
return None
25+
26+
27+
def _symbol_values(symbol: sympy.Symbol, shape_env: ShapeEnv) -> _ExactValues:
28+
value_range = shape_env.var_to_range.get(symbol)
29+
if value_range is None or not value_range.is_int:
30+
return None
31+
32+
lower = _expr_to_int(value_range.lower)
33+
upper = _expr_to_int(value_range.upper)
34+
if lower is None or upper is None or upper < lower:
35+
return None
36+
if upper - lower + 1 > _MAX_SET_SIZE:
37+
return None
38+
39+
return frozenset(sympy.Integer(value) for value in range(lower, upper + 1))
40+
41+
42+
def _map_values(values: _ExactValues, fn) -> _ExactValues:
43+
if values is None:
44+
return None
45+
46+
result = {sympy.simplify(fn(value)) for value in values}
47+
if len(result) > _MAX_SET_SIZE:
48+
return None
49+
return frozenset(result)
50+
51+
52+
def _combine_values(lhs: _ExactValues, rhs: _ExactValues, fn) -> _ExactValues:
53+
if lhs is None or rhs is None:
54+
return None
55+
if len(lhs) * len(rhs) > _MAX_SET_SIZE * _MAX_SET_SIZE:
56+
return None
57+
58+
result = {sympy.simplify(fn(a, b)) for a in lhs for b in rhs}
59+
if len(result) > _MAX_SET_SIZE:
60+
return None
61+
return frozenset(result)
62+
63+
64+
class _ExactValueAnalysis:
65+
@staticmethod
66+
def constant(value, dtype) -> frozenset[sympy.Basic]:
67+
return frozenset({sympy.sympify(value)})
68+
69+
@staticmethod
70+
def add(lhs: _ExactValues, rhs: _ExactValues) -> _ExactValues:
71+
return _combine_values(lhs, rhs, lambda a, b: a + b)
72+
73+
@staticmethod
74+
def mul(lhs: _ExactValues, rhs: _ExactValues) -> _ExactValues:
75+
return _combine_values(lhs, rhs, lambda a, b: a * b)
76+
77+
@staticmethod
78+
def mod(lhs: _ExactValues, rhs: _ExactValues) -> _ExactValues:
79+
if rhs is None or any(value == 0 for value in rhs):
80+
return None
81+
return _combine_values(lhs, rhs, lambda a, b: sympy.Mod(a, b))
82+
83+
@staticmethod
84+
def pow(lhs: _ExactValues, rhs: _ExactValues) -> _ExactValues:
85+
return _combine_values(lhs, rhs, lambda a, b: a**b)
86+
87+
@staticmethod
88+
def floor_to_int(values: _ExactValues, dtype) -> _ExactValues:
89+
return _map_values(values, sympy.floor)
90+
91+
@staticmethod
92+
def sym_sum(args: list[_ExactValues]) -> _ExactValues:
93+
acc: _ExactValues = frozenset({sympy.Integer(0)})
94+
for arg in args:
95+
acc = _ExactValueAnalysis.add(acc, arg)
96+
if acc is None:
97+
return None
98+
return acc
99+
100+
101+
def evaluate_symbolic_expr_values(
102+
expr: sympy.Basic | torch.SymInt,
103+
shape_env: ShapeEnv,
104+
) -> Optional[set[int]]:
105+
"""Return a best-effort finite set of possible integer values.
106+
107+
The helper first relies on ``bound_sympy`` for cheap singleton detection.
108+
When interval bounds are not precise enough, it falls back to a small
109+
exact-set analysis over bounded symbols using ``sympy_interp``.
110+
111+
"""
112+
root_expr = sympy.simplify(
113+
expr.node.expr if isinstance(expr, torch.SymInt) else expr
114+
)
115+
value_range = shape_env.bound_sympy(root_expr)
116+
if value_range.is_int and value_range.is_singleton():
117+
singleton = _expr_to_int(value_range.lower)
118+
return {singleton} if singleton is not None else None
119+
120+
exact_values = sympy_interp(
121+
_ExactValueAnalysis,
122+
{
123+
symbol: _symbol_values(symbol, shape_env)
124+
for symbol in root_expr.free_symbols
125+
},
126+
root_expr,
127+
missing_handler=lambda symbol: _symbol_values(symbol, shape_env),
128+
)
129+
if exact_values is None:
130+
return None
131+
132+
result: set[int] = set()
133+
for value in exact_values:
134+
integer_value = _expr_to_int(sympy.simplify(value))
135+
if integer_value is None:
136+
return None
137+
result.add(integer_value)
138+
return result
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import sympy # type: ignore[import-untyped]
7+
import torch
8+
from executorch.backends.arm._passes.symbolic_value_range import (
9+
evaluate_symbolic_expr_values,
10+
)
11+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
12+
13+
14+
def _make_shape_env(
15+
*,
16+
symbol_name: str = "s89",
17+
hint: int = 2,
18+
compiler_min: int = 1,
19+
compiler_max: int = 2,
20+
) -> tuple[ShapeEnv, torch.SymInt]:
21+
shape_env = ShapeEnv()
22+
symint = shape_env.create_symintnode(sympy.Symbol(symbol_name), hint=hint)
23+
shape_env.constrain_symbol_range(
24+
symint.node.expr,
25+
compiler_min=compiler_min,
26+
compiler_max=compiler_max,
27+
)
28+
return shape_env, symint
29+
30+
31+
def test_evaluate_symbolic_expr_values_returns_singleton_for_constant_expr() -> None:
32+
shape_env, symint = _make_shape_env()
33+
34+
assert evaluate_symbolic_expr_values(
35+
symint.node.expr - symint.node.expr, shape_env
36+
) == {0}
37+
assert evaluate_symbolic_expr_values(
38+
sympy.floor(symint.node.expr / symint.node.expr), shape_env
39+
) == {1}
40+
41+
42+
def test_evaluate_symbolic_expr_values_returns_singleton_for_singleton_symint() -> None:
43+
shape_env, symint = _make_shape_env(hint=3, compiler_min=3, compiler_max=3)
44+
45+
assert evaluate_symbolic_expr_values(symint, shape_env) == {3}
46+
assert evaluate_symbolic_expr_values(symint.node.expr, shape_env) == {3}
47+
48+
49+
def test_evaluate_symbolic_expr_values_enumerates_non_singleton_symint() -> None:
50+
shape_env, symint = _make_shape_env(hint=3, compiler_min=2, compiler_max=6)
51+
52+
assert evaluate_symbolic_expr_values(symint, shape_env) == {2, 3, 4, 5, 6}
53+
assert evaluate_symbolic_expr_values(symint.node.expr, shape_env) == {2, 3, 4, 5, 6}
54+
55+
56+
def test_evaluate_symbolic_expr_values_tracks_exact_modulo_residue() -> None:
57+
shape_env, symint = _make_shape_env(hint=3, compiler_min=2, compiler_max=6)
58+
expr = sympy.Mod(16 * symint.node.expr - 7, 4)
59+
60+
value_range = shape_env.bound_sympy(expr)
61+
assert value_range.lower == 0
62+
assert value_range.upper == 3
63+
assert evaluate_symbolic_expr_values(expr, shape_env) == {1}
64+
65+
66+
def test_evaluate_symbolic_expr_values_bails_out_for_large_symbol_ranges() -> None:
67+
shape_env, symint = _make_shape_env(hint=3, compiler_min=1, compiler_max=400)
68+
69+
assert evaluate_symbolic_expr_values(symint, shape_env) is None

0 commit comments

Comments
 (0)