Skip to content

Commit c5a6716

Browse files
Bihan  RanaBihan  Rana
authored andcommitted
Resolve review comments
1 parent 8fe01e5 commit c5a6716

File tree

10 files changed

+75
-52
lines changed

10 files changed

+75
-52
lines changed

src/dstack/_internal/proxy/gateway/routers/registry.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ async def register_replica(
8282
ssh_head_proxy=body.ssh_head_proxy,
8383
ssh_head_proxy_private_key=body.ssh_head_proxy_private_key,
8484
internal_ip=body.internal_ip,
85-
is_router_replica=body.is_router_replica,
8685
repo=repo,
8786
nginx=nginx,
8887
service_conn_pool=service_conn_pool,

src/dstack/_internal/proxy/gateway/schemas/registry.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ class RegisterReplicaRequest(BaseModel):
5858
ssh_head_proxy: Optional[SSHConnectionParams]
5959
ssh_head_proxy_private_key: Optional[str]
6060
internal_ip: Optional[str] = None
61-
is_router_replica: bool = False
6261

6362

6463
class RegisterEntrypointRequest(BaseModel):

src/dstack/_internal/proxy/gateway/services/registry.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,6 @@ async def register_replica(
141141
nginx: Nginx,
142142
service_conn_pool: ServiceConnectionPool,
143143
internal_ip: Optional[str] = None,
144-
is_router_replica: bool = False,
145144
) -> None:
146145
replica = models.Replica(
147146
id=replica_id,
@@ -153,7 +152,6 @@ async def register_replica(
153152
ssh_head_proxy=ssh_head_proxy,
154153
ssh_head_proxy_private_key=ssh_head_proxy_private_key,
155154
internal_ip=internal_ip,
156-
is_router_replica=is_router_replica,
157155
)
158156

159157
async with lock:

src/dstack/_internal/proxy/lib/models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ class Replica(ImmutableModel):
3030
ssh_head_proxy: Optional[SSHConnectionParams] = None
3131
ssh_head_proxy_private_key: Optional[str] = None
3232
internal_ip: Optional[str] = None
33-
is_router_replica: bool = False
3433

3534

3635
class IPAddressPartitioningKey(ImmutableModel):

src/dstack/_internal/proxy/lib/services/service_connection.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,7 @@ async def get_service_replica_client(
151151
timeout=HTTP_TIMEOUT,
152152
)
153153
# Nginx not available, forward directly to the tunnel
154-
router_replicas = [r for r in service.replicas if r.is_router_replica]
155-
replicas_to_use = router_replicas if router_replicas else service.replicas
156-
replica = random.choice(replicas_to_use)
154+
replica = random.choice(service.replicas)
157155
connection = await service_conn_pool.get(replica.id)
158156
if connection is None:
159157
project = await repo.get_project(service.project_name)

src/dstack/_internal/server/background/pipeline_tasks/service_router_worker_sync.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from sqlalchemy import or_, select, update
88
from sqlalchemy.orm import joinedload, load_only, selectinload
9-
from sqlalchemy.sql import false
9+
from sqlalchemy.sql import false, true
1010

1111
from dstack._internal.core.models.runs import JobStatus, RunStatus
1212
from dstack._internal.server.background.pipeline_tasks.base import (
@@ -120,7 +120,13 @@ async def fetch(self, limit: int) -> list[ServiceRouterWorkerSyncPipelineItem]:
120120
.join(RunModel, RunModel.id == ServiceRouterWorkerSyncModel.run_id)
121121
.where(
122122
ServiceRouterWorkerSyncModel.deleted == false(),
123-
RunModel.status == RunStatus.RUNNING,
123+
# Fetch RUNNING runs for normal processing, and finished/deleted runs so
124+
# the worker can mark their sync rows deleted.
125+
or_(
126+
RunModel.status == RunStatus.RUNNING,
127+
RunModel.status.in_(RunStatus.finished_statuses()),
128+
RunModel.deleted == true(),
129+
),
124130
or_(
125131
ServiceRouterWorkerSyncModel.last_processed_at
126132
<= now - self._min_processing_interval,
@@ -170,7 +176,7 @@ async def fetch(self, limit: int) -> list[ServiceRouterWorkerSyncPipelineItem]:
170176

171177

172178
class _SyncRowUpdateMap(ItemUpdateMap, total=False):
173-
pass
179+
deleted: bool
174180

175181

176182
class ServiceRouterWorkerSyncWorker(Worker[ServiceRouterWorkerSyncPipelineItem]):
@@ -224,6 +230,7 @@ async def process(self, item: ServiceRouterWorkerSyncPipelineItem) -> None:
224230
selectinload(
225231
RunModel.jobs.and_(
226232
JobModel.status == JobStatus.RUNNING,
233+
JobModel.registered == True,
227234
)
228235
)
229236
.load_only(
@@ -248,20 +255,23 @@ async def process(self, item: ServiceRouterWorkerSyncPipelineItem) -> None:
248255
run_for_sync = res.unique().scalar_one_or_none()
249256

250257
if run_for_sync is None:
258+
cleanup_update_map: _SyncRowUpdateMap = {"deleted": True}
259+
set_processed_update_map_fields(cleanup_update_map)
260+
set_unlock_update_map_fields(cleanup_update_map)
251261
async with get_session_ctx() as session:
252-
await session.execute(
262+
now = get_current_datetime()
263+
resolve_now_placeholders(cleanup_update_map, now=now)
264+
res2 = await session.execute(
253265
update(ServiceRouterWorkerSyncModel)
254266
.where(
255267
ServiceRouterWorkerSyncModel.id == item.id,
256268
ServiceRouterWorkerSyncModel.lock_token == item.lock_token,
257269
)
258-
.values(
259-
deleted=True,
260-
lock_expires_at=None,
261-
lock_token=None,
262-
lock_owner=None,
263-
)
270+
.values(**cleanup_update_map)
271+
.returning(ServiceRouterWorkerSyncModel.id)
264272
)
273+
if not list(res2.scalars().all()):
274+
log_lock_token_changed_after_processing(logger, item)
265275
await session.commit()
266276
return
267277

src/dstack/_internal/server/services/gateways/client.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,17 +90,11 @@ async def register_replica(
9090
ssh_head_proxy_private_key: Optional[str],
9191
):
9292
assert run.run_spec.configuration.type == "service"
93-
config = run.run_spec.configuration
94-
router_group = next((g for g in config.replica_groups if g.router is not None), None)
95-
is_router_replica = (
96-
router_group is not None and job_spec.replica_group == router_group.name
97-
)
9893
payload = {
9994
"job_id": job_submission.id.hex,
10095
"app_port": get_service_port(job_spec, run.run_spec.configuration),
10196
"ssh_head_proxy": ssh_head_proxy.dict() if ssh_head_proxy is not None else None,
10297
"ssh_head_proxy_private_key": ssh_head_proxy_private_key,
103-
"is_router_replica": is_router_replica,
10498
}
10599
jpd = job_submission.job_provisioning_data
106100
assert jpd is not None

src/dstack/_internal/server/services/proxy/repo.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
7171
run_spec = get_run_spec(run)
7272
if not isinstance(run_spec.configuration, ServiceConfiguration):
7373
return None
74+
router_group = next(
75+
(g for g in run_spec.configuration.replica_groups if g.router is not None),
76+
None,
77+
)
78+
router = run_spec.configuration.router
7479
replicas = []
7580
for job in jobs:
7681
jpd: JobProvisioningData = JobProvisioningData.__response__.parse_raw(
@@ -109,13 +114,10 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
109114
ssh_head_proxy = rci.ssh_proxy
110115
ssh_head_proxy_private_key = get_or_error(rci.ssh_proxy_keys)[0].private
111116
job_spec = get_job_spec(job)
112-
router_group = next(
113-
(g for g in run_spec.configuration.replica_groups if g.router is not None),
114-
None,
115-
)
116-
is_router_replica = (
117-
router_group is not None and job_spec.replica_group == router_group.name
118-
)
117+
if router_group is not None and job_spec.replica_group != router_group.name:
118+
# Strict router-only: when a router is configured, the proxy should only be aware
119+
# of router replicas.
120+
continue
119121
replica = Replica(
120122
id=job.id.hex,
121123
app_port=get_service_port(job_spec, run_spec.configuration),
@@ -126,7 +128,6 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
126128
ssh_head_proxy=ssh_head_proxy,
127129
ssh_head_proxy_private_key=ssh_head_proxy_private_key,
128130
internal_ip=jpd.internal_ip,
129-
is_router_replica=is_router_replica,
130131
)
131132
replicas.append(replica)
132133
return Service(
@@ -138,6 +139,7 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
138139
client_max_body_size=DEFAULT_SERVICE_CLIENT_MAX_BODY_SIZE,
139140
strip_prefix=run_spec.configuration.strip_prefix,
140141
replicas=tuple(replicas),
142+
router=router,
141143
)
142144

143145
async def list_models(self, project_name: str) -> List[ChatModel]:

src/dstack/_internal/server/services/proxy/services/service_proxy.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from fastapi import status
66
from starlette.requests import ClientDisconnect
77

8+
from dstack._internal.core.models.routers import RouterType
89
from dstack._internal.proxy.lib.deps import ProxyAuthContext
910
from dstack._internal.proxy.lib.errors import ProxyError
1011
from dstack._internal.proxy.lib.repo import BaseProxyRepo
@@ -36,11 +37,16 @@ async def proxy(
3637
if service.auth:
3738
await auth.enforce()
3839

39-
client = await get_service_replica_client(service, repo, service_conn_pool)
40-
4140
if not service.strip_prefix:
4241
path = concat_url_path(request.scope.get("root_path", "/"), request.url.path)
4342

43+
if service.router is not None and service.router.type == RouterType.SGLANG:
44+
path_for_match = path if path.startswith("/") else f"/{path}"
45+
if not _is_whitelisted_path(path_for_match, _SGLANG_WHITELISTED_PATHS):
46+
raise ProxyError("Path is not allowed for this service", status.HTTP_404_NOT_FOUND)
47+
48+
client = await get_service_replica_client(service, repo, service_conn_pool)
49+
4450
try:
4551
upstream_request = await build_upstream_request(request, path, client)
4652
except ClientDisconnect:
@@ -68,6 +74,23 @@ async def proxy(
6874
)
6975

7076

77+
_SGLANG_WHITELISTED_PATHS = (
78+
"/generate",
79+
"/v1/",
80+
"/chat/completions",
81+
)
82+
83+
84+
def _is_whitelisted_path(path: str, whitelisted_paths: tuple[str, ...]) -> bool:
85+
for allowed in whitelisted_paths:
86+
if allowed.endswith("/"):
87+
if path.startswith(allowed):
88+
return True
89+
elif path == allowed:
90+
return True
91+
return False
92+
93+
7194
async def stream_response(response: httpx.Response) -> AsyncGenerator[bytes, None]:
7295
try:
7396
async for chunk in response.aiter_raw():

src/dstack/_internal/server/services/runs/router_worker_sync.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -176,20 +176,10 @@ async def _add_worker_to_router(
176176
return False
177177

178178

179-
async def _remove_worker_from_router(client: AsyncClient, worker_url: str) -> bool:
179+
async def _remove_worker_from_router_by_id(
180+
client: AsyncClient, worker_id: str, *, worker_url: str
181+
) -> bool:
180182
try:
181-
current = await _get_router_workers(client)
182-
worker_id = None
183-
for w in current:
184-
u = w.get("url")
185-
if u and isinstance(u, str) and u.rstrip("/") == worker_url.rstrip("/"):
186-
wid = w.get("id")
187-
if wid and isinstance(wid, str):
188-
worker_id = wid
189-
break
190-
if not worker_id:
191-
logger.error("No worker id found for url %s", worker_url)
192-
return False
193183
body = await _request_json_limited(
194184
client,
195185
"DELETE",
@@ -212,10 +202,16 @@ async def _update_workers_in_router_replica(
212202
) -> None:
213203
current = await _get_router_workers(client)
214204
current_urls: set[str] = set()
205+
current_ids_by_norm_url: dict[str, str] = {}
215206
for w in current:
216207
u = w.get("url")
217-
if isinstance(u, str) and u:
218-
current_urls.add(_normalize_worker_url(u))
208+
if not isinstance(u, str) or not u:
209+
continue
210+
norm_u = _normalize_worker_url(u)
211+
current_urls.add(norm_u)
212+
wid = w.get("id")
213+
if isinstance(wid, str) and wid:
214+
current_ids_by_norm_url[norm_u] = wid
219215
target_by_norm = {_normalize_worker_url(t["url"]): t for t in target_workers}
220216
target_urls = set(target_by_norm.keys())
221217
to_add = sorted(target_urls - current_urls)
@@ -231,7 +227,12 @@ async def _update_workers_in_router_replica(
231227
if not ok:
232228
logger.warning("Failed to add worker %s, continuing with others", tw["url"])
233229
for url in to_remove:
234-
ok = await _remove_worker_from_router(client, url)
230+
wid = current_ids_by_norm_url.get(url)
231+
if not wid:
232+
logger.error("No worker id found for url %s", url)
233+
ok = False
234+
else:
235+
ok = await _remove_worker_from_router_by_id(client, wid, worker_url=url)
235236
if not ok:
236237
logger.warning("Failed to remove worker %s, continuing with others", url)
237238

@@ -270,9 +271,9 @@ async def _get_worker_payload(job_model: JobModel, worker_url: str) -> _WorkerPa
270271
"payload": {"url": worker_url, "worker_type": "regular"},
271272
}
272273
except _ResponseTooLargeError:
273-
logger.debug("server_info response too large for worker %s", worker_url)
274+
logger.warning("server_info response too large for worker %s", worker_url)
274275
except Exception as e:
275-
logger.debug("Could not fetch server_info for worker %s: %r", worker_url, e)
276+
logger.exception("Could not fetch server_info for worker %s: %r", worker_url, e)
276277
return {"status": "not_ready", "payload": None}
277278

278279

0 commit comments

Comments
 (0)