Skip to content

Commit 37effad

Browse files
Arm backend: Support dynamic fulls (pytorch#19912)
Support fulls with dynamic shapes by creating a full with size (1,) followed by a dynamic repeat/tile. cc @digantdesai @freddan80 @per @zingo @mansnils @Sebastian-Larsson @robell @rascani Signed-off-by: Oscar Andersson <oscar.andersson@arm.com>
1 parent 6043775 commit 37effad

4 files changed

Lines changed: 235 additions & 0 deletions

File tree

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from .decompose_cumsum_pass import DecomposeCumsumPass # noqa
4343
from .decompose_div_pass import DecomposeDivPass # noqa
4444
from .decompose_div_tensor_mode import DecomposeDivTensorModePass # noqa
45+
from .decompose_dynamic_full_pass import DecomposeDynamicFullPass # noqa
4546
from .decompose_einsum_pass import DecomposeEinsumPass # noqa
4647
from .decompose_elu_pass import ConvertEluFamilyToEluPass, DecomposeEluPass # noqa
4748
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
DecomposeCumsumPass,
5050
DecomposeDivPass,
5151
DecomposeDivTensorModePass,
52+
DecomposeDynamicFullPass,
5253
DecomposeEinsumPass,
5354
DecomposeEluPass,
5455
DecomposeEmbeddingPass,
@@ -496,6 +497,7 @@ def _tosa_pipeline(
496497
ConvertMinMaxPass(),
497498
DecomposeAnyPass(),
498499
DecorateFp32toInt32CastingPass(),
500+
DecomposeDynamicFullPass(),
499501
ConvertExpandCopyToRepeatPass(),
500502
UnsqueezeBeforeRepeatPass(),
501503
DecomposeCumsumPass(exported_program),
@@ -582,6 +584,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
582584
DecomposeIndexCopyPass(tfa_pass=True),
583585
DecomposeSelectScatterPass(tfa_pass=True),
584586
DecomposeSliceScatterPass(tfa_pass=True),
587+
DecomposeDynamicFullPass(tfa_pass=True),
585588
ConvertInt64ConstOpsToInt32Pass(tfa_pass=True),
586589
ConvertInt64OutputOpsToInt32Pass(tfa_pass=True),
587590
InsertInt32CastsAfterInt64PlaceholdersPass(tfa_pass=True),
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Any, Set, Type
7+
8+
import torch
9+
from executorch.backends.arm._passes.arm_pass import ArmPass
10+
from executorch.backends.arm._passes.unsqueeze_before_repeat_pass import (
11+
UnsqueezeBeforeRepeatPass,
12+
)
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import ExportPass
15+
16+
17+
class DecomposeDynamicFullPass(ArmPass):
18+
"""Rewrite dynamic-shape `full` into scalar `full` plus `repeat`."""
19+
20+
_passes_required_after: Set[Type[ExportPass]] = {UnsqueezeBeforeRepeatPass}
21+
22+
full_targets = {
23+
torch.ops.aten.full.default,
24+
exir_ops.edge.aten.full.default,
25+
}
26+
repeat = exir_ops.edge.aten.repeat.default
27+
28+
@staticmethod
29+
def _has_symbolic_extent(size: Any) -> bool:
30+
return isinstance(size, (list, tuple)) and any(
31+
not isinstance(dim, int) for dim in size
32+
)
33+
34+
def call_operator(self, op, args, kwargs, meta, updated=False):
35+
if op not in self.full_targets:
36+
return super().call_operator(op, args, kwargs, meta, updated)
37+
38+
size, fill_value = args[:2]
39+
if not self._has_symbolic_extent(size):
40+
return super().call_operator(op, args, kwargs, meta, updated)
41+
42+
scalar_full = super().call_operator(
43+
op=op,
44+
args=((1,), fill_value),
45+
kwargs=kwargs,
46+
meta=meta,
47+
updated=True,
48+
)
49+
return super().call_operator(
50+
op=self.repeat,
51+
args=(scalar_full, size),
52+
kwargs={},
53+
meta=meta,
54+
updated=True,
55+
)
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm._passes import ArmPassManager, DecomposeDynamicFullPass
8+
from executorch.backends.arm.test import common
9+
from executorch.exir import EdgeCompileConfig, to_edge
10+
from executorch.exir.dialects._ops import ops as exir_ops
11+
12+
13+
class _DynamicFull(torch.nn.Module):
14+
def forward(self, x: torch.Tensor) -> torch.Tensor:
15+
return torch.full(x.shape, 3.5)
16+
17+
18+
class _DynamicIntegerFull(torch.nn.Module):
19+
def forward(self, x: torch.Tensor) -> torch.Tensor:
20+
return torch.full(x.shape, 3)
21+
22+
23+
class _DynamicFullLike(torch.nn.Module):
24+
def forward(self, x: torch.Tensor) -> torch.Tensor:
25+
return torch.full_like(x, 3.5)
26+
27+
28+
class _StaticFull(torch.nn.Module):
29+
def forward(self) -> torch.Tensor:
30+
return torch.full((2, 3), 3.5)
31+
32+
33+
def _export_dynamic_full() -> torch.export.ExportedProgram:
34+
return torch.export.export(
35+
_DynamicFull().eval(),
36+
(torch.randn(2, 3, 4),),
37+
dynamic_shapes={
38+
"x": {
39+
0: torch.export.Dim("batch", min=1, max=8),
40+
2: torch.export.Dim("height", min=1, max=16),
41+
}
42+
},
43+
)
44+
45+
46+
def test_decompose_dynamic_full_to_scalar_full_and_repeat() -> None:
47+
exported_program = _export_dynamic_full()
48+
49+
result = DecomposeDynamicFullPass()(exported_program.graph_module)
50+
assert result is not None
51+
52+
full_nodes = [
53+
node
54+
for node in result.graph_module.graph.nodes
55+
if node.op == "call_function" and node.target == torch.ops.aten.full.default
56+
]
57+
repeat_nodes = [
58+
node
59+
for node in result.graph_module.graph.nodes
60+
if node.op == "call_function"
61+
and node.target == exir_ops.edge.aten.repeat.default
62+
]
63+
64+
assert len(full_nodes) == 1
65+
assert len(repeat_nodes) == 1
66+
assert full_nodes[0].args[0] == (1,)
67+
68+
repeat_sizes = repeat_nodes[0].args[1]
69+
assert isinstance(repeat_sizes, list)
70+
assert len(repeat_sizes) == 3
71+
assert repeat_sizes[1] == 3
72+
assert getattr(repeat_sizes[0], "target", None) == torch.ops.aten.sym_size.int
73+
assert getattr(repeat_sizes[2], "target", None) == torch.ops.aten.sym_size.int
74+
75+
result.graph_module.graph.lint()
76+
77+
78+
def test_annotation_pipeline_converts_dynamic_integer_full_to_int32() -> None:
79+
exported_program = torch.export.export(
80+
_DynamicIntegerFull().eval(),
81+
(torch.randn(2, 3, 4),),
82+
dynamic_shapes={
83+
"x": {
84+
0: torch.export.Dim("batch", min=1, max=8),
85+
2: torch.export.Dim("height", min=1, max=16),
86+
}
87+
},
88+
)
89+
90+
graph_module = ArmPassManager(
91+
common.get_tosa_compile_spec("TOSA-1.0+INT")
92+
).transform_for_annotation_pipeline(exported_program.graph_module)
93+
94+
full_nodes = [
95+
node
96+
for node in graph_module.graph.nodes
97+
if node.op == "call_function" and node.target == torch.ops.aten.full.default
98+
]
99+
repeat_nodes = [
100+
node
101+
for node in graph_module.graph.nodes
102+
if node.op == "call_function"
103+
and node.target == exir_ops.edge.aten.repeat.default
104+
]
105+
106+
assert len(full_nodes) == 1
107+
assert len(repeat_nodes) == 1
108+
assert full_nodes[0].args[0] == (1,)
109+
assert full_nodes[0].kwargs["dtype"] == torch.int32
110+
assert full_nodes[0].meta["val"].dtype == torch.int32
111+
112+
113+
def test_backend_pipeline_decomposes_dynamic_full_like() -> None:
114+
exported_program = torch.export.export(
115+
_DynamicFullLike().eval(),
116+
(torch.randn(2, 3, 4),),
117+
dynamic_shapes={
118+
"x": {
119+
0: torch.export.Dim("batch", min=1, max=8),
120+
2: torch.export.Dim("height", min=1, max=16),
121+
}
122+
},
123+
)
124+
edge_program = to_edge(exported_program, compile_config=EdgeCompileConfig())
125+
graph_module = ArmPassManager(
126+
common.get_tosa_compile_spec("TOSA-1.0+FP")
127+
).transform_to_backend_pipeline(
128+
edge_program.exported_program(),
129+
edge_program.exported_program().graph_module,
130+
)
131+
132+
full_nodes = [
133+
node
134+
for node in graph_module.graph.nodes
135+
if node.op == "call_function" and node.target == exir_ops.edge.aten.full.default
136+
]
137+
full_like_nodes = [
138+
node
139+
for node in graph_module.graph.nodes
140+
if node.op == "call_function"
141+
and node.target == exir_ops.edge.aten.full_like.default
142+
]
143+
repeat_nodes = [
144+
node
145+
for node in graph_module.graph.nodes
146+
if node.op == "call_function"
147+
and node.target == exir_ops.edge.aten.repeat.default
148+
]
149+
150+
assert not full_nodes
151+
assert not full_like_nodes
152+
assert len(repeat_nodes) == 1
153+
assert repeat_nodes[0].args[1][1] == 3
154+
155+
156+
def test_decompose_dynamic_full_leaves_static_full_unchanged() -> None:
157+
exported_program = torch.export.export(_StaticFull().eval(), ())
158+
159+
result = DecomposeDynamicFullPass()(exported_program.graph_module)
160+
assert result is not None
161+
162+
full_nodes = [
163+
node
164+
for node in result.graph_module.graph.nodes
165+
if node.op == "call_function" and node.target == torch.ops.aten.full.default
166+
]
167+
repeat_nodes = [
168+
node
169+
for node in result.graph_module.graph.nodes
170+
if node.op == "call_function"
171+
and node.target == exir_ops.edge.aten.repeat.default
172+
]
173+
174+
assert len(full_nodes) == 1
175+
assert full_nodes[0].args[0] == [2, 3]
176+
assert not repeat_nodes

0 commit comments

Comments
 (0)