Skip to content

Commit c852a73

Browse files
committed
Optimize process_instances queries
1 parent 50d1487 commit c852a73

3 files changed

Lines changed: 139 additions & 122 deletions

File tree

src/dstack/_internal/server/background/tasks/process_instances.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pydantic import ValidationError
1010
from sqlalchemy import select
1111
from sqlalchemy.ext.asyncio import AsyncSession
12-
from sqlalchemy.orm import joinedload, lazyload
12+
from sqlalchemy.orm import joinedload
1313

1414
from dstack._internal import settings
1515
from dstack._internal.core.backends.base.compute import (
@@ -78,6 +78,7 @@
7878
from dstack._internal.server.models import (
7979
FleetModel,
8080
InstanceModel,
81+
JobModel,
8182
PlacementGroupModel,
8283
ProjectModel,
8384
)
@@ -106,6 +107,7 @@
106107
from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel
107108
from dstack._internal.utils.common import (
108109
get_current_datetime,
110+
get_or_error,
109111
run_async,
110112
)
111113
from dstack._internal.utils.logging import get_logger
@@ -154,7 +156,8 @@ async def _process_next_instance():
154156
InstanceModel.last_processed_at
155157
< get_current_datetime() - MIN_PROCESSING_INTERVAL,
156158
)
157-
.options(lazyload(InstanceModel.jobs))
159+
.options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status))
160+
.options(joinedload(InstanceModel.project).load_only(ProjectModel.ssh_private_key))
158161
.order_by(InstanceModel.last_processed_at.asc())
159162
.limit(1)
160163
.with_for_update(skip_locked=True, key_share=True)
@@ -171,23 +174,22 @@ async def _process_next_instance():
171174

172175

173176
async def _process_instance(session: AsyncSession, instance: InstanceModel):
174-
# Refetch to load related attributes.
175-
# joinedload produces LEFT OUTER JOIN that can't be used with FOR UPDATE.
176-
res = await session.execute(
177-
select(InstanceModel)
178-
.where(InstanceModel.id == instance.id)
179-
.options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends))
180-
.options(joinedload(InstanceModel.jobs))
181-
.options(joinedload(InstanceModel.fleet).joinedload(FleetModel.instances))
182-
.execution_options(populate_existing=True)
183-
)
184-
instance = res.unique().scalar_one()
185-
if (
186-
instance.status == InstanceStatus.IDLE
187-
and instance.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE
188-
and not instance.jobs
177+
if instance.status in (
178+
InstanceStatus.PENDING,
179+
InstanceStatus.TERMINATING,
189180
):
190-
await _mark_terminating_if_idle_duration_expired(instance)
181+
# Refetch to load related attributes.
182+
# Load related attributes only for statuses that always need them.
183+
res = await session.execute(
184+
select(InstanceModel)
185+
.where(InstanceModel.id == instance.id)
186+
.options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends))
187+
.options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status))
188+
.options(joinedload(InstanceModel.fleet).joinedload(FleetModel.instances))
189+
.execution_options(populate_existing=True)
190+
)
191+
instance = res.unique().scalar_one()
192+
191193
if instance.status == InstanceStatus.PENDING:
192194
if instance.remote_connection_info is not None:
193195
await _add_remote(instance)
@@ -201,15 +203,23 @@ async def _process_instance(session: AsyncSession, instance: InstanceModel):
201203
InstanceStatus.IDLE,
202204
InstanceStatus.BUSY,
203205
):
204-
await _check_instance(instance)
206+
idle_duration_expired = _check_and_mark_terminating_if_idle_duration_expired(instance)
207+
if not idle_duration_expired:
208+
await _check_instance(session, instance)
205209
elif instance.status == InstanceStatus.TERMINATING:
206210
await _terminate(instance)
207211

208212
instance.last_processed_at = get_current_datetime()
209213
await session.commit()
210214

211215

212-
async def _mark_terminating_if_idle_duration_expired(instance: InstanceModel):
216+
def _check_and_mark_terminating_if_idle_duration_expired(instance: InstanceModel):
217+
if not (
218+
instance.status == InstanceStatus.IDLE
219+
and instance.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE
220+
and not instance.jobs
221+
):
222+
return False
213223
idle_duration = _get_instance_idle_duration(instance)
214224
idle_seconds = instance.termination_idle_time
215225
delta = datetime.timedelta(seconds=idle_seconds)
@@ -225,6 +235,8 @@ async def _mark_terminating_if_idle_duration_expired(instance: InstanceModel):
225235
"instance_status": instance.status.value,
226236
},
227237
)
238+
return True
239+
return False
228240

229241

230242
async def _add_remote(instance: InstanceModel) -> None:
@@ -703,7 +715,7 @@ def _mark_terminated(instance: InstanceModel, termination_reason: str) -> None:
703715
)
704716

705717

706-
async def _check_instance(instance: InstanceModel) -> None:
718+
async def _check_instance(session: AsyncSession, instance: InstanceModel) -> None:
707719
if (
708720
instance.status == InstanceStatus.BUSY
709721
and instance.jobs
@@ -722,12 +734,16 @@ async def _check_instance(instance: InstanceModel) -> None:
722734
)
723735
return
724736

725-
job_provisioning_data = JobProvisioningData.__response__.parse_raw(
726-
instance.job_provisioning_data
727-
)
737+
job_provisioning_data = get_or_error(get_instance_provisioning_data(instance))
728738
if job_provisioning_data.hostname is None:
739+
res = await session.execute(
740+
select(ProjectModel)
741+
.where(ProjectModel.id == instance.id)
742+
.options(joinedload(ProjectModel.backends))
743+
)
744+
project = res.scalar_one()
729745
await _wait_for_instance_provisioning_data(
730-
project=instance.project,
746+
project=project,
731747
instance=instance,
732748
job_provisioning_data=job_provisioning_data,
733749
)

src/dstack/_internal/server/background/tasks/process_running_jobs.py

Lines changed: 95 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
143143
res = await session.execute(
144144
select(RunModel)
145145
.where(RunModel.id == job_model.run_id)
146-
.options(joinedload(RunModel.project).joinedload(ProjectModel.backends))
146+
.options(joinedload(RunModel.project))
147147
.options(joinedload(RunModel.user))
148148
.options(joinedload(RunModel.repo))
149149
.options(joinedload(RunModel.jobs))
@@ -163,22 +163,18 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
163163

164164
job = find_job(run.jobs, job_model.replica_num, job_model.job_num)
165165

166-
# Wait until all other jobs in the replica are provisioned
167-
for other_job in run.jobs:
168-
if (
169-
other_job.job_spec.replica_num == job.job_spec.replica_num
170-
and other_job.job_submissions[-1].status == JobStatus.SUBMITTED
171-
):
172-
job_model.last_processed_at = common_utils.get_current_datetime()
173-
await session.commit()
174-
return
175-
176-
server_ssh_private_keys = get_instance_ssh_private_keys(
177-
common_utils.get_or_error(job_model.instance)
178-
)
179-
180166
initial_status = job_model.status
181167
if initial_status in [JobStatus.PROVISIONING, JobStatus.PULLING]:
168+
# Wait until all other jobs in the replica are provisioned
169+
for other_job in run.jobs:
170+
if (
171+
other_job.job_spec.replica_num == job.job_spec.replica_num
172+
and other_job.job_submissions[-1].status == JobStatus.SUBMITTED
173+
):
174+
job_model.last_processed_at = common_utils.get_current_datetime()
175+
await session.commit()
176+
return
177+
182178
cluster_info = _get_cluster_info(
183179
jobs=run.jobs,
184180
replica_num=job.job_spec.replica_num,
@@ -210,94 +206,98 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
210206
job_model.last_processed_at = common_utils.get_current_datetime()
211207
return
212208

209+
server_ssh_private_keys = get_instance_ssh_private_keys(
210+
common_utils.get_or_error(job_model.instance)
211+
)
212+
213213
if initial_status == JobStatus.PROVISIONING:
214214
if job_provisioning_data.hostname is None:
215215
await _wait_for_instance_provisioning_data(job_model=job_model)
216+
job_model.last_processed_at = common_utils.get_current_datetime()
217+
await session.commit()
218+
return
219+
if _should_wait_for_other_nodes(run, job, job_model):
220+
job_model.last_processed_at = common_utils.get_current_datetime()
221+
await session.commit()
222+
return
223+
224+
# fails are acceptable until timeout is exceeded
225+
if job_provisioning_data.dockerized:
226+
logger.debug(
227+
"%s: process provisioning job with shim, age=%s",
228+
fmt(job_model),
229+
job_submission.age,
230+
)
231+
ssh_user = job_provisioning_data.username
232+
user_ssh_key = run.run_spec.ssh_key_pub.strip()
233+
public_keys = [project.ssh_public_key.strip(), user_ssh_key]
234+
if job_provisioning_data.backend == BackendType.LOCAL:
235+
# No need to update ~/.ssh/authorized_keys when running shim locally
236+
user_ssh_key = ""
237+
success = await common_utils.run_async(
238+
_process_provisioning_with_shim,
239+
server_ssh_private_keys,
240+
job_provisioning_data,
241+
None,
242+
run,
243+
job_model,
244+
job_provisioning_data,
245+
volumes,
246+
job.job_spec.registry_auth,
247+
public_keys,
248+
ssh_user,
249+
user_ssh_key,
250+
)
216251
else:
217-
if _should_wait_for_other_nodes(run, job, job_model):
218-
job_model.last_processed_at = common_utils.get_current_datetime()
219-
await session.commit()
220-
return
252+
logger.debug(
253+
"%s: process provisioning job without shim, age=%s",
254+
fmt(job_model),
255+
job_submission.age,
256+
)
257+
# FIXME: downloading file archives and code here is a waste of time if
258+
# the runner is not ready yet
259+
file_archives = await _get_job_file_archives(
260+
session=session,
261+
archive_mappings=job.job_spec.file_archives,
262+
user=run_model.user,
263+
)
264+
code = await _get_job_code(
265+
session=session,
266+
project=project,
267+
repo=repo_model,
268+
code_hash=_get_repo_code_hash(run, job),
269+
)
221270

222-
# fails are acceptable until timeout is exceeded
223-
if job_provisioning_data.dockerized:
224-
logger.debug(
225-
"%s: process provisioning job with shim, age=%s",
226-
fmt(job_model),
227-
job_submission.age,
228-
)
229-
ssh_user = job_provisioning_data.username
230-
user_ssh_key = run.run_spec.ssh_key_pub.strip()
231-
public_keys = [project.ssh_public_key.strip(), user_ssh_key]
232-
if job_provisioning_data.backend == BackendType.LOCAL:
233-
# No need to update ~/.ssh/authorized_keys when running shim locally
234-
user_ssh_key = ""
235-
success = await common_utils.run_async(
236-
_process_provisioning_with_shim,
237-
server_ssh_private_keys,
238-
job_provisioning_data,
239-
None,
240-
run,
241-
job_model,
242-
job_provisioning_data,
243-
volumes,
244-
job.job_spec.registry_auth,
245-
public_keys,
246-
ssh_user,
247-
user_ssh_key,
248-
)
249-
else:
250-
logger.debug(
251-
"%s: process provisioning job without shim, age=%s",
271+
success = await common_utils.run_async(
272+
_submit_job_to_runner,
273+
server_ssh_private_keys,
274+
job_provisioning_data,
275+
None,
276+
run,
277+
job_model,
278+
job,
279+
cluster_info,
280+
code,
281+
file_archives,
282+
secrets,
283+
repo_creds,
284+
success_if_not_available=False,
285+
)
286+
287+
if not success:
288+
# check timeout
289+
if job_submission.age > get_provisioning_timeout(
290+
backend_type=job_provisioning_data.get_base_backend(),
291+
instance_type_name=job_provisioning_data.instance_type.name,
292+
):
293+
logger.warning(
294+
"%s: failed because runner has not become available in time, age=%s",
252295
fmt(job_model),
253296
job_submission.age,
254297
)
255-
# FIXME: downloading file archives and code here is a waste of time if
256-
# the runner is not ready yet
257-
file_archives = await _get_job_file_archives(
258-
session=session,
259-
archive_mappings=job.job_spec.file_archives,
260-
user=run_model.user,
261-
)
262-
code = await _get_job_code(
263-
session=session,
264-
project=project,
265-
repo=repo_model,
266-
code_hash=_get_repo_code_hash(run, job),
267-
)
268-
269-
success = await common_utils.run_async(
270-
_submit_job_to_runner,
271-
server_ssh_private_keys,
272-
job_provisioning_data,
273-
None,
274-
run,
275-
job_model,
276-
job,
277-
cluster_info,
278-
code,
279-
file_archives,
280-
secrets,
281-
repo_creds,
282-
success_if_not_available=False,
283-
)
284-
285-
if not success:
286-
# check timeout
287-
if job_submission.age > get_provisioning_timeout(
288-
backend_type=job_provisioning_data.get_base_backend(),
289-
instance_type_name=job_provisioning_data.instance_type.name,
290-
):
291-
logger.warning(
292-
"%s: failed because runner has not become available in time, age=%s",
293-
fmt(job_model),
294-
job_submission.age,
295-
)
296-
job_model.status = JobStatus.TERMINATING
297-
job_model.termination_reason = (
298-
JobTerminationReason.WAITING_RUNNER_LIMIT_EXCEEDED
299-
)
300-
# instance will be emptied by process_terminating_jobs
298+
job_model.status = JobStatus.TERMINATING
299+
job_model.termination_reason = JobTerminationReason.WAITING_RUNNER_LIMIT_EXCEEDED
300+
# instance will be emptied by process_terminating_jobs
301301

302302
else: # fails are not acceptable
303303
if initial_status == JobStatus.PULLING:

src/tests/_internal/server/background/tasks/test_process_instances.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ async def test_check_shim_stop_termination_deadline(self, test_db, session: Asyn
224224

225225
@pytest.mark.asyncio
226226
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
227-
async def test_check_shim_terminate_instance_by_dedaline(self, test_db, session: AsyncSession):
227+
async def test_check_shim_terminate_instance_by_deadline(self, test_db, session: AsyncSession):
228228
project = await create_project(session=session)
229229
instance = await create_instance(
230230
session=session,
@@ -306,6 +306,7 @@ async def test_check_shim_process_ureachable_state(
306306
) as healthcheck:
307307
healthcheck.return_value = HealthStatus(healthy=True, reason="OK")
308308
await process_instances()
309+
healthcheck.assert_called()
309310

310311
await session.refresh(instance)
311312

@@ -329,7 +330,7 @@ async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):
329330
await process_instances()
330331
await session.refresh(instance)
331332
assert instance is not None
332-
assert instance.status == InstanceStatus.TERMINATED
333+
assert instance.status == InstanceStatus.TERMINATING
333334
assert instance.termination_reason == "Idle timeout"
334335

335336

0 commit comments

Comments
 (0)