1919# pylint: disable=no-value-for-parameter, unused-variable
2020# pylint: disable=unexpected-keyword-arg, unused-import, too-many-function-args
2121# ruff: noqa: RUF005
22- # F821: remaining _qnn references (requantize, conv2d, dense, concat,
23- # conv2d_transpose, and detection-postprocess dequantize) are in
24- # not-yet-covered code paths and will be resolved as quantized op support
25- # advances. _expr references will be resolved when vision ops are added.
26- # ruff: noqa: F821
2722"""Tensorflow lite frontend."""
2823
2924import functools
@@ -792,12 +787,15 @@ def convert_reshape(self, op):
792787 "TFLite reshape requires input and output scale and zero points to be equal"
793788 )
794789
795- out = relax .op .reshape (in_expr , shape = relax .ShapeExpr (target_shape ))
796790 if input_tensor .qnn_params and input_tensor_type_str == "uint8" :
797791 output_tensor = output_tensors [0 ]
798792 if not self .has_same_qnn_params (input_tensor , output_tensor ):
793+ in_f32 = self .dequantize (in_expr , input_tensor )
794+ out = relax .op .reshape (in_f32 , shape = relax .ShapeExpr (target_shape ))
799795 out = self .quantize (out , output_tensor )
796+ return out
800797
798+ out = relax .op .reshape (in_expr , shape = relax .ShapeExpr (target_shape ))
801799 return out
802800
803801 def _convert_resize (self , method , op ):
@@ -1265,18 +1263,12 @@ def convert_concatenation(self, op):
12651263 if not input_tensors [0 ].qnn_params :
12661264 out = relax .op .concat (in_exprs , axis = concatenation_axis )
12671265 else :
1268- input_scales = [input_tensor . qnn_params [ "scale" ] for input_tensor in input_tensors ]
1269- input_zero_points = [
1270- input_tensor . qnn_params [ "zero_point" ] for input_tensor in input_tensors
1266+ in_f32s = [
1267+ self . dequantize ( expr , tensor )
1268+ for expr , tensor in zip ( in_exprs , input_tensors )
12711269 ]
1272- out = _qnn .op .concat (
1273- in_exprs ,
1274- input_scales = input_scales ,
1275- input_zero_points = input_zero_points ,
1276- output_scale = output_tensor .qnn_params ["scale" ],
1277- output_zero_point = output_tensor .qnn_params ["zero_point" ],
1278- axis = concatenation_axis ,
1279- )
1270+ out = relax .op .concat (in_f32s , axis = concatenation_axis )
1271+ out = self .quantize (out , output_tensor )
12801272
12811273 # Handle fused activations
12821274 if output_tensor .qnn_params :
@@ -2518,20 +2510,24 @@ def convert_fully_connected(self, op):
25182510 )
25192511
25202512 weight_expr = self .get_tensor_expr (weight_tensor )
2521- weight_shape = weight_expr .struct_info .shape
25222513 weight_expr = relax .op .permute_dims (weight_expr , [1 , 0 ])
25232514
25242515 if input_tensor .qnn_params :
2525- out = _qnn .op .dense (
2526- in_expr ,
2516+ # Dequantize input and weight (OC remapped from axis 0 to 1)
2517+ in_f32 = self .dequantize (in_expr , input_tensor )
2518+ weight_axis = weight_tensor .qnn_params ["axis" ]
2519+ if weight_axis != 0 :
2520+ raise tvm .error .OpAttributeInvalid (
2521+ f"FC weight QuantizedDimension() must be 0 (output-channel "
2522+ f"axis in [OC,IC] layout), got { weight_axis } "
2523+ )
2524+ w_f32 = relax .op .dequantize (
25272525 weight_expr ,
2528- input_zero_point = input_tensor .qnn_params ["zero_point" ],
2529- kernel_zero_point = weight_tensor .qnn_params ["zero_point" ],
2530- input_scale = input_tensor .qnn_params ["scale" ],
2531- kernel_scale = weight_tensor .qnn_params ["scale" ],
2532- units = weight_shape [0 ],
2533- out_dtype = "int64" if output_tensor_type_str == "int16" else "int32" ,
2526+ scale = weight_tensor .qnn_params ["scale" ],
2527+ zero_point = weight_tensor .qnn_params ["zero_point" ],
2528+ axis = 1 ,
25342529 )
2530+ out = relax .op .matmul (in_f32 , w_f32 )
25352531 else :
25362532 out = relax .op .matmul (in_expr , weight_expr )
25372533
@@ -2555,27 +2551,27 @@ def convert_fully_connected(self, op):
25552551 dtype = bias_tensor_type_str ,
25562552 source_name = bias_tensor .tensor .Name (),
25572553 )
2554+ if bias_tensor .qnn_params :
2555+ bias_expr = self .dequantize (bias_expr , bias_tensor )
2556+ elif input_tensor .qnn_params and bias_tensor_type in (
2557+ TensorType .INT32 ,
2558+ TensorType .INT64 ,
2559+ ):
2560+ bias_scale = relax .op .multiply (
2561+ input_tensor .qnn_params ["scale" ],
2562+ weight_tensor .qnn_params ["scale" ],
2563+ )
2564+ bias_expr = relax .op .dequantize (
2565+ bias_expr ,
2566+ scale = bias_scale ,
2567+ zero_point = relax .const (0 , "int32" ),
2568+ axis = 0 ,
2569+ )
25582570 out = relax .op .add (out , bias_expr )
25592571
2560- # Finally if the dense is quantized. Add a requantize at the end .
2572+ # Finally if the dense is quantized. Quantize the output .
25612573 if output_tensor .qnn_params :
2562- data_scale = input_tensor .qnn_params ["scale" ]
2563- weight_scale = weight_tensor .qnn_params ["scale" ]
2564- data_scale_val = get_scalar_from_constant (data_scale )
2565- weight_scale_val = get_scalar_from_constant (weight_scale )
2566- new_input_scale_val = data_scale_val * weight_scale_val
2567- new_input_scale = relax .const (new_input_scale_val , "float32" )
2568- new_input_zero_point = relax .const (0 , "int32" )
2569-
2570- # Requantize
2571- out = _qnn .op .requantize (
2572- out ,
2573- input_scale = new_input_scale ,
2574- input_zero_point = new_input_zero_point ,
2575- output_scale = output_tensor .qnn_params ["scale" ],
2576- output_zero_point = output_tensor .qnn_params ["zero_point" ],
2577- out_dtype = output_tensor_type_str ,
2578- )
2574+ out = self .quantize (out , output_tensor )
25792575
25802576 # Call activation function
25812577 output_scale_val = get_scalar_from_constant (output_tensor .qnn_params ["scale" ])
@@ -2794,7 +2790,19 @@ def convert_conv(self, op, conv_type):
27942790 # After transpose to HWIO: [KH, KW, IC, OC]
27952791 # QuantizedDimension() == 0 (OC in original) → axis 3 in HWIO.
27962792 weight_axis = weight_tensor .qnn_params ["axis" ]
2797- if not is_depthwise_conv :
2793+ if is_depthwise_conv :
2794+ if weight_axis != 0 :
2795+ raise tvm .error .OpNotImplemented (
2796+ "Per-channel quantized depthwise convolution is not supported "
2797+ "because the channel axis changes semantics after the "
2798+ "[1,KH,KW,C*M] → [KH,KW,C,M] reshape."
2799+ )
2800+ else :
2801+ if weight_axis != 0 :
2802+ raise tvm .error .OpAttributeInvalid (
2803+ f"Conv2D weight QuantizedDimension() must be 0 (output-channel "
2804+ f"axis in [OC,KH,KW,IC] layout), got { weight_axis } "
2805+ )
27982806 weight_axis = 3
27992807 w_f32 = relax .op .dequantize (
28002808 weight_expr ,
@@ -2836,7 +2844,10 @@ def convert_conv(self, op, conv_type):
28362844 ):
28372845 bias_expr = relax .op .dequantize (
28382846 bias_expr ,
2839- scale = input_tensor .qnn_params ["scale" ],
2847+ scale = relax .op .multiply (
2848+ input_tensor .qnn_params ["scale" ],
2849+ weight_tensor .qnn_params ["scale" ],
2850+ ),
28402851 zero_point = relax .const (0 , "int32" ),
28412852 axis = 0 ,
28422853 )
@@ -4328,25 +4339,27 @@ def convert_transpose_conv(self, op):
43284339 padding = (0 , 0 , 0 , 0 )
43294340
43304341 if input_tensor .qnn_params :
4331- input_zero_point = input_tensor .qnn_params ["zero_point" ]
4332- kernel_zero_point = weights_tensor .qnn_params ["zero_point" ]
4333- input_scale = input_tensor .qnn_params ["scale" ]
4334- kernel_scale = weights_tensor .qnn_params ["scale" ]
4335- out_dtype = "int64" if output_tensor_type_str == "int16" else "int32"
4336- out = _qnn .op .conv2d_transpose (
4337- in_expr ,
4342+ in_f32 = self .dequantize (in_expr , input_tensor )
4343+ weight_axis = weights_tensor .qnn_params ["axis" ]
4344+ if weight_axis != 0 :
4345+ raise tvm .error .OpAttributeInvalid (
4346+ f"TransposeConv weight QuantizedDimension() must be 0 "
4347+ f"(output-channel axis in OHWI layout), got { weight_axis } "
4348+ )
4349+ w_f32 = relax .op .dequantize (
43384350 weight_expr_iohw ,
4339- input_zero_point ,
4340- kernel_zero_point ,
4341- input_scale ,
4342- kernel_scale ,
4351+ scale = weights_tensor .qnn_params ["scale" ],
4352+ zero_point = weights_tensor .qnn_params ["zero_point" ],
4353+ axis = 1 ,
4354+ )
4355+ out = relax .op .nn .conv2d_transpose (
4356+ in_f32 ,
4357+ w_f32 ,
43434358 strides = (stride_h , stride_w ),
43444359 padding = padding ,
4345- channels = int (out_channels ),
4346- kernel_size = (int (kernel_h ), int (kernel_w )),
43474360 data_layout = "NHWC" ,
43484361 kernel_layout = "IOHW" ,
4349- out_dtype = out_dtype ,
4362+ out_dtype = "float32" ,
43504363 )
43514364 else :
43524365 out = relax .op .nn .conv2d_transpose (
@@ -4378,34 +4391,26 @@ def convert_transpose_conv(self, op):
43784391 dtype = bias_tensor_type_str ,
43794392 source_name = bias_tensor .tensor .Name (),
43804393 )
4381- channel_axis = 3
4382- out = relax .op .nn .bias_add (out , bias_expr , axis = channel_axis )
4394+ if bias_tensor .qnn_params :
4395+ bias_expr = self .dequantize (bias_expr , bias_tensor )
4396+ elif input_tensor .qnn_params and bias_tensor_type in (
4397+ TensorType .INT32 ,
4398+ TensorType .INT64 ,
4399+ ):
4400+ bias_scale = relax .op .multiply (
4401+ input_tensor .qnn_params ["scale" ],
4402+ weights_tensor .qnn_params ["scale" ],
4403+ )
4404+ bias_expr = relax .op .dequantize (
4405+ bias_expr ,
4406+ scale = bias_scale ,
4407+ zero_point = relax .const (0 , "int32" ),
4408+ axis = 0 ,
4409+ )
4410+ out = relax .op .add (out , bias_expr )
43834411
43844412 if output_tensor .qnn_params :
4385- # Calculate the intermediate scale and zero point of the int32 output.
4386- data_scale = input_tensor .qnn_params ["scale" ]
4387- data_scale_val = get_scalar_from_constant (data_scale )
4388-
4389- weight_scale = weights_tensor .qnn_params ["scale" ]
4390- # If weight scale is scalar, it is per-tensor quantization
4391- if isinstance (weight_scale , float ):
4392- weight_scale_val = get_scalar_from_constant (weight_scale )
4393- else :
4394- weight_scale_val = get_tensor_from_constant (weight_scale )
4395-
4396- new_input_scale_val = data_scale_val * weight_scale_val
4397- new_input_scale = relax .const (new_input_scale_val , "float32" )
4398- new_input_zero_point = relax .const (0 , "int32" )
4399-
4400- out = _qnn .op .requantize (
4401- out ,
4402- input_scale = new_input_scale ,
4403- input_zero_point = new_input_zero_point ,
4404- output_scale = output_tensor .qnn_params ["scale" ],
4405- output_zero_point = output_tensor .qnn_params ["zero_point" ],
4406- out_dtype = output_tensor_type_str ,
4407- axis = 3 ,
4408- )
4413+ out = self .quantize (out , output_tensor )
44094414 return out
44104415
44114416 def convert_quantize (self , op ):
@@ -4420,7 +4425,6 @@ def convert_quantize(self, op):
44204425 output_tensors = self .get_output_tensors (op )
44214426 assert len (output_tensors ) == 1 , "output tensors length should be 1"
44224427 output_tensor = output_tensors [0 ]
4423- output_tensor_type_str = self .get_tensor_type_str (output_tensor .tensor .Type ())
44244428
44254429 # The output must be quantized
44264430 assert output_tensor .qnn_params
@@ -4429,9 +4433,8 @@ def convert_quantize(self, op):
44294433 if input_tensor_type_str == "float32" :
44304434 out = self .quantize (in_expr , output_tensor )
44314435 else :
4432- raise tvm .error .OpNotImplemented (
4433- "TFLite QUANTIZE acting as requantize is not supported yet"
4434- )
4436+ in_f32 = self .dequantize (in_expr , input_tensor )
4437+ out = self .quantize (in_f32 , output_tensor )
44354438 return out
44364439
44374440 def convert_dequantize (self , op ):
@@ -4580,23 +4583,11 @@ def convert_detection_postprocess(self, op):
45804583 )
45814584
45824585 if inputs [0 ].qnn_params :
4583- loc_prob = _qnn .op .dequantize (
4584- data = loc_prob ,
4585- input_scale = inputs [0 ].qnn_params ["scale" ],
4586- input_zero_point = inputs [0 ].qnn_params ["zero_point" ],
4587- )
4586+ loc_prob = self .dequantize (loc_prob , inputs [0 ])
45884587 if inputs [1 ].qnn_params :
4589- cls_pred = _qnn .op .dequantize (
4590- data = cls_pred ,
4591- input_scale = inputs [1 ].qnn_params ["scale" ],
4592- input_zero_point = inputs [1 ].qnn_params ["zero_point" ],
4593- )
4588+ cls_pred = self .dequantize (cls_pred , inputs [1 ])
45944589 if inputs [2 ].qnn_params :
4595- anchor_expr = _qnn .op .dequantize (
4596- data = anchor_expr ,
4597- input_scale = inputs [2 ].qnn_params ["scale" ],
4598- input_zero_point = inputs [2 ].qnn_params ["zero_point" ],
4599- )
4590+ anchor_expr = self .dequantize (anchor_expr , inputs [2 ])
46004591
46014592 # loc_prob coords are in yxhw format
46024593 # need to convert to xywh
0 commit comments