@@ -260,36 +260,30 @@ async def _process_submitted_job(
260260 context = await _load_submitted_job_context (session = session , job_model = job_model )
261261 logger .debug ("%s: provisioning has started" , fmt (context .job_model ))
262262
263- job_model = context .job_model
264- run_model = context .run_model
265- run = context .run
266- job = context .job
267- run_spec = run .run_spec
268-
269263 master_job_dependency = await _resolve_master_job_dependency (
270264 session = session ,
271- job_model = job_model ,
272- run = run ,
273- job = job ,
265+ job_model = context . job_model ,
266+ run = context . run ,
267+ job = context . job ,
274268 )
275269 if master_job_dependency is None :
276270 return
277271 master_job_provisioning_data = master_job_dependency .provisioning_data
278272
279273 if not await _resolve_fleet_dependency (
280274 session = session ,
281- job_model = job_model ,
282- run_model = run_model ,
283- job = job ,
275+ job_model = context . job_model ,
276+ run_model = context . run_model ,
277+ job = context . job ,
284278 ):
285279 return
286280
287281 prepared_job_volumes = await _prepare_job_volumes (
288282 session = session ,
289- job_model = job_model ,
283+ job_model = context . job_model ,
290284 project = context .project ,
291- run_spec = run_spec ,
292- job = job ,
285+ run_spec = context . run . run_spec ,
286+ job = context . job ,
293287 )
294288 if prepared_job_volumes is None :
295289 return
@@ -775,34 +769,53 @@ async def _finalize_submitted_job_processing(
775769 jobs_to_provision = provisioning_phase_result .jobs_to_provision ,
776770 )
777771
778- volume_models = prepared_job_volumes .volume_models
779- volumes_ids = sorted ([v .id for vs in volume_models for v in vs ])
772+ await _attach_job_volumes_if_needed (
773+ exit_stack = exit_stack ,
774+ session = session ,
775+ context = context ,
776+ prepared_job_volumes = prepared_job_volumes ,
777+ provisioning_phase_result = provisioning_phase_result ,
778+ )
779+ await session .commit ()
780+
781+
782+ async def _attach_job_volumes_if_needed (
783+ exit_stack : AsyncExitStack ,
784+ session : AsyncSession ,
785+ context : _SubmittedJobContext ,
786+ prepared_job_volumes : _PreparedJobVolumes ,
787+ provisioning_phase_result : _ProvisioningPhaseResult ,
788+ ) -> None :
780789 # TODO: Volume attachment for compute groups is not yet supported since
781790 # currently supported compute groups (e.g. Runpod) don't need explicit volume attachment.
782- if provisioning_phase_result .compute_group_model is None :
783- # Take lock to prevent attaching volumes that are to be deleted.
784- # If the volume was deleted before the lock, the volume will fail to attach and the job will fail.
785- # TODO: Lock instances for attaching volumes?
786- await session .execute (
787- select (VolumeModel )
788- .where (VolumeModel .id .in_ (volumes_ids ))
789- .options (joinedload (VolumeModel .user ).load_only (UserModel .name ))
790- .order_by (VolumeModel .id ) # take locks in order
791- .with_for_update (key_share = True , of = VolumeModel )
792- )
793- await exit_stack .enter_async_context (
794- get_locker (get_db ().dialect_name ).lock_ctx (VolumeModel .__tablename__ , volumes_ids )
795- )
796- if len (volume_models ) > 0 :
797- assert len (provisioning_phase_result .instance_models ) == 1
798- await _attach_volumes (
799- session = session ,
800- project = context .project ,
801- job_model = context .job_model ,
802- instance = provisioning_phase_result .instance_models [0 ],
803- volume_models = volume_models ,
804- )
805- await session .commit ()
791+ if provisioning_phase_result .compute_group_model is not None :
792+ return
793+
794+ volume_models = prepared_job_volumes .volume_models
795+ volumes_ids = sorted ([v .id for vs in volume_models for v in vs ])
796+ # Take lock to prevent attaching volumes that are to be deleted.
797+ # If the volume was deleted before the lock, the volume will fail to attach and the job will fail.
798+ # TODO: Lock instances for attaching volumes?
799+ await session .execute (
800+ select (VolumeModel )
801+ .where (VolumeModel .id .in_ (volumes_ids ))
802+ .options (joinedload (VolumeModel .user ).load_only (UserModel .name ))
803+ .order_by (VolumeModel .id ) # take locks in order
804+ .with_for_update (key_share = True , of = VolumeModel )
805+ )
806+ await exit_stack .enter_async_context (
807+ get_locker (get_db ().dialect_name ).lock_ctx (VolumeModel .__tablename__ , volumes_ids )
808+ )
809+ if len (volume_models ) == 0 :
810+ return
811+ assert len (provisioning_phase_result .instance_models ) == 1
812+ await _attach_volumes (
813+ session = session ,
814+ project = context .project ,
815+ job_model = context .job_model ,
816+ instance = provisioning_phase_result .instance_models [0 ],
817+ volume_models = volume_models ,
818+ )
806819
807820
808821async def _defer_submitted_job (
0 commit comments