-
Notifications
You must be signed in to change notification settings - Fork 958
Expand file tree
/
Copy pathIRToCode.py
More file actions
129 lines (105 loc) · 4.41 KB
/
Copy pathIRToCode.py
File metadata and controls
129 lines (105 loc) · 4.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import sys as _sys
import google.protobuf.text_format as text_format
from six import text_type as _text_type
def _convert(args):
if args.dstFramework == 'caffe':
from mmdnn.conversion.caffe.caffe_emitter import CaffeEmitter
if args.IRWeightPath is None:
emitter = CaffeEmitter(args.IRModelPath)
else:
assert args.dstWeightPath
emitter = CaffeEmitter((args.IRModelPath, args.IRWeightPath))
elif args.dstFramework == 'keras':
from mmdnn.conversion.keras.keras2_emitter import Keras2Emitter
emitter = Keras2Emitter(args.IRModelPath)
elif args.dstFramework == 'tensorflow':
from mmdnn.conversion.tensorflow.tensorflow_emitter import TensorflowEmitter
if args.IRWeightPath is None:
# Convert network architecture only
emitter = TensorflowEmitter(args.IRModelPath)
else:
emitter = TensorflowEmitter((args.IRModelPath, args.IRWeightPath))
elif args.dstFramework == 'cntk':
from mmdnn.conversion.cntk.cntk_emitter import CntkEmitter
if args.IRWeightPath is None:
emitter = CntkEmitter(args.IRModelPath)
else:
emitter = CntkEmitter((args.IRModelPath, args.IRWeightPath))
elif args.dstFramework == 'coreml':
raise NotImplementedError("CoreML emitter is not finished yet.")
elif args.dstFramework == 'pytorch':
if not args.dstWeightPath or not args.IRWeightPath:
raise ValueError("Need to set a target weight filename.")
from mmdnn.conversion.pytorch.pytorch_emitter import PytorchEmitter
emitter = PytorchEmitter((args.IRModelPath, args.IRWeightPath))
elif args.dstFramework == 'mxnet':
from mmdnn.conversion.mxnet.mxnet_emitter import MXNetEmitter
if args.IRWeightPath is None:
emitter = MXNetEmitter(args.IRModelPath)
else:
if args.dstWeightPath is None:
raise ValueError("MXNet emitter needs argument [dstWeightPath(dw)], like -dw mxnet_converted-0000.param")
emitter = MXNetEmitter((args.IRModelPath, args.IRWeightPath, args.dstWeightPath))
elif args.dstFramework == 'onnx':
from mmdnn.conversion.onnx.onnx_emitter import OnnxEmitter
if args.IRWeightPath is None:
raise NotImplementedError("ONNX emitter needs IR weight file")
else:
emitter = OnnxEmitter(args.IRModelPath, args.IRWeightPath)
elif args.dstFramework == 'darknet':
from mmdnn.conversion.darknet.darknet_emitter import DarknetEmitter
if args.IRWeightPath is None:
emitter = DarknetEmitter(args.IRModelPath)
else:
assert args.dstWeightPath
emitter = DarknetEmitter((args.IRModelPath, args.IRWeightPath))
else:
assert False
emitter.run(args.dstModelPath, args.dstWeightPath, args.phase)
return 0
def _get_parser():
import argparse
parser = argparse.ArgumentParser(description = 'Convert IR model file formats to other format.')
parser.add_argument(
'--phase',
type=_text_type,
choices=['train', 'test'],
default='test',
help='Convert phase (train/test) for destination toolkits.'
)
parser.add_argument(
'--dstFramework', '-f',
type=_text_type,
choices=['caffe', 'caffe2', 'cntk', 'mxnet', 'keras', 'tensorflow', 'coreml', 'pytorch', 'onnx', 'darknet'],
required=True,
help='Format of model at srcModelPath (default is to auto-detect).')
parser.add_argument(
'--IRModelPath', '-n', '-in',
type=_text_type,
required=True,
help='Path to the IR network structure file.')
parser.add_argument(
'--IRWeightPath', '-w', '-iw',
type=_text_type,
required=False,
default=None,
help = 'Path to the IR network structure file.')
parser.add_argument(
'--dstModelPath', '-d', '-o',
type = _text_type,
required = True,
help = 'Path to save the destination model')
# MXNet
parser.add_argument(
'--dstWeightPath', '-dw', '-ow',
type=_text_type,
default=None,
help='[MXNet] Path to save the destination weight.')
return parser
def _main():
parser=_get_parser()
args = parser.parse_args()
ret = _convert(args)
_sys.exit(int(ret)) # cast to int or else the exit code is always 1
if __name__ == '__main__':
_main()