Skip to content

Commit d144212

Browse files
committed
Interpolate registry_auth with secrets
1 parent a118308 commit d144212

7 files changed

Lines changed: 89 additions & 22 deletions

File tree

src/dstack/_internal/server/background/tasks/process_runs.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
run_model_to_run,
3535
scale_run_replicas,
3636
)
37+
from dstack._internal.server.services.secrets import get_project_secrets_mapping
3738
from dstack._internal.server.services.services import update_service_desired_replica_count
3839
from dstack._internal.utils import common
3940
from dstack._internal.utils.logging import get_logger
@@ -385,7 +386,11 @@ async def _handle_run_replicas(
385386
)
386387
return
387388

388-
await _update_jobs_to_new_deployment_in_place(run_model, run_spec)
389+
await _update_jobs_to_new_deployment_in_place(
390+
session=session,
391+
run_model=run_model,
392+
run_spec=run_spec,
393+
)
389394
if _has_out_of_date_replicas(run_model):
390395
non_terminated_replica_count = len(
391396
{j.replica_num for j in run_model.jobs if not j.status.is_finished()}
@@ -425,18 +430,25 @@ async def _handle_run_replicas(
425430
)
426431

427432

428-
async def _update_jobs_to_new_deployment_in_place(run_model: RunModel, run_spec: RunSpec) -> None:
433+
async def _update_jobs_to_new_deployment_in_place(
434+
session: AsyncSession, run_model: RunModel, run_spec: RunSpec
435+
) -> None:
429436
"""
430437
Bump deployment_num for jobs that do not require redeployment.
431438
"""
432-
439+
secrets = await get_project_secrets_mapping(
440+
session=session,
441+
project=run_model.project,
442+
)
433443
for replica_num, job_models in group_jobs_by_replica_latest(run_model.jobs):
434444
if all(j.status.is_finished() for j in job_models):
435445
continue
436446
if all(j.deployment_num == run_model.deployment_num for j in job_models):
437447
continue
448+
# FIXME: Handle getting image configuration errors or skip it.
438449
new_job_specs = await get_job_specs_from_run_spec(
439450
run_spec=run_spec,
451+
secrets=secrets,
440452
replica_num=replica_num,
441453
)
442454
assert len(new_job_specs) == len(job_models), (

src/dstack/_internal/server/services/jobs/__init__.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,23 @@
6464
logger = get_logger(__name__)
6565

6666

67-
async def get_jobs_from_run_spec(run_spec: RunSpec, replica_num: int) -> List[Job]:
67+
async def get_jobs_from_run_spec(
68+
run_spec: RunSpec, secrets: Dict[str, str], replica_num: int
69+
) -> List[Job]:
6870
return [
6971
Job(job_spec=s, job_submissions=[])
70-
for s in await get_job_specs_from_run_spec(run_spec, replica_num)
72+
for s in await get_job_specs_from_run_spec(
73+
run_spec=run_spec,
74+
secrets=secrets,
75+
replica_num=replica_num,
76+
)
7177
]
7278

7379

74-
async def get_job_specs_from_run_spec(run_spec: RunSpec, replica_num: int) -> List[JobSpec]:
75-
job_configurator = _get_job_configurator(run_spec)
80+
async def get_job_specs_from_run_spec(
81+
run_spec: RunSpec, secrets: Dict[str, str], replica_num: int
82+
) -> List[JobSpec]:
83+
job_configurator = _get_job_configurator(run_spec=run_spec, secrets=secrets)
7684
job_specs = await job_configurator.get_job_specs(replica_num=replica_num)
7785
return job_specs
7886

@@ -158,10 +166,10 @@ def delay_job_instance_termination(job_model: JobModel):
158166
job_model.remove_at = common.get_current_datetime() + timedelta(seconds=15)
159167

160168

161-
def _get_job_configurator(run_spec: RunSpec) -> JobConfigurator:
169+
def _get_job_configurator(run_spec: RunSpec, secrets: Dict[str, str]) -> JobConfigurator:
162170
configuration_type = RunConfigurationType(run_spec.configuration.type)
163171
configurator_class = _configuration_type_to_configurator_class_map[configuration_type]
164-
return configurator_class(run_spec)
172+
return configurator_class(run_spec=run_spec, secrets=secrets)
165173

166174

167175
_job_configurator_classes = [

src/dstack/_internal/server/services/jobs/configurators/base.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,13 @@ class JobConfigurator(ABC):
6868
# JobSSHKey should be shared for all jobs in a replica for inter-node communication.
6969
_job_ssh_key: Optional[JobSSHKey] = None
7070

71-
def __init__(self, run_spec: RunSpec):
71+
def __init__(
72+
self,
73+
run_spec: RunSpec,
74+
secrets: Optional[Dict[str, str]] = None,
75+
):
7276
self.run_spec = run_spec
77+
self.secrets = secrets or {}
7378

7479
async def get_job_specs(self, replica_num: int) -> List[JobSpec]:
7580
job_spec = await self._get_job_spec(replica_num=replica_num, job_num=0, jobs_per_replica=1)
@@ -98,10 +103,20 @@ def _ports(self) -> List[PortMapping]:
98103
async def _get_image_config(self) -> ImageConfig:
99104
if self._image_config is not None:
100105
return self._image_config
106+
interpolate = VariablesInterpolator({"secrets": self.secrets}).interpolate_or_error
107+
registry_auth = self.run_spec.configuration.registry_auth
108+
if registry_auth is not None:
109+
try:
110+
registry_auth = RegistryAuth(
111+
username=interpolate(registry_auth.username),
112+
password=interpolate(registry_auth.password),
113+
)
114+
except InterpolatorError as e:
115+
raise ServerClientError(e.args[0])
101116
image_config = await run_async(
102117
_get_image_config,
103118
self._image_name(),
104-
self.run_spec.configuration.registry_auth,
119+
registry_auth,
105120
)
106121
self._image_config = image_config
107122
return image_config

src/dstack/_internal/server/services/jobs/configurators/dev.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional
1+
from typing import Dict, List, Optional
22

33
from dstack._internal.core.errors import ServerClientError
44
from dstack._internal.core.models.configurations import PortMapping, RunConfigurationType
@@ -17,7 +17,7 @@
1717
class DevEnvironmentJobConfigurator(JobConfigurator):
1818
TYPE: RunConfigurationType = RunConfigurationType.DEV_ENVIRONMENT
1919

20-
def __init__(self, run_spec: RunSpec):
20+
def __init__(self, run_spec: RunSpec, secrets: Dict[str, str]):
2121
if run_spec.configuration.ide == "vscode":
2222
__class = VSCodeDesktop
2323
elif run_spec.configuration.ide == "cursor":
@@ -29,7 +29,7 @@ def __init__(self, run_spec: RunSpec):
2929
version=run_spec.configuration.version,
3030
extensions=["ms-python.python", "ms-toolsai.jupyter"],
3131
)
32-
super().__init__(run_spec)
32+
super().__init__(run_spec=run_spec, secrets=secrets)
3333

3434
def _shell_commands(self) -> List[str]:
3535
commands = self.ide.get_install_commands()

src/dstack/_internal/server/services/runs.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
from dstack._internal.server.services.plugins import apply_plugin_policies
8383
from dstack._internal.server.services.projects import list_project_models, list_user_project_models
8484
from dstack._internal.server.services.resources import set_resources_defaults
85+
from dstack._internal.server.services.secrets import get_project_secrets_mapping
8586
from dstack._internal.server.services.users import get_user_model_by_name
8687
from dstack._internal.utils.logging import get_logger
8788
from dstack._internal.utils.random_names import generate_name
@@ -311,7 +312,12 @@ async def get_plan(
311312
):
312313
action = ApplyAction.UPDATE
313314

314-
jobs = await get_jobs_from_run_spec(effective_run_spec, replica_num=0)
315+
secrets = await get_project_secrets_mapping(session=session, project=project)
316+
jobs = await get_jobs_from_run_spec(
317+
run_spec=effective_run_spec,
318+
secrets=secrets,
319+
replica_num=0,
320+
)
315321

316322
volumes = await get_job_configured_volumes(
317323
session=session,
@@ -462,6 +468,10 @@ async def submit_run(
462468
project=project,
463469
run_spec=run_spec,
464470
)
471+
secrets = await get_project_secrets_mapping(
472+
session=session,
473+
project=project,
474+
)
465475

466476
lock_namespace = f"run_names_{project.name}"
467477
if get_db().dialect_name == "sqlite":
@@ -513,7 +523,11 @@ async def submit_run(
513523
await services.register_service(session, run_model, run_spec)
514524

515525
for replica_num in range(replicas):
516-
jobs = await get_jobs_from_run_spec(run_spec, replica_num=replica_num)
526+
jobs = await get_jobs_from_run_spec(
527+
run_spec=run_spec,
528+
secrets=secrets,
529+
replica_num=replica_num,
530+
)
517531
for job in jobs:
518532
job_model = create_job_model_for_new_submission(
519533
run_model=run_model,
@@ -1068,10 +1082,20 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica
10681082
await retry_run_replica_jobs(session, run_model, replica_jobs, only_failed=False)
10691083
scheduled_replicas += 1
10701084

1085+
secrets = await get_project_secrets_mapping(
1086+
session=session,
1087+
project=run_model.project,
1088+
)
1089+
10711090
for replica_num in range(
10721091
len(active_replicas) + scheduled_replicas, len(active_replicas) + replicas_diff
10731092
):
1074-
jobs = await get_jobs_from_run_spec(run_spec, replica_num=replica_num)
1093+
# FIXME: Handle getting image configuration errors or skip it.
1094+
jobs = await get_jobs_from_run_spec(
1095+
run_spec=run_spec,
1096+
secrets=secrets,
1097+
replica_num=replica_num,
1098+
)
10751099
for job in jobs:
10761100
job_model = create_job_model_for_new_submission(
10771101
run_model=run_model,
@@ -1084,8 +1108,14 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica
10841108
async def retry_run_replica_jobs(
10851109
session: AsyncSession, run_model: RunModel, latest_jobs: List[JobModel], *, only_failed: bool
10861110
):
1111+
# FIXME: Handle getting image configuration errors or skip it.
1112+
secrets = await get_project_secrets_mapping(
1113+
session=session,
1114+
project=run_model.project,
1115+
)
10871116
new_jobs = await get_jobs_from_run_spec(
1088-
RunSpec.__response__.parse_raw(run_model.run_spec),
1117+
run_spec=RunSpec.__response__.parse_raw(run_model.run_spec),
1118+
secrets=secrets,
10891119
replica_num=latest_jobs[0].replica_num,
10901120
)
10911121
assert len(new_jobs) == len(latest_jobs), (

src/dstack/_internal/server/testing/common.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,9 @@ async def create_job(
316316
if deployment_num is None:
317317
deployment_num = run.deployment_num
318318
run_spec = RunSpec.parse_raw(run.run_spec)
319-
job_spec = (await get_job_specs_from_run_spec(run_spec, replica_num=replica_num))[0]
319+
job_spec = (
320+
await get_job_specs_from_run_spec(run_spec=run_spec, secrets={}, replica_num=replica_num)
321+
)[0]
320322
job_spec.job_num = job_num
321323
job = JobModel(
322324
project_id=run.project_id,

src/tests/_internal/server/services/test_runs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List
1+
from typing import List, Union
22

33
import pytest
44
from pydantic import parse_obj_as
@@ -30,7 +30,7 @@ async def make_run(
3030
session: AsyncSession,
3131
replicas_statuses: List[JobStatus],
3232
status: RunStatus = RunStatus.RUNNING,
33-
replicas: str = 1,
33+
replicas: Union[str, int] = 1,
3434
) -> RunModel:
3535
project = await create_project(session=session)
3636
user = await create_user(session=session)
@@ -70,7 +70,7 @@ async def make_run(
7070
status=job_status,
7171
replica_num=replica_num,
7272
)
73-
await session.refresh(run)
73+
await session.refresh(run, attribute_names=["project", "jobs"])
7474
return run
7575

7676

0 commit comments

Comments
 (0)