Skip to content

Commit 3158183

Browse files
committed
Allow multi-node tasks on idle shared instances
Fixes: #2650
1 parent 4b15d69 commit 3158183

4 files changed

Lines changed: 120 additions & 22 deletions

File tree

src/dstack/_internal/server/background/tasks/process_submitted_jobs.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -360,16 +360,16 @@ async def _assign_job_to_pool_instance(
360360
(instance, common_utils.get_or_error(get_instance_offer(instance)))
361361
for instance in nonshared_instances
362362
]
363-
if not multinode:
364-
shared_instances_with_offers = get_shared_pool_instances_with_offers(
365-
pool_instances=pool_instances,
366-
profile=profile,
367-
requirements=job.job_spec.requirements,
368-
idle_only=True,
369-
fleet_model=fleet_model,
370-
volumes=volumes,
371-
)
372-
instances_with_offers.extend(shared_instances_with_offers)
363+
shared_instances_with_offers = get_shared_pool_instances_with_offers(
364+
pool_instances=pool_instances,
365+
profile=profile,
366+
requirements=job.job_spec.requirements,
367+
idle_only=True,
368+
fleet_model=fleet_model,
369+
multinode=multinode,
370+
volumes=volumes,
371+
)
372+
instances_with_offers.extend(shared_instances_with_offers)
373373

374374
if len(instances_with_offers) == 0:
375375
return None
@@ -572,7 +572,7 @@ def _create_instance_model_for_job(
572572

573573

574574
def _prepare_job_runtime_data(offer: InstanceOfferWithAvailability) -> JobRuntimeData:
575-
if offer.total_blocks == 1:
575+
if offer.blocks == offer.total_blocks:
576576
if env_utils.get_bool("DSTACK_FORCE_BRIDGE_NETWORK"):
577577
network_mode = NetworkMode.BRIDGE
578578
else:

src/dstack/_internal/server/services/instances.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ def get_shared_pool_instances_with_offers(
235235
*,
236236
idle_only: bool = False,
237237
fleet_model: Optional[FleetModel] = None,
238+
multinode: bool = False,
238239
volumes: Optional[List[List[Volume]]] = None,
239240
) -> list[tuple[InstanceModel, InstanceOfferWithAvailability]]:
240241
instances_with_offers: list[tuple[InstanceModel, InstanceOfferWithAvailability]] = []
@@ -243,19 +244,22 @@ def get_shared_pool_instances_with_offers(
243244
pool_instances=pool_instances,
244245
profile=profile,
245246
fleet_model=fleet_model,
246-
multinode=False,
247+
multinode=multinode,
247248
volumes=volumes,
248249
shared=True,
249250
)
250251
for instance in filtered_instances:
251252
if idle_only and instance.status not in [InstanceStatus.IDLE, InstanceStatus.BUSY]:
252253
continue
254+
if multinode and instance.busy_blocks > 0:
255+
continue
253256
offer = get_instance_offer(instance)
254257
if offer is None:
255258
continue
256259
total_blocks = common_utils.get_or_error(instance.total_blocks)
257260
idle_blocks = total_blocks - instance.busy_blocks
258-
for blocks in range(1, total_blocks + 1):
261+
min_blocks = total_blocks if multinode else 1
262+
for blocks in range(min_blocks, total_blocks + 1):
259263
shared_offer = generate_shared_offer(offer, blocks, total_blocks)
260264
catalog_item = offer_to_catalog_item(shared_offer)
261265
if gpuhunt.matches(catalog_item, query_filter):

src/dstack/_internal/server/services/runs.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -721,15 +721,15 @@ async def _get_pool_offers(
721721
pool_instances = [i for i in pool_instances if i.id not in detaching_instances_ids]
722722
multinode = job.job_spec.jobs_per_replica > 1
723723

724-
if not multinode:
725-
shared_instances_with_offers = get_shared_pool_instances_with_offers(
726-
pool_instances=pool_instances,
727-
profile=run_spec.merged_profile,
728-
requirements=job.job_spec.requirements,
729-
volumes=volumes,
730-
)
731-
for _, offer in shared_instances_with_offers:
732-
pool_offers.append(offer)
724+
shared_instances_with_offers = get_shared_pool_instances_with_offers(
725+
pool_instances=pool_instances,
726+
profile=run_spec.merged_profile,
727+
requirements=job.job_spec.requirements,
728+
volumes=volumes,
729+
multinode=multinode,
730+
)
731+
for _, offer in shared_instances_with_offers:
732+
pool_offers.append(offer)
733733

734734
nonshared_instances = filter_pool_instances(
735735
pool_instances=pool_instances,

src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from sqlalchemy.orm import joinedload
88

99
from dstack._internal.core.models.backends.base import BackendType
10+
from dstack._internal.core.models.configurations import TaskConfiguration
1011
from dstack._internal.core.models.instances import (
1112
InstanceAvailability,
1213
InstanceOfferWithAvailability,
@@ -536,6 +537,99 @@ async def test_assigns_job_to_shared_instance(self, test_db, session: AsyncSessi
536537
assert instance.total_blocks == 4
537538
assert instance.busy_blocks == 2
538539

540+
@pytest.mark.asyncio
541+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
542+
async def test_assigns_multi_node_job_to_shared_instance(self, test_db, session: AsyncSession):
543+
project = await create_project(session)
544+
user = await create_user(session)
545+
repo = await create_repo(
546+
session=session,
547+
project_id=project.id,
548+
)
549+
offer = get_instance_offer_with_availability(gpu_count=8, cpu_count=64, memory_gib=128)
550+
instance = await create_instance(
551+
session=session,
552+
project=project,
553+
status=InstanceStatus.IDLE,
554+
backend=BackendType.AWS,
555+
offer=offer,
556+
total_blocks=4,
557+
busy_blocks=0,
558+
)
559+
configuration = TaskConfiguration(image="debian", nodes=2)
560+
run_spec = get_run_spec(run_name="run", repo_id=repo.name, configuration=configuration)
561+
run = await create_run(
562+
session=session,
563+
run_name="run",
564+
project=project,
565+
repo=repo,
566+
user=user,
567+
run_spec=run_spec,
568+
)
569+
job = await create_job(
570+
session=session,
571+
run=run,
572+
instance_assigned=False,
573+
)
574+
await process_submitted_jobs()
575+
await session.refresh(job)
576+
await session.refresh(instance)
577+
res = await session.execute(select(JobModel).options(joinedload(JobModel.instance)))
578+
job = res.unique().scalar_one()
579+
assert job.status == JobStatus.SUBMITTED
580+
assert job.instance_assigned
581+
assert job.instance is not None
582+
assert job.instance.id == instance.id
583+
assert instance.total_blocks == 4
584+
assert instance.busy_blocks == 4
585+
586+
@pytest.mark.asyncio
587+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
588+
async def test_cannot_assign_multi_node_job_to_partially_busy_shared_instance(
589+
self, test_db, session: AsyncSession
590+
):
591+
project = await create_project(session)
592+
user = await create_user(session)
593+
repo = await create_repo(
594+
session=session,
595+
project_id=project.id,
596+
)
597+
offer = get_instance_offer_with_availability(gpu_count=8, cpu_count=64, memory_gib=128)
598+
instance = await create_instance(
599+
session=session,
600+
project=project,
601+
status=InstanceStatus.IDLE,
602+
backend=BackendType.AWS,
603+
offer=offer,
604+
total_blocks=4,
605+
busy_blocks=1,
606+
)
607+
configuration = TaskConfiguration(image="debian", nodes=2)
608+
run_spec = get_run_spec(run_name="run", repo_id=repo.name, configuration=configuration)
609+
run = await create_run(
610+
session=session,
611+
run_name="run",
612+
project=project,
613+
repo=repo,
614+
user=user,
615+
run_spec=run_spec,
616+
)
617+
job = await create_job(
618+
session=session,
619+
run=run,
620+
instance_assigned=False,
621+
)
622+
await process_submitted_jobs()
623+
await session.refresh(job)
624+
await session.refresh(instance)
625+
res = await session.execute(select(JobModel).options(joinedload(JobModel.instance)))
626+
job = res.unique().scalar_one()
627+
assert job.status == JobStatus.SUBMITTED
628+
assert job.instance_assigned
629+
assert job.instance is None
630+
assert instance.total_blocks == 4
631+
assert instance.busy_blocks == 1
632+
539633
@pytest.mark.asyncio
540634
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
541635
async def test_assigns_job_to_specific_fleet(self, test_db, session: AsyncSession):

0 commit comments

Comments
 (0)