Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/dstack/_internal/core/models/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def convert_replicas(cls, v: Any) -> Range[int]:
raise ValueError("The minimum number of replicas must be greater than or equal to 0")
if v.max < v.min:
raise ValueError(
"The maximum number of replicas must be greater than or equal to the minium number of replicas"
"The maximum number of replicas must be greater than or equal to the minimum number of replicas"
)
return v

Expand Down
7 changes: 6 additions & 1 deletion src/dstack/_internal/core/models/fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
parse_idle_duration,
)
from dstack._internal.core.models.resources import Range, ResourcesSpec
from dstack._internal.utils.common import list_enum_values_for_annotation
from dstack._internal.utils.json_schema import add_extra_schema_types
from dstack._internal.utils.tags import tags_validator

Expand Down Expand Up @@ -207,7 +208,11 @@ class InstanceGroupParams(CoreModel):
spot_policy: Annotated[
Optional[SpotPolicy],
Field(
description="The policy for provisioning spot or on-demand instances: `spot`, `on-demand`, or `auto`"
description=(
"The policy for provisioning spot or on-demand instances:"
f" {list_enum_values_for_annotation(SpotPolicy)}."
f" Defaults to `{SpotPolicy.ONDEMAND.value}`"
)
),
] = None
retry: Annotated[
Expand Down
46 changes: 43 additions & 3 deletions src/dstack/_internal/core/models/profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import CoreModel, Duration
from dstack._internal.utils.common import list_enum_values_for_annotation
from dstack._internal.utils.tags import tags_validator

DEFAULT_RETRY_DURATION = 3600
Expand All @@ -32,6 +33,17 @@ class TerminationPolicy(str, Enum):
DESTROY_AFTER_IDLE = "destroy-after-idle"


class StartupOrder(str, Enum):
ANY = "any"
MASTER_FIRST = "master-first"
WORKERS_FIRST = "workers-first"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) I'm not sure about calling non-master nodes "workers", because the master node is also a "worker" - it performs the same work other nodes do.

I can suggest to use "secondary" (secondary-first) or avoid any names (master-last). Although we might still need a name to use in the code

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

master/worker is standard terminalogy used for pytorch, mpi, etc. Let's not reinvent.



class StopCriteria(str, Enum):
ALL_DONE = "all-done"
MASTER_DONE = "master-done"


@overload
def parse_duration(v: None) -> None: ...

Expand Down Expand Up @@ -102,7 +114,7 @@ class ProfileRetry(CoreModel):
Field(
description=(
"The list of events that should be handled with retry."
" Supported events are `no-capacity`, `interruption`, and `error`."
f" Supported events are {list_enum_values_for_annotation(RetryEvent)}."
" Omit to retry on all events"
)
),
Expand Down Expand Up @@ -190,7 +202,11 @@ class ProfileParams(CoreModel):
spot_policy: Annotated[
Optional[SpotPolicy],
Field(
description="The policy for provisioning spot or on-demand instances: `spot`, `on-demand`, or `auto`. Defaults to `on-demand`"
description=(
"The policy for provisioning spot or on-demand instances:"
f" {list_enum_values_for_annotation(SpotPolicy)}."
f" Defaults to `{SpotPolicy.ONDEMAND.value}`"
)
),
] = None
retry: Annotated[
Expand Down Expand Up @@ -225,7 +241,11 @@ class ProfileParams(CoreModel):
creation_policy: Annotated[
Optional[CreationPolicy],
Field(
description="The policy for using instances from fleets. Defaults to `reuse-or-create`"
description=(
"The policy for using instances from fleets:"
f" {list_enum_values_for_annotation(CreationPolicy)}."
f" Defaults to `{CreationPolicy.REUSE_OR_CREATE.value}`"
)
),
] = None
idle_duration: Annotated[
Expand All @@ -241,6 +261,26 @@ class ProfileParams(CoreModel):
Optional[UtilizationPolicy],
Field(description="Run termination policy based on utilization"),
] = None
startup_order: Annotated[
Optional[StartupOrder],
Field(
description=(
f"The order in which master and workers jobs are started:"
f" {list_enum_values_for_annotation(StartupOrder)}."
f" Defaults to `{StartupOrder.ANY.value}`"
)
),
] = None
stop_criteria: Annotated[
Optional[StopCriteria],
Field(
description=(
"The criteria determining when a multi-node run should be considered finished:"
f" {list_enum_values_for_annotation(StopCriteria)}."
f" Defaults to `{StopCriteria.ALL_DONE.value}`"
)
),
] = None
fleets: Annotated[
Optional[list[str]], Field(description="The fleets considered for reuse")
] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
SSHConnectionParams,
)
from dstack._internal.core.models.metrics import Metric
from dstack._internal.core.models.profiles import StartupOrder
from dstack._internal.core.models.repos import RemoteRepoCreds
from dstack._internal.core.models.runs import (
ClusterInfo,
Expand Down Expand Up @@ -184,18 +185,10 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
if job_provisioning_data.hostname is None:
await _wait_for_instance_provisioning_data(job_model=job_model)
else:
# Wait until all other jobs in the replica have IPs assigned.
# This is needed to ensure cluster_info has all IPs set.
for other_job in run.jobs:
if (
other_job.job_spec.replica_num == job.job_spec.replica_num
and other_job.job_submissions[-1].status == JobStatus.PROVISIONING
and other_job.job_submissions[-1].job_provisioning_data is not None
and other_job.job_submissions[-1].job_provisioning_data.hostname is None
):
job_model.last_processed_at = common_utils.get_current_datetime()
await session.commit()
return
if _should_wait_for_other_nodes(run, job, job_model):
job_model.last_processed_at = common_utils.get_current_datetime()
await session.commit()
return

# fails are acceptable until timeout is exceeded
if job_provisioning_data.dockerized:
Expand Down Expand Up @@ -406,6 +399,48 @@ async def _wait_for_instance_provisioning_data(job_model: JobModel):
job_model.job_provisioning_data = job_model.instance.job_provisioning_data


def _should_wait_for_other_nodes(run: Run, job: Job, job_model: JobModel) -> bool:
for other_job in run.jobs:
if (
other_job.job_spec.replica_num == job.job_spec.replica_num
and other_job.job_submissions[-1].status == JobStatus.PROVISIONING
and other_job.job_submissions[-1].job_provisioning_data is not None
and other_job.job_submissions[-1].job_provisioning_data.hostname is None
):
logger.debug(
"%s: waiting for other job to have IP assigned",
fmt(job_model),
)
return True
master_job = find_job(run.jobs, job.job_spec.replica_num, 0)
if (
job.job_spec.job_num != 0
and run.run_spec.merged_profile.startup_order == StartupOrder.MASTER_FIRST
and master_job.job_submissions[-1].status != JobStatus.RUNNING
):
logger.debug(
"%s: waiting for master job to become running",
fmt(job_model),
)
return True
if (
job.job_spec.job_num == 0
and run.run_spec.merged_profile.startup_order == StartupOrder.WORKERS_FIRST
):
for other_job in run.jobs:
if (
other_job.job_spec.replica_num == job.job_spec.replica_num
and other_job.job_spec.job_num != job.job_spec.job_num
and other_job.job_submissions[-1].status != JobStatus.RUNNING
):
logger.debug(
"%s: waiting for worker job to become running",
fmt(job_model),
)
return True
return False


@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1)
def _process_provisioning_with_shim(
ports: Dict[int, int],
Expand Down
15 changes: 14 additions & 1 deletion src/dstack/_internal/server/background/tasks/process_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import dstack._internal.server.services.gateways as gateways
import dstack._internal.server.services.services.autoscalers as autoscalers
from dstack._internal.core.errors import ServerError
from dstack._internal.core.models.profiles import RetryEvent
from dstack._internal.core.models.profiles import RetryEvent, StopCriteria
from dstack._internal.core.models.runs import (
Job,
JobStatus,
Expand Down Expand Up @@ -313,6 +313,10 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
termination_reason = RunTerminationReason.RETRY_LIMIT_EXCEEDED
else:
raise ValueError(f"Unexpected termination reason {run_termination_reasons}")
elif _should_stop_on_master_done(run):
new_status = RunStatus.TERMINATING
# ALL_JOBS_DONE is used for all DONE reasons including master-done
termination_reason = RunTerminationReason.ALL_JOBS_DONE
Comment thread
jvstme marked this conversation as resolved.
elif RunStatus.RUNNING in run_statuses:
new_status = RunStatus.RUNNING
elif RunStatus.PROVISIONING in run_statuses:
Expand Down Expand Up @@ -434,3 +438,12 @@ def _can_retry_single_job(run_spec: RunSpec) -> bool:
# We could make partial retry in some multi-node cases.
# E.g. restarting a worker node, independent jobs.
return False


def _should_stop_on_master_done(run: Run) -> bool:
if run.run_spec.merged_profile.stop_criteria != StopCriteria.MASTER_DONE:
return False
for job in run.jobs:
if job.job_spec.job_num == 0 and job.job_submissions[-1].status == JobStatus.DONE:
Copy link
Copy Markdown
Collaborator

@jvstme jvstme Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) Can also check for termination_reason == JobTerminationReason.DONE_BY_RUNNER to terminate the run faster, without waiting for the terminating -> done master job transition. See line 241

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if we want to terminate the run before the master is really done.

return True
return False
3 changes: 2 additions & 1 deletion src/dstack/_internal/server/testing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,12 +308,13 @@ async def create_job(
) -> JobModel:
run_spec = RunSpec.parse_raw(run.run_spec)
job_spec = (await get_job_specs_from_run_spec(run_spec, replica_num=replica_num))[0]
job_spec.job_num = job_num
job = JobModel(
project_id=run.project_id,
run_id=run.id,
run_name=run.run_name,
job_num=job_num,
job_name=run.run_name + f"-0-{replica_num}",
job_name=run.run_name + f"-{job_num}-{replica_num}",
replica_num=replica_num,
submission_num=submission_num,
submitted_at=submitted_at,
Expand Down
4 changes: 4 additions & 0 deletions src/dstack/_internal/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,3 +314,7 @@ def make_proxy_url(server_url: str, proxy_url: str) -> str:
path=concat_url_path(server.path, proxy.path),
)
return proxy.geturl()


def list_enum_values_for_annotation(enum_class: type[enum.Enum]) -> str:
return ", ".join(f"`{e.value}`" for e in enum_class)
4 changes: 4 additions & 0 deletions src/dstack/api/server/_fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ def _get_fleet_spec_excludes(fleet_spec: FleetSpec) -> Optional[Dict]:
configuration_excludes["tags"] = True
if profile.tags is None:
profile_excludes.add("tags")
if profile.startup_order is None:
profile_excludes.add("startup_order")
if profile.stop_criteria is None:
profile_excludes.add("stop_criteria")
if configuration_excludes:
spec_excludes["configuration"] = configuration_excludes
if profile_excludes:
Expand Down
8 changes: 8 additions & 0 deletions src/dstack/api/server/_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,14 @@ def _get_run_spec_excludes(run_spec: RunSpec) -> Optional[Dict]:
configuration_excludes["shell"] = True
if configuration.priority is None:
configuration_excludes["priority"] = True
if configuration.startup_order is None:
configuration_excludes["startup_order"] = True
if profile is not None and profile.startup_order is None:
profile_excludes.add("startup_order")
if configuration.stop_criteria is None:
configuration_excludes["stop_criteria"] = True
if profile is not None and profile.stop_criteria is None:
profile_excludes.add("stop_criteria")

if configuration_excludes:
spec_excludes["configuration"] = configuration_excludes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from dstack._internal.core.models.common import NetworkMode
from dstack._internal.core.models.configurations import DevEnvironmentConfiguration
from dstack._internal.core.models.instances import InstanceStatus
from dstack._internal.core.models.profiles import UtilizationPolicy
from dstack._internal.core.models.profiles import StartupOrder, UtilizationPolicy
from dstack._internal.core.models.runs import (
JobRuntimeData,
JobStatus,
Expand Down Expand Up @@ -805,3 +805,76 @@ async def test_gpu_utilization(
else:
assert job.termination_reason is None
assert job.termination_reason_message is None

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_master_job_waits_for_workers(self, test_db, session: AsyncSession):
project = await create_project(session=session)
user = await create_user(session=session)
repo = await create_repo(
session=session,
project_id=project.id,
)
run_spec = get_run_spec(
run_name="test-run",
repo_id=repo.name,
)
run_spec.configuration.startup_order = StartupOrder.WORKERS_FIRST
run = await create_run(
session=session,
project=project,
repo=repo,
user=user,
run_spec=run_spec,
)
instance1 = await create_instance(
session=session,
project=project,
status=InstanceStatus.BUSY,
)
instance2 = await create_instance(
session=session,
project=project,
status=InstanceStatus.BUSY,
)
job_provisioning_data = get_job_provisioning_data(dockerized=False)
master_job = await create_job(
session=session,
run=run,
status=JobStatus.PROVISIONING,
job_provisioning_data=job_provisioning_data,
instance_assigned=True,
instance=instance1,
job_num=0,
last_processed_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc),
)
worker_job = await create_job(
session=session,
run=run,
status=JobStatus.PROVISIONING,
job_provisioning_data=job_provisioning_data,
instance_assigned=True,
instance=instance2,
job_num=1,
last_processed_at=datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc),
)
await process_running_jobs()
await session.refresh(master_job)
assert master_job.status == JobStatus.PROVISIONING
worker_job.status = JobStatus.RUNNING
# To guarantee master_job is processed next
master_job.last_processed_at = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc)
await session.commit()
with (
patch("dstack._internal.server.services.runner.ssh.SSHTunnel"),
patch(
"dstack._internal.server.services.runner.client.RunnerClient"
) as RunnerClientMock,
):
runner_client_mock = RunnerClientMock.return_value
runner_client_mock.healthcheck.return_value = HealthcheckResponse(
service="dstack-runner", version="0.0.1.dev2"
)
await process_running_jobs()
await session.refresh(master_job)
assert master_job.status == JobStatus.RUNNING
4 changes: 4 additions & 0 deletions src/tests/_internal/server/routers/test_fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,8 @@ async def test_creates_fleet(self, test_db, session: AsyncSession, client: Async
"creation_policy": None,
"idle_duration": None,
"utilization_policy": None,
"startup_order": None,
"stop_criteria": None,
"name": "",
"default": False,
"reservation": None,
Expand Down Expand Up @@ -485,6 +487,8 @@ async def test_creates_ssh_fleet(self, test_db, session: AsyncSession, client: A
"creation_policy": None,
"idle_duration": None,
"utilization_policy": None,
"startup_order": None,
"stop_criteria": None,
"name": "",
"default": False,
"reservation": None,
Expand Down
Loading
Loading