diff --git a/docs/docs/reference/environment-variables.md b/docs/docs/reference/environment-variables.md index 6cefada19..a4ac24bc7 100644 --- a/docs/docs/reference/environment-variables.md +++ b/docs/docs/reference/environment-variables.md @@ -112,6 +112,9 @@ For more details on the options below, refer to the [server deployment](../guide - `DSTACK_DEFAULT_SERVICE_CLIENT_MAX_BODY_SIZE`{ #DSTACK_DEFAULT_SERVICE_CLIENT_MAX_BODY_SIZE } – Request body size limit for services running with a gateway, in bytes. Defaults to 64 MiB. - `DSTACK_FORBID_SERVICES_WITHOUT_GATEWAY`{ #DSTACK_FORBID_SERVICES_WITHOUT_GATEWAY } – Forbids registering new services without a gateway if set to any value. - `DSTACK_SERVER_CODE_UPLOAD_LIMIT`{ #DSTACK_SERVER_CODE_UPLOAD_LIMIT } - The repo size limit when uploading diffs or local repos, in bytes. Set to 0 to disable size limits. Defaults to 2MiB. +- `DSTACK_SERVER_S3_BUCKET`{ #DSTACK_SERVER_S3_BUCKET } - The bucket that repo diffs will be uploaded to if set. If unset, diffs are uploaded to the database. +- `DSTACK_SERVER_S3_BUCKET_REGION`{ #DSTACK_SERVER_S3_BUCKET_REGION } - The region of the S3 Bucket. +- `DSTACK_SERVER_GCS_BUCKET`{ #DSTACK_SERVER_GCD_BUCKET } - The bucket that repo diffs will be uploaded to if set. If unset, diffs are uploaded to the database. ??? info "Internal environment variables" The following environment variables are intended for development purposes: diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index 3bafc7eb0..5f4403053 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -128,7 +128,7 @@ async def lifespan(app: FastAPI): yes=UPDATE_DEFAULT_PROJECT, no=DO_NOT_UPDATE_DEFAULT_PROJECT, ) - if settings.SERVER_BUCKET is not None: + if settings.SERVER_S3_BUCKET is not None or settings.SERVER_GCS_BUCKET is not None: init_default_storage() scheduler = start_background_tasks() dstack_version = DSTACK_VERSION if DSTACK_VERSION else "(no version)" diff --git a/src/dstack/_internal/server/services/storage/__init__.py b/src/dstack/_internal/server/services/storage/__init__.py new file mode 100644 index 000000000..14b75c347 --- /dev/null +++ b/src/dstack/_internal/server/services/storage/__init__.py @@ -0,0 +1,38 @@ +from typing import Optional + +from dstack._internal.server import settings +from dstack._internal.server.services.storage.base import BaseStorage +from dstack._internal.server.services.storage.gcs import GCS_AVAILABLE, GCSStorage +from dstack._internal.server.services.storage.s3 import BOTO_AVAILABLE, S3Storage + +_default_storage = None + + +def init_default_storage(): + global _default_storage + if settings.SERVER_S3_BUCKET is None and settings.SERVER_GCS_BUCKET is None: + raise ValueError( + "Either settings.SERVER_S3_BUCKET or settings.SERVER_GCS_BUCKET must be set" + ) + if settings.SERVER_S3_BUCKET and settings.SERVER_GCS_BUCKET: + raise ValueError( + "Only one of settings.SERVER_S3_BUCKET or settings.SERVER_GCS_BUCKET can be set" + ) + + if settings.SERVER_S3_BUCKET: + if not BOTO_AVAILABLE: + raise ValueError("AWS dependencies are not installed") + _default_storage = S3Storage( + bucket=settings.SERVER_S3_BUCKET, + region=settings.SERVER_S3_BUCKET_REGION, + ) + elif settings.SERVER_GCS_BUCKET: + if not GCS_AVAILABLE: + raise ValueError("GCS dependencies are not installed") + _default_storage = GCSStorage( + bucket=settings.SERVER_GCS_BUCKET, + ) + + +def get_default_storage() -> Optional[BaseStorage]: + return _default_storage diff --git a/src/dstack/_internal/server/services/storage/base.py b/src/dstack/_internal/server/services/storage/base.py new file mode 100644 index 000000000..5864f061d --- /dev/null +++ b/src/dstack/_internal/server/services/storage/base.py @@ -0,0 +1,27 @@ +from abc import ABC, abstractmethod +from typing import Optional + + +class BaseStorage(ABC): + @abstractmethod + def upload_code( + self, + project_id: str, + repo_id: str, + code_hash: str, + blob: bytes, + ): + pass + + @abstractmethod + def get_code( + self, + project_id: str, + repo_id: str, + code_hash: str, + ) -> Optional[bytes]: + pass + + @staticmethod + def _get_code_key(project_id: str, repo_id: str, code_hash: str) -> str: + return f"data/projects/{project_id}/codes/{repo_id}/{code_hash}" diff --git a/src/dstack/_internal/server/services/storage/gcs.py b/src/dstack/_internal/server/services/storage/gcs.py new file mode 100644 index 000000000..1075aba01 --- /dev/null +++ b/src/dstack/_internal/server/services/storage/gcs.py @@ -0,0 +1,44 @@ +from typing import Optional + +from dstack._internal.server.services.storage.base import BaseStorage + +GCS_AVAILABLE = True +try: + from google.cloud import storage + from google.cloud.exceptions import NotFound +except ImportError: + GCS_AVAILABLE = False + + +class GCSStorage(BaseStorage): + def __init__( + self, + bucket: str, + ): + self._client = storage.Client() + self._bucket = self._client.bucket(bucket) + + def upload_code( + self, + project_id: str, + repo_id: str, + code_hash: str, + blob: bytes, + ): + blob_name = self._get_code_key(project_id, repo_id, code_hash) + blob_obj = self._bucket.blob(blob_name) + blob_obj.upload_from_string(blob) + + def get_code( + self, + project_id: str, + repo_id: str, + code_hash: str, + ) -> Optional[bytes]: + try: + blob_name = self._get_code_key(project_id, repo_id, code_hash) + blob = self._bucket.blob(blob_name) + except NotFound: + return None + + return blob.download_as_bytes() diff --git a/src/dstack/_internal/server/services/storage.py b/src/dstack/_internal/server/services/storage/s3.py similarity index 56% rename from src/dstack/_internal/server/services/storage.py rename to src/dstack/_internal/server/services/storage/s3.py index 728f0af7c..8c67f28c2 100644 --- a/src/dstack/_internal/server/services/storage.py +++ b/src/dstack/_internal/server/services/storage/s3.py @@ -1,6 +1,6 @@ from typing import Optional -from dstack._internal.server import settings +from dstack._internal.server.services.storage.base import BaseStorage BOTO_AVAILABLE = True try: @@ -10,7 +10,7 @@ BOTO_AVAILABLE = False -class S3Storage: +class S3Storage(BaseStorage): def __init__( self, bucket: str, @@ -29,7 +29,7 @@ def upload_code( ): self._client.put_object( Bucket=self.bucket, - Key=_get_code_key(project_id, repo_id, code_hash), + Key=self._get_code_key(project_id, repo_id, code_hash), Body=blob, ) @@ -42,33 +42,10 @@ def get_code( try: response = self._client.get_object( Bucket=self.bucket, - Key=_get_code_key(project_id, repo_id, code_hash), + Key=self._get_code_key(project_id, repo_id, code_hash), ) except botocore.exceptions.ClientError as e: if e.response["Error"]["Code"] == "NoSuchKey": return None raise e return response["Body"].read() - - -def _get_code_key(project_id: str, repo_id: str, code_hash: str) -> str: - return f"data/projects/{project_id}/codes/{repo_id}/{code_hash}" - - -_default_storage = None - - -def init_default_storage(): - global _default_storage - if settings.SERVER_BUCKET is None: - raise ValueError("settings.SERVER_BUCKET not set") - if not BOTO_AVAILABLE: - raise ValueError("AWS dependencies are not installed") - _default_storage = S3Storage( - bucket=settings.SERVER_BUCKET, - region=settings.SERVER_BUCKET_REGION, - ) - - -def get_default_storage() -> Optional[S3Storage]: - return _default_storage diff --git a/src/dstack/_internal/server/settings.py b/src/dstack/_internal/server/settings.py index 05d21f83e..5df67123b 100644 --- a/src/dstack/_internal/server/settings.py +++ b/src/dstack/_internal/server/settings.py @@ -37,9 +37,13 @@ SERVER_CONFIG_DISABLED = os.getenv("DSTACK_SERVER_CONFIG_DISABLED") is not None SERVER_CONFIG_ENABLED = not SERVER_CONFIG_DISABLED -# TODO: add s3/aws prefix -SERVER_BUCKET = os.getenv("DSTACK_SERVER_BUCKET") -SERVER_BUCKET_REGION = os.getenv("DSTACK_SERVER_BUCKET_REGION") +# TODO: remove deprecated DSTACK_SERVER_BUCKET and DSTACK_SERVER_BUCKET_REGION env var usage +SERVER_S3_BUCKET = os.getenv("DSTACK_SERVER_S3_BUCKET", os.getenv("DSTACK_SERVER_BUCKET")) +SERVER_S3_BUCKET_REGION = os.getenv( + "DSTACK_SERVER_S3_BUCKET_REGION", os.getenv("DSTACK_SERVER_BUCKET_REGION") +) + +SERVER_GCS_BUCKET = os.getenv("DSTACK_SERVER_GCS_BUCKET") SERVER_CLOUDWATCH_LOG_GROUP = os.getenv("DSTACK_SERVER_CLOUDWATCH_LOG_GROUP") SERVER_CLOUDWATCH_LOG_REGION = os.getenv("DSTACK_SERVER_CLOUDWATCH_LOG_REGION")