Skip to content

Commit ee2d835

Browse files
committed
chore(storage): abc improvements, initial Archive work
1 parent b08c893 commit ee2d835

4 files changed

Lines changed: 123 additions & 3 deletions

File tree

api/utils/storage/BaseStorage.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import os
2+
import re
3+
import subprocess
4+
from abc import ABC, abstractmethod
5+
6+
7+
class BaseArchive(ABC):
8+
def __init__(self, path):
9+
self.path = path
10+
11+
def extract(self):
12+
print("TODO")
13+
14+
def splitext(self):
15+
base, ext = os.path.splitext(self.path)
16+
base, subext = os.path.splitext(base)
17+
return base, ext, subext
18+
19+
20+
class TarZstdArchive(BaseArchive):
21+
@staticmethod
22+
def test(path):
23+
return re.search(r"\.tar\.zstd?$", path)
24+
25+
def extract(self, dir, dry_run=False):
26+
if not dir:
27+
dir = os.path.dirname(self.path)
28+
base, ext, subext = self.splitext()
29+
dir = os.path.join(dir, base)
30+
31+
if not dry_run:
32+
os.mkdir(dir)
33+
subprocess.run(
34+
[
35+
"tar",
36+
"--use-compress-program=unzstd",
37+
"-C",
38+
dir,
39+
"-xvf",
40+
self.path,
41+
],
42+
check=True,
43+
)
44+
os.remove(self.path)
45+
46+
return dir # , base, ext, subext
47+
48+
49+
archiveClasses = [TarZstdArchive]
50+
51+
52+
def Archive(path, **kwargs):
53+
for ArchiveClass in archiveClasses:
54+
if ArchiveClass.test(path):
55+
return ArchiveClass(path, **kwargs)
56+
57+
58+
class BaseStorage(ABC):
59+
@staticmethod
60+
@abstractmethod
61+
def test(url):
62+
return re.search(r"^https?://", url)
63+
64+
def __init__(self, url, **kwargs):
65+
self.url = url
66+
67+
def splitext(self):
68+
base, ext = os.path.splitext(self.url)
69+
base, subext = os.path.splitext(base)
70+
return base, ext, subext
71+
72+
def get_filename(self):
73+
return self.url.split("/").pop()
74+
75+
@abstractmethod
76+
def download_file(self, dest):
77+
"""Download the file to `dest`"""
78+
pass
79+
80+
def download_and_extract(self, fname, dry_run=False):
81+
"""
82+
Downloads the file, and if it's an archive, extract it too. Returns
83+
the filename if not, or directory name (fname without extension) if
84+
it was.
85+
"""
86+
if not fname:
87+
fname = self.get_filename()
88+
89+
dir = None
90+
archive = Archive(fname, dry_run=dry_run)
91+
if archive:
92+
# TODO, streaming pipeline
93+
self.download_file(fname)
94+
return archive.extract()
95+
else:
96+
self.download_file(fname)
97+
return fname
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import unittest
2+
from . import Storage, S3Storage, HTTPStorage
3+
4+
5+
class BaseStorageTest(unittest.TestCase):
6+
def test_get_filename(self):
7+
storage = Storage("http://host.com/dir/file.tar.zst")
8+
self.assertEqual(storage.get_filename(), "file.tar.zst")
9+
10+
class Download_and_extract(unittest.TestCase):
11+
def test_file_only(self):
12+
storage = Storage("http://host.com/dir/file.bin")
13+
result = storage.download_and_extract(dry_run=True)
14+
self.assertEqual(result, "file.bin")
15+
16+
def test_file_archive(self):
17+
storage = Storage("http://host.com/dir/file.tar.zst")
18+
result, base, ext, subext = storage.download_and_extract(dry_run=True)
19+
self.assertEqual(result, "file")
20+
self.assertEqual(base, "file")
21+
self.assertEqual(ext, "tar")
22+
self.assertEqual(subext, "zst")

api/utils/storage/HTTPStorage.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
import time
44
import requests
55
from tqdm import tqdm
6+
from .BaseStorage import BaseStorage
67

78

89
def get_now():
910
return round(time.time() * 1000)
1011

1112

12-
class HTTPStorage:
13+
class HTTPStorage(BaseStorage):
1314
@staticmethod
1415
def test(url):
1516
return re.search(r"^https?://", url)

api/utils/storage/S3Storage.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import time
66
from tqdm import tqdm
77
from botocore.client import Config
8-
8+
from .BaseStorage import BaseStorage
99

1010
AWS_S3_ENDPOINT_URL = os.environ.get("AWS_S3_ENDPOINT_URL", None)
1111
AWS_S3_DEFAULT_BUCKET = os.environ.get("AWS_S3_DEFAULT_BUCKET", None)
@@ -19,7 +19,7 @@ def get_now():
1919
return round(time.time() * 1000)
2020

2121

22-
class S3Storage:
22+
class S3Storage(BaseStorage):
2323
def test(url):
2424
return re.search(r"^(https?\+)?s3://", url)
2525

0 commit comments

Comments
 (0)