-
-
Notifications
You must be signed in to change notification settings - Fork 223
Expand file tree
/
Copy pathconfig.py
More file actions
279 lines (235 loc) · 10.4 KB
/
config.py
File metadata and controls
279 lines (235 loc) · 10.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
from typing import List, Optional
import yaml
from pydantic import Field, ValidationError
from sqlalchemy.ext.asyncio import AsyncSession
from typing_extensions import Annotated
import dstack._internal.core.backends.configurators
from dstack._internal.core.backends.models import (
AnyBackendConfigWithCreds,
AnyBackendFileConfigWithCreds,
BackendInfoYAML,
)
from dstack._internal.core.errors import (
BackendNotAvailable,
ResourceNotExistsError,
ServerClientError,
)
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import CoreModel
from dstack._internal.server import settings
from dstack._internal.server.models import ProjectModel, UserModel
from dstack._internal.server.services import backends as backends_services
from dstack._internal.server.services import encryption as encryption_services
from dstack._internal.server.services import projects as projects_services
from dstack._internal.server.services.backends.handlers import delete_backends_safe
from dstack._internal.server.services.encryption import AnyEncryptionKeyConfig
from dstack._internal.server.services.permissions import (
DefaultPermissions,
set_default_permissions,
)
from dstack._internal.server.services.plugins import load_plugins
from dstack._internal.utils.logging import get_logger
logger = get_logger(__name__)
# By default, PyYAML chooses the style of a collection depending on whether it has nested collections.
# If a collection has nested collections, it will be assigned the block style. Otherwise it will have the flow style.
#
# We want mapping to always be displayed in block-style but lists without nested objects in flow-style.
# So we define a custom representer.
def seq_representer(dumper, sequence):
flow_style = len(sequence) == 0 or isinstance(sequence[0], str) or isinstance(sequence[0], int)
return dumper.represent_sequence("tag:yaml.org,2002:seq", sequence, flow_style)
yaml.add_representer(list, seq_representer)
BackendFileConfigWithCreds = Annotated[
AnyBackendFileConfigWithCreds, Field(..., discriminator="type")
]
class ProjectConfig(CoreModel):
name: Annotated[str, Field(description="The name of the project")]
backends: Annotated[
Optional[List[BackendFileConfigWithCreds]], Field(description="The list of backends")
] = None
EncryptionKeyConfig = Annotated[AnyEncryptionKeyConfig, Field(..., discriminator="type")]
class EncryptionConfig(CoreModel):
keys: Annotated[List[EncryptionKeyConfig], Field(description="The encryption keys")]
class ServerConfig(CoreModel):
projects: Annotated[List[ProjectConfig], Field(description="The list of projects")]
encryption: Annotated[
Optional[EncryptionConfig], Field(description="The encryption config")
] = None
default_permissions: Annotated[
Optional[DefaultPermissions], Field(description="The default user permissions")
] = None
plugins: Annotated[
Optional[List[str]], Field(description="The server-side plugins to enable")
] = None
class ServerConfigManager:
def load_config(self) -> bool:
self.config = self._load_config()
return self.config is not None
async def init_config(self, session: AsyncSession):
"""
Initializes the default server/config.yml.
The default config is empty or contains an existing `main` project config.
"""
self.config = await self._init_config(session)
if self.config is not None:
self._save_config(self.config)
async def sync_config(self, session: AsyncSession):
# Disable config.yml sync for https://github.com/dstackai/dstack/issues/815.
return
async def apply_encryption(self):
if self.config is None:
logger.info("No server/config.yml. Skipping encryption configuration.")
return
if self.config.encryption is not None:
encryption_services.init_encryption_keys(self.config.encryption.keys)
async def apply_config(self, session: AsyncSession, owner: UserModel):
if self.config is None:
raise ValueError("Config is not loaded")
if self.config.default_permissions is not None:
set_default_permissions(self.config.default_permissions)
for project_config in self.config.projects:
await self._apply_project_config(
session=session, owner=owner, project_config=project_config
)
load_plugins(enabled_plugins=self.config.plugins or [])
async def _apply_project_config(
self,
session: AsyncSession,
owner: UserModel,
project_config: ProjectConfig,
):
project = await projects_services.get_project_model_by_name(
session=session,
project_name=project_config.name,
)
if not project:
await projects_services.create_project_model(
session=session, owner=owner, project_name=project_config.name
)
project = await projects_services.get_project_model_by_name_or_error(
session=session, project_name=project_config.name
)
backends_to_delete = set(
dstack._internal.core.backends.configurators.list_available_backend_types()
)
for backend_file_config in project_config.backends or []:
backend_config = file_config_to_config(backend_file_config)
backend_type = BackendType(backend_config.type)
backends_to_delete.difference_update([backend_type])
try:
current_backend_config = await backends_services.get_backend_config(
project=project,
backend_type=backend_type,
)
except BackendNotAvailable:
logger.warning(
"Backend %s not available and won't be configured."
" Check that backend dependencies are installed.",
backend_type.value,
)
continue
if backend_config == current_backend_config:
continue
backend_exists = any(backend_type == b.type for b in project.backends)
try:
# current_backend_config may be None if backend exists
# but it's config is invalid (e.g. cannot be decrypted).
# Update backend in this case.
if current_backend_config is None and not backend_exists:
await backends_services.create_backend(
session=session, project=project, config=backend_config
)
else:
await backends_services.update_backend(
session=session, project=project, config=backend_config
)
except Exception as e:
logger.warning("Failed to configure backend %s: %s", backend_config.type, e)
await delete_backends_safe(
session=session,
project=project,
backends_types=list(backends_to_delete),
error=False,
)
async def _init_config(self, session: AsyncSession) -> Optional[ServerConfig]:
project = await projects_services.get_project_model_by_name(
session=session,
project_name=settings.DEFAULT_PROJECT_NAME,
)
if project is None:
return None
# Force project reload to reflect updates when syncing
await session.refresh(project)
backends = []
for (
backend_type
) in dstack._internal.core.backends.configurators.list_available_backend_types():
backend_config = await backends_services.get_backend_config(
project=project, backend_type=backend_type
)
if backend_config is not None:
backends.append(backend_config)
return ServerConfig(
projects=[ProjectConfig(name=settings.DEFAULT_PROJECT_NAME, backends=backends)],
encryption=EncryptionConfig(keys=[]),
default_permissions=None,
)
def _load_config(self) -> Optional[ServerConfig]:
try:
with open(settings.SERVER_CONFIG_FILE_PATH) as f:
content = f.read()
except OSError:
return
config_dict = yaml.load(content, yaml.FullLoader)
return ServerConfig.parse_obj(config_dict)
def _save_config(self, config: ServerConfig):
with open(settings.SERVER_CONFIG_FILE_PATH, "w+") as f:
f.write(config_to_yaml(config))
async def get_backend_config_yaml(
project: ProjectModel, backend_type: BackendType
) -> BackendInfoYAML:
backend_config = await backends_services.get_backend_config(
project=project, backend_type=backend_type
)
if backend_config is None:
raise ResourceNotExistsError()
config_yaml = config_to_yaml(backend_config)
return BackendInfoYAML(
name=backend_type,
config_yaml=config_yaml,
)
async def create_backend_config_yaml(
session: AsyncSession,
project: ProjectModel,
config_yaml: str,
):
config = config_yaml_to_backend_config(config_yaml)
await backends_services.create_backend(session=session, project=project, config=config)
async def update_backend_config_yaml(
session: AsyncSession,
project: ProjectModel,
config_yaml: str,
):
config = config_yaml_to_backend_config(config_yaml)
await backends_services.update_backend(session=session, project=project, config=config)
class _BackendConfigWithCreds(CoreModel):
"""
Model for parsing API and file YAML configs.
"""
__root__: Annotated[AnyBackendConfigWithCreds, Field(..., discriminator="type")]
def config_yaml_to_backend_config(config_yaml: str) -> AnyBackendConfigWithCreds:
try:
config_dict = yaml.load(config_yaml, yaml.FullLoader)
except yaml.YAMLError:
raise ServerClientError("Error parsing YAML")
try:
backend_config = _BackendConfigWithCreds.parse_obj(config_dict).__root__
except ValidationError as e:
raise ServerClientError(str(e))
return backend_config
def file_config_to_config(file_config: AnyBackendFileConfigWithCreds) -> AnyBackendConfigWithCreds:
backend_config_dict = file_config.dict()
backend_config = _BackendConfigWithCreds.parse_obj(backend_config_dict)
return backend_config.__root__
def config_to_yaml(config: CoreModel) -> str:
return yaml.dump(config.dict(exclude_none=True), sort_keys=False)