Skip to content

Commit bef5bcb

Browse files
author
duccuong.le
committed
refine comments and add type hints
1 parent e0e82b3 commit bef5bcb

2 files changed

Lines changed: 16 additions & 7 deletions

File tree

paddleocr_convert/main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,9 @@ def convert_to_onnx(self, unzip_result: object, save_onnx_path: str) -> None:
8484
"""借助 :code:`paddle2onnx` 工具转换模型为onnx格式
8585
8686
Args:
87-
model_dir (str): 保存paddle格式模型所在目录
87+
unzip_result (object):
88+
- model_dir (Path): 解压后模型保存路径。
89+
- my_files (set): 解压得到的文件名集合。
8890
save_onnx_path (str): 保存的onnx全路径
8991
"""
9092
model_dir = unzip_result.get("model_dir", "")

paddleocr_convert/utils.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,19 @@
55
import re
66
import tarfile
77
from pathlib import Path
8-
from typing import Union
8+
from typing import TypedDict, Set, Union
99

1010
import requests
1111
from tqdm import tqdm
1212

1313
InputType = Union[str, Path]
1414

1515

16+
class UnzipResult(TypedDict):
17+
model_dir: Path
18+
my_files: Set[str]
19+
20+
1621
class DownloadModelError(Exception):
1722
pass
1823

@@ -52,7 +57,7 @@ def download_file(url: str, save_dir: InputType) -> Path:
5257

5358
save_path = Path(save_dir) / Path(url).name
5459
with tqdm(
55-
total=total_size_in_bytes, unit="iB", unit_scale=True, desc="Downloading"
60+
total=total_size_in_bytes, unit="iB", unit_scale=True, desc="Downloading"
5661
) as pb:
5762
with open(save_path, "wb") as file:
5863
for data in response.iter_content(block_size):
@@ -61,7 +66,7 @@ def download_file(url: str, save_dir: InputType) -> Path:
6166
return save_path
6267

6368

64-
def unzip_file(file_path: str, save_dir: InputType, is_del_raw: bool = True) -> Path:
69+
def unzip_file(file_path: str, save_dir: InputType, is_del_raw: bool = True) -> UnzipResult:
6570
"""解压下载得到的tar模型文件,会自动解压到save_dir下以file_path命名的目录下
6671
6772
Args:
@@ -70,15 +75,17 @@ def unzip_file(file_path: str, save_dir: InputType, is_del_raw: bool = True) ->
7075
is_del_raw (bool, optional): 是否删除原文件. Defaults to True.
7176
7277
Returns:
73-
Path: 解压后模型保存路径
78+
dict:
79+
- model_dir (Path): 解压后模型保存路径
80+
- my_files (set): 解压得到的文件名集合
7481
"""
7582
model_dir = Path(save_dir) / Path(file_path).stem
7683
mkdir(model_dir)
7784

7885
tar_file_name_list = [".pdiparams", ".pdiparams.info", ".pdmodel", ".json"]
7986
my_files = set()
8087
with tarfile.open(file_path, "r") as tarObj:
81-
88+
8289
for member in tarObj.getmembers():
8390
filename = None
8491

@@ -122,4 +129,4 @@ def is_http_url(s: InputType) -> bool:
122129

123130
if regex.match(str(s)):
124131
return True
125-
return False
132+
return False

0 commit comments

Comments
 (0)