Skip to content

Commit 8568135

Browse files
committed
feat: added transposed conv 1d support + refactored conv_converter
1 parent 063f9c9 commit 8568135

9 files changed

Lines changed: 827 additions & 335 deletions

File tree

Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
# Copyright 2026 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from torch._subclasses import FakeTensor, FakeTensorMode
8+
from torch.fx import GraphModule, Node
9+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
10+
11+
12+
Conv1dArgs = tuple[Node, Node, (Node | None), list[int], list[int], list[int], int]
13+
Conv1dTranspArgs = tuple[
14+
Node, Node, (Node | None), list[int], list[int], list[int], int, list[int]
15+
]
16+
17+
18+
class ConvertConv1dToConv2dPass(PassBase):
19+
r"""
20+
The NXP backend supports only 2D convolutions. Rewrite 1D convolutions into an equivalent 2D form by
21+
inserting a singleton spatial dimension and then removing it again.
22+
23+
x W x W
24+
[N, C1, H1] [I/O, I/O, k] [N, C1, H1] [I/O, I/O, k]
25+
│ │ │ │
26+
│ │ ┌────────▼─────────┐ ┌─────────▼────────┐
27+
│ │ │ unsqueeze(x, 2) │ │ unsqueeze(x, 2) │
28+
│ │ └────────▼─────────┘ └─────────▼────────┘
29+
│ │ │ │
30+
│ │ [N, C1, 1, H1] [I/O, I/O, 1, k]
31+
│ │ │ │
32+
└────────┐ ┌────────┘ └──────────┐ ┌──────────┘
33+
│ │ │ │
34+
┌────────▼───────▼───────┐ ┌────────▼─────▼────────┐
35+
│ convolution ◄──B [O] replace │ convolution ◄──B [O]
36+
│ (1D/transposed 1D) │ ────────────────► │ (2D/transposed 2D) │
37+
└────────────┬───────────┘ with └───────────┬───────────┘
38+
│ │
39+
│ [N, C2, 1, H2]
40+
│ │
41+
│ ┌────────▼─────────┐
42+
│ │ squeeze(x, 2) │
43+
│ └────────┬─────────┘
44+
│ │
45+
▼ ▼
46+
[N, C2, H2] [N, C2, H2]
47+
y y
48+
"""
49+
50+
@staticmethod
51+
def _is_conv_1d(node: Node) -> bool:
52+
return node.target == torch.ops.aten.conv1d.default
53+
54+
@staticmethod
55+
def _is_conv_transposed_1d(node: Node) -> bool:
56+
return node.target == torch.ops.aten.conv_transpose1d.default
57+
58+
@staticmethod
59+
def _listify(x: int | list[int] | tuple[int]) -> list[int]:
60+
if isinstance(x, int):
61+
return [x]
62+
63+
return list(x)
64+
65+
@staticmethod
66+
def _get_node_shape(node: Node):
67+
return node.meta["val"].shape if hasattr(node, "meta") else node.shape
68+
69+
@staticmethod
70+
def _get_node_dtype(node: Node):
71+
return node.meta["val"].dtype if hasattr(node, "meta") else node.dtype
72+
73+
def _create_some_conv_2d_node(self, target, *conv_args):
74+
# some_conv_2d_node = could be regular 2d conv or transposed 2d conv
75+
some_conv_node = self.graph_module.graph.call_function(target, conv_args)
76+
some_conv_node.meta["source_fn_stack"] = [(some_conv_node.name, target)]
77+
78+
# take out the bias node argument if bias=False, cannot calculate fake tensor for None
79+
has_b_node = len(conv_args) >= 3 and conv_args[2] is not None
80+
if has_b_node:
81+
node_args = conv_args[:3]
82+
scalar_args = conv_args[3:]
83+
else:
84+
node_args = conv_args[:2]
85+
scalar_args = conv_args[2:]
86+
87+
with FakeTensorMode() as mode:
88+
node_arg_shapes = [self._get_node_shape(arg) for arg in node_args]
89+
node_arg_dtypes = [self._get_node_dtype(arg) for arg in node_args]
90+
fake_node_args = [
91+
FakeTensor.from_tensor(torch.empty(shape, dtype=dtype), mode)
92+
for shape, dtype in zip(node_arg_shapes, node_arg_dtypes)
93+
]
94+
95+
# insert back the bias node argument (= None) if it was taken out earlier
96+
node_args = fake_node_args if has_b_node else fake_node_args + [None]
97+
output = target(*fake_node_args, *scalar_args)
98+
99+
some_conv_node.meta["val"] = FakeTensor.from_tensor(
100+
torch.empty(output.shape, dtype=output.dtype), mode
101+
)
102+
103+
return some_conv_node
104+
105+
def _create_sq_or_unsq_node(self, target, *sq_or_unsq_args) -> Node:
106+
sq_or_unsq_node = self.graph_module.graph.call_function(target, sq_or_unsq_args)
107+
108+
sq_or_unsq_node.meta["source_fn_stack"] = [(sq_or_unsq_node.name, target)]
109+
with FakeTensorMode() as mode:
110+
inp_node = sq_or_unsq_args[0]
111+
fake_input = FakeTensor.from_tensor(
112+
torch.empty(
113+
self._get_node_shape(inp_node), dtype=self._get_node_dtype(inp_node)
114+
),
115+
mode,
116+
)
117+
118+
output = target(fake_input, *sq_or_unsq_args[1:])
119+
sq_or_unsq_node.meta["val"] = FakeTensor.from_tensor(
120+
torch.empty(output.shape, dtype=output.dtype), mode
121+
)
122+
123+
return sq_or_unsq_node
124+
125+
@staticmethod
126+
def _get_conv_1d_transp_args(node: Node):
127+
args = node.args
128+
listify_fn = ConvertConv1dToConv2dPass._listify
129+
130+
b_node = None if len(args) < 3 else args[2]
131+
stride = [1] if len(args) < 4 else listify_fn(args[3])
132+
padding = [0] if len(args) < 5 else listify_fn(args[4])
133+
output_padding = [0] if len(args) < 6 else listify_fn(args[5])
134+
groups = 1 if len(args) < 7 else args[6]
135+
dilation = [1] if len(args) < 8 else listify_fn(args[7])
136+
137+
return (
138+
args[0],
139+
args[1],
140+
b_node,
141+
stride,
142+
padding,
143+
output_padding,
144+
groups,
145+
dilation,
146+
)
147+
148+
@staticmethod
149+
def _get_conv_1d_args(node: Node) -> Conv1dArgs:
150+
args = node.args
151+
listify_fn = ConvertConv1dToConv2dPass._listify
152+
153+
b_node = None if len(args) < 3 else args[2]
154+
stride = [1] if len(args) < 4 else listify_fn(args[3])
155+
padding = [0] if len(args) < 5 else listify_fn(args[4])
156+
dilation = [1] if len(args) < 6 else listify_fn(args[5])
157+
groups = 1 if len(args) < 7 else args[6]
158+
159+
return args[0], args[1], b_node, stride, padding, dilation, groups
160+
161+
def _convert_scalar_1d_args_to_2d(self, old_1d_node: Node):
162+
if self._is_conv_transposed_1d(old_1d_node):
163+
_, _, _, stride, pad, output_pad, groups, dil = (
164+
self._get_conv_1d_transp_args(old_1d_node)
165+
)
166+
167+
# conversion of 1d args to 2d, ie. padding with default values
168+
stride = [1] + stride
169+
pad = [0] + pad
170+
output_pad = [0] + output_pad
171+
dil = [1] + dil
172+
173+
return stride, pad, output_pad, groups, dil
174+
175+
else:
176+
_, _, _, stride, pad, dil, groups = self._get_conv_1d_args(old_1d_node)
177+
178+
# conversion of 1d args to 2d, ie. padding with default values
179+
stride = [1] + stride
180+
pad = [0] + pad
181+
dil = [1] + dil
182+
183+
return stride, pad, dil, groups
184+
185+
def _convert_node_1d_args_to_2d(self, old_1d_node: Node):
186+
if self._is_conv_transposed_1d(old_1d_node):
187+
input_node, w_node, b_node, _, _, _, _, _ = self._get_conv_1d_transp_args(
188+
old_1d_node
189+
)
190+
else:
191+
input_node, w_node, b_node, _, _, _, _ = self._get_conv_1d_args(old_1d_node)
192+
193+
with self.graph_module.graph.inserting_before(old_1d_node):
194+
unsqueeze_target = torch.ops.aten.unsqueeze.default
195+
196+
# weights = [i/o, i/o, k] => [i/o, i/o, 1, k]
197+
w_unsq_args = (w_node, 2)
198+
w_unsq_node = self._create_sq_or_unsq_node(unsqueeze_target, *w_unsq_args)
199+
200+
# input = [n, c, h] => [n, c, 1, h]
201+
inp_unsq_args = (input_node, 2)
202+
inp_unsq_node = self._create_sq_or_unsq_node(
203+
unsqueeze_target, *inp_unsq_args
204+
)
205+
206+
return (inp_unsq_node, w_unsq_node, b_node)
207+
208+
def call(self, graph_module: GraphModule) -> PassResult:
209+
self.graph_module = graph_module
210+
made_changes = False
211+
212+
for node in list(graph_module.graph.nodes):
213+
is_conv_1d = self._is_conv_1d(node)
214+
is_conv_1d_transp = self._is_conv_transposed_1d(node)
215+
216+
# some_1d_conv = regular 1d conv or 1d transposed conv
217+
is_some_1d_conv = is_conv_1d or is_conv_1d_transp
218+
if not is_some_1d_conv:
219+
continue
220+
221+
# invalid number of args
222+
if len(node.args) < 2:
223+
continue
224+
225+
old_1d_node = node
226+
227+
# get input, weight and bias arguments for the new 2d conv
228+
node_args = self._convert_node_1d_args_to_2d(old_1d_node)
229+
# get stride, padding etc. arguments for the new 2d conv
230+
scalar_args = self._convert_scalar_1d_args_to_2d(old_1d_node)
231+
232+
new_2d_target = (
233+
torch.ops.aten.conv_transpose2d.input
234+
if is_conv_1d_transp
235+
else torch.ops.aten.conv2d.default
236+
)
237+
238+
# create the new conv 2d and unsqueeze the input and weights
239+
with self.graph_module.graph.inserting_before(old_1d_node):
240+
new_2d_args = node_args + scalar_args
241+
new_2d_node = self._create_some_conv_2d_node(
242+
new_2d_target, *new_2d_args
243+
)
244+
245+
# the original 1d conv output shape must be retained, thus insert squeeze
246+
with self.graph_module.graph.inserting_after(new_2d_node):
247+
squeeze_target = torch.ops.aten.squeeze.dim
248+
249+
out_sq_args = (new_2d_node, 2)
250+
out_sq_node = self._create_sq_or_unsq_node(squeeze_target, *out_sq_args)
251+
252+
old_1d_node.replace_all_uses_with(out_sq_node)
253+
graph_module.graph.erase_node(old_1d_node)
254+
255+
made_changes = True
256+
257+
graph_module.recompile()
258+
graph_module.graph.eliminate_dead_code()
259+
return PassResult(graph_module, made_changes)

backends/nxp/aten_passes/neutron_aten_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
import torch
99

10+
from executorch.backends.nxp.aten_passes.convert_1d_conv_to_2d import (
11+
ConvertConv1dToConv2dPass,
12+
)
1013
from executorch.backends.nxp.aten_passes.convert_div_to_mul import ConvertDivToMulPass
1114
from executorch.backends.nxp.aten_passes.decompose_split_to_slices_pass import (
1215
DecomposeSplitToSlicesPass,
@@ -49,6 +52,7 @@ def _get_default_passes(neutron_target_spec, qat_mode: bool = False) -> list[Pas
4952
FuseLinearAndAddPass(),
5053
MoveActivationBeforeConcat(neutron_target_spec),
5154
ConvertDivToMulPass(),
55+
ConvertConv1dToConv2dPass(),
5256
]
5357

5458
if not qat_mode:

0 commit comments

Comments
 (0)