Skip to content

Commit 8d9bf63

Browse files
committed
feat: enable bidirectional pd kv transfer
1 parent 8f6524e commit 8d9bf63

6 files changed

Lines changed: 242 additions & 21 deletions

File tree

py_src/vllm_router/router.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ class Router:
7676
If not specified, uses the main policy. Default: None
7777
decode_policy: Specific load balancing policy for decode nodes (PD mode only).
7878
If not specified, uses the main policy. Default: None
79+
pd_kv_cache_ttl_secs: TTL in seconds for Decode-side KV metadata cached for
80+
bidirectional vLLM P/D transfer. Default: 0
7981
request_id_headers: List of HTTP headers to check for request IDs. If not specified,
8082
uses common defaults: ['x-request-id', 'x-correlation-id', 'x-trace-id', 'request-id'].
8183
Example: ['x-my-request-id', 'x-custom-trace-id']. Default: None

py_src/vllm_router/router_args.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class RouterArgs:
2121
default_factory=list
2222
) # List of (url, bootstrap_port)
2323
decode_urls: List[str] = dataclasses.field(default_factory=list)
24+
pd_kv_cache_ttl_secs: int = 0
2425

2526
# Routing policy
2627
policy: str = "cache_aware"
@@ -201,6 +202,12 @@ def add_cli_args(
201202
metavar=("URL",),
202203
help="Decode server URL. Can be specified multiple times.",
203204
)
205+
parser.add_argument(
206+
f"--{prefix}pd-kv-cache-ttl-secs",
207+
type=int,
208+
default=RouterArgs.pd_kv_cache_ttl_secs,
209+
help="TTL in seconds for Decode-side KV metadata cached for bidirectional vLLM P/D transfer.",
210+
)
204211
parser.add_argument(
205212
f"--{prefix}worker-startup-timeout-secs",
206213
type=int,

src/config/types.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,19 @@ pub struct RouterConfig {
8484
/// Profiling timeout in seconds (for vLLM profiling endpoints)
8585
#[serde(default = "default_profile_timeout_secs")]
8686
pub profile_timeout_secs: u64,
87+
/// TTL for Decode-side KV metadata cached by vLLM P/D router.
88+
#[serde(default = "default_pd_kv_cache_ttl_secs")]
89+
pub pd_kv_cache_ttl_secs: u64,
8790
}
8891

8992
fn default_profile_timeout_secs() -> u64 {
9093
10
9194
}
9295

96+
fn default_pd_kv_cache_ttl_secs() -> u64 {
97+
0
98+
}
99+
93100
fn default_history_backend() -> HistoryBackend {
94101
HistoryBackend::Memory
95102
}
@@ -491,6 +498,7 @@ impl Default for RouterConfig {
491498
history_backend: default_history_backend(),
492499
enable_profiling: false,
493500
profile_timeout_secs: default_profile_timeout_secs(),
501+
pd_kv_cache_ttl_secs: default_pd_kv_cache_ttl_secs(),
494502
}
495503
}
496504
}
@@ -1063,6 +1071,7 @@ mod tests {
10631071
history_backend: default_history_backend(),
10641072
enable_profiling: false,
10651073
profile_timeout_secs: default_profile_timeout_secs(),
1074+
pd_kv_cache_ttl_secs: default_pd_kv_cache_ttl_secs(),
10661075
};
10671076

10681077
assert!(config.mode.is_pd_mode());
@@ -1131,6 +1140,7 @@ mod tests {
11311140
history_backend: default_history_backend(),
11321141
enable_profiling: false,
11331142
profile_timeout_secs: default_profile_timeout_secs(),
1143+
pd_kv_cache_ttl_secs: default_pd_kv_cache_ttl_secs(),
11341144
};
11351145

11361146
assert!(!config.mode.is_pd_mode());
@@ -1195,6 +1205,7 @@ mod tests {
11951205
history_backend: default_history_backend(),
11961206
enable_profiling: false,
11971207
profile_timeout_secs: default_profile_timeout_secs(),
1208+
pd_kv_cache_ttl_secs: default_pd_kv_cache_ttl_secs(),
11981209
};
11991210

12001211
assert!(config.has_service_discovery());

src/lib.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ struct Router {
104104
// OpenTelemetry tracing
105105
enable_trace: bool,
106106
otlp_traces_endpoint: Option<String>,
107+
// vLLM P/D Decode -> Prefill KV metadata cache
108+
pd_kv_cache_ttl_secs: u64,
107109
}
108110

109111
impl Router {
@@ -253,6 +255,7 @@ impl Router {
253255
history_backend: config::HistoryBackend::Memory,
254256
enable_profiling: false, // Profiling disabled in Python binding by default
255257
profile_timeout_secs: 10, // Default profiling timeout
258+
pd_kv_cache_ttl_secs: self.pd_kv_cache_ttl_secs,
256259
})
257260
}
258261
}
@@ -327,6 +330,8 @@ impl Router {
327330
// Tracing defaults
328331
enable_trace = false,
329332
otlp_traces_endpoint = None,
333+
// vLLM P/D defaults
334+
pd_kv_cache_ttl_secs = 0,
330335
))]
331336
#[allow(clippy::too_many_arguments)]
332337
fn new(
@@ -390,6 +395,7 @@ impl Router {
390395
tokenizer_path: Option<String>,
391396
enable_trace: bool,
392397
otlp_traces_endpoint: Option<String>,
398+
pd_kv_cache_ttl_secs: u64,
393399
) -> PyResult<Self> {
394400
// Determine connection mode from worker URLs
395401
let mut all_urls = worker_urls.clone();
@@ -470,6 +476,7 @@ impl Router {
470476
tokenizer_path,
471477
enable_trace,
472478
otlp_traces_endpoint,
479+
pd_kv_cache_ttl_secs,
473480
})
474481
}
475482

src/main.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,10 @@ struct CliArgs {
134134
#[arg(long, default_value_t = false)]
135135
vllm_pd_disaggregation: bool,
136136

137+
/// TTL in seconds for Decode-side KV metadata cached for bidirectional vLLM P/D transfer
138+
#[arg(long, default_value_t = 0)]
139+
pd_kv_cache_ttl_secs: u64,
140+
137141
/// ZMQ service discovery address for vLLM P2P NCCL coordination (e.g., "0.0.0.0:30001")
138142
/// Required for --vllm-pd-disaggregation mode. Workers register their HTTP and ZMQ addresses here.
139143
#[arg(long)]
@@ -680,6 +684,7 @@ impl CliArgs {
680684
},
681685
enable_profiling: self.profile,
682686
profile_timeout_secs: 10, // Default profiling timeout
687+
pd_kv_cache_ttl_secs: self.pd_kv_cache_ttl_secs,
683688
})
684689
}
685690

0 commit comments

Comments
 (0)