Skip to content

Commit 4de3aaa

Browse files
committed
[Relax][Frontend][TFLite] Complete QDQ conversion for remaining quantized ops
Replace the last _qnn.op.* references in the TFLite frontend with the DQ → float op → Q pattern, eliminating all references to the non-existent _qnn module. convert_fully_connected: - _qnn.op.dense → DQ input + DQ weight (axis remap OC 0→1) + matmul - _qnn.op.requantize + activation → self.quantize + activation - INT32/INT64 bias dequantized with input_scale × weight_scale convert_concatenation: - _qnn.op.concat → DQ each input → float concat → quantize → activation convert_transpose_conv: - _qnn.op.conv2d_transpose → DQ input + DQ weight (axis remap OHWI→IOHW, OC axis 0→1) + float conv2d_transpose - _qnn.op.requantize → self.quantize - INT32/INT64 bias dequantized (previously missing — added in review fix) convert_detection_postprocess: - 3× _qnn.op.dequantize → self.dequantize convert_reshape (uint8 path): - Requantize on integer tensor → DQ → reshape → Q Depthwise Conv2D: - Explicit OpNotImplemented for per-channel depthwise (axis semantics change after [1,KH,KW,C*M] → [KH,KW,C,M] reshape) Cleanup: - Removed now-unnecessary F821 noqa comment (zero _qnn / _expr refs) - Removed unused locals (weight_shape, output_tensor_type_str) All _qnn.op.* references eliminated. 386 tests pass, ruff clean. Closes #19534.
1 parent d71cab5 commit 4de3aaa

2 files changed

Lines changed: 585 additions & 109 deletions

File tree

python/tvm/relax/frontend/tflite/tflite_frontend.py

Lines changed: 80 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,6 @@
1919
# pylint: disable=no-value-for-parameter, unused-variable
2020
# pylint: disable=unexpected-keyword-arg, unused-import, too-many-function-args
2121
# ruff: noqa: RUF005
22-
# F821: remaining _qnn references (requantize, conv2d, dense, concat,
23-
# conv2d_transpose, and detection-postprocess dequantize) are in
24-
# not-yet-covered code paths and will be resolved as quantized op support
25-
# advances. _expr references will be resolved when vision ops are added.
26-
# ruff: noqa: F821
2722
"""Tensorflow lite frontend."""
2823

2924
import functools
@@ -792,12 +787,15 @@ def convert_reshape(self, op):
792787
"TFLite reshape requires input and output scale and zero points to be equal"
793788
)
794789

795-
out = relax.op.reshape(in_expr, shape=relax.ShapeExpr(target_shape))
796790
if input_tensor.qnn_params and input_tensor_type_str == "uint8":
797791
output_tensor = output_tensors[0]
798792
if not self.has_same_qnn_params(input_tensor, output_tensor):
793+
in_f32 = self.dequantize(in_expr, input_tensor)
794+
out = relax.op.reshape(in_f32, shape=relax.ShapeExpr(target_shape))
799795
out = self.quantize(out, output_tensor)
796+
return out
800797

798+
out = relax.op.reshape(in_expr, shape=relax.ShapeExpr(target_shape))
801799
return out
802800

803801
def _convert_resize(self, method, op):
@@ -1265,18 +1263,12 @@ def convert_concatenation(self, op):
12651263
if not input_tensors[0].qnn_params:
12661264
out = relax.op.concat(in_exprs, axis=concatenation_axis)
12671265
else:
1268-
input_scales = [input_tensor.qnn_params["scale"] for input_tensor in input_tensors]
1269-
input_zero_points = [
1270-
input_tensor.qnn_params["zero_point"] for input_tensor in input_tensors
1266+
in_f32s = [
1267+
self.dequantize(expr, tensor)
1268+
for expr, tensor in zip(in_exprs, input_tensors)
12711269
]
1272-
out = _qnn.op.concat(
1273-
in_exprs,
1274-
input_scales=input_scales,
1275-
input_zero_points=input_zero_points,
1276-
output_scale=output_tensor.qnn_params["scale"],
1277-
output_zero_point=output_tensor.qnn_params["zero_point"],
1278-
axis=concatenation_axis,
1279-
)
1270+
out = relax.op.concat(in_f32s, axis=concatenation_axis)
1271+
out = self.quantize(out, output_tensor)
12801272

12811273
# Handle fused activations
12821274
if output_tensor.qnn_params:
@@ -2518,20 +2510,18 @@ def convert_fully_connected(self, op):
25182510
)
25192511

25202512
weight_expr = self.get_tensor_expr(weight_tensor)
2521-
weight_shape = weight_expr.struct_info.shape
25222513
weight_expr = relax.op.permute_dims(weight_expr, [1, 0])
25232514

25242515
if input_tensor.qnn_params:
2525-
out = _qnn.op.dense(
2526-
in_expr,
2516+
# Dequantize input and weight (OC remapped from axis 0 to 1)
2517+
in_f32 = self.dequantize(in_expr, input_tensor)
2518+
w_f32 = relax.op.dequantize(
25272519
weight_expr,
2528-
input_zero_point=input_tensor.qnn_params["zero_point"],
2529-
kernel_zero_point=weight_tensor.qnn_params["zero_point"],
2530-
input_scale=input_tensor.qnn_params["scale"],
2531-
kernel_scale=weight_tensor.qnn_params["scale"],
2532-
units=weight_shape[0],
2533-
out_dtype="int64" if output_tensor_type_str == "int16" else "int32",
2520+
scale=weight_tensor.qnn_params["scale"],
2521+
zero_point=weight_tensor.qnn_params["zero_point"],
2522+
axis=1,
25342523
)
2524+
out = relax.op.matmul(in_f32, w_f32)
25352525
else:
25362526
out = relax.op.matmul(in_expr, weight_expr)
25372527

@@ -2555,27 +2545,27 @@ def convert_fully_connected(self, op):
25552545
dtype=bias_tensor_type_str,
25562546
source_name=bias_tensor.tensor.Name(),
25572547
)
2548+
if bias_tensor.qnn_params:
2549+
bias_expr = self.dequantize(bias_expr, bias_tensor)
2550+
elif input_tensor.qnn_params and bias_tensor_type in (
2551+
TensorType.INT32,
2552+
TensorType.INT64,
2553+
):
2554+
bias_scale_val = (
2555+
get_scalar_from_constant(input_tensor.qnn_params["scale"])
2556+
* get_scalar_from_constant(weight_tensor.qnn_params["scale"])
2557+
)
2558+
bias_expr = relax.op.dequantize(
2559+
bias_expr,
2560+
scale=relax.const(bias_scale_val, "float32"),
2561+
zero_point=relax.const(0, "int32"),
2562+
axis=0,
2563+
)
25582564
out = relax.op.add(out, bias_expr)
25592565

2560-
# Finally if the dense is quantized. Add a requantize at the end.
2566+
# Finally if the dense is quantized. Quantize the output.
25612567
if output_tensor.qnn_params:
2562-
data_scale = input_tensor.qnn_params["scale"]
2563-
weight_scale = weight_tensor.qnn_params["scale"]
2564-
data_scale_val = get_scalar_from_constant(data_scale)
2565-
weight_scale_val = get_scalar_from_constant(weight_scale)
2566-
new_input_scale_val = data_scale_val * weight_scale_val
2567-
new_input_scale = relax.const(new_input_scale_val, "float32")
2568-
new_input_zero_point = relax.const(0, "int32")
2569-
2570-
# Requantize
2571-
out = _qnn.op.requantize(
2572-
out,
2573-
input_scale=new_input_scale,
2574-
input_zero_point=new_input_zero_point,
2575-
output_scale=output_tensor.qnn_params["scale"],
2576-
output_zero_point=output_tensor.qnn_params["zero_point"],
2577-
out_dtype=output_tensor_type_str,
2578-
)
2568+
out = self.quantize(out, output_tensor)
25792569

25802570
# Call activation function
25812571
output_scale_val = get_scalar_from_constant(output_tensor.qnn_params["scale"])
@@ -2794,7 +2784,14 @@ def convert_conv(self, op, conv_type):
27942784
# After transpose to HWIO: [KH, KW, IC, OC]
27952785
# QuantizedDimension() == 0 (OC in original) → axis 3 in HWIO.
27962786
weight_axis = weight_tensor.qnn_params["axis"]
2797-
if not is_depthwise_conv:
2787+
if is_depthwise_conv:
2788+
if weight_axis != 0:
2789+
raise tvm.error.OpNotImplemented(
2790+
"Per-channel quantized depthwise convolution is not supported "
2791+
"because the channel axis changes semantics after the "
2792+
"[1,KH,KW,C*M] → [KH,KW,C,M] reshape."
2793+
)
2794+
else:
27982795
weight_axis = 3
27992796
w_f32 = relax.op.dequantize(
28002797
weight_expr,
@@ -2836,7 +2833,11 @@ def convert_conv(self, op, conv_type):
28362833
):
28372834
bias_expr = relax.op.dequantize(
28382835
bias_expr,
2839-
scale=input_tensor.qnn_params["scale"],
2836+
scale=relax.const(
2837+
get_scalar_from_constant(input_tensor.qnn_params["scale"])
2838+
* get_scalar_from_constant(weight_tensor.qnn_params["scale"]),
2839+
"float32",
2840+
),
28402841
zero_point=relax.const(0, "int32"),
28412842
axis=0,
28422843
)
@@ -4328,25 +4329,21 @@ def convert_transpose_conv(self, op):
43284329
padding = (0, 0, 0, 0)
43294330

43304331
if input_tensor.qnn_params:
4331-
input_zero_point = input_tensor.qnn_params["zero_point"]
4332-
kernel_zero_point = weights_tensor.qnn_params["zero_point"]
4333-
input_scale = input_tensor.qnn_params["scale"]
4334-
kernel_scale = weights_tensor.qnn_params["scale"]
4335-
out_dtype = "int64" if output_tensor_type_str == "int16" else "int32"
4336-
out = _qnn.op.conv2d_transpose(
4337-
in_expr,
4332+
in_f32 = self.dequantize(in_expr, input_tensor)
4333+
w_f32 = relax.op.dequantize(
43384334
weight_expr_iohw,
4339-
input_zero_point,
4340-
kernel_zero_point,
4341-
input_scale,
4342-
kernel_scale,
4335+
scale=weights_tensor.qnn_params["scale"],
4336+
zero_point=weights_tensor.qnn_params["zero_point"],
4337+
axis=1,
4338+
)
4339+
out = relax.op.nn.conv2d_transpose(
4340+
in_f32,
4341+
w_f32,
43434342
strides=(stride_h, stride_w),
43444343
padding=padding,
4345-
channels=int(out_channels),
4346-
kernel_size=(int(kernel_h), int(kernel_w)),
43474344
data_layout="NHWC",
43484345
kernel_layout="IOHW",
4349-
out_dtype=out_dtype,
4346+
out_dtype="float32",
43504347
)
43514348
else:
43524349
out = relax.op.nn.conv2d_transpose(
@@ -4378,34 +4375,26 @@ def convert_transpose_conv(self, op):
43784375
dtype=bias_tensor_type_str,
43794376
source_name=bias_tensor.tensor.Name(),
43804377
)
4381-
channel_axis = 3
4382-
out = relax.op.nn.bias_add(out, bias_expr, axis=channel_axis)
4378+
if bias_tensor.qnn_params:
4379+
bias_expr = self.dequantize(bias_expr, bias_tensor)
4380+
elif input_tensor.qnn_params and bias_tensor_type in (
4381+
TensorType.INT32,
4382+
TensorType.INT64,
4383+
):
4384+
bias_scale_val = (
4385+
get_scalar_from_constant(input_tensor.qnn_params["scale"])
4386+
* get_scalar_from_constant(weights_tensor.qnn_params["scale"])
4387+
)
4388+
bias_expr = relax.op.dequantize(
4389+
bias_expr,
4390+
scale=relax.const(bias_scale_val, "float32"),
4391+
zero_point=relax.const(0, "int32"),
4392+
axis=0,
4393+
)
4394+
out = relax.op.add(out, bias_expr)
43834395

43844396
if output_tensor.qnn_params:
4385-
# Calculate the intermediate scale and zero point of the int32 output.
4386-
data_scale = input_tensor.qnn_params["scale"]
4387-
data_scale_val = get_scalar_from_constant(data_scale)
4388-
4389-
weight_scale = weights_tensor.qnn_params["scale"]
4390-
# If weight scale is scalar, it is per-tensor quantization
4391-
if isinstance(weight_scale, float):
4392-
weight_scale_val = get_scalar_from_constant(weight_scale)
4393-
else:
4394-
weight_scale_val = get_tensor_from_constant(weight_scale)
4395-
4396-
new_input_scale_val = data_scale_val * weight_scale_val
4397-
new_input_scale = relax.const(new_input_scale_val, "float32")
4398-
new_input_zero_point = relax.const(0, "int32")
4399-
4400-
out = _qnn.op.requantize(
4401-
out,
4402-
input_scale=new_input_scale,
4403-
input_zero_point=new_input_zero_point,
4404-
output_scale=output_tensor.qnn_params["scale"],
4405-
output_zero_point=output_tensor.qnn_params["zero_point"],
4406-
out_dtype=output_tensor_type_str,
4407-
axis=3,
4408-
)
4397+
out = self.quantize(out, output_tensor)
44094398
return out
44104399

44114400
def convert_quantize(self, op):
@@ -4420,7 +4409,6 @@ def convert_quantize(self, op):
44204409
output_tensors = self.get_output_tensors(op)
44214410
assert len(output_tensors) == 1, "output tensors length should be 1"
44224411
output_tensor = output_tensors[0]
4423-
output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())
44244412

44254413
# The output must be quantized
44264414
assert output_tensor.qnn_params
@@ -4429,9 +4417,8 @@ def convert_quantize(self, op):
44294417
if input_tensor_type_str == "float32":
44304418
out = self.quantize(in_expr, output_tensor)
44314419
else:
4432-
raise tvm.error.OpNotImplemented(
4433-
"TFLite QUANTIZE acting as requantize is not supported yet"
4434-
)
4420+
in_f32 = self.dequantize(in_expr, input_tensor)
4421+
out = self.quantize(in_f32, output_tensor)
44354422
return out
44364423

44374424
def convert_dequantize(self, op):
@@ -4580,23 +4567,11 @@ def convert_detection_postprocess(self, op):
45804567
)
45814568

45824569
if inputs[0].qnn_params:
4583-
loc_prob = _qnn.op.dequantize(
4584-
data=loc_prob,
4585-
input_scale=inputs[0].qnn_params["scale"],
4586-
input_zero_point=inputs[0].qnn_params["zero_point"],
4587-
)
4570+
loc_prob = self.dequantize(loc_prob, inputs[0])
45884571
if inputs[1].qnn_params:
4589-
cls_pred = _qnn.op.dequantize(
4590-
data=cls_pred,
4591-
input_scale=inputs[1].qnn_params["scale"],
4592-
input_zero_point=inputs[1].qnn_params["zero_point"],
4593-
)
4572+
cls_pred = self.dequantize(cls_pred, inputs[1])
45944573
if inputs[2].qnn_params:
4595-
anchor_expr = _qnn.op.dequantize(
4596-
data=anchor_expr,
4597-
input_scale=inputs[2].qnn_params["scale"],
4598-
input_zero_point=inputs[2].qnn_params["zero_point"],
4599-
)
4574+
anchor_expr = self.dequantize(anchor_expr, inputs[2])
46004575

46014576
# loc_prob coords are in yxhw format
46024577
# need to convert to xywh

0 commit comments

Comments
 (0)