Skip to content

Commit a018853

Browse files
authored
NXP backend: added support for aten.conv_transpose1 and refactored convolution_converter (#19004)
### Summary Added support for `aten.conv_transpose1d` by moving functionality from `convolution_converter` to brand new `convert_1d_conv_to2d` aten pass, and extending it. ### Test plan tests can be manually run using `pytest -c /dev/null backends/nxp/tests/` cc @robert-kalmar @JakeStevens @digantdesai @MartinPavella
1 parent 19d3950 commit a018853

13 files changed

Lines changed: 985 additions & 370 deletions

File tree

backends/nxp/aten_passes/convert_1d_conv_to_2d.py

Lines changed: 395 additions & 0 deletions
Large diffs are not rendered by default.

backends/nxp/aten_passes/neutron_aten_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
import torch
99

10+
from executorch.backends.nxp.aten_passes.convert_1d_conv_to_2d import (
11+
ConvertConv1dToConv2dPass,
12+
)
1013
from executorch.backends.nxp.aten_passes.convert_div_to_mul import ConvertDivToMulPass
1114
from executorch.backends.nxp.aten_passes.decompose_split_to_slices_pass import (
1215
DecomposeSplitToSlicesPass,
@@ -49,6 +52,7 @@ def _get_default_passes(neutron_target_spec, qat_mode: bool = False) -> list[Pas
4952
FuseLinearAndAddPass(),
5053
MoveActivationBeforeConcat(neutron_target_spec),
5154
ConvertDivToMulPass(),
55+
ConvertConv1dToConv2dPass(),
5256
]
5357

5458
if not qat_mode:

backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py

Lines changed: 4 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from executorch.backends.nxp.backend.ir.converter.conversion import (
1616
aten_translator,
1717
common,
18-
translator,
1918
)
2019
from executorch.backends.nxp.backend.ir.converter.conversion.common import try_get_input
2120
from executorch.backends.nxp.backend.ir.converter.conversion.translator import (
@@ -42,7 +41,6 @@
4241
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
4342
conv_2d_options,
4443
depthwise_conv_2d_options,
45-
reshape_options,
4644
transpose_conv_options,
4745
)
4846
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
@@ -70,8 +68,9 @@ def _is_supported_on_target(
7068
return False
7169

7270
if conv_params.transposed:
73-
# TransposeConv1d is not supported on Neutron
74-
if len(conv_params.dilation) == 1:
71+
# TransposeConv2d with groups > 1 is not supported
72+
# TODO: split into multiple convs with groups = 1
73+
if conv_params.groups > 1:
7574
return False
7675
if not node_is_effectively_static_tensor(weights, parameters_mapping):
7776
# Only supported if the weights are static, because TFLite `TransposeConv` uses permuted
@@ -187,99 +186,6 @@ def _get_convolution_arguments(
187186
groups,
188187
)
189188

190-
def _convert_1d_conv(
191-
self, t_op: tflite_model.Operator, conv_params: ConvParameters
192-
) -> list[tflite_model.Operator]:
193-
"""Convert the 'Conv' operator with a 1D kernel to TFLite 'Conv2D'.
194-
TFLite doesn't support 1D convolution, but this behaviour can be represented using
195-
Reshape -> Conv2D -> Reshape.
196-
The first reshape introduces a 4th dimension with size 1. The second Reshape removes the temporary dimension.
197-
"""
198-
# -- Calculate the shapes for equivalent 2D convolution --
199-
conv_2d_input_shape = translator.nhc_dimensions_to_nhwc(
200-
t_op.tmp_inputs[0].shape.vector
201-
)
202-
conv_2d_weight_shape = translator.nhc_dimensions_to_nhwc(
203-
t_op.tmp_inputs[1].shape.vector
204-
)
205-
conv_2d_output_shape = translator.nhc_dimensions_to_nhwc(
206-
t_op.tmp_outputs[0].shape.vector
207-
)
208-
209-
# -- Generate tensors taking part in the conversion --
210-
reshape1_input = t_op.tmp_inputs[0]
211-
212-
reshape1_output = self.builder.duplicate_tensor(
213-
reshape1_input, name_suffix="_4D_"
214-
)
215-
reshape1_output.shape = tflite_model.Shape(conv_2d_input_shape)
216-
217-
reshape2_input = self.builder.duplicate_tensor(
218-
t_op.tmp_outputs[0], name_suffix="_4D_"
219-
)
220-
reshape2_input.shape = tflite_model.Shape(conv_2d_output_shape)
221-
222-
reshape2_output = t_op.tmp_outputs[0]
223-
224-
pre_reshapes = []
225-
226-
# Extend the weights tensor to 4D
227-
weights_tensor = t_op.tmp_inputs[1]
228-
if tensor_has_data(weights_tensor):
229-
# Do it statically
230-
weights_tensor.shape = tflite_model.Shape(conv_2d_weight_shape)
231-
weights_tensor.tmp_buffer.data = weights_tensor.tmp_buffer.data.reshape(
232-
conv_2d_weight_shape
233-
)
234-
235-
else:
236-
# Add a Reshape before the weights tensor
237-
new_weights_tensor = self.builder.duplicate_tensor(
238-
weights_tensor, name_suffix="_4D_"
239-
)
240-
new_weights_tensor.shape = tflite_model.Shape(conv_2d_weight_shape)
241-
242-
weight_reshape = tflite_model.Operator(
243-
builtin_options=reshape_options.Reshape(conv_2d_weight_shape)
244-
)
245-
weight_reshape.tmp_inputs = [weights_tensor]
246-
weight_reshape.tmp_outputs = [new_weights_tensor]
247-
248-
pre_reshapes.append(weight_reshape)
249-
250-
# Save the new weights tensor, to assign it later.
251-
weights_tensor = new_weights_tensor
252-
253-
# -- Create the new operators --
254-
reshape1 = tflite_model.Operator(
255-
builtin_options=reshape_options.Reshape(conv_2d_input_shape)
256-
)
257-
reshape1.tmp_inputs = [reshape1_input]
258-
reshape1.tmp_outputs = [reshape1_output]
259-
pre_reshapes.append(reshape1)
260-
261-
reshape2 = tflite_model.Operator(
262-
builtin_options=reshape_options.Reshape(reshape2_output.shape.vector)
263-
)
264-
reshape2.tmp_inputs = [reshape2_input]
265-
reshape2.tmp_outputs = [reshape2_output]
266-
267-
# Assign the new input and output of the Conv2D
268-
t_op.tmp_inputs = [reshape1_output, weights_tensor] + t_op.tmp_inputs[
269-
2:
270-
] # Add bias as well, if present
271-
t_op.tmp_outputs = [reshape2_input]
272-
273-
# Extend all Conv attributes to 2D
274-
common.extend_1d_stride_to_2d(conv_params.stride)
275-
common.extend_1d_dilation_to_2d(conv_params.dilation)
276-
common.extend_1d_padding_to_2d(conv_params.padding)
277-
278-
# Convert the now 2D Conv
279-
converted_conv_ops = self._convert_2d_conv(t_op, conv_params)
280-
281-
return pre_reshapes + converted_conv_ops + [reshape2]
282-
283189
# noinspection PyPep8Naming
284190
def _convert_unpadded_2D(
285191
self, t_op: tflite_model.Operator, conv_params: ConvParameters
@@ -523,9 +429,7 @@ def convert(self, node: Node):
523429
)
524430

525431
rank = t_op.tmp_inputs[1].shape.len()
526-
if rank == 3: # Conv1D
527-
ops_to_add = self._convert_1d_conv(t_op, conv_params)
528-
elif rank == 4: # Conv2D
432+
if rank == 4: # Conv2D
529433
ops_to_add = self._convert_2d_conv(t_op, conv_params)
530434
else:
531435
raise NotImplementedError(

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
BMMPattern,
2424
CatPattern,
2525
ClampPattern,
26-
Conv1dPattern,
2726
Conv2dPattern,
2827
ConvTranspose2dPattern,
2928
DropoutPattern,
@@ -266,9 +265,10 @@ def __init__(self, neutron_target_spec: NeutronTargetSpec, is_qat: bool = False)
266265
OpQuantizer(BMMPattern(is_qat=is_qat), static_qconfig),
267266
OpQuantizer(CatPattern(is_qat=is_qat), static_qconfig),
268267
OpQuantizer(ClampPattern(is_qat=is_qat), static_qconfig),
269-
OpQuantizer(Conv1dPattern(is_qat=is_qat), static_qconfig),
270268
OpQuantizer(Conv2dPattern(self, is_qat=is_qat), static_qconfig),
271-
OpQuantizer(ConvTranspose2dPattern(is_qat=is_qat), static_qconfig),
269+
OpQuantizer(
270+
ConvTranspose2dPattern(self, is_qat=is_qat), static_qconfig
271+
),
272272
OpQuantizer(DropoutPattern(is_qat=is_qat), static_qconfig),
273273
OpQuantizer(FlattenPattern(is_qat=is_qat), static_qconfig),
274274
OpQuantizer(HardTanhPattern(is_qat=is_qat), static_qconfig),

backends/nxp/quantizer/patterns.py

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,14 @@
77

88
from abc import ABC, abstractmethod
99
from dataclasses import dataclass, field
10+
from functools import partial
1011

1112
import torch
1213

13-
from executorch.backends.nxp.quantizer.utils import get_bias_qparams
14+
from executorch.backends.nxp.quantizer.utils import (
15+
get_bias_qparams,
16+
get_bias_qparams_transp_conv,
17+
)
1418
from torch import fx
1519
from torch._ops import OpOverload
1620
from torch.fx import Node
@@ -482,16 +486,6 @@ def get_anchors(
482486
)
483487

484488

485-
class Conv1dPattern(ConvPattern):
486-
def partition_types(self) -> list[OpOverload]:
487-
return [torch.ops.aten.conv1d.default]
488-
489-
490-
class ConvTranspose1dPattern(ConvPattern):
491-
def partition_types(self) -> list[OpOverload]:
492-
return [torch.ops.aten.conv_transpose1d.default]
493-
494-
495489
class Conv2dPattern(ConvPattern):
496490
def __init__(self, neutron_quantizer, is_qat: bool = False):
497491
super().__init__(is_qat=is_qat)
@@ -587,6 +581,14 @@ def get_anchors(
587581

588582

589583
class ConvTranspose2dPattern(QuantizationPattern):
584+
def __init__(self, neutron_quantizer, is_qat: bool = False):
585+
super().__init__(is_qat=is_qat)
586+
587+
self.neutron_quantizer = neutron_quantizer
588+
self.neutron_target_info = (
589+
self.neutron_quantizer.neutron_target_spec.neutron_target_info
590+
)
591+
590592
def partition_types(self) -> list[OpOverload]:
591593
return [torch.ops.aten.conv_transpose2d.input]
592594

@@ -595,48 +597,81 @@ def get_anchors(
595597
) -> PartitionAnchors:
596598
conv_node = fused_partition[0].nodes[-1]
597599

600+
# When `groups` > 1, the per-channel weight qparams have shape (`out_channels` / `groups`),
601+
# but bias qparams have shape (`out_channels`) - not divided by `groups`.
602+
# So the weight qparams must be expanded to match the shape correctly.
603+
groups = 1 if len(conv_node.args) < 7 else conv_node.args[6]
604+
if groups > 1:
605+
out_channels = conv_node.meta["val"].shape[1]
606+
derive_qparams_fn = partial(
607+
get_bias_qparams_transp_conv, out_channels=out_channels
608+
)
609+
610+
else:
611+
derive_qparams_fn = get_bias_qparams
612+
598613
bias_quantization_qspec = DerivedQuantizationSpec(
599614
derived_from=[
600615
(conv_node.args[0], conv_node),
601616
(conv_node.args[1], conv_node),
602617
],
603-
derive_qparams_fn=get_bias_qparams,
618+
derive_qparams_fn=derive_qparams_fn,
604619
dtype=torch.int32,
605620
quant_min=-(2**31) + 1,
606621
quant_max=2**31 - 1,
607622
qscheme=torch.per_channel_symmetric,
608623
ch_axis=0,
609624
)
610625

611-
weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
626+
w_ch_axis = 1
627+
weight_observer_or_fake_quant_ctr = (
628+
FakeQuantize.with_args(
629+
observer=MovingAveragePerChannelMinMaxObserver, ch_axis=w_ch_axis
630+
)
631+
if self.is_qat
632+
else PerChannelMinMaxObserver.with_args(ch_axis=w_ch_axis)
633+
)
612634
weight_quantization_spec = QuantizationSpec(
613635
dtype=torch.int8,
614636
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr,
615637
quant_min=-127,
616638
quant_max=127,
617639
qscheme=torch.per_channel_symmetric,
618-
ch_axis=1,
640+
ch_axis=w_ch_axis,
619641
)
620642

621643
# Keep bias empty if not supplied
622644
bias = []
623645
if len(conv_node.args) > 2 and conv_node.args[2] is not None:
624646
bias = [(conv_node, NodeArgsIdx(2), bias_quantization_qspec)]
625647

626-
output_specs = [(conv_node,)]
648+
# If the following node is a fusable activation, quantize together with activation
649+
output = [(conv_node,)]
650+
if len(
651+
conv_node.users
652+
) == 1 and self.neutron_target_info.is_supported_fused_activation__aten(
653+
activation := next(iter(conv_node.users))
654+
):
655+
activation_quantizer = self.neutron_quantizer.op_to_quantizer[
656+
activation.target
657+
]
658+
activation_quantizer.annotate(gm)
659+
output = []
660+
activation.meta["quantization_annotation"].input_qspec_map = {}
661+
627662
# In order for QAT to be numerically correct, there should be no quantization between
628663
# convolution node and batch norm node.
629664
if self.is_qat:
630665
conv_users = conv_node.users
631666
possibly_bn = list(conv_users.keys())[0] if len(conv_users) == 1 else None
632667
if possibly_bn and _is_batch_norm(possibly_bn):
633-
output_specs = []
668+
output = []
634669

635670
return PartitionAnchors(
636671
inputs=[(conv_node, NodeArgsIdx(0))],
637672
weights=[(conv_node, NodeArgsIdx(1), weight_quantization_spec)],
638673
biases=bias,
639-
output=output_specs,
674+
output=output,
640675
)
641676

642677

0 commit comments

Comments
 (0)