Skip to content

Commit 1b372bd

Browse files
authored
Merge pull request #440 from InfiniTensor/issue/340-pd
feat: support PD disaggregation
2 parents 05585e9 + dc9f6df commit 1b372bd

29 files changed

Lines changed: 3863 additions & 369 deletions

csrc/engine/infer_engine.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,4 +196,22 @@ void InferEngine::reset_cache(const cache::CacheConfig *new_config) {
196196
this->compile();
197197
}
198198

199+
std::vector<std::vector<infinicore::Tensor>> InferEngine::get_kv_cache() {
200+
std::vector<std::vector<infinicore::Tensor>> kv_cache_list;
201+
if (workers_.empty()) {
202+
throw std::runtime_error("InferEngine::get_cache_vec: no workers");
203+
}
204+
205+
kv_cache_list.reserve(workers_.size());
206+
for (auto &worker : workers_) {
207+
kv_cache_list.push_back(std::move(worker->get_kv_cache()));
208+
}
209+
210+
for (auto &worker : workers_) {
211+
worker->wait();
212+
}
213+
214+
return kv_cache_list;
215+
}
216+
199217
} // namespace infinilm::engine

csrc/engine/infer_engine.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ class InferEngine {
4949

5050
void reset_cache(const cache::CacheConfig *new_config);
5151

52+
std::vector<std::vector<infinicore::Tensor>> get_kv_cache();
53+
5254
~InferEngine();
5355

5456
const distributed::DistConfig &get_dist_config() const;

csrc/engine/rank_worker.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,22 @@ void RankWorker::reset_cache(const cache::CacheConfig *new_config) {
208208
cv_.notify_all();
209209
}
210210

211+
//------------------------------------------------------
212+
// get kv cache
213+
//------------------------------------------------------
214+
std::vector<infinicore::Tensor> RankWorker::get_kv_cache() {
215+
std::unique_lock<std::mutex> lk(mutex_);
216+
cv_.wait(lk, [&] { return init_done_ || should_exit_; });
217+
218+
if (should_exit_) {
219+
throw std::runtime_error("RankWorker stopped; cannot get_cache_vec");
220+
}
221+
222+
ASSERT(forward_context_.kv_cache_vec.size() > 0 && "RankWorker::get_kv_cache(): kv_cache_vec is empty");
223+
224+
return forward_context_.kv_cache_vec;
225+
}
226+
211227
//------------------------------------------------------
212228
// close -- request shutdown and join thread
213229
//------------------------------------------------------

csrc/engine/rank_worker.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ class RankWorker {
9898
// Reset the internal cache with a new configuration
9999
void reset_cache(const cache::CacheConfig *new_config);
100100

101+
std::vector<infinicore::Tensor> get_kv_cache();
102+
101103
// Compile the model graph if enabled.
102104
void compile();
103105

csrc/pybind11/engine/engine.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,17 @@ inline void bind_infer_engine(py::module &m) {
102102
.def("process_weights_after_loading", &InferEngine::process_weights_after_loading, "Process the weights after loading on all workers (e.g., for quantization)")
103103
.def(
104104
"forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output {
105+
// IMPORTANT: Release the GIL before calling forward() to allow other Python threads
106+
// to run concurrently during inference (which may block for a long time).
107+
// Do NOT remove this — without it, the GIL is held throughout inference and will
108+
// deadlock or stall any other Python thread (e.g., request handling, scheduling).
105109
py::gil_scoped_release release;
106110
return self.forward(input);
107111
},
108112
"Run inference on all ranks with arbitrary arguments")
109113
.def(
110114
"reset_cache", [](InferEngine &self, std::shared_ptr<cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
115+
.def("get_kv_cache", &InferEngine::get_kv_cache, "Get per-rank kv cache list")
111116
.def("get_cache_config", [](const InferEngine &self) -> std::shared_ptr<cache::CacheConfig> {
112117
auto cfg = self.get_cache_config();
113118
return cfg ? std::shared_ptr<cache::CacheConfig>(cfg->unique_copy()) : nullptr; })

python/infinilm/base_config.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ def __init__(self):
9999
self.port = self.args.port
100100
self.endpoint = self.args.endpoint
101101
self.ignore_eos = self.args.ignore_eos
102+
# PD separation (KV transfer)
103+
self.kv_transfer_config = self.args.kv_transfer_config
102104

103105
# Multimodal parameters
104106
self.image = self.args.image
@@ -268,6 +270,19 @@ def _add_common_args(self):
268270
help="image path for multimodal models",
269271
)
270272

273+
# ---- PD separation arguments ----
274+
self.parser.add_argument(
275+
"--kv-transfer-config",
276+
type=str,
277+
default=None,
278+
help=(
279+
"JSON object for KVTransferConfig. Allowed keys only: "
280+
"kv_connector, engine_id, kv_role, kv_connector_extra_config (omit any for defaults). "
281+
"Example: "
282+
'\'{"kv_connector":"MooncakeConnector","kv_role":"kv_consumer"}\''
283+
),
284+
)
285+
271286
def get_device_str(self, device):
272287
"""Convert device name to backend string (cuda/cpu/musa/mlu)"""
273288
DEVICE_STR_MAP = {

python/infinilm/config/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .engine_config import EngineConfig
2+
from .kv_transfer import KVTransferConfig
3+
4+
__all__ = ["EngineConfig", "KVTransferConfig"]
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from dataclasses import dataclass
2+
from typing import Optional
3+
from infinilm.config.kv_transfer import KVTransferConfig
4+
5+
6+
@dataclass
7+
class EngineConfig:
8+
"""Configuration for LLM Engine.
9+
10+
Attributes:
11+
model_path: Path to the model directory.
12+
device: Device type string ('cpu', 'cuda', 'mlu', etc.).
13+
dtype: Data type string ('float16', 'bfloat16', 'float32').
14+
tensor_parallel_size: Number of devices for tensor parallelism.
15+
cache_type: Cache type ('paged' or 'static').
16+
max_batch_size: Maximum batch size for inference (only for paged cache).
17+
max_tokens: Default maximum tokens to generate.
18+
num_blocks: Number of KV cache blocks (only for paged cache).
19+
block_size: Size of each KV cache block (only for paged cache).
20+
max_cache_len: Maximum sequence length (only for static cache).
21+
temperature: Default sampling temperature.
22+
top_p: Default top-p sampling parameter.
23+
top_k: Default top-k sampling parameter.
24+
enable_graph: Whether to enable graph compiling.
25+
attn_backend: Attention backend to use ('default', 'flash-attn').
26+
skip_load: Whether to skip loading model weights (for testing).
27+
"""
28+
29+
model_path: str
30+
device: str = "cuda"
31+
dtype: str = "float16"
32+
tensor_parallel_size: int = 1
33+
cache_type: str = "paged" # "paged" or "static"
34+
max_batch_size: int = 16
35+
max_tokens: int = 4096
36+
num_blocks: int = 512
37+
block_size: int = 256
38+
max_cache_len: int = 4096
39+
temperature: float = 1.0
40+
top_p: float = 0.8
41+
top_k: int = 1
42+
enable_graph: bool = False
43+
attn_backend: str = "default"
44+
skip_load: bool = False
45+
kv_transfer_config: Optional[KVTransferConfig] = None
46+
47+
def __post_init__(self) -> None:
48+
if (
49+
self.kv_transfer_config is not None
50+
and self.kv_transfer_config.kv_connector
51+
and self.cache_type != "paged"
52+
):
53+
raise ValueError("kv_transfer_config requires cache_type='paged'")
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# SPDX-FileCopyrightText: Copyright 2026 InfiniLM Contributors
4+
5+
import uuid
6+
from dataclasses import dataclass, field
7+
from typing import Optional
8+
import os
9+
10+
KV_ROLE_CHOICES = frozenset({"kv_producer", "kv_consumer"})
11+
12+
13+
@dataclass
14+
class KVTransferConfig:
15+
"""Configuration for KV cache transfer in prefill/decode (P/D) separation.
16+
17+
Attributes:
18+
kv_connector: Name of the KV connector to use (e.g. 'MooncakeConnector').
19+
None disables KV transfer.
20+
kv_role: Role of this node: 'kv_producer' (prefill) or 'kv_consumer' (decode).
21+
engine_id: Unique identifier for this engine instance used in KV transfers.
22+
Auto-generated (UUID) if not provided.
23+
kv_connector_extra_config: Extra configuration dict passed to the connector backend.
24+
"""
25+
26+
kv_connector: Optional[str] = None
27+
kv_role: Optional[str] = None
28+
engine_id: Optional[str] = None
29+
kv_connector_extra_config: Optional[dict] = field(default_factory=dict)
30+
31+
def __post_init__(self) -> None:
32+
if self.kv_connector is not None and self.kv_role is None:
33+
raise ValueError("Please specify kv_role when kv_connector is set.")
34+
35+
if self.kv_role is not None and self.kv_role not in KV_ROLE_CHOICES:
36+
raise ValueError(
37+
f"Unsupported kv_role: {self.kv_role!r}. "
38+
f"Supported roles are {sorted(KV_ROLE_CHOICES)}"
39+
)
40+
41+
if self.engine_id is None:
42+
self.engine_id = f"{self.kv_role}_" + str(uuid.uuid4())
43+
44+
self.kv_connector_extra_config = dict(self.kv_connector_extra_config or {})
45+
self.kv_connector_extra_config.setdefault("mooncake_protocol", "rdma")
46+
47+
allowed_extra_config_keys = frozenset({"mooncake_protocol", "num_workers"})
48+
unknown_keys = set(self.kv_connector_extra_config.keys()) - allowed_extra_config_keys
49+
if unknown_keys:
50+
raise ValueError(
51+
f"Unsupported kv_connector_extra_config keys: {sorted(unknown_keys)}. "
52+
f"Supported keys are {sorted(allowed_extra_config_keys)}"
53+
)
54+
55+
mooncake_protocol = self.kv_connector_extra_config["mooncake_protocol"]
56+
if mooncake_protocol not in ["tcp", "rdma"]:
57+
raise ValueError(f"only support tcp or rdma, but got {mooncake_protocol}")
58+
59+
if mooncake_protocol == "tcp":
60+
# NOTE: force use tcp for Mooncake
61+
os.environ["MC_FORCE_TCP"] = "true"

python/infinilm/infer_engine.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,3 +389,18 @@ def load_state_dict(self, state_dict, strict=None):
389389

390390
def process_weights_after_loading(self):
391391
super().process_weights_after_loading()
392+
393+
def get_kv_cache(self) -> list[list[infinicore.Tensor]]:
394+
"""
395+
get per-rank kv cache.
396+
"""
397+
kv_cache_list = super().get_kv_cache()
398+
infinicore.sync_device()
399+
400+
result = []
401+
for rank_idx, kv_caches_per_rank in enumerate(kv_cache_list):
402+
result_rank = []
403+
for layer_idx, layer_kv in enumerate(kv_caches_per_rank):
404+
result_rank.append(infinicore.Tensor(layer_kv))
405+
result.append(result_rank)
406+
return result

0 commit comments

Comments
 (0)