Skip to content

Commit 780e7ca

Browse files
committed
Support provisioning in empty fleets
1 parent 12463ea commit 780e7ca

5 files changed

Lines changed: 293 additions & 154 deletions

File tree

src/dstack/_internal/server/background/tasks/process_submitted_jobs.py

Lines changed: 121 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import asyncio
2+
import itertools
3+
import math
24
import uuid
35
from datetime import datetime, timedelta
46
from typing import List, Optional, Tuple
57

6-
from sqlalchemy import select
8+
from sqlalchemy import and_, or_, select
79
from sqlalchemy.ext.asyncio import AsyncSession
8-
from sqlalchemy.orm import joinedload, load_only, selectinload
10+
from sqlalchemy.orm import contains_eager, joinedload, load_only, selectinload
911

1012
from dstack._internal.core.backends.base.backend import Backend
1113
from dstack._internal.core.backends.base.compute import ComputeWithVolumeSupport
@@ -51,6 +53,7 @@
5153
from dstack._internal.server.services.backends import get_project_backend_by_type_or_error
5254
from dstack._internal.server.services.fleets import (
5355
fleet_model_to_fleet,
56+
get_fleet_spec,
5457
)
5558
from dstack._internal.server.services.instances import (
5659
filter_pool_instances,
@@ -158,7 +161,10 @@ async def _process_next_submitted_job():
158161
async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
159162
# Refetch to load related attributes.
160163
res = await session.execute(
161-
select(JobModel).where(JobModel.id == job_model.id).options(joinedload(JobModel.instance))
164+
select(JobModel)
165+
.where(JobModel.id == job_model.id)
166+
.options(joinedload(JobModel.instance))
167+
.options(joinedload(JobModel.fleet).joinedload(FleetModel.instances))
162168
)
163169
job_model = res.unique().scalar_one()
164170
res = await session.execute(
@@ -177,6 +183,12 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
177183
profile = run_spec.merged_profile
178184
job = find_job(run.jobs, job_model.replica_num, job_model.job_num)
179185

186+
# Master job chooses fleet for the run.
187+
# Due to two-step processing, it's saved to job_model.fleet.
188+
# Other jobs just inherit fleet from run_model.fleet.
189+
# If master job chooses no fleet, the new fleet will be created.
190+
fleet_model = run_model.fleet or job_model.fleet
191+
180192
master_job = find_job(run.jobs, job_model.replica_num, 0)
181193
master_job_provisioning_data = None
182194
if job.job_spec.job_num != 0:
@@ -224,19 +236,36 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
224236
# Then, the job runs on the assigned instance or a new instance is provisioned.
225237
# This is needed to avoid holding instances lock for a long time.
226238
if not job_model.instance_assigned:
227-
# Try assigning an existing instance
239+
fleet_filters = [
240+
FleetModel.project_id == project.id,
241+
FleetModel.deleted == False,
242+
]
243+
if run_model.fleet is not None:
244+
fleet_filters.append(FleetModel.id == run_model.fleet_id)
245+
if run_spec.configuration.fleets is not None:
246+
fleet_filters.append(FleetModel.name.in_(run_spec.configuration.fleets))
228247
res = await session.execute(
229-
select(InstanceModel)
248+
select(FleetModel)
249+
.outerjoin(FleetModel.instances)
250+
.where(*fleet_filters)
230251
.where(
231-
InstanceModel.project_id == project.id,
232-
InstanceModel.deleted == False,
233-
InstanceModel.total_blocks > InstanceModel.busy_blocks,
252+
or_(
253+
InstanceModel.id.is_(None),
254+
and_(
255+
InstanceModel.deleted == False,
256+
InstanceModel.total_blocks > InstanceModel.busy_blocks,
257+
),
258+
)
234259
)
260+
.options(contains_eager(FleetModel.instances))
235261
.order_by(InstanceModel.id) # take locks in order
236-
.with_for_update(key_share=True)
262+
.with_for_update(key_share=True, of=InstanceModel)
263+
)
264+
fleet_models = list(res.unique().scalars().all())
265+
fleets_ids = sorted([f.id for f in fleet_models])
266+
instances_ids = sorted(
267+
itertools.chain.from_iterable([i.id for i in f.instances] for f in fleet_models)
237268
)
238-
pool_instances = list(res.unique().scalars().all())
239-
instances_ids = sorted([i.id for i in pool_instances])
240269
if get_db().dialect_name == "sqlite":
241270
# Start new transaction to see committed changes after lock
242271
await session.commit()
@@ -248,30 +277,77 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
248277
detaching_instances_ids = await get_instances_ids_with_detaching_volumes(session)
249278
# Refetch after lock
250279
res = await session.execute(
251-
select(InstanceModel)
280+
select(FleetModel)
281+
.outerjoin(FleetModel.instances)
252282
.where(
253-
InstanceModel.id.not_in(detaching_instances_ids),
254-
InstanceModel.id.in_(instances_ids),
255-
InstanceModel.deleted == False,
256-
InstanceModel.total_blocks > InstanceModel.busy_blocks,
283+
FleetModel.id.in_(fleets_ids),
284+
*fleet_filters,
285+
)
286+
.where(
287+
or_(
288+
InstanceModel.id.is_(None),
289+
and_(
290+
InstanceModel.id.not_in(detaching_instances_ids),
291+
InstanceModel.id.in_(instances_ids),
292+
InstanceModel.deleted == False,
293+
InstanceModel.total_blocks > InstanceModel.busy_blocks,
294+
),
295+
)
257296
)
258-
.options(joinedload(InstanceModel.fleet))
297+
.options(contains_eager(FleetModel.instances))
259298
.execution_options(populate_existing=True)
260299
)
261-
pool_instances = list(res.unique().scalars().all())
262-
instance = await _assign_job_to_pool_instance(
300+
fleet_models = list(res.unique().scalars().all())
301+
fleet_instances_with_offers = []
302+
for candidate_fleet_model in fleet_models:
303+
fleet_instances_with_offers = await _get_fleet_instances_with_offers(
304+
fleet_model=candidate_fleet_model,
305+
run_spec=run_spec,
306+
job=job,
307+
master_job_provisioning_data=master_job_provisioning_data,
308+
volumes=volumes,
309+
)
310+
if run_model.fleet_id is not None:
311+
# Using the first fleet that was already chosen by the master job.
312+
fleet_model = candidate_fleet_model
313+
break
314+
# Looking for an eligible fleet for the run.
315+
# TODO: Pick optimal fleet instead of the first eligible one.
316+
fleet_spec = get_fleet_spec(candidate_fleet_model)
317+
fleet_capacity = len(
318+
[o for o in fleet_instances_with_offers if o[1].availability.is_available()]
319+
)
320+
if fleet_spec.configuration.nodes is not None:
321+
if fleet_spec.configuration.nodes.max is None:
322+
fleet_capacity = math.inf
323+
else:
324+
# FIXME: Multiple service jobs can be provisioned on one instance with blocks.
325+
# Current capacity calculation does not take future provisioned blocks into account.
326+
# It may be impossible to do since we cannot be sure which instance will be provisioned.
327+
fleet_capacity += fleet_spec.configuration.nodes.max - len(
328+
candidate_fleet_model.instances
329+
)
330+
instances_required = 1
331+
if run_spec.configuration.type == "task":
332+
instances_required = run_spec.configuration.nodes
333+
elif (
334+
run_spec.configuration.type == "service"
335+
and run_spec.configuration.replicas.min is not None
336+
):
337+
instances_required = run_spec.configuration.replicas.min
338+
if fleet_capacity >= instances_required:
339+
# TODO: Ensure we use the chosen fleet when there are no instance assigned.
340+
fleet_model = candidate_fleet_model
341+
break
342+
instance = await _assign_job_to_fleet_instance(
263343
session=session,
264-
pool_instances=pool_instances,
265-
run_spec=run_spec,
344+
instances_with_offers=fleet_instances_with_offers,
266345
job_model=job_model,
267-
job=job,
268-
fleet_model=run_model.fleet,
269-
master_job_provisioning_data=master_job_provisioning_data,
270-
volumes=volumes,
271346
)
347+
job_model.fleet = fleet_model
272348
job_model.instance_assigned = True
273349
job_model.last_processed_at = common_utils.get_current_datetime()
274-
if len(pool_instances) > 0:
350+
if len(instances_ids) > 0:
275351
await session.commit()
276352
return
277353
# If no instances were locked, we can proceed in the same transaction.
@@ -298,7 +374,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
298374
# Create a new cloud instance
299375
run_job_result = await _run_job_on_new_instance(
300376
project=project,
301-
fleet_model=run_model.fleet,
377+
fleet_model=fleet_model,
302378
job_model=job_model,
303379
run=run,
304380
job=job,
@@ -319,11 +395,11 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
319395
job_provisioning_data, offer = run_job_result
320396
job_model.job_provisioning_data = job_provisioning_data.json()
321397
job_model.status = JobStatus.PROVISIONING
322-
fleet_model = _get_or_create_fleet_model_for_job(
323-
project=project,
324-
run_model=run_model,
325-
run=run,
326-
)
398+
if fleet_model is None:
399+
fleet_model = _create_fleet_model_for_job(
400+
project=project,
401+
run=run,
402+
)
327403
instance_num = await _get_next_instance_num(
328404
session=session,
329405
fleet_model=fleet_model,
@@ -377,16 +453,14 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
377453
await session.commit()
378454

379455

380-
async def _assign_job_to_pool_instance(
381-
session: AsyncSession,
382-
pool_instances: List[InstanceModel],
456+
async def _get_fleet_instances_with_offers(
457+
fleet_model: FleetModel,
383458
run_spec: RunSpec,
384-
job_model: JobModel,
385459
job: Job,
386-
fleet_model: Optional[FleetModel],
387460
master_job_provisioning_data: Optional[JobProvisioningData] = None,
388461
volumes: Optional[List[List[Volume]]] = None,
389-
) -> Optional[InstanceModel]:
462+
) -> list[tuple[InstanceModel, InstanceOfferWithAvailability]]:
463+
pool_instances = fleet_model.instances
390464
instances_with_offers: list[tuple[InstanceModel, InstanceOfferWithAvailability]]
391465
profile = run_spec.merged_profile
392466
multinode = job.job_spec.jobs_per_replica > 1
@@ -415,7 +489,15 @@ async def _assign_job_to_pool_instance(
415489
volumes=volumes,
416490
)
417491
instances_with_offers.extend(shared_instances_with_offers)
492+
instances_with_offers.sort(key=lambda instance_with_offer: instance_with_offer[0].price or 0)
493+
return instances_with_offers
494+
418495

496+
async def _assign_job_to_fleet_instance(
497+
session: AsyncSession,
498+
instances_with_offers: list[tuple[InstanceModel, InstanceOfferWithAvailability]],
499+
job_model: JobModel,
500+
) -> Optional[InstanceModel]:
419501
if len(instances_with_offers) == 0:
420502
return None
421503

@@ -543,13 +625,10 @@ def _check_can_create_new_instance_in_fleet(fleet: Fleet) -> bool:
543625
return True
544626

545627

546-
def _get_or_create_fleet_model_for_job(
628+
def _create_fleet_model_for_job(
547629
project: ProjectModel,
548-
run_model: RunModel,
549630
run: Run,
550631
) -> FleetModel:
551-
if run_model.fleet is not None:
552-
return run_model.fleet
553632
placement = InstanceGroupPlacement.ANY
554633
if run.run_spec.configuration.type == "task" and run.run_spec.configuration.nodes > 1:
555634
placement = InstanceGroupPlacement.CLUSTER
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""Add JobModel.fleet
2+
3+
Revision ID: ecd3cfc5c86e
4+
Revises: 728b1488b1b4
5+
Create Date: 2025-08-08 17:51:27.267140
6+
7+
"""
8+
9+
import sqlalchemy as sa
10+
import sqlalchemy_utils
11+
from alembic import op
12+
13+
# revision identifiers, used by Alembic.
14+
revision = "ecd3cfc5c86e"
15+
down_revision = "728b1488b1b4"
16+
branch_labels = None
17+
depends_on = None
18+
19+
20+
def upgrade() -> None:
21+
# ### commands auto generated by Alembic - please adjust! ###
22+
with op.batch_alter_table("jobs", schema=None) as batch_op:
23+
batch_op.add_column(
24+
sa.Column(
25+
"fleet_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=True
26+
)
27+
)
28+
batch_op.create_foreign_key(
29+
batch_op.f("fk_jobs_fleet_id_fleets"), "fleets", ["fleet_id"], ["id"]
30+
)
31+
32+
# ### end Alembic commands ###
33+
34+
35+
def downgrade() -> None:
36+
# ### commands auto generated by Alembic - please adjust! ###
37+
with op.batch_alter_table("jobs", schema=None) as batch_op:
38+
batch_op.drop_constraint(batch_op.f("fk_jobs_fleet_id_fleets"), type_="foreignkey")
39+
batch_op.drop_column("fleet_id")
40+
41+
# ### end Alembic commands ###

src/dstack/_internal/server/models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,10 +390,18 @@ class JobModel(BaseModel):
390390
id: Mapped[uuid.UUID] = mapped_column(
391391
UUIDType(binary=False), primary_key=True, default=uuid.uuid4
392392
)
393+
393394
project_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("projects.id", ondelete="CASCADE"))
394395
project: Mapped["ProjectModel"] = relationship()
396+
395397
run_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("runs.id", ondelete="CASCADE"))
396398
run: Mapped["RunModel"] = relationship()
399+
400+
# Jobs need to reference fleets because we may choose an optimal fleet for a master job
401+
# but not yet create an instance for it.
402+
fleet_id: Mapped[Optional[uuid.UUID]] = mapped_column(ForeignKey("fleets.id"))
403+
fleet: Mapped[Optional["FleetModel"]] = relationship(back_populates="jobs")
404+
397405
run_name: Mapped[str] = mapped_column(String(100))
398406
job_num: Mapped[int] = mapped_column(Integer)
399407
job_name: Mapped[str] = mapped_column(String(100))
@@ -537,6 +545,7 @@ class FleetModel(BaseModel):
537545
spec: Mapped[str] = mapped_column(Text)
538546

539547
runs: Mapped[List["RunModel"]] = relationship(back_populates="fleet")
548+
jobs: Mapped[List["JobModel"]] = relationship(back_populates="fleet")
540549
instances: Mapped[List["InstanceModel"]] = relationship(back_populates="fleet")
541550

542551

src/dstack/_internal/server/testing/common.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ async def create_run(
330330
async def create_job(
331331
session: AsyncSession,
332332
run: RunModel,
333+
fleet: Optional[FleetModel] = None,
333334
submission_num: int = 0,
334335
status: JobStatus = JobStatus.SUBMITTED,
335336
submitted_at: datetime = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc),
@@ -353,6 +354,7 @@ async def create_job(
353354
job_spec.job_num = job_num
354355
job = JobModel(
355356
project_id=run.project_id,
357+
fleet=fleet,
356358
run_id=run.id,
357359
run_name=run.run_name,
358360
job_num=job_num,
@@ -733,6 +735,7 @@ def get_instance_offer_with_availability(
733735
availability_zones: Optional[List[str]] = None,
734736
price: float = 1.0,
735737
instance_type: str = "instance",
738+
availability: InstanceAvailability = InstanceAvailability.AVAILABLE,
736739
):
737740
gpus = [
738741
Gpu(
@@ -756,7 +759,7 @@ def get_instance_offer_with_availability(
756759
),
757760
region=region,
758761
price=price,
759-
availability=InstanceAvailability.AVAILABLE,
762+
availability=availability,
760763
availability_zones=availability_zones,
761764
blocks=blocks,
762765
total_blocks=total_blocks,

0 commit comments

Comments
 (0)