Skip to content

Commit b6c170f

Browse files
committed
feat: configure bidirectional pd kv transfer
1 parent cd2f3da commit b6c170f

10 files changed

Lines changed: 323 additions & 200 deletions

File tree

packages/prime-rl-configs/src/prime_rl/configs/inference.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,69 @@ class MultiNodeInferenceDeploymentConfig(BaseInferenceDeploymentConfig):
172172
] = "consistent_hash"
173173

174174

175+
class NixlTransportConfig(BaseModel):
176+
"""Configures NIXL KV transfer for disaggregated inference deployments."""
177+
178+
model_config = ConfigDict(extra="forbid")
179+
180+
type: Literal["nixl"] = "nixl"
181+
182+
enable_bidirectional: Annotated[
183+
bool,
184+
Field(
185+
description=(
186+
"Whether Prefill workers can pull Decode-side KV through NIXL for later requests "
187+
"in the same conversation."
188+
),
189+
),
190+
] = False
191+
num_threads: Annotated[
192+
int,
193+
Field(ge=1, description="Number of NIXL connector threads."),
194+
] = 1
195+
kv_recompute_threshold: Annotated[
196+
int,
197+
Field(
198+
ge=0,
199+
description=(
200+
"Minimum number of remote Decode-side KV tokens required before a Prefill worker pulls "
201+
"KV through NIXL instead of recomputing locally. Passed to NixlConnector extra config."
202+
),
203+
),
204+
] = 64
205+
abort_timeout_seconds: Annotated[
206+
int,
207+
Field(
208+
gt=0,
209+
description=(
210+
"Seconds vLLM NIXL waits for the peer to fetch held KV blocks before aborting and freeing them. "
211+
"Exported as NIXL_ABORT_TIMEOUT and vLLM's VLLM_NIXL_ABORT_REQUEST_TIMEOUT."
212+
),
213+
),
214+
] = 480
215+
router_cache_ttl_seconds: Annotated[
216+
int | None,
217+
Field(
218+
gt=0,
219+
description=(
220+
"Seconds vllm-router keeps Decode-side KV metadata for bidirectional P/D reuse. "
221+
"Defaults to 95% of abort_timeout_seconds."
222+
),
223+
),
224+
] = None
225+
226+
@model_validator(mode="after")
227+
def validate_router_cache_ttl(self):
228+
if self.router_cache_ttl_seconds is None:
229+
self.router_cache_ttl_seconds = int(self.abort_timeout_seconds * 0.95)
230+
if self.router_cache_ttl_seconds >= self.abort_timeout_seconds:
231+
raise ValueError(
232+
"router_cache_ttl_seconds must be less than abort_timeout_seconds "
233+
f"({self.router_cache_ttl_seconds} >= {self.abort_timeout_seconds})"
234+
)
235+
return self
236+
237+
175238
class DisaggregatedInferenceDeploymentConfig(BaseInferenceDeploymentConfig):
176239
"""Configures a disaggregated prefill/decode inference deployment.
177240
@@ -211,6 +274,11 @@ class DisaggregatedInferenceDeploymentConfig(BaseInferenceDeploymentConfig):
211274
str, Field(description="Routing policy for the vllm-router (e.g. 'consistent_hash', 'round_robin').")
212275
] = "consistent_hash"
213276

277+
kv_transport_config: Annotated[
278+
NixlTransportConfig,
279+
Field(description="KV transport settings for disaggregated P/D deployments."),
280+
] = NixlTransportConfig()
281+
214282
prefill_env_overrides: Annotated[
215283
dict[str, str],
216284
Field(description="Extra environment variables exported only on prefill nodes."),

packages/prime-rl-configs/src/prime_rl/configs/rl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -770,8 +770,9 @@ def auto_setup_lora(self):
770770

771771
@model_validator(mode="after")
772772
def auto_setup_session_headers(self):
773-
"""Ensure X-Session-ID header is always set for sticky DP-aware routing at the inference router."""
773+
"""Ensure stable routing headers are set for inference routers."""
774774
self.orchestrator.client.extra_headers_from_state.setdefault("X-Session-ID", "example_id")
775+
self.orchestrator.client.extra_headers_from_state.setdefault("X-Conversation-ID", "trajectory_id")
775776
return self
776777

777778
@model_validator(mode="after")

pyproject.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ dependencies = [
1818
"torchaudio",
1919
"torchdata>=0.11.0",
2020
"transformers",
21-
"vllm>=0.20.2",
21+
"vllm==0.21.0",
2222
"wandb>=0.26.1",
2323
"ring-flash-attn>=0.1.8",
2424
"prime>=0.6.4",
@@ -176,6 +176,7 @@ override-dependencies = [
176176
[tool.uv.exclude-newer-package]
177177
# we want latest vllm, remove next patch
178178
vllm = false
179+
tokenspeed-mla = false
179180
flash_attn_3 = false
180181
# PrimeIntellect-published on PyPI (trusted publisher)
181182
prime = false
@@ -229,10 +230,10 @@ dion = { git = "https://github.com/samsja/dion.git", rev = "d891eeb" }
229230
transformers = { git = "https://github.com/huggingface/transformers.git", rev = "c1c3424" }
230231
flash-attn-4 = { git = "https://github.com/Dao-AILab/flash-attention.git", subdirectory = "flash_attn/cute", rev = "96bd151" }
231232
pydantic-config = { git = "https://github.com/samsja/pydantic_config.git", branch = "main" }
232-
vllm-router = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.22/vllm_router-0.1.22-cp38-abi3-manylinux_2_28_x86_64.whl" }
233+
vllm-router = { git = "https://github.com/PrimeIntellect-ai/router.git", rev = "23af7bb" }
233234
vllm = [
234-
{ url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_x86_64.whl", marker = "platform_machine == 'x86_64'" },
235-
{ url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_aarch64.whl", marker = "platform_machine == 'aarch64'" },
235+
{ url = "https://files.pythonhosted.org/packages/73/6d/9b78990c9fabc70c7731de6af246a420156dc019f66b48da7c86f509c132/vllm-0.21.0-1-cp38-abi3-manylinux_2_24_x86_64.whl", marker = "platform_machine == 'x86_64'" },
236+
{ url = "https://files.pythonhosted.org/packages/ac/58/564b64d17dde6dc31faae836f98313538c152edf88e2a4fb43b9d551a635/vllm-0.21.0-1-cp38-abi3-manylinux_2_24_aarch64.whl", marker = "platform_machine == 'aarch64'" },
236237
]
237238
deep-ep = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/deep_ep-1.2.1+29d31c0-cp312-cp312-linux_x86_64.whl" }
238239
deep-gemm = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/deep_gemm-2.5.0+891d57b-cp312-cp312-linux_x86_64.whl" }

skills/config/SKILL.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,29 @@ In TOML, an empty section header does the same:
172172
[ckpt] # enables checkpointing with defaults
173173
```
174174

175+
### Disaggregated inference
176+
177+
For `[deployment] type = "disaggregated"`, P/D NIXL transfer knobs live under `deployment.kv_transport_config`:
178+
179+
```toml
180+
[deployment]
181+
type = "disaggregated"
182+
183+
[deployment.kv_transport_config]
184+
type = "nixl"
185+
enable_bidirectional = true
186+
num_threads = 1
187+
kv_recompute_threshold = 64
188+
abort_timeout_seconds = 480
189+
router_cache_ttl_seconds = 456
190+
```
191+
192+
`enable_bidirectional` defaults to `false`. When it is false, the Slurm templates pass `--pd-kv-cache-ttl-secs 0` to vllm-router so Decode-side KV metadata is not reused.
193+
`router_cache_ttl_seconds` can be omitted; it defaults to 95% of `abort_timeout_seconds` and must remain lower than the abort timeout.
194+
The Slurm templates export `abort_timeout_seconds` as both `NIXL_ABORT_TIMEOUT` and vLLM's `VLLM_NIXL_ABORT_REQUEST_TIMEOUT`.
195+
196+
P/D NIXL deployments need UCX 1.19 or newer for H200 CUDA buffer registration. The Slurm templates add `$PROJECT_DIR/third_party/ucx` to `LD_LIBRARY_PATH`.
197+
175198
## Key files
176199

177200
- `src/prime_rl/utils/config.py` — re-exports `BaseConfig` and `cli` from pydantic_config

src/prime_rl/entrypoints/inference.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def write_slurm_script(config: InferenceConfig, config_path: Path, script_path:
6060
decode_port=config.deployment.decode_port,
6161
router_port=config.deployment.router_port,
6262
router_policy=config.deployment.router_policy,
63+
kv_transport_config=config.deployment.kv_transport_config,
6364
data_parallel_rpc_port=config.data_parallel_rpc_port,
6465
use_deep_gemm=config.use_deep_gemm,
6566
prefill_env_overrides=config.deployment.prefill_env_overrides,

src/prime_rl/entrypoints/rl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,8 +437,10 @@ def write_slurm_script(config: RLConfig, config_dir: Path, script_path: Path) ->
437437
num_decode_replicas=infer_deploy.num_decode_replicas,
438438
gpus_per_node=config.deployment.gpus_per_node,
439439
router_port=infer_deploy.router_port,
440+
router_policy=infer_deploy.router_policy,
440441
prefill_port=infer_deploy.prefill_port,
441442
decode_port=infer_deploy.decode_port,
443+
kv_transport_config=infer_deploy.kv_transport_config,
442444
inference_tp=config.inference.parallel.tp,
443445
inference_data_parallel_rpc_port=config.inference.data_parallel_rpc_port,
444446
use_deep_gemm=config.inference.use_deep_gemm,

src/prime_rl/inference/patches.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -897,9 +897,9 @@ def monkey_patch_dp_engine_core_pause_resume_deadlock():
897897
- on resume, wake every DP rank and force an immediate global unfinished
898898
sync instead of waiting for the normal 32-step cadence
899899
900-
This keeps the upstream pause-side fix from
901-
https://github.com/vllm-project/vllm/pull/37024 and extends it with the
902-
resume-side wave-state fix.
900+
This also bypasses vLLM's two-phase DP pause implementation
901+
(https://github.com/vllm-project/vllm/pull/39366), which makes resume
902+
reject states that our weight-update flow can validly hit.
903903
"""
904904
from vllm.config import ParallelConfig
905905
from vllm.v1.core.sched.interface import PauseState
@@ -909,7 +909,8 @@ def monkey_patch_dp_engine_core_pause_resume_deadlock():
909909

910910
_base_add_request = EngineCore.add_request
911911
_base_handle_client_request = EngineCoreProc._handle_client_request
912-
_base_resume_scheduler = DPEngineCoreProc.resume_scheduler
912+
_base_pause_complete = EngineCoreProc._pause_complete
913+
_base_resume_scheduler = EngineCoreProc.resume_scheduler
913914

914915
def _patched_add_request(self, request: Request, request_wave: int = 0):
915916
_base_add_request(self, request, request_wave)
@@ -930,8 +931,15 @@ def _patched_handle_client_request(self, request_type, request):
930931
else:
931932
_base_handle_client_request(self, request_type, request)
932933

934+
def _patched_pause_complete(self) -> bool:
935+
self.pending_pause = False
936+
self.ignore_start_dp_wave = False
937+
return _base_pause_complete(self)
938+
933939
def _patched_resume_scheduler(self):
934940
was_paused = self.scheduler.pause_state != PauseState.UNPAUSED
941+
self.pending_pause = False
942+
self.ignore_start_dp_wave = False
935943
_base_resume_scheduler(self)
936944
if was_paused:
937945
self.engines_running = True
@@ -948,6 +956,7 @@ def _patched_has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
948956

949957
DPEngineCoreProc.add_request = _patched_add_request
950958
DPEngineCoreProc._handle_client_request = _patched_handle_client_request
959+
DPEngineCoreProc._pause_complete = _patched_pause_complete
951960
DPEngineCoreProc.resume_scheduler = _patched_resume_scheduler
952961
DPEngineCoreProc._has_global_unfinished_reqs = _patched_has_global_unfinished_reqs
953962

src/prime_rl/templates/inference.sbatch.j2

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ export PREFILL_PORT={{ prefill_port }}
3636
export DECODE_PORT={{ decode_port }}
3737
export ROUTER_PORT={{ router_port }}
3838
export RPC_PORT={{ data_parallel_rpc_port }}
39+
export NIXL_ABORT_TIMEOUT={{ kv_transport_config.abort_timeout_seconds }}
40+
export VLLM_NIXL_ABORT_REQUEST_TIMEOUT={{ kv_transport_config.abort_timeout_seconds }}
3941
{%- elif num_nodes > 1 %}
4042
export ROUTER_PORT={{ router_port }}
4143
export BACKEND_PORT={{ backend_port }}
@@ -171,15 +173,15 @@ srun bash -c '
171173
export VLLM_NIXL_SIDE_CHANNEL_PORT=5600
172174
173175
{%- if kv_offload %}
174-
PREFILL_KV_CFG='"'"'{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_connector_extra_config":{"num_threads":1}},{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"cpu_bytes_to_use":{{ kv_offload_cpu_bytes }}}}]}}'"'"'
176+
PREFILL_KV_CFG='"'"'{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_connector_extra_config":{"num_threads":{{ kv_transport_config.num_threads }},"bidirectional_kv_xfer":{{ "true" if kv_transport_config.enable_bidirectional else "false" }},"kv_recompute_threshold":{{ kv_transport_config.kv_recompute_threshold }}}},{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"cpu_bytes_to_use":{{ kv_offload_cpu_bytes }}}}]}}'"'"'
175177
{%- else %}
176-
PREFILL_KV_CFG='"'"'{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_connector_extra_config":{"num_threads":1}}'"'"'
178+
PREFILL_KV_CFG='"'"'{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_connector_extra_config":{"num_threads":{{ kv_transport_config.num_threads }},"bidirectional_kv_xfer":{{ "true" if kv_transport_config.enable_bidirectional else "false" }},"kv_recompute_threshold":{{ kv_transport_config.kv_recompute_threshold }}}}'"'"'
177179
{%- endif %}
178180
179181
{%- if kv_offload %}
180-
DECODE_KV_CFG='"'"'{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_connector_extra_config":{"num_threads":1}},{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"cpu_bytes_to_use":{{ kv_offload_cpu_bytes }}}}]}}'"'"'
182+
DECODE_KV_CFG='"'"'{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_connector_extra_config":{"num_threads":{{ kv_transport_config.num_threads }},"bidirectional_kv_xfer":{{ "true" if kv_transport_config.enable_bidirectional else "false" }},"kv_recompute_threshold":{{ kv_transport_config.kv_recompute_threshold }}}},{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"cpu_bytes_to_use":{{ kv_offload_cpu_bytes }}}}]}}'"'"'
181183
{%- else %}
182-
DECODE_KV_CFG='"'"'{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_connector_extra_config":{"num_threads":1}}'"'"'
184+
DECODE_KV_CFG='"'"'{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_connector_extra_config":{"num_threads":{{ kv_transport_config.num_threads }},"bidirectional_kv_xfer":{{ "true" if kv_transport_config.enable_bidirectional else "false" }},"kv_recompute_threshold":{{ kv_transport_config.kv_recompute_threshold }}}}'"'"'
183185
{%- endif %}
184186
185187
DECODE_COMPILE_CFG='"'"'{"cudagraph_mode":"FULL_DECODE_ONLY"}'"'"'
@@ -250,6 +252,7 @@ srun bash -c '
250252
--host 0.0.0.0 \
251253
--port $ROUTER_PORT \
252254
--intra-node-data-parallel-size {{ dp_per_node }} \
255+
--pd-kv-cache-ttl-secs {{ kv_transport_config.router_cache_ttl_seconds if kv_transport_config.enable_bidirectional else 0 }} \
253256
--worker-startup-timeout-secs 4200 \
254257
--log-level debug \
255258
>> $ROUTER_LOG 2>&1 &

src/prime_rl/templates/multi_node_rl.sbatch.j2

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ export NODES_PER_PREFILL_REPLICA=$((NUM_PREFILL_NODES / NUM_PREFILL_REPLICAS))
4141
export NODES_PER_DECODE_REPLICA=$((NUM_DECODE_NODES / NUM_DECODE_REPLICAS))
4242
export PREFILL_PORT={{ prefill_port }}
4343
export DECODE_PORT={{ decode_port }}
44+
export NIXL_ABORT_TIMEOUT={{ kv_transport_config.abort_timeout_seconds }}
45+
export VLLM_NIXL_ABORT_REQUEST_TIMEOUT={{ kv_transport_config.abort_timeout_seconds }}
4446
{%- else -%}
4547
export BACKEND_PORT={{ backend_port }}
4648
export INFERENCE_ENABLE_EXPERT_PARALLEL={{ "1" if inference_enable_expert_parallel else "0" }}
@@ -223,15 +225,15 @@ if [ "$SLURM_PROCID" -lt "$NUM_INFER_NODES" ]; then
223225
export VLLM_NIXL_SIDE_CHANNEL_PORT=5600
224226
225227
{%- if kv_offload %}
226-
PREFILL_KV_CFG='"'"'{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_connector_extra_config":{"num_threads":1}},{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"cpu_bytes_to_use":{{ kv_offload_cpu_bytes }}}}]}}'"'"'
228+
PREFILL_KV_CFG='"'"'{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_connector_extra_config":{"num_threads":{{ kv_transport_config.num_threads }},"bidirectional_kv_xfer":{{ "true" if kv_transport_config.enable_bidirectional else "false" }},"kv_recompute_threshold":{{ kv_transport_config.kv_recompute_threshold }}}},{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"cpu_bytes_to_use":{{ kv_offload_cpu_bytes }}}}]}}'"'"'
227229
{%- else %}
228-
PREFILL_KV_CFG='"'"'{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_connector_extra_config":{"num_threads":1}}'"'"'
230+
PREFILL_KV_CFG='"'"'{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_connector_extra_config":{"num_threads":{{ kv_transport_config.num_threads }},"bidirectional_kv_xfer":{{ "true" if kv_transport_config.enable_bidirectional else "false" }},"kv_recompute_threshold":{{ kv_transport_config.kv_recompute_threshold }}}}'"'"'
229231
{%- endif %}
230232
231233
{%- if kv_offload %}
232-
DECODE_KV_CFG='"'"'{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_connector_extra_config":{"num_threads":1}},{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"cpu_bytes_to_use":{{ kv_offload_cpu_bytes }}}}]}}'"'"'
234+
DECODE_KV_CFG='"'"'{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_connector_extra_config":{"num_threads":{{ kv_transport_config.num_threads }},"bidirectional_kv_xfer":{{ "true" if kv_transport_config.enable_bidirectional else "false" }},"kv_recompute_threshold":{{ kv_transport_config.kv_recompute_threshold }}}},{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"cpu_bytes_to_use":{{ kv_offload_cpu_bytes }}}}]}}'"'"'
233235
{%- else %}
234-
DECODE_KV_CFG='"'"'{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_connector_extra_config":{"num_threads":1}}'"'"'
236+
DECODE_KV_CFG='"'"'{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_connector_extra_config":{"num_threads":{{ kv_transport_config.num_threads }},"bidirectional_kv_xfer":{{ "true" if kv_transport_config.enable_bidirectional else "false" }},"kv_recompute_threshold":{{ kv_transport_config.kv_recompute_threshold }}}}'"'"'
235237
{%- endif %}
236238
DECODE_COMPILE_CFG='"'"'{"cudagraph_mode":"FULL_DECODE_ONLY"}'"'"'
237239
@@ -296,12 +298,13 @@ if [ "$SLURM_PROCID" -lt "$NUM_INFER_NODES" ]; then
296298
REPLICA_ROUTER_ARGS=$(echo "$ALL_ROUTER_ARGS" | cut -d"|" -f$((REPLICA_IDX + 1)))
297299
298300
vllm-router \
299-
--policy consistent_hash \
301+
--policy {{ router_policy }} \
300302
--vllm-pd-disaggregation \
301303
$REPLICA_ROUTER_ARGS \
302304
--host 0.0.0.0 \
303305
--port $ROUTER_PORT \
304306
--intra-node-data-parallel-size {{ dp_per_node }} \
307+
--pd-kv-cache-ttl-secs {{ kv_transport_config.router_cache_ttl_seconds if kv_transport_config.enable_bidirectional else 0 }} \
305308
--worker-startup-timeout-secs 4200 \
306309
>> $ROUTER_LOG 2>&1 &
307310
fi

0 commit comments

Comments
 (0)