diff --git a/src/dstack/_internal/core/backends/amddevcloud/configurator.py b/src/dstack/_internal/core/backends/amddevcloud/configurator.py index 1f464fc549..2f00f359eb 100644 --- a/src/dstack/_internal/core/backends/amddevcloud/configurator.py +++ b/src/dstack/_internal/core/backends/amddevcloud/configurator.py @@ -1,6 +1,7 @@ from typing import Optional from dstack._internal.core.backends.amddevcloud.backend import AMDDevCloudBackend +from dstack._internal.core.backends.base.configurator import BackendRecord from dstack._internal.core.backends.digitalocean_base.api_client import DigitalOceanAPIClient from dstack._internal.core.backends.digitalocean_base.backend import BaseDigitalOceanBackend from dstack._internal.core.backends.digitalocean_base.configurator import ( @@ -17,7 +18,7 @@ class AMDDevCloudConfigurator(BaseDigitalOceanConfigurator): BACKEND_CLASS = AMDDevCloudBackend API_URL = "https://api-amd.digitalocean.com" - def get_backend(self, record) -> BaseDigitalOceanBackend: + def get_backend(self, record: BackendRecord) -> BaseDigitalOceanBackend: config = self._get_config(record) return AMDDevCloudBackend(config=config, api_url=self.API_URL) diff --git a/src/dstack/_internal/core/backends/digitalocean_base/configurator.py b/src/dstack/_internal/core/backends/digitalocean_base/configurator.py index b57559f1ae..f44c5d2d0f 100644 --- a/src/dstack/_internal/core/backends/digitalocean_base/configurator.py +++ b/src/dstack/_internal/core/backends/digitalocean_base/configurator.py @@ -7,7 +7,6 @@ ) from dstack._internal.core.backends.digitalocean_base.backend import BaseDigitalOceanBackend from dstack._internal.core.backends.digitalocean_base.models import ( - AnyBaseDigitalOceanBackendConfig, AnyBaseDigitalOceanCreds, BaseDigitalOceanBackendConfig, BaseDigitalOceanBackendConfigWithCreds, @@ -33,16 +32,20 @@ def create_backend( auth=BaseDigitalOceanCreds.parse_obj(config.creds).json(), ) - def get_backend_config( - self, record: BackendRecord, include_creds: bool - ) -> AnyBaseDigitalOceanBackendConfig: + def get_backend_config_with_creds( + self, record: BackendRecord + ) -> BaseDigitalOceanBackendConfigWithCreds: + config = self._get_config(record) + return BaseDigitalOceanBackendConfigWithCreds.__response__.parse_obj(config) + + def get_backend_config_without_creds( + self, record: BackendRecord + ) -> BaseDigitalOceanBackendConfig: config = self._get_config(record) - if include_creds: - return BaseDigitalOceanBackendConfigWithCreds.__response__.parse_obj(config) return BaseDigitalOceanBackendConfig.__response__.parse_obj(config) def get_backend(self, record: BackendRecord) -> BaseDigitalOceanBackend: - pass + raise NotImplementedError("Subclasses must implement get_backend") def _get_config(self, record: BackendRecord) -> BaseDigitalOceanConfig: return BaseDigitalOceanConfig.__response__(