Skip to content

Commit 1feffdb

Browse files
committed
feat: support s3
1 parent eb05e61 commit 1feffdb

6 files changed

Lines changed: 215 additions & 2 deletions

File tree

.env.example

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,9 @@ OSS_ACCESS_KEY_ID=
55
OSS_ACCESS_KEY_SECRET=
66
ENDPOINT=
77
BUCKET=
8+
9+
# S3
10+
BUCKET=
11+
AWS_ACCESS_KEY_ID=
12+
AWS_SECRET_ACCESS_KEY=
13+
AWS_DEFAULT_REGION=

README.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,31 @@ MINIO_ACCESS_KEY=
8383
MINIO_SECRET_KEY=
8484
```
8585

86+
### [S3](https://aws.amazon.com/s3/)
87+
88+
Usage:
89+
90+
```python
91+
client = StoreFactory.new_client(
92+
provider="S3", bucket=<bucket>
93+
)
94+
95+
# Use endpoint when accessing S3 via a PrivateLink interface endpoint.
96+
# https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3-example-privatelink.html
97+
client = StoreFactory.new_client(
98+
provider="S3", bucket=<bucket>, endpoint=<endpoint>
99+
)
100+
```
101+
102+
Required environment variables:
103+
104+
```yaml
105+
AWS_ACCESS_KEY_ID=
106+
AWS_SECRET_ACCESS_KEY=
107+
# If a region is not specified, the bucket is created in the S3 default region (us-east-1).
108+
AWS_DEFAULT_REGION=
109+
```
110+
86111
## Development
87112

88113
Once you want to run the integration tests, you should have a `.env` file locally, similar to the `.env.example`.

omnistore/objstore/constant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
OBJECT_STORE_OSS = "OSS"
22
OBJECT_STORE_MINIO = "MINIO"
3+
OBJECT_STORE_S3 = "S3"

omnistore/objstore/objstore_factory.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
from omnistore.objstore.aliyun_oss import OSS
2-
from omnistore.objstore.constant import OBJECT_STORE_OSS, OBJECT_STORE_MINIO
2+
from omnistore.objstore.constant import OBJECT_STORE_OSS, OBJECT_STORE_MINIO, OBJECT_STORE_S3
33
from omnistore.objstore.minio import MinIO
4+
from omnistore.objstore.s3 import S3
45
from omnistore.store import Store
56

67

78
class StoreFactory:
89
ObjStores = {
910
OBJECT_STORE_OSS: OSS,
1011
OBJECT_STORE_MINIO: MinIO,
12+
OBJECT_STORE_S3: S3,
1113
}
1214

1315
@classmethod
14-
def new_client(cls, provider: str, endpoint: str, bucket: str) -> Store:
16+
def new_client(cls, provider: str, endpoint: str = None, bucket: str = None) -> Store:
1517
objstore = cls.ObjStores[provider]
1618
if not objstore:
1719
raise KeyError(f"Unknown object store provider {provider}")

omnistore/objstore/s3.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import io
2+
import os
3+
from pathlib import Path
4+
5+
import boto3
6+
from botocore.exceptions import ClientError
7+
8+
from omnistore.objstore.objstore import ObjStore
9+
10+
11+
class S3(ObjStore):
12+
def __init__(self, bucket: str, endpoint: str = None):
13+
"""
14+
Construct a new client to communicate with the AWS S3 provider.
15+
16+
AWS credentials are expected to be provided via environment variables:
17+
- AWS_ACCESS_KEY_ID
18+
- AWS_SECRET_ACCESS_KEY
19+
- AWS_DEFAULT_REGION
20+
"""
21+
region = os.environ.get("AWS_DEFAULT_REGION")
22+
23+
# If a region is not specified, the bucket is created in the S3 default region (us-east-1).
24+
# If the user explicitly provides an endpoint_url, the region is not used.
25+
kwargs = {}
26+
if endpoint:
27+
kwargs['endpoint_url'] = endpoint
28+
if region:
29+
kwargs['region_name'] = region
30+
31+
self.client = boto3.client('s3', **kwargs)
32+
self.resource = boto3.resource('s3', **kwargs)
33+
self.bucket_name = bucket
34+
35+
# Make sure the bucket exists
36+
try:
37+
self.client.head_bucket(Bucket=bucket)
38+
except ClientError as e:
39+
# If bucket doesn't exist, create it
40+
if e.response['Error']['Code'] == '404':
41+
kwargs = {}
42+
# For non us-east-1 region, we need to specify the LocationConstraint parameter when creating the bucket
43+
if region:
44+
kwargs['CreateBucketConfiguration'] = {
45+
"LocationConstraint": region
46+
}
47+
self.client.create_bucket(Bucket=bucket, **kwargs)
48+
else:
49+
raise e
50+
51+
def create_dir(self, dirname: str):
52+
if not dirname.endswith("/"):
53+
dirname += "/"
54+
empty_stream = io.BytesIO(b"")
55+
self.client.put_object(Bucket=self.bucket_name, Key=dirname, Body=empty_stream)
56+
57+
def delete_dir(self, dirname: str):
58+
if not dirname.endswith("/"):
59+
dirname += "/"
60+
61+
bucket = self.resource.Bucket(self.bucket_name)
62+
bucket.objects.filter(Prefix=dirname).delete()
63+
64+
def upload(self, src: str, dest: str):
65+
self.client.upload_file(src, self.bucket_name, dest)
66+
67+
def upload_dir(self, src_dir: str, dest_dir: str):
68+
for file in Path(src_dir).rglob("*"):
69+
if file.is_file():
70+
dest_path = f"{dest_dir}/{file.relative_to(src_dir)}"
71+
self.upload(str(file), dest_path)
72+
elif file.is_dir():
73+
self.create_dir(f"{dest_dir}/{file.relative_to(src_dir)}/")
74+
75+
def download(self, src: str, dest: str):
76+
self.client.download_file(self.bucket_name, src, dest)
77+
78+
def download_dir(self, src_dir: str, dest_dir: str):
79+
if not src_dir.endswith("/"):
80+
src_dir += "/"
81+
path = Path(dest_dir)
82+
if not path.exists():
83+
path.mkdir(parents=True)
84+
85+
paginator = self.client.get_paginator('list_objects_v2')
86+
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=src_dir)
87+
88+
for page in pages:
89+
if 'Contents' not in page:
90+
continue
91+
92+
for obj in page['Contents']:
93+
key = obj['Key']
94+
if key.endswith('/'): # Skip directories
95+
continue
96+
97+
file_path = Path(dest_dir, Path(key).relative_to(src_dir))
98+
if not file_path.parent.exists():
99+
file_path.parent.mkdir(parents=True, exist_ok=True)
100+
101+
self.download(key, str(file_path))
102+
103+
def delete(self, filename: str):
104+
self.client.delete_object(Bucket=self.bucket_name, Key=filename)
105+
106+
def exists(self, filename: str):
107+
try:
108+
self.client.head_object(Bucket=self.bucket_name, Key=filename)
109+
return True
110+
except ClientError as e:
111+
if e.response['Error']['Code'] == '404':
112+
return False
113+
else:
114+
raise e
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import os
2+
import shutil
3+
4+
import pytest
5+
from dotenv import load_dotenv
6+
7+
from omnistore.objstore import StoreFactory
8+
from omnistore.objstore.constant import OBJECT_STORE_S3
9+
10+
load_dotenv()
11+
12+
class TestS3:
13+
@pytest.fixture(scope="module", autouse=True)
14+
def setup_and_teardown(self):
15+
print("Setting up the test environment.")
16+
try:
17+
os.makedirs("./test-tmp", exist_ok=True)
18+
except Exception as e:
19+
print(f"An error occurred: {e}")
20+
21+
yield
22+
23+
print("Tearing down the test environment.")
24+
shutil.rmtree("./test-tmp")
25+
26+
def test_upload_and_download_files(self):
27+
bucket = os.getenv("BUCKET")
28+
29+
client = StoreFactory.new_client(
30+
provider=OBJECT_STORE_S3, bucket=bucket
31+
)
32+
assert False == client.exists("foo.txt")
33+
34+
with open("./test-tmp/foo.txt", "w") as file:
35+
file.write("test")
36+
37+
client.upload("./test-tmp/foo.txt", "foo.txt")
38+
assert True == client.exists("foo.txt")
39+
40+
client.download("foo.txt", "./test-tmp/bar.txt")
41+
assert True == os.path.exists("./test-tmp/bar.txt")
42+
43+
client.delete("foo.txt")
44+
assert False == client.exists("foo.txt")
45+
46+
def test_upload_and_download_dir(self):
47+
bucket = os.getenv("BUCKET")
48+
49+
client = StoreFactory.new_client(
50+
provider=OBJECT_STORE_S3, bucket=bucket
51+
)
52+
assert False == client.exists("/test/foo.txt")
53+
54+
os.makedirs("./test-tmp/test/111", exist_ok=True)
55+
with open("./test-tmp/test/111/foo.txt", "w") as file:
56+
file.write("test")
57+
58+
client.upload_dir("./test-tmp/test", "test")
59+
assert True == client.exists("test/111/foo.txt")
60+
61+
client.download_dir("test", "./test-tmp/test1")
62+
assert True == os.path.exists("./test-tmp/test1/111/foo.txt")
63+
64+
client.delete_dir("test")
65+
assert False == client.exists("test/foo.txt")

0 commit comments

Comments
 (0)