55import re
66import tarfile
77from pathlib import Path
8- from typing import Union
8+ from typing import TypedDict , Set , Union
99
1010import requests
1111from tqdm import tqdm
1212
1313InputType = Union [str , Path ]
1414
1515
16+ class UnzipResult (TypedDict ):
17+ model_dir : Path
18+ my_files : Set [str ]
19+
20+
1621class DownloadModelError (Exception ):
1722 pass
1823
@@ -52,7 +57,7 @@ def download_file(url: str, save_dir: InputType) -> Path:
5257
5358 save_path = Path (save_dir ) / Path (url ).name
5459 with tqdm (
55- total = total_size_in_bytes , unit = "iB" , unit_scale = True , desc = "Downloading"
60+ total = total_size_in_bytes , unit = "iB" , unit_scale = True , desc = "Downloading"
5661 ) as pb :
5762 with open (save_path , "wb" ) as file :
5863 for data in response .iter_content (block_size ):
@@ -61,7 +66,7 @@ def download_file(url: str, save_dir: InputType) -> Path:
6166 return save_path
6267
6368
64- def unzip_file (file_path : str , save_dir : InputType , is_del_raw : bool = True ) -> Path :
69+ def unzip_file (file_path : str , save_dir : InputType , is_del_raw : bool = True ) -> UnzipResult :
6570 """解压下载得到的tar模型文件,会自动解压到save_dir下以file_path命名的目录下
6671
6772 Args:
@@ -70,19 +75,24 @@ def unzip_file(file_path: str, save_dir: InputType, is_del_raw: bool = True) ->
7075 is_del_raw (bool, optional): 是否删除原文件. Defaults to True.
7176
7277 Returns:
73- Path: 解压后模型保存路径
78+ dict:
79+ - model_dir (Path): 解压后模型保存路径
80+ - my_files (set): 解压得到的文件名集合
7481 """
7582 model_dir = Path (save_dir ) / Path (file_path ).stem
7683 mkdir (model_dir )
7784
78- tar_file_name_list = [".pdiparams" , ".pdiparams.info" , ".pdmodel" ]
85+ tar_file_name_list = [".pdiparams" , ".pdiparams.info" , ".pdmodel" , ".json" ]
86+ my_files = set ()
7987 with tarfile .open (file_path , "r" ) as tarObj :
88+
8089 for member in tarObj .getmembers ():
8190 filename = None
8291
8392 for tar_file_name in tar_file_name_list :
8493 if member .name .endswith (tar_file_name ):
8594 filename = "inference" + tar_file_name
95+ my_files .add (filename )
8696
8797 if filename is None :
8898 continue
@@ -94,7 +104,7 @@ def unzip_file(file_path: str, save_dir: InputType, is_del_raw: bool = True) ->
94104 if is_del_raw :
95105 Path (file_path ).unlink ()
96106 print (f"The { file_path } has been deleted." )
97- return model_dir
107+ return { " model_dir" : model_dir , "my_files" : my_files }
98108
99109
100110def is_http_url (s : InputType ) -> bool :
@@ -119,4 +129,4 @@ def is_http_url(s: InputType) -> bool:
119129
120130 if regex .match (str (s )):
121131 return True
122- return False
132+ return False
0 commit comments