Skip to content

Commit 2fe5e14

Browse files
Bihan  RanaBihan  Rana
authored andcommitted
Support router as replica with pipelines
1 parent 2e03cd3 commit 2fe5e14

File tree

16 files changed

+830
-40
lines changed

16 files changed

+830
-40
lines changed

src/dstack/_internal/core/models/configurations.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,12 @@ class ReplicaGroup(CoreModel):
801801
CommandsList,
802802
Field(description="The shell commands to run for replicas in this group"),
803803
] = []
804+
router: Annotated[
805+
Optional[AnyServiceRouterConfig],
806+
Field(
807+
description="When set, replicas in this group run the in-service HTTP router (e.g. SGLang).",
808+
),
809+
] = None
804810

805811
@validator("name")
806812
def validate_name(cls, v: Optional[str]) -> Optional[str]:
@@ -1032,6 +1038,16 @@ def validate_replica_groups_have_commands_or_image(cls, values):
10321038

10331039
return values
10341040

1041+
@root_validator()
1042+
def validate_at_most_one_router_replica_group(cls, values):
1043+
replicas = values.get("replicas")
1044+
if not isinstance(replicas, list):
1045+
return values
1046+
router_groups = [g for g in replicas if g.router is not None]
1047+
if len(router_groups) > 1:
1048+
raise ValueError("At most one replica group may specify `router`.")
1049+
return values
1050+
10351051

10361052
class ServiceConfigurationConfig(
10371053
ProfileParamsConfig,

src/dstack/_internal/proxy/gateway/routers/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ async def register_replica(
8282
ssh_head_proxy=body.ssh_head_proxy,
8383
ssh_head_proxy_private_key=body.ssh_head_proxy_private_key,
8484
internal_ip=body.internal_ip,
85+
is_router_replica=body.is_router_replica,
8586
repo=repo,
8687
nginx=nginx,
8788
service_conn_pool=service_conn_pool,

src/dstack/_internal/proxy/gateway/schemas/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class RegisterReplicaRequest(BaseModel):
5858
ssh_head_proxy: Optional[SSHConnectionParams]
5959
ssh_head_proxy_private_key: Optional[str]
6060
internal_ip: Optional[str] = None
61+
is_router_replica: bool = False
6162

6263

6364
class RegisterEntrypointRequest(BaseModel):

src/dstack/_internal/proxy/gateway/services/registry.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ async def register_replica(
141141
nginx: Nginx,
142142
service_conn_pool: ServiceConnectionPool,
143143
internal_ip: Optional[str] = None,
144+
is_router_replica: bool = False,
144145
) -> None:
145146
replica = models.Replica(
146147
id=replica_id,
@@ -152,6 +153,7 @@ async def register_replica(
152153
ssh_head_proxy=ssh_head_proxy,
153154
ssh_head_proxy_private_key=ssh_head_proxy_private_key,
154155
internal_ip=internal_ip,
156+
is_router_replica=is_router_replica,
155157
)
156158

157159
async with lock:
@@ -291,6 +293,13 @@ async def apply_service(
291293
)
292294
for replica, conn in replica_conns.items()
293295
]
296+
router_replicas = [r for r in service.replicas if r.is_router_replica]
297+
if router_replicas:
298+
replica_configs_for_nginx = [c for c in replica_configs if c.id == router_replicas[0].id]
299+
service_config = await get_nginx_service_config(service, replica_configs_for_nginx)
300+
await nginx.register(service_config, (await repo.get_config()).acme_settings)
301+
return replica_failures
302+
294303
service_config = await get_nginx_service_config(service, replica_configs)
295304
await nginx.register(service_config, (await repo.get_config()).acme_settings)
296305
return replica_failures

src/dstack/_internal/proxy/lib/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class Replica(ImmutableModel):
3030
ssh_head_proxy: Optional[SSHConnectionParams] = None
3131
ssh_head_proxy_private_key: Optional[str] = None
3232
internal_ip: Optional[str] = None
33+
is_router_replica: bool = False
3334

3435

3536
class IPAddressPartitioningKey(ImmutableModel):

src/dstack/_internal/proxy/lib/services/service_connection.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,9 @@ async def get_service_replica_client(
151151
timeout=HTTP_TIMEOUT,
152152
)
153153
# Nginx not available, forward directly to the tunnel
154-
replica = random.choice(service.replicas)
154+
router_replicas = [r for r in service.replicas if r.is_router_replica]
155+
replicas_to_use = router_replicas if router_replicas else service.replicas
156+
replica = random.choice(replicas_to_use)
155157
connection = await service_conn_pool.get(replica.id)
156158
if connection is None:
157159
project = await repo.get_project(service.project_name)

src/dstack/_internal/server/background/pipeline_tasks/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
PlacementGroupPipeline,
1717
)
1818
from dstack._internal.server.background.pipeline_tasks.runs import RunPipeline
19+
from dstack._internal.server.background.pipeline_tasks.service_router_worker_sync import (
20+
ServiceRouterWorkerSyncPipeline,
21+
)
1922
from dstack._internal.server.background.pipeline_tasks.volumes import VolumePipeline
2023
from dstack._internal.utils.logging import get_logger
2124

@@ -36,6 +39,7 @@ def __init__(self) -> None:
3639
InstancePipeline(),
3740
PlacementGroupPipeline(),
3841
RunPipeline(),
42+
ServiceRouterWorkerSyncPipeline(),
3943
VolumePipeline(),
4044
]:
4145
self.register_pipeline(builtin_pipeline)
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
import asyncio
2+
import uuid
3+
from dataclasses import dataclass
4+
from datetime import timedelta
5+
from typing import Sequence
6+
7+
from sqlalchemy import delete, or_, select, update
8+
from sqlalchemy.orm import load_only, selectinload
9+
10+
from dstack._internal.core.models.runs import RunStatus
11+
from dstack._internal.server.background.pipeline_tasks.base import (
12+
Fetcher,
13+
Heartbeater,
14+
ItemUpdateMap,
15+
Pipeline,
16+
PipelineItem,
17+
Worker,
18+
log_lock_token_changed_after_processing,
19+
log_lock_token_mismatch,
20+
resolve_now_placeholders,
21+
set_processed_update_map_fields,
22+
set_unlock_update_map_fields,
23+
)
24+
from dstack._internal.server.db import get_db, get_session_ctx
25+
from dstack._internal.server.models import (
26+
InstanceModel,
27+
JobModel,
28+
RunModel,
29+
ServiceRouterWorkerSyncModel,
30+
)
31+
from dstack._internal.server.services.locking import get_locker
32+
from dstack._internal.server.services.router_worker_sync import (
33+
run_model_has_router_replica_group,
34+
sync_router_workers_for_run_model,
35+
)
36+
from dstack._internal.server.utils import sentry_utils
37+
from dstack._internal.utils.common import get_current_datetime
38+
from dstack._internal.utils.logging import get_logger
39+
40+
logger = get_logger(__name__)
41+
42+
43+
@dataclass
44+
class ServiceRouterWorkerSyncPipelineItem(PipelineItem):
45+
run_id: uuid.UUID
46+
47+
48+
class ServiceRouterWorkerSyncPipeline(Pipeline[ServiceRouterWorkerSyncPipelineItem]):
49+
def __init__(
50+
self,
51+
workers_num: int = 8,
52+
queue_lower_limit_factor: float = 0.5,
53+
queue_upper_limit_factor: float = 2.0,
54+
min_processing_interval: timedelta = timedelta(seconds=5),
55+
lock_timeout: timedelta = timedelta(seconds=25),
56+
heartbeat_trigger: timedelta = timedelta(seconds=10),
57+
) -> None:
58+
super().__init__(
59+
workers_num=workers_num,
60+
queue_lower_limit_factor=queue_lower_limit_factor,
61+
queue_upper_limit_factor=queue_upper_limit_factor,
62+
min_processing_interval=min_processing_interval,
63+
lock_timeout=lock_timeout,
64+
heartbeat_trigger=heartbeat_trigger,
65+
)
66+
self.__heartbeater = Heartbeater[ServiceRouterWorkerSyncPipelineItem](
67+
model_type=ServiceRouterWorkerSyncModel,
68+
lock_timeout=self._lock_timeout,
69+
heartbeat_trigger=self._heartbeat_trigger,
70+
)
71+
self.__fetcher = ServiceRouterWorkerSyncFetcher(
72+
queue=self._queue,
73+
queue_desired_minsize=self._queue_desired_minsize,
74+
min_processing_interval=self._min_processing_interval,
75+
lock_timeout=self._lock_timeout,
76+
heartbeater=self.__heartbeater,
77+
)
78+
self.__workers = [
79+
ServiceRouterWorkerSyncWorker(queue=self._queue, heartbeater=self.__heartbeater)
80+
for _ in range(self._workers_num)
81+
]
82+
83+
@property
84+
def hint_fetch_model_name(self) -> str:
85+
return ServiceRouterWorkerSyncModel.__name__
86+
87+
@property
88+
def _heartbeater(self) -> Heartbeater[ServiceRouterWorkerSyncPipelineItem]:
89+
return self.__heartbeater
90+
91+
@property
92+
def _fetcher(self) -> Fetcher[ServiceRouterWorkerSyncPipelineItem]:
93+
return self.__fetcher
94+
95+
@property
96+
def _workers(self) -> Sequence["ServiceRouterWorkerSyncWorker"]:
97+
return self.__workers
98+
99+
100+
class ServiceRouterWorkerSyncFetcher(Fetcher[ServiceRouterWorkerSyncPipelineItem]):
101+
@sentry_utils.instrument_named_task("pipeline_tasks.ServiceRouterWorkerSyncFetcher.fetch")
102+
async def fetch(self, limit: int) -> list[ServiceRouterWorkerSyncPipelineItem]:
103+
sync_lock, _ = get_locker(get_db().dialect_name).get_lockset(
104+
ServiceRouterWorkerSyncModel.__tablename__
105+
)
106+
async with sync_lock:
107+
async with get_session_ctx() as session:
108+
now = get_current_datetime()
109+
res = await session.execute(
110+
select(ServiceRouterWorkerSyncModel)
111+
.join(RunModel, RunModel.id == ServiceRouterWorkerSyncModel.run_id)
112+
.where(
113+
RunModel.status == RunStatus.RUNNING,
114+
or_(
115+
ServiceRouterWorkerSyncModel.last_processed_at
116+
<= now - self._min_processing_interval,
117+
ServiceRouterWorkerSyncModel.last_processed_at
118+
== ServiceRouterWorkerSyncModel.created_at,
119+
),
120+
or_(
121+
ServiceRouterWorkerSyncModel.lock_expires_at.is_(None),
122+
ServiceRouterWorkerSyncModel.lock_expires_at < now,
123+
),
124+
)
125+
.order_by(ServiceRouterWorkerSyncModel.last_processed_at.asc())
126+
.limit(limit)
127+
.with_for_update(
128+
skip_locked=True, key_share=True, of=ServiceRouterWorkerSyncModel
129+
)
130+
.options(
131+
load_only(
132+
ServiceRouterWorkerSyncModel.id,
133+
ServiceRouterWorkerSyncModel.run_id,
134+
ServiceRouterWorkerSyncModel.lock_token,
135+
ServiceRouterWorkerSyncModel.lock_expires_at,
136+
)
137+
)
138+
)
139+
rows = list(res.scalars().all())
140+
lock_expires_at = get_current_datetime() + self._lock_timeout
141+
lock_token = uuid.uuid4()
142+
items: list[ServiceRouterWorkerSyncPipelineItem] = []
143+
for row in rows:
144+
prev_lock_expired = row.lock_expires_at is not None
145+
row.lock_expires_at = lock_expires_at
146+
row.lock_token = lock_token
147+
row.lock_owner = ServiceRouterWorkerSyncPipeline.__name__
148+
items.append(
149+
ServiceRouterWorkerSyncPipelineItem(
150+
__tablename__=ServiceRouterWorkerSyncModel.__tablename__,
151+
id=row.id,
152+
lock_expires_at=lock_expires_at,
153+
lock_token=lock_token,
154+
prev_lock_expired=prev_lock_expired,
155+
run_id=row.run_id,
156+
)
157+
)
158+
await session.commit()
159+
return items
160+
161+
162+
class _SyncRowUpdateMap(ItemUpdateMap, total=False):
163+
pass
164+
165+
166+
class ServiceRouterWorkerSyncWorker(Worker[ServiceRouterWorkerSyncPipelineItem]):
167+
def __init__(
168+
self,
169+
queue: asyncio.Queue[ServiceRouterWorkerSyncPipelineItem],
170+
heartbeater: Heartbeater[ServiceRouterWorkerSyncPipelineItem],
171+
) -> None:
172+
super().__init__(
173+
queue=queue,
174+
heartbeater=heartbeater,
175+
)
176+
177+
@sentry_utils.instrument_named_task("pipeline_tasks.ServiceRouterWorkerSyncWorker.process")
178+
async def process(self, item: ServiceRouterWorkerSyncPipelineItem) -> None:
179+
async with get_session_ctx() as session:
180+
res = await session.execute(
181+
select(ServiceRouterWorkerSyncModel)
182+
.where(
183+
ServiceRouterWorkerSyncModel.id == item.id,
184+
ServiceRouterWorkerSyncModel.lock_token == item.lock_token,
185+
)
186+
.options(selectinload(ServiceRouterWorkerSyncModel.run))
187+
)
188+
sync_row = res.unique().scalar_one_or_none()
189+
if sync_row is None:
190+
log_lock_token_mismatch(logger, item)
191+
return
192+
run_model = sync_row.run
193+
if run_model is None:
194+
await session.delete(sync_row)
195+
await session.commit()
196+
return
197+
if (
198+
run_model.deleted
199+
or run_model.status.is_finished()
200+
or run_model.status != RunStatus.RUNNING
201+
or not run_model_has_router_replica_group(run_model)
202+
):
203+
await session.delete(sync_row)
204+
await session.commit()
205+
return
206+
207+
async with get_session_ctx() as session:
208+
res = await session.execute(
209+
select(RunModel)
210+
.where(RunModel.id == item.run_id)
211+
.options(
212+
selectinload(RunModel.project),
213+
selectinload(RunModel.jobs).selectinload(JobModel.project),
214+
selectinload(RunModel.jobs)
215+
.selectinload(JobModel.instance)
216+
.selectinload(InstanceModel.project),
217+
)
218+
)
219+
run_for_sync = res.unique().scalar_one_or_none()
220+
221+
if run_for_sync is None:
222+
async with get_session_ctx() as session:
223+
await session.execute(
224+
delete(ServiceRouterWorkerSyncModel).where(
225+
ServiceRouterWorkerSyncModel.id == item.id,
226+
ServiceRouterWorkerSyncModel.lock_token == item.lock_token,
227+
)
228+
)
229+
await session.commit()
230+
return
231+
232+
await sync_router_workers_for_run_model(run_for_sync)
233+
234+
update_map: _SyncRowUpdateMap = {}
235+
set_processed_update_map_fields(update_map)
236+
set_unlock_update_map_fields(update_map)
237+
async with get_session_ctx() as session:
238+
now = get_current_datetime()
239+
resolve_now_placeholders(update_map, now=now)
240+
res2 = await session.execute(
241+
update(ServiceRouterWorkerSyncModel)
242+
.where(
243+
ServiceRouterWorkerSyncModel.id == item.id,
244+
ServiceRouterWorkerSyncModel.lock_token == item.lock_token,
245+
)
246+
.values(**update_map)
247+
.returning(ServiceRouterWorkerSyncModel.id)
248+
)
249+
if not list(res2.scalars().all()):
250+
log_lock_token_changed_after_processing(logger, item)

0 commit comments

Comments
 (0)