Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions python/tvm/relax/frontend/tflite/tflite_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(self, model, subgraph, exp_tab, ctx):
"CONCATENATION": self.convert_concatenation,
"CONV_2D": functools.partial(self.convert_conv, conv_type="conv2d"),
"CONV_3D": self.convert_conv3d,
"CONV_3D_TRANSPOSE": self.convert_conv3d_transpose,
"COS": functools.partial(self._convert_unary_elemwise, relax_op=_op.cos),
"CUMSUM": self.convert_cumsum,
"DENSIFY": self.convert_densify,
Expand Down Expand Up @@ -2586,6 +2587,153 @@ def convert_conv3d(self, op):
out = self.convert_fused_activation_function(out, fused_activation_fn)
return out

def convert_conv3d_transpose(self, op):
"""3D transposed convolution implementation."""

from tflite.BuiltinOptions import BuiltinOptions
from tflite.Conv3DOptions import Conv3DOptions
from tflite.Padding import Padding
from tflite.TensorType import TensorType

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) >= 3, "input tensors length should be >= 3"

# TFLite CONV_3D_TRANSPOSE input order:
# [0] output_shape, [1] weight, [2] data, [3] bias (optional)
weight_tensor = input_tensors[1]
input_tensor = input_tensors[2]
input_tensor_idx = input_tensor.tensor_idx

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]

assert op.BuiltinOptionsType() == BuiltinOptions.Conv3DOptions
op_options = op.BuiltinOptions()
conv3d_options = Conv3DOptions()
conv3d_options.Init(op_options.Bytes, op_options.Pos)

stride_d = conv3d_options.StrideD()
stride_h = conv3d_options.StrideH()
stride_w = conv3d_options.StrideW()
dilation_d = conv3d_options.DilationDFactor()
dilation_h = conv3d_options.DilationHFactor()
dilation_w = conv3d_options.DilationWFactor()
padding = conv3d_options.Padding()
fused_activation_fn = conv3d_options.FusedActivationFunction()

_, input_d, input_h, input_w, input_c = to_int_list(self.get_tensor_shape(input_tensor))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variables input_d, input_h, and input_w are extracted from the input tensor shape but are not used in the subsequent logic, including the padding calculations. For clarity and to avoid unused variable warnings, they can be replaced with underscores.

Suggested change
_, input_d, input_h, input_w, input_c = to_int_list(self.get_tensor_shape(input_tensor))
_, _, _, _, input_c = to_int_list(self.get_tensor_shape(input_tensor))


# TFLite Conv3DTranspose kernel layout is DHWOI:
# KD KH KW OC IC
kernel_d, kernel_h, kernel_w, output_channels, in_channels = to_int_list(
self.get_tensor_shape(weight_tensor)
)

dilated_kernel_d = dilation_d * (kernel_d - 1) + 1
dilated_kernel_h = dilation_h * (kernel_h - 1) + 1
dilated_kernel_w = dilation_w * (kernel_w - 1) + 1

params = {
"strides": [stride_d, stride_h, stride_w],
"dilation": [dilation_d, dilation_h, dilation_w],
"padding": [0, 0, 0, 0, 0, 0],
"output_padding": [0, 0, 0],
"data_layout": "NDHWC",
"kernel_layout": "DHWOI",
}

if input_c != in_channels:
assert input_c % in_channels == 0, (
"Input channels is not divisible by kernel in_channels."
)
params["groups"] = int(input_c / in_channels)

# weight tensor type should be INT8/UINT8 (quantization) or FLOAT32
weight_tensor_type = weight_tensor.tensor.Type()
assert weight_tensor_type in (
TensorType.INT8,
TensorType.UINT8,
TensorType.FLOAT32,
)
weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type)

in_expr = self.get_expr(input_tensor_idx)

# TFLite Conv3DTranspose kernel is already in DHWOI layout, no transpose needed.
if self.has_expr(weight_tensor.tensor_idx):
weight_expr = self.get_expr(weight_tensor.tensor_idx)
else:
if self.is_prefetched(weight_tensor.tensor_idx):
weight_value = self.get_prefetched_node(weight_tensor.tensor_idx)
else:
weight_value = self.get_tensor_value(weight_tensor)

weight_expr = self.exp_tab.new_const(
weight_value, dtype=weight_tensor_type_str,
source_name=weight_tensor.tensor.Name()
)

if padding == Padding.VALID:
pass
elif padding == Padding.SAME:
# For transposed convolution with SAME padding:
# target output_size = input_size * stride
# total_pad = max(0, dilated_kernel - stride)
for dim_kernel, dim_stride, label in [
(dilated_kernel_d, stride_d, "D"),
(dilated_kernel_h, stride_h, "H"),
(dilated_kernel_w, stride_w, "W"),
]:
total_pad = max(0, dim_kernel - dim_stride)
pad_before = total_pad // 2
pad_after = total_pad - pad_before
idx = {"D": 0, "H": 1, "W": 2}[label]
params["padding"][idx] = pad_before
params["padding"][idx + 3] = pad_after

# output_padding handles the case when stride > dilated_kernel
output_pad = max(0, dim_stride - dim_kernel)
params["output_padding"][idx] = output_pad
else:
raise tvm.error.OpAttributeUnImplemented(
f"Padding format {padding} is not supported for operator Conv3DTranspose."
)

if input_tensor.qnn_params:
raise tvm.error.OpNotImplemented(
"Quantized Conv3DTranspose is not yet supported in the Relax frontend."
)

out = relax.op.nn.conv3d_transpose(in_expr, weight_expr, **params)

# if we have bias (input_tensors[3])
if len(input_tensors) >= 4:
bias_tensor = input_tensors[3]
if bias_tensor.tensor_idx != -1:
bias_tensor_type = bias_tensor.tensor.Type()
# bias tensor type should be INT32 (int8 qnn) or INT64 (int16 qnn) or FLOAT32
assert bias_tensor_type in (TensorType.INT32, TensorType.INT64, TensorType.FLOAT32)
bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type)
if self.has_expr(bias_tensor.tensor_idx):
bias_expr = self.get_expr(bias_tensor.tensor_idx)
else:
bias_expr = self.exp_tab.new_const(
self.get_tensor_value(bias_tensor),
dtype=bias_tensor_type_str,
source_name=bias_tensor.tensor.Name(),
)
out = relax.op.add(out, bias_expr)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While relax.op.add works for adding bias due to broadcasting, using relax.op.nn.bias_add is generally preferred for neural network layers as it explicitly targets the channel dimension and is more robust to layout variations. This would also be consistent with the implementation of convert_transpose_conv (2D) in this frontend.

Suggested change
out = relax.op.add(out, bias_expr)
out = relax.op.nn.bias_add(out, bias_expr, axis=4)


# Handle fused activation.
if output_tensor.qnn_params:
raise tvm.error.OpNotImplemented(
"Quantized Conv3DTranspose is not yet supported in the Relax frontend."
)

out = self.convert_fused_activation_function(out, fused_activation_fn)
return out

def convert_split(self, op):
"""split implementation."""

Expand Down
103 changes: 103 additions & 0 deletions tests/python/relax/test_frontend_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1694,6 +1694,109 @@ def main(
verify(Conv3DModule, Expected)


def _make_conv3d_transpose_module(data_shape, kernel_shape, strides, padding):
# Compute the expected output_shape for tf.nn.conv3d_transpose.
# data_shape: (N, D, H, W, C_in), kernel_shape: (KD, KH, KW, C_out, C_in)
# strides: (1, sD, sH, sW, 1)
batch = data_shape[0]
out_channels = kernel_shape[3]
out_spatial = []
for i in range(3): # D, H, W
in_size = data_shape[1 + i]
k_size = kernel_shape[i]
s = strides[1 + i]
if padding == "VALID":
out_spatial.append((in_size - 1) * s + k_size)
else: # SAME
out_spatial.append(in_size * s)
computed_output_shape = [batch] + out_spatial + [out_channels]

class Conv3DTransposeModule(tf.Module):
@tf.function(
input_signature=[
tf.TensorSpec(shape=data_shape, dtype=tf.float32),
tf.TensorSpec(shape=kernel_shape, dtype=tf.float32),
]
)
def func(self, data, kernel):
return tf.nn.conv3d_transpose(
input=data,
filters=kernel,
output_shape=computed_output_shape,
strides=strides,
padding=padding,
)

return Conv3DTransposeModule



def test_conv3d_transpose_valid():
Conv3DTransposeModule = _make_conv3d_transpose_module(
(1, 8, 8, 8, 3), (3, 3, 3, 8, 3), (1, 1, 1, 1, 1), "VALID"
)

@I.ir_module
class Expected:
@R.function
def main(
data: R.Tensor((1, 8, 8, 8, 3), dtype="float32"),
kernel: R.Tensor((3, 3, 3, 8, 3), dtype="float32"),
) -> R.Tensor((1, 10, 10, 10, 8), dtype="float32"):
R.func_attr({"num_input": 2})
with R.dataflow():
gv: R.Tensor((1, 10, 10, 10, 8), dtype="float32") = R.nn.conv3d_transpose(
data,
kernel,
strides=[1, 1, 1],
padding=[0, 0, 0, 0, 0, 0],
output_padding=[0, 0, 0],
dilation=[1, 1, 1],
groups=1,
data_layout="NDHWC",
kernel_layout="DHWOI",
out_layout="NDHWC",
out_dtype="void",
)
R.output(gv)
return gv

verify(Conv3DTransposeModule, Expected)


def test_conv3d_transpose_same():
Conv3DTransposeModule = _make_conv3d_transpose_module(
(1, 8, 8, 8, 3), (3, 3, 3, 8, 3), (1, 1, 1, 1, 1), "SAME"
)

@I.ir_module
class Expected:
@R.function
def main(
data: R.Tensor((1, 8, 8, 8, 3), dtype="float32"),
kernel: R.Tensor((3, 3, 3, 8, 3), dtype="float32"),
) -> R.Tensor((1, 8, 8, 8, 8), dtype="float32"):
R.func_attr({"num_input": 2})
with R.dataflow():
gv: R.Tensor((1, 8, 8, 8, 8), dtype="float32") = R.nn.conv3d_transpose(
data,
kernel,
strides=[1, 1, 1],
padding=[1, 1, 1, 1, 1, 1],
output_padding=[0, 0, 0],
dilation=[1, 1, 1],
groups=1,
data_layout="NDHWC",
kernel_layout="DHWOI",
out_layout="NDHWC",
out_dtype="void",
)
R.output(gv)
return gv

verify(Conv3DTransposeModule, Expected)


def _make_pool2d_module(pool, data_shape, ksize, data_format, strides, padding):
class Pool2DModule(tf.Module):
@tf.function(
Expand Down
Loading