-
Notifications
You must be signed in to change notification settings - Fork 3.9k
[Relax][Frontend] Add TFLite Frontend Support for CONV_3D_TRANSPOSE #19530
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -133,6 +133,7 @@ def __init__(self, model, subgraph, exp_tab, ctx): | |||||
| "CONCATENATION": self.convert_concatenation, | ||||||
| "CONV_2D": functools.partial(self.convert_conv, conv_type="conv2d"), | ||||||
| "CONV_3D": self.convert_conv3d, | ||||||
| "CONV_3D_TRANSPOSE": self.convert_conv3d_transpose, | ||||||
| "COS": functools.partial(self._convert_unary_elemwise, relax_op=_op.cos), | ||||||
| "CUMSUM": self.convert_cumsum, | ||||||
| "DENSIFY": self.convert_densify, | ||||||
|
|
@@ -2586,6 +2587,153 @@ def convert_conv3d(self, op): | |||||
| out = self.convert_fused_activation_function(out, fused_activation_fn) | ||||||
| return out | ||||||
|
|
||||||
| def convert_conv3d_transpose(self, op): | ||||||
| """3D transposed convolution implementation.""" | ||||||
|
|
||||||
| from tflite.BuiltinOptions import BuiltinOptions | ||||||
| from tflite.Conv3DOptions import Conv3DOptions | ||||||
| from tflite.Padding import Padding | ||||||
| from tflite.TensorType import TensorType | ||||||
|
|
||||||
| input_tensors = self.get_input_tensors(op) | ||||||
| assert len(input_tensors) >= 3, "input tensors length should be >= 3" | ||||||
|
|
||||||
| # TFLite CONV_3D_TRANSPOSE input order: | ||||||
| # [0] output_shape, [1] weight, [2] data, [3] bias (optional) | ||||||
| weight_tensor = input_tensors[1] | ||||||
| input_tensor = input_tensors[2] | ||||||
| input_tensor_idx = input_tensor.tensor_idx | ||||||
|
|
||||||
| output_tensors = self.get_output_tensors(op) | ||||||
| assert len(output_tensors) == 1, "output tensors length should be 1" | ||||||
| output_tensor = output_tensors[0] | ||||||
|
|
||||||
| assert op.BuiltinOptionsType() == BuiltinOptions.Conv3DOptions | ||||||
| op_options = op.BuiltinOptions() | ||||||
| conv3d_options = Conv3DOptions() | ||||||
| conv3d_options.Init(op_options.Bytes, op_options.Pos) | ||||||
|
|
||||||
| stride_d = conv3d_options.StrideD() | ||||||
| stride_h = conv3d_options.StrideH() | ||||||
| stride_w = conv3d_options.StrideW() | ||||||
| dilation_d = conv3d_options.DilationDFactor() | ||||||
| dilation_h = conv3d_options.DilationHFactor() | ||||||
| dilation_w = conv3d_options.DilationWFactor() | ||||||
| padding = conv3d_options.Padding() | ||||||
| fused_activation_fn = conv3d_options.FusedActivationFunction() | ||||||
|
|
||||||
| _, input_d, input_h, input_w, input_c = to_int_list(self.get_tensor_shape(input_tensor)) | ||||||
|
|
||||||
| # TFLite Conv3DTranspose kernel layout is DHWOI: | ||||||
| # KD KH KW OC IC | ||||||
| kernel_d, kernel_h, kernel_w, output_channels, in_channels = to_int_list( | ||||||
| self.get_tensor_shape(weight_tensor) | ||||||
| ) | ||||||
|
|
||||||
| dilated_kernel_d = dilation_d * (kernel_d - 1) + 1 | ||||||
| dilated_kernel_h = dilation_h * (kernel_h - 1) + 1 | ||||||
| dilated_kernel_w = dilation_w * (kernel_w - 1) + 1 | ||||||
|
|
||||||
| params = { | ||||||
| "strides": [stride_d, stride_h, stride_w], | ||||||
| "dilation": [dilation_d, dilation_h, dilation_w], | ||||||
| "padding": [0, 0, 0, 0, 0, 0], | ||||||
| "output_padding": [0, 0, 0], | ||||||
| "data_layout": "NDHWC", | ||||||
| "kernel_layout": "DHWOI", | ||||||
| } | ||||||
|
|
||||||
| if input_c != in_channels: | ||||||
| assert input_c % in_channels == 0, ( | ||||||
| "Input channels is not divisible by kernel in_channels." | ||||||
| ) | ||||||
| params["groups"] = int(input_c / in_channels) | ||||||
|
|
||||||
| # weight tensor type should be INT8/UINT8 (quantization) or FLOAT32 | ||||||
| weight_tensor_type = weight_tensor.tensor.Type() | ||||||
| assert weight_tensor_type in ( | ||||||
| TensorType.INT8, | ||||||
| TensorType.UINT8, | ||||||
| TensorType.FLOAT32, | ||||||
| ) | ||||||
| weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type) | ||||||
|
|
||||||
| in_expr = self.get_expr(input_tensor_idx) | ||||||
|
|
||||||
| # TFLite Conv3DTranspose kernel is already in DHWOI layout, no transpose needed. | ||||||
| if self.has_expr(weight_tensor.tensor_idx): | ||||||
| weight_expr = self.get_expr(weight_tensor.tensor_idx) | ||||||
| else: | ||||||
| if self.is_prefetched(weight_tensor.tensor_idx): | ||||||
| weight_value = self.get_prefetched_node(weight_tensor.tensor_idx) | ||||||
| else: | ||||||
| weight_value = self.get_tensor_value(weight_tensor) | ||||||
|
|
||||||
| weight_expr = self.exp_tab.new_const( | ||||||
| weight_value, dtype=weight_tensor_type_str, | ||||||
| source_name=weight_tensor.tensor.Name() | ||||||
| ) | ||||||
|
|
||||||
| if padding == Padding.VALID: | ||||||
| pass | ||||||
| elif padding == Padding.SAME: | ||||||
| # For transposed convolution with SAME padding: | ||||||
| # target output_size = input_size * stride | ||||||
| # total_pad = max(0, dilated_kernel - stride) | ||||||
| for dim_kernel, dim_stride, label in [ | ||||||
| (dilated_kernel_d, stride_d, "D"), | ||||||
| (dilated_kernel_h, stride_h, "H"), | ||||||
| (dilated_kernel_w, stride_w, "W"), | ||||||
| ]: | ||||||
| total_pad = max(0, dim_kernel - dim_stride) | ||||||
| pad_before = total_pad // 2 | ||||||
| pad_after = total_pad - pad_before | ||||||
| idx = {"D": 0, "H": 1, "W": 2}[label] | ||||||
| params["padding"][idx] = pad_before | ||||||
| params["padding"][idx + 3] = pad_after | ||||||
|
|
||||||
| # output_padding handles the case when stride > dilated_kernel | ||||||
| output_pad = max(0, dim_stride - dim_kernel) | ||||||
| params["output_padding"][idx] = output_pad | ||||||
| else: | ||||||
| raise tvm.error.OpAttributeUnImplemented( | ||||||
| f"Padding format {padding} is not supported for operator Conv3DTranspose." | ||||||
| ) | ||||||
|
|
||||||
| if input_tensor.qnn_params: | ||||||
| raise tvm.error.OpNotImplemented( | ||||||
| "Quantized Conv3DTranspose is not yet supported in the Relax frontend." | ||||||
| ) | ||||||
|
|
||||||
| out = relax.op.nn.conv3d_transpose(in_expr, weight_expr, **params) | ||||||
|
|
||||||
| # if we have bias (input_tensors[3]) | ||||||
| if len(input_tensors) >= 4: | ||||||
| bias_tensor = input_tensors[3] | ||||||
| if bias_tensor.tensor_idx != -1: | ||||||
| bias_tensor_type = bias_tensor.tensor.Type() | ||||||
| # bias tensor type should be INT32 (int8 qnn) or INT64 (int16 qnn) or FLOAT32 | ||||||
| assert bias_tensor_type in (TensorType.INT32, TensorType.INT64, TensorType.FLOAT32) | ||||||
| bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type) | ||||||
| if self.has_expr(bias_tensor.tensor_idx): | ||||||
| bias_expr = self.get_expr(bias_tensor.tensor_idx) | ||||||
| else: | ||||||
| bias_expr = self.exp_tab.new_const( | ||||||
| self.get_tensor_value(bias_tensor), | ||||||
| dtype=bias_tensor_type_str, | ||||||
| source_name=bias_tensor.tensor.Name(), | ||||||
| ) | ||||||
| out = relax.op.add(out, bias_expr) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While
Suggested change
|
||||||
|
|
||||||
| # Handle fused activation. | ||||||
| if output_tensor.qnn_params: | ||||||
| raise tvm.error.OpNotImplemented( | ||||||
| "Quantized Conv3DTranspose is not yet supported in the Relax frontend." | ||||||
| ) | ||||||
|
|
||||||
| out = self.convert_fused_activation_function(out, fused_activation_fn) | ||||||
| return out | ||||||
|
|
||||||
| def convert_split(self, op): | ||||||
| """split implementation.""" | ||||||
|
|
||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variables
input_d,input_h, andinput_ware extracted from the input tensor shape but are not used in the subsequent logic, including the padding calculations. For clarity and to avoid unused variable warnings, they can be replaced with underscores.