Skip to content

Commit 08e92d3

Browse files
committed
[Relax][Frontend][TFLite] Complete QDQ conversion for remaining quantized ops
Replace the last _qnn.op.* references in the TFLite frontend with the DQ → float op → Q pattern, eliminating all references to the non-existent _qnn module. convert_fully_connected: - _qnn.op.dense → DQ input + DQ weight (axis remap OC 0→1) + matmul - _qnn.op.requantize + activation → self.quantize + activation - INT32/INT64 bias dequantized with input_scale × weight_scale convert_concatenation: - _qnn.op.concat → DQ each input → float concat → quantize → activation convert_transpose_conv: - _qnn.op.conv2d_transpose → DQ input + DQ weight (axis remap OHWI→IOHW, OC axis 0→1) + float conv2d_transpose - _qnn.op.requantize → self.quantize - INT32/INT64 bias dequantized (previously missing — added in review fix) convert_detection_postprocess: - 3× _qnn.op.dequantize → self.dequantize convert_reshape (uint8 path): - Requantize on integer tensor → DQ → reshape → Q Depthwise Conv2D: - Explicit OpNotImplemented for per-channel depthwise (axis semantics change after [1,KH,KW,C*M] → [KH,KW,C,M] reshape) Cleanup: - Removed now-unnecessary F821 noqa comment (zero _qnn / _expr refs) - Removed unused locals (weight_shape, output_tensor_type_str) All _qnn.op.* references eliminated. 386 tests pass, ruff clean. Closes #19534.
1 parent d71cab5 commit 08e92d3

2 files changed

Lines changed: 607 additions & 109 deletions

File tree

python/tvm/relax/frontend/tflite/tflite_frontend.py

Lines changed: 96 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,6 @@
1919
# pylint: disable=no-value-for-parameter, unused-variable
2020
# pylint: disable=unexpected-keyword-arg, unused-import, too-many-function-args
2121
# ruff: noqa: RUF005
22-
# F821: remaining _qnn references (requantize, conv2d, dense, concat,
23-
# conv2d_transpose, and detection-postprocess dequantize) are in
24-
# not-yet-covered code paths and will be resolved as quantized op support
25-
# advances. _expr references will be resolved when vision ops are added.
26-
# ruff: noqa: F821
2722
"""Tensorflow lite frontend."""
2823

2924
import functools
@@ -792,12 +787,15 @@ def convert_reshape(self, op):
792787
"TFLite reshape requires input and output scale and zero points to be equal"
793788
)
794789

795-
out = relax.op.reshape(in_expr, shape=relax.ShapeExpr(target_shape))
796790
if input_tensor.qnn_params and input_tensor_type_str == "uint8":
797791
output_tensor = output_tensors[0]
798792
if not self.has_same_qnn_params(input_tensor, output_tensor):
793+
in_f32 = self.dequantize(in_expr, input_tensor)
794+
out = relax.op.reshape(in_f32, shape=relax.ShapeExpr(target_shape))
799795
out = self.quantize(out, output_tensor)
796+
return out
800797

798+
out = relax.op.reshape(in_expr, shape=relax.ShapeExpr(target_shape))
801799
return out
802800

803801
def _convert_resize(self, method, op):
@@ -1265,18 +1263,12 @@ def convert_concatenation(self, op):
12651263
if not input_tensors[0].qnn_params:
12661264
out = relax.op.concat(in_exprs, axis=concatenation_axis)
12671265
else:
1268-
input_scales = [input_tensor.qnn_params["scale"] for input_tensor in input_tensors]
1269-
input_zero_points = [
1270-
input_tensor.qnn_params["zero_point"] for input_tensor in input_tensors
1266+
in_f32s = [
1267+
self.dequantize(expr, tensor)
1268+
for expr, tensor in zip(in_exprs, input_tensors)
12711269
]
1272-
out = _qnn.op.concat(
1273-
in_exprs,
1274-
input_scales=input_scales,
1275-
input_zero_points=input_zero_points,
1276-
output_scale=output_tensor.qnn_params["scale"],
1277-
output_zero_point=output_tensor.qnn_params["zero_point"],
1278-
axis=concatenation_axis,
1279-
)
1270+
out = relax.op.concat(in_f32s, axis=concatenation_axis)
1271+
out = self.quantize(out, output_tensor)
12801272

12811273
# Handle fused activations
12821274
if output_tensor.qnn_params:
@@ -2518,20 +2510,24 @@ def convert_fully_connected(self, op):
25182510
)
25192511

25202512
weight_expr = self.get_tensor_expr(weight_tensor)
2521-
weight_shape = weight_expr.struct_info.shape
25222513
weight_expr = relax.op.permute_dims(weight_expr, [1, 0])
25232514

25242515
if input_tensor.qnn_params:
2525-
out = _qnn.op.dense(
2526-
in_expr,
2516+
# Dequantize input and weight (OC remapped from axis 0 to 1)
2517+
in_f32 = self.dequantize(in_expr, input_tensor)
2518+
weight_axis = weight_tensor.qnn_params["axis"]
2519+
if weight_axis != 0:
2520+
raise tvm.error.OpAttributeInvalid(
2521+
f"FC weight QuantizedDimension() must be 0 (output-channel "
2522+
f"axis in [OC,IC] layout), got {weight_axis}"
2523+
)
2524+
w_f32 = relax.op.dequantize(
25272525
weight_expr,
2528-
input_zero_point=input_tensor.qnn_params["zero_point"],
2529-
kernel_zero_point=weight_tensor.qnn_params["zero_point"],
2530-
input_scale=input_tensor.qnn_params["scale"],
2531-
kernel_scale=weight_tensor.qnn_params["scale"],
2532-
units=weight_shape[0],
2533-
out_dtype="int64" if output_tensor_type_str == "int16" else "int32",
2526+
scale=weight_tensor.qnn_params["scale"],
2527+
zero_point=weight_tensor.qnn_params["zero_point"],
2528+
axis=1,
25342529
)
2530+
out = relax.op.matmul(in_f32, w_f32)
25352531
else:
25362532
out = relax.op.matmul(in_expr, weight_expr)
25372533

@@ -2555,27 +2551,27 @@ def convert_fully_connected(self, op):
25552551
dtype=bias_tensor_type_str,
25562552
source_name=bias_tensor.tensor.Name(),
25572553
)
2554+
if bias_tensor.qnn_params:
2555+
bias_expr = self.dequantize(bias_expr, bias_tensor)
2556+
elif input_tensor.qnn_params and bias_tensor_type in (
2557+
TensorType.INT32,
2558+
TensorType.INT64,
2559+
):
2560+
bias_scale = relax.op.multiply(
2561+
input_tensor.qnn_params["scale"],
2562+
weight_tensor.qnn_params["scale"],
2563+
)
2564+
bias_expr = relax.op.dequantize(
2565+
bias_expr,
2566+
scale=bias_scale,
2567+
zero_point=relax.const(0, "int32"),
2568+
axis=0,
2569+
)
25582570
out = relax.op.add(out, bias_expr)
25592571

2560-
# Finally if the dense is quantized. Add a requantize at the end.
2572+
# Finally if the dense is quantized. Quantize the output.
25612573
if output_tensor.qnn_params:
2562-
data_scale = input_tensor.qnn_params["scale"]
2563-
weight_scale = weight_tensor.qnn_params["scale"]
2564-
data_scale_val = get_scalar_from_constant(data_scale)
2565-
weight_scale_val = get_scalar_from_constant(weight_scale)
2566-
new_input_scale_val = data_scale_val * weight_scale_val
2567-
new_input_scale = relax.const(new_input_scale_val, "float32")
2568-
new_input_zero_point = relax.const(0, "int32")
2569-
2570-
# Requantize
2571-
out = _qnn.op.requantize(
2572-
out,
2573-
input_scale=new_input_scale,
2574-
input_zero_point=new_input_zero_point,
2575-
output_scale=output_tensor.qnn_params["scale"],
2576-
output_zero_point=output_tensor.qnn_params["zero_point"],
2577-
out_dtype=output_tensor_type_str,
2578-
)
2574+
out = self.quantize(out, output_tensor)
25792575

25802576
# Call activation function
25812577
output_scale_val = get_scalar_from_constant(output_tensor.qnn_params["scale"])
@@ -2794,7 +2790,19 @@ def convert_conv(self, op, conv_type):
27942790
# After transpose to HWIO: [KH, KW, IC, OC]
27952791
# QuantizedDimension() == 0 (OC in original) → axis 3 in HWIO.
27962792
weight_axis = weight_tensor.qnn_params["axis"]
2797-
if not is_depthwise_conv:
2793+
if is_depthwise_conv:
2794+
if weight_axis != 0:
2795+
raise tvm.error.OpNotImplemented(
2796+
"Per-channel quantized depthwise convolution is not supported "
2797+
"because the channel axis changes semantics after the "
2798+
"[1,KH,KW,C*M] → [KH,KW,C,M] reshape."
2799+
)
2800+
else:
2801+
if weight_axis != 0:
2802+
raise tvm.error.OpAttributeInvalid(
2803+
f"Conv2D weight QuantizedDimension() must be 0 (output-channel "
2804+
f"axis in [OC,KH,KW,IC] layout), got {weight_axis}"
2805+
)
27982806
weight_axis = 3
27992807
w_f32 = relax.op.dequantize(
28002808
weight_expr,
@@ -2836,7 +2844,10 @@ def convert_conv(self, op, conv_type):
28362844
):
28372845
bias_expr = relax.op.dequantize(
28382846
bias_expr,
2839-
scale=input_tensor.qnn_params["scale"],
2847+
scale=relax.op.multiply(
2848+
input_tensor.qnn_params["scale"],
2849+
weight_tensor.qnn_params["scale"],
2850+
),
28402851
zero_point=relax.const(0, "int32"),
28412852
axis=0,
28422853
)
@@ -4328,25 +4339,27 @@ def convert_transpose_conv(self, op):
43284339
padding = (0, 0, 0, 0)
43294340

43304341
if input_tensor.qnn_params:
4331-
input_zero_point = input_tensor.qnn_params["zero_point"]
4332-
kernel_zero_point = weights_tensor.qnn_params["zero_point"]
4333-
input_scale = input_tensor.qnn_params["scale"]
4334-
kernel_scale = weights_tensor.qnn_params["scale"]
4335-
out_dtype = "int64" if output_tensor_type_str == "int16" else "int32"
4336-
out = _qnn.op.conv2d_transpose(
4337-
in_expr,
4342+
in_f32 = self.dequantize(in_expr, input_tensor)
4343+
weight_axis = weights_tensor.qnn_params["axis"]
4344+
if weight_axis != 0:
4345+
raise tvm.error.OpAttributeInvalid(
4346+
f"TransposeConv weight QuantizedDimension() must be 0 "
4347+
f"(output-channel axis in OHWI layout), got {weight_axis}"
4348+
)
4349+
w_f32 = relax.op.dequantize(
43384350
weight_expr_iohw,
4339-
input_zero_point,
4340-
kernel_zero_point,
4341-
input_scale,
4342-
kernel_scale,
4351+
scale=weights_tensor.qnn_params["scale"],
4352+
zero_point=weights_tensor.qnn_params["zero_point"],
4353+
axis=1,
4354+
)
4355+
out = relax.op.nn.conv2d_transpose(
4356+
in_f32,
4357+
w_f32,
43434358
strides=(stride_h, stride_w),
43444359
padding=padding,
4345-
channels=int(out_channels),
4346-
kernel_size=(int(kernel_h), int(kernel_w)),
43474360
data_layout="NHWC",
43484361
kernel_layout="IOHW",
4349-
out_dtype=out_dtype,
4362+
out_dtype="float32",
43504363
)
43514364
else:
43524365
out = relax.op.nn.conv2d_transpose(
@@ -4378,34 +4391,26 @@ def convert_transpose_conv(self, op):
43784391
dtype=bias_tensor_type_str,
43794392
source_name=bias_tensor.tensor.Name(),
43804393
)
4381-
channel_axis = 3
4382-
out = relax.op.nn.bias_add(out, bias_expr, axis=channel_axis)
4394+
if bias_tensor.qnn_params:
4395+
bias_expr = self.dequantize(bias_expr, bias_tensor)
4396+
elif input_tensor.qnn_params and bias_tensor_type in (
4397+
TensorType.INT32,
4398+
TensorType.INT64,
4399+
):
4400+
bias_scale = relax.op.multiply(
4401+
input_tensor.qnn_params["scale"],
4402+
weights_tensor.qnn_params["scale"],
4403+
)
4404+
bias_expr = relax.op.dequantize(
4405+
bias_expr,
4406+
scale=bias_scale,
4407+
zero_point=relax.const(0, "int32"),
4408+
axis=0,
4409+
)
4410+
out = relax.op.add(out, bias_expr)
43834411

43844412
if output_tensor.qnn_params:
4385-
# Calculate the intermediate scale and zero point of the int32 output.
4386-
data_scale = input_tensor.qnn_params["scale"]
4387-
data_scale_val = get_scalar_from_constant(data_scale)
4388-
4389-
weight_scale = weights_tensor.qnn_params["scale"]
4390-
# If weight scale is scalar, it is per-tensor quantization
4391-
if isinstance(weight_scale, float):
4392-
weight_scale_val = get_scalar_from_constant(weight_scale)
4393-
else:
4394-
weight_scale_val = get_tensor_from_constant(weight_scale)
4395-
4396-
new_input_scale_val = data_scale_val * weight_scale_val
4397-
new_input_scale = relax.const(new_input_scale_val, "float32")
4398-
new_input_zero_point = relax.const(0, "int32")
4399-
4400-
out = _qnn.op.requantize(
4401-
out,
4402-
input_scale=new_input_scale,
4403-
input_zero_point=new_input_zero_point,
4404-
output_scale=output_tensor.qnn_params["scale"],
4405-
output_zero_point=output_tensor.qnn_params["zero_point"],
4406-
out_dtype=output_tensor_type_str,
4407-
axis=3,
4408-
)
4413+
out = self.quantize(out, output_tensor)
44094414
return out
44104415

44114416
def convert_quantize(self, op):
@@ -4420,7 +4425,6 @@ def convert_quantize(self, op):
44204425
output_tensors = self.get_output_tensors(op)
44214426
assert len(output_tensors) == 1, "output tensors length should be 1"
44224427
output_tensor = output_tensors[0]
4423-
output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())
44244428

44254429
# The output must be quantized
44264430
assert output_tensor.qnn_params
@@ -4429,9 +4433,8 @@ def convert_quantize(self, op):
44294433
if input_tensor_type_str == "float32":
44304434
out = self.quantize(in_expr, output_tensor)
44314435
else:
4432-
raise tvm.error.OpNotImplemented(
4433-
"TFLite QUANTIZE acting as requantize is not supported yet"
4434-
)
4436+
in_f32 = self.dequantize(in_expr, input_tensor)
4437+
out = self.quantize(in_f32, output_tensor)
44354438
return out
44364439

44374440
def convert_dequantize(self, op):
@@ -4580,23 +4583,11 @@ def convert_detection_postprocess(self, op):
45804583
)
45814584

45824585
if inputs[0].qnn_params:
4583-
loc_prob = _qnn.op.dequantize(
4584-
data=loc_prob,
4585-
input_scale=inputs[0].qnn_params["scale"],
4586-
input_zero_point=inputs[0].qnn_params["zero_point"],
4587-
)
4586+
loc_prob = self.dequantize(loc_prob, inputs[0])
45884587
if inputs[1].qnn_params:
4589-
cls_pred = _qnn.op.dequantize(
4590-
data=cls_pred,
4591-
input_scale=inputs[1].qnn_params["scale"],
4592-
input_zero_point=inputs[1].qnn_params["zero_point"],
4593-
)
4588+
cls_pred = self.dequantize(cls_pred, inputs[1])
45944589
if inputs[2].qnn_params:
4595-
anchor_expr = _qnn.op.dequantize(
4596-
data=anchor_expr,
4597-
input_scale=inputs[2].qnn_params["scale"],
4598-
input_zero_point=inputs[2].qnn_params["zero_point"],
4599-
)
4590+
anchor_expr = self.dequantize(anchor_expr, inputs[2])
46004591

46014592
# loc_prob coords are in yxhw format
46024593
# need to convert to xywh

0 commit comments

Comments
 (0)