Skip to content

Commit 9f951dd

Browse files
Arm backend: Sympify tosa.RESIZE lowering (#19222)
Modify rewrite_upsample to support dynamic shapes. cc @digantdesai @freddan80 @per @zingo @mansnils @Sebastian-Larsson @robell Signed-off-by: Oscar Andersson <oscar.andersson@arm.com>
1 parent a6ee309 commit 9f951dd

4 files changed

Lines changed: 104 additions & 17 deletions

File tree

backends/arm/_passes/rewrite_upsample.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -72,16 +72,16 @@ def get_resize_parameters_1d(
7272
"We do not support align_corners=True for symbolic shapes."
7373
)
7474

75-
# SymInt seems to not actually work for symbolic expressions, so use the underlying sympy objects instead
75+
# Use the exported SymPy expressions for symbolic shapes.
7676
input_size = (
77-
input_size.node._expr
77+
sympy.sympify(input_size.node.expr)
7878
if isinstance(input_size, torch.SymInt)
79-
else input_size
79+
else sympy.sympify(input_size)
8080
)
8181
output_size = (
82-
output_size.node._expr
82+
sympy.sympify(output_size.node.expr)
8383
if isinstance(output_size, torch.SymInt)
84-
else output_size
84+
else sympy.sympify(output_size)
8585
)
8686
if align_corners and input_size > 1 and output_size > 1:
8787
scale_n = output_size - 1
@@ -91,17 +91,15 @@ def get_resize_parameters_1d(
9191
scale_d = input_size - 1
9292
else:
9393
scale_d = input_size
94-
ratio = scale_n / scale_d
95-
if not sympy.sympify(ratio).is_constant():
94+
ratio = sympy.nsimplify(sympy.simplify(scale_n / scale_d))
95+
if ratio.free_symbols:
9696
raise RuntimeError(
9797
"Resize requires a constant ratio: " + str(ratio) + " is not constant!"
9898
)
99-
gcd = sympy.gcd(scale_n, scale_d)
100-
scale_n = 2 * scale_n // gcd
101-
scale_d = 2 * scale_d // gcd
102-
# These should always be whole integers, based on the above calculations
103-
scale_n = int(scale_n.evalf())
104-
scale_d = int(scale_d.evalf())
99+
ratio_num, ratio_den = ratio.as_numer_denom()
100+
# TOSA encodes resize scales as doubled rationals.
101+
scale_n = int((2 * ratio_num).evalf())
102+
scale_d = int((2 * ratio_den).evalf())
105103

106104
if align_corners:
107105
offset = 0
@@ -111,9 +109,11 @@ def get_resize_parameters_1d(
111109

112110
# Calculate border to maintain the correct the output size.
113111
# Note that this should always result in a constant value, as the ratio is constant.
114-
border = scale_d * (output_size - 1) - scale_n * (input_size - 1) + offset
112+
border = sympy.simplify(
113+
scale_d * (output_size - 1) - scale_n * (input_size - 1) + offset
114+
)
115115

116-
if not sympy.sympify(border).is_constant():
116+
if border.free_symbols:
117117
raise RuntimeError(
118118
"Resize requires a constant border: "
119119
+ str(border)

backends/arm/test/misc/test_tosa_dialect_resize.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import executorch.backends.arm.tosa.dialect # noqa: F401
7-
87
import pytest
8+
import sympy # type: ignore
99
import torch
1010

1111
from executorch.backends.arm.tosa.dialect.lib import TosaValueError
@@ -15,6 +15,21 @@
1515
)
1616
from executorch.exir.dialects._ops import ops as exir_ops
1717
from torch._subclasses.fake_tensor import FakeTensorMode
18+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
19+
20+
21+
def _make_symint(
22+
shape_env: ShapeEnv, symbol: str, hint: int, min: int = 1, max: int = 64
23+
) -> torch.SymInt:
24+
symint = shape_env.create_symintnode(sympy.Symbol(symbol), hint=hint)
25+
shape_env.constrain_symbol_range(
26+
symint.node.expr, compiler_min=min, compiler_max=max
27+
)
28+
return symint
29+
30+
31+
def _expr(sym: torch.SymInt) -> sympy.Expr:
32+
return sympy.sympify(getattr(sym.node, "expr", sym.node._expr))
1833

1934

2035
def test_bilinear_resize_rejects_exact_one_sixteenth_downscale():
@@ -34,3 +49,30 @@ def test_bilinear_resize_rejects_exact_one_sixteenth_downscale():
3449
[-15, -15],
3550
resize_mode="bilinear",
3651
)
52+
53+
54+
def test_resize_accepts_symbolic_scale_and_border_values():
55+
shape_env = ShapeEnv()
56+
scale_y_n = _make_symint(shape_env, "scale_y_n", hint=2, min=1, max=8)
57+
border_y = _make_symint(shape_env, "border_y", hint=1, min=0, max=8)
58+
59+
with TosaLoweringContext(
60+
TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env
61+
), FakeTensorMode(shape_env=shape_env) as mode:
62+
x = mode.from_tensor(torch.empty(size=(1, 3, 4, 2), dtype=torch.float32))
63+
output = exir_ops.backend.tosa.RESIZE.default(
64+
x,
65+
[scale_y_n, 1, 4, 2],
66+
[0, 0],
67+
[border_y, 0],
68+
resize_mode="nearest",
69+
)
70+
71+
assert output.dtype == torch.float32
72+
assert (output.shape[0], output.shape[-1]) == (1, 2)
73+
assert isinstance(output.shape[1], torch.SymInt)
74+
assert output.shape[2] == 7
75+
# The output height is computed as: (input_height - 1) * scale_y_n + border_y + 1.
76+
# As the hegiht is a symbolic expression, we check that the expression is correct by
77+
# comparing it to the expected expression.
78+
assert str(_expr(output.shape[1])) == "(((border_y + 2*scale_y_n)//1)) + 1"
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import pytest
7+
import sympy # type: ignore
8+
import torch
9+
from executorch.backends.arm._passes.rewrite_upsample import RewriteUpsamplePass
10+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
11+
12+
13+
def _make_symint(
14+
shape_env: ShapeEnv, symbol: str, hint: int, min: int = 1, max: int = 64
15+
) -> torch.SymInt:
16+
symint = shape_env.create_symintnode(sympy.Symbol(symbol), hint=hint)
17+
shape_env.constrain_symbol_range(
18+
symint.node.expr, compiler_min=min, compiler_max=max
19+
)
20+
return symint
21+
22+
23+
def test_get_resize_parameters_1d_supports_symbolic_shapes_with_constant_ratio():
24+
shape_env = ShapeEnv()
25+
input_size = _make_symint(shape_env, "input_size", hint=4)
26+
output_size = input_size * 2
27+
28+
scale_n, scale_d, offset, border = RewriteUpsamplePass.get_resize_parameters_1d(
29+
input_size, output_size, align_corners=False
30+
)
31+
32+
assert (scale_n, scale_d, offset, border) == (4, 2, -1, 1)
33+
34+
35+
def test_get_resize_parameters_1d_rejects_non_constant_symbolic_ratio():
36+
shape_env = ShapeEnv()
37+
input_size = _make_symint(shape_env, "input_size", hint=4)
38+
output_size = input_size + 1
39+
40+
with pytest.raises(RuntimeError, match="constant ratio"):
41+
RewriteUpsamplePass.get_resize_parameters_1d(
42+
input_size, output_size, align_corners=False
43+
)

backends/arm/tosa/dialect/ops/resize.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ def _get_output_dtype(
5252

5353
def _validate_resize_parameters(scale, border, resize_mode):
5454
def in_int16_range(values):
55-
return all((x >= -(2**15)) and (x <= 2**15 - 1) for x in values)
55+
return all(
56+
(x >= -(2**15)) and (x <= 2**15 - 1) for x in values if isinstance(x, int)
57+
)
5658

5759
if not in_int16_range(scale):
5860
raise TosaValueError("scale is out of the int16 range", op="RESIZE")

0 commit comments

Comments
 (0)