@@ -33,11 +33,11 @@ def __call__(
3333
3434 model_path = self .get_file_path (model_path , save_dir )
3535
36- model_dir = unzip_file (model_path , save_dir , is_del_raw = is_del_raw )
36+ unzip_result = unzip_file (model_path , save_dir , is_del_raw = is_del_raw )
3737
38- save_onnx_path = model_dir / f"{ Path (model_path ).stem } .onnx"
38+ save_onnx_path = unzip_result . get ( ' model_dir' , '' ) / f"{ Path (model_path ).stem } .onnx"
3939 try :
40- self .convert_to_onnx (model_dir , save_onnx_path )
40+ self .convert_to_onnx (unzip_result , save_onnx_path )
4141 except ConvertError as e :
4242 raise e
4343
@@ -80,16 +80,20 @@ def get_file_path(file_path: str, save_dir: InputType) -> str:
8080 raise FileExistsError (f"{ file_path } does not exist." )
8181 return file_path
8282
83- def convert_to_onnx (self , model_dir : str , save_onnx_path : str ) -> None :
83+ def convert_to_onnx (self , unzip_result : object , save_onnx_path : str ) -> None :
8484 """借助 :code:`paddle2onnx` 工具转换模型为onnx格式
8585
8686 Args:
8787 model_dir (str): 保存paddle格式模型所在目录
8888 save_onnx_path (str): 保存的onnx全路径
8989 """
90+ model_dir = unzip_result .get ("model_dir" , "" )
91+ my_files = unzip_result .get ("my_files" , {})
92+ model_filename = 'inference.json' if "inference.json" in my_files else 'inference.pdmodel'
93+
9094 shell_str = (
9195 f"paddle2onnx --model_dir { model_dir } "
92- "--model_filename inference.pdmodel "
96+ f "--model_filename { model_filename } "
9397 "--params_filename inference.pdiparams "
9498 f"--opset_version { self .opset } "
9599 f"--save_file { save_onnx_path } "
0 commit comments