Skip to content

Commit 161376d

Browse files
bdemirbBaris Demir
andauthored
Arm backend: Reject symbolic shapes without shape extension (#19689)
This patch prevents unsupported dynamic-shape PAD cases from being delegated to TOSA specs that do not support the shape extension. Dynamic PAD lowering needs symbolic shape handling. When symbolic tensor shapes are serialized for a spec without the TOSA shape extension, symbolic dimensions can be emitted as -1. For PAD this can make Vela validate output shapes with checks such as -1 == -1 + 1, which reports a PAD output shape mismatch instead of rejecting the unsupported dynamic-shape case earlier. Add a shared TOSA partitioner support check that rejects nodes with symbolic tensor input or output shapes when the active TOSA spec does not enable the shape extension. Known supported symbolic-shape cases, such as upsample_nearest2d.vec and quantize/dequantize boundaries, remain delegable. This keeps the restriction tied to the TOSA spec capability while preserving existing dynamic upsample coverage. The regression coverage verifies that U55 and U85 dynamic PAD cases are not delegated and that static 5D PAD still delegates through Ethos-U lowering. Signed-off-by: Baris Demir <baris.demir@arm.com> Change-Id: If2912c1badcb6528c4869e5bc716aa8da136ff2a cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell @rascani Co-authored-by: Baris Demir <baris.demir@arm.com>
1 parent ee61747 commit 161376d

2 files changed

Lines changed: 198 additions & 0 deletions

File tree

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,8 @@ def tosa_support_factory(
310310
negative_checks.append(EthosU55NotSupported(reporter))
311311
negative_checks.append(EthosU55DtypeSupport(reporter))
312312
negative_checks.append(EthosU55CastCheck(reporter))
313+
if not tosa_spec.support_extension("shape"):
314+
negative_checks.append(SymbolicShapeSupportCheck(reporter))
313315

314316
return chain(
315317
reporter.wrap_check(
@@ -320,6 +322,72 @@ def tosa_support_factory(
320322
)
321323

322324

325+
class SymbolicShapeSupportCheck(OperatorSupportBase):
326+
"""Reject symbolic tensor shapes for specs without the shape extension."""
327+
328+
def __init__(self, reporter: WhyNoPartitionReporter):
329+
"""Initialize the check with a reporter.
330+
331+
Args:
332+
reporter (WhyNoPartitionReporter): Reporter for rejection reasons.
333+
334+
"""
335+
self.reporter = reporter
336+
337+
@staticmethod
338+
def _has_symbolic_shape(node: fx.Node) -> bool:
339+
val = node.meta.get("val")
340+
vals = val if isinstance(val, (list, tuple)) else (val,)
341+
for node_val in vals:
342+
if isinstance(node_val, torch.SymInt):
343+
return True
344+
345+
shape = getattr(node_val, "shape", None)
346+
if shape is not None and any(
347+
isinstance(dim, torch.SymInt) for dim in shape
348+
):
349+
return True
350+
351+
return False
352+
353+
def is_node_supported(
354+
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
355+
) -> bool:
356+
"""Return False for nodes with symbolic tensor input or output shapes.
357+
358+
Dynamic shapes require the TOSA shape extension. Reject nodes with
359+
symbolic tensor dimensions before partitioning when the active spec
360+
does not enable that extension.
361+
362+
Args:
363+
submodules (typing.Mapping[str, torch.nn.Module]): Exported modules.
364+
node (fx.Node): FX node to check.
365+
366+
Returns:
367+
bool: False if rejected by constraints; otherwise, True.
368+
369+
"""
370+
if node.op in ("placeholder", "output"):
371+
return True
372+
if node.op == "call_function" and node.target in (*Q_OPS, *DQ_OPS):
373+
return True
374+
375+
if self._has_symbolic_shape(node) or any(
376+
self._has_symbolic_shape(input_node) for input_node in node.all_input_nodes
377+
):
378+
if node.target == exir_ops.edge.aten.upsample_nearest2d.vec:
379+
return True
380+
381+
self.reporter.report_reject(
382+
node,
383+
"Node has symbolic shape but the TOSA spec does not support "
384+
"the shape extension.",
385+
)
386+
return False
387+
388+
return True
389+
390+
323391
class TOSAProINTSupportList(OperatorSupportBase):
324392
"""Provide the INT profile support list for TOSA.
325393

backends/arm/test/ops/test_constant_pad_nd.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,23 @@
99

1010
import torch
1111
import torch.nn.functional as F
12+
from executorch.backends.arm.operator_support.tosa_supported_operators import (
13+
SymbolicShapeSupportCheck,
14+
)
1215
from executorch.backends.arm.quantizer.arm_quantizer import (
1316
get_symmetric_a16w8_quantization_config,
1417
)
1518
from executorch.backends.arm.test import common
19+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
1620
from executorch.backends.arm.test.tester.test_pipeline import (
1721
TosaPipelineFP,
1822
TosaPipelineINT,
1923
VgfPipeline,
2024
)
25+
from executorch.exir import to_edge
26+
from executorch.exir.backend.utils import WhyNoPartitionReporter
27+
from executorch.exir.dialects._ops import ops as exir_ops
28+
from torch.export import Dim, export
2129

2230
aten_op = "torch.ops.aten.pad.default"
2331
exir_op = "executorch_exir_dialects_edge__ops_aten_pad_default"
@@ -143,6 +151,128 @@ def forward(self, x: torch.Tensor):
143151
return F.pad(x, pad=self.pad, mode=self.mode, value=self.value)
144152

145153

154+
class RawConstantPadND(torch.nn.Module):
155+
def __init__(self, pad: Tuple, value: float = 0.0):
156+
super().__init__()
157+
self.pad = pad
158+
self.value = value
159+
160+
def forward(self, x: torch.Tensor):
161+
return F.pad(x, pad=self.pad, mode="constant", value=self.value)
162+
163+
164+
def _constant_pad_nd_node(
165+
module: torch.nn.Module,
166+
example_inputs: tuple[torch.Tensor, ...],
167+
dynamic_shapes=None,
168+
) -> torch.fx.Node:
169+
edge = to_edge(
170+
export(module, example_inputs, dynamic_shapes=dynamic_shapes, strict=True)
171+
)
172+
return next(
173+
n
174+
for n in edge.exported_program().graph.nodes
175+
if n.target == exir_ops.edge.aten.constant_pad_nd.default
176+
)
177+
178+
179+
def _is_tosa_without_shape_extension_supported(node: torch.fx.Node) -> bool:
180+
return SymbolicShapeSupportCheck(WhyNoPartitionReporter()).is_node_supported(
181+
{}, node
182+
)
183+
184+
185+
def test_constant_pad_nd_no_target_u55_symbolic_padded_axis_not_delegated():
186+
input_tensor = torch.rand(1, 3, 8, 8, 5)
187+
width = Dim("width", min=4, max=10)
188+
node = _constant_pad_nd_node(
189+
RawConstantPadND((0, 1, 0, 0, 0, 0, 0, 0)),
190+
(input_tensor,),
191+
dynamic_shapes={"x": {4: width}},
192+
)
193+
194+
assert not _is_tosa_without_shape_extension_supported(node)
195+
196+
197+
def test_constant_pad_nd_no_target_u55_symbolic_unpadded_axis_not_delegated():
198+
input_tensor = torch.rand(1, 3, 8, 8, 5)
199+
width = Dim("width", min=4, max=10)
200+
node = _constant_pad_nd_node(
201+
RawConstantPadND((0, 0, 1, 0, 0, 0, 0, 0)),
202+
(input_tensor,),
203+
dynamic_shapes={"x": {4: width}},
204+
)
205+
206+
assert not _is_tosa_without_shape_extension_supported(node)
207+
208+
209+
def test_constant_pad_nd_no_target_u55_static_padded_axis_supported():
210+
input_tensor = torch.rand(1, 3, 8, 8, 5)
211+
node = _constant_pad_nd_node(
212+
RawConstantPadND((0, 1, 0, 0, 0, 0, 0, 0)),
213+
(input_tensor,),
214+
)
215+
216+
assert _is_tosa_without_shape_extension_supported(node)
217+
218+
219+
def test_constant_pad_nd_u55_INT_dynamic_padded_axis_not_delegated():
220+
input_tensor = torch.rand(1, 3, 8, 8, 5)
221+
width = Dim("width", min=4, max=10)
222+
tester = ArmTester(
223+
RawConstantPadND((0, 1, 0, 0, 0, 0, 0, 0)),
224+
(input_tensor,),
225+
common.get_u55_compile_spec(),
226+
dynamic_shapes=({4: width},),
227+
)
228+
229+
tester.quantize().export().to_edge().partition()
230+
targets = {
231+
node.target
232+
for node in tester.stages[tester.cur].artifact.exported_program().graph.nodes
233+
}
234+
235+
assert exir_ops.edge.aten.constant_pad_nd.default in targets
236+
assert torch.ops.higher_order.executorch_call_delegate not in targets
237+
238+
239+
def test_constant_pad_nd_u85_INT_dynamic_padded_axis_not_delegated():
240+
input_tensor = torch.rand(1, 3, 8, 8, 5)
241+
width = Dim("width", min=4, max=10)
242+
tester = ArmTester(
243+
RawConstantPadND((0, 1, 0, 0, 0, 0, 0, 0)),
244+
(input_tensor,),
245+
common.get_u85_compile_spec(),
246+
dynamic_shapes=({4: width},),
247+
)
248+
249+
tester.quantize().export().to_edge().partition()
250+
targets = {
251+
node.target
252+
for node in tester.stages[tester.cur].artifact.exported_program().graph.nodes
253+
}
254+
255+
assert exir_ops.edge.aten.constant_pad_nd.default in targets
256+
assert torch.ops.higher_order.executorch_call_delegate not in targets
257+
258+
259+
def test_constant_pad_nd_u55_INT_static_5d_padded_axis_delegated():
260+
input_tensor = torch.rand(1, 3, 8, 8, 5)
261+
tester = ArmTester(
262+
RawConstantPadND((0, 1, 0, 0, 0, 0, 0, 0)),
263+
(input_tensor,),
264+
common.get_u55_compile_spec(),
265+
)
266+
267+
tester.quantize().export().to_edge_transform_and_lower()
268+
targets = {
269+
node.target
270+
for node in tester.stages[tester.cur].artifact.exported_program().graph.nodes
271+
}
272+
273+
assert torch.ops.higher_order.executorch_call_delegate in targets
274+
275+
146276
@common.parametrize(
147277
"test_data",
148278
test_data_suite | test_data_suite_bf16 | test_data_suite_fp16,

0 commit comments

Comments
 (0)