From bafd2d9abcaaf82206508d5be62b8158f5228d81 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Wed, 1 Apr 2026 13:01:14 +0545 Subject: [PATCH 1/3] Resolve Merge Conflict --- .../_internal/core/models/configurations.py | 16 + .../proxy/gateway/routers/registry.py | 1 + .../proxy/gateway/schemas/registry.py | 1 + .../proxy/gateway/services/registry.py | 9 + src/dstack/_internal/proxy/lib/models.py | 1 + .../proxy/lib/services/service_connection.py | 4 +- .../background/pipeline_tasks/__init__.py | 4 + .../service_router_worker_sync.py | 259 +++++++++++++ .../background/scheduled_tasks/probes.py | 44 +-- ...a91b2c3d_add_service_router_worker_sync.py | 66 ++++ src/dstack/_internal/server/models.py | 29 ++ .../server/services/gateways/client.py | 6 + .../services/job_replica_http_client.py | 49 +++ .../_internal/server/services/proxy/repo.py | 8 + .../server/services/router_worker_sync.py | 345 ++++++++++++++++++ .../server/services/runs/__init__.py | 37 ++ 16 files changed, 839 insertions(+), 40 deletions(-) create mode 100644 src/dstack/_internal/server/background/pipeline_tasks/service_router_worker_sync.py create mode 100644 src/dstack/_internal/server/migrations/versions/2026/03_29_1200_e7f4a91b2c3d_add_service_router_worker_sync.py create mode 100644 src/dstack/_internal/server/services/job_replica_http_client.py create mode 100644 src/dstack/_internal/server/services/router_worker_sync.py diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 3d2d30683..26fa8420d 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -801,6 +801,12 @@ class ReplicaGroup(CoreModel): CommandsList, Field(description="The shell commands to run for replicas in this group"), ] = [] + router: Annotated[ + Optional[AnyServiceRouterConfig], + Field( + description="When set, replicas in this group run the in-service HTTP router (e.g. SGLang).", + ), + ] = None @validator("name") def validate_name(cls, v: Optional[str]) -> Optional[str]: @@ -1032,6 +1038,16 @@ def validate_replica_groups_have_commands_or_image(cls, values): return values + @root_validator() + def validate_at_most_one_router_replica_group(cls, values): + replicas = values.get("replicas") + if not isinstance(replicas, list): + return values + router_groups = [g for g in replicas if g.router is not None] + if len(router_groups) > 1: + raise ValueError("At most one replica group may specify `router`.") + return values + class ServiceConfigurationConfig( ProfileParamsConfig, diff --git a/src/dstack/_internal/proxy/gateway/routers/registry.py b/src/dstack/_internal/proxy/gateway/routers/registry.py index 61283e908..f0e24f0b2 100644 --- a/src/dstack/_internal/proxy/gateway/routers/registry.py +++ b/src/dstack/_internal/proxy/gateway/routers/registry.py @@ -82,6 +82,7 @@ async def register_replica( ssh_head_proxy=body.ssh_head_proxy, ssh_head_proxy_private_key=body.ssh_head_proxy_private_key, internal_ip=body.internal_ip, + is_router_replica=body.is_router_replica, repo=repo, nginx=nginx, service_conn_pool=service_conn_pool, diff --git a/src/dstack/_internal/proxy/gateway/schemas/registry.py b/src/dstack/_internal/proxy/gateway/schemas/registry.py index 33001cf25..e7f9df81f 100644 --- a/src/dstack/_internal/proxy/gateway/schemas/registry.py +++ b/src/dstack/_internal/proxy/gateway/schemas/registry.py @@ -58,6 +58,7 @@ class RegisterReplicaRequest(BaseModel): ssh_head_proxy: Optional[SSHConnectionParams] ssh_head_proxy_private_key: Optional[str] internal_ip: Optional[str] = None + is_router_replica: bool = False class RegisterEntrypointRequest(BaseModel): diff --git a/src/dstack/_internal/proxy/gateway/services/registry.py b/src/dstack/_internal/proxy/gateway/services/registry.py index 919c05c0f..f09071259 100644 --- a/src/dstack/_internal/proxy/gateway/services/registry.py +++ b/src/dstack/_internal/proxy/gateway/services/registry.py @@ -141,6 +141,7 @@ async def register_replica( nginx: Nginx, service_conn_pool: ServiceConnectionPool, internal_ip: Optional[str] = None, + is_router_replica: bool = False, ) -> None: replica = models.Replica( id=replica_id, @@ -152,6 +153,7 @@ async def register_replica( ssh_head_proxy=ssh_head_proxy, ssh_head_proxy_private_key=ssh_head_proxy_private_key, internal_ip=internal_ip, + is_router_replica=is_router_replica, ) async with lock: @@ -291,6 +293,13 @@ async def apply_service( ) for replica, conn in replica_conns.items() ] + router_replicas = [r for r in service.replicas if r.is_router_replica] + if router_replicas: + replica_configs_for_nginx = [c for c in replica_configs if c.id == router_replicas[0].id] + service_config = await get_nginx_service_config(service, replica_configs_for_nginx) + await nginx.register(service_config, (await repo.get_config()).acme_settings) + return replica_failures + service_config = await get_nginx_service_config(service, replica_configs) await nginx.register(service_config, (await repo.get_config()).acme_settings) return replica_failures diff --git a/src/dstack/_internal/proxy/lib/models.py b/src/dstack/_internal/proxy/lib/models.py index d15e4b7ef..da10461a5 100644 --- a/src/dstack/_internal/proxy/lib/models.py +++ b/src/dstack/_internal/proxy/lib/models.py @@ -30,6 +30,7 @@ class Replica(ImmutableModel): ssh_head_proxy: Optional[SSHConnectionParams] = None ssh_head_proxy_private_key: Optional[str] = None internal_ip: Optional[str] = None + is_router_replica: bool = False class IPAddressPartitioningKey(ImmutableModel): diff --git a/src/dstack/_internal/proxy/lib/services/service_connection.py b/src/dstack/_internal/proxy/lib/services/service_connection.py index 37bdc5083..70bca0ee4 100644 --- a/src/dstack/_internal/proxy/lib/services/service_connection.py +++ b/src/dstack/_internal/proxy/lib/services/service_connection.py @@ -151,7 +151,9 @@ async def get_service_replica_client( timeout=HTTP_TIMEOUT, ) # Nginx not available, forward directly to the tunnel - replica = random.choice(service.replicas) + router_replicas = [r for r in service.replicas if r.is_router_replica] + replicas_to_use = router_replicas if router_replicas else service.replicas + replica = random.choice(replicas_to_use) connection = await service_conn_pool.get(replica.id) if connection is None: project = await repo.get_project(service.project_name) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py index ca12c95ad..3bb31fe5b 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py @@ -16,6 +16,9 @@ PlacementGroupPipeline, ) from dstack._internal.server.background.pipeline_tasks.runs import RunPipeline +from dstack._internal.server.background.pipeline_tasks.service_router_worker_sync import ( + ServiceRouterWorkerSyncPipeline, +) from dstack._internal.server.background.pipeline_tasks.volumes import VolumePipeline from dstack._internal.utils.logging import get_logger @@ -36,6 +39,7 @@ def __init__(self) -> None: InstancePipeline(pipeline_hinter=self._hinter), PlacementGroupPipeline(pipeline_hinter=self._hinter), RunPipeline(pipeline_hinter=self._hinter), + ServiceRouterWorkerSyncPipeline(pipeline_hinter=self._hinter), VolumePipeline(pipeline_hinter=self._hinter), ]: self.register_pipeline(builtin_pipeline) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/service_router_worker_sync.py b/src/dstack/_internal/server/background/pipeline_tasks/service_router_worker_sync.py new file mode 100644 index 000000000..11a55ac84 --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/service_router_worker_sync.py @@ -0,0 +1,259 @@ +import asyncio +import uuid +from dataclasses import dataclass +from datetime import timedelta +from typing import Sequence + +from sqlalchemy import delete, or_, select, update +from sqlalchemy.orm import load_only, selectinload + +from dstack._internal.core.models.runs import RunStatus +from dstack._internal.server.background.pipeline_tasks.base import ( + Fetcher, + Heartbeater, + ItemUpdateMap, + Pipeline, + PipelineItem, + Worker, + log_lock_token_changed_after_processing, + log_lock_token_mismatch, + resolve_now_placeholders, + set_processed_update_map_fields, + set_unlock_update_map_fields, +) +from dstack._internal.server.db import get_db, get_session_ctx +from dstack._internal.server.models import ( + InstanceModel, + JobModel, + RunModel, + ServiceRouterWorkerSyncModel, +) +from dstack._internal.server.services.locking import get_locker +from dstack._internal.server.services.pipelines import PipelineHinterProtocol +from dstack._internal.server.services.router_worker_sync import ( + run_model_has_router_replica_group, + sync_router_workers_for_run_model, +) +from dstack._internal.server.utils import sentry_utils +from dstack._internal.utils.common import get_current_datetime +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +@dataclass +class ServiceRouterWorkerSyncPipelineItem(PipelineItem): + run_id: uuid.UUID + + +class ServiceRouterWorkerSyncPipeline(Pipeline[ServiceRouterWorkerSyncPipelineItem]): + def __init__( + self, + workers_num: int = 8, + queue_lower_limit_factor: float = 0.5, + queue_upper_limit_factor: float = 2.0, + min_processing_interval: timedelta = timedelta(seconds=5), + lock_timeout: timedelta = timedelta(seconds=25), + heartbeat_trigger: timedelta = timedelta(seconds=10), + *, + pipeline_hinter: PipelineHinterProtocol, + ) -> None: + super().__init__( + workers_num=workers_num, + queue_lower_limit_factor=queue_lower_limit_factor, + queue_upper_limit_factor=queue_upper_limit_factor, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeat_trigger=heartbeat_trigger, + ) + self.__heartbeater = Heartbeater[ServiceRouterWorkerSyncPipelineItem]( + model_type=ServiceRouterWorkerSyncModel, + lock_timeout=self._lock_timeout, + heartbeat_trigger=self._heartbeat_trigger, + ) + self.__fetcher = ServiceRouterWorkerSyncFetcher( + queue=self._queue, + queue_desired_minsize=self._queue_desired_minsize, + min_processing_interval=self._min_processing_interval, + lock_timeout=self._lock_timeout, + heartbeater=self.__heartbeater, + ) + self.__workers = [ + ServiceRouterWorkerSyncWorker( + queue=self._queue, + heartbeater=self.__heartbeater, + pipeline_hinter=pipeline_hinter, + ) + for _ in range(self._workers_num) + ] + + @property + def hint_fetch_model_name(self) -> str: + return ServiceRouterWorkerSyncModel.__name__ + + @property + def _heartbeater(self) -> Heartbeater[ServiceRouterWorkerSyncPipelineItem]: + return self.__heartbeater + + @property + def _fetcher(self) -> Fetcher[ServiceRouterWorkerSyncPipelineItem]: + return self.__fetcher + + @property + def _workers(self) -> Sequence["ServiceRouterWorkerSyncWorker"]: + return self.__workers + + +class ServiceRouterWorkerSyncFetcher(Fetcher[ServiceRouterWorkerSyncPipelineItem]): + @sentry_utils.instrument_named_task("pipeline_tasks.ServiceRouterWorkerSyncFetcher.fetch") + async def fetch(self, limit: int) -> list[ServiceRouterWorkerSyncPipelineItem]: + sync_lock, _ = get_locker(get_db().dialect_name).get_lockset( + ServiceRouterWorkerSyncModel.__tablename__ + ) + async with sync_lock: + async with get_session_ctx() as session: + now = get_current_datetime() + res = await session.execute( + select(ServiceRouterWorkerSyncModel) + .join(RunModel, RunModel.id == ServiceRouterWorkerSyncModel.run_id) + .where( + RunModel.status == RunStatus.RUNNING, + or_( + ServiceRouterWorkerSyncModel.last_processed_at + <= now - self._min_processing_interval, + ServiceRouterWorkerSyncModel.last_processed_at + == ServiceRouterWorkerSyncModel.created_at, + ), + or_( + ServiceRouterWorkerSyncModel.lock_expires_at.is_(None), + ServiceRouterWorkerSyncModel.lock_expires_at < now, + ), + ) + .order_by(ServiceRouterWorkerSyncModel.last_processed_at.asc()) + .limit(limit) + .with_for_update( + skip_locked=True, key_share=True, of=ServiceRouterWorkerSyncModel + ) + .options( + load_only( + ServiceRouterWorkerSyncModel.id, + ServiceRouterWorkerSyncModel.run_id, + ServiceRouterWorkerSyncModel.lock_token, + ServiceRouterWorkerSyncModel.lock_expires_at, + ) + ) + ) + rows = list(res.scalars().all()) + lock_expires_at = get_current_datetime() + self._lock_timeout + lock_token = uuid.uuid4() + items: list[ServiceRouterWorkerSyncPipelineItem] = [] + for row in rows: + prev_lock_expired = row.lock_expires_at is not None + row.lock_expires_at = lock_expires_at + row.lock_token = lock_token + row.lock_owner = ServiceRouterWorkerSyncPipeline.__name__ + items.append( + ServiceRouterWorkerSyncPipelineItem( + __tablename__=ServiceRouterWorkerSyncModel.__tablename__, + id=row.id, + lock_expires_at=lock_expires_at, + lock_token=lock_token, + prev_lock_expired=prev_lock_expired, + run_id=row.run_id, + ) + ) + await session.commit() + return items + + +class _SyncRowUpdateMap(ItemUpdateMap, total=False): + pass + + +class ServiceRouterWorkerSyncWorker(Worker[ServiceRouterWorkerSyncPipelineItem]): + def __init__( + self, + queue: asyncio.Queue[ServiceRouterWorkerSyncPipelineItem], + heartbeater: Heartbeater[ServiceRouterWorkerSyncPipelineItem], + pipeline_hinter: PipelineHinterProtocol, + ) -> None: + super().__init__( + queue=queue, + heartbeater=heartbeater, + pipeline_hinter=pipeline_hinter, + ) + + @sentry_utils.instrument_named_task("pipeline_tasks.ServiceRouterWorkerSyncWorker.process") + async def process(self, item: ServiceRouterWorkerSyncPipelineItem) -> None: + async with get_session_ctx() as session: + res = await session.execute( + select(ServiceRouterWorkerSyncModel) + .where( + ServiceRouterWorkerSyncModel.id == item.id, + ServiceRouterWorkerSyncModel.lock_token == item.lock_token, + ) + .options(selectinload(ServiceRouterWorkerSyncModel.run)) + ) + sync_row = res.unique().scalar_one_or_none() + if sync_row is None: + log_lock_token_mismatch(logger, item) + return + run_model = sync_row.run + if run_model is None: + await session.delete(sync_row) + await session.commit() + return + if ( + run_model.deleted + or run_model.status.is_finished() + or run_model.status != RunStatus.RUNNING + or not run_model_has_router_replica_group(run_model) + ): + await session.delete(sync_row) + await session.commit() + return + + async with get_session_ctx() as session: + res = await session.execute( + select(RunModel) + .where(RunModel.id == item.run_id) + .options( + selectinload(RunModel.project), + selectinload(RunModel.jobs).selectinload(JobModel.project), + selectinload(RunModel.jobs) + .selectinload(JobModel.instance) + .selectinload(InstanceModel.project), + ) + ) + run_for_sync = res.unique().scalar_one_or_none() + + if run_for_sync is None: + async with get_session_ctx() as session: + await session.execute( + delete(ServiceRouterWorkerSyncModel).where( + ServiceRouterWorkerSyncModel.id == item.id, + ServiceRouterWorkerSyncModel.lock_token == item.lock_token, + ) + ) + await session.commit() + return + + await sync_router_workers_for_run_model(run_for_sync) + + update_map: _SyncRowUpdateMap = {} + set_processed_update_map_fields(update_map) + set_unlock_update_map_fields(update_map) + async with get_session_ctx() as session: + now = get_current_datetime() + resolve_now_placeholders(update_map, now=now) + res2 = await session.execute( + update(ServiceRouterWorkerSyncModel) + .where( + ServiceRouterWorkerSyncModel.id == item.id, + ServiceRouterWorkerSyncModel.lock_token == item.lock_token, + ) + .values(**update_map) + .returning(ServiceRouterWorkerSyncModel.id) + ) + if not list(res2.scalars().all()): + log_lock_token_changed_after_processing(logger, item) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/probes.py b/src/dstack/_internal/server/background/scheduled_tasks/probes.py index ee5b5c9b4..8815e5e27 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/probes.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/probes.py @@ -1,36 +1,27 @@ -from collections.abc import AsyncGenerator -from contextlib import asynccontextmanager from datetime import timedelta from functools import partial -from pathlib import Path -from tempfile import TemporaryDirectory import httpx from apscheduler.schedulers.asyncio import AsyncIOScheduler -from httpx import AsyncClient, AsyncHTTPTransport from sqlalchemy import select, update from sqlalchemy.orm import joinedload from dstack._internal.core.errors import SSHError from dstack._internal.core.models.runs import JobStatus, ProbeSpec -from dstack._internal.core.services.ssh.tunnel import ( - SSH_DEFAULT_OPTIONS, - IPSocket, - SocketPair, - UnixSocket, -) from dstack._internal.server.db import get_db, get_session_ctx from dstack._internal.server.models import InstanceModel, JobModel, ProbeModel +from dstack._internal.server.services.job_replica_http_client import ( + SSH_CONNECT_TIMEOUT, + _get_service_replica_client, +) from dstack._internal.server.services.jobs import get_job_spec from dstack._internal.server.services.locking import get_locker from dstack._internal.server.services.logging import fmt -from dstack._internal.server.services.ssh import container_ssh_tunnel -from dstack._internal.utils.common import get_current_datetime, get_or_error +from dstack._internal.utils.common import get_current_datetime from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) BATCH_SIZE = 100 -SSH_CONNECT_TIMEOUT = timedelta(seconds=10) PROCESSING_OVERHEAD_TIMEOUT = timedelta(minutes=1) PROBES_SCHEDULER = AsyncIOScheduler() @@ -141,28 +132,3 @@ def _get_probe_async_processing_timeout(probe_spec: ProbeSpec) -> timedelta: + SSH_CONNECT_TIMEOUT + PROCESSING_OVERHEAD_TIMEOUT # slow db queries and other unforeseen conditions ) - - -@asynccontextmanager -async def _get_service_replica_client(job: JobModel) -> AsyncGenerator[AsyncClient, None]: - options = { - **SSH_DEFAULT_OPTIONS, - "ConnectTimeout": str(int(SSH_CONNECT_TIMEOUT.total_seconds())), - } - job_spec = get_job_spec(job) - with TemporaryDirectory() as temp_dir: - app_socket_path = (Path(temp_dir) / "replica.sock").absolute() - async with container_ssh_tunnel( - job=job, - forwarded_sockets=[ - SocketPair( - remote=IPSocket("localhost", get_or_error(job_spec.service_port)), - local=UnixSocket(app_socket_path), - ), - ], - options=options, - ): - async with AsyncClient( - transport=AsyncHTTPTransport(uds=str(app_socket_path)) - ) as client: - yield client diff --git a/src/dstack/_internal/server/migrations/versions/2026/03_29_1200_e7f4a91b2c3d_add_service_router_worker_sync.py b/src/dstack/_internal/server/migrations/versions/2026/03_29_1200_e7f4a91b2c3d_add_service_router_worker_sync.py new file mode 100644 index 000000000..61a8a7d99 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/03_29_1200_e7f4a91b2c3d_add_service_router_worker_sync.py @@ -0,0 +1,66 @@ +"""Add service_router_worker_sync for router-worker reconcile pipeline. + +Revision ID: e7f4a91b2c3d +Revises: e9d81c97c042 +Create Date: 2026-03-29 12:00:00.000000+00:00 + +""" + +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +import dstack._internal.server.models + +revision = "e7f4a91b2c3d" +down_revision = "e9d81c97c042" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "service_router_worker_sync", + sa.Column("id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), + sa.Column("run_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), + sa.Column("created_at", dstack._internal.server.models.NaiveDateTime(), nullable=False), + sa.Column( + "last_processed_at", dstack._internal.server.models.NaiveDateTime(), nullable=False + ), + sa.Column( + "lock_expires_at", dstack._internal.server.models.NaiveDateTime(), nullable=True + ), + sa.Column("lock_token", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=True), + sa.Column("lock_owner", sa.String(length=100), nullable=True), + sa.ForeignKeyConstraint( + ["run_id"], + ["runs.id"], + name=op.f("fk_service_router_worker_sync_run_id_runs"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_service_router_worker_sync")), + sa.UniqueConstraint("run_id", name=op.f("uq_service_router_worker_sync_run_id")), + ) + op.create_index( + op.f("ix_service_router_worker_sync_pipeline_fetch_q"), + "service_router_worker_sync", + [sa.literal_column("last_processed_at ASC")], + unique=False, + ) + op.create_index( + op.f("ix_service_router_worker_sync_run_id"), + "service_router_worker_sync", + ["run_id"], + unique=True, + ) + + +def downgrade() -> None: + op.drop_index( + op.f("ix_service_router_worker_sync_run_id"), table_name="service_router_worker_sync" + ) + op.drop_index( + op.f("ix_service_router_worker_sync_pipeline_fetch_q"), + table_name="service_router_worker_sync", + ) + op.drop_table("service_router_worker_sync") diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 45ab93096..a75f27a85 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -443,6 +443,10 @@ class RunModel(PipelineModelMixin, BaseModel): ) gateway: Mapped[Optional["GatewayModel"]] = relationship() + service_router_worker_sync: Mapped[Optional["ServiceRouterWorkerSyncModel"]] = relationship( + back_populates="run", uselist=False + ) + __table_args__ = ( Index("ix_submitted_at_id", submitted_at.desc(), id), Index( @@ -454,6 +458,31 @@ class RunModel(PipelineModelMixin, BaseModel): ) +class ServiceRouterWorkerSyncModel(PipelineModelMixin, BaseModel): + """ + Row processed by ServiceRouterWorkerSyncPipeline: sync router /workers with worker replicas. + At most one per run that uses replica-group routers. + """ + + __tablename__ = "service_router_worker_sync" + + id: Mapped[uuid.UUID] = mapped_column( + UUIDType(binary=False), primary_key=True, default=uuid.uuid4 + ) + run_id: Mapped[uuid.UUID] = mapped_column( + ForeignKey("runs.id", ondelete="CASCADE"), unique=True, index=True + ) + run: Mapped["RunModel"] = relationship(back_populates="service_router_worker_sync") + created_at: Mapped[datetime] = mapped_column(NaiveDateTime, default=get_current_datetime) + last_processed_at: Mapped[datetime] = mapped_column( + NaiveDateTime, default=get_current_datetime + ) + + __table_args__ = ( + Index("ix_service_router_worker_sync_pipeline_fetch_q", last_processed_at.asc()), + ) + + class JobModel(PipelineModelMixin, BaseModel): __tablename__ = "jobs" diff --git a/src/dstack/_internal/server/services/gateways/client.py b/src/dstack/_internal/server/services/gateways/client.py index d83891c0b..0f42b2ec6 100644 --- a/src/dstack/_internal/server/services/gateways/client.py +++ b/src/dstack/_internal/server/services/gateways/client.py @@ -90,11 +90,17 @@ async def register_replica( ssh_head_proxy_private_key: Optional[str], ): assert run.run_spec.configuration.type == "service" + config = run.run_spec.configuration + router_group = next((g for g in config.replica_groups if g.router is not None), None) + is_router_replica = ( + router_group is not None and job_spec.replica_group == router_group.name + ) payload = { "job_id": job_submission.id.hex, "app_port": get_service_port(job_spec, run.run_spec.configuration), "ssh_head_proxy": ssh_head_proxy.dict() if ssh_head_proxy is not None else None, "ssh_head_proxy_private_key": ssh_head_proxy_private_key, + "is_router_replica": is_router_replica, } jpd = job_submission.job_provisioning_data assert jpd is not None diff --git a/src/dstack/_internal/server/services/job_replica_http_client.py b/src/dstack/_internal/server/services/job_replica_http_client.py new file mode 100644 index 000000000..02a2f1e69 --- /dev/null +++ b/src/dstack/_internal/server/services/job_replica_http_client.py @@ -0,0 +1,49 @@ +"""SSH-tunneled async HTTP client to a job's service port (same path as probes).""" + +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from datetime import timedelta +from pathlib import Path +from tempfile import TemporaryDirectory + +from httpx import AsyncClient, AsyncHTTPTransport + +from dstack._internal.core.services.ssh.tunnel import ( + SSH_DEFAULT_OPTIONS, + IPSocket, + SocketPair, + UnixSocket, +) +from dstack._internal.server.models import JobModel +from dstack._internal.server.services.jobs import get_job_spec +from dstack._internal.server.services.ssh import container_ssh_tunnel +from dstack._internal.utils.common import get_or_error + +SSH_CONNECT_TIMEOUT = timedelta(seconds=10) + + +@asynccontextmanager +async def _get_service_replica_client( + job: JobModel, +) -> AsyncGenerator[AsyncClient, None]: + options = { + **SSH_DEFAULT_OPTIONS, + "ConnectTimeout": str(int(SSH_CONNECT_TIMEOUT.total_seconds())), + } + job_spec = get_job_spec(job) + with TemporaryDirectory() as temp_dir: + app_socket_path = (Path(temp_dir) / "replica.sock").absolute() + async with container_ssh_tunnel( + job=job, + forwarded_sockets=[ + SocketPair( + remote=IPSocket("localhost", get_or_error(job_spec.service_port)), + local=UnixSocket(app_socket_path), + ), + ], + options=options, + ): + async with AsyncClient( + transport=AsyncHTTPTransport(uds=str(app_socket_path)) + ) as client: + yield client diff --git a/src/dstack/_internal/server/services/proxy/repo.py b/src/dstack/_internal/server/services/proxy/repo.py index a454b74ba..8d7048776 100644 --- a/src/dstack/_internal/server/services/proxy/repo.py +++ b/src/dstack/_internal/server/services/proxy/repo.py @@ -109,6 +109,13 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic ssh_head_proxy = rci.ssh_proxy ssh_head_proxy_private_key = get_or_error(rci.ssh_proxy_keys)[0].private job_spec = get_job_spec(job) + router_group = next( + (g for g in run_spec.configuration.replica_groups if g.router is not None), + None, + ) + is_router_replica = ( + router_group is not None and job_spec.replica_group == router_group.name + ) replica = Replica( id=job.id.hex, app_port=get_service_port(job_spec, run_spec.configuration), @@ -119,6 +126,7 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic ssh_head_proxy=ssh_head_proxy, ssh_head_proxy_private_key=ssh_head_proxy_private_key, internal_ip=jpd.internal_ip, + is_router_replica=is_router_replica, ) replicas.append(replica) return Service( diff --git a/src/dstack/_internal/server/services/router_worker_sync.py b/src/dstack/_internal/server/services/router_worker_sync.py new file mode 100644 index 000000000..afa24f353 --- /dev/null +++ b/src/dstack/_internal/server/services/router_worker_sync.py @@ -0,0 +1,345 @@ +"""Reconcile SGLang router /workers with dstack's registered worker replicas (async, SSH-tunneled).""" + +import json +from typing import Any, Dict, List, Literal, Optional, TypedDict +from urllib.parse import urlsplit, urlunsplit + +from httpx import AsyncClient, Response +from typing_extensions import NotRequired + +from dstack._internal.core.models.configurations import ServiceConfiguration +from dstack._internal.core.models.runs import JobStatus, RunSpec, get_service_port +from dstack._internal.server.models import JobModel, RunModel +from dstack._internal.server.services.job_replica_http_client import ( + _get_service_replica_client, +) +from dstack._internal.server.services.jobs import get_job_provisioning_data, get_job_spec +from dstack._internal.server.services.logging import fmt +from dstack._internal.server.services.runs import run_spec_has_router_replica_group +from dstack._internal.server.services.runs.replicas import ( + is_replica_registered, + job_belongs_to_group, +) +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + +_ROUTER_HTTP = "http://dstack" +_ROUTER_HTTP_TIMEOUT = 10.0 +_MAX_SERVER_INFO_RESPONSE_BYTES = 256 * 1024 +_MAX_WORKERS_RESPONSE_BYTES = 2 * 1024 * 1024 +_MAX_WORKERS_COMMAND_ACK_BYTES = 64 * 1024 +_MAX_WORKERS_LIST_ITEMS = 8192 + + +class _ResponseTooLargeError(Exception): + pass + + +async def _stream_response_body_bytes(resp: Response, max_bytes: int) -> bytes: + buf = bytearray() + async for chunk in resp.aiter_bytes(): + buf.extend(chunk) + if len(buf) > max_bytes: + raise _ResponseTooLargeError() + return bytes(buf) + + +async def _request_json_limited( + client: AsyncClient, + method: str, + url: str, + *, + max_response_bytes: int, + ok_statuses: set[int], + json_body: Optional[dict] = None, + timeout: float = _ROUTER_HTTP_TIMEOUT, +) -> Any: + kwargs: dict[str, Any] = {"timeout": timeout} + if json_body is not None: + kwargs["json"] = json_body + endpoint = f"{method} {url}" + async with client.stream(method, url, **kwargs) as resp: + if resp.status_code not in ok_statuses: + logger.warning( + "router_http unexpected status endpoint=%s status_code=%s expected=%s", + endpoint, + resp.status_code, + sorted(ok_statuses), + ) + return None + cl = resp.headers.get("content-length") + if cl is not None: + try: + if int(cl) > max_response_bytes: + raise _ResponseTooLargeError() + except ValueError: + pass + raw = await _stream_response_body_bytes(resp, max_response_bytes) + try: + return json.loads(raw) + except json.JSONDecodeError: + logger.warning("router_http JSON parse failed endpoint=%s", endpoint) + return None + + +class _WorkerPayloadResult(TypedDict): + status: Literal["ready", "not_ready"] + payload: Optional[Dict[str, Any]] + + +class _TargetWorker(TypedDict): + url: str + worker_type: str + bootstrap_port: NotRequired[Optional[int]] + + +def run_model_has_router_replica_group(run_model: RunModel) -> bool: + run_spec = RunSpec.__response__.parse_raw(run_model.run_spec) + return run_spec_has_router_replica_group(run_spec) + + +def _get_router_job(run_model: RunModel, router_group) -> JobModel | None: + group_name = router_group.name + assert group_name is not None, "Replica group name is set by validation" + router_jobs = [ + j + for j in run_model.jobs + if job_belongs_to_group(j, group_name) and j.status == JobStatus.RUNNING + ] + if not router_jobs or not is_replica_registered(router_jobs): + return None + return router_jobs[0] + + +def _normalize_worker_url(url: str) -> str: + url = url.strip() + parts = urlsplit(url) + path = (parts.path or "").rstrip("/") + return urlunsplit((parts.scheme, parts.netloc, path, parts.query, parts.fragment)) + + +async def _get_router_workers(client: AsyncClient) -> List[dict]: + try: + data = await _request_json_limited( + client, + "GET", + f"{_ROUTER_HTTP}/workers", + max_response_bytes=_MAX_WORKERS_RESPONSE_BYTES, + ok_statuses={200}, + ) + if not isinstance(data, dict): + return [] + workers = data.get("workers", []) + if not isinstance(workers, list): + return [] + if len(workers) > _MAX_WORKERS_LIST_ITEMS: + logger.warning( + "Router /workers list exceeds %s items, truncating", + _MAX_WORKERS_LIST_ITEMS, + ) + workers = workers[:_MAX_WORKERS_LIST_ITEMS] + return [w for w in workers if isinstance(w, dict)] + except _ResponseTooLargeError: + logger.warning("Router /workers response exceeded size limit") + except Exception: + logger.exception("Error getting router /workers") + return [] + + +async def _add_worker_to_router( + client: AsyncClient, + url: str, + worker_type: str = "regular", + bootstrap_port: Optional[int] = None, +) -> bool: + try: + payload: dict = {"url": url, "worker_type": worker_type} + if bootstrap_port is not None: + payload["bootstrap_port"] = bootstrap_port + body = await _request_json_limited( + client, + "POST", + f"{_ROUTER_HTTP}/workers", + max_response_bytes=_MAX_WORKERS_COMMAND_ACK_BYTES, + ok_statuses={202}, + json_body=payload, + ) + return isinstance(body, dict) and body.get("status") == "accepted" + except _ResponseTooLargeError: + logger.warning("Router add-worker response exceeded size limit for %s", url) + return False + except Exception: + logger.exception("Error adding worker %s", url) + return False + + +async def _remove_worker_from_router(client: AsyncClient, worker_url: str) -> bool: + try: + current = await _get_router_workers(client) + worker_id = None + for w in current: + u = w.get("url") + if u and isinstance(u, str) and u.rstrip("/") == worker_url.rstrip("/"): + wid = w.get("id") + if wid and isinstance(wid, str): + worker_id = wid + break + if not worker_id: + logger.error("No worker id found for url %s", worker_url) + return False + body = await _request_json_limited( + client, + "DELETE", + f"{_ROUTER_HTTP}/workers/{worker_id}", + max_response_bytes=_MAX_WORKERS_COMMAND_ACK_BYTES, + ok_statuses={202}, + ) + return isinstance(body, dict) and body.get("status") == "accepted" + except _ResponseTooLargeError: + logger.warning("Router remove-worker response exceeded size limit for %s", worker_url) + return False + except Exception: + logger.exception("Error removing worker %s", worker_url) + return False + + +async def _update_workers_in_router_replica( + client: AsyncClient, + target_workers: List[_TargetWorker], +) -> None: + current = await _get_router_workers(client) + current_urls: set[str] = set() + for w in current: + u = w.get("url") + if isinstance(u, str) and u: + current_urls.add(_normalize_worker_url(u)) + target_by_norm = {_normalize_worker_url(t["url"]): t for t in target_workers} + target_urls = set(target_by_norm.keys()) + to_add = sorted(target_urls - current_urls) + to_remove = sorted(current_urls - target_urls) + for norm_url in to_add: + tw = target_by_norm[norm_url] + ok = await _add_worker_to_router( + client, + tw["url"], + tw["worker_type"], + tw.get("bootstrap_port"), + ) + if not ok: + logger.warning("Failed to add worker %s, continuing with others", tw["url"]) + for url in to_remove: + ok = await _remove_worker_from_router(client, url) + if not ok: + logger.warning("Failed to remove worker %s, continuing with others", url) + + +async def _get_worker_payload(job_model: JobModel, worker_url: str) -> _WorkerPayloadResult: + try: + async with _get_service_replica_client(job_model) as client: + data = await _request_json_limited( + client, + "GET", + f"{_ROUTER_HTTP}/server_info", + max_response_bytes=_MAX_SERVER_INFO_RESPONSE_BYTES, + ok_statuses={200}, + ) + if isinstance(data, dict): + if data.get("status") != "ready": + return {"status": "not_ready", "payload": None} + mode = data.get("disaggregation_mode", "") + if mode == "prefill": + bootstrap_port = data.get("disaggregation_bootstrap_port") + return { + "status": "ready", + "payload": { + "url": worker_url, + "worker_type": "prefill", + "bootstrap_port": bootstrap_port, + }, + } + if mode == "decode": + return { + "status": "ready", + "payload": {"url": worker_url, "worker_type": "decode"}, + } + return { + "status": "ready", + "payload": {"url": worker_url, "worker_type": "regular"}, + } + except _ResponseTooLargeError: + logger.debug("server_info response too large for worker %s", worker_url) + except Exception as e: + logger.debug("Could not fetch server_info for worker %s: %r", worker_url, e) + return {"status": "not_ready", "payload": None} + + +async def _build_target_workers( + run_model: RunModel, + run_spec: RunSpec, + replica_groups: List, +) -> List[_TargetWorker]: + payloads: List[_TargetWorker] = [] + config = run_spec.configuration + if not isinstance(config, ServiceConfiguration): + return payloads + + for group in replica_groups: + if group.router is not None: + continue + assert group.name is not None, "Replica group name is set by validation" + group_name = group.name + for job in run_model.jobs: + if not job_belongs_to_group(job, group_name): + continue + if job.status != JobStatus.RUNNING: + continue + if not is_replica_registered([job]): + continue + jpd = get_job_provisioning_data(job) + if jpd is None: + continue + ip = jpd.internal_ip or jpd.hostname + if not ip: + continue + job_spec = get_job_spec(job) + port = get_service_port(job_spec, config) + worker_url = f"http://{ip}:{port}" + result = await _get_worker_payload(job, worker_url) + if result["status"] == "ready" and result["payload"]: + p = result["payload"] + entry: _TargetWorker = { + "url": p["url"], + "worker_type": p.get("worker_type", "regular"), + } + if p.get("bootstrap_port") is not None: + entry["bootstrap_port"] = p["bootstrap_port"] + payloads.append(entry) + elif result["status"] == "not_ready": + logger.debug("Worker %s not ready", worker_url) + return payloads + + +async def sync_router_workers_for_run_model(run_model: RunModel) -> None: + run_spec = RunSpec.__response__.parse_raw(run_model.run_spec) + config = run_spec.configuration + if not isinstance(config, ServiceConfiguration): + return + replica_groups = config.replica_groups + router_group = next((g for g in replica_groups if g.router is not None), None) + if router_group is None: + return + + target_workers = await _build_target_workers(run_model, run_spec, replica_groups) + router_job = _get_router_job(run_model, router_group) + if router_job is None: + return + try: + async with _get_service_replica_client(router_job) as client: + await _update_workers_in_router_replica(client, target_workers) + except Exception as e: + logger.warning( + "%s: failed to sync workers with router: %r", + fmt(router_job), + e, + ) diff --git a/src/dstack/_internal/server/services/runs/__init__.py b/src/dstack/_internal/server/services/runs/__init__.py index 9f687756f..c640a75ee 100644 --- a/src/dstack/_internal/server/services/runs/__init__.py +++ b/src/dstack/_internal/server/services/runs/__init__.py @@ -18,6 +18,7 @@ ServerClientError, ) from dstack._internal.core.models.common import ApplyAction +from dstack._internal.core.models.configurations import ServiceConfiguration from dstack._internal.core.models.profiles import ( RetryEvent, ) @@ -47,6 +48,7 @@ ProjectModel, RepoModel, RunModel, + ServiceRouterWorkerSyncModel, UserModel, ) from dstack._internal.server.services import events, services @@ -93,6 +95,40 @@ } +def run_spec_has_router_replica_group(run_spec: RunSpec) -> bool: + if run_spec.configuration.type != "service": + return False + cfg = run_spec.configuration + if not isinstance(cfg, ServiceConfiguration): + return False + return any(g.router is not None for g in cfg.replica_groups) + + +async def ensure_service_router_worker_sync_row( + session: AsyncSession, + run_model: RunModel, + run_spec: RunSpec, +) -> None: + if not run_spec_has_router_replica_group(run_spec): + return + res = await session.execute( + select(ServiceRouterWorkerSyncModel.id).where( + ServiceRouterWorkerSyncModel.run_id == run_model.id + ) + ) + if res.scalar_one_or_none() is not None: + return + now = common_utils.get_current_datetime() + session.add( + ServiceRouterWorkerSyncModel( + id=uuid.uuid4(), + run_id=run_model.id, + created_at=now, + last_processed_at=now, + ) + ) + + def switch_run_status( session: AsyncSession, run_model: RunModel, @@ -621,6 +657,7 @@ async def submit_run( ], ) global_replica_num += 1 + await ensure_service_router_worker_sync_row(session, run_model, run_spec) else: for replica_num in range(initial_replicas): jobs = await get_jobs_from_run_spec( From dfec63e7acb852f08368551e447c5bdf6c0193ba Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Wed, 1 Apr 2026 13:29:15 +0545 Subject: [PATCH 2/3] Resolve pyright test --- src/dstack/_internal/server/services/router_worker_sync.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dstack/_internal/server/services/router_worker_sync.py b/src/dstack/_internal/server/services/router_worker_sync.py index afa24f353..ad3071ad8 100644 --- a/src/dstack/_internal/server/services/router_worker_sync.py +++ b/src/dstack/_internal/server/services/router_worker_sync.py @@ -99,7 +99,7 @@ def run_model_has_router_replica_group(run_model: RunModel) -> bool: return run_spec_has_router_replica_group(run_spec) -def _get_router_job(run_model: RunModel, router_group) -> JobModel | None: +def _get_router_job(run_model: RunModel, router_group) -> Optional[JobModel]: group_name = router_group.name assert group_name is not None, "Replica group name is set by validation" router_jobs = [ From e155d17a7edcf13cf2a930acd70cec52f1974311 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Wed, 1 Apr 2026 14:08:13 +0545 Subject: [PATCH 3/3] Resolve tests --- .../03_29_1200_e7f4a91b2c3d_add_service_router_worker_sync.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/dstack/_internal/server/migrations/versions/2026/03_29_1200_e7f4a91b2c3d_add_service_router_worker_sync.py b/src/dstack/_internal/server/migrations/versions/2026/03_29_1200_e7f4a91b2c3d_add_service_router_worker_sync.py index 61a8a7d99..32ba5e838 100644 --- a/src/dstack/_internal/server/migrations/versions/2026/03_29_1200_e7f4a91b2c3d_add_service_router_worker_sync.py +++ b/src/dstack/_internal/server/migrations/versions/2026/03_29_1200_e7f4a91b2c3d_add_service_router_worker_sync.py @@ -39,7 +39,6 @@ def upgrade() -> None: ondelete="CASCADE", ), sa.PrimaryKeyConstraint("id", name=op.f("pk_service_router_worker_sync")), - sa.UniqueConstraint("run_id", name=op.f("uq_service_router_worker_sync_run_id")), ) op.create_index( op.f("ix_service_router_worker_sync_pipeline_fetch_q"),