Skip to content

Commit 3f72398

Browse files
authored
Make Configurator generic (#3013)
1 parent 2fbc251 commit 3f72398

File tree

18 files changed

+210
-134
lines changed

18 files changed

+210
-134
lines changed

src/dstack/_internal/core/backends/aws/configurator.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from dstack._internal.core.backends.aws import auth, compute, resources
88
from dstack._internal.core.backends.aws.backend import AWSBackend
99
from dstack._internal.core.backends.aws.models import (
10-
AnyAWSBackendConfig,
1110
AWSAccessKeyCreds,
1211
AWSBackendConfig,
1312
AWSBackendConfigWithCreds,
@@ -52,7 +51,12 @@
5251
MAIN_REGION = "us-east-1"
5352

5453

55-
class AWSConfigurator(Configurator):
54+
class AWSConfigurator(
55+
Configurator[
56+
AWSBackendConfig,
57+
AWSBackendConfigWithCreds,
58+
]
59+
):
5660
TYPE = BackendType.AWS
5761
BACKEND_CLASS = AWSBackend
5862

@@ -87,12 +91,12 @@ def create_backend(
8791
auth=AWSCreds.parse_obj(config.creds).json(),
8892
)
8993

90-
def get_backend_config(
91-
self, record: BackendRecord, include_creds: bool
92-
) -> AnyAWSBackendConfig:
94+
def get_backend_config_with_creds(self, record: BackendRecord) -> AWSBackendConfigWithCreds:
95+
config = self._get_config(record)
96+
return AWSBackendConfigWithCreds.__response__.parse_obj(config)
97+
98+
def get_backend_config_without_creds(self, record: BackendRecord) -> AWSBackendConfig:
9399
config = self._get_config(record)
94-
if include_creds:
95-
return AWSBackendConfigWithCreds.__response__.parse_obj(config)
96100
return AWSBackendConfig.__response__.parse_obj(config)
97101

98102
def get_backend(self, record: BackendRecord) -> AWSBackend:

src/dstack/_internal/core/backends/azure/configurator.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from dstack._internal.core.backends.azure import utils as azure_utils
2525
from dstack._internal.core.backends.azure.backend import AzureBackend
2626
from dstack._internal.core.backends.azure.models import (
27-
AnyAzureBackendConfig,
2827
AzureBackendConfig,
2928
AzureBackendConfigWithCreds,
3029
AzureClientCreds,
@@ -71,7 +70,12 @@
7170
MAIN_LOCATION = "eastus"
7271

7372

74-
class AzureConfigurator(Configurator):
73+
class AzureConfigurator(
74+
Configurator[
75+
AzureBackendConfig,
76+
AzureBackendConfigWithCreds,
77+
]
78+
):
7579
TYPE = BackendType.AZURE
7680
BACKEND_CLASS = AzureBackend
7781

@@ -130,12 +134,12 @@ def create_backend(
130134
auth=AzureCreds.parse_obj(config.creds).__root__.json(),
131135
)
132136

133-
def get_backend_config(
134-
self, record: BackendRecord, include_creds: bool
135-
) -> AnyAzureBackendConfig:
137+
def get_backend_config_with_creds(self, record: BackendRecord) -> AzureBackendConfigWithCreds:
138+
config = self._get_config(record)
139+
return AzureBackendConfigWithCreds.__response__.parse_obj(config)
140+
141+
def get_backend_config_without_creds(self, record: BackendRecord) -> AzureBackendConfig:
136142
config = self._get_config(record)
137-
if include_creds:
138-
return AzureBackendConfigWithCreds.__response__.parse_obj(config)
139143
return AzureBackendConfig.__response__.parse_obj(config)
140144

141145
def get_backend(self, record: BackendRecord) -> AzureBackend:

src/dstack/_internal/core/backends/base/configurator.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from abc import ABC, abstractmethod
2-
from typing import Any, ClassVar, List, Literal, Optional, overload
2+
from typing import Any, ClassVar, Generic, List, Optional, TypeVar
33
from uuid import UUID
44

55
from dstack._internal.core.backends.base.backend import Backend
66
from dstack._internal.core.backends.models import (
7-
AnyBackendConfig,
87
AnyBackendConfigWithCreds,
98
AnyBackendConfigWithoutCreds,
109
)
@@ -16,6 +15,11 @@
1615
# We'll introduce our own base limit that can be customized per backend if required.
1716
TAGS_MAX_NUM = 25
1817

18+
BackendConfigWithoutCredsT = TypeVar(
19+
"BackendConfigWithoutCredsT", bound=AnyBackendConfigWithoutCreds
20+
)
21+
BackendConfigWithCredsT = TypeVar("BackendConfigWithCredsT", bound=AnyBackendConfigWithCreds)
22+
1923

2024
class BackendRecord(CoreModel):
2125
"""
@@ -40,7 +44,7 @@ class StoredBackendRecord(BackendRecord):
4044
backend_id: UUID
4145

4246

43-
class Configurator(ABC):
47+
class Configurator(ABC, Generic[BackendConfigWithoutCredsT, BackendConfigWithCredsT]):
4448
"""
4549
`Configurator` is responsible for configuring backends
4650
and initializing `Backend` instances from backend configs.
@@ -53,7 +57,7 @@ class Configurator(ABC):
5357
BACKEND_CLASS: ClassVar[type[Backend]]
5458

5559
@abstractmethod
56-
def validate_config(self, config: AnyBackendConfigWithCreds, default_creds_enabled: bool):
60+
def validate_config(self, config: BackendConfigWithCredsT, default_creds_enabled: bool):
5761
"""
5862
Validates backend config including backend creds and other parameters.
5963
Raises `ServerClientError` or its subclass if config is invalid.
@@ -62,9 +66,7 @@ def validate_config(self, config: AnyBackendConfigWithCreds, default_creds_enabl
6266
pass
6367

6468
@abstractmethod
65-
def create_backend(
66-
self, project_name: str, config: AnyBackendConfigWithCreds
67-
) -> BackendRecord:
69+
def create_backend(self, project_name: str, config: BackendConfigWithCredsT) -> BackendRecord:
6870
"""
6971
Sets up backend given backend config and returns
7072
text-encoded config and creds to be stored in the DB.
@@ -78,26 +80,23 @@ def create_backend(
7880
"""
7981
pass
8082

81-
@overload
82-
def get_backend_config(
83-
self, record: StoredBackendRecord, include_creds: Literal[False]
84-
) -> AnyBackendConfigWithoutCreds:
85-
pass
86-
87-
@overload
88-
def get_backend_config(
89-
self, record: StoredBackendRecord, include_creds: Literal[True]
90-
) -> AnyBackendConfigWithCreds:
83+
@abstractmethod
84+
def get_backend_config_with_creds(
85+
self, record: StoredBackendRecord
86+
) -> BackendConfigWithCredsT:
87+
"""
88+
Constructs `BackendConfig` with credentials included.
89+
Used internally and when project admins need to see backend's creds.
90+
"""
9191
pass
9292

9393
@abstractmethod
94-
def get_backend_config(
95-
self, record: StoredBackendRecord, include_creds: bool
96-
) -> AnyBackendConfig:
94+
def get_backend_config_without_creds(
95+
self, record: StoredBackendRecord
96+
) -> BackendConfigWithoutCredsT:
9797
"""
98-
Constructs `BackendConfig` to be returned in API responses.
99-
Project admins may need to see backend's creds. In this case `include_creds` will be `True`.
100-
Otherwise, no sensitive information should be included.
98+
Constructs `BackendConfig` without sensitive information.
99+
Used for API responses where creds should not be exposed.
101100
"""
102101
pass
103102

src/dstack/_internal/core/backends/cloudrift/configurator.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from dstack._internal.core.backends.cloudrift.api_client import RiftClient
99
from dstack._internal.core.backends.cloudrift.backend import CloudRiftBackend
1010
from dstack._internal.core.backends.cloudrift.models import (
11-
AnyCloudRiftBackendConfig,
1211
AnyCloudRiftCreds,
1312
CloudRiftBackendConfig,
1413
CloudRiftBackendConfigWithCreds,
@@ -21,7 +20,12 @@
2120
)
2221

2322

24-
class CloudRiftConfigurator(Configurator):
23+
class CloudRiftConfigurator(
24+
Configurator[
25+
CloudRiftBackendConfig,
26+
CloudRiftBackendConfigWithCreds,
27+
]
28+
):
2529
TYPE = BackendType.CLOUDRIFT
2630
BACKEND_CLASS = CloudRiftBackend
2731

@@ -40,12 +44,14 @@ def create_backend(
4044
auth=CloudRiftCreds.parse_obj(config.creds).json(),
4145
)
4246

43-
def get_backend_config(
44-
self, record: BackendRecord, include_creds: bool
45-
) -> AnyCloudRiftBackendConfig:
47+
def get_backend_config_with_creds(
48+
self, record: BackendRecord
49+
) -> CloudRiftBackendConfigWithCreds:
50+
config = self._get_config(record)
51+
return CloudRiftBackendConfigWithCreds.__response__.parse_obj(config)
52+
53+
def get_backend_config_without_creds(self, record: BackendRecord) -> CloudRiftBackendConfig:
4654
config = self._get_config(record)
47-
if include_creds:
48-
return CloudRiftBackendConfigWithCreds.__response__.parse_obj(config)
4955
return CloudRiftBackendConfig.__response__.parse_obj(config)
5056

5157
def get_backend(self, record: BackendRecord) -> CloudRiftBackend:

src/dstack/_internal/core/backends/cudo/configurator.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from dstack._internal.core.backends.cudo import api_client
99
from dstack._internal.core.backends.cudo.backend import CudoBackend
1010
from dstack._internal.core.backends.cudo.models import (
11-
AnyCudoBackendConfig,
1211
CudoBackendConfig,
1312
CudoBackendConfigWithCreds,
1413
CudoConfig,
@@ -18,7 +17,12 @@
1817
from dstack._internal.core.models.backends.base import BackendType
1918

2019

21-
class CudoConfigurator(Configurator):
20+
class CudoConfigurator(
21+
Configurator[
22+
CudoBackendConfig,
23+
CudoBackendConfigWithCreds,
24+
]
25+
):
2226
TYPE = BackendType.CUDO
2327
BACKEND_CLASS = CudoBackend
2428

@@ -35,12 +39,12 @@ def create_backend(
3539
auth=CudoCreds.parse_obj(config.creds).json(),
3640
)
3741

38-
def get_backend_config(
39-
self, record: BackendRecord, include_creds: bool
40-
) -> AnyCudoBackendConfig:
42+
def get_backend_config_with_creds(self, record: BackendRecord) -> CudoBackendConfigWithCreds:
43+
config = self._get_config(record)
44+
return CudoBackendConfigWithCreds.__response__.parse_obj(config)
45+
46+
def get_backend_config_without_creds(self, record: BackendRecord) -> CudoBackendConfig:
4147
config = self._get_config(record)
42-
if include_creds:
43-
return CudoBackendConfigWithCreds.__response__.parse_obj(config)
4448
return CudoBackendConfig.__response__.parse_obj(config)
4549

4650
def get_backend(self, record: BackendRecord) -> CudoBackend:

src/dstack/_internal/core/backends/datacrunch/configurator.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
)
1111
from dstack._internal.core.backends.datacrunch.backend import DataCrunchBackend
1212
from dstack._internal.core.backends.datacrunch.models import (
13-
AnyDataCrunchBackendConfig,
1413
DataCrunchBackendConfig,
1514
DataCrunchBackendConfigWithCreds,
1615
DataCrunchConfig,
@@ -22,7 +21,12 @@
2221
)
2322

2423

25-
class DataCrunchConfigurator(Configurator):
24+
class DataCrunchConfigurator(
25+
Configurator[
26+
DataCrunchBackendConfig,
27+
DataCrunchBackendConfigWithCreds,
28+
]
29+
):
2630
TYPE = BackendType.DATACRUNCH
2731
BACKEND_CLASS = DataCrunchBackend
2832

@@ -41,12 +45,14 @@ def create_backend(
4145
auth=DataCrunchCreds.parse_obj(config.creds).json(),
4246
)
4347

44-
def get_backend_config(
45-
self, record: BackendRecord, include_creds: bool
46-
) -> AnyDataCrunchBackendConfig:
48+
def get_backend_config_with_creds(
49+
self, record: BackendRecord
50+
) -> DataCrunchBackendConfigWithCreds:
51+
config = self._get_config(record)
52+
return DataCrunchBackendConfigWithCreds.__response__.parse_obj(config)
53+
54+
def get_backend_config_without_creds(self, record: BackendRecord) -> DataCrunchBackendConfig:
4755
config = self._get_config(record)
48-
if include_creds:
49-
return DataCrunchBackendConfigWithCreds.__response__.parse_obj(config)
5056
return DataCrunchBackendConfig.__response__.parse_obj(config)
5157

5258
def get_backend(self, record: BackendRecord) -> DataCrunchBackend:

src/dstack/_internal/core/backends/gcp/configurator.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from dstack._internal.core.backends.gcp import auth, resources
1212
from dstack._internal.core.backends.gcp.backend import GCPBackend
1313
from dstack._internal.core.backends.gcp.models import (
14-
AnyGCPBackendConfig,
1514
GCPBackendConfig,
1615
GCPBackendConfigWithCreds,
1716
GCPConfig,
@@ -109,7 +108,12 @@
109108
MAIN_REGION = "us-east1"
110109

111110

112-
class GCPConfigurator(Configurator):
111+
class GCPConfigurator(
112+
Configurator[
113+
GCPBackendConfig,
114+
GCPBackendConfigWithCreds,
115+
]
116+
):
113117
TYPE = BackendType.GCP
114118
BACKEND_CLASS = GCPBackend
115119

@@ -147,12 +151,12 @@ def create_backend(
147151
auth=GCPCreds.parse_obj(config.creds).json(),
148152
)
149153

150-
def get_backend_config(
151-
self, record: BackendRecord, include_creds: bool
152-
) -> AnyGCPBackendConfig:
154+
def get_backend_config_with_creds(self, record: BackendRecord) -> GCPBackendConfigWithCreds:
155+
config = self._get_config(record)
156+
return GCPBackendConfigWithCreds.__response__.parse_obj(config)
157+
158+
def get_backend_config_without_creds(self, record: BackendRecord) -> GCPBackendConfig:
153159
config = self._get_config(record)
154-
if include_creds:
155-
return GCPBackendConfigWithCreds.__response__.parse_obj(config)
156160
return GCPBackendConfig.__response__.parse_obj(config)
157161

158162
def get_backend(self, record: BackendRecord) -> GCPBackend:

src/dstack/_internal/core/backends/hotaisle/configurator.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from dstack._internal.core.backends.hotaisle.api_client import HotAisleAPIClient
88
from dstack._internal.core.backends.hotaisle.backend import HotAisleBackend
99
from dstack._internal.core.backends.hotaisle.models import (
10-
AnyHotAisleBackendConfig,
1110
AnyHotAisleCreds,
1211
HotAisleBackendConfig,
1312
HotAisleBackendConfigWithCreds,
@@ -20,7 +19,12 @@
2019
)
2120

2221

23-
class HotAisleConfigurator(Configurator):
22+
class HotAisleConfigurator(
23+
Configurator[
24+
HotAisleBackendConfig,
25+
HotAisleBackendConfigWithCreds,
26+
]
27+
):
2428
TYPE = BackendType.HOTAISLE
2529
BACKEND_CLASS = HotAisleBackend
2630

@@ -37,12 +41,14 @@ def create_backend(
3741
auth=HotAisleCreds.parse_obj(config.creds).json(),
3842
)
3943

40-
def get_backend_config(
41-
self, record: BackendRecord, include_creds: bool
42-
) -> AnyHotAisleBackendConfig:
44+
def get_backend_config_with_creds(
45+
self, record: BackendRecord
46+
) -> HotAisleBackendConfigWithCreds:
47+
config = self._get_config(record)
48+
return HotAisleBackendConfigWithCreds.__response__.parse_obj(config)
49+
50+
def get_backend_config_without_creds(self, record: BackendRecord) -> HotAisleBackendConfig:
4351
config = self._get_config(record)
44-
if include_creds:
45-
return HotAisleBackendConfigWithCreds.__response__.parse_obj(config)
4652
return HotAisleBackendConfig.__response__.parse_obj(config)
4753

4854
def get_backend(self, record: BackendRecord) -> HotAisleBackend:

0 commit comments

Comments
 (0)