Skip to content

Commit dee4d75

Browse files
committed
feat(router): Add NIXL-based disaggregated prefill routing support
Signed-off-by: Yiqi Xue <xuey666@gmail.com>
1 parent dda6ea8 commit dee4d75

6 files changed

Lines changed: 722 additions & 45 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ dependencies = [
3131
"opentelemetry-exporter-otlp>=1.28.0",
3232
"h11>=0.16.0", # fix critical vulnerability GHSA-vqfr-h8mv-ghfj
3333
"httpcore>=1.0.8", # required for h11>=0.16.0
34+
"pyzmq>=27.0.0",
35+
"msgspec>=0.19.0",
3436
]
3537

3638
[project.scripts]

src/vllm_router/app.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from vllm_router.routers.main_router import main_router
3434
from vllm_router.routers.metrics_router import metrics_router
3535
from vllm_router.routers.routing_logic import (
36+
DisaggregatedPrefillRouter,
3637
cleanup_routing_logic,
3738
get_routing_logic,
3839
initialize_routing_logic,
@@ -48,6 +49,7 @@
4849
from vllm_router.services.request_service.rewriter import (
4950
get_request_rewriter,
5051
)
52+
from vllm_router.services.request_service.zmq_proxy import NixlConfig, ZmqProxy
5153
from vllm_router.stats.engine_stats import (
5254
get_engine_stats_scraper,
5355
initialize_engine_stats_scraper,
@@ -97,6 +99,7 @@
9799
@asynccontextmanager
98100
async def lifespan(app: FastAPI):
99101
app.state.aiohttp_client_wrapper.start()
102+
100103
if hasattr(app.state, "batch_processor"):
101104
await app.state.batch_processor.initialize()
102105

@@ -111,7 +114,29 @@ async def lifespan(app: FastAPI):
111114
if hasattr(service_discovery, "initialize_client_sessions"):
112115
await service_discovery.initialize_client_sessions()
113116

114-
yield
117+
use_nixl = (
118+
isinstance(app.state.router, DisaggregatedPrefillRouter)
119+
and hasattr(app.state, "nixl_config")
120+
and app.state.nixl_config is not None
121+
)
122+
if use_nixl:
123+
logger.info(
124+
"Starting ZMQ task because the routing logic is"
125+
" RoutingLogic.DISAGGREGATED_PREFILL and nixl_proxy_host is configured"
126+
)
127+
nixl_config = app.state.nixl_config
128+
app.state.zmq_proxy = ZmqProxy(
129+
finished_req_ttl=nixl_config.finished_req_ttl,
130+
cleanup_interval=nixl_config.cleanup_interval,
131+
)
132+
await app.state.zmq_proxy.start(nixl_config.proxy_host, nixl_config.proxy_port)
133+
134+
yield
135+
136+
await app.state.zmq_proxy.stop()
137+
else:
138+
yield
139+
115140
await app.state.aiohttp_client_wrapper.stop()
116141

117142
# Close the threaded-components
@@ -211,8 +236,16 @@ def initialize_all(app: FastAPI, args):
211236
namespace=args.k8s_namespace,
212237
port=args.k8s_port,
213238
label_selector=args.k8s_label_selector,
214-
prefill_model_labels=args.prefill_model_labels,
215-
decode_model_labels=args.decode_model_labels,
239+
prefill_model_labels=(
240+
parse_comma_separated_args(args.prefill_model_labels)
241+
if args.prefill_model_labels
242+
else None
243+
),
244+
decode_model_labels=(
245+
parse_comma_separated_args(args.decode_model_labels)
246+
if args.decode_model_labels
247+
else None
248+
),
216249
watcher_timeout_seconds=args.k8s_watcher_timeout_seconds,
217250
health_check_timeout_seconds=args.backend_health_check_timeout_seconds,
218251
)
@@ -325,6 +358,21 @@ def initialize_all(app: FastAPI, args):
325358
app.state.router = get_routing_logic()
326359
app.state.request_rewriter = get_request_rewriter()
327360

361+
# Build NixlConfig if disaggregated prefill with NIXL proxy is configured.
362+
if (
363+
hasattr(args, "nixl_proxy_host")
364+
and args.nixl_proxy_host is not None
365+
):
366+
app.state.nixl_config = NixlConfig(
367+
proxy_host=args.nixl_proxy_host,
368+
proxy_port=args.nixl_proxy_port,
369+
peer_host=args.nixl_peer_host,
370+
peer_init_port=args.nixl_peer_init_port,
371+
peer_alloc_port=args.nixl_peer_alloc_port,
372+
finished_req_ttl=args.nixl_finished_req_ttl,
373+
cleanup_interval=args.nixl_cleanup_interval,
374+
)
375+
328376

329377
app = FastAPI(lifespan=lifespan)
330378
app.include_router(main_router)

src/vllm_router/parsers/parser.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,54 @@ def parse_args():
461461
help="Timeout for LMCache worker (seconds)",
462462
)
463463

464+
parser.add_argument(
465+
"--nixl-peer-host",
466+
type=str,
467+
help="The hostname or IP address of the NIXL peer service. Only use for DisaggregatedPrefillRouter.",
468+
)
469+
parser.add_argument(
470+
"--nixl-peer-init-port",
471+
type=int,
472+
default=7300,
473+
help="The initialization port for the NIXL peer service. Only use for DisaggregatedPrefillRouter.",
474+
)
475+
parser.add_argument(
476+
"--nixl-peer-alloc-port",
477+
type=int,
478+
default=7400,
479+
help="The allocation port for the NIXL peer service. Only use for DisaggregatedPrefillRouter.",
480+
)
481+
parser.add_argument(
482+
"--nixl-proxy-host",
483+
type=str,
484+
help="The hostname or IP address for the NIXL proxy server. Only use for DisaggregatedPrefillRouter.",
485+
)
486+
parser.add_argument(
487+
"--nixl-proxy-port",
488+
type=int,
489+
default=7500,
490+
help="The port for the NIXL proxy server. Only use for DisaggregatedPrefillRouter.",
491+
)
492+
parser.add_argument(
493+
"--nixl-finished-req-ttl",
494+
type=float,
495+
default=120.0,
496+
help=(
497+
"Seconds to retain a KV-ready entry in the ZMQ proxy before "
498+
"evicting it. Must be at least as long as the worst-case decode "
499+
"latency for a single request. Defaults to 120 s."
500+
),
501+
)
502+
parser.add_argument(
503+
"--nixl-cleanup-interval",
504+
type=float,
505+
default=60.0,
506+
help=(
507+
"How often (seconds) the ZMQ proxy background task scans for "
508+
"stale KV-ready entries. Defaults to 60 s."
509+
),
510+
)
511+
464512
args = parser.parse_args()
465513
args = load_initial_config_from_config_file_if_required(parser, args)
466514

src/vllm_router/service_discovery.py

Lines changed: 110 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def get_endpoint_info(self) -> List[EndpointInfo]:
330330

331331
async def initialize_client_sessions(self) -> None:
332332
"""
333-
Initialize aiohttp ClientSession objects for prefill and decode endpoints.
333+
Initialize aiohttp client sessions for prefill and decode endpoints.
334334
This must be called from an async context during app startup.
335335
"""
336336
if (
@@ -739,18 +739,22 @@ def _add_engine(
739739
# Store model information in the endpoint info
740740
self.available_engines[engine_name].model_info = model_info
741741

742-
if self.event_loop_ready.is_set() and self.event_loop is not None:
743-
try:
742+
# Initialize client sessions only if event_loop is available
743+
try:
744+
if hasattr(self.app.state, "event_loop") and self.app.state.event_loop:
744745
fut = asyncio.run_coroutine_threadsafe(
745-
self.initialize_client_sessions(),
746-
self.event_loop,
746+
self.initialize_client_sessions(), self.app.state.event_loop
747747
)
748748
fut.result()
749-
except Exception as e:
750-
logger.error(f"Error initializing client sessions: {e}")
751-
else:
752-
logger.debug(
753-
"Event loop not ready; deferring client session initialization"
749+
logger.info("Client sessions initialized successfully in _add_engine")
750+
else:
751+
# Event loop not ready yet, client sessions will be initialized in lifespan
752+
logger.debug(
753+
"Event loop not ready in _add_engine, client sessions will be initialized later"
754+
)
755+
except Exception as e:
756+
logger.error(
757+
f"Error initializing client sessions in _add_engine: {e}", exc_info=True
754758
)
755759

756760
# Track all models we've ever seen
@@ -833,35 +837,63 @@ def close(self):
833837

834838
async def initialize_client_sessions(self) -> None:
835839
"""
836-
Initialize aiohttp ClientSession objects for prefill and decode endpoints.
840+
Initialize aiohttp client sessions for prefill and decode endpoints.
837841
This must be called from an async context during app startup.
838842
"""
843+
logger.info(
844+
f"initialize_client_sessions called. prefill_model_labels={self.prefill_model_labels}, decode_model_labels={self.decode_model_labels}"
845+
)
839846
if (
840847
self.prefill_model_labels is not None
841848
and self.decode_model_labels is not None
842849
):
843850
endpoint_infos = self.get_endpoint_info()
851+
logger.info(f"Got {len(endpoint_infos)} endpoints")
844852
for endpoint_info in endpoint_infos:
853+
logger.info(
854+
f"Checking endpoint: url={endpoint_info.url}, model_label={endpoint_info.model_label}"
855+
)
845856
if endpoint_info.model_label in self.prefill_model_labels:
846857
if (
847858
hasattr(self.app.state, "prefill_client")
848859
and self.app.state.prefill_client is not None
849860
):
850-
await self.app.state.prefill_client.close()
851-
self.app.state.prefill_client = aiohttp.ClientSession(
852-
base_url=endpoint_info.url,
853-
timeout=aiohttp.ClientTimeout(total=None),
854-
)
861+
# Session already initialised; skip to avoid disrupting
862+
# in-flight requests. xPyD (multiple prefill nodes) is
863+
# not supported in this PR — only the first discovered
864+
# prefill endpoint is used.
865+
logger.debug(
866+
f"prefill_client already set, skipping {endpoint_info.url}"
867+
)
868+
else:
869+
self.app.state.prefill_client = aiohttp.ClientSession(
870+
base_url=endpoint_info.url,
871+
timeout=aiohttp.ClientTimeout(total=None),
872+
)
873+
logger.info(
874+
f"Created prefill_client for {endpoint_info.url} with timeout=None"
875+
)
876+
855877
elif endpoint_info.model_label in self.decode_model_labels:
856878
if (
857879
hasattr(self.app.state, "decode_client")
858880
and self.app.state.decode_client is not None
859881
):
860-
await self.app.state.decode_client.close()
861-
self.app.state.decode_client = aiohttp.ClientSession(
862-
base_url=endpoint_info.url,
863-
timeout=aiohttp.ClientTimeout(total=None),
864-
)
882+
logger.debug(
883+
f"decode_client already set, skipping {endpoint_info.url}"
884+
)
885+
else:
886+
self.app.state.decode_client = aiohttp.ClientSession(
887+
base_url=endpoint_info.url,
888+
timeout=aiohttp.ClientTimeout(total=None),
889+
)
890+
logger.info(
891+
f"Created decode_client for {endpoint_info.url} with timeout=None"
892+
)
893+
else:
894+
logger.warning(
895+
"prefill_model_labels or decode_model_labels is None, skipping client session initialization"
896+
)
865897

866898
def has_ever_seen_model(self, model_name: str) -> bool:
867899
"""Check if we've ever seen this model, even if currently scaled to zero."""
@@ -1195,6 +1227,21 @@ def _add_engine(self, engine_name: str, model_names: List[str], model_label: str
11951227
# Store model information in the endpoint info
11961228
self.available_engines[engine_name].model_info = model_info
11971229

1230+
try:
1231+
# Only initialize client sessions if event_loop is available
1232+
if hasattr(self.app.state, "event_loop") and self.app.state.event_loop:
1233+
fut = asyncio.run_coroutine_threadsafe(
1234+
self.initialize_client_sessions(), self.app.state.event_loop
1235+
)
1236+
fut.result()
1237+
else:
1238+
# Event loop not ready yet, client sessions will be initialized in lifespan
1239+
logger.debug(
1240+
"Event loop not ready, client sessions will be initialized later"
1241+
)
1242+
except Exception as e:
1243+
logger.error(f"Error initializing client sessions: {e}")
1244+
11981245
def _delete_engine(self, engine_name: str):
11991246
logger.info(f"Serving engine {engine_name} is deleted")
12001247
with self.available_engines_lock:
@@ -1270,25 +1317,58 @@ def close(self):
12701317

12711318
async def initialize_client_sessions(self) -> None:
12721319
"""
1273-
Initialize aiohttp ClientSession objects for prefill and decode endpoints.
1320+
Initialize aiohttp client sessions for prefill and decode endpoints.
12741321
This must be called from an async context during app startup.
12751322
"""
1323+
logger.info(
1324+
f"K8sServiceNameServiceDiscovery.initialize_client_sessions called. prefill_model_labels={self.prefill_model_labels}, decode_model_labels={self.decode_model_labels}"
1325+
)
12761326
if (
12771327
self.prefill_model_labels is not None
12781328
and self.decode_model_labels is not None
12791329
):
12801330
endpoint_infos = self.get_endpoint_info()
1331+
logger.info(f"Got {len(endpoint_infos)} endpoints")
12811332
for endpoint_info in endpoint_infos:
1333+
logger.info(
1334+
f"Checking endpoint: url={endpoint_info.url}, model_label={endpoint_info.model_label}"
1335+
)
12821336
if endpoint_info.model_label in self.prefill_model_labels:
1283-
self.app.state.prefill_client = aiohttp.ClientSession(
1284-
base_url=endpoint_info.url,
1285-
timeout=aiohttp.ClientTimeout(total=None),
1286-
)
1337+
if (
1338+
hasattr(self.app.state, "prefill_client")
1339+
and self.app.state.prefill_client is not None
1340+
):
1341+
logger.debug(
1342+
f"prefill_client already set, skipping {endpoint_info.url}"
1343+
)
1344+
else:
1345+
self.app.state.prefill_client = aiohttp.ClientSession(
1346+
base_url=endpoint_info.url,
1347+
timeout=aiohttp.ClientTimeout(total=None),
1348+
)
1349+
logger.info(
1350+
f"Created prefill_client for {endpoint_info.url} with timeout=None"
1351+
)
12871352
elif endpoint_info.model_label in self.decode_model_labels:
1288-
self.app.state.decode_client = aiohttp.ClientSession(
1289-
base_url=endpoint_info.url,
1290-
timeout=aiohttp.ClientTimeout(total=None),
1291-
)
1353+
if (
1354+
hasattr(self.app.state, "decode_client")
1355+
and self.app.state.decode_client is not None
1356+
):
1357+
logger.debug(
1358+
f"decode_client already set, skipping {endpoint_info.url}"
1359+
)
1360+
else:
1361+
self.app.state.decode_client = aiohttp.ClientSession(
1362+
base_url=endpoint_info.url,
1363+
timeout=aiohttp.ClientTimeout(total=None),
1364+
)
1365+
logger.info(
1366+
f"Created decode_client for {endpoint_info.url} with timeout=None"
1367+
)
1368+
else:
1369+
logger.warning(
1370+
"K8sServiceNameServiceDiscovery: prefill_model_labels or decode_model_labels is None, skipping client session initialization"
1371+
)
12921372

12931373

12941374
def _create_service_discovery(

0 commit comments

Comments
 (0)