-
Notifications
You must be signed in to change notification settings - Fork 958
Expand file tree
/
Copy pathdump_code.py
More file actions
86 lines (69 loc) · 2.72 KB
/
Copy pathdump_code.py
File metadata and controls
86 lines (69 loc) · 2.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import sys as _sys
from six import text_type as _text_type
import sys
import imp
import os.path
def dump_code(framework, network_filepath, weight_filepath, dump_filepath, dump_tag):
if network_filepath.endswith('.py'):
network_filepath = network_filepath[:-3]
sys.path.insert(0, os.path.dirname(os.path.abspath(network_filepath)))
MainModel = imp.load_source('MainModel', network_filepath + '.py')
if framework == 'caffe':
from mmdnn.conversion.caffe.saver import save_model
elif framework == 'cntk':
from mmdnn.conversion.cntk.saver import save_model
elif framework == 'keras':
from mmdnn.conversion.keras.saver import save_model
elif framework == 'mxnet':
from mmdnn.conversion.mxnet.saver import save_model
elif framework == 'pytorch':
from mmdnn.conversion.pytorch.saver import save_model
elif framework == 'darknet':
from mmdnn.conversion.darknet.saver import save_model
elif framework == 'tensorflow':
from mmdnn.conversion.tensorflow.saver import save_model
save_model(MainModel, network_filepath, weight_filepath, dump_filepath, dump_tag)
return 0
elif framework == 'onnx':
from mmdnn.conversion.onnx.saver import save_model
else:
raise NotImplementedError("{} saver is not finished yet.".format(framework))
save_model(MainModel, network_filepath, weight_filepath, dump_filepath)
return 0
def _get_parser():
import argparse
parser = argparse.ArgumentParser(description='Dump the model code into target model.')
parser.add_argument(
'-f', '--framework', type=_text_type, choices=["caffe", "cntk", "mxnet", "keras", "tensorflow", 'pytorch', 'onnx', 'darknet'],
required=True,
help='Format of model at srcModelPath (default is to auto-detect).'
)
parser.add_argument(
'-in', '--inputNetwork',
type=_text_type,
required=True,
help='Path to the model network architecture file.')
parser.add_argument(
'-iw', '--inputWeight',
type=_text_type,
required=True,
help='Path to the model network weight file.')
parser.add_argument(
'-o', '-om', '--outputModel',
type=_text_type,
required=True,
help='Path to save the target model')
parser.add_argument(
'--dump_tag',
type=_text_type,
default=None,
help='Tensorflow model dump type',
choices=['SERVING', 'TRAINING'])
return parser
def _main():
parser = _get_parser()
args = parser.parse_args()
ret = dump_code(args.framework, args.inputNetwork, args.inputWeight, args.outputModel, args.dump_tag)
_sys.exit(int(ret))
if __name__ == '__main__':
_main()