Skip to content

Commit 7e3969c

Browse files
authored
Move stop_runner() to JobTerminating pipeline (#3714)
* Move stop_runner() to JobTerminating pipeline * Update tests * Replace graceful_termination with graceful_termination_attempts * Rebase migration * Add status-specific processing interval
1 parent 83cad55 commit 7e3969c

File tree

9 files changed

+235
-48
lines changed

9 files changed

+235
-48
lines changed

src/dstack/_internal/server/background/pipeline_tasks/base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ async def heartbeat(self):
255255

256256
class Fetcher(Generic[ItemT], ABC):
257257
_DEFAULT_FETCH_DELAYS = [0.5, 1, 2, 5]
258+
"""Increasing fetch delays on empty fetches to avoid frequent selects on low-activity/low-resource servers."""
258259

259260
def __init__(
260261
self,
@@ -319,7 +320,15 @@ async def fetch(self, limit: int) -> list[ItemT]:
319320
pass
320321

321322
def _next_fetch_delay(self, empty_fetch_count: int) -> float:
322-
next_delay = self._fetch_delays[min(empty_fetch_count, len(self._fetch_delays) - 1)]
323+
effective_empty_fetch_count = empty_fetch_count
324+
if random.random() < 0.1:
325+
# Empty fetch count can be 0 not because there are no items in the DB,
326+
# but for other reasons such as waiting parent resource processing.
327+
# From time to time, force minimal next delay to avoid empty results due to rare fetches.
328+
effective_empty_fetch_count = 0
329+
next_delay = self._fetch_delays[
330+
min(effective_empty_fetch_count, len(self._fetch_delays) - 1)
331+
]
323332
jitter = random.random() * 0.4 - 0.2
324333
return next_delay * (1 + jitter)
325334

src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def __init__(
117117
workers_num: int = 20,
118118
queue_lower_limit_factor: float = 0.5,
119119
queue_upper_limit_factor: float = 2.0,
120-
min_processing_interval: timedelta = timedelta(seconds=10),
120+
min_processing_interval: timedelta = timedelta(seconds=5),
121121
lock_timeout: timedelta = timedelta(seconds=30),
122122
heartbeat_trigger: timedelta = timedelta(seconds=15),
123123
) -> None:
@@ -196,7 +196,19 @@ async def fetch(self, limit: int) -> list[JobRunningPipelineItem]:
196196
[JobStatus.PROVISIONING, JobStatus.PULLING, JobStatus.RUNNING]
197197
),
198198
RunModel.status.not_in([RunStatus.TERMINATING]),
199-
JobModel.last_processed_at <= now - self._min_processing_interval,
199+
or_(
200+
# Process provisioning and pulling jobs quicker for low-latency provisioning.
201+
# Active jobs processing can be less frequent to minimize contention with `RunPipeline`.
202+
and_(
203+
JobModel.status.in_([JobStatus.PROVISIONING, JobStatus.PULLING]),
204+
JobModel.last_processed_at <= now - self._min_processing_interval,
205+
),
206+
and_(
207+
JobModel.status.in_([JobStatus.RUNNING]),
208+
JobModel.last_processed_at
209+
<= now - self._min_processing_interval * 2,
210+
),
211+
),
200212
or_(
201213
and_(
202214
# Do not try to lock jobs if the run is waiting for the lock,

src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
get_job_provisioning_data,
6161
get_job_runtime_data,
6262
get_job_spec,
63+
stop_runner,
6364
)
6465
from dstack._internal.server.services.locking import get_locker
6566
from dstack._internal.server.services.logging import fmt
@@ -265,8 +266,10 @@ class _JobUpdateMap(ItemUpdateMap, total=False):
265266
termination_reason: Optional[JobTerminationReason]
266267
termination_reason_message: Optional[str]
267268
instance_id: Optional[uuid.UUID]
269+
graceful_termination_attempts: int
268270
volumes_detached_at: UpdateMapDateTime
269271
registered: bool
272+
remove_at: UpdateMapDateTime
270273

271274

272275
class _InstanceUpdateMap(ItemUpdateMap, total=False):
@@ -580,9 +583,11 @@ async def _process_terminating_job(
580583
instance_model: Optional[InstanceModel],
581584
) -> _ProcessResult:
582585
"""
583-
Stops the job: tells shim to stop the container, detaches the job from the instance,
584-
and detaches volumes from the instance.
585-
Graceful stop should already be done by the run terminating path.
586+
Terminates the job:
587+
1. tells the runner to stop the job's command
588+
2. tells the shim to stop the container
589+
3. detaches the job from the instance
590+
4. and detaches volumes from the instance.
586591
"""
587592
instance_update_map = None if instance_model is None else _InstanceUpdateMap()
588593
result = _ProcessResult(instance_update_map=instance_update_map)
@@ -592,6 +597,10 @@ async def _process_terminating_job(
592597
result.job_update_map["status"] = _get_job_termination_status(job_model)
593598
return result
594599

600+
if job_model.graceful_termination_attempts == 0 and job_model.remove_at is None:
601+
result.job_update_map = await _stop_job_gracefully(job_model, instance_model)
602+
return result
603+
595604
jrd = get_job_runtime_data(job_model)
596605
jpd = get_job_provisioning_data(job_model)
597606
if jpd is not None:
@@ -642,6 +651,20 @@ async def _process_terminating_job(
642651
return result
643652

644653

654+
async def _stop_job_gracefully(
655+
job_model: JobModel, instance_model: InstanceModel
656+
) -> _JobUpdateMap:
657+
"""
658+
Tells the runner to stop the job's command. Records the first graceful-stop attempt and
659+
sets `remove_at` so `_process_terminating_job()` stops the container on a later iteration.
660+
"""
661+
job_update_map = _JobUpdateMap()
662+
await stop_runner(job_model=job_model, instance_model=instance_model)
663+
job_update_map["graceful_termination_attempts"] = 1
664+
job_update_map["remove_at"] = get_current_datetime() + timedelta(seconds=10)
665+
return job_update_map
666+
667+
645668
async def _process_job_volumes_detaching(
646669
job_model: JobModel,
647670
instance_model: InstanceModel,

src/dstack/_internal/server/background/pipeline_tasks/runs/__init__.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(
5555
workers_num: int = 10,
5656
queue_lower_limit_factor: float = 0.5,
5757
queue_upper_limit_factor: float = 2.0,
58-
min_processing_interval: timedelta = timedelta(seconds=10),
58+
min_processing_interval: timedelta = timedelta(seconds=5),
5959
lock_timeout: timedelta = timedelta(seconds=30),
6060
heartbeat_trigger: timedelta = timedelta(seconds=15),
6161
) -> None:
@@ -164,7 +164,17 @@ async def fetch(self, limit: int) -> list[RunPipelineItem]:
164164
),
165165
),
166166
or_(
167-
RunModel.last_processed_at <= now - self._min_processing_interval,
167+
# Process submitted runs quicker for low-latency provisioning.
168+
# Active run processing can be less frequent to minimize contention with `JobRunningPipeline`.
169+
and_(
170+
RunModel.status == RunStatus.SUBMITTED,
171+
RunModel.last_processed_at <= now - self._min_processing_interval,
172+
),
173+
and_(
174+
RunModel.status != RunStatus.SUBMITTED,
175+
RunModel.last_processed_at
176+
<= now - self._min_processing_interval * 2,
177+
),
168178
RunModel.last_processed_at == RunModel.submitted_at,
169179
),
170180
or_(

src/dstack/_internal/server/background/pipeline_tasks/runs/terminating.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import uuid
22
from dataclasses import dataclass, field
3-
from datetime import datetime, timedelta
3+
from datetime import datetime
44
from typing import Optional
55

66
import httpx
@@ -17,10 +17,9 @@
1717
from dstack._internal.server.db import get_session_ctx
1818
from dstack._internal.server.services import events
1919
from dstack._internal.server.services.gateways import get_or_add_gateway_connection
20-
from dstack._internal.server.services.jobs import stop_runner
2120
from dstack._internal.server.services.logging import fmt
2221
from dstack._internal.server.services.runs import _get_next_triggered_at, get_run_spec
23-
from dstack._internal.utils.common import get_current_datetime, get_or_error
22+
from dstack._internal.utils.common import get_or_error
2423
from dstack._internal.utils.logging import get_logger
2524

2625
logger = get_logger(__name__)
@@ -35,7 +34,7 @@ class TerminatingRunUpdateMap(ItemUpdateMap, total=False):
3534
class TerminatingRunJobUpdateMap(ItemUpdateMap, total=False):
3635
status: JobStatus
3736
termination_reason: Optional[JobTerminationReason]
38-
remove_at: Optional[datetime]
37+
graceful_termination_attempts: int
3938

4039

4140
@dataclass
@@ -77,10 +76,6 @@ async def process_terminating_run(context: TerminatingContext) -> TerminatingRes
7776
JobTerminationReason.ABORTED_BY_USER,
7877
JobTerminationReason.DONE_BY_RUNNER,
7978
}:
80-
# Send a signal to stop the job gracefully.
81-
await stop_runner(
82-
job_model=job_model, instance_model=get_or_error(job_model.instance)
83-
)
8479
delayed_job_ids.append(job_model.id)
8580
continue
8681
regular_job_ids.append(job_model.id)
@@ -123,7 +118,7 @@ def _get_job_id_to_update_map(
123118
job_id_to_update_map[job_id] = TerminatingRunJobUpdateMap(
124119
status=JobStatus.TERMINATING,
125120
termination_reason=job_termination_reason,
126-
remove_at=get_current_datetime() + timedelta(seconds=15),
121+
graceful_termination_attempts=0,
127122
)
128123
return job_id_to_update_map
129124

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""Add JobModel.graceful_termination_attempts
2+
3+
Revision ID: e9d81c97c042
4+
Revises: 59e328ced74c
5+
Create Date: 2026-03-30 08:41:29.308250+00:00
6+
7+
"""
8+
9+
import sqlalchemy as sa
10+
from alembic import op
11+
12+
# revision identifiers, used by Alembic.
13+
revision = "e9d81c97c042"
14+
down_revision = "59e328ced74c"
15+
branch_labels = None
16+
depends_on = None
17+
18+
19+
def upgrade() -> None:
20+
# ### commands auto generated by Alembic - please adjust! ###
21+
with op.batch_alter_table("jobs", schema=None) as batch_op:
22+
batch_op.add_column(
23+
sa.Column("graceful_termination_attempts", sa.Integer(), nullable=True)
24+
)
25+
26+
# ### end Alembic commands ###
27+
28+
29+
def downgrade() -> None:
30+
# ### commands auto generated by Alembic - please adjust! ###
31+
with op.batch_alter_table("jobs", schema=None) as batch_op:
32+
batch_op.drop_column("graceful_termination_attempts")
33+
34+
# ### end Alembic commands ###

src/dstack/_internal/server/models.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -497,8 +497,16 @@ class JobModel(PipelineModelMixin, BaseModel):
497497
runner_timestamp: Mapped[Optional[int]] = mapped_column(BigInteger)
498498
inactivity_secs: Mapped[Optional[int]] = mapped_column(Integer)
499499
"""`inactivity_secs` uses `0` for active jobs and `None` when inactivity is not applicable."""
500+
graceful_termination_attempts: Mapped[Optional[int]] = mapped_column(Integer)
501+
"""`graceful_termination_attempts` is used for terminating jobs.
502+
* `None` means graceful termination is not needed
503+
* `0` means it is needed but not attempted,
504+
* `>= 1` means at least one graceful stop attempt was sent.
505+
"""
500506
remove_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime)
501-
"""`remove_at` is used to ensure the instance is killed after the job is finished."""
507+
"""`remove_at` is used to ensure the container/instance is killed after the job is gracefully finished.
508+
Cannot kill the container/instance until `remove_at` is set.
509+
"""
502510
volumes_detached_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime)
503511
instance_assigned: Mapped[bool] = mapped_column(Boolean, default=False)
504512
"""`instance_assigned` shows whether instance assignment has already been attempted.

src/tests/_internal/server/background/pipeline_tasks/test_runs/test_termination.py

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
JobTerminatingPipeline,
2121
)
2222
from dstack._internal.server.background.pipeline_tasks.runs import RunPipeline, RunWorker
23+
from dstack._internal.server.background.pipeline_tasks.runs.terminating import (
24+
TerminatingResult,
25+
process_terminating_run,
26+
)
2327
from dstack._internal.server.testing.common import (
2428
create_fleet,
2529
create_instance,
@@ -84,32 +88,14 @@ async def test_transitions_running_jobs_to_terminating(
8488
)
8589
lock_run(run)
8690
await session.commit()
87-
item = run_to_pipeline_item(run)
88-
observed_job_lock = {}
89-
90-
async def record_stop_call(**kwargs) -> None:
91-
observed_job_lock["lock_token"] = kwargs["job_model"].lock_token
92-
observed_job_lock["lock_owner"] = kwargs["job_model"].lock_owner
93-
94-
with patch(
95-
"dstack._internal.server.background.pipeline_tasks.runs.terminating.stop_runner",
96-
new=AsyncMock(side_effect=record_stop_call),
97-
) as stop_runner:
98-
await worker.process(item)
99-
100-
assert stop_runner.await_count == 1
101-
stop_call = stop_runner.await_args
102-
assert stop_call is not None
103-
assert stop_call.kwargs["job_model"].id == job.id
104-
assert observed_job_lock["lock_token"] == item.lock_token
105-
assert observed_job_lock["lock_owner"] == RunPipeline.__name__
106-
assert stop_call.kwargs["instance_model"].id == instance.id
91+
await worker.process(run_to_pipeline_item(run))
10792

10893
await session.refresh(job)
10994
await session.refresh(run)
11095
assert job.status == JobStatus.TERMINATING
11196
assert job.termination_reason == JobTerminationReason.TERMINATED_BY_SERVER
112-
assert job.remove_at is not None
97+
assert job.graceful_termination_attempts == 0
98+
assert job.remove_at is None
11399
assert job.lock_token is None
114100
assert job.lock_expires_at is None
115101
assert job.lock_owner is None
@@ -154,19 +140,17 @@ async def test_updates_delayed_and_regular_jobs_separately(
154140
lock_run(run)
155141
await session.commit()
156142

157-
with patch(
158-
"dstack._internal.server.background.pipeline_tasks.runs.terminating.stop_runner",
159-
new=AsyncMock(),
160-
):
161-
await worker.process(run_to_pipeline_item(run))
143+
await worker.process(run_to_pipeline_item(run))
162144

163145
await session.refresh(delayed_job)
164146
await session.refresh(regular_job)
165147
assert delayed_job.status == JobStatus.TERMINATING
166148
assert delayed_job.termination_reason == JobTerminationReason.TERMINATED_BY_SERVER
167-
assert delayed_job.remove_at is not None
149+
assert delayed_job.graceful_termination_attempts == 0
150+
assert delayed_job.remove_at is None
168151
assert regular_job.status == JobStatus.TERMINATING
169152
assert regular_job.termination_reason == JobTerminationReason.TERMINATED_BY_SERVER
153+
assert regular_job.graceful_termination_attempts is None
170154
assert regular_job.remove_at is None
171155

172156
async def test_finishes_non_scheduled_run_when_all_jobs_are_finished(
@@ -273,14 +257,16 @@ async def test_noops_when_run_lock_changes_after_processing(
273257
await session.commit()
274258
item = run_to_pipeline_item(run)
275259
new_lock_token = uuid.uuid4()
260+
original_process_terminating_run = process_terminating_run
276261

277-
async def change_run_lock(**kwargs) -> None:
262+
async def change_run_lock(context) -> TerminatingResult:
278263
run.lock_token = new_lock_token
279264
run.lock_expires_at = get_current_datetime() + timedelta(minutes=1)
280265
await session.commit()
266+
return await original_process_terminating_run(context)
281267

282268
with patch(
283-
"dstack._internal.server.background.pipeline_tasks.runs.terminating.stop_runner",
269+
"dstack._internal.server.background.pipeline_tasks.runs.terminating.process_terminating_run",
284270
new=AsyncMock(side_effect=change_run_lock),
285271
):
286272
await worker.process(item)
@@ -289,7 +275,10 @@ async def change_run_lock(**kwargs) -> None:
289275
await session.refresh(job)
290276
assert run.status == RunStatus.TERMINATING
291277
assert run.lock_token == new_lock_token
278+
assert run.lock_owner == RunPipeline.__name__
292279
assert job.status == JobStatus.RUNNING
280+
assert job.graceful_termination_attempts is None
281+
assert job.remove_at is None
293282
assert job.lock_token is None
294283
assert job.lock_expires_at is None
295284
assert job.lock_owner is None

0 commit comments

Comments
 (0)