1- # Copyright 2025 Arm Limited and/or its affiliates.
1+ # Copyright 2025-2026 Arm Limited and/or its affiliates.
22#
33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
55
66from typing import Set , Type
77
8+ import sympy # type: ignore
9+
810import torch
911from executorch .backends .arm ._passes import ArmPass
1012from executorch .backends .arm ._passes .arm_pass_utils import (
1113 create_node ,
14+ create_shape_node ,
1215 get_first_fake_tensor ,
1316)
1417from executorch .backends .arm .tosa .mapping import TosaSpecialDtype
15- from executorch .backends .arm .tosa .utils import get_resize_parameters
1618from executorch .exir .dialects ._ops import ops as exir_ops
1719from executorch .exir .pass_base import ExportPass , PassResult
1820
1921
2022class RewriteUpsamplePass (ArmPass ):
21- """Rewrite upsample2d nodes to TOSA.RESIZE nodes."""
23+ """Rewrite upsample2d nodes to TOSA.RESIZE nodes with appropriate
24+ parameters.
25+
26+ For constant parameters, CONST_SHAPE nodes are inserted for the scale,
27+ offset, and border values. For symbolic parameters, the parameters are
28+ directly passed to the TOSA.RESIZE node, and we rely on subsequent passes to
29+ handle them correctly once symbolic shapes are delegated by the TOSA
30+ backend.
31+
32+ """
2233
2334 targeted_ops = (
2435 exir_ops .edge .aten .upsample_nearest2d .vec ,
@@ -27,6 +38,89 @@ class RewriteUpsamplePass(ArmPass):
2738
2839 _passes_required_after : Set [Type [ExportPass ]] = set ()
2940
41+ @staticmethod
42+ def get_resize_parameters_1d (
43+ input_size : int | torch .SymInt ,
44+ output_size : int | torch .SymInt ,
45+ align_corners : bool ,
46+ ):
47+ """Compute resize coefficients for a single spatial dimension.
48+
49+ Args:
50+ input_size (int | torch.SymInt): Input size for the axis, possibly
51+ symbolic.
52+ output_size (int | torch.SymInt): Output size for the axis, possibly
53+ symbolic.
54+ align_corners (bool): Whether the resize should align the corner
55+ pixels.
56+
57+ Returns:
58+ tuple[int, int, int, int]: Numerator, denominator, offset, and border
59+ terms encoded as integers.
60+
61+ Raises:
62+ RuntimeError: If symbolic shapes are used with ``align_corners`` or if
63+ the computed ratio or border is not constant.
64+
65+ """
66+ # We don't support align_corners for symbolic shapes, because handling the edge case where size == 1 is tricky.
67+ if align_corners :
68+ if (not isinstance (input_size , int )) or (not isinstance (output_size , int )):
69+ raise RuntimeError (
70+ "We do not support align_corners=True for symbolic shapes."
71+ )
72+
73+ # SymInt seems to not actually work for symbolic expressions, so use the underlying sympy objects instead
74+ input_size = (
75+ input_size .node ._expr
76+ if isinstance (input_size , torch .SymInt )
77+ else input_size
78+ )
79+ output_size = (
80+ output_size .node ._expr
81+ if isinstance (output_size , torch .SymInt )
82+ else output_size
83+ )
84+ if align_corners and input_size > 1 and output_size > 1 :
85+ scale_n = output_size - 1
86+ else :
87+ scale_n = output_size
88+ if align_corners and input_size > 1 and output_size > 1 :
89+ scale_d = input_size - 1
90+ else :
91+ scale_d = input_size
92+ ratio = scale_n / scale_d
93+ if not sympy .sympify (ratio ).is_constant ():
94+ raise RuntimeError (
95+ "Resize requires a constant ratio: " + str (ratio ) + " is not constant!"
96+ )
97+ gcd = sympy .gcd (scale_n , scale_d )
98+ scale_n = 2 * scale_n // gcd
99+ scale_d = 2 * scale_d // gcd
100+ # These should always be whole integers, based on the above calculations
101+ scale_n = int (scale_n .evalf ())
102+ scale_d = int (scale_d .evalf ())
103+
104+ if align_corners :
105+ offset = 0
106+ else :
107+ # Half pixel centers so input and output sampling positions are offset by 1/2 pixel.
108+ offset = scale_d // 2 - scale_n // 2
109+
110+ # Calculate border to maintain the correct the output size.
111+ # Note that this should always result in a constant value, as the ratio is constant.
112+ border = scale_d * (output_size - 1 ) - scale_n * (input_size - 1 ) + offset
113+
114+ if not sympy .sympify (border ).is_constant ():
115+ raise RuntimeError (
116+ "Resize requires a constant border: "
117+ + str (border )
118+ + " is not constant!"
119+ )
120+
121+ border = int (sympy .sympify (border ).evalf ())
122+ return scale_n , scale_d , offset , border
123+
30124 def call (self , graph_module ):
31125 modified = False
32126 for node in graph_module .graph .nodes :
@@ -39,14 +133,65 @@ def call(self, graph_module):
39133 resize_mode = "bilinear"
40134 else :
41135 x , output_size , scale_factors = node .args
136+ # As per https://docs.pytorch.org/docs/stable/generated/torch.nn.Upsample.html
137+ # align_corners is not valid for nearest mode. Default to False.
42138 align_corners = False
43139 resize_mode = "nearest"
44140
141+ input_size_yx = node .args [0 ].meta ["val" ].shape [2 :]
142+ output_size_yx = node .meta ["val" ].shape [2 :]
143+
144+ scale_y_n , scale_y_d , offset_y , border_y = (
145+ RewriteUpsamplePass .get_resize_parameters_1d (
146+ input_size_yx [0 ], output_size_yx [0 ], align_corners
147+ )
148+ )
149+ scale_x_n , scale_x_d , offset_x , border_x = (
150+ RewriteUpsamplePass .get_resize_parameters_1d (
151+ input_size_yx [1 ], output_size_yx [1 ], align_corners
152+ )
153+ )
154+
155+ scales = [
156+ scale_y_n ,
157+ scale_y_d ,
158+ scale_x_n ,
159+ scale_x_d ,
160+ ]
45161 with graph_module .graph .inserting_before (node ):
162+ if all (isinstance (s , int ) for s in scales ):
163+ scale = create_shape_node (
164+ graph_module .graph ,
165+ op_target = exir_ops .backend .tosa .CONST_SHAPE .default ,
166+ args = (scales ,),
167+ kwargs = {},
168+ from_node = node ,
169+ )
170+ else :
171+ scale = scales
172+ offset = [offset_y , offset_x ]
173+ if all (isinstance (o , int ) for o in offset ):
174+ offset = create_shape_node (
175+ graph_module .graph ,
176+ op_target = exir_ops .backend .tosa .CONST_SHAPE .default ,
177+ args = (offset ,),
178+ kwargs = {},
179+ from_node = node ,
180+ )
181+ border = [border_y , border_x ]
182+ if all (isinstance (b , int ) for b in border ):
183+ border = create_shape_node (
184+ graph_module .graph ,
185+ op_target = exir_ops .backend .tosa .CONST_SHAPE .default ,
186+ args = (border ,),
187+ kwargs = {},
188+ from_node = node ,
189+ )
190+
46191 tosa_resize_node = create_node (
47192 graph_module .graph ,
48193 op_target = exir_ops .backend .tosa .RESIZE .default ,
49- args = (x , output_size , align_corners , scale_factors ),
194+ args = (x , scale , offset , border ),
50195 kwargs = {"resize_mode" : resize_mode },
51196 from_node = node ,
52197 inherit_qparams = True ,
@@ -57,18 +202,8 @@ def call(self, graph_module):
57202 if (
58203 input_dtype == torch .int8 or input_dtype == torch .int16
59204 ) and resize_mode == "bilinear" :
60- input_size = get_first_fake_tensor (x ).shape
61- input_size_xy = input_size [2 :]
62- output_size = get_first_fake_tensor (node ).shape
63- output_size_xy = output_size [2 :]
64- scale_n_yx , _ , _ , _ = get_resize_parameters (
65- input_size_xy = input_size_xy ,
66- output_size_xy = output_size_xy ,
67- resize_mode = 1 ,
68- align_corners = align_corners ,
69- )
70205 output_dtype = get_first_fake_tensor (node ).dtype
71- output_scale = float (1 / (scale_n_yx [ 0 ] * scale_n_yx [ 1 ] ))
206+ output_scale = float (1 / (scale_y_n * scale_x_n ))
72207 with graph_module .graph .inserting_after (tosa_resize_node ):
73208 rescale_node = create_node (
74209 graph_module .graph ,
0 commit comments