55import re
66import tarfile
77from pathlib import Path
8- from typing import Union
8+ from typing import TypedDict , Set , Union
99
1010import requests
1111from tqdm import tqdm
1212
1313InputType = Union [str , Path ]
1414
1515
16+ class UnzipResult (TypedDict ):
17+ model_dir : Path
18+ my_files : Set [str ]
19+
20+
1621class DownloadModelError (Exception ):
1722 pass
1823
@@ -52,7 +57,7 @@ def download_file(url: str, save_dir: InputType) -> Path:
5257
5358 save_path = Path (save_dir ) / Path (url ).name
5459 with tqdm (
55- total = total_size_in_bytes , unit = "iB" , unit_scale = True , desc = "Downloading"
60+ total = total_size_in_bytes , unit = "iB" , unit_scale = True , desc = "Downloading"
5661 ) as pb :
5762 with open (save_path , "wb" ) as file :
5863 for data in response .iter_content (block_size ):
@@ -61,7 +66,7 @@ def download_file(url: str, save_dir: InputType) -> Path:
6166 return save_path
6267
6368
64- def unzip_file (file_path : str , save_dir : InputType , is_del_raw : bool = True ) -> Path :
69+ def unzip_file (file_path : str , save_dir : InputType , is_del_raw : bool = True ) -> UnzipResult :
6570 """解压下载得到的tar模型文件,会自动解压到save_dir下以file_path命名的目录下
6671
6772 Args:
@@ -70,15 +75,17 @@ def unzip_file(file_path: str, save_dir: InputType, is_del_raw: bool = True) ->
7075 is_del_raw (bool, optional): 是否删除原文件. Defaults to True.
7176
7277 Returns:
73- Path: 解压后模型保存路径
78+ dict:
79+ - model_dir (Path): 解压后模型保存路径
80+ - my_files (set): 解压得到的文件名集合
7481 """
7582 model_dir = Path (save_dir ) / Path (file_path ).stem
7683 mkdir (model_dir )
7784
7885 tar_file_name_list = [".pdiparams" , ".pdiparams.info" , ".pdmodel" , ".json" ]
7986 my_files = set ()
8087 with tarfile .open (file_path , "r" ) as tarObj :
81-
88+
8289 for member in tarObj .getmembers ():
8390 filename = None
8491
@@ -122,4 +129,4 @@ def is_http_url(s: InputType) -> bool:
122129
123130 if regex .match (str (s )):
124131 return True
125- return False
132+ return False
0 commit comments