Skip to content

Commit bd42dfc

Browse files
authored
Add pipelines optimizations (#3719)
* Skip sibling jobs loading for running jobs * Use status-specific processing intervals for instances * Optimize run_model loading for submitted jobs * Introduce *_STATUSES_WITH_MIN_PROCESSING_INTERVAL * Use state-specific processing intervals for terminating jobs * Pass pipeline hinter to pipelines
1 parent 3277143 commit bd42dfc

File tree

22 files changed

+447
-87
lines changed

22 files changed

+447
-87
lines changed

src/dstack/_internal/server/background/pipeline_tasks/__init__.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,16 @@ def __init__(self) -> None:
2727
self._pipelines: list[Pipeline] = []
2828
self._hinter = PipelineHinter()
2929
for builtin_pipeline in [
30-
ComputeGroupPipeline(),
31-
FleetPipeline(),
32-
GatewayPipeline(),
33-
JobSubmittedPipeline(),
34-
JobRunningPipeline(),
35-
JobTerminatingPipeline(),
36-
InstancePipeline(),
37-
PlacementGroupPipeline(),
38-
RunPipeline(),
39-
VolumePipeline(),
30+
ComputeGroupPipeline(pipeline_hinter=self._hinter),
31+
FleetPipeline(pipeline_hinter=self._hinter),
32+
GatewayPipeline(pipeline_hinter=self._hinter),
33+
JobSubmittedPipeline(pipeline_hinter=self._hinter),
34+
JobRunningPipeline(pipeline_hinter=self._hinter),
35+
JobTerminatingPipeline(pipeline_hinter=self._hinter),
36+
InstancePipeline(pipeline_hinter=self._hinter),
37+
PlacementGroupPipeline(pipeline_hinter=self._hinter),
38+
RunPipeline(pipeline_hinter=self._hinter),
39+
VolumePipeline(pipeline_hinter=self._hinter),
4040
]:
4141
self.register_pipeline(builtin_pipeline)
4242

src/dstack/_internal/server/background/pipeline_tasks/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from sqlalchemy.orm import Mapped
2525

2626
from dstack._internal.server.db import get_session_ctx
27+
from dstack._internal.server.services.pipelines import PipelineHinterProtocol
2728
from dstack._internal.utils.common import get_current_datetime
2829
from dstack._internal.utils.logging import get_logger
2930

@@ -338,9 +339,11 @@ def __init__(
338339
self,
339340
queue: asyncio.Queue[ItemT],
340341
heartbeater: Heartbeater[ItemT],
342+
pipeline_hinter: PipelineHinterProtocol,
341343
) -> None:
342344
self._queue = queue
343345
self._heartbeater = heartbeater
346+
self._pipeline_hinter = pipeline_hinter
344347
self._running = False
345348

346349
async def start(self):

src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from dstack._internal.server.services.compute_groups import compute_group_model_to_compute_group
3333
from dstack._internal.server.services.instances import emit_instance_status_change_event
3434
from dstack._internal.server.services.locking import get_locker
35+
from dstack._internal.server.services.pipelines import PipelineHinterProtocol
3536
from dstack._internal.server.utils import sentry_utils
3637
from dstack._internal.utils.common import get_current_datetime, run_async
3738
from dstack._internal.utils.logging import get_logger
@@ -51,6 +52,8 @@ def __init__(
5152
min_processing_interval: timedelta = timedelta(seconds=15),
5253
lock_timeout: timedelta = timedelta(seconds=30),
5354
heartbeat_trigger: timedelta = timedelta(seconds=15),
55+
*,
56+
pipeline_hinter: PipelineHinterProtocol,
5457
) -> None:
5558
super().__init__(
5659
workers_num=workers_num,
@@ -73,7 +76,11 @@ def __init__(
7376
heartbeater=self._heartbeater,
7477
)
7578
self.__workers = [
76-
ComputeGroupWorker(queue=self._queue, heartbeater=self._heartbeater)
79+
ComputeGroupWorker(
80+
queue=self._queue,
81+
heartbeater=self._heartbeater,
82+
pipeline_hinter=pipeline_hinter,
83+
)
7784
for _ in range(self._workers_num)
7885
]
7986

@@ -173,10 +180,12 @@ def __init__(
173180
self,
174181
queue: asyncio.Queue[PipelineItem],
175182
heartbeater: Heartbeater[PipelineItem],
183+
pipeline_hinter: PipelineHinterProtocol,
176184
) -> None:
177185
super().__init__(
178186
queue=queue,
179187
heartbeater=heartbeater,
188+
pipeline_hinter=pipeline_hinter,
180189
)
181190

182191
@sentry_utils.instrument_named_task("pipeline_tasks.ComputeGroupWorker.process")

src/dstack/_internal/server/background/pipeline_tasks/fleets.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
is_fleet_in_use,
5050
)
5151
from dstack._internal.server.services.locking import get_locker
52+
from dstack._internal.server.services.pipelines import PipelineHinterProtocol
5253
from dstack._internal.server.utils import sentry_utils
5354
from dstack._internal.utils.common import get_current_datetime
5455
from dstack._internal.utils.logging import get_logger
@@ -65,6 +66,8 @@ def __init__(
6566
min_processing_interval: timedelta = timedelta(seconds=30),
6667
lock_timeout: timedelta = timedelta(seconds=20),
6768
heartbeat_trigger: timedelta = timedelta(seconds=10),
69+
*,
70+
pipeline_hinter: PipelineHinterProtocol,
6871
) -> None:
6972
super().__init__(
7073
workers_num=workers_num,
@@ -87,7 +90,11 @@ def __init__(
8790
heartbeater=self._heartbeater,
8891
)
8992
self.__workers = [
90-
FleetWorker(queue=self._queue, heartbeater=self._heartbeater)
93+
FleetWorker(
94+
queue=self._queue,
95+
heartbeater=self._heartbeater,
96+
pipeline_hinter=pipeline_hinter,
97+
)
9198
for _ in range(self._workers_num)
9299
]
93100

@@ -188,10 +195,12 @@ def __init__(
188195
self,
189196
queue: asyncio.Queue[PipelineItem],
190197
heartbeater: Heartbeater[PipelineItem],
198+
pipeline_hinter: PipelineHinterProtocol,
191199
) -> None:
192200
super().__init__(
193201
queue=queue,
194202
heartbeater=heartbeater,
203+
pipeline_hinter=pipeline_hinter,
195204
)
196205

197206
@sentry_utils.instrument_named_task("pipeline_tasks.FleetWorker.process")

src/dstack/_internal/server/background/pipeline_tasks/gateways.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from dstack._internal.server.services.gateways.pool import gateway_connections_pool
3939
from dstack._internal.server.services.locking import get_locker
4040
from dstack._internal.server.services.logging import fmt
41+
from dstack._internal.server.services.pipelines import PipelineHinterProtocol
4142
from dstack._internal.server.utils import sentry_utils
4243
from dstack._internal.utils.common import get_current_datetime, run_async
4344
from dstack._internal.utils.logging import get_logger
@@ -60,6 +61,8 @@ def __init__(
6061
min_processing_interval: timedelta = timedelta(seconds=15),
6162
lock_timeout: timedelta = timedelta(seconds=30),
6263
heartbeat_trigger: timedelta = timedelta(seconds=15),
64+
*,
65+
pipeline_hinter: PipelineHinterProtocol,
6366
) -> None:
6467
super().__init__(
6568
workers_num=workers_num,
@@ -82,7 +85,11 @@ def __init__(
8285
heartbeater=self._heartbeater,
8386
)
8487
self.__workers = [
85-
GatewayWorker(queue=self._queue, heartbeater=self._heartbeater)
88+
GatewayWorker(
89+
queue=self._queue,
90+
heartbeater=self._heartbeater,
91+
pipeline_hinter=pipeline_hinter,
92+
)
8693
for _ in range(self._workers_num)
8794
]
8895

@@ -192,10 +199,12 @@ def __init__(
192199
self,
193200
queue: asyncio.Queue[GatewayPipelineItem],
194201
heartbeater: Heartbeater[GatewayPipelineItem],
202+
pipeline_hinter: PipelineHinterProtocol,
195203
) -> None:
196204
super().__init__(
197205
queue=queue,
198206
heartbeater=heartbeater,
207+
pipeline_hinter=pipeline_hinter,
199208
)
200209

201210
@sentry_utils.instrument_named_task("pipeline_tasks.GatewayWorker.process")

src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
is_ssh_instance,
5353
)
5454
from dstack._internal.server.services.locking import get_locker
55+
from dstack._internal.server.services.pipelines import PipelineHinterProtocol
5556
from dstack._internal.server.services.placement import (
5657
schedule_fleet_placement_groups_deletion,
5758
)
@@ -62,6 +63,13 @@
6263
logger = get_logger(__name__)
6364

6465

66+
INSTANCE_STATUSES_WITH_MIN_PROCESSING_INTERVAL = [
67+
InstanceStatus.PENDING,
68+
InstanceStatus.PROVISIONING,
69+
InstanceStatus.TERMINATING,
70+
]
71+
72+
6573
@dataclass
6674
class InstancePipelineItem(PipelineItem):
6775
status: InstanceStatus
@@ -73,9 +81,11 @@ def __init__(
7381
workers_num: int = 20,
7482
queue_lower_limit_factor: float = 0.5,
7583
queue_upper_limit_factor: float = 2.0,
76-
min_processing_interval: timedelta = timedelta(seconds=15),
84+
min_processing_interval: timedelta = timedelta(seconds=7),
7785
lock_timeout: timedelta = timedelta(seconds=30),
7886
heartbeat_trigger: timedelta = timedelta(seconds=15),
87+
*,
88+
pipeline_hinter: PipelineHinterProtocol,
7989
) -> None:
8090
super().__init__(
8191
workers_num=workers_num,
@@ -98,7 +108,11 @@ def __init__(
98108
heartbeater=self._heartbeater,
99109
)
100110
self.__workers = [
101-
InstanceWorker(queue=self._queue, heartbeater=self._heartbeater)
111+
InstanceWorker(
112+
queue=self._queue,
113+
heartbeater=self._heartbeater,
114+
pipeline_hinter=pipeline_hinter,
115+
)
102116
for _ in range(self._workers_num)
103117
]
104118

@@ -167,7 +181,24 @@ async def fetch(self, limit: int) -> list[InstancePipelineItem]:
167181
),
168182
InstanceModel.deleted == False,
169183
or_(
170-
InstanceModel.last_processed_at <= now - self._min_processing_interval,
184+
# Process fast-moving instances (pending, provisioning, terminating)
185+
# at base interval for low-latency state transitions.
186+
# Steady-state instances (idle, busy) use a longer interval
187+
# since they only need periodic health checks.
188+
and_(
189+
InstanceModel.status.in_(
190+
INSTANCE_STATUSES_WITH_MIN_PROCESSING_INTERVAL
191+
),
192+
InstanceModel.last_processed_at
193+
<= now - self._min_processing_interval,
194+
),
195+
and_(
196+
InstanceModel.status.not_in(
197+
INSTANCE_STATUSES_WITH_MIN_PROCESSING_INTERVAL
198+
),
199+
InstanceModel.last_processed_at
200+
<= now - self._min_processing_interval * 2,
201+
),
171202
InstanceModel.last_processed_at == InstanceModel.created_at,
172203
),
173204
or_(
@@ -228,10 +259,12 @@ def __init__(
228259
self,
229260
queue: asyncio.Queue[InstancePipelineItem],
230261
heartbeater: Heartbeater[InstancePipelineItem],
262+
pipeline_hinter: PipelineHinterProtocol,
231263
) -> None:
232264
super().__init__(
233265
queue=queue,
234266
heartbeater=heartbeater,
267+
pipeline_hinter=pipeline_hinter,
235268
)
236269

237270
@sentry_utils.instrument_named_task("pipeline_tasks.InstanceWorker.process")

0 commit comments

Comments
 (0)