Skip to content

Commit 9c1feac

Browse files
feat(tflite): support bias use int64 dtype
1 parent 6e19d20 commit 9c1feac

5 files changed

Lines changed: 45 additions & 1 deletion

File tree

bin/convert

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ complete_str = """_mgeconvert(){
9494
return
9595
;;
9696
tracedmodule_to_tflite)
97-
words="-i --input -o --output --input_data_type --input_scales --input_zero_points --require_quantize --param_fake_quant --quantize_file_path --graph_name --mtk --end_point --outspec --remove_relu --prefer_same_pad_mode"
97+
words="-i --input -o --output --input_data_type --input_scales --input_zero_points --require_quantize --param_fake_quant --quantize_file_path --graph_name --mtk --end_point --outspec --remove_relu --prefer_same_pad_mode --use_int64_bias"
9898
COMPREPLY=( $(compgen -W "$words" -- $word) )
9999
return
100100
;;
@@ -397,6 +397,7 @@ def init(subparsers):
397397
outspec=args.outspec,
398398
remove_relu=args.remove_relu,
399399
prefer_same_pad_mode=args.prefer_same_pad_mode,
400+
use_int64_bias=args.use_int64_bias,
400401
)
401402
else:
402403
mgeconvert.mge_to_tflite(
@@ -406,6 +407,7 @@ def init(subparsers):
406407
mtk=args.mtk,
407408
outspec=args.outspec,
408409
prefer_same_pad_mode=args.prefer_same_pad_mode,
410+
use_int64_bias=args.use_int64_bias,
409411
)
410412

411413
def tflite_parser(subparsers):
@@ -488,6 +490,12 @@ def init(subparsers):
488490
help="whether prefer to use SAME pad mode for conv op",
489491
)
490492

493+
p.add_argument(
494+
"--use_int64_bias",
495+
action="store_true",
496+
help="whether use int64 as dtype of bias",
497+
)
498+
491499
tflite_parser(subparsers)
492500

493501

mgeconvert/backend/ir_to_tflite/tflite_op.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def get_shape_param(
233233
"int8": TensorType.INT8,
234234
"int16": TensorType.INT16,
235235
"int32": TensorType.INT32,
236+
"int64": TensorType.INT64,
236237
"qint8_narrow": TensorType.INT8,
237238
}
238239

mgeconvert/converter_ir/ir_transform.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
TanHOpr,
4949
TransposeOpr,
5050
TrueDivOpr,
51+
_ConvOpr,
5152
_PoolOpr,
5253
)
5354
from .ir_tensor import AxisOrder, IRTensor
@@ -121,6 +122,7 @@ class TransformerRule(Enum):
121122
TRANSPOSE_LINEAR_WEIGHT_TO_NHWC = 133
122123
# force fc with no trans for megengine
123124
FC_NO_TRANS = 134
125+
BIAS_ASTYPE_INT64 = 135
124126

125127

126128
def cmp_rules(a, b):
@@ -1537,3 +1539,29 @@ def trans_tensor(tensor):
15371539
if opr.transpose_b and tensor_b.owner_opr == None:
15381540
opr.transpose_b = False
15391541
trans_tensor(tensor_b)
1542+
1543+
1544+
@_register_tranformation_rule(TransformerRule.BIAS_ASTYPE_INT64)
1545+
def _bias_astype_int64(net: IRGraph):
1546+
for opr in net.all_oprs:
1547+
if not isinstance(opr, (MatMulOpr, _ConvOpr)):
1548+
continue
1549+
bias = None
1550+
if isinstance(opr, MatMulOpr) and len(opr.inp_tensors) == 3:
1551+
bias = opr.inp_tensors[2]
1552+
elif isinstance(opr, Deconv2dOpr) and len(opr.inp_tensors) > 3:
1553+
if (
1554+
opr.inp_tensors[0].shape == [4] and len(opr.inp_tensors) == 4
1555+
): # shape as input
1556+
bias = opr.inp_tensors[-1]
1557+
if len(opr.inp_tensors) == 4:
1558+
bias = opr.inp_tensors[-1]
1559+
elif isinstance(opr, (Conv2dOpr, ConvRelu2dOpr)) and len(opr.inp_tensors) == 3:
1560+
bias = opr.inp_tensors[-1]
1561+
if bias is not None and bias.scale is not None:
1562+
bias.set_qparams(
1563+
scale=bias.scale,
1564+
zero_point=bias.zero_point,
1565+
q_dtype="int64",
1566+
np_dtype="int64",
1567+
)

mgeconvert/converters/mge_to_tflite.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def mge_to_tflite(
2424
disable_nhwc=False,
2525
outspec=None,
2626
prefer_same_pad_mode=False,
27+
use_int64_bias=False,
2728
):
2829
"""
2930
Convert megengine model to TFLite,
@@ -70,6 +71,9 @@ def mge_to_tflite(
7071
transformer_options.append(TransformerRule.DECONV_ADD_ZERO_BIAS,)
7172
transformer_options.append(TransformerRule.FUSE_FOR_DECONV_BIAS,)
7273

74+
if use_int64_bias:
75+
transformer_options.append(TransformerRule.BIAS_ASTYPE_INT64)
76+
7377
transformer = IRTransform(transformer_options)
7478
transformed_irgraph = transformer.transform(irgraph)
7579

mgeconvert/converters/tm_to_tflite.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def tracedmodule_to_tflite(
3939
remove_relu=False,
4040
prefer_same_pad_mode=False,
4141
disable_nhwc=False,
42+
use_int64_bias=False,
4243
):
4344
"""
4445
Convert traced model to TFLite,
@@ -89,6 +90,8 @@ def tracedmodule_to_tflite(
8990
transformer_options.append(TransformerRule.DECONV_ADD_ZERO_BIAS,)
9091
if remove_relu:
9192
transformer_options.append(TransformerRule.REMOVE_TFLITE_RELU,)
93+
if use_int64_bias:
94+
transformer_options.append(TransformerRule.BIAS_ASTYPE_INT64)
9295

9396
transformer = IRTransform(transformer_options)
9497
transformed_irgraph = transformer.transform(irgraph)

0 commit comments

Comments
 (0)