Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
396 changes: 396 additions & 0 deletions backends/nxp/aten_passes/convert_1d_conv_to_2d.py

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions backends/nxp/aten_passes/neutron_aten_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

import torch

from executorch.backends.nxp.aten_passes.convert_1d_conv_to_2d import (
ConvertConv1dToConv2dPass,
)
from executorch.backends.nxp.aten_passes.convert_div_to_mul import ConvertDivToMulPass
from executorch.backends.nxp.aten_passes.decompose_split_to_slices_pass import (
DecomposeSplitToSlicesPass,
Expand Down Expand Up @@ -49,6 +52,7 @@ def _get_default_passes(neutron_target_spec, qat_mode: bool = False) -> list[Pas
FuseLinearAndAddPass(),
MoveActivationBeforeConcat(neutron_target_spec),
ConvertDivToMulPass(),
ConvertConv1dToConv2dPass(),
]

if not qat_mode:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from executorch.backends.nxp.backend.ir.converter.conversion import (
aten_translator,
common,
translator,
)
from executorch.backends.nxp.backend.ir.converter.conversion.common import try_get_input
from executorch.backends.nxp.backend.ir.converter.conversion.translator import (
Expand All @@ -42,7 +41,6 @@
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
conv_2d_options,
depthwise_conv_2d_options,
reshape_options,
transpose_conv_options,
)
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
Expand Down Expand Up @@ -70,8 +68,9 @@ def _is_supported_on_target(
return False

if conv_params.transposed:
# TransposeConv1d is not supported on Neutron
if len(conv_params.dilation) == 1:
# TransposeConv2d with groups > 1 is not supported
# TODO: split into multiple convs with groups = 1
if conv_params.groups > 1:
return False
if not node_is_effectively_static_tensor(weights, parameters_mapping):
# Only supported if the weights are static, because TFLite `TransposeConv` uses permuted
Expand Down Expand Up @@ -187,99 +186,6 @@ def _get_convolution_arguments(
groups,
)

def _convert_1d_conv(
self, t_op: tflite_model.Operator, conv_params: ConvParameters
) -> list[tflite_model.Operator]:
"""Convert the 'Conv' operator with a 1D kernel to TFLite 'Conv2D'.
TFLite doesn't support 1D convolution, but this behaviour can be represented using
Reshape -> Conv2D -> Reshape.
The first reshape introduces a 4th dimension with size 1. The second Reshape removes the temporary dimension.
"""
# -- Calculate the shapes for equivalent 2D convolution --
conv_2d_input_shape = translator.nhc_dimensions_to_nhwc(
t_op.tmp_inputs[0].shape.vector
)
conv_2d_weight_shape = translator.nhc_dimensions_to_nhwc(
t_op.tmp_inputs[1].shape.vector
)
conv_2d_output_shape = translator.nhc_dimensions_to_nhwc(
t_op.tmp_outputs[0].shape.vector
)

# -- Generate tensors taking part in the conversion --
reshape1_input = t_op.tmp_inputs[0]

reshape1_output = self.builder.duplicate_tensor(
reshape1_input, name_suffix="_4D_"
)
reshape1_output.shape = tflite_model.Shape(conv_2d_input_shape)

reshape2_input = self.builder.duplicate_tensor(
t_op.tmp_outputs[0], name_suffix="_4D_"
)
reshape2_input.shape = tflite_model.Shape(conv_2d_output_shape)

reshape2_output = t_op.tmp_outputs[0]

pre_reshapes = []

# Extend the weights tensor to 4D
weights_tensor = t_op.tmp_inputs[1]
if tensor_has_data(weights_tensor):
# Do it statically
weights_tensor.shape = tflite_model.Shape(conv_2d_weight_shape)
weights_tensor.tmp_buffer.data = weights_tensor.tmp_buffer.data.reshape(
conv_2d_weight_shape
)

else:
# Add a Reshape before the weights tensor
new_weights_tensor = self.builder.duplicate_tensor(
weights_tensor, name_suffix="_4D_"
)
new_weights_tensor.shape = tflite_model.Shape(conv_2d_weight_shape)

weight_reshape = tflite_model.Operator(
builtin_options=reshape_options.Reshape(conv_2d_weight_shape)
)
weight_reshape.tmp_inputs = [weights_tensor]
weight_reshape.tmp_outputs = [new_weights_tensor]

pre_reshapes.append(weight_reshape)

# Save the new weights tensor, to assign it later.
weights_tensor = new_weights_tensor

# -- Create the new operators --
reshape1 = tflite_model.Operator(
builtin_options=reshape_options.Reshape(conv_2d_input_shape)
)
reshape1.tmp_inputs = [reshape1_input]
reshape1.tmp_outputs = [reshape1_output]
pre_reshapes.append(reshape1)

reshape2 = tflite_model.Operator(
builtin_options=reshape_options.Reshape(reshape2_output.shape.vector)
)
reshape2.tmp_inputs = [reshape2_input]
reshape2.tmp_outputs = [reshape2_output]

# Assign the new input and output of the Conv2D
t_op.tmp_inputs = [reshape1_output, weights_tensor] + t_op.tmp_inputs[
2:
] # Add bias as well, if present
t_op.tmp_outputs = [reshape2_input]

# Extend all Conv attributes to 2D
common.extend_1d_stride_to_2d(conv_params.stride)
common.extend_1d_dilation_to_2d(conv_params.dilation)
common.extend_1d_padding_to_2d(conv_params.padding)

# Convert the now 2D Conv
converted_conv_ops = self._convert_2d_conv(t_op, conv_params)

return pre_reshapes + converted_conv_ops + [reshape2]

# noinspection PyPep8Naming
def _convert_unpadded_2D(
self, t_op: tflite_model.Operator, conv_params: ConvParameters
Expand Down Expand Up @@ -523,9 +429,7 @@ def convert(self, node: Node):
)

rank = t_op.tmp_inputs[1].shape.len()
if rank == 3: # Conv1D
ops_to_add = self._convert_1d_conv(t_op, conv_params)
elif rank == 4: # Conv2D
if rank == 4: # Conv2D
ops_to_add = self._convert_2d_conv(t_op, conv_params)
else:
raise NotImplementedError(
Expand Down
6 changes: 3 additions & 3 deletions backends/nxp/quantizer/neutron_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
BMMPattern,
CatPattern,
ClampPattern,
Conv1dPattern,
Conv2dPattern,
ConvTranspose2dPattern,
DropoutPattern,
Expand Down Expand Up @@ -266,9 +265,10 @@ def __init__(self, neutron_target_spec: NeutronTargetSpec, is_qat: bool = False)
OpQuantizer(BMMPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(CatPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(ClampPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(Conv1dPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(Conv2dPattern(self, is_qat=is_qat), static_qconfig),
OpQuantizer(ConvTranspose2dPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(
ConvTranspose2dPattern(self, is_qat=is_qat), static_qconfig
),
OpQuantizer(DropoutPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(FlattenPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(HardTanhPattern(is_qat=is_qat), static_qconfig),
Expand Down
69 changes: 52 additions & 17 deletions backends/nxp/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import partial

import torch

from executorch.backends.nxp.quantizer.utils import get_bias_qparams
from executorch.backends.nxp.quantizer.utils import (
get_bias_qparams,
get_padded_bias_qparams,
)
from torch import fx
from torch._ops import OpOverload
from torch.fx import Node
Expand Down Expand Up @@ -482,16 +486,6 @@ def get_anchors(
)


class Conv1dPattern(ConvPattern):
def partition_types(self) -> list[OpOverload]:
return [torch.ops.aten.conv1d.default]


class ConvTranspose1dPattern(ConvPattern):
def partition_types(self) -> list[OpOverload]:
return [torch.ops.aten.conv_transpose1d.default]


class Conv2dPattern(ConvPattern):
def __init__(self, neutron_quantizer, is_qat: bool = False):
super().__init__(is_qat=is_qat)
Expand Down Expand Up @@ -572,6 +566,14 @@ def get_anchors(


class ConvTranspose2dPattern(QuantizationPattern):
def __init__(self, neutron_quantizer, is_qat: bool = False):
super().__init__(is_qat=is_qat)

self.neutron_quantizer = neutron_quantizer
self.neutron_target_info = (
self.neutron_quantizer.neutron_target_spec.neutron_target_info
)

def partition_types(self) -> list[OpOverload]:
return [torch.ops.aten.conv_transpose2d.input]

Expand All @@ -580,48 +582,81 @@ def get_anchors(
) -> PartitionAnchors:
conv_node = fused_partition[0].nodes[-1]

# When `groups` > 1, the per-channel weight qparams have shape (`out_channels` / `groups`),
# but bias qparams have shape (`out_channels`) - not divided by `groups`.
# So the weight qparams must be expanded to match the shape correctly.
groups = 1 if len(conv_node.args) < 7 else conv_node.args[6]
if groups > 1:
out_channels = conv_node.meta["val"].shape[1]
derive_qparams_fn = partial(
get_padded_bias_qparams, out_channels=out_channels
)

else:
derive_qparams_fn = get_bias_qparams

bias_quantization_qspec = DerivedQuantizationSpec(
derived_from=[
(conv_node.args[0], conv_node),
(conv_node.args[1], conv_node),
],
derive_qparams_fn=get_bias_qparams,
derive_qparams_fn=derive_qparams_fn,
dtype=torch.int32,
quant_min=-(2**31) + 1,
quant_max=2**31 - 1,
qscheme=torch.per_channel_symmetric,
ch_axis=0,
)

weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
w_ch_axis = 1
weight_observer_or_fake_quant_ctr = (
FakeQuantize.with_args(
observer=MovingAveragePerChannelMinMaxObserver, ch_axis=w_ch_axis
)
if self.is_qat
else PerChannelMinMaxObserver.with_args(ch_axis=w_ch_axis)
)
weight_quantization_spec = QuantizationSpec(
dtype=torch.int8,
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr,
quant_min=-127,
quant_max=127,
qscheme=torch.per_channel_symmetric,
ch_axis=1,
ch_axis=w_ch_axis,
)

# Keep bias empty if not supplied
bias = []
if len(conv_node.args) > 2 and conv_node.args[2] is not None:
bias = [(conv_node, NodeArgsIdx(2), bias_quantization_qspec)]

output_specs = [(conv_node,)]
# If the following node is a fusable activation, quantize together with activation
output = [(conv_node,)]
if len(
conv_node.users
) == 1 and self.neutron_target_info.is_supported_fused_activation__aten(
activation := next(iter(conv_node.users))
):
activation_quantizer = self.neutron_quantizer.op_to_quantizer[
activation.target
]
activation_quantizer.annotate(gm)
output = []
activation.meta["quantization_annotation"].input_qspec_map = {}

# In order for QAT to be numerically correct, there should be no quantization between
# convolution node and batch norm node.
if self.is_qat:
conv_users = conv_node.users
possibly_bn = list(conv_users.keys())[0] if len(conv_users) == 1 else None
if possibly_bn and _is_batch_norm(possibly_bn):
output_specs = []
output = []

return PartitionAnchors(
inputs=[(conv_node, NodeArgsIdx(0))],
weights=[(conv_node, NodeArgsIdx(1), weight_quantization_spec)],
biases=bias,
output=output_specs,
output=output,
)


Expand Down
26 changes: 26 additions & 0 deletions backends/nxp/quantizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,32 @@ def get_bias_qparams(
return bias_scale, bias_zero_point


def get_padded_bias_qparams(
obs_or_fqs: List[ObserverOrFakeQuantize],
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: List -> list to minimize imports.
Same for Tuple below.

out_channels: int | None = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function name doesn't explain what the function does. So I think some docstring here explaining it is specifically designed for transpose conv would be useful. Or perhaps the function name could mention the transpose conv?

act_scale, _ = obs_or_fqs[0].calculate_qparams()
weight_scale, _ = obs_or_fqs[1].calculate_qparams()

# It may happen that `torch.ao` incorrectly sets the weight qparams, not matching bias qparams.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the weight qparams really set incorrectly by torchao? After our discussions, I thought that that's how it's supposed to work, and just the function which derives the bias qparams was wrong.

# If `out_channels` is given, ensure bias qparams are per-output-channel:
# So for example w = [w1, w2, w3] -> [w1, w2, w3, w1, w2, w3, ...]
if out_channels is not None:
weight_scale = weight_scale.flatten()
if weight_scale.numel() != out_channels:
if out_channels % weight_scale.numel() != 0:
raise RuntimeError(
"Weight qparams cannot be repeated if not divisible by `out_channels`."
)
weight_scale = weight_scale.repeat(out_channels // weight_scale.numel())

act_scale = act_scale.flatten()[0]

bias_scale = act_scale * weight_scale
bias_zero_point = torch.zeros_like(bias_scale, dtype=torch.int64)
return bias_scale, bias_zero_point


def get_aten_node_target_partitions(
graph: torch.fx.Graph,
wanted_original_aten_op: List[OpOverload],
Expand Down
Loading
Loading