88 try_get_tensor_constant_from_node ,
99)
1010from executorch .backends .nxp .backend .graph_utils import is_batch_norm
11+ from executorch .backends .nxp .backend .neutron_target_spec import NeutronTargetSpec
1112from torch ._subclasses import FakeTensor , FakeTensorMode
1213from torch .ao .quantization .fx .utils import get_new_attr_name_with_prefix
1314from torch .export .unflatten import _assign_attr , _AttrKind
@@ -26,9 +27,10 @@ class ConvertConv1dToConv2dPass(PassBase):
2627 r"""
2728 The NXP backend supports only 2D convolutions. Rewrite 1D convolutions into an equivalent 2D form by
2829 inserting a singleton spatial dimension and then remove it again.
29- If batch norm is present after the convolution, it is also converted from 1D to 2D.
30+ If batch norm and/or a fusable activation (as defined by the NeutronTargetSpec) follow the convolution,
31+ they are also kept in 2D (before the squeeze) so the partitioner can fuse them with the convolution.
3032
31- Without batch norm:
33+ Without batch norm or activation :
3234
3335 x W x W
3436 [N, C1, H] [I/O, I/O, k] [N, C1, H] [I/O, I/O, 1, k]
@@ -90,8 +92,86 @@ class ConvertConv1dToConv2dPass(PassBase):
9092 ▼ ▼
9193 [N, C3, H] [N, C3, H]
9294 y y
95+
96+ With activation (e.g. relu):
97+
98+ x W x W
99+ [N, C1, H] [I/O, I/O, k] [N, C1, H] [I/O, I/O, 1, k]
100+ │ │ │ │
101+ │ │ ┌─────────▼──────────┐ │
102+ │ │ │ unsqueeze(x, -2) │ │
103+ │ │ └─────────▼──────────┘ │
104+ │ │ │ │
105+ │ │ [N, C1, 1, H] │
106+ │ │ │ │
107+ └────────┐ ┌────────┘ └──────────┐ ┌──────────┘
108+ │ │ │ │
109+ ┌────────▼───────▼───────┐ ┌────────▼─────▼────────┐
110+ │ convolution ◄──B [O] replace │ convolution ◄──B [O]
111+ │ (1D/transposed 1D) │ ────────────────► │ (2D/transposed 2D) │
112+ └────────────┬───────────┘ with └───────────┬───────────┘
113+ │ │
114+ [N, C2, H] [N, C2, 1, H]
115+ │ │
116+ ┌───────▼───────┐ ┌───────▼───────┐
117+ │ relu │ │ relu │
118+ └───────┬───────┘ └───────┬───────┘
119+ │ │
120+ │ [N, C2, 1, H]
121+ │ │
122+ │ ┌───────▼────────┐
123+ │ │ squeeze(-2) │
124+ │ └───────┬────────┘
125+ │ │
126+ ▼ ▼
127+ [N, C2, H] [N, C2, H]
128+ y y
129+
130+ With batch norm and activation:
131+
132+ x W x W
133+ [N, C1, H] [I/O, I/O, k] [N, C1, H] [I/O, I/O, 1, k]
134+ │ │ │ │
135+ │ │ ┌─────────▼──────────┐ │
136+ │ │ │ unsqueeze(x, -2) │ │
137+ │ │ └─────────▼──────────┘ │
138+ │ │ │ │
139+ │ │ [N, C1, 1, H] │
140+ │ │ │ │
141+ └────────┐ ┌────────┘ └──────────┐ ┌──────────┘
142+ │ │ │ │
143+ ┌────────▼───────▼───────┐ ┌────────▼─────▼────────┐
144+ │ convolution ◄──B [O] replace │ convolution ◄──B [O]
145+ │ (1D/transposed 1D) │ ────────────────► │ (2D/transposed 2D) │
146+ └────────────┬───────────┘ with └───────────┬───────────┘
147+ │ │
148+ [N, C2, H] [N, C2, 1, H]
149+ │ │
150+ ┌───────▼───────┐ ┌───────▼───────┐
151+ │ batch_norm │ │ batch_norm │
152+ │ (1D) │ │ (2D) │
153+ └───────┬───────┘ └───────┬───────┘
154+ │ │
155+ [N, C3, H] [N, C3, 1, H]
156+ │ │
157+ ┌───────▼───────┐ ┌───────▼───────┐
158+ │ relu │ │ relu │
159+ └───────┬───────┘ └───────┬───────┘
160+ │ │
161+ │ [N, C3, 1, H]
162+ │ │
163+ │ ┌───────▼────────┐
164+ │ │ squeeze(-2) │
165+ │ └───────┬────────┘
166+ │ │
167+ ▼ ▼
168+ [N, C3, H] [N, C3, H]
169+ y y
93170 """
94171
172+ def __init__ (self , neutron_target_spec : NeutronTargetSpec ):
173+ self .neutron_target_spec = neutron_target_spec
174+
95175 @staticmethod
96176 def _is_conv_1d (node : Node ) -> bool :
97177 return node .target == torch .ops .aten .conv1d .default
@@ -357,35 +437,43 @@ def call(self, graph_module: GraphModule) -> PassResult:
357437 )
358438
359439 old_1d_conv_users = list (old_1d_node .users .keys ())
440+ last_4d_node = new_2d_node
441+ node_to_replace = old_1d_node
442+ nodes_to_erase = []
443+
360444 if len (old_1d_conv_users ) == 1 and is_batch_norm (old_1d_conv_users [0 ]):
361445 bn_1d_node = old_1d_conv_users [0 ]
362-
363- # also convert batch_norm 1d to 2d
364- with self .graph_module .graph .inserting_after (new_2d_node ):
446+ with self .graph_module .graph .inserting_after (last_4d_node ):
365447 bn_2d_args = (new_2d_node ,) + bn_1d_node .args [1 :]
366448 bn_2d_node = self ._create_batch_norm_2d_node (* bn_2d_args )
367-
368- with self .graph_module .graph .inserting_after (bn_2d_node ):
369- squeeze_target = torch .ops .aten .squeeze .dim
370-
371- out_sq_args = (bn_2d_node , - 2 )
372- out_sq_node = self ._create_sq_or_unsq_node (
373- squeeze_target , * out_sq_args
374- )
375-
376- bn_1d_node .replace_all_uses_with (out_sq_node )
377- self .graph_module .graph .erase_node (bn_1d_node )
378-
379- else :
380- with self .graph_module .graph .inserting_after (new_2d_node ):
381- squeeze_target = torch .ops .aten .squeeze .dim
382-
383- out_sq_args = (new_2d_node , - 2 )
384- out_sq_node = self ._create_sq_or_unsq_node (
385- squeeze_target , * out_sq_args
449+ last_4d_node = bn_2d_node
450+ node_to_replace = bn_1d_node
451+ nodes_to_erase .append (bn_1d_node )
452+ old_1d_conv_users = list (bn_1d_node .users .keys ())
453+
454+ if len (
455+ old_1d_conv_users
456+ ) == 1 and self .neutron_target_spec .neutron_target_info .is_supported_fused_activation__aten (
457+ old_1d_conv_users [0 ]
458+ ):
459+ act_1d_node = old_1d_conv_users [0 ]
460+ with self .graph_module .graph .inserting_after (last_4d_node ):
461+ act_2d_args = (last_4d_node ,) + act_1d_node .args [1 :]
462+ act_2d_node = self ._create_sq_or_unsq_node (
463+ act_1d_node .target , * act_2d_args
386464 )
387-
388- old_1d_node .replace_all_uses_with (out_sq_node )
465+ last_4d_node = act_2d_node
466+ node_to_replace = act_1d_node
467+ nodes_to_erase .append (act_1d_node )
468+
469+ with self .graph_module .graph .inserting_after (last_4d_node ):
470+ squeeze_target = torch .ops .aten .squeeze .dim
471+ out_sq_args = (last_4d_node , - 2 )
472+ out_sq_node = self ._create_sq_or_unsq_node (squeeze_target , * out_sq_args )
473+
474+ node_to_replace .replace_all_uses_with (out_sq_node )
475+ for n in reversed (nodes_to_erase ):
476+ self .graph_module .graph .erase_node (n )
389477
390478 graph_module .graph .erase_node (old_1d_node )
391479 made_changes = True
0 commit comments