diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index f6b73a61e..bc6ba3235 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -440,7 +440,7 @@ def convert_replicas(cls, v: Any) -> Range[int]: raise ValueError("The minimum number of replicas must be greater than or equal to 0") if v.max < v.min: raise ValueError( - "The maximum number of replicas must be greater than or equal to the minium number of replicas" + "The maximum number of replicas must be greater than or equal to the minimum number of replicas" ) return v diff --git a/src/dstack/_internal/core/models/fleets.py b/src/dstack/_internal/core/models/fleets.py index 0e5580309..6cf970a95 100644 --- a/src/dstack/_internal/core/models/fleets.py +++ b/src/dstack/_internal/core/models/fleets.py @@ -20,6 +20,7 @@ parse_idle_duration, ) from dstack._internal.core.models.resources import Range, ResourcesSpec +from dstack._internal.utils.common import list_enum_values_for_annotation from dstack._internal.utils.json_schema import add_extra_schema_types from dstack._internal.utils.tags import tags_validator @@ -207,7 +208,11 @@ class InstanceGroupParams(CoreModel): spot_policy: Annotated[ Optional[SpotPolicy], Field( - description="The policy for provisioning spot or on-demand instances: `spot`, `on-demand`, or `auto`" + description=( + "The policy for provisioning spot or on-demand instances:" + f" {list_enum_values_for_annotation(SpotPolicy)}." + f" Defaults to `{SpotPolicy.ONDEMAND.value}`" + ) ), ] = None retry: Annotated[ diff --git a/src/dstack/_internal/core/models/profiles.py b/src/dstack/_internal/core/models/profiles.py index fa5296547..62997ce4e 100644 --- a/src/dstack/_internal/core/models/profiles.py +++ b/src/dstack/_internal/core/models/profiles.py @@ -6,6 +6,7 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import CoreModel, Duration +from dstack._internal.utils.common import list_enum_values_for_annotation from dstack._internal.utils.tags import tags_validator DEFAULT_RETRY_DURATION = 3600 @@ -32,6 +33,17 @@ class TerminationPolicy(str, Enum): DESTROY_AFTER_IDLE = "destroy-after-idle" +class StartupOrder(str, Enum): + ANY = "any" + MASTER_FIRST = "master-first" + WORKERS_FIRST = "workers-first" + + +class StopCriteria(str, Enum): + ALL_DONE = "all-done" + MASTER_DONE = "master-done" + + @overload def parse_duration(v: None) -> None: ... @@ -102,7 +114,7 @@ class ProfileRetry(CoreModel): Field( description=( "The list of events that should be handled with retry." - " Supported events are `no-capacity`, `interruption`, and `error`." + f" Supported events are {list_enum_values_for_annotation(RetryEvent)}." " Omit to retry on all events" ) ), @@ -190,7 +202,11 @@ class ProfileParams(CoreModel): spot_policy: Annotated[ Optional[SpotPolicy], Field( - description="The policy for provisioning spot or on-demand instances: `spot`, `on-demand`, or `auto`. Defaults to `on-demand`" + description=( + "The policy for provisioning spot or on-demand instances:" + f" {list_enum_values_for_annotation(SpotPolicy)}." + f" Defaults to `{SpotPolicy.ONDEMAND.value}`" + ) ), ] = None retry: Annotated[ @@ -225,7 +241,11 @@ class ProfileParams(CoreModel): creation_policy: Annotated[ Optional[CreationPolicy], Field( - description="The policy for using instances from fleets. Defaults to `reuse-or-create`" + description=( + "The policy for using instances from fleets:" + f" {list_enum_values_for_annotation(CreationPolicy)}." + f" Defaults to `{CreationPolicy.REUSE_OR_CREATE.value}`" + ) ), ] = None idle_duration: Annotated[ @@ -241,6 +261,26 @@ class ProfileParams(CoreModel): Optional[UtilizationPolicy], Field(description="Run termination policy based on utilization"), ] = None + startup_order: Annotated[ + Optional[StartupOrder], + Field( + description=( + f"The order in which master and workers jobs are started:" + f" {list_enum_values_for_annotation(StartupOrder)}." + f" Defaults to `{StartupOrder.ANY.value}`" + ) + ), + ] = None + stop_criteria: Annotated[ + Optional[StopCriteria], + Field( + description=( + "The criteria determining when a multi-node run should be considered finished:" + f" {list_enum_values_for_annotation(StopCriteria)}." + f" Defaults to `{StopCriteria.ALL_DONE.value}`" + ) + ), + ] = None fleets: Annotated[ Optional[list[str]], Field(description="The fleets considered for reuse") ] = None diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index ea048383c..e05f98fd2 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -18,6 +18,7 @@ SSHConnectionParams, ) from dstack._internal.core.models.metrics import Metric +from dstack._internal.core.models.profiles import StartupOrder from dstack._internal.core.models.repos import RemoteRepoCreds from dstack._internal.core.models.runs import ( ClusterInfo, @@ -184,18 +185,10 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): if job_provisioning_data.hostname is None: await _wait_for_instance_provisioning_data(job_model=job_model) else: - # Wait until all other jobs in the replica have IPs assigned. - # This is needed to ensure cluster_info has all IPs set. - for other_job in run.jobs: - if ( - other_job.job_spec.replica_num == job.job_spec.replica_num - and other_job.job_submissions[-1].status == JobStatus.PROVISIONING - and other_job.job_submissions[-1].job_provisioning_data is not None - and other_job.job_submissions[-1].job_provisioning_data.hostname is None - ): - job_model.last_processed_at = common_utils.get_current_datetime() - await session.commit() - return + if _should_wait_for_other_nodes(run, job, job_model): + job_model.last_processed_at = common_utils.get_current_datetime() + await session.commit() + return # fails are acceptable until timeout is exceeded if job_provisioning_data.dockerized: @@ -406,6 +399,48 @@ async def _wait_for_instance_provisioning_data(job_model: JobModel): job_model.job_provisioning_data = job_model.instance.job_provisioning_data +def _should_wait_for_other_nodes(run: Run, job: Job, job_model: JobModel) -> bool: + for other_job in run.jobs: + if ( + other_job.job_spec.replica_num == job.job_spec.replica_num + and other_job.job_submissions[-1].status == JobStatus.PROVISIONING + and other_job.job_submissions[-1].job_provisioning_data is not None + and other_job.job_submissions[-1].job_provisioning_data.hostname is None + ): + logger.debug( + "%s: waiting for other job to have IP assigned", + fmt(job_model), + ) + return True + master_job = find_job(run.jobs, job.job_spec.replica_num, 0) + if ( + job.job_spec.job_num != 0 + and run.run_spec.merged_profile.startup_order == StartupOrder.MASTER_FIRST + and master_job.job_submissions[-1].status != JobStatus.RUNNING + ): + logger.debug( + "%s: waiting for master job to become running", + fmt(job_model), + ) + return True + if ( + job.job_spec.job_num == 0 + and run.run_spec.merged_profile.startup_order == StartupOrder.WORKERS_FIRST + ): + for other_job in run.jobs: + if ( + other_job.job_spec.replica_num == job.job_spec.replica_num + and other_job.job_spec.job_num != job.job_spec.job_num + and other_job.job_submissions[-1].status != JobStatus.RUNNING + ): + logger.debug( + "%s: waiting for worker job to become running", + fmt(job_model), + ) + return True + return False + + @runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1) def _process_provisioning_with_shim( ports: Dict[int, int], diff --git a/src/dstack/_internal/server/background/tasks/process_runs.py b/src/dstack/_internal/server/background/tasks/process_runs.py index 910166376..547a4cd5a 100644 --- a/src/dstack/_internal/server/background/tasks/process_runs.py +++ b/src/dstack/_internal/server/background/tasks/process_runs.py @@ -10,7 +10,7 @@ import dstack._internal.server.services.gateways as gateways import dstack._internal.server.services.services.autoscalers as autoscalers from dstack._internal.core.errors import ServerError -from dstack._internal.core.models.profiles import RetryEvent +from dstack._internal.core.models.profiles import RetryEvent, StopCriteria from dstack._internal.core.models.runs import ( Job, JobStatus, @@ -313,6 +313,10 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel): termination_reason = RunTerminationReason.RETRY_LIMIT_EXCEEDED else: raise ValueError(f"Unexpected termination reason {run_termination_reasons}") + elif _should_stop_on_master_done(run): + new_status = RunStatus.TERMINATING + # ALL_JOBS_DONE is used for all DONE reasons including master-done + termination_reason = RunTerminationReason.ALL_JOBS_DONE elif RunStatus.RUNNING in run_statuses: new_status = RunStatus.RUNNING elif RunStatus.PROVISIONING in run_statuses: @@ -434,3 +438,12 @@ def _can_retry_single_job(run_spec: RunSpec) -> bool: # We could make partial retry in some multi-node cases. # E.g. restarting a worker node, independent jobs. return False + + +def _should_stop_on_master_done(run: Run) -> bool: + if run.run_spec.merged_profile.stop_criteria != StopCriteria.MASTER_DONE: + return False + for job in run.jobs: + if job.job_spec.job_num == 0 and job.job_submissions[-1].status == JobStatus.DONE: + return True + return False diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index d50ac0115..52526394a 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -308,12 +308,13 @@ async def create_job( ) -> JobModel: run_spec = RunSpec.parse_raw(run.run_spec) job_spec = (await get_job_specs_from_run_spec(run_spec, replica_num=replica_num))[0] + job_spec.job_num = job_num job = JobModel( project_id=run.project_id, run_id=run.id, run_name=run.run_name, job_num=job_num, - job_name=run.run_name + f"-0-{replica_num}", + job_name=run.run_name + f"-{job_num}-{replica_num}", replica_num=replica_num, submission_num=submission_num, submitted_at=submitted_at, diff --git a/src/dstack/_internal/utils/common.py b/src/dstack/_internal/utils/common.py index 33c05d3b4..683207802 100644 --- a/src/dstack/_internal/utils/common.py +++ b/src/dstack/_internal/utils/common.py @@ -314,3 +314,7 @@ def make_proxy_url(server_url: str, proxy_url: str) -> str: path=concat_url_path(server.path, proxy.path), ) return proxy.geturl() + + +def list_enum_values_for_annotation(enum_class: type[enum.Enum]) -> str: + return ", ".join(f"`{e.value}`" for e in enum_class) diff --git a/src/dstack/api/server/_fleets.py b/src/dstack/api/server/_fleets.py index f08f24b5b..8c1bf8b0c 100644 --- a/src/dstack/api/server/_fleets.py +++ b/src/dstack/api/server/_fleets.py @@ -126,6 +126,10 @@ def _get_fleet_spec_excludes(fleet_spec: FleetSpec) -> Optional[Dict]: configuration_excludes["tags"] = True if profile.tags is None: profile_excludes.add("tags") + if profile.startup_order is None: + profile_excludes.add("startup_order") + if profile.stop_criteria is None: + profile_excludes.add("stop_criteria") if configuration_excludes: spec_excludes["configuration"] = configuration_excludes if profile_excludes: diff --git a/src/dstack/api/server/_runs.py b/src/dstack/api/server/_runs.py index d7fd2f4d0..22f994cc4 100644 --- a/src/dstack/api/server/_runs.py +++ b/src/dstack/api/server/_runs.py @@ -188,6 +188,14 @@ def _get_run_spec_excludes(run_spec: RunSpec) -> Optional[Dict]: configuration_excludes["shell"] = True if configuration.priority is None: configuration_excludes["priority"] = True + if configuration.startup_order is None: + configuration_excludes["startup_order"] = True + if profile is not None and profile.startup_order is None: + profile_excludes.add("startup_order") + if configuration.stop_criteria is None: + configuration_excludes["stop_criteria"] = True + if profile is not None and profile.stop_criteria is None: + profile_excludes.add("stop_criteria") if configuration_excludes: spec_excludes["configuration"] = configuration_excludes diff --git a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py b/src/tests/_internal/server/background/tasks/test_process_running_jobs.py index ed36d1af8..7bd91ad78 100644 --- a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_running_jobs.py @@ -12,7 +12,7 @@ from dstack._internal.core.models.common import NetworkMode from dstack._internal.core.models.configurations import DevEnvironmentConfiguration from dstack._internal.core.models.instances import InstanceStatus -from dstack._internal.core.models.profiles import UtilizationPolicy +from dstack._internal.core.models.profiles import StartupOrder, UtilizationPolicy from dstack._internal.core.models.runs import ( JobRuntimeData, JobStatus, @@ -805,3 +805,76 @@ async def test_gpu_utilization( else: assert job.termination_reason is None assert job.termination_reason_message is None + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_master_job_waits_for_workers(self, test_db, session: AsyncSession): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo( + session=session, + project_id=project.id, + ) + run_spec = get_run_spec( + run_name="test-run", + repo_id=repo.name, + ) + run_spec.configuration.startup_order = StartupOrder.WORKERS_FIRST + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_spec=run_spec, + ) + instance1 = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + instance2 = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + job_provisioning_data = get_job_provisioning_data(dockerized=False) + master_job = await create_job( + session=session, + run=run, + status=JobStatus.PROVISIONING, + job_provisioning_data=job_provisioning_data, + instance_assigned=True, + instance=instance1, + job_num=0, + last_processed_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), + ) + worker_job = await create_job( + session=session, + run=run, + status=JobStatus.PROVISIONING, + job_provisioning_data=job_provisioning_data, + instance_assigned=True, + instance=instance2, + job_num=1, + last_processed_at=datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc), + ) + await process_running_jobs() + await session.refresh(master_job) + assert master_job.status == JobStatus.PROVISIONING + worker_job.status = JobStatus.RUNNING + # To guarantee master_job is processed next + master_job.last_processed_at = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + with ( + patch("dstack._internal.server.services.runner.ssh.SSHTunnel"), + patch( + "dstack._internal.server.services.runner.client.RunnerClient" + ) as RunnerClientMock, + ): + runner_client_mock = RunnerClientMock.return_value + runner_client_mock.healthcheck.return_value = HealthcheckResponse( + service="dstack-runner", version="0.0.1.dev2" + ) + await process_running_jobs() + await session.refresh(master_job) + assert master_job.status == JobStatus.RUNNING diff --git a/src/tests/_internal/server/routers/test_fleets.py b/src/tests/_internal/server/routers/test_fleets.py index 39176efab..bac18a5e0 100644 --- a/src/tests/_internal/server/routers/test_fleets.py +++ b/src/tests/_internal/server/routers/test_fleets.py @@ -364,6 +364,8 @@ async def test_creates_fleet(self, test_db, session: AsyncSession, client: Async "creation_policy": None, "idle_duration": None, "utilization_policy": None, + "startup_order": None, + "stop_criteria": None, "name": "", "default": False, "reservation": None, @@ -485,6 +487,8 @@ async def test_creates_ssh_fleet(self, test_db, session: AsyncSession, client: A "creation_policy": None, "idle_duration": None, "utilization_policy": None, + "startup_order": None, + "stop_criteria": None, "name": "", "default": False, "reservation": None, diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 760bbbb66..7fc35661a 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -121,6 +121,8 @@ def get_dev_env_run_plan_dict( "spot_policy": "spot", "idle_duration": None, "utilization_policy": None, + "startup_order": None, + "stop_criteria": None, "reservation": None, "fleets": None, "tags": None, @@ -142,6 +144,8 @@ def get_dev_env_run_plan_dict( "spot_policy": "spot", "idle_duration": None, "utilization_policy": None, + "startup_order": None, + "stop_criteria": None, "reservation": None, "fleets": None, "tags": None, @@ -285,6 +289,8 @@ def get_dev_env_run_dict( "spot_policy": "spot", "idle_duration": None, "utilization_policy": None, + "startup_order": None, + "stop_criteria": None, "reservation": None, "fleets": None, "tags": None, @@ -306,6 +312,8 @@ def get_dev_env_run_dict( "spot_policy": "spot", "idle_duration": None, "utilization_policy": None, + "startup_order": None, + "stop_criteria": None, "reservation": None, "fleets": None, "tags": None,