Skip to content

Commit de5d980

Browse files
authored
Merge branch 'main' into arm-backend-stat-cache-integration-llama
2 parents 91206c1 + 170677f commit de5d980

14 files changed

Lines changed: 571 additions & 432 deletions

backends/arm/_passes/arm_pass_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch.fx
1515
from executorch.backends.arm.common.debug import get_node_debug_info
1616
from executorch.backends.arm.common.type import ensure_type
17+
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
1718
from executorch.exir import ExportedProgram
1819
from executorch.exir.dialects._ops import ops as exir_ops
1920
from executorch.exir.dialects.edge._ops import EdgeOpOverload
@@ -172,6 +173,30 @@ def create_node(
172173
return node
173174

174175

176+
def create_shape_node(
177+
graph: torch.fx.Graph,
178+
op_target: EdgeOpOverload,
179+
args: tuple = (),
180+
kwargs: Optional[dict] = None,
181+
from_node: Optional[torch.fx.Node] = None,
182+
):
183+
"""Adds a shape node to 'graph'.
184+
185+
graph.inserting_before/after() should be used before the call to decide
186+
where to insert the node.
187+
188+
"""
189+
node = create_node(
190+
graph=graph,
191+
op_target=op_target,
192+
args=args,
193+
kwargs=kwargs,
194+
from_node=from_node,
195+
)
196+
node.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.SHAPE
197+
return node
198+
199+
175200
def insert_q_dq_pair(
176201
graph: torch.fx.Graph,
177202
anchor: torch.fx.Node,

backends/arm/_passes/rewrite_upsample.py

Lines changed: 150 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,35 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

66
from typing import Set, Type
77

8+
import sympy # type: ignore
9+
810
import torch
911
from executorch.backends.arm._passes import ArmPass
1012
from executorch.backends.arm._passes.arm_pass_utils import (
1113
create_node,
14+
create_shape_node,
1215
get_first_fake_tensor,
1316
)
1417
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
15-
from executorch.backends.arm.tosa.utils import get_resize_parameters
1618
from executorch.exir.dialects._ops import ops as exir_ops
1719
from executorch.exir.pass_base import ExportPass, PassResult
1820

1921

2022
class RewriteUpsamplePass(ArmPass):
21-
"""Rewrite upsample2d nodes to TOSA.RESIZE nodes."""
23+
"""Rewrite upsample2d nodes to TOSA.RESIZE nodes with appropriate
24+
parameters.
25+
26+
For constant parameters, CONST_SHAPE nodes are inserted for the scale,
27+
offset, and border values. For symbolic parameters, the parameters are
28+
directly passed to the TOSA.RESIZE node, and we rely on subsequent passes to
29+
handle them correctly once symbolic shapes are delegated by the TOSA
30+
backend.
31+
32+
"""
2233

2334
targeted_ops = (
2435
exir_ops.edge.aten.upsample_nearest2d.vec,
@@ -27,6 +38,89 @@ class RewriteUpsamplePass(ArmPass):
2738

2839
_passes_required_after: Set[Type[ExportPass]] = set()
2940

41+
@staticmethod
42+
def get_resize_parameters_1d(
43+
input_size: int | torch.SymInt,
44+
output_size: int | torch.SymInt,
45+
align_corners: bool,
46+
):
47+
"""Compute resize coefficients for a single spatial dimension.
48+
49+
Args:
50+
input_size (int | torch.SymInt): Input size for the axis, possibly
51+
symbolic.
52+
output_size (int | torch.SymInt): Output size for the axis, possibly
53+
symbolic.
54+
align_corners (bool): Whether the resize should align the corner
55+
pixels.
56+
57+
Returns:
58+
tuple[int, int, int, int]: Numerator, denominator, offset, and border
59+
terms encoded as integers.
60+
61+
Raises:
62+
RuntimeError: If symbolic shapes are used with ``align_corners`` or if
63+
the computed ratio or border is not constant.
64+
65+
"""
66+
# We don't support align_corners for symbolic shapes, because handling the edge case where size == 1 is tricky.
67+
if align_corners:
68+
if (not isinstance(input_size, int)) or (not isinstance(output_size, int)):
69+
raise RuntimeError(
70+
"We do not support align_corners=True for symbolic shapes."
71+
)
72+
73+
# SymInt seems to not actually work for symbolic expressions, so use the underlying sympy objects instead
74+
input_size = (
75+
input_size.node._expr
76+
if isinstance(input_size, torch.SymInt)
77+
else input_size
78+
)
79+
output_size = (
80+
output_size.node._expr
81+
if isinstance(output_size, torch.SymInt)
82+
else output_size
83+
)
84+
if align_corners and input_size > 1 and output_size > 1:
85+
scale_n = output_size - 1
86+
else:
87+
scale_n = output_size
88+
if align_corners and input_size > 1 and output_size > 1:
89+
scale_d = input_size - 1
90+
else:
91+
scale_d = input_size
92+
ratio = scale_n / scale_d
93+
if not sympy.sympify(ratio).is_constant():
94+
raise RuntimeError(
95+
"Resize requires a constant ratio: " + str(ratio) + " is not constant!"
96+
)
97+
gcd = sympy.gcd(scale_n, scale_d)
98+
scale_n = 2 * scale_n // gcd
99+
scale_d = 2 * scale_d // gcd
100+
# These should always be whole integers, based on the above calculations
101+
scale_n = int(scale_n.evalf())
102+
scale_d = int(scale_d.evalf())
103+
104+
if align_corners:
105+
offset = 0
106+
else:
107+
# Half pixel centers so input and output sampling positions are offset by 1/2 pixel.
108+
offset = scale_d // 2 - scale_n // 2
109+
110+
# Calculate border to maintain the correct the output size.
111+
# Note that this should always result in a constant value, as the ratio is constant.
112+
border = scale_d * (output_size - 1) - scale_n * (input_size - 1) + offset
113+
114+
if not sympy.sympify(border).is_constant():
115+
raise RuntimeError(
116+
"Resize requires a constant border: "
117+
+ str(border)
118+
+ " is not constant!"
119+
)
120+
121+
border = int(sympy.sympify(border).evalf())
122+
return scale_n, scale_d, offset, border
123+
30124
def call(self, graph_module):
31125
modified = False
32126
for node in graph_module.graph.nodes:
@@ -39,14 +133,65 @@ def call(self, graph_module):
39133
resize_mode = "bilinear"
40134
else:
41135
x, output_size, scale_factors = node.args
136+
# As per https://docs.pytorch.org/docs/stable/generated/torch.nn.Upsample.html
137+
# align_corners is not valid for nearest mode. Default to False.
42138
align_corners = False
43139
resize_mode = "nearest"
44140

141+
input_size_yx = node.args[0].meta["val"].shape[2:]
142+
output_size_yx = node.meta["val"].shape[2:]
143+
144+
scale_y_n, scale_y_d, offset_y, border_y = (
145+
RewriteUpsamplePass.get_resize_parameters_1d(
146+
input_size_yx[0], output_size_yx[0], align_corners
147+
)
148+
)
149+
scale_x_n, scale_x_d, offset_x, border_x = (
150+
RewriteUpsamplePass.get_resize_parameters_1d(
151+
input_size_yx[1], output_size_yx[1], align_corners
152+
)
153+
)
154+
155+
scales = [
156+
scale_y_n,
157+
scale_y_d,
158+
scale_x_n,
159+
scale_x_d,
160+
]
45161
with graph_module.graph.inserting_before(node):
162+
if all(isinstance(s, int) for s in scales):
163+
scale = create_shape_node(
164+
graph_module.graph,
165+
op_target=exir_ops.backend.tosa.CONST_SHAPE.default,
166+
args=(scales,),
167+
kwargs={},
168+
from_node=node,
169+
)
170+
else:
171+
scale = scales
172+
offset = [offset_y, offset_x]
173+
if all(isinstance(o, int) for o in offset):
174+
offset = create_shape_node(
175+
graph_module.graph,
176+
op_target=exir_ops.backend.tosa.CONST_SHAPE.default,
177+
args=(offset,),
178+
kwargs={},
179+
from_node=node,
180+
)
181+
border = [border_y, border_x]
182+
if all(isinstance(b, int) for b in border):
183+
border = create_shape_node(
184+
graph_module.graph,
185+
op_target=exir_ops.backend.tosa.CONST_SHAPE.default,
186+
args=(border,),
187+
kwargs={},
188+
from_node=node,
189+
)
190+
46191
tosa_resize_node = create_node(
47192
graph_module.graph,
48193
op_target=exir_ops.backend.tosa.RESIZE.default,
49-
args=(x, output_size, align_corners, scale_factors),
194+
args=(x, scale, offset, border),
50195
kwargs={"resize_mode": resize_mode},
51196
from_node=node,
52197
inherit_qparams=True,
@@ -57,18 +202,8 @@ def call(self, graph_module):
57202
if (
58203
input_dtype == torch.int8 or input_dtype == torch.int16
59204
) and resize_mode == "bilinear":
60-
input_size = get_first_fake_tensor(x).shape
61-
input_size_xy = input_size[2:]
62-
output_size = get_first_fake_tensor(node).shape
63-
output_size_xy = output_size[2:]
64-
scale_n_yx, _, _, _ = get_resize_parameters(
65-
input_size_xy=input_size_xy,
66-
output_size_xy=output_size_xy,
67-
resize_mode=1,
68-
align_corners=align_corners,
69-
)
70205
output_dtype = get_first_fake_tensor(node).dtype
71-
output_scale = float(1 / (scale_n_yx[0] * scale_n_yx[1]))
206+
output_scale = float(1 / (scale_y_n * scale_x_n))
72207
with graph_module.graph.inserting_after(tosa_resize_node):
73208
rescale_node = create_node(
74209
graph_module.graph,

backends/arm/_passes/to_tosa_memory_format_pass.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,9 @@ def _propagate_dim_order_to_shape_args(self, node: torch.fx.Node) -> None:
436436
raise RuntimeError(
437437
f"Conflicting dim orders {arg.meta['tosa_dim_order']} and {dim_order} for shape node {arg.name}"
438438
)
439+
if node.target == exir_ops.backend.tosa.RESIZE.default:
440+
# RESIZE's shape input is expected to be in HW order, so we need to override the dim order to be the identity for it regardless of the user node's dim order.
441+
dim_order = tuple(range(len(arg.meta["val"])))
439442
arg.meta["tosa_dim_order"] = dim_order
440443
self._propagate_dim_order_to_shape_args(arg)
441444

backends/arm/operators/op_tosa_resize.py

Lines changed: 6 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,8 @@
1515
)
1616
from executorch.backends.arm.operators.operator_validation_utils import (
1717
validate_num_inputs,
18-
validate_same_dtype,
19-
validate_valid_dtype,
2018
)
2119
from executorch.backends.arm.tosa.mapping import TosaArg
22-
from executorch.backends.arm.tosa.utils import get_resize_parameters
2320

2421

2522
@register_node_visitor
@@ -36,81 +33,12 @@ def define_node(
3633
inputs: List[TosaArg],
3734
output: TosaArg,
3835
) -> None:
39-
validate_num_inputs(self.target, inputs, [3, 4])
40-
supported_input_dtypes = [
41-
ts.DType.INT8,
42-
ts.DType.FP16,
43-
ts.DType.FP32,
44-
ts.DType.BF16,
45-
]
46-
if self.tosa_spec.support_extension("int16"):
47-
supported_input_dtypes.append(ts.DType.INT16)
48-
if self.tosa_spec.support_extension("bf16"):
49-
supported_input_dtypes.append(ts.DType.BF16)
50-
validate_valid_dtype(
51-
self.target,
52-
[inputs[0]],
53-
supported_input_dtypes,
54-
self.tosa_spec,
55-
)
56-
supported_output_dtypes = [ts.DType.FP16, ts.DType.FP32, ts.DType.BF16]
36+
x, scales, offset, border = inputs
37+
validate_num_inputs(self.target, inputs, [4])
5738
if node.kwargs.get("resize_mode") == "bilinear":
5839
resize_mode = ts.ResizeMode.BILINEAR
59-
align_corners = bool(node.args[2])
60-
supported_output_dtypes.append(ts.DType.INT32)
61-
if self.tosa_spec.support_extension("int16"):
62-
supported_output_dtypes.append(ts.DType.INT48)
6340
else:
6441
resize_mode = ts.ResizeMode.NEAREST
65-
align_corners = False
66-
validate_same_dtype(self.target, [inputs[0], output], ts)
67-
supported_output_dtypes.append(ts.DType.INT8)
68-
if self.tosa_spec.support_extension("int16"):
69-
supported_output_dtypes.append(ts.DType.INT16)
70-
validate_valid_dtype(
71-
self.target, [output], supported_output_dtypes, self.tosa_spec
72-
)
73-
# tosa_shape output is NHWC, take HW
74-
input_size_yx = tuple([inputs[0].shape[dim] for dim in inputs[0].dim_order])[
75-
1:3
76-
]
77-
output_size_yx = tuple([output.shape[dim] for dim in output.dim_order])[1:3]
78-
79-
# Align corners shouldn't make a difference for nearest upsampling. We set to False so
80-
# half pixel centers are used for resize parameter logic.
81-
scale_n_yx, scale_d_yx, offset_yx, border_yx = get_resize_parameters(
82-
input_size_yx, output_size_yx, resize_mode, align_corners=align_corners
83-
)
84-
85-
def in_int16_range(x):
86-
return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1)
87-
88-
if not in_int16_range(scale_n_yx):
89-
raise ValueError("scale_n_yx is out of the int16 range")
90-
if not in_int16_range(scale_d_yx):
91-
raise ValueError("scale_d_yx is out of the int16 range")
92-
if not in_int16_range(border_yx):
93-
raise ValueError("border_yx is out of the int16 range")
94-
95-
scale_n_vals = [int(v) for v in scale_n_yx.tolist()]
96-
scale_d_vals = [int(v) for v in scale_d_yx.tolist()]
97-
scales = [
98-
scale_n_vals[0],
99-
scale_d_vals[0],
100-
scale_n_vals[1],
101-
scale_d_vals[1],
102-
]
103-
scales_tensor = tosa_graph.addConst(
104-
[len(scales)], ts.DType.SHAPE, scales, output.name + "_scales"
105-
)
106-
offset = [int(v) for v in offset_yx.tolist()]
107-
offset_tensor = tosa_graph.addConst(
108-
[len(offset)], ts.DType.SHAPE, offset, output.name + "_offset"
109-
)
110-
border = [int(v) for v in border_yx.tolist()]
111-
border_tensor = tosa_graph.addConst(
112-
[len(border)], ts.DType.SHAPE, border, output.name + "_border"
113-
)
11442
attr = ts.TosaSerializerAttribute()
11543
attr.ResizeAttribute(resize_mode)
11644

@@ -119,10 +47,10 @@ def in_int16_range(x):
11947
tosa_graph,
12048
ts.Op.RESIZE,
12149
[
122-
inputs[0].name,
123-
scales_tensor.name,
124-
offset_tensor.name,
125-
border_tensor.name,
50+
x.name,
51+
scales.name,
52+
offset.name,
53+
border.name,
12654
],
12755
[output.name],
12856
attr,

0 commit comments

Comments
 (0)