Skip to content

Commit 1a991b4

Browse files
author
niushengxiao
committed
feat: refactor kv buffer
1 parent 40d8fdc commit 1a991b4

File tree

17 files changed

+545
-165
lines changed

17 files changed

+545
-165
lines changed

lightllm/common/basemodel/attention/base_att.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,7 @@ def _find_layer_index(
4141
self, k: torch.Tensor, v: torch.Tensor, att_state: Union["BasePrefillAttState", "BaseDecodeAttState"]
4242
) -> int:
4343
kv_buffer = att_state.infer_state.mem_manager.kv_buffer
44-
layer_count = len(kv_buffer)
45-
find_dict = {kv_buffer[i].data_ptr(): i for i in range(layer_count)}
46-
key = min(k.data_ptr(), v.data_ptr())
47-
assert key in find_dict
48-
return find_dict[key]
44+
return kv_buffer.find_layer_index(k, v)
4945

5046

5147
@dataclass

lightllm/common/kv_cache_mem_manager/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from .kv_buffer.kv_buffer import KvBuffer
2+
from .kv_buffer.quant_kv_buffer import QuantKvBuffer, PPLInt4QuantKvBuffer, PPLInt8QuantKvBuffer
13
from .mem_manager import MemoryManager, ReadOnlyStaticsMemoryManager
24
from .ppl_int8kv_mem_manager import PPLINT8KVMemoryManager
35
from .ppl_int4kv_mem_manager import PPLINT4KVMemoryManager
@@ -6,8 +8,18 @@
68
from .fp8_per_token_group_quant_deepseek3_2mem_manager import FP8PerTokenGroupQuantDeepseek3_2MemoryManager
79
from .fp8_static_per_head_quant_mem_manager import FP8StaticPerHeadQuantMemManager
810
from .fp8_static_per_tensor_quant_mem_manager import FP8StaticPerTensorQuantMemManager
11+
from .kv_buffer.kv_buffer_adapter import KvBufferAdapter
12+
from .kv_buffer.hybrid_kv_buffer import HybridKvBuffer
13+
from .kv_buffer.hybrid_kv_buffer_adapter import HybridKvBufferAdapter
914

1015
__all__ = [
16+
"KvBuffer",
17+
"QuantKvBuffer",
18+
"PPLInt4QuantKvBuffer",
19+
"PPLInt8QuantKvBuffer",
20+
"HybridKvBuffer",
21+
"KvBufferAdapter",
22+
"HybridKvBufferAdapter",
1123
"MemoryManager",
1224
"ReadOnlyStaticsMemoryManager",
1325
"PPLINT4KVMemoryManager",

lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
import os
33
import torch.distributed as dist
44
from lightllm.server.pd_io_struct import KVMoveTask
5+
from .kv_buffer.kv_buffer import KvBuffer
56
from .mem_manager import MemoryManager
67
from typing import List, Union, Any
78
from lightllm.utils.log_utils import init_logger
89
from lightllm.common.kv_trans_kernel.kv_trans import kv_trans
910
from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_d_node, kv_trans_v2_for_p_node
1011
from lightllm.distributed.pynccl import PyNcclCommunicator
11-
from lightllm.common.kv_trans_kernel.nixl_kv_trans import mla_page_io
1212

1313

1414
logger = init_logger(__name__)
@@ -45,7 +45,10 @@ def get_cell_size(self):
4545
return self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype)
4646

4747
def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
48-
self.kv_buffer = torch.empty((layer_num, size + 1, head_num, head_dim), dtype=dtype, device="cuda")
48+
self.kv_buffer = KvBuffer(
49+
torch.empty((layer_num, size + 1, head_num, head_dim), dtype=dtype, device="cuda"),
50+
head_num=head_num,
51+
)
4952

5053
def alloc_kv_move_buffer(self, max_req_total_len):
5154
self.kv_move_buffer = torch.empty(
@@ -77,11 +80,8 @@ def write_mem_to_page_kv_move_buffer(
7780
pin_mem_indexes.numpy()[:] = mem_indexes
7881
mem_indexes_gpu = pin_mem_indexes.cuda(non_blocking=True)
7982
dp_mems = mem_managers[(dp_index * dp_world_size) : ((dp_index + 1) * dp_world_size)]
80-
mla_page_io(
81-
mem_indexes=mem_indexes_gpu,
82-
page_tensor=cur_page,
83-
kv_buffer=dp_mems[0].kv_buffer,
84-
mode="write",
83+
dp_mems[0].kv_buffer_adapter.write_to_page_buffer(
84+
mem_indexes=mem_indexes_gpu, page_tensor=cur_page, is_mla=True
8585
)
8686
return
8787

@@ -99,12 +99,7 @@ def read_page_kv_move_buffer_to_mem(
9999
mem_indexes_gpu = pin_mem_indexes.cuda(non_blocking=True)
100100
dp_mems = mem_managers[(dp_index * dp_world_size) : ((dp_index + 1) * dp_world_size)]
101101
for mem in dp_mems:
102-
mla_page_io(
103-
mem_indexes=mem_indexes_gpu,
104-
page_tensor=cur_page,
105-
kv_buffer=mem.kv_buffer,
106-
mode="read",
107-
)
102+
mem.kv_buffer_adapter.read_from_page_buffer(mem_indexes=mem_indexes_gpu, page_tensor=cur_page, is_mla=True)
108103

109104
def send_to_decode_node(
110105
self,
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from .kv_buffer import KvBuffer
2+
from .quant_kv_buffer import QuantKvBuffer, PPLInt4QuantKvBuffer, PPLInt8QuantKvBuffer
3+
4+
__all__ = [
5+
"KvBuffer",
6+
"QuantKvBuffer",
7+
"PPLInt4QuantKvBuffer",
8+
"PPLInt8QuantKvBuffer",
9+
]
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from typing import Any, List, Optional
2+
3+
import torch
4+
5+
from lightllm.common.mamba_cache_mem_manager.cache_manager import MambaCacheManager
6+
7+
from .kv_buffer import KvBuffer
8+
9+
10+
class HybridKvBuffer(KvBuffer):
11+
def __init__(
12+
self,
13+
buffers: List[Optional[torch.Tensor]],
14+
head_num: int,
15+
full_attention_interval: int,
16+
mamba_cache_size: int,
17+
linear_attn_layer_num: int,
18+
conv_state_dtype: torch.dtype,
19+
ssm_state_dtype: torch.dtype,
20+
conv_kernel_size: int,
21+
num_linear_k_heads: int,
22+
num_linear_v_heads: int,
23+
head_linear_k_dim: int,
24+
head_linear_v_dim: int,
25+
):
26+
self._buffers = buffers
27+
self._head_num = head_num
28+
self._full_attention_interval = full_attention_interval
29+
self.mamba_cache_manager = MambaCacheManager(
30+
size=mamba_cache_size,
31+
layer_num=linear_attn_layer_num,
32+
conv_state_dtype=conv_state_dtype,
33+
ssm_state_dtype=ssm_state_dtype,
34+
conv_kernel_size=conv_kernel_size,
35+
num_linear_k_heads=num_linear_k_heads,
36+
num_linear_v_heads=num_linear_v_heads,
37+
head_linear_k_dim=head_linear_k_dim,
38+
head_linear_v_dim=head_linear_v_dim,
39+
)
40+
41+
def create_adapter(self):
42+
from .hybrid_kv_buffer_adapter import HybridKvBufferAdapter
43+
44+
return HybridKvBufferAdapter(self)
45+
46+
def get_mamba_cache(self, layer_idx: int):
47+
layer_idx_in_linear = layer_idx - (layer_idx // self._full_attention_interval)
48+
return self.mamba_cache_manager.get_mamba_cache(layer_idx_in_linear)
49+
50+
def __getitem__(self, item):
51+
return self._buffers[item]
52+
53+
def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor) -> None:
54+
from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv
55+
56+
layer_buffer = self._buffers[layer_index]
57+
if layer_buffer is None:
58+
raise RuntimeError(f"layer {layer_index} does not have kv cache storage")
59+
destindex_copy_kv(kv, mem_index, layer_buffer)
60+
61+
def get_att_input_params(self, layer_index: int) -> Any:
62+
layer_buffer = self._buffers[layer_index]
63+
if layer_buffer is None:
64+
raise RuntimeError(f"layer {layer_index} does not have kv cache storage")
65+
k = layer_buffer[:, : self._head_num, :]
66+
v = layer_buffer[:, self._head_num :, :]
67+
return k, v
68+
69+
def get_index_kv_buffer(self, index: Any) -> dict:
70+
return {"kv_buffer": [None if layer_buffer is None else layer_buffer[index] for layer_buffer in self._buffers]}
71+
72+
def load_index_kv_buffer(self, index: Any, payload: dict) -> None:
73+
for layer_index, layer_payload in enumerate(payload["kv_buffer"]):
74+
if layer_payload is None:
75+
continue
76+
layer_buffer = self._buffers[layer_index]
77+
if layer_buffer is None:
78+
raise RuntimeError(f"layer {layer_index} does not have kv cache storage")
79+
layer_buffer[index].copy_(layer_payload)
80+
81+
def get_device(self) -> int:
82+
for layer_buffer in self._buffers:
83+
if layer_buffer is not None:
84+
return layer_buffer.get_device()
85+
raise RuntimeError("HybridKvBuffer does not contain any kv cache tensor")
86+
87+
def find_layer_index(self, k: torch.Tensor, v: torch.Tensor) -> int:
88+
key = min(k.data_ptr(), v.data_ptr())
89+
find_dict = {
90+
layer_buffer.data_ptr(): layer_index
91+
for layer_index, layer_buffer in enumerate(self._buffers)
92+
if layer_buffer is not None
93+
}
94+
assert key in find_dict
95+
return find_dict[key]
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from typing import Optional
2+
3+
import torch
4+
5+
from .hybrid_kv_buffer import HybridKvBuffer
6+
from .kv_buffer_adapter import KvBufferAdapter
7+
8+
9+
class HybridKvBufferAdapter(KvBufferAdapter):
10+
def __init__(self, kv_buffer: HybridKvBuffer):
11+
super().__init__(kv_buffer)
12+
13+
def write_to_page_buffer(
14+
self, mem_indexes: torch.Tensor, page_tensor: torch.Tensor, tp_index: int, tp_world_size: int
15+
) -> None:
16+
raise NotImplementedError(f"{self.__class__.__name__} does not support paged kv write")
17+
18+
def read_from_page_buffer(
19+
self, mem_indexes: torch.Tensor, page_tensor: torch.Tensor, tp_index: int, tp_world_size: int
20+
) -> None:
21+
raise NotImplementedError(f"{self.__class__.__name__} does not support paged kv read")
22+
23+
def write_from_mla_page_buffer(self, mem_indexes: torch.Tensor, page_tensor: torch.Tensor) -> None:
24+
raise NotImplementedError(f"{self.__class__.__name__} does not support mla paged kv write")
25+
26+
def read_from_mla_page_buffer(self, mem_indexes: torch.Tensor, page_tensor: torch.Tensor) -> None:
27+
raise NotImplementedError(f"{self.__class__.__name__} does not support mla paged kv read")
28+
29+
def load_from_cpu_cache(
30+
self,
31+
gpu_mem_indexes: torch.Tensor,
32+
cpu_kv_cache: torch.Tensor,
33+
cpu_kv_cache_scale: Optional[torch.Tensor],
34+
page_indexes: torch.Tensor,
35+
tp_index: int,
36+
tp_world_size: int,
37+
grid_num: int,
38+
) -> None:
39+
raise NotImplementedError(f"{self.__class__.__name__} does not support cpu cache load")
40+
41+
def offload_to_cpu_cache(
42+
self,
43+
token_indexes: torch.Tensor,
44+
cpu_kv_cache: torch.Tensor,
45+
cpu_kv_cache_scale: Optional[torch.Tensor],
46+
page_indexes: torch.Tensor,
47+
page_readies: torch.Tensor,
48+
tp_index: int,
49+
tp_world_size: int,
50+
grid_num: int,
51+
) -> None:
52+
raise NotImplementedError(f"{self.__class__.__name__} does not support cpu cache offload")
53+
54+
def copy_kv_from_other_dp_ranks(
55+
self,
56+
mem_managers,
57+
move_token_indexes: torch.Tensor,
58+
token_dp_indexes: torch.Tensor,
59+
mem_indexes: torch.Tensor,
60+
dp_size_in_node: int,
61+
rank_in_dp: int,
62+
) -> None:
63+
raise NotImplementedError(f"{self.__class__.__name__} does not support dp kv copy")
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from typing import Any, Optional
2+
3+
import torch
4+
5+
6+
class KvBuffer:
7+
"""KV cache 的数据封装类。
8+
9+
这个类的职责是管理 kv buffer 本身的存储与访问语义,关注点是
10+
"这块缓存里存了什么、怎么按层读写、怎么导入导出"。
11+
因此这里的方法应当主要围绕 kv buffer 自身的数据操作展开,
12+
不承载 page io、cpu cache、dp 传输这类业务流程逻辑。
13+
"""
14+
15+
def __init__(self, buffer: torch.Tensor, head_num: int):
16+
self._buffer = buffer
17+
self._head_num = head_num
18+
19+
def create_adapter(self):
20+
# 业务逻辑由 adapter 承接,KvBuffer 只负责提供底层存储对象。
21+
from .kv_buffer_adapter import KvBufferAdapter
22+
23+
return KvBufferAdapter(self)
24+
25+
def __getitem__(self, item):
26+
return self._buffer[item]
27+
28+
@property
29+
def shape(self):
30+
return self._buffer.shape
31+
32+
def get_storage_tensor(self) -> torch.Tensor:
33+
return self._buffer
34+
35+
def get_storage_data_ptr(self) -> int:
36+
return self._buffer.data_ptr()
37+
38+
def get_scale_buffer(self) -> Optional[torch.Tensor]:
39+
return None
40+
41+
def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor) -> None:
42+
from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv
43+
44+
destindex_copy_kv(kv, mem_index, self._buffer[layer_index])
45+
46+
def get_att_input_params(self, layer_index: int) -> Any:
47+
layer_buffer = self._buffer[layer_index]
48+
k = layer_buffer[:, : self._head_num, :]
49+
v = layer_buffer[:, self._head_num :, :]
50+
return k, v
51+
52+
def get_index_kv_buffer(self, index: Any) -> dict:
53+
return {"kv_buffer": self._buffer[:, index]}
54+
55+
def load_index_kv_buffer(self, index: Any, payload: dict) -> None:
56+
self._buffer[:, index].copy_(payload["kv_buffer"])
57+
58+
def get_device(self) -> int:
59+
return self._buffer.get_device()
60+
61+
def find_layer_index(self, k: torch.Tensor, v: torch.Tensor) -> int:
62+
key = min(k.data_ptr(), v.data_ptr())
63+
find_dict = {self._buffer[i].data_ptr(): i for i in range(len(self._buffer))}
64+
assert key in find_dict
65+
return find_dict[key]

0 commit comments

Comments
 (0)