Skip to content

Commit b6824d1

Browse files
Arm backend: Add support for all padding modes (pytorch#18521)
The pad operator with different modes was decomposed in the to-edge-step into a number of operation which were not handled well by the arm backend. Instead avoid the default decomposition and decompose replicate circular and reflect modes using slice + concat operators. Signed-off-by: Adrian Lundell <adrian.lundell@arm.com>
1 parent 59838fc commit b6824d1

7 files changed

Lines changed: 308 additions & 35 deletions

File tree

backends/arm/_passes/rewrite_pad.py

Lines changed: 120 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from collections.abc import Sequence
67
from typing import Set, Type
78

89
import torch
@@ -20,18 +21,10 @@ class RewritePadPass(ArmPass):
2021
_passes_required_after: Set[Type[ExportPass]] = set()
2122
targeted_ops = {
2223
exir_ops.edge.aten.constant_pad_nd.default,
24+
exir_ops.edge.aten.pad.default,
2325
}
2426

25-
def call_operator(self, op, args, kwargs, meta, updated=False):
26-
if op not in self.targeted_ops:
27-
return super().call_operator(op, args, kwargs, meta)
28-
29-
if len(args) == 3:
30-
input_tensor, pad, value = args
31-
else:
32-
input_tensor, pad = args
33-
value = 0
34-
27+
def _rewrite_constant_pad(self, input_tensor, pad, value, meta):
3528
output_dtype = meta["val"].dtype
3629
if output_dtype in (torch.int8, torch.int16):
3730
input_qparams = meta.data.get("input_qparams", {})
@@ -65,3 +58,120 @@ def call_operator(self, op, args, kwargs, meta, updated=False):
6558
meta,
6659
True,
6760
)
61+
62+
def _slice_idx(self, x, dim: int, idx: int, meta):
63+
return super().call_operator(
64+
exir_ops.edge.aten.slice_copy.Tensor,
65+
(x, dim, idx, idx + 1),
66+
{},
67+
meta,
68+
True,
69+
)
70+
71+
def _pad_along_dim(
72+
self,
73+
x,
74+
dim: int,
75+
left: int,
76+
right: int,
77+
mode: str,
78+
meta,
79+
):
80+
if left == 0 and right == 0:
81+
return x
82+
83+
size = x.data.shape[dim]
84+
if isinstance(size, torch.SymInt):
85+
raise ValueError(f"Pad mode '{mode}' does not support symbolic shape yet.")
86+
if not isinstance(size, int):
87+
raise ValueError(f"Expected integer dim size for pad rewrite, got {size}.")
88+
89+
left_tensors = []
90+
right_tensors = []
91+
92+
if mode == "replicate":
93+
left_tensors = [self._slice_idx(x, dim, 0, meta) for _ in range(left)]
94+
right_tensors = [
95+
self._slice_idx(x, dim, size - 1, meta) for _ in range(right)
96+
]
97+
elif mode == "circular":
98+
left_tensors = [
99+
self._slice_idx(x, dim, size - left + i, meta) for i in range(left)
100+
]
101+
right_tensors = [self._slice_idx(x, dim, i, meta) for i in range(right)]
102+
elif mode == "reflect":
103+
if left >= size or right >= size:
104+
raise ValueError(
105+
f"Pad mode 'reflect' requires pad < input size, got left={left}, right={right}, size={size}."
106+
)
107+
left_tensors = [
108+
self._slice_idx(x, dim, left - i, meta) for i in range(left)
109+
]
110+
right_tensors = [
111+
self._slice_idx(x, dim, size - 2 - i, meta) for i in range(right)
112+
]
113+
else:
114+
raise ValueError(f"Unsupported pad mode '{mode}'.")
115+
116+
return super().call_operator(
117+
exir_ops.edge.aten.cat.default,
118+
(left_tensors + [x] + right_tensors, dim),
119+
{},
120+
meta,
121+
True,
122+
)
123+
124+
def _rewrite_non_constant_pad(
125+
self,
126+
input_tensor,
127+
pad: Sequence[int],
128+
mode: str,
129+
meta,
130+
):
131+
if len(pad) % 2 != 0:
132+
raise ValueError(f"Invalid pad spec length {len(pad)} for mode '{mode}'.")
133+
134+
output = input_tensor
135+
pairs = [(pad[i], pad[i + 1]) for i in range(0, len(pad), 2)]
136+
rank = len(input_tensor.data.shape)
137+
for pair_idx, (left, right) in enumerate(pairs):
138+
if not isinstance(left, int) or not isinstance(right, int):
139+
raise ValueError(
140+
f"Pad mode '{mode}' expects integer pad values, got ({left}, {right})."
141+
)
142+
# F.pad pad tuples are ordered from the innermost dimension outward.
143+
dim = rank - 1 - pair_idx
144+
output = self._pad_along_dim(output, dim, left, right, mode, meta)
145+
return output
146+
147+
def call_operator(self, op, args, kwargs, meta, updated=False):
148+
if op not in self.targeted_ops:
149+
return super().call_operator(op, args, kwargs, meta)
150+
151+
if op == exir_ops.edge.aten.constant_pad_nd.default:
152+
if len(args) == 3:
153+
input_tensor, pad, value = args
154+
else:
155+
input_tensor, pad = args
156+
value = 0
157+
return self._rewrite_constant_pad(input_tensor, pad, value, meta)
158+
159+
if len(args) < 2:
160+
raise ValueError(
161+
f"Expected at least 2 args for aten.pad.default, got {args}"
162+
)
163+
164+
input_tensor, pad = args[:2]
165+
mode = args[2] if len(args) > 2 else kwargs.get("mode", "constant")
166+
value = args[3] if len(args) > 3 else kwargs.get("value", 0)
167+
168+
if not isinstance(mode, str):
169+
raise ValueError(f"Expected string mode in aten.pad.default, got {mode}")
170+
171+
if mode == "constant":
172+
return self._rewrite_constant_pad(input_tensor, pad, value, meta)
173+
174+
if mode in ("reflect", "replicate", "circular"):
175+
return self._rewrite_non_constant_pad(input_tensor, pad, mode, meta)
176+
177+
raise ValueError(f"Unsupported pad mode '{mode}' in aten.pad.default.")

backends/arm/operator_support/tosa_profile_supported_op_lists.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
9999
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
100100
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
101+
exir_ops.edge.aten.pad.default,
101102
exir_ops.edge.aten.constant_pad_nd.default,
102103
exir_ops.edge.aten.amax.default,
103104
exir_ops.edge.aten.amin.default,
@@ -219,6 +220,7 @@
219220
exir_ops.edge.aten.pow.Tensor_Scalar,
220221
exir_ops.edge.aten.pow.Tensor_Tensor,
221222
operator.getitem,
223+
exir_ops.edge.aten.pad.default,
222224
exir_ops.edge.aten.constant_pad_nd.default,
223225
exir_ops.edge.aten.amax.default,
224226
exir_ops.edge.aten.amin.default,

backends/arm/test/ops/test_constant_pad_nd.py

Lines changed: 87 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,47 +22,113 @@
2222
input_t1 = Tuple[torch.Tensor] # Input x
2323

2424
test_data_suite = {
25-
"4dim_last1dim": lambda: (torch.rand(1, 1, 16, 16), (1, 1, 0, 0, 0, 0, 0, 0), 1),
26-
"4dim_last2dim": lambda: (torch.rand(1, 1, 16, 16), (1, 0, 1, 0, 0, 0, 0, 0), 2),
27-
"4dim_last3dim": lambda: (torch.rand(1, 1, 16, 16), (1, 1, 0, 2, 0, 2, 0, 0), 3),
28-
"4dim_last4dim": lambda: (torch.rand(1, 1, 16, 16), (1, 0, 1, 1, 0, 2, 0, 2), 4),
29-
"3dim_last1dim": lambda: (torch.rand(1, 1, 16), (1, 1, 0, 0, 0, 0), 1),
30-
"3dim_last2dim": lambda: (torch.rand(1, 1, 16), (1, 0, 1, 1, 0, 0), 2),
31-
"3dim_last3dim": lambda: (torch.rand(1, 1, 16), (1, 0, 1, 0, 1, 1), 3),
32-
"2dim_last1dim": lambda: (torch.rand(1, 1, 16), (1, 1, 0, 0), 1),
33-
"2dim_last2dim": lambda: (torch.rand(1, 1, 16), (1, 0, 1, 1), 2),
25+
"4dim_last1dim": lambda: (
26+
torch.rand(1, 1, 16, 16),
27+
(1, 1, 0, 0, 0, 0, 0, 0),
28+
1,
29+
"constant",
30+
),
31+
"4dim_last2dim": lambda: (
32+
torch.rand(1, 1, 16, 16),
33+
(1, 0, 1, 0, 0, 0, 0, 0),
34+
2,
35+
"constant",
36+
),
37+
"4dim_last3dim": lambda: (
38+
torch.rand(1, 1, 16, 16),
39+
(1, 1, 0, 2, 0, 2, 0, 0),
40+
3,
41+
"constant",
42+
),
43+
"4dim_last4dim": lambda: (
44+
torch.rand(1, 1, 16, 16),
45+
(1, 0, 1, 1, 0, 2, 0, 2),
46+
4,
47+
"constant",
48+
),
49+
"3dim_last1dim": lambda: (torch.rand(1, 1, 16), (1, 1, 0, 0, 0, 0), 1, "constant"),
50+
"3dim_last2dim": lambda: (torch.rand(1, 1, 16), (1, 0, 1, 1, 0, 0), 2, "constant"),
51+
"3dim_last3dim": lambda: (torch.rand(1, 1, 16), (1, 0, 1, 0, 1, 1), 3, "constant"),
52+
"2dim_last1dim": lambda: (torch.rand(1, 1, 16), (1, 1, 0, 0), 1, "constant"),
53+
"2dim_last2dim": lambda: (torch.rand(1, 1, 16), (1, 0, 1, 1), 2, "constant"),
54+
"4dim_reflect": lambda: (
55+
torch.rand(6, 6, 6, 6),
56+
(3, 3, 3, 3, 3, 3),
57+
None,
58+
"reflect",
59+
),
60+
"4dim_replicate": lambda: (
61+
torch.rand(3, 3, 3, 3),
62+
(3, 3, 3, 3, 3, 3),
63+
None,
64+
"replicate",
65+
),
66+
"4dim_circular": lambda: (
67+
torch.rand(3, 3, 3, 3),
68+
(3, 3, 3, 3, 3, 3),
69+
None,
70+
"circular",
71+
),
72+
"2dim_reflect": lambda: (
73+
torch.rand(6, 6),
74+
(3, 3),
75+
None,
76+
"reflect",
77+
),
78+
"2dim_replicate": lambda: (
79+
torch.rand(3, 3),
80+
(3, 3),
81+
None,
82+
"replicate",
83+
),
84+
"2dim_circular": lambda: (
85+
torch.rand(3, 3),
86+
(3, 3),
87+
None,
88+
"circular",
89+
),
3490
}
3591

3692
test_data_suite_bf16 = {
3793
"4dim_last1dim_bf16": lambda: (
3894
torch.rand(1, 1, 8, 8, dtype=torch.bfloat16),
3995
(1, 1, 0, 0, 0, 0, 0, 0),
4096
1.0,
97+
"constant",
4198
),
4299
"3dim_last1dim_bf16": lambda: (
43100
torch.rand(1, 1, 8, dtype=torch.bfloat16),
44101
(1, 0, 1, 0, 0, 0),
45102
-0.5,
103+
"constant",
46104
),
47105
}
48106
test_data_suite_fp16 = {
49107
"4dim_last1dim_fp16": lambda: (
50108
torch.rand(1, 1, 8, 8, dtype=torch.float16),
51109
(1, 1, 0, 0, 0, 0, 0, 0),
52110
1.0,
111+
"constant",
53112
),
54113
"3dim_last1dim_fp16": lambda: (
55114
torch.rand(1, 1, 8, dtype=torch.float16),
56115
(1, 0, 1, 0, 0, 0),
57116
-0.5,
117+
"constant",
58118
),
59119
}
60120

61121

62122
class ConstantPadND(torch.nn.Module):
63-
def __init__(self, pad: Tuple, value: float | None = None):
123+
def __init__(
124+
self,
125+
pad: Tuple,
126+
value: float | None = None,
127+
mode: str = "constant",
128+
):
64129
super().__init__()
65130
self.value = value
131+
self.mode = mode
66132
nonzero_idx = len(pad)
67133
for i in range(0, len(pad), 2):
68134
if pad[i] + pad[i + 1] == 0:
@@ -71,18 +137,17 @@ def __init__(self, pad: Tuple, value: float | None = None):
71137
self.pad = pad[:nonzero_idx]
72138

73139
def forward(self, x: torch.Tensor):
74-
x = F.pad(x, pad=self.pad, mode="constant", value=self.value)
75-
return x
140+
return F.pad(x, pad=self.pad, mode=self.mode, value=self.value)
76141

77142

78143
@common.parametrize(
79144
"test_data",
80145
test_data_suite | test_data_suite_bf16 | test_data_suite_fp16,
81146
)
82147
def test_constant_pad_nd_tosa_FP(test_data: Tuple):
83-
test_data, padding, value = test_data()
148+
test_data, padding, value, mode = test_data()
84149
pipeline = TosaPipelineFP[input_t1](
85-
ConstantPadND(padding, value),
150+
ConstantPadND(padding, value, mode),
86151
(test_data,),
87152
aten_op,
88153
exir_op,
@@ -93,9 +158,9 @@ def test_constant_pad_nd_tosa_FP(test_data: Tuple):
93158

94159
@common.parametrize("test_data", test_data_suite)
95160
def test_constant_pad_nd_tosa_INT(test_data: Tuple):
96-
test_data, padding, value = test_data()
161+
test_data, padding, value, mode = test_data()
97162
pipeline = TosaPipelineINT[input_t1](
98-
ConstantPadND(padding, value),
163+
ConstantPadND(padding, value, mode),
99164
(test_data,),
100165
aten_op,
101166
exir_op,
@@ -106,9 +171,9 @@ def test_constant_pad_nd_tosa_INT(test_data: Tuple):
106171
@common.parametrize("test_data", test_data_suite)
107172
def test_constant_pad_nd_tosa_INT_a16w8(test_data: Tuple):
108173
"""Test constant_pad_nd op with int16 I/O quantization for TOSA INT."""
109-
test_data, padding, value = test_data()
174+
test_data, padding, value, mode = test_data()
110175
pipeline = TosaPipelineINT[input_t1](
111-
ConstantPadND(padding, value),
176+
ConstantPadND(padding, value, mode),
112177
(test_data,),
113178
aten_op,
114179
exir_op,
@@ -120,9 +185,9 @@ def test_constant_pad_nd_tosa_INT_a16w8(test_data: Tuple):
120185
@common.parametrize("test_data", test_data_suite | test_data_suite_fp16)
121186
@common.SkipIfNoModelConverter
122187
def test_constant_pad_nd_vgf_no_quant(test_data: Tuple):
123-
inp, padding, value = test_data()
188+
inp, padding, value, mode = test_data()
124189
pipeline = VgfPipeline[input_t1](
125-
ConstantPadND(padding, value),
190+
ConstantPadND(padding, value, mode),
126191
(inp,),
127192
aten_op,
128193
exir_op,
@@ -134,9 +199,9 @@ def test_constant_pad_nd_vgf_no_quant(test_data: Tuple):
134199
@common.parametrize("test_data", test_data_suite)
135200
@common.SkipIfNoModelConverter
136201
def test_constant_pad_nd_vgf_quant(test_data: Tuple):
137-
inp, padding, value = test_data()
202+
inp, padding, value, mode = test_data()
138203
pipeline = VgfPipeline[input_t1](
139-
ConstantPadND(padding, value),
204+
ConstantPadND(padding, value, mode),
140205
(inp,),
141206
aten_op,
142207
exir_op,

0 commit comments

Comments
 (0)