Skip to content

Commit bb7434a

Browse files
committed
feat(download): async, status; download.py: use download_and_extract
1 parent 77e9078 commit bb7434a

4 files changed

Lines changed: 44 additions & 32 deletions

File tree

api/download.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -70,27 +70,13 @@ async def download_model(
7070
if not filename:
7171
filename = normalized_model_id + ".tar.zst"
7272
model_file = os.path.join(MODELS_DIR, filename)
73-
storage = Storage(model_url, default_path=normalized_model_id + ".tar.zst")
73+
storage = Storage(
74+
model_url, default_path=normalized_model_id + ".tar.zst", status=status
75+
)
7476
exists = storage.file_exists()
7577
if exists:
76-
storage.download_file(model_file)
77-
# os.mkdir(id)
78-
# Path(id).mkdir(parents=True, exist_ok=False)
7978
model_dir = os.path.join(MODELS_DIR, normalized_model_id)
80-
os.mkdir(model_dir)
81-
subprocess.run(
82-
[
83-
"tar",
84-
"--use-compress-program=unzstd",
85-
"-C",
86-
model_dir,
87-
"-xvf",
88-
model_file,
89-
],
90-
check=True,
91-
)
92-
subprocess.run(["ls", "-l"])
93-
os.remove(model_file)
79+
await asyncio.to_thread(storage.download_and_extract, model_file, model_dir)
9480
else:
9581
if checkpoint_url:
9682
download_checkpoint(checkpoint_url)

api/utils/storage/BaseStorage.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,13 @@
55

66

77
class BaseArchive(ABC):
8-
def __init__(self, path):
8+
def __init__(self, path, status=None):
99
self.path = path
10+
self.status = status
11+
12+
def updateStatus(self, type, progress):
13+
if hasattr(self, "status"):
14+
self.status.update(type, progress)
1015

1116
def extract(self):
1217
print("TODO")
@@ -23,6 +28,7 @@ def test(path):
2328
return re.search(r"\.tar\.zstd?$", path)
2429

2530
def extract(self, dir, dry_run=False):
31+
self.updateStatus("extract", 0)
2632
if not dir:
2733
dir = os.path.dirname(self.path)
2834
base, ext, subext = self.splitext()
@@ -41,8 +47,10 @@ def extract(self, dir, dry_run=False):
4147
],
4248
check=True,
4349
)
50+
subprocess.run(["ls", "-l"])
4451
os.remove(self.path)
4552

53+
self.updateStatus("extract", 1)
4654
return dir # , base, ext, subext
4755

4856

@@ -63,6 +71,11 @@ def test(url):
6371

6472
def __init__(self, url, **kwargs):
6573
self.url = url
74+
self.status = kwargs.get("status", None)
75+
76+
def updateStatus(self, type, progress):
77+
if hasattr(self, "status"):
78+
self.status.update(type, progress)
6679

6780
def splitext(self):
6881
base, ext = os.path.splitext(self.url)
@@ -77,7 +90,7 @@ def download_file(self, dest):
7790
"""Download the file to `dest`"""
7891
pass
7992

80-
def download_and_extract(self, fname, dry_run=False):
93+
def download_and_extract(self, fname, dir=None, dry_run=False):
8194
"""
8295
Downloads the file, and if it's an archive, extract it too. Returns
8396
the filename if not, or directory name (fname without extension) if
@@ -86,12 +99,11 @@ def download_and_extract(self, fname, dry_run=False):
8699
if not fname:
87100
fname = self.get_filename()
88101

89-
dir = None
90-
archive = Archive(fname, dry_run=dry_run)
102+
archive = Archive(fname, status=self.status)
91103
if archive:
92104
# TODO, streaming pipeline
93105
self.download_file(fname)
94-
return archive.extract()
106+
return archive.extract(dir)
95107
else:
96108
self.download_file(fname)
97109
return fname

api/utils/storage/HTTPStorage.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def test(url):
1616
return re.search(r"^https?://", url)
1717

1818
def __init__(self, url, **kwargs):
19-
self.url = url
19+
super().__init__(url, **kwargs)
2020

2121
def upload_file(self, source, dest):
2222
raise RuntimeError("HTTP PUT not implemented yet")
@@ -37,6 +37,9 @@ def download_file(self, fname):
3737
unit_scale=True,
3838
unit_divisor=1024,
3939
) as bar:
40+
total_written = 0
4041
for data in resp.iter_content(chunk_size=1024):
4142
size = file.write(data)
4243
bar.update(size)
44+
total_written += size
45+
self.updateStatus("download", total_written / total)

api/utils/storage/S3Storage.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def test(url):
2424
return re.search(r"^(https?\+)?s3://", url)
2525

2626
def __init__(self, url, **kwargs):
27-
self.url = url
27+
super().__init__(url, **kwargs)
2828

2929
if url.startswith("s3://"):
3030
url = "https://" + url[5:]
@@ -90,10 +90,16 @@ def upload_file(self, source, dest):
9090
upload_start = get_now()
9191
file_size = os.stat(source).st_size
9292
with tqdm(total=file_size, unit="B", unit_scale=True, desc="Uploading") as bar:
93+
total_transferred = 0
94+
95+
def callback(bytes_transferred):
96+
nonlocal total_transferred
97+
bar.update(bytes_transferred),
98+
total_transferred += bytes_transferred
99+
self.updateStatus("upload", total_transferred / file_size)
100+
93101
result = self.bucket().upload_file(
94-
Filename=source,
95-
Key=dest,
96-
Callback=lambda bytes_transferred: bar.update(bytes_transferred),
102+
Filename=source, Key=dest, Callback=callback
97103
)
98104
print(result)
99105
upload_total = get_now() - upload_start
@@ -110,10 +116,15 @@ def download_file(self, dest):
110116
with tqdm(
111117
total=object.content_length, unit="B", unit_scale=True, desc="Downloading"
112118
) as bar:
113-
object.download_file(
114-
Filename=dest,
115-
Callback=lambda bytes_transffered: bar.update(bytes_transffered),
116-
)
119+
total_transferred = 0
120+
121+
def callback(bytes_transferred):
122+
nonlocal total_transferred
123+
bar.update(bytes_transferred),
124+
total_transferred += bytes_transferred
125+
self.updateStatus("download", total_transferred / object.content_length)
126+
127+
object.download_file(Filename=dest, Callback=callback)
117128

118129
def file_exists(self):
119130
# res = self.s3client().list_objects_v2(

0 commit comments

Comments
 (0)