Skip to content

Commit 3aae583

Browse files
authored
Add pd disaggregated inference (#3558)
* Initial PD disaggregation implementation Test2 Internal IP Test Add worker with internal_ip Check status and register Add Status Ready Log Add Prefill-Decode Add PD to dstack Test register worker without poll Add router config in service config Update remove worker Clean Up router code Clean Up Further Cleanup * Add pd disaggregation service * Move router configuration to service * Resolve major comments * Resolve Lint Error * Minor Update * Resolve Minor Comments * Update wheel url * Resolve backward incompatibility * Update RouterConfigs * Resolve Lint Error * Update gateway wheel * Minor Update --------- Co-authored-by: Bihan Rana
1 parent f7a977d commit 3aae583

File tree

20 files changed

+294
-90
lines changed

20 files changed

+294
-90
lines changed

docs/docs/reference/dstack.yml/gateway.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ The `gateway` configuration type allows creating and updating [gateways](../../c
1414

1515
=== "SGLang Model Gateway"
1616

17-
#SCHEMA# dstack._internal.core.models.routers.SGLangRouterConfig
17+
#SCHEMA# dstack._internal.core.models.routers.SGLangGatewayRouterConfig
1818
overrides:
1919
show_root_heading: false
2020
type:

gateway/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ dependencies = [
1515
]
1616

1717
[project.optional-dependencies]
18-
sglang = ["sglang-router==0.2.1"]
18+
sglang = ["sglang-router==0.3.2"]
1919

2020
[tool.setuptools.package-data]
2121
"dstack.gateway" = [

src/dstack/_internal/core/backends/base/compute.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
SSHKey,
4040
)
4141
from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData
42-
from dstack._internal.core.models.routers import AnyRouterConfig
42+
from dstack._internal.core.models.routers import AnyGatewayRouterConfig
4343
from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run
4444
from dstack._internal.core.models.volumes import (
4545
Volume,
@@ -924,7 +924,9 @@ def get_run_shim_script(
924924
]
925925

926926

927-
def get_gateway_user_data(authorized_key: str, router: Optional[AnyRouterConfig] = None) -> str:
927+
def get_gateway_user_data(
928+
authorized_key: str, router: Optional[AnyGatewayRouterConfig] = None
929+
) -> str:
928930
return get_cloud_config(
929931
package_update=True,
930932
packages=[
@@ -1036,7 +1038,7 @@ def get_latest_runner_build() -> Optional[str]:
10361038
return None
10371039

10381040

1039-
def get_dstack_gateway_wheel(build: str, router: Optional[AnyRouterConfig] = None) -> str:
1041+
def get_dstack_gateway_wheel(build: str, router: Optional[AnyGatewayRouterConfig] = None) -> str:
10401042
channel = "release" if settings.DSTACK_RELEASE else "stgn"
10411043
base_url = f"https://dstack-gateway-downloads.s3.amazonaws.com/{channel}"
10421044
if build == "latest":
@@ -1049,7 +1051,7 @@ def get_dstack_gateway_wheel(build: str, router: Optional[AnyRouterConfig] = Non
10491051
return f"dstack-gateway @ {wheel}"
10501052

10511053

1052-
def get_dstack_gateway_commands(router: Optional[AnyRouterConfig] = None) -> List[str]:
1054+
def get_dstack_gateway_commands(router: Optional[AnyGatewayRouterConfig] = None) -> List[str]:
10531055
build = get_dstack_runner_version() or "latest"
10541056
gateway_package = get_dstack_gateway_wheel(build, router)
10551057
return [

src/dstack/_internal/core/backends/kubernetes/compute.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
)
6767
from dstack._internal.core.models.placement import PlacementGroup
6868
from dstack._internal.core.models.resources import CPUSpec, GPUSpec
69-
from dstack._internal.core.models.routers import AnyRouterConfig
69+
from dstack._internal.core.models.routers import AnyGatewayRouterConfig
7070
from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run
7171
from dstack._internal.core.models.volumes import Volume
7272
from dstack._internal.utils.common import get_or_error
@@ -864,7 +864,7 @@ def _wait_for_load_balancer_address(
864864

865865

866866
def _get_gateway_commands(
867-
authorized_keys: List[str], router: Optional[AnyRouterConfig] = None
867+
authorized_keys: List[str], router: Optional[AnyGatewayRouterConfig] = None
868868
) -> List[str]:
869869
authorized_keys_content = "\n".join(authorized_keys).strip()
870870
gateway_commands = " && ".join(get_dstack_gateway_commands(router=router))

src/dstack/_internal/core/compatibility/gateways.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@ def _get_gateway_configuration_excludes(
3131
) -> IncludeExcludeDictType:
3232
configuration_excludes: IncludeExcludeDictType = {}
3333

34-
# Add excludes like this:
35-
#
36-
# if configuration.tags is None:
37-
# configuration_excludes["tags"] = True
34+
if configuration.router is None:
35+
configuration_excludes["router"] = True
3836

3937
return configuration_excludes

src/dstack/_internal/core/compatibility/runs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from dstack._internal.core.models.common import IncludeExcludeDictType, IncludeExcludeSetType
44
from dstack._internal.core.models.configurations import ServiceConfiguration
5+
from dstack._internal.core.models.routers import SGLangServiceRouterConfig
56
from dstack._internal.core.models.runs import (
67
DEFAULT_PROBE_UNTIL_READY,
78
DEFAULT_REPLICA_GROUP_NAME,
@@ -72,6 +73,12 @@ def get_run_spec_excludes(run_spec: RunSpec) -> IncludeExcludeDictType:
7273
# Servers prior to 0.20.8 do not support probes=None
7374
configuration_excludes["probes"] = True
7475

76+
router = run_spec.configuration.router
77+
if router is None:
78+
configuration_excludes["router"] = True
79+
elif isinstance(router, SGLangServiceRouterConfig) and router.pd_disaggregation is False:
80+
configuration_excludes["router"] = {"pd_disaggregation": True}
81+
7582
if configuration_excludes:
7683
spec_excludes["configuration"] = configuration_excludes
7784
if profile_excludes:

src/dstack/_internal/core/models/configurations.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
parse_off_duration,
2929
)
3030
from dstack._internal.core.models.resources import Range, ResourcesSpec
31+
from dstack._internal.core.models.routers import AnyServiceRouterConfig
3132
from dstack._internal.core.models.services import AnyModel, OpenAIChatModel
3233
from dstack._internal.core.models.unix import UnixUser
3334
from dstack._internal.core.models.volumes import MountPoint, VolumeConfiguration, parse_mount_point
@@ -887,6 +888,14 @@ class ServiceConfigurationParams(CoreModel):
887888
)
888889
),
889890
] = None
891+
router: Annotated[
892+
Optional[AnyServiceRouterConfig],
893+
Field(
894+
description=(
895+
"Router configuration for the service. Requires a gateway with matching router enabled. "
896+
),
897+
),
898+
] = None
890899

891900
@validator("port")
892901
def convert_port(cls, v) -> PortMapping:

src/dstack/_internal/core/models/gateways.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from dstack._internal.core.models.backends.base import BackendType
1010
from dstack._internal.core.models.common import CoreModel
11-
from dstack._internal.core.models.routers import AnyRouterConfig
11+
from dstack._internal.core.models.routers import AnyGatewayRouterConfig
1212
from dstack._internal.utils.tags import tags_validator
1313

1414

@@ -63,8 +63,13 @@ class GatewayConfiguration(CoreModel):
6363
),
6464
] = None
6565
router: Annotated[
66-
Optional[AnyRouterConfig],
67-
Field(description="The router configuration"),
66+
Optional[AnyGatewayRouterConfig],
67+
Field(
68+
description=(
69+
"The router configuration for this gateway. "
70+
"E.g. `{ type: sglang, policy: round_robin }`."
71+
),
72+
),
6873
] = None
6974
domain: Annotated[
7075
Optional[str], Field(description="The gateway domain, e.g. `example.com`")
@@ -134,7 +139,7 @@ class GatewayComputeConfiguration(CoreModel):
134139
ssh_key_pub: str
135140
certificate: Optional[AnyGatewayCertificate] = None
136141
tags: Optional[Dict[str, str]] = None
137-
router: Optional[AnyRouterConfig] = None
142+
router: Optional[AnyGatewayRouterConfig] = None
138143

139144

140145
class GatewayProvisioningData(CoreModel):

src/dstack/_internal/core/models/routers.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,37 @@ class RouterType(str, Enum):
1111
SGLANG = "sglang"
1212

1313

14-
class SGLangRouterConfig(CoreModel):
14+
class SGLangGatewayRouterConfig(CoreModel):
15+
"""Gateway-level router configuration. type and policy only. pd_disaggregation is service-level."""
16+
17+
type: Annotated[
18+
Literal["sglang"],
19+
Field(description="The router type enabled on this gateway."),
20+
] = "sglang"
21+
policy: Annotated[
22+
Literal["random", "round_robin", "cache_aware", "power_of_two"],
23+
Field(
24+
description=(
25+
"The routing policy. Deprecated: prefer setting policy in the service's router config. "
26+
"Options: `random`, `round_robin`, `cache_aware`, `power_of_two`"
27+
),
28+
),
29+
] = "cache_aware"
30+
31+
32+
class SGLangServiceRouterConfig(CoreModel):
1533
type: Annotated[Literal["sglang"], Field(description="The router type")] = "sglang"
1634
policy: Annotated[
1735
Literal["random", "round_robin", "cache_aware", "power_of_two"],
1836
Field(
1937
description="The routing policy. Options: `random`, `round_robin`, `cache_aware`, `power_of_two`"
2038
),
2139
] = "cache_aware"
40+
pd_disaggregation: Annotated[
41+
bool,
42+
Field(description="Enable PD disaggregation mode for the SGLang router"),
43+
] = False
2244

2345

24-
AnyRouterConfig = SGLangRouterConfig
46+
AnyServiceRouterConfig = SGLangServiceRouterConfig
47+
AnyGatewayRouterConfig = SGLangGatewayRouterConfig

src/dstack/_internal/proxy/gateway/routers/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ async def register_replica(
8080
ssh_proxy=body.ssh_proxy,
8181
ssh_head_proxy=body.ssh_head_proxy,
8282
ssh_head_proxy_private_key=body.ssh_head_proxy_private_key,
83+
internal_ip=body.internal_ip,
8384
repo=repo,
8485
nginx=nginx,
8586
service_conn_pool=service_conn_pool,

0 commit comments

Comments
 (0)