Skip to content

Commit 73e97a6

Browse files
committed
feat: added aten pass to convert conv1d to conv2d
1 parent 063f9c9 commit 73e97a6

13 files changed

Lines changed: 916 additions & 392 deletions

File tree

backends/nxp/aten_passes/convert_1d_conv_to_2d.py

Lines changed: 396 additions & 0 deletions
Large diffs are not rendered by default.

backends/nxp/aten_passes/neutron_aten_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
import torch
99

10+
from executorch.backends.nxp.aten_passes.convert_1d_conv_to_2d import (
11+
ConvertConv1dToConv2dPass,
12+
)
1013
from executorch.backends.nxp.aten_passes.convert_div_to_mul import ConvertDivToMulPass
1114
from executorch.backends.nxp.aten_passes.decompose_split_to_slices_pass import (
1215
DecomposeSplitToSlicesPass,
@@ -49,6 +52,7 @@ def _get_default_passes(neutron_target_spec, qat_mode: bool = False) -> list[Pas
4952
FuseLinearAndAddPass(),
5053
MoveActivationBeforeConcat(neutron_target_spec),
5154
ConvertDivToMulPass(),
55+
ConvertConv1dToConv2dPass(),
5256
]
5357

5458
if not qat_mode:

backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py

Lines changed: 4 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from executorch.backends.nxp.backend.ir.converter.conversion import (
1616
aten_translator,
1717
common,
18-
translator,
1918
)
2019
from executorch.backends.nxp.backend.ir.converter.conversion.common import try_get_input
2120
from executorch.backends.nxp.backend.ir.converter.conversion.translator import (
@@ -42,7 +41,6 @@
4241
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
4342
conv_2d_options,
4443
depthwise_conv_2d_options,
45-
reshape_options,
4644
transpose_conv_options,
4745
)
4846
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
@@ -70,8 +68,9 @@ def _is_supported_on_target(
7068
return False
7169

7270
if conv_params.transposed:
73-
# TransposeConv1d is not supported on Neutron
74-
if len(conv_params.dilation) == 1:
71+
# TransposeConv2d with groups > 1 is not supported
72+
# TODO: split into multiple convs with groups = 1
73+
if conv_params.groups > 1:
7574
return False
7675
if not node_is_effectively_static_tensor(weights, parameters_mapping):
7776
# Only supported if the weights are static, because TFLite `TransposeConv` uses permuted
@@ -187,99 +186,6 @@ def _get_convolution_arguments(
187186
groups,
188187
)
189188

190-
def _convert_1d_conv(
191-
self, t_op: tflite_model.Operator, conv_params: ConvParameters
192-
) -> list[tflite_model.Operator]:
193-
"""Convert the 'Conv' operator with a 1D kernel to TFLite 'Conv2D'.
194-
TFLite doesn't support 1D convolution, but this behaviour can be represented using
195-
Reshape -> Conv2D -> Reshape.
196-
The first reshape introduces a 4th dimension with size 1. The second Reshape removes the temporary dimension.
197-
"""
198-
# -- Calculate the shapes for equivalent 2D convolution --
199-
conv_2d_input_shape = translator.nhc_dimensions_to_nhwc(
200-
t_op.tmp_inputs[0].shape.vector
201-
)
202-
conv_2d_weight_shape = translator.nhc_dimensions_to_nhwc(
203-
t_op.tmp_inputs[1].shape.vector
204-
)
205-
conv_2d_output_shape = translator.nhc_dimensions_to_nhwc(
206-
t_op.tmp_outputs[0].shape.vector
207-
)
208-
209-
# -- Generate tensors taking part in the conversion --
210-
reshape1_input = t_op.tmp_inputs[0]
211-
212-
reshape1_output = self.builder.duplicate_tensor(
213-
reshape1_input, name_suffix="_4D_"
214-
)
215-
reshape1_output.shape = tflite_model.Shape(conv_2d_input_shape)
216-
217-
reshape2_input = self.builder.duplicate_tensor(
218-
t_op.tmp_outputs[0], name_suffix="_4D_"
219-
)
220-
reshape2_input.shape = tflite_model.Shape(conv_2d_output_shape)
221-
222-
reshape2_output = t_op.tmp_outputs[0]
223-
224-
pre_reshapes = []
225-
226-
# Extend the weights tensor to 4D
227-
weights_tensor = t_op.tmp_inputs[1]
228-
if tensor_has_data(weights_tensor):
229-
# Do it statically
230-
weights_tensor.shape = tflite_model.Shape(conv_2d_weight_shape)
231-
weights_tensor.tmp_buffer.data = weights_tensor.tmp_buffer.data.reshape(
232-
conv_2d_weight_shape
233-
)
234-
235-
else:
236-
# Add a Reshape before the weights tensor
237-
new_weights_tensor = self.builder.duplicate_tensor(
238-
weights_tensor, name_suffix="_4D_"
239-
)
240-
new_weights_tensor.shape = tflite_model.Shape(conv_2d_weight_shape)
241-
242-
weight_reshape = tflite_model.Operator(
243-
builtin_options=reshape_options.Reshape(conv_2d_weight_shape)
244-
)
245-
weight_reshape.tmp_inputs = [weights_tensor]
246-
weight_reshape.tmp_outputs = [new_weights_tensor]
247-
248-
pre_reshapes.append(weight_reshape)
249-
250-
# Save the new weights tensor, to assign it later.
251-
weights_tensor = new_weights_tensor
252-
253-
# -- Create the new operators --
254-
reshape1 = tflite_model.Operator(
255-
builtin_options=reshape_options.Reshape(conv_2d_input_shape)
256-
)
257-
reshape1.tmp_inputs = [reshape1_input]
258-
reshape1.tmp_outputs = [reshape1_output]
259-
pre_reshapes.append(reshape1)
260-
261-
reshape2 = tflite_model.Operator(
262-
builtin_options=reshape_options.Reshape(reshape2_output.shape.vector)
263-
)
264-
reshape2.tmp_inputs = [reshape2_input]
265-
reshape2.tmp_outputs = [reshape2_output]
266-
267-
# Assign the new input and output of the Conv2D
268-
t_op.tmp_inputs = [reshape1_output, weights_tensor] + t_op.tmp_inputs[
269-
2:
270-
] # Add bias as well, if present
271-
t_op.tmp_outputs = [reshape2_input]
272-
273-
# Extend all Conv attributes to 2D
274-
common.extend_1d_stride_to_2d(conv_params.stride)
275-
common.extend_1d_dilation_to_2d(conv_params.dilation)
276-
common.extend_1d_padding_to_2d(conv_params.padding)
277-
278-
# Convert the now 2D Conv
279-
converted_conv_ops = self._convert_2d_conv(t_op, conv_params)
280-
281-
return pre_reshapes + converted_conv_ops + [reshape2]
282-
283189
# noinspection PyPep8Naming
284190
def _convert_unpadded_2D(
285191
self, t_op: tflite_model.Operator, conv_params: ConvParameters
@@ -523,9 +429,7 @@ def convert(self, node: Node):
523429
)
524430

525431
rank = t_op.tmp_inputs[1].shape.len()
526-
if rank == 3: # Conv1D
527-
ops_to_add = self._convert_1d_conv(t_op, conv_params)
528-
elif rank == 4: # Conv2D
432+
if rank == 4: # Conv2D
529433
ops_to_add = self._convert_2d_conv(t_op, conv_params)
530434
else:
531435
raise NotImplementedError(

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
BMMPattern,
2424
CatPattern,
2525
ClampPattern,
26-
Conv1dPattern,
2726
Conv2dPattern,
2827
ConvTranspose2dPattern,
2928
DropoutPattern,
@@ -266,9 +265,10 @@ def __init__(self, neutron_target_spec: NeutronTargetSpec, is_qat: bool = False)
266265
OpQuantizer(BMMPattern(is_qat=is_qat), static_qconfig),
267266
OpQuantizer(CatPattern(is_qat=is_qat), static_qconfig),
268267
OpQuantizer(ClampPattern(is_qat=is_qat), static_qconfig),
269-
OpQuantizer(Conv1dPattern(is_qat=is_qat), static_qconfig),
270268
OpQuantizer(Conv2dPattern(self, is_qat=is_qat), static_qconfig),
271-
OpQuantizer(ConvTranspose2dPattern(is_qat=is_qat), static_qconfig),
269+
OpQuantizer(
270+
ConvTranspose2dPattern(self, is_qat=is_qat), static_qconfig
271+
),
272272
OpQuantizer(DropoutPattern(is_qat=is_qat), static_qconfig),
273273
OpQuantizer(FlattenPattern(is_qat=is_qat), static_qconfig),
274274
OpQuantizer(HardTanhPattern(is_qat=is_qat), static_qconfig),

backends/nxp/quantizer/patterns.py

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,14 @@
77

88
from abc import ABC, abstractmethod
99
from dataclasses import dataclass, field
10+
from functools import partial
1011

1112
import torch
1213

13-
from executorch.backends.nxp.quantizer.utils import get_bias_qparams
14+
from executorch.backends.nxp.quantizer.utils import (
15+
get_bias_qparams,
16+
get_padded_bias_qparams,
17+
)
1418
from torch import fx
1519
from torch._ops import OpOverload
1620
from torch.fx import Node
@@ -482,16 +486,6 @@ def get_anchors(
482486
)
483487

484488

485-
class Conv1dPattern(ConvPattern):
486-
def partition_types(self) -> list[OpOverload]:
487-
return [torch.ops.aten.conv1d.default]
488-
489-
490-
class ConvTranspose1dPattern(ConvPattern):
491-
def partition_types(self) -> list[OpOverload]:
492-
return [torch.ops.aten.conv_transpose1d.default]
493-
494-
495489
class Conv2dPattern(ConvPattern):
496490
def __init__(self, neutron_quantizer, is_qat: bool = False):
497491
super().__init__(is_qat=is_qat)
@@ -572,6 +566,14 @@ def get_anchors(
572566

573567

574568
class ConvTranspose2dPattern(QuantizationPattern):
569+
def __init__(self, neutron_quantizer, is_qat: bool = False):
570+
super().__init__(is_qat=is_qat)
571+
572+
self.neutron_quantizer = neutron_quantizer
573+
self.neutron_target_info = (
574+
self.neutron_quantizer.neutron_target_spec.neutron_target_info
575+
)
576+
575577
def partition_types(self) -> list[OpOverload]:
576578
return [torch.ops.aten.conv_transpose2d.input]
577579

@@ -580,48 +582,81 @@ def get_anchors(
580582
) -> PartitionAnchors:
581583
conv_node = fused_partition[0].nodes[-1]
582584

585+
# When `groups` > 1, the per-channel weight qparams have shape (`out_channels` / `groups`),
586+
# but bias qparams have shape (`out_channels`) - not divided by `groups`.
587+
# So the weight qparams must be expanded to match the shape correctly.
588+
groups = 1 if len(conv_node.args) < 7 else conv_node.args[6]
589+
if groups > 1:
590+
out_channels = conv_node.meta["val"].shape[1]
591+
derive_qparams_fn = partial(
592+
get_padded_bias_qparams, out_channels=out_channels
593+
)
594+
595+
else:
596+
derive_qparams_fn = get_bias_qparams
597+
583598
bias_quantization_qspec = DerivedQuantizationSpec(
584599
derived_from=[
585600
(conv_node.args[0], conv_node),
586601
(conv_node.args[1], conv_node),
587602
],
588-
derive_qparams_fn=get_bias_qparams,
603+
derive_qparams_fn=derive_qparams_fn,
589604
dtype=torch.int32,
590605
quant_min=-(2**31) + 1,
591606
quant_max=2**31 - 1,
592607
qscheme=torch.per_channel_symmetric,
593608
ch_axis=0,
594609
)
595610

596-
weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
611+
w_ch_axis = 1
612+
weight_observer_or_fake_quant_ctr = (
613+
FakeQuantize.with_args(
614+
observer=MovingAveragePerChannelMinMaxObserver, ch_axis=w_ch_axis
615+
)
616+
if self.is_qat
617+
else PerChannelMinMaxObserver.with_args(ch_axis=w_ch_axis)
618+
)
597619
weight_quantization_spec = QuantizationSpec(
598620
dtype=torch.int8,
599621
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr,
600622
quant_min=-127,
601623
quant_max=127,
602624
qscheme=torch.per_channel_symmetric,
603-
ch_axis=1,
625+
ch_axis=w_ch_axis,
604626
)
605627

606628
# Keep bias empty if not supplied
607629
bias = []
608630
if len(conv_node.args) > 2 and conv_node.args[2] is not None:
609631
bias = [(conv_node, NodeArgsIdx(2), bias_quantization_qspec)]
610632

611-
output_specs = [(conv_node,)]
633+
# If the following node is a fusable activation, quantize together with activation
634+
output = [(conv_node,)]
635+
if len(
636+
conv_node.users
637+
) == 1 and self.neutron_target_info.is_supported_fused_activation__aten(
638+
activation := next(iter(conv_node.users))
639+
):
640+
activation_quantizer = self.neutron_quantizer.op_to_quantizer[
641+
activation.target
642+
]
643+
activation_quantizer.annotate(gm)
644+
output = []
645+
activation.meta["quantization_annotation"].input_qspec_map = {}
646+
612647
# In order for QAT to be numerically correct, there should be no quantization between
613648
# convolution node and batch norm node.
614649
if self.is_qat:
615650
conv_users = conv_node.users
616651
possibly_bn = list(conv_users.keys())[0] if len(conv_users) == 1 else None
617652
if possibly_bn and _is_batch_norm(possibly_bn):
618-
output_specs = []
653+
output = []
619654

620655
return PartitionAnchors(
621656
inputs=[(conv_node, NodeArgsIdx(0))],
622657
weights=[(conv_node, NodeArgsIdx(1), weight_quantization_spec)],
623658
biases=bias,
624-
output=output_specs,
659+
output=output,
625660
)
626661

627662

backends/nxp/quantizer/utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,32 @@ def get_bias_qparams(
7373
return bias_scale, bias_zero_point
7474

7575

76+
def get_padded_bias_qparams(
77+
obs_or_fqs: List[ObserverOrFakeQuantize],
78+
out_channels: int | None = None,
79+
) -> Tuple[torch.Tensor, torch.Tensor]:
80+
act_scale, _ = obs_or_fqs[0].calculate_qparams()
81+
weight_scale, _ = obs_or_fqs[1].calculate_qparams()
82+
83+
# It may happen that `torch.ao` incorrectly sets the weight qparams, not matching bias qparams.
84+
# If `out_channels` is given, ensure bias qparams are per-output-channel:
85+
# So for example w = [w1, w2, w3] -> [w1, w2, w3, w1, w2, w3, ...]
86+
if out_channels is not None:
87+
weight_scale = weight_scale.flatten()
88+
if weight_scale.numel() != out_channels:
89+
if out_channels % weight_scale.numel() != 0:
90+
raise RuntimeError(
91+
"Weight qparams cannot be repeated if not divisible by `out_channels`."
92+
)
93+
weight_scale = weight_scale.repeat(out_channels // weight_scale.numel())
94+
95+
act_scale = act_scale.flatten()[0]
96+
97+
bias_scale = act_scale * weight_scale
98+
bias_zero_point = torch.zeros_like(bias_scale, dtype=torch.int64)
99+
return bias_scale, bias_zero_point
100+
101+
76102
def get_aten_node_target_partitions(
77103
graph: torch.fx.Graph,
78104
wanted_original_aten_op: List[OpOverload],

0 commit comments

Comments
 (0)