Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/dstack/_internal/core/models/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,12 @@ class ReplicaGroup(CoreModel):
CommandsList,
Field(description="The shell commands to run for replicas in this group"),
] = []
router: Annotated[
Optional[AnyServiceRouterConfig],
Field(
description="When set, replicas in this group run the in-service HTTP router (e.g. SGLang).",
),
] = None

@validator("name")
def validate_name(cls, v: Optional[str]) -> Optional[str]:
Expand Down Expand Up @@ -1032,6 +1038,16 @@ def validate_replica_groups_have_commands_or_image(cls, values):

return values

@root_validator()
def validate_at_most_one_router_replica_group(cls, values):
replicas = values.get("replicas")
if not isinstance(replicas, list):
return values
router_groups = [g for g in replicas if g.router is not None]
if len(router_groups) > 1:
raise ValueError("At most one replica group may specify `router`.")
return values


class ServiceConfigurationConfig(
ProfileParamsConfig,
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/proxy/gateway/routers/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ async def register_replica(
ssh_head_proxy=body.ssh_head_proxy,
ssh_head_proxy_private_key=body.ssh_head_proxy_private_key,
internal_ip=body.internal_ip,
is_router_replica=body.is_router_replica,
repo=repo,
nginx=nginx,
service_conn_pool=service_conn_pool,
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/proxy/gateway/schemas/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class RegisterReplicaRequest(BaseModel):
ssh_head_proxy: Optional[SSHConnectionParams]
ssh_head_proxy_private_key: Optional[str]
internal_ip: Optional[str] = None
is_router_replica: bool = False


class RegisterEntrypointRequest(BaseModel):
Expand Down
9 changes: 9 additions & 0 deletions src/dstack/_internal/proxy/gateway/services/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ async def register_replica(
nginx: Nginx,
service_conn_pool: ServiceConnectionPool,
internal_ip: Optional[str] = None,
is_router_replica: bool = False,
) -> None:
replica = models.Replica(
id=replica_id,
Expand All @@ -152,6 +153,7 @@ async def register_replica(
ssh_head_proxy=ssh_head_proxy,
ssh_head_proxy_private_key=ssh_head_proxy_private_key,
internal_ip=internal_ip,
is_router_replica=is_router_replica,
)

async with lock:
Expand Down Expand Up @@ -291,6 +293,13 @@ async def apply_service(
)
for replica, conn in replica_conns.items()
]
router_replicas = [r for r in service.replicas if r.is_router_replica]
if router_replicas:
replica_configs_for_nginx = [c for c in replica_configs if c.id == router_replicas[0].id]
service_config = await get_nginx_service_config(service, replica_configs_for_nginx)
await nginx.register(service_config, (await repo.get_config()).acme_settings)
return replica_failures

service_config = await get_nginx_service_config(service, replica_configs)
await nginx.register(service_config, (await repo.get_config()).acme_settings)
return replica_failures
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/proxy/lib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class Replica(ImmutableModel):
ssh_head_proxy: Optional[SSHConnectionParams] = None
ssh_head_proxy_private_key: Optional[str] = None
internal_ip: Optional[str] = None
is_router_replica: bool = False


class IPAddressPartitioningKey(ImmutableModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ async def get_service_replica_client(
timeout=HTTP_TIMEOUT,
)
# Nginx not available, forward directly to the tunnel
replica = random.choice(service.replicas)
router_replicas = [r for r in service.replicas if r.is_router_replica]
replicas_to_use = router_replicas if router_replicas else service.replicas
replica = random.choice(replicas_to_use)
connection = await service_conn_pool.get(replica.id)
if connection is None:
project = await repo.get_project(service.project_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
PlacementGroupPipeline,
)
from dstack._internal.server.background.pipeline_tasks.runs import RunPipeline
from dstack._internal.server.background.pipeline_tasks.service_router_worker_sync import (
ServiceRouterWorkerSyncPipeline,
)
from dstack._internal.server.background.pipeline_tasks.volumes import VolumePipeline
from dstack._internal.utils.logging import get_logger

Expand All @@ -36,6 +39,7 @@ def __init__(self) -> None:
InstancePipeline(pipeline_hinter=self._hinter),
PlacementGroupPipeline(pipeline_hinter=self._hinter),
RunPipeline(pipeline_hinter=self._hinter),
ServiceRouterWorkerSyncPipeline(pipeline_hinter=self._hinter),
VolumePipeline(pipeline_hinter=self._hinter),
]:
self.register_pipeline(builtin_pipeline)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
import asyncio
import uuid
from dataclasses import dataclass
from datetime import timedelta
from typing import Sequence

from sqlalchemy import delete, or_, select, update
from sqlalchemy.orm import load_only, selectinload

from dstack._internal.core.models.runs import RunStatus
from dstack._internal.server.background.pipeline_tasks.base import (
Fetcher,
Heartbeater,
ItemUpdateMap,
Pipeline,
PipelineItem,
Worker,
log_lock_token_changed_after_processing,
log_lock_token_mismatch,
resolve_now_placeholders,
set_processed_update_map_fields,
set_unlock_update_map_fields,
)
from dstack._internal.server.db import get_db, get_session_ctx
from dstack._internal.server.models import (
InstanceModel,
JobModel,
RunModel,
ServiceRouterWorkerSyncModel,
)
from dstack._internal.server.services.locking import get_locker
from dstack._internal.server.services.pipelines import PipelineHinterProtocol
from dstack._internal.server.services.router_worker_sync import (
run_model_has_router_replica_group,
sync_router_workers_for_run_model,
)
from dstack._internal.server.utils import sentry_utils
from dstack._internal.utils.common import get_current_datetime
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)


@dataclass
class ServiceRouterWorkerSyncPipelineItem(PipelineItem):
run_id: uuid.UUID


class ServiceRouterWorkerSyncPipeline(Pipeline[ServiceRouterWorkerSyncPipelineItem]):
def __init__(
self,
workers_num: int = 8,
queue_lower_limit_factor: float = 0.5,
queue_upper_limit_factor: float = 2.0,
min_processing_interval: timedelta = timedelta(seconds=5),
lock_timeout: timedelta = timedelta(seconds=25),
heartbeat_trigger: timedelta = timedelta(seconds=10),
*,
pipeline_hinter: PipelineHinterProtocol,
) -> None:
super().__init__(
workers_num=workers_num,
queue_lower_limit_factor=queue_lower_limit_factor,
queue_upper_limit_factor=queue_upper_limit_factor,
min_processing_interval=min_processing_interval,
lock_timeout=lock_timeout,
heartbeat_trigger=heartbeat_trigger,
)
self.__heartbeater = Heartbeater[ServiceRouterWorkerSyncPipelineItem](
model_type=ServiceRouterWorkerSyncModel,
lock_timeout=self._lock_timeout,
heartbeat_trigger=self._heartbeat_trigger,
)
self.__fetcher = ServiceRouterWorkerSyncFetcher(
queue=self._queue,
queue_desired_minsize=self._queue_desired_minsize,
min_processing_interval=self._min_processing_interval,
lock_timeout=self._lock_timeout,
heartbeater=self.__heartbeater,
)
self.__workers = [
ServiceRouterWorkerSyncWorker(
queue=self._queue,
heartbeater=self.__heartbeater,
pipeline_hinter=pipeline_hinter,
)
for _ in range(self._workers_num)
]

@property
def hint_fetch_model_name(self) -> str:
return ServiceRouterWorkerSyncModel.__name__

@property
def _heartbeater(self) -> Heartbeater[ServiceRouterWorkerSyncPipelineItem]:
return self.__heartbeater

@property
def _fetcher(self) -> Fetcher[ServiceRouterWorkerSyncPipelineItem]:
return self.__fetcher

@property
def _workers(self) -> Sequence["ServiceRouterWorkerSyncWorker"]:
return self.__workers


class ServiceRouterWorkerSyncFetcher(Fetcher[ServiceRouterWorkerSyncPipelineItem]):
@sentry_utils.instrument_named_task("pipeline_tasks.ServiceRouterWorkerSyncFetcher.fetch")
async def fetch(self, limit: int) -> list[ServiceRouterWorkerSyncPipelineItem]:
sync_lock, _ = get_locker(get_db().dialect_name).get_lockset(
ServiceRouterWorkerSyncModel.__tablename__
)
async with sync_lock:
async with get_session_ctx() as session:
now = get_current_datetime()
res = await session.execute(
select(ServiceRouterWorkerSyncModel)
.join(RunModel, RunModel.id == ServiceRouterWorkerSyncModel.run_id)
.where(
RunModel.status == RunStatus.RUNNING,
or_(
ServiceRouterWorkerSyncModel.last_processed_at
<= now - self._min_processing_interval,
ServiceRouterWorkerSyncModel.last_processed_at
== ServiceRouterWorkerSyncModel.created_at,
),
or_(
ServiceRouterWorkerSyncModel.lock_expires_at.is_(None),
ServiceRouterWorkerSyncModel.lock_expires_at < now,
),
)
.order_by(ServiceRouterWorkerSyncModel.last_processed_at.asc())
.limit(limit)
.with_for_update(
skip_locked=True, key_share=True, of=ServiceRouterWorkerSyncModel
)
.options(
load_only(
ServiceRouterWorkerSyncModel.id,
ServiceRouterWorkerSyncModel.run_id,
ServiceRouterWorkerSyncModel.lock_token,
ServiceRouterWorkerSyncModel.lock_expires_at,
)
)
)
rows = list(res.scalars().all())
lock_expires_at = get_current_datetime() + self._lock_timeout
lock_token = uuid.uuid4()
items: list[ServiceRouterWorkerSyncPipelineItem] = []
for row in rows:
prev_lock_expired = row.lock_expires_at is not None
row.lock_expires_at = lock_expires_at
row.lock_token = lock_token
row.lock_owner = ServiceRouterWorkerSyncPipeline.__name__
items.append(
ServiceRouterWorkerSyncPipelineItem(
__tablename__=ServiceRouterWorkerSyncModel.__tablename__,
id=row.id,
lock_expires_at=lock_expires_at,
lock_token=lock_token,
prev_lock_expired=prev_lock_expired,
run_id=row.run_id,
)
)
await session.commit()
return items


class _SyncRowUpdateMap(ItemUpdateMap, total=False):
pass


class ServiceRouterWorkerSyncWorker(Worker[ServiceRouterWorkerSyncPipelineItem]):
def __init__(
self,
queue: asyncio.Queue[ServiceRouterWorkerSyncPipelineItem],
heartbeater: Heartbeater[ServiceRouterWorkerSyncPipelineItem],
pipeline_hinter: PipelineHinterProtocol,
) -> None:
super().__init__(
queue=queue,
heartbeater=heartbeater,
pipeline_hinter=pipeline_hinter,
)

@sentry_utils.instrument_named_task("pipeline_tasks.ServiceRouterWorkerSyncWorker.process")
async def process(self, item: ServiceRouterWorkerSyncPipelineItem) -> None:
async with get_session_ctx() as session:
res = await session.execute(
select(ServiceRouterWorkerSyncModel)
.where(
ServiceRouterWorkerSyncModel.id == item.id,
ServiceRouterWorkerSyncModel.lock_token == item.lock_token,
)
.options(selectinload(ServiceRouterWorkerSyncModel.run))
)
sync_row = res.unique().scalar_one_or_none()
if sync_row is None:
log_lock_token_mismatch(logger, item)
return
run_model = sync_row.run
if run_model is None:
await session.delete(sync_row)
await session.commit()
return
if (
run_model.deleted
or run_model.status.is_finished()
or run_model.status != RunStatus.RUNNING
or not run_model_has_router_replica_group(run_model)
):
await session.delete(sync_row)
await session.commit()
return

async with get_session_ctx() as session:
res = await session.execute(
select(RunModel)
.where(RunModel.id == item.run_id)
.options(
selectinload(RunModel.project),
selectinload(RunModel.jobs).selectinload(JobModel.project),
selectinload(RunModel.jobs)
.selectinload(JobModel.instance)
.selectinload(InstanceModel.project),
)
)
run_for_sync = res.unique().scalar_one_or_none()

if run_for_sync is None:
async with get_session_ctx() as session:
await session.execute(
delete(ServiceRouterWorkerSyncModel).where(
ServiceRouterWorkerSyncModel.id == item.id,
ServiceRouterWorkerSyncModel.lock_token == item.lock_token,
)
)
await session.commit()
return

await sync_router_workers_for_run_model(run_for_sync)

update_map: _SyncRowUpdateMap = {}
set_processed_update_map_fields(update_map)
set_unlock_update_map_fields(update_map)
async with get_session_ctx() as session:
now = get_current_datetime()
resolve_now_placeholders(update_map, now=now)
res2 = await session.execute(
update(ServiceRouterWorkerSyncModel)
.where(
ServiceRouterWorkerSyncModel.id == item.id,
ServiceRouterWorkerSyncModel.lock_token == item.lock_token,
)
.values(**update_map)
.returning(ServiceRouterWorkerSyncModel.id)
)
if not list(res2.scalars().all()):
log_lock_token_changed_after_processing(logger, item)
Loading
Loading