Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional

from dstack._internal.core.backends.amddevcloud.backend import AMDDevCloudBackend
from dstack._internal.core.backends.base.configurator import BackendRecord
from dstack._internal.core.backends.digitalocean_base.api_client import DigitalOceanAPIClient
from dstack._internal.core.backends.digitalocean_base.backend import BaseDigitalOceanBackend
from dstack._internal.core.backends.digitalocean_base.configurator import (
Expand All @@ -17,7 +18,7 @@ class AMDDevCloudConfigurator(BaseDigitalOceanConfigurator):
BACKEND_CLASS = AMDDevCloudBackend
API_URL = "https://api-amd.digitalocean.com"

def get_backend(self, record) -> BaseDigitalOceanBackend:
def get_backend(self, record: BackendRecord) -> BaseDigitalOceanBackend:
config = self._get_config(record)
return AMDDevCloudBackend(config=config, api_url=self.API_URL)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
)
from dstack._internal.core.backends.digitalocean_base.backend import BaseDigitalOceanBackend
from dstack._internal.core.backends.digitalocean_base.models import (
AnyBaseDigitalOceanBackendConfig,
AnyBaseDigitalOceanCreds,
BaseDigitalOceanBackendConfig,
BaseDigitalOceanBackendConfigWithCreds,
Expand All @@ -33,16 +32,20 @@ def create_backend(
auth=BaseDigitalOceanCreds.parse_obj(config.creds).json(),
)

def get_backend_config(
self, record: BackendRecord, include_creds: bool
) -> AnyBaseDigitalOceanBackendConfig:
def get_backend_config_with_creds(
self, record: BackendRecord
) -> BaseDigitalOceanBackendConfigWithCreds:
config = self._get_config(record)
return BaseDigitalOceanBackendConfigWithCreds.__response__.parse_obj(config)

def get_backend_config_without_creds(
self, record: BackendRecord
) -> BaseDigitalOceanBackendConfig:
config = self._get_config(record)
if include_creds:
return BaseDigitalOceanBackendConfigWithCreds.__response__.parse_obj(config)
return BaseDigitalOceanBackendConfig.__response__.parse_obj(config)

def get_backend(self, record: BackendRecord) -> BaseDigitalOceanBackend:
pass
raise NotImplementedError("Subclasses must implement get_backend")

def _get_config(self, record: BackendRecord) -> BaseDigitalOceanConfig:
return BaseDigitalOceanConfig.__response__(
Expand Down
Loading