Skip to content

Commit bafd2d9

Browse files
Bihan  RanaBihan  Rana
authored andcommitted
Resolve Merge Conflict
1 parent bd42dfc commit bafd2d9

File tree

16 files changed

+839
-40
lines changed

16 files changed

+839
-40
lines changed

src/dstack/_internal/core/models/configurations.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,12 @@ class ReplicaGroup(CoreModel):
801801
CommandsList,
802802
Field(description="The shell commands to run for replicas in this group"),
803803
] = []
804+
router: Annotated[
805+
Optional[AnyServiceRouterConfig],
806+
Field(
807+
description="When set, replicas in this group run the in-service HTTP router (e.g. SGLang).",
808+
),
809+
] = None
804810

805811
@validator("name")
806812
def validate_name(cls, v: Optional[str]) -> Optional[str]:
@@ -1032,6 +1038,16 @@ def validate_replica_groups_have_commands_or_image(cls, values):
10321038

10331039
return values
10341040

1041+
@root_validator()
1042+
def validate_at_most_one_router_replica_group(cls, values):
1043+
replicas = values.get("replicas")
1044+
if not isinstance(replicas, list):
1045+
return values
1046+
router_groups = [g for g in replicas if g.router is not None]
1047+
if len(router_groups) > 1:
1048+
raise ValueError("At most one replica group may specify `router`.")
1049+
return values
1050+
10351051

10361052
class ServiceConfigurationConfig(
10371053
ProfileParamsConfig,

src/dstack/_internal/proxy/gateway/routers/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ async def register_replica(
8282
ssh_head_proxy=body.ssh_head_proxy,
8383
ssh_head_proxy_private_key=body.ssh_head_proxy_private_key,
8484
internal_ip=body.internal_ip,
85+
is_router_replica=body.is_router_replica,
8586
repo=repo,
8687
nginx=nginx,
8788
service_conn_pool=service_conn_pool,

src/dstack/_internal/proxy/gateway/schemas/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class RegisterReplicaRequest(BaseModel):
5858
ssh_head_proxy: Optional[SSHConnectionParams]
5959
ssh_head_proxy_private_key: Optional[str]
6060
internal_ip: Optional[str] = None
61+
is_router_replica: bool = False
6162

6263

6364
class RegisterEntrypointRequest(BaseModel):

src/dstack/_internal/proxy/gateway/services/registry.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ async def register_replica(
141141
nginx: Nginx,
142142
service_conn_pool: ServiceConnectionPool,
143143
internal_ip: Optional[str] = None,
144+
is_router_replica: bool = False,
144145
) -> None:
145146
replica = models.Replica(
146147
id=replica_id,
@@ -152,6 +153,7 @@ async def register_replica(
152153
ssh_head_proxy=ssh_head_proxy,
153154
ssh_head_proxy_private_key=ssh_head_proxy_private_key,
154155
internal_ip=internal_ip,
156+
is_router_replica=is_router_replica,
155157
)
156158

157159
async with lock:
@@ -291,6 +293,13 @@ async def apply_service(
291293
)
292294
for replica, conn in replica_conns.items()
293295
]
296+
router_replicas = [r for r in service.replicas if r.is_router_replica]
297+
if router_replicas:
298+
replica_configs_for_nginx = [c for c in replica_configs if c.id == router_replicas[0].id]
299+
service_config = await get_nginx_service_config(service, replica_configs_for_nginx)
300+
await nginx.register(service_config, (await repo.get_config()).acme_settings)
301+
return replica_failures
302+
294303
service_config = await get_nginx_service_config(service, replica_configs)
295304
await nginx.register(service_config, (await repo.get_config()).acme_settings)
296305
return replica_failures

src/dstack/_internal/proxy/lib/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class Replica(ImmutableModel):
3030
ssh_head_proxy: Optional[SSHConnectionParams] = None
3131
ssh_head_proxy_private_key: Optional[str] = None
3232
internal_ip: Optional[str] = None
33+
is_router_replica: bool = False
3334

3435

3536
class IPAddressPartitioningKey(ImmutableModel):

src/dstack/_internal/proxy/lib/services/service_connection.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,9 @@ async def get_service_replica_client(
151151
timeout=HTTP_TIMEOUT,
152152
)
153153
# Nginx not available, forward directly to the tunnel
154-
replica = random.choice(service.replicas)
154+
router_replicas = [r for r in service.replicas if r.is_router_replica]
155+
replicas_to_use = router_replicas if router_replicas else service.replicas
156+
replica = random.choice(replicas_to_use)
155157
connection = await service_conn_pool.get(replica.id)
156158
if connection is None:
157159
project = await repo.get_project(service.project_name)

src/dstack/_internal/server/background/pipeline_tasks/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
PlacementGroupPipeline,
1717
)
1818
from dstack._internal.server.background.pipeline_tasks.runs import RunPipeline
19+
from dstack._internal.server.background.pipeline_tasks.service_router_worker_sync import (
20+
ServiceRouterWorkerSyncPipeline,
21+
)
1922
from dstack._internal.server.background.pipeline_tasks.volumes import VolumePipeline
2023
from dstack._internal.utils.logging import get_logger
2124

@@ -36,6 +39,7 @@ def __init__(self) -> None:
3639
InstancePipeline(pipeline_hinter=self._hinter),
3740
PlacementGroupPipeline(pipeline_hinter=self._hinter),
3841
RunPipeline(pipeline_hinter=self._hinter),
42+
ServiceRouterWorkerSyncPipeline(pipeline_hinter=self._hinter),
3943
VolumePipeline(pipeline_hinter=self._hinter),
4044
]:
4145
self.register_pipeline(builtin_pipeline)
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
import asyncio
2+
import uuid
3+
from dataclasses import dataclass
4+
from datetime import timedelta
5+
from typing import Sequence
6+
7+
from sqlalchemy import delete, or_, select, update
8+
from sqlalchemy.orm import load_only, selectinload
9+
10+
from dstack._internal.core.models.runs import RunStatus
11+
from dstack._internal.server.background.pipeline_tasks.base import (
12+
Fetcher,
13+
Heartbeater,
14+
ItemUpdateMap,
15+
Pipeline,
16+
PipelineItem,
17+
Worker,
18+
log_lock_token_changed_after_processing,
19+
log_lock_token_mismatch,
20+
resolve_now_placeholders,
21+
set_processed_update_map_fields,
22+
set_unlock_update_map_fields,
23+
)
24+
from dstack._internal.server.db import get_db, get_session_ctx
25+
from dstack._internal.server.models import (
26+
InstanceModel,
27+
JobModel,
28+
RunModel,
29+
ServiceRouterWorkerSyncModel,
30+
)
31+
from dstack._internal.server.services.locking import get_locker
32+
from dstack._internal.server.services.pipelines import PipelineHinterProtocol
33+
from dstack._internal.server.services.router_worker_sync import (
34+
run_model_has_router_replica_group,
35+
sync_router_workers_for_run_model,
36+
)
37+
from dstack._internal.server.utils import sentry_utils
38+
from dstack._internal.utils.common import get_current_datetime
39+
from dstack._internal.utils.logging import get_logger
40+
41+
logger = get_logger(__name__)
42+
43+
44+
@dataclass
45+
class ServiceRouterWorkerSyncPipelineItem(PipelineItem):
46+
run_id: uuid.UUID
47+
48+
49+
class ServiceRouterWorkerSyncPipeline(Pipeline[ServiceRouterWorkerSyncPipelineItem]):
50+
def __init__(
51+
self,
52+
workers_num: int = 8,
53+
queue_lower_limit_factor: float = 0.5,
54+
queue_upper_limit_factor: float = 2.0,
55+
min_processing_interval: timedelta = timedelta(seconds=5),
56+
lock_timeout: timedelta = timedelta(seconds=25),
57+
heartbeat_trigger: timedelta = timedelta(seconds=10),
58+
*,
59+
pipeline_hinter: PipelineHinterProtocol,
60+
) -> None:
61+
super().__init__(
62+
workers_num=workers_num,
63+
queue_lower_limit_factor=queue_lower_limit_factor,
64+
queue_upper_limit_factor=queue_upper_limit_factor,
65+
min_processing_interval=min_processing_interval,
66+
lock_timeout=lock_timeout,
67+
heartbeat_trigger=heartbeat_trigger,
68+
)
69+
self.__heartbeater = Heartbeater[ServiceRouterWorkerSyncPipelineItem](
70+
model_type=ServiceRouterWorkerSyncModel,
71+
lock_timeout=self._lock_timeout,
72+
heartbeat_trigger=self._heartbeat_trigger,
73+
)
74+
self.__fetcher = ServiceRouterWorkerSyncFetcher(
75+
queue=self._queue,
76+
queue_desired_minsize=self._queue_desired_minsize,
77+
min_processing_interval=self._min_processing_interval,
78+
lock_timeout=self._lock_timeout,
79+
heartbeater=self.__heartbeater,
80+
)
81+
self.__workers = [
82+
ServiceRouterWorkerSyncWorker(
83+
queue=self._queue,
84+
heartbeater=self.__heartbeater,
85+
pipeline_hinter=pipeline_hinter,
86+
)
87+
for _ in range(self._workers_num)
88+
]
89+
90+
@property
91+
def hint_fetch_model_name(self) -> str:
92+
return ServiceRouterWorkerSyncModel.__name__
93+
94+
@property
95+
def _heartbeater(self) -> Heartbeater[ServiceRouterWorkerSyncPipelineItem]:
96+
return self.__heartbeater
97+
98+
@property
99+
def _fetcher(self) -> Fetcher[ServiceRouterWorkerSyncPipelineItem]:
100+
return self.__fetcher
101+
102+
@property
103+
def _workers(self) -> Sequence["ServiceRouterWorkerSyncWorker"]:
104+
return self.__workers
105+
106+
107+
class ServiceRouterWorkerSyncFetcher(Fetcher[ServiceRouterWorkerSyncPipelineItem]):
108+
@sentry_utils.instrument_named_task("pipeline_tasks.ServiceRouterWorkerSyncFetcher.fetch")
109+
async def fetch(self, limit: int) -> list[ServiceRouterWorkerSyncPipelineItem]:
110+
sync_lock, _ = get_locker(get_db().dialect_name).get_lockset(
111+
ServiceRouterWorkerSyncModel.__tablename__
112+
)
113+
async with sync_lock:
114+
async with get_session_ctx() as session:
115+
now = get_current_datetime()
116+
res = await session.execute(
117+
select(ServiceRouterWorkerSyncModel)
118+
.join(RunModel, RunModel.id == ServiceRouterWorkerSyncModel.run_id)
119+
.where(
120+
RunModel.status == RunStatus.RUNNING,
121+
or_(
122+
ServiceRouterWorkerSyncModel.last_processed_at
123+
<= now - self._min_processing_interval,
124+
ServiceRouterWorkerSyncModel.last_processed_at
125+
== ServiceRouterWorkerSyncModel.created_at,
126+
),
127+
or_(
128+
ServiceRouterWorkerSyncModel.lock_expires_at.is_(None),
129+
ServiceRouterWorkerSyncModel.lock_expires_at < now,
130+
),
131+
)
132+
.order_by(ServiceRouterWorkerSyncModel.last_processed_at.asc())
133+
.limit(limit)
134+
.with_for_update(
135+
skip_locked=True, key_share=True, of=ServiceRouterWorkerSyncModel
136+
)
137+
.options(
138+
load_only(
139+
ServiceRouterWorkerSyncModel.id,
140+
ServiceRouterWorkerSyncModel.run_id,
141+
ServiceRouterWorkerSyncModel.lock_token,
142+
ServiceRouterWorkerSyncModel.lock_expires_at,
143+
)
144+
)
145+
)
146+
rows = list(res.scalars().all())
147+
lock_expires_at = get_current_datetime() + self._lock_timeout
148+
lock_token = uuid.uuid4()
149+
items: list[ServiceRouterWorkerSyncPipelineItem] = []
150+
for row in rows:
151+
prev_lock_expired = row.lock_expires_at is not None
152+
row.lock_expires_at = lock_expires_at
153+
row.lock_token = lock_token
154+
row.lock_owner = ServiceRouterWorkerSyncPipeline.__name__
155+
items.append(
156+
ServiceRouterWorkerSyncPipelineItem(
157+
__tablename__=ServiceRouterWorkerSyncModel.__tablename__,
158+
id=row.id,
159+
lock_expires_at=lock_expires_at,
160+
lock_token=lock_token,
161+
prev_lock_expired=prev_lock_expired,
162+
run_id=row.run_id,
163+
)
164+
)
165+
await session.commit()
166+
return items
167+
168+
169+
class _SyncRowUpdateMap(ItemUpdateMap, total=False):
170+
pass
171+
172+
173+
class ServiceRouterWorkerSyncWorker(Worker[ServiceRouterWorkerSyncPipelineItem]):
174+
def __init__(
175+
self,
176+
queue: asyncio.Queue[ServiceRouterWorkerSyncPipelineItem],
177+
heartbeater: Heartbeater[ServiceRouterWorkerSyncPipelineItem],
178+
pipeline_hinter: PipelineHinterProtocol,
179+
) -> None:
180+
super().__init__(
181+
queue=queue,
182+
heartbeater=heartbeater,
183+
pipeline_hinter=pipeline_hinter,
184+
)
185+
186+
@sentry_utils.instrument_named_task("pipeline_tasks.ServiceRouterWorkerSyncWorker.process")
187+
async def process(self, item: ServiceRouterWorkerSyncPipelineItem) -> None:
188+
async with get_session_ctx() as session:
189+
res = await session.execute(
190+
select(ServiceRouterWorkerSyncModel)
191+
.where(
192+
ServiceRouterWorkerSyncModel.id == item.id,
193+
ServiceRouterWorkerSyncModel.lock_token == item.lock_token,
194+
)
195+
.options(selectinload(ServiceRouterWorkerSyncModel.run))
196+
)
197+
sync_row = res.unique().scalar_one_or_none()
198+
if sync_row is None:
199+
log_lock_token_mismatch(logger, item)
200+
return
201+
run_model = sync_row.run
202+
if run_model is None:
203+
await session.delete(sync_row)
204+
await session.commit()
205+
return
206+
if (
207+
run_model.deleted
208+
or run_model.status.is_finished()
209+
or run_model.status != RunStatus.RUNNING
210+
or not run_model_has_router_replica_group(run_model)
211+
):
212+
await session.delete(sync_row)
213+
await session.commit()
214+
return
215+
216+
async with get_session_ctx() as session:
217+
res = await session.execute(
218+
select(RunModel)
219+
.where(RunModel.id == item.run_id)
220+
.options(
221+
selectinload(RunModel.project),
222+
selectinload(RunModel.jobs).selectinload(JobModel.project),
223+
selectinload(RunModel.jobs)
224+
.selectinload(JobModel.instance)
225+
.selectinload(InstanceModel.project),
226+
)
227+
)
228+
run_for_sync = res.unique().scalar_one_or_none()
229+
230+
if run_for_sync is None:
231+
async with get_session_ctx() as session:
232+
await session.execute(
233+
delete(ServiceRouterWorkerSyncModel).where(
234+
ServiceRouterWorkerSyncModel.id == item.id,
235+
ServiceRouterWorkerSyncModel.lock_token == item.lock_token,
236+
)
237+
)
238+
await session.commit()
239+
return
240+
241+
await sync_router_workers_for_run_model(run_for_sync)
242+
243+
update_map: _SyncRowUpdateMap = {}
244+
set_processed_update_map_fields(update_map)
245+
set_unlock_update_map_fields(update_map)
246+
async with get_session_ctx() as session:
247+
now = get_current_datetime()
248+
resolve_now_placeholders(update_map, now=now)
249+
res2 = await session.execute(
250+
update(ServiceRouterWorkerSyncModel)
251+
.where(
252+
ServiceRouterWorkerSyncModel.id == item.id,
253+
ServiceRouterWorkerSyncModel.lock_token == item.lock_token,
254+
)
255+
.values(**update_map)
256+
.returning(ServiceRouterWorkerSyncModel.id)
257+
)
258+
if not list(res2.scalars().all()):
259+
log_lock_token_changed_after_processing(logger, item)

0 commit comments

Comments
 (0)