Skip to content

Commit a577584

Browse files
pytorchbotper
andauthored
Allow symints to be created for arguments (#16774)
### Summary Add test for creating args of SymInt type to be able to use them in view_copy nodes together with the fix to make the test pass. ### Test plan Tested through CI tests. cc @freddan80 @zingo @oscarandersson8218 @digantdesai Signed-off-by: Per Åstrand <per.astrand@arm.com> Co-authored-by: Per Åstrand <per.astrand@arm.com>
1 parent e006d99 commit a577584

2 files changed

Lines changed: 112 additions & 2 deletions

File tree

exir/pass_base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3-
# Copyright 2025 Arm Limited and/or its affiliates.
3+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
@@ -191,6 +191,11 @@ def create_arg(self, a: Argument) -> torch.fx.Node:
191191
if not hasattr(a, "constant") or a.constant is None:
192192
raise ExportPassBaseError(f"Cannot add {a} to graph.")
193193
a = a.constant
194+
elif isinstance(a, torch.SymInt):
195+
if a.node.constant is not None:
196+
return a.node.constant
197+
else:
198+
return a
194199
node = super().create_arg(a)
195200
if (
196201
isinstance(a, torch.Tensor)

exir/tests/test_dynamic_shape_propagation.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,26 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
33
#
4+
# Copyright 2026 Arm Limited and/or its affiliates.
5+
#
46
# This source code is licensed under the BSD-style license found in the
57
# LICENSE file in the root directory of this source tree.
68

79
# pyre-unsafe
810

911
from unittest import TestCase
1012

13+
import torch
14+
1115
from executorch import exir
1216
from executorch.exir import to_edge
13-
from executorch.exir.passes import DebugPass, HintBasedSymShapeEvalPass, SpecPropPass
17+
from executorch.exir.passes import (
18+
DebugPass,
19+
ExportPass,
20+
HintBasedSymShapeEvalPass,
21+
SpecPropPass,
22+
)
23+
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
1424
from executorch.exir.tests.models import Repeat, TensorItem
1525
from torch.export import export
1626

@@ -67,3 +77,98 @@ def test_unbacked_symint(self):
6777
self.assertEqual(
6878
speclist[0].shape, [100, 100]
6979
) # upper bound of TensorItem model
80+
81+
82+
class TestSymIntViewArgs(TestCase):
83+
class Conv1dToConv2d(torch.nn.Module):
84+
def __init__(self) -> None:
85+
super().__init__()
86+
87+
def forward(self, input: torch.Tensor) -> torch.Tensor:
88+
# Use view to make sure edge view handle symint shapes correctly.
89+
# input = input.view(input.size(0), input.size(1), input.size(2), 1) # (N, C, H, W)
90+
# weight = torch.randn(1, 16, 3, 1) # (out_channels, in_channels, kH, kW)
91+
# return torch.nn.functional.conv2d(input, weight)
92+
93+
return torch.nn.functional.conv1d(
94+
input, torch.randn(1, 16, 3)
95+
) # (out_channels, in_channels, kW)
96+
97+
def get_random_inputs(self) -> tuple[torch.Tensor]:
98+
return (torch.randn(1, 16, 50),) # (batch_size, channels, width)
99+
100+
def get_dynamic_shape(self) -> tuple[dict[int, torch.export.Dim]]:
101+
dim = torch.export.Dim("width", min=10, max=100)
102+
return ({2: dim},)
103+
104+
def test_symint_viewargs(self):
105+
eager_model = TestSymIntViewArgs.Conv1dToConv2d()
106+
inputs = eager_model.get_random_inputs()
107+
108+
class TestViewCopyPass(ExportPass):
109+
def call_operator(self, op, args, kwargs, meta):
110+
from executorch.exir.dialects._ops import ops as exir_ops
111+
112+
if op != exir_ops.edge.aten.convolution.default:
113+
return super().call_operator(op, args, kwargs, meta)
114+
115+
x = args[0]
116+
x = super().call_operator(
117+
exir_ops.edge.aten.view_copy.default,
118+
(x, list(x.data.shape) + [1]),
119+
{},
120+
meta,
121+
)
122+
123+
w = args[1]
124+
w = super().call_operator(
125+
exir_ops.edge.aten.view_copy.default,
126+
(w, list(w.data.shape) + [1]),
127+
{},
128+
meta,
129+
)
130+
131+
new_args = (
132+
x,
133+
w,
134+
args[2],
135+
args[3] + [1], # stride
136+
args[4] + [0], # padding
137+
args[5] + [1], # dilation
138+
args[6],
139+
args[7] + [0],
140+
args[8],
141+
)
142+
x = super().call_operator(
143+
exir_ops.edge.aten.convolution.default, new_args, kwargs, meta
144+
)
145+
x = super().call_operator(
146+
exir_ops.edge.aten.view_copy.default,
147+
(x, list(x.data.shape)[:-1]),
148+
{},
149+
meta,
150+
)
151+
152+
return x
153+
154+
prog = to_edge(
155+
export(
156+
eager_model,
157+
inputs,
158+
dynamic_shapes=eager_model.get_dynamic_shape(),
159+
strict=True,
160+
),
161+
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
162+
)
163+
new_prog = prog.transform(
164+
[SpecPropPass(), ConstraintBasedSymShapeEvalPass(), TestViewCopyPass()]
165+
)
166+
gm = new_prog.exported_program().graph_module
167+
DebugPass(show_spec=True)(gm)
168+
*_, return_node = gm.graph.nodes
169+
speclist = return_node.meta["spec"]
170+
171+
self.assertEqual(len(speclist), 1)
172+
out_spec = speclist[0]
173+
self.assertTrue(out_spec.is_upper_bound_tensor)
174+
self.assertEqual(out_spec.shape, [1, 1, 98])

0 commit comments

Comments
 (0)