Skip to content

Commit 0594ecc

Browse files
authored
Merge pull request #20 from tvone/feature/update-json
feature/update-json - Updated json file support
2 parents e818e31 + bef5bcb commit 0594ecc

2 files changed

Lines changed: 29 additions & 13 deletions

File tree

paddleocr_convert/main.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ def __call__(
3333

3434
model_path = self.get_file_path(model_path, save_dir)
3535

36-
model_dir = unzip_file(model_path, save_dir, is_del_raw=is_del_raw)
36+
unzip_result = unzip_file(model_path, save_dir, is_del_raw=is_del_raw)
3737

38-
save_onnx_path = model_dir / f"{Path(model_path).stem}.onnx"
38+
save_onnx_path = unzip_result.get('model_dir', '') / f"{Path(model_path).stem}.onnx"
3939
try:
40-
self.convert_to_onnx(model_dir, save_onnx_path)
40+
self.convert_to_onnx(unzip_result, save_onnx_path)
4141
except ConvertError as e:
4242
raise e
4343

@@ -80,16 +80,22 @@ def get_file_path(file_path: str, save_dir: InputType) -> str:
8080
raise FileExistsError(f"{file_path} does not exist.")
8181
return file_path
8282

83-
def convert_to_onnx(self, model_dir: str, save_onnx_path: str) -> None:
83+
def convert_to_onnx(self, unzip_result: object, save_onnx_path: str) -> None:
8484
"""借助 :code:`paddle2onnx` 工具转换模型为onnx格式
8585
8686
Args:
87-
model_dir (str): 保存paddle格式模型所在目录
87+
unzip_result (object):
88+
- model_dir (Path): 解压后模型保存路径。
89+
- my_files (set): 解压得到的文件名集合。
8890
save_onnx_path (str): 保存的onnx全路径
8991
"""
92+
model_dir = unzip_result.get("model_dir", "")
93+
my_files = unzip_result.get("my_files", {})
94+
model_filename = 'inference.json' if "inference.json" in my_files else 'inference.pdmodel'
95+
9096
shell_str = (
9197
f"paddle2onnx --model_dir {model_dir} "
92-
"--model_filename inference.pdmodel "
98+
f"--model_filename {model_filename} "
9399
"--params_filename inference.pdiparams "
94100
f"--opset_version {self.opset} "
95101
f"--save_file {save_onnx_path}"

paddleocr_convert/utils.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,19 @@
55
import re
66
import tarfile
77
from pathlib import Path
8-
from typing import Union
8+
from typing import TypedDict, Set, Union
99

1010
import requests
1111
from tqdm import tqdm
1212

1313
InputType = Union[str, Path]
1414

1515

16+
class UnzipResult(TypedDict):
17+
model_dir: Path
18+
my_files: Set[str]
19+
20+
1621
class DownloadModelError(Exception):
1722
pass
1823

@@ -52,7 +57,7 @@ def download_file(url: str, save_dir: InputType) -> Path:
5257

5358
save_path = Path(save_dir) / Path(url).name
5459
with tqdm(
55-
total=total_size_in_bytes, unit="iB", unit_scale=True, desc="Downloading"
60+
total=total_size_in_bytes, unit="iB", unit_scale=True, desc="Downloading"
5661
) as pb:
5762
with open(save_path, "wb") as file:
5863
for data in response.iter_content(block_size):
@@ -61,7 +66,7 @@ def download_file(url: str, save_dir: InputType) -> Path:
6166
return save_path
6267

6368

64-
def unzip_file(file_path: str, save_dir: InputType, is_del_raw: bool = True) -> Path:
69+
def unzip_file(file_path: str, save_dir: InputType, is_del_raw: bool = True) -> UnzipResult:
6570
"""解压下载得到的tar模型文件,会自动解压到save_dir下以file_path命名的目录下
6671
6772
Args:
@@ -70,19 +75,24 @@ def unzip_file(file_path: str, save_dir: InputType, is_del_raw: bool = True) ->
7075
is_del_raw (bool, optional): 是否删除原文件. Defaults to True.
7176
7277
Returns:
73-
Path: 解压后模型保存路径
78+
dict:
79+
- model_dir (Path): 解压后模型保存路径
80+
- my_files (set): 解压得到的文件名集合
7481
"""
7582
model_dir = Path(save_dir) / Path(file_path).stem
7683
mkdir(model_dir)
7784

78-
tar_file_name_list = [".pdiparams", ".pdiparams.info", ".pdmodel"]
85+
tar_file_name_list = [".pdiparams", ".pdiparams.info", ".pdmodel", ".json"]
86+
my_files = set()
7987
with tarfile.open(file_path, "r") as tarObj:
88+
8089
for member in tarObj.getmembers():
8190
filename = None
8291

8392
for tar_file_name in tar_file_name_list:
8493
if member.name.endswith(tar_file_name):
8594
filename = "inference" + tar_file_name
95+
my_files.add(filename)
8696

8797
if filename is None:
8898
continue
@@ -94,7 +104,7 @@ def unzip_file(file_path: str, save_dir: InputType, is_del_raw: bool = True) ->
94104
if is_del_raw:
95105
Path(file_path).unlink()
96106
print(f"The {file_path} has been deleted.")
97-
return model_dir
107+
return {"model_dir": model_dir, "my_files": my_files}
98108

99109

100110
def is_http_url(s: InputType) -> bool:
@@ -119,4 +129,4 @@ def is_http_url(s: InputType) -> bool:
119129

120130
if regex.match(str(s)):
121131
return True
122-
return False
132+
return False

0 commit comments

Comments
 (0)