Skip to content

Commit e0e82b3

Browse files
author
duccuong.le
committed
feature/update-json - Updated json file support
1 parent e818e31 commit e0e82b3

2 files changed

Lines changed: 14 additions & 7 deletions

File tree

paddleocr_convert/main.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ def __call__(
3333

3434
model_path = self.get_file_path(model_path, save_dir)
3535

36-
model_dir = unzip_file(model_path, save_dir, is_del_raw=is_del_raw)
36+
unzip_result = unzip_file(model_path, save_dir, is_del_raw=is_del_raw)
3737

38-
save_onnx_path = model_dir / f"{Path(model_path).stem}.onnx"
38+
save_onnx_path = unzip_result.get('model_dir', '') / f"{Path(model_path).stem}.onnx"
3939
try:
40-
self.convert_to_onnx(model_dir, save_onnx_path)
40+
self.convert_to_onnx(unzip_result, save_onnx_path)
4141
except ConvertError as e:
4242
raise e
4343

@@ -80,16 +80,20 @@ def get_file_path(file_path: str, save_dir: InputType) -> str:
8080
raise FileExistsError(f"{file_path} does not exist.")
8181
return file_path
8282

83-
def convert_to_onnx(self, model_dir: str, save_onnx_path: str) -> None:
83+
def convert_to_onnx(self, unzip_result: object, save_onnx_path: str) -> None:
8484
"""借助 :code:`paddle2onnx` 工具转换模型为onnx格式
8585
8686
Args:
8787
model_dir (str): 保存paddle格式模型所在目录
8888
save_onnx_path (str): 保存的onnx全路径
8989
"""
90+
model_dir = unzip_result.get("model_dir", "")
91+
my_files = unzip_result.get("my_files", {})
92+
model_filename = 'inference.json' if "inference.json" in my_files else 'inference.pdmodel'
93+
9094
shell_str = (
9195
f"paddle2onnx --model_dir {model_dir} "
92-
"--model_filename inference.pdmodel "
96+
f"--model_filename {model_filename} "
9397
"--params_filename inference.pdiparams "
9498
f"--opset_version {self.opset} "
9599
f"--save_file {save_onnx_path}"

paddleocr_convert/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,17 @@ def unzip_file(file_path: str, save_dir: InputType, is_del_raw: bool = True) ->
7575
model_dir = Path(save_dir) / Path(file_path).stem
7676
mkdir(model_dir)
7777

78-
tar_file_name_list = [".pdiparams", ".pdiparams.info", ".pdmodel"]
78+
tar_file_name_list = [".pdiparams", ".pdiparams.info", ".pdmodel", ".json"]
79+
my_files = set()
7980
with tarfile.open(file_path, "r") as tarObj:
81+
8082
for member in tarObj.getmembers():
8183
filename = None
8284

8385
for tar_file_name in tar_file_name_list:
8486
if member.name.endswith(tar_file_name):
8587
filename = "inference" + tar_file_name
88+
my_files.add(filename)
8689

8790
if filename is None:
8891
continue
@@ -94,7 +97,7 @@ def unzip_file(file_path: str, save_dir: InputType, is_del_raw: bool = True) ->
9497
if is_del_raw:
9598
Path(file_path).unlink()
9699
print(f"The {file_path} has been deleted.")
97-
return model_dir
100+
return {"model_dir": model_dir, "my_files": my_files}
98101

99102

100103
def is_http_url(s: InputType) -> bool:

0 commit comments

Comments
 (0)