Skip to content

Commit 2dbd44d

Browse files
committed
Extract submitted-jobs volume attachment
1 parent ed98726 commit 2dbd44d

1 file changed

Lines changed: 54 additions & 41 deletions

File tree

src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py

Lines changed: 54 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -260,36 +260,30 @@ async def _process_submitted_job(
260260
context = await _load_submitted_job_context(session=session, job_model=job_model)
261261
logger.debug("%s: provisioning has started", fmt(context.job_model))
262262

263-
job_model = context.job_model
264-
run_model = context.run_model
265-
run = context.run
266-
job = context.job
267-
run_spec = run.run_spec
268-
269263
master_job_dependency = await _resolve_master_job_dependency(
270264
session=session,
271-
job_model=job_model,
272-
run=run,
273-
job=job,
265+
job_model=context.job_model,
266+
run=context.run,
267+
job=context.job,
274268
)
275269
if master_job_dependency is None:
276270
return
277271
master_job_provisioning_data = master_job_dependency.provisioning_data
278272

279273
if not await _resolve_fleet_dependency(
280274
session=session,
281-
job_model=job_model,
282-
run_model=run_model,
283-
job=job,
275+
job_model=context.job_model,
276+
run_model=context.run_model,
277+
job=context.job,
284278
):
285279
return
286280

287281
prepared_job_volumes = await _prepare_job_volumes(
288282
session=session,
289-
job_model=job_model,
283+
job_model=context.job_model,
290284
project=context.project,
291-
run_spec=run_spec,
292-
job=job,
285+
run_spec=context.run.run_spec,
286+
job=context.job,
293287
)
294288
if prepared_job_volumes is None:
295289
return
@@ -775,34 +769,53 @@ async def _finalize_submitted_job_processing(
775769
jobs_to_provision=provisioning_phase_result.jobs_to_provision,
776770
)
777771

778-
volume_models = prepared_job_volumes.volume_models
779-
volumes_ids = sorted([v.id for vs in volume_models for v in vs])
772+
await _attach_job_volumes_if_needed(
773+
exit_stack=exit_stack,
774+
session=session,
775+
context=context,
776+
prepared_job_volumes=prepared_job_volumes,
777+
provisioning_phase_result=provisioning_phase_result,
778+
)
779+
await session.commit()
780+
781+
782+
async def _attach_job_volumes_if_needed(
783+
exit_stack: AsyncExitStack,
784+
session: AsyncSession,
785+
context: _SubmittedJobContext,
786+
prepared_job_volumes: _PreparedJobVolumes,
787+
provisioning_phase_result: _ProvisioningPhaseResult,
788+
) -> None:
780789
# TODO: Volume attachment for compute groups is not yet supported since
781790
# currently supported compute groups (e.g. Runpod) don't need explicit volume attachment.
782-
if provisioning_phase_result.compute_group_model is None:
783-
# Take lock to prevent attaching volumes that are to be deleted.
784-
# If the volume was deleted before the lock, the volume will fail to attach and the job will fail.
785-
# TODO: Lock instances for attaching volumes?
786-
await session.execute(
787-
select(VolumeModel)
788-
.where(VolumeModel.id.in_(volumes_ids))
789-
.options(joinedload(VolumeModel.user).load_only(UserModel.name))
790-
.order_by(VolumeModel.id) # take locks in order
791-
.with_for_update(key_share=True, of=VolumeModel)
792-
)
793-
await exit_stack.enter_async_context(
794-
get_locker(get_db().dialect_name).lock_ctx(VolumeModel.__tablename__, volumes_ids)
795-
)
796-
if len(volume_models) > 0:
797-
assert len(provisioning_phase_result.instance_models) == 1
798-
await _attach_volumes(
799-
session=session,
800-
project=context.project,
801-
job_model=context.job_model,
802-
instance=provisioning_phase_result.instance_models[0],
803-
volume_models=volume_models,
804-
)
805-
await session.commit()
791+
if provisioning_phase_result.compute_group_model is not None:
792+
return
793+
794+
volume_models = prepared_job_volumes.volume_models
795+
volumes_ids = sorted([v.id for vs in volume_models for v in vs])
796+
# Take lock to prevent attaching volumes that are to be deleted.
797+
# If the volume was deleted before the lock, the volume will fail to attach and the job will fail.
798+
# TODO: Lock instances for attaching volumes?
799+
await session.execute(
800+
select(VolumeModel)
801+
.where(VolumeModel.id.in_(volumes_ids))
802+
.options(joinedload(VolumeModel.user).load_only(UserModel.name))
803+
.order_by(VolumeModel.id) # take locks in order
804+
.with_for_update(key_share=True, of=VolumeModel)
805+
)
806+
await exit_stack.enter_async_context(
807+
get_locker(get_db().dialect_name).lock_ctx(VolumeModel.__tablename__, volumes_ids)
808+
)
809+
if len(volume_models) == 0:
810+
return
811+
assert len(provisioning_phase_result.instance_models) == 1
812+
await _attach_volumes(
813+
session=session,
814+
project=context.project,
815+
job_model=context.job_model,
816+
instance=provisioning_phase_result.instance_models[0],
817+
volume_models=volume_models,
818+
)
806819

807820

808821
async def _defer_submitted_job(

0 commit comments

Comments
 (0)