Skip to content

Commit c072133

Browse files
Arm backend: Fix transposed conv lowering
Avoid accidentally slicing transpose convolutions. Input slicing is used for regular convolutions to trim unused edge regions for TOSA lowering, but that is not valid for transpose convolutions. Shrinking the input to a transpose convolution changes the expanded output shape and can break output_padding cases. Signed-off-by: Christoffer J.L <christoffer.johanssonlundqvist@arm.com> Change-Id: Ib3e349310bf1672ea1cce366154101aed1415e6e
1 parent 69989b7 commit c072133

2 files changed

Lines changed: 45 additions & 2 deletions

File tree

backends/arm/_passes/size_adjust_input_pass.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,13 @@ def _greater_than(input: SymIntLike, other: int) -> bool | torch.SymBool:
6363

6464

6565
def get_slices_convolution(conv_node: torch.fx.Node) -> Slices:
66-
slices = []
66+
slices: Slices = []
6767

68-
input_node, weight, _, stride_hw, pad_hw, dilation_hw, _, _, _ = conv_node.args
68+
input_node, weight, _, stride_hw, pad_hw, dilation_hw, transposed, _, _ = (
69+
conv_node.args
70+
)
71+
if transposed:
72+
return slices
6973
weight_shape = cast(torch.fx.Node, weight).meta["val"].shape
7074
input_shape = cast(torch.fx.Node, input_node).meta["val"].shape
7175
spatial_rank = len(input_shape) - 2

backends/arm/test/passes/test_size_adjust_input_pass.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
4242
return self.conv(x)
4343

4444

45+
class TransposeConvModule(torch.nn.Module):
46+
def __init__(self) -> None:
47+
super().__init__()
48+
self.conv = torch.nn.ConvTranspose2d(
49+
in_channels=3,
50+
out_channels=6,
51+
kernel_size=3,
52+
stride=2,
53+
output_padding=1,
54+
)
55+
56+
def forward(self, x: torch.Tensor) -> torch.Tensor:
57+
return self.conv(x)
58+
59+
4560
def _needs_truncation(input_length, kernel_size, stride, padding):
4661
return _greater_than((input_length + 2 * padding - kernel_size) % stride, padding)
4762

@@ -115,6 +130,30 @@ def test_size_adjust_input_static_conv_no_adjustment_needed():
115130
), "No slice nodes should be inserted when no adjustment is needed"
116131

117132

133+
def test_size_adjust_input_skips_transpose_conv2d() -> None:
134+
model = TransposeConvModule()
135+
example_inputs = (torch.randn(1, 3, 16, 16),)
136+
edge_model = to_edge(export(model, example_inputs))
137+
edge_model = edge_model.transform([SizeAdjustInputPass()])
138+
gm = edge_model.exported_program().graph_module
139+
140+
conv_node = next(
141+
n
142+
for n in gm.graph.nodes
143+
if n.op == "call_function"
144+
and n.target == exir_ops.edge.aten.convolution.default
145+
)
146+
input_node = conv_node.args[0]
147+
assert input_node.meta["val"].shape == example_inputs[0].shape
148+
149+
slice_nodes = [
150+
n
151+
for n in gm.graph.nodes
152+
if n.op == "call_function" and n.target == exir_ops.edge.aten.slice_copy.Tensor
153+
]
154+
assert len(slice_nodes) == 0
155+
156+
118157
def test_size_adjust_input_dynamic_conv2d():
119158
kernel_size, stride, padding = 3, 3, 1
120159
model = ConvModule(kernel_size=kernel_size, stride=stride, padding=padding)

0 commit comments

Comments
 (0)