Skip to content

Commit 2ddae6e

Browse files
authored
Implement startup_order and stop_criteria (#2714)
* Implement startup_order * Implement stop_criteria * Fix tests * Test master job waiting for workers * Exclude new fields for client backward compatibility * Add reference for startup_order and stop_criteria * Fix hardcoded enum values in reference * Minor fixes
1 parent 918a921 commit 2ddae6e

12 files changed

Lines changed: 215 additions & 20 deletions

File tree

src/dstack/_internal/core/models/configurations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ def convert_replicas(cls, v: Any) -> Range[int]:
440440
raise ValueError("The minimum number of replicas must be greater than or equal to 0")
441441
if v.max < v.min:
442442
raise ValueError(
443-
"The maximum number of replicas must be greater than or equal to the minium number of replicas"
443+
"The maximum number of replicas must be greater than or equal to the minimum number of replicas"
444444
)
445445
return v
446446

src/dstack/_internal/core/models/fleets.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
parse_idle_duration,
2121
)
2222
from dstack._internal.core.models.resources import Range, ResourcesSpec
23+
from dstack._internal.utils.common import list_enum_values_for_annotation
2324
from dstack._internal.utils.json_schema import add_extra_schema_types
2425
from dstack._internal.utils.tags import tags_validator
2526

@@ -207,7 +208,11 @@ class InstanceGroupParams(CoreModel):
207208
spot_policy: Annotated[
208209
Optional[SpotPolicy],
209210
Field(
210-
description="The policy for provisioning spot or on-demand instances: `spot`, `on-demand`, or `auto`"
211+
description=(
212+
"The policy for provisioning spot or on-demand instances:"
213+
f" {list_enum_values_for_annotation(SpotPolicy)}."
214+
f" Defaults to `{SpotPolicy.ONDEMAND.value}`"
215+
)
211216
),
212217
] = None
213218
retry: Annotated[

src/dstack/_internal/core/models/profiles.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from dstack._internal.core.models.backends.base import BackendType
88
from dstack._internal.core.models.common import CoreModel, Duration
9+
from dstack._internal.utils.common import list_enum_values_for_annotation
910
from dstack._internal.utils.tags import tags_validator
1011

1112
DEFAULT_RETRY_DURATION = 3600
@@ -32,6 +33,17 @@ class TerminationPolicy(str, Enum):
3233
DESTROY_AFTER_IDLE = "destroy-after-idle"
3334

3435

36+
class StartupOrder(str, Enum):
37+
ANY = "any"
38+
MASTER_FIRST = "master-first"
39+
WORKERS_FIRST = "workers-first"
40+
41+
42+
class StopCriteria(str, Enum):
43+
ALL_DONE = "all-done"
44+
MASTER_DONE = "master-done"
45+
46+
3547
@overload
3648
def parse_duration(v: None) -> None: ...
3749

@@ -102,7 +114,7 @@ class ProfileRetry(CoreModel):
102114
Field(
103115
description=(
104116
"The list of events that should be handled with retry."
105-
" Supported events are `no-capacity`, `interruption`, and `error`."
117+
f" Supported events are {list_enum_values_for_annotation(RetryEvent)}."
106118
" Omit to retry on all events"
107119
)
108120
),
@@ -190,7 +202,11 @@ class ProfileParams(CoreModel):
190202
spot_policy: Annotated[
191203
Optional[SpotPolicy],
192204
Field(
193-
description="The policy for provisioning spot or on-demand instances: `spot`, `on-demand`, or `auto`. Defaults to `on-demand`"
205+
description=(
206+
"The policy for provisioning spot or on-demand instances:"
207+
f" {list_enum_values_for_annotation(SpotPolicy)}."
208+
f" Defaults to `{SpotPolicy.ONDEMAND.value}`"
209+
)
194210
),
195211
] = None
196212
retry: Annotated[
@@ -225,7 +241,11 @@ class ProfileParams(CoreModel):
225241
creation_policy: Annotated[
226242
Optional[CreationPolicy],
227243
Field(
228-
description="The policy for using instances from fleets. Defaults to `reuse-or-create`"
244+
description=(
245+
"The policy for using instances from fleets:"
246+
f" {list_enum_values_for_annotation(CreationPolicy)}."
247+
f" Defaults to `{CreationPolicy.REUSE_OR_CREATE.value}`"
248+
)
229249
),
230250
] = None
231251
idle_duration: Annotated[
@@ -241,6 +261,26 @@ class ProfileParams(CoreModel):
241261
Optional[UtilizationPolicy],
242262
Field(description="Run termination policy based on utilization"),
243263
] = None
264+
startup_order: Annotated[
265+
Optional[StartupOrder],
266+
Field(
267+
description=(
268+
f"The order in which master and workers jobs are started:"
269+
f" {list_enum_values_for_annotation(StartupOrder)}."
270+
f" Defaults to `{StartupOrder.ANY.value}`"
271+
)
272+
),
273+
] = None
274+
stop_criteria: Annotated[
275+
Optional[StopCriteria],
276+
Field(
277+
description=(
278+
"The criteria determining when a multi-node run should be considered finished:"
279+
f" {list_enum_values_for_annotation(StopCriteria)}."
280+
f" Defaults to `{StopCriteria.ALL_DONE.value}`"
281+
)
282+
),
283+
] = None
244284
fleets: Annotated[
245285
Optional[list[str]], Field(description="The fleets considered for reuse")
246286
] = None

src/dstack/_internal/server/background/tasks/process_running_jobs.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
SSHConnectionParams,
1919
)
2020
from dstack._internal.core.models.metrics import Metric
21+
from dstack._internal.core.models.profiles import StartupOrder
2122
from dstack._internal.core.models.repos import RemoteRepoCreds
2223
from dstack._internal.core.models.runs import (
2324
ClusterInfo,
@@ -184,18 +185,10 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
184185
if job_provisioning_data.hostname is None:
185186
await _wait_for_instance_provisioning_data(job_model=job_model)
186187
else:
187-
# Wait until all other jobs in the replica have IPs assigned.
188-
# This is needed to ensure cluster_info has all IPs set.
189-
for other_job in run.jobs:
190-
if (
191-
other_job.job_spec.replica_num == job.job_spec.replica_num
192-
and other_job.job_submissions[-1].status == JobStatus.PROVISIONING
193-
and other_job.job_submissions[-1].job_provisioning_data is not None
194-
and other_job.job_submissions[-1].job_provisioning_data.hostname is None
195-
):
196-
job_model.last_processed_at = common_utils.get_current_datetime()
197-
await session.commit()
198-
return
188+
if _should_wait_for_other_nodes(run, job, job_model):
189+
job_model.last_processed_at = common_utils.get_current_datetime()
190+
await session.commit()
191+
return
199192

200193
# fails are acceptable until timeout is exceeded
201194
if job_provisioning_data.dockerized:
@@ -406,6 +399,48 @@ async def _wait_for_instance_provisioning_data(job_model: JobModel):
406399
job_model.job_provisioning_data = job_model.instance.job_provisioning_data
407400

408401

402+
def _should_wait_for_other_nodes(run: Run, job: Job, job_model: JobModel) -> bool:
403+
for other_job in run.jobs:
404+
if (
405+
other_job.job_spec.replica_num == job.job_spec.replica_num
406+
and other_job.job_submissions[-1].status == JobStatus.PROVISIONING
407+
and other_job.job_submissions[-1].job_provisioning_data is not None
408+
and other_job.job_submissions[-1].job_provisioning_data.hostname is None
409+
):
410+
logger.debug(
411+
"%s: waiting for other job to have IP assigned",
412+
fmt(job_model),
413+
)
414+
return True
415+
master_job = find_job(run.jobs, job.job_spec.replica_num, 0)
416+
if (
417+
job.job_spec.job_num != 0
418+
and run.run_spec.merged_profile.startup_order == StartupOrder.MASTER_FIRST
419+
and master_job.job_submissions[-1].status != JobStatus.RUNNING
420+
):
421+
logger.debug(
422+
"%s: waiting for master job to become running",
423+
fmt(job_model),
424+
)
425+
return True
426+
if (
427+
job.job_spec.job_num == 0
428+
and run.run_spec.merged_profile.startup_order == StartupOrder.WORKERS_FIRST
429+
):
430+
for other_job in run.jobs:
431+
if (
432+
other_job.job_spec.replica_num == job.job_spec.replica_num
433+
and other_job.job_spec.job_num != job.job_spec.job_num
434+
and other_job.job_submissions[-1].status != JobStatus.RUNNING
435+
):
436+
logger.debug(
437+
"%s: waiting for worker job to become running",
438+
fmt(job_model),
439+
)
440+
return True
441+
return False
442+
443+
409444
@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1)
410445
def _process_provisioning_with_shim(
411446
ports: Dict[int, int],

src/dstack/_internal/server/background/tasks/process_runs.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import dstack._internal.server.services.gateways as gateways
1111
import dstack._internal.server.services.services.autoscalers as autoscalers
1212
from dstack._internal.core.errors import ServerError
13-
from dstack._internal.core.models.profiles import RetryEvent
13+
from dstack._internal.core.models.profiles import RetryEvent, StopCriteria
1414
from dstack._internal.core.models.runs import (
1515
Job,
1616
JobStatus,
@@ -313,6 +313,10 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
313313
termination_reason = RunTerminationReason.RETRY_LIMIT_EXCEEDED
314314
else:
315315
raise ValueError(f"Unexpected termination reason {run_termination_reasons}")
316+
elif _should_stop_on_master_done(run):
317+
new_status = RunStatus.TERMINATING
318+
# ALL_JOBS_DONE is used for all DONE reasons including master-done
319+
termination_reason = RunTerminationReason.ALL_JOBS_DONE
316320
elif RunStatus.RUNNING in run_statuses:
317321
new_status = RunStatus.RUNNING
318322
elif RunStatus.PROVISIONING in run_statuses:
@@ -434,3 +438,12 @@ def _can_retry_single_job(run_spec: RunSpec) -> bool:
434438
# We could make partial retry in some multi-node cases.
435439
# E.g. restarting a worker node, independent jobs.
436440
return False
441+
442+
443+
def _should_stop_on_master_done(run: Run) -> bool:
444+
if run.run_spec.merged_profile.stop_criteria != StopCriteria.MASTER_DONE:
445+
return False
446+
for job in run.jobs:
447+
if job.job_spec.job_num == 0 and job.job_submissions[-1].status == JobStatus.DONE:
448+
return True
449+
return False

src/dstack/_internal/server/testing/common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,12 +308,13 @@ async def create_job(
308308
) -> JobModel:
309309
run_spec = RunSpec.parse_raw(run.run_spec)
310310
job_spec = (await get_job_specs_from_run_spec(run_spec, replica_num=replica_num))[0]
311+
job_spec.job_num = job_num
311312
job = JobModel(
312313
project_id=run.project_id,
313314
run_id=run.id,
314315
run_name=run.run_name,
315316
job_num=job_num,
316-
job_name=run.run_name + f"-0-{replica_num}",
317+
job_name=run.run_name + f"-{job_num}-{replica_num}",
317318
replica_num=replica_num,
318319
submission_num=submission_num,
319320
submitted_at=submitted_at,

src/dstack/_internal/utils/common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,3 +314,7 @@ def make_proxy_url(server_url: str, proxy_url: str) -> str:
314314
path=concat_url_path(server.path, proxy.path),
315315
)
316316
return proxy.geturl()
317+
318+
319+
def list_enum_values_for_annotation(enum_class: type[enum.Enum]) -> str:
320+
return ", ".join(f"`{e.value}`" for e in enum_class)

src/dstack/api/server/_fleets.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,10 @@ def _get_fleet_spec_excludes(fleet_spec: FleetSpec) -> Optional[Dict]:
126126
configuration_excludes["tags"] = True
127127
if profile.tags is None:
128128
profile_excludes.add("tags")
129+
if profile.startup_order is None:
130+
profile_excludes.add("startup_order")
131+
if profile.stop_criteria is None:
132+
profile_excludes.add("stop_criteria")
129133
if configuration_excludes:
130134
spec_excludes["configuration"] = configuration_excludes
131135
if profile_excludes:

src/dstack/api/server/_runs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,14 @@ def _get_run_spec_excludes(run_spec: RunSpec) -> Optional[Dict]:
188188
configuration_excludes["shell"] = True
189189
if configuration.priority is None:
190190
configuration_excludes["priority"] = True
191+
if configuration.startup_order is None:
192+
configuration_excludes["startup_order"] = True
193+
if profile is not None and profile.startup_order is None:
194+
profile_excludes.add("startup_order")
195+
if configuration.stop_criteria is None:
196+
configuration_excludes["stop_criteria"] = True
197+
if profile is not None and profile.stop_criteria is None:
198+
profile_excludes.add("stop_criteria")
191199

192200
if configuration_excludes:
193201
spec_excludes["configuration"] = configuration_excludes

src/tests/_internal/server/background/tasks/test_process_running_jobs.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from dstack._internal.core.models.common import NetworkMode
1313
from dstack._internal.core.models.configurations import DevEnvironmentConfiguration
1414
from dstack._internal.core.models.instances import InstanceStatus
15-
from dstack._internal.core.models.profiles import UtilizationPolicy
15+
from dstack._internal.core.models.profiles import StartupOrder, UtilizationPolicy
1616
from dstack._internal.core.models.runs import (
1717
JobRuntimeData,
1818
JobStatus,
@@ -805,3 +805,76 @@ async def test_gpu_utilization(
805805
else:
806806
assert job.termination_reason is None
807807
assert job.termination_reason_message is None
808+
809+
@pytest.mark.asyncio
810+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
811+
async def test_master_job_waits_for_workers(self, test_db, session: AsyncSession):
812+
project = await create_project(session=session)
813+
user = await create_user(session=session)
814+
repo = await create_repo(
815+
session=session,
816+
project_id=project.id,
817+
)
818+
run_spec = get_run_spec(
819+
run_name="test-run",
820+
repo_id=repo.name,
821+
)
822+
run_spec.configuration.startup_order = StartupOrder.WORKERS_FIRST
823+
run = await create_run(
824+
session=session,
825+
project=project,
826+
repo=repo,
827+
user=user,
828+
run_spec=run_spec,
829+
)
830+
instance1 = await create_instance(
831+
session=session,
832+
project=project,
833+
status=InstanceStatus.BUSY,
834+
)
835+
instance2 = await create_instance(
836+
session=session,
837+
project=project,
838+
status=InstanceStatus.BUSY,
839+
)
840+
job_provisioning_data = get_job_provisioning_data(dockerized=False)
841+
master_job = await create_job(
842+
session=session,
843+
run=run,
844+
status=JobStatus.PROVISIONING,
845+
job_provisioning_data=job_provisioning_data,
846+
instance_assigned=True,
847+
instance=instance1,
848+
job_num=0,
849+
last_processed_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc),
850+
)
851+
worker_job = await create_job(
852+
session=session,
853+
run=run,
854+
status=JobStatus.PROVISIONING,
855+
job_provisioning_data=job_provisioning_data,
856+
instance_assigned=True,
857+
instance=instance2,
858+
job_num=1,
859+
last_processed_at=datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc),
860+
)
861+
await process_running_jobs()
862+
await session.refresh(master_job)
863+
assert master_job.status == JobStatus.PROVISIONING
864+
worker_job.status = JobStatus.RUNNING
865+
# To guarantee master_job is processed next
866+
master_job.last_processed_at = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc)
867+
await session.commit()
868+
with (
869+
patch("dstack._internal.server.services.runner.ssh.SSHTunnel"),
870+
patch(
871+
"dstack._internal.server.services.runner.client.RunnerClient"
872+
) as RunnerClientMock,
873+
):
874+
runner_client_mock = RunnerClientMock.return_value
875+
runner_client_mock.healthcheck.return_value = HealthcheckResponse(
876+
service="dstack-runner", version="0.0.1.dev2"
877+
)
878+
await process_running_jobs()
879+
await session.refresh(master_job)
880+
assert master_job.status == JobStatus.RUNNING

0 commit comments

Comments
 (0)