Skip to content

Commit d56fca1

Browse files
JakeStevensfacebook-github-bot
authored andcommitted
Fix broken ConvBNReLu from new Convert1DConvTo2D pass
Summary: The pass checked for a batch norm following the conv to avoid breaking fusion with a squeeze. However, it did not support Conv -> Batch Norm -> ReLu OR Conv -> ReLU This commit adds that support, along with other supported activation Reviewed By: rascani Differential Revision: D105017469
1 parent 1992bdd commit d56fca1

5 files changed

Lines changed: 286 additions & 34 deletions

File tree

backends/nxp/aten_passes/convert_1d_conv_to_2d.py

Lines changed: 114 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
try_get_tensor_constant_from_node,
99
)
1010
from executorch.backends.nxp.backend.graph_utils import is_batch_norm
11+
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
1112
from torch._subclasses import FakeTensor, FakeTensorMode
1213
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
1314
from torch.export.unflatten import _assign_attr, _AttrKind
@@ -26,9 +27,10 @@ class ConvertConv1dToConv2dPass(PassBase):
2627
r"""
2728
The NXP backend supports only 2D convolutions. Rewrite 1D convolutions into an equivalent 2D form by
2829
inserting a singleton spatial dimension and then remove it again.
29-
If batch norm is present after the convolution, it is also converted from 1D to 2D.
30+
If batch norm and/or a fusable activation (as defined by the NeutronTargetSpec) follow the convolution,
31+
they are also kept in 2D (before the squeeze) so the partitioner can fuse them with the convolution.
3032
31-
Without batch norm:
33+
Without batch norm or activation:
3234
3335
x W x W
3436
[N, C1, H] [I/O, I/O, k] [N, C1, H] [I/O, I/O, 1, k]
@@ -90,8 +92,86 @@ class ConvertConv1dToConv2dPass(PassBase):
9092
▼ ▼
9193
[N, C3, H] [N, C3, H]
9294
y y
95+
96+
With activation (e.g. relu):
97+
98+
x W x W
99+
[N, C1, H] [I/O, I/O, k] [N, C1, H] [I/O, I/O, 1, k]
100+
│ │ │ │
101+
│ │ ┌─────────▼──────────┐ │
102+
│ │ │ unsqueeze(x, -2) │ │
103+
│ │ └─────────▼──────────┘ │
104+
│ │ │ │
105+
│ │ [N, C1, 1, H] │
106+
│ │ │ │
107+
└────────┐ ┌────────┘ └──────────┐ ┌──────────┘
108+
│ │ │ │
109+
┌────────▼───────▼───────┐ ┌────────▼─────▼────────┐
110+
│ convolution ◄──B [O] replace │ convolution ◄──B [O]
111+
│ (1D/transposed 1D) │ ────────────────► │ (2D/transposed 2D) │
112+
└────────────┬───────────┘ with └───────────┬───────────┘
113+
│ │
114+
[N, C2, H] [N, C2, 1, H]
115+
│ │
116+
┌───────▼───────┐ ┌───────▼───────┐
117+
│ relu │ │ relu │
118+
└───────┬───────┘ └───────┬───────┘
119+
│ │
120+
│ [N, C2, 1, H]
121+
│ │
122+
│ ┌───────▼────────┐
123+
│ │ squeeze(-2) │
124+
│ └───────┬────────┘
125+
│ │
126+
▼ ▼
127+
[N, C2, H] [N, C2, H]
128+
y y
129+
130+
With batch norm and activation:
131+
132+
x W x W
133+
[N, C1, H] [I/O, I/O, k] [N, C1, H] [I/O, I/O, 1, k]
134+
│ │ │ │
135+
│ │ ┌─────────▼──────────┐ │
136+
│ │ │ unsqueeze(x, -2) │ │
137+
│ │ └─────────▼──────────┘ │
138+
│ │ │ │
139+
│ │ [N, C1, 1, H] │
140+
│ │ │ │
141+
└────────┐ ┌────────┘ └──────────┐ ┌──────────┘
142+
│ │ │ │
143+
┌────────▼───────▼───────┐ ┌────────▼─────▼────────┐
144+
│ convolution ◄──B [O] replace │ convolution ◄──B [O]
145+
│ (1D/transposed 1D) │ ────────────────► │ (2D/transposed 2D) │
146+
└────────────┬───────────┘ with └───────────┬───────────┘
147+
│ │
148+
[N, C2, H] [N, C2, 1, H]
149+
│ │
150+
┌───────▼───────┐ ┌───────▼───────┐
151+
│ batch_norm │ │ batch_norm │
152+
│ (1D) │ │ (2D) │
153+
└───────┬───────┘ └───────┬───────┘
154+
│ │
155+
[N, C3, H] [N, C3, 1, H]
156+
│ │
157+
┌───────▼───────┐ ┌───────▼───────┐
158+
│ relu │ │ relu │
159+
└───────┬───────┘ └───────┬───────┘
160+
│ │
161+
│ [N, C3, 1, H]
162+
│ │
163+
│ ┌───────▼────────┐
164+
│ │ squeeze(-2) │
165+
│ └───────┬────────┘
166+
│ │
167+
▼ ▼
168+
[N, C3, H] [N, C3, H]
169+
y y
93170
"""
94171

172+
def __init__(self, neutron_target_spec: NeutronTargetSpec):
173+
self.neutron_target_spec = neutron_target_spec
174+
95175
@staticmethod
96176
def _is_conv_1d(node: Node) -> bool:
97177
return node.target == torch.ops.aten.conv1d.default
@@ -357,35 +437,43 @@ def call(self, graph_module: GraphModule) -> PassResult:
357437
)
358438

359439
old_1d_conv_users = list(old_1d_node.users.keys())
440+
last_4d_node = new_2d_node
441+
node_to_replace = old_1d_node
442+
nodes_to_erase = []
443+
360444
if len(old_1d_conv_users) == 1 and is_batch_norm(old_1d_conv_users[0]):
361445
bn_1d_node = old_1d_conv_users[0]
362-
363-
# also convert batch_norm 1d to 2d
364-
with self.graph_module.graph.inserting_after(new_2d_node):
446+
with self.graph_module.graph.inserting_after(last_4d_node):
365447
bn_2d_args = (new_2d_node,) + bn_1d_node.args[1:]
366448
bn_2d_node = self._create_batch_norm_2d_node(*bn_2d_args)
367-
368-
with self.graph_module.graph.inserting_after(bn_2d_node):
369-
squeeze_target = torch.ops.aten.squeeze.dim
370-
371-
out_sq_args = (bn_2d_node, -2)
372-
out_sq_node = self._create_sq_or_unsq_node(
373-
squeeze_target, *out_sq_args
374-
)
375-
376-
bn_1d_node.replace_all_uses_with(out_sq_node)
377-
self.graph_module.graph.erase_node(bn_1d_node)
378-
379-
else:
380-
with self.graph_module.graph.inserting_after(new_2d_node):
381-
squeeze_target = torch.ops.aten.squeeze.dim
382-
383-
out_sq_args = (new_2d_node, -2)
384-
out_sq_node = self._create_sq_or_unsq_node(
385-
squeeze_target, *out_sq_args
449+
last_4d_node = bn_2d_node
450+
node_to_replace = bn_1d_node
451+
nodes_to_erase.append(bn_1d_node)
452+
old_1d_conv_users = list(bn_1d_node.users.keys())
453+
454+
if len(
455+
old_1d_conv_users
456+
) == 1 and self.neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten(
457+
old_1d_conv_users[0]
458+
):
459+
act_1d_node = old_1d_conv_users[0]
460+
with self.graph_module.graph.inserting_after(last_4d_node):
461+
act_2d_args = (last_4d_node,) + act_1d_node.args[1:]
462+
act_2d_node = self._create_sq_or_unsq_node(
463+
act_1d_node.target, *act_2d_args
386464
)
387-
388-
old_1d_node.replace_all_uses_with(out_sq_node)
465+
last_4d_node = act_2d_node
466+
node_to_replace = act_1d_node
467+
nodes_to_erase.append(act_1d_node)
468+
469+
with self.graph_module.graph.inserting_after(last_4d_node):
470+
squeeze_target = torch.ops.aten.squeeze.dim
471+
out_sq_args = (last_4d_node, -2)
472+
out_sq_node = self._create_sq_or_unsq_node(squeeze_target, *out_sq_args)
473+
474+
node_to_replace.replace_all_uses_with(out_sq_node)
475+
for n in reversed(nodes_to_erase):
476+
self.graph_module.graph.erase_node(n)
389477

390478
graph_module.graph.erase_node(old_1d_node)
391479
made_changes = True

backends/nxp/aten_passes/neutron_aten_pass_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def _get_default_passes(neutron_target_spec, qat_mode: bool = False) -> list[Pas
5252
FuseLinearAndAddPass(),
5353
MoveActivationBeforeConcat(neutron_target_spec),
5454
ConvertDivToMulPass(),
55-
ConvertConv1dToConv2dPass(),
55+
ConvertConv1dToConv2dPass(neutron_target_spec),
5656
]
5757

5858
if not qat_mode:

backends/nxp/tests/BUCK

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,26 @@ fbcode_target(_kind = python_pytest,
9292
],
9393
)
9494

95+
fbcode_target(_kind = python_pytest,
96+
name = "test_convert_1d_conv_to_2d",
97+
srcs = [
98+
"test_convert_1d_conv_to_2d.py",
99+
],
100+
env = {
101+
"PYTEST_ADDOPTS": "--ignore-glob=*full_pipeline* -k 'not full_pipeline'",
102+
},
103+
deps = [
104+
"//caffe2:torch",
105+
"//executorch/backends/nxp:aten_passes",
106+
"//executorch/backends/nxp:neutron_backend",
107+
":executorch_pipeline",
108+
":models",
109+
"fbsource//third-party/pypi/numpy:numpy",
110+
"fbsource//third-party/pypi/pytest:pytest",
111+
"fbsource//third-party/pypi/pytest-mock:pytest-mock", # @manual
112+
],
113+
)
114+
95115
fbcode_target(_kind = python_pytest,
96116
name = "test_integration",
97117
srcs = [

backends/nxp/tests/generic_tests/test_split_group_convolution.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ def test_split_group_convolution__1d(
161161
# `ConvertConv1dToConv2dPass` is needed to convert `conv1d` to `conv2d`.
162162
# The 1d variant is not supported.
163163
modified_module = NeutronAtenPassManager(
164-
neutron_target_spec, [SplitGroupConvolution(), ConvertConv1dToConv2dPass()]
164+
neutron_target_spec,
165+
[SplitGroupConvolution(), ConvertConv1dToConv2dPass(neutron_target_spec)],
165166
)(graph_module).graph_module
166167

167168
# Verify that the behavior has not changed.

0 commit comments

Comments
 (0)