Skip to content

Commit fa001e5

Browse files
authored
Disable autoflush (#3553)
* Fix _get_next_instance_num rely on autoflush * Fix TestSwitchInstanceStatus rely on autoflush * Fix long write transaction when cleaning up placement groups
1 parent 2fd45cc commit fa001e5

File tree

3 files changed

+16
-23
lines changed

3 files changed

+16
-23
lines changed

src/dstack/_internal/server/background/tasks/process_submitted_jobs.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@
110110
get_fleet_placement_group_models,
111111
get_placement_group_model_for_job,
112112
placement_group_model_to_placement_group_optional,
113-
schedule_fleet_placement_groups_deletion,
114113
)
115114
from dstack._internal.server.services.runs import (
116115
run_model_to_run,
@@ -481,17 +480,15 @@ async def _process_submitted_job(
481480
logger.info("%s: provisioned %s new instance(s)", fmt(job_model), len(provisioned_jobs))
482481
provisioned_job_models = _get_job_models_for_jobs(run_model.jobs, provisioned_jobs)
483482
instance = None # Instance for attaching volumes in case of single job provisioned
483+
# FIXME: Fleet is not locked which may lead to duplicate instance_num.
484+
# This is currently hard to fix without locking the fleet for entire provisioning duration.
485+
# Processing should be done in multiple steps so that
486+
# InstanceModel is created before provisioning.
487+
taken_instance_nums = await _get_taken_instance_nums(session, fleet_model)
484488
for provisioned_job_model, jpd in zip(provisioned_job_models, jpds):
485489
provisioned_job_model.job_provisioning_data = jpd.json()
486490
switch_job_status(session, provisioned_job_model, JobStatus.PROVISIONING)
487-
# FIXME: Fleet is not locked which may lead to duplicate instance_num.
488-
# This is currently hard to fix without locking the fleet for entire provisioning duration.
489-
# Processing should be done in multiple steps so that
490-
# InstanceModel is created before provisioning.
491-
instance_num = await _get_next_instance_num(
492-
session=session,
493-
fleet_model=fleet_model,
494-
)
491+
instance_num = get_next_instance_num(taken_instance_nums)
495492
instance = _create_instance_model_for_job(
496493
project=project,
497494
fleet_model=fleet_model,
@@ -502,6 +499,7 @@ async def _process_submitted_job(
502499
instance_num=instance_num,
503500
profile=effective_profile,
504501
)
502+
taken_instance_nums.add(instance_num)
505503
provisioned_job_model.job_runtime_data = _prepare_job_runtime_data(
506504
offer, multinode
507505
).json()
@@ -847,15 +845,9 @@ async def _run_jobs_on_new_instances(
847845
finally:
848846
if fleet_model is not None and len(fleet_model.instances) == 0:
849847
# Clean up placement groups that did not end up being used.
850-
# Flush to update still uncommitted placement groups.
851-
await session.flush()
852-
await schedule_fleet_placement_groups_deletion(
853-
session=session,
854-
fleet_id=fleet_model.id,
855-
except_placement_group_ids=(
856-
[placement_group_model.id] if placement_group_model is not None else []
857-
),
858-
)
848+
for pg in placement_group_models:
849+
if placement_group_model is None or pg.id != placement_group_model.id:
850+
pg.fleet_deleted = True
859851
return None
860852

861853

@@ -906,15 +898,14 @@ async def _create_fleet_model_for_job(
906898
return fleet_model
907899

908900

909-
async def _get_next_instance_num(session: AsyncSession, fleet_model: FleetModel) -> int:
901+
async def _get_taken_instance_nums(session: AsyncSession, fleet_model: FleetModel) -> set[int]:
910902
res = await session.execute(
911903
select(InstanceModel.instance_num).where(
912904
InstanceModel.fleet_id == fleet_model.id,
913905
InstanceModel.deleted.is_(False),
914906
)
915907
)
916-
taken_instance_nums = set(res.scalars().all())
917-
return get_next_instance_num(taken_instance_nums)
908+
return set(res.scalars().all())
918909

919910

920911
def _create_instance_model_for_job(

src/dstack/_internal/server/db.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ def __init__(self, url: str, engine: Optional[AsyncEngine] = None):
3333
self.session_maker = async_sessionmaker(
3434
bind=self.engine, # type: ignore[assignment]
3535
expire_on_commit=False,
36+
# Disable autoflush to avoid accidental long write transactions on SQLite.
37+
autoflush=False,
3638
class_=AsyncSession,
3739
)
3840

src/tests/_internal/server/services/test_instances.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ async def test_includes_termination_reason_in_event_messages_only_once(
4040
instance.termination_reason_message = "Some err"
4141
instances_services.switch_instance_status(session, instance, InstanceStatus.TERMINATING)
4242
instances_services.switch_instance_status(session, instance, InstanceStatus.TERMINATED)
43-
43+
await session.commit()
4444
events = await list_events(session)
4545
assert len(events) == 2
4646
assert {e.message for e in events} == {
@@ -61,7 +61,7 @@ async def test_includes_termination_reason_in_event_message_when_switching_direc
6161
instance.termination_reason = InstanceTerminationReason.ERROR
6262
instance.termination_reason_message = "Some err"
6363
instances_services.switch_instance_status(session, instance, InstanceStatus.TERMINATED)
64-
64+
await session.commit()
6565
events = await list_events(session)
6666
assert len(events) == 1
6767
assert events[0].message == (

0 commit comments

Comments
 (0)