-
Notifications
You must be signed in to change notification settings - Fork 318
Expand file tree
/
Copy pathdisk_cache_worker.py
More file actions
192 lines (163 loc) · 7.4 KB
/
disk_cache_worker.py
File metadata and controls
192 lines (163 loc) · 7.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import os
import tempfile
import time
import math
from dataclasses import dataclass
from typing import List, Optional
import torch
from lightllm.utils.envs_utils import get_disk_cache_index_prefix, get_unique_server_name
from lightllm.utils.log_utils import init_logger
from .cpu_cache_client import CpuKvCacheClient
logger = init_logger(__name__)
try:
from light_mem import PyLocalCacheService, PyState
except ImportError as e:
logger.error(
"Failed to import LightMem library. Please install it first.\n"
"You can install it by building from source: https://github.com/ModelTC/LightMem"
)
raise ImportError("LightMem library is required for disk cache functionality") from e
@dataclass
class _PagePayload:
index: int
hash_key: int
class DiskCacheWorker:
"""Background worker that offloads CPU KV pages to disk using kvcache."""
def __init__(
self,
disk_cache_storage_size: float,
cpu_cache_client: CpuKvCacheClient,
disk_cache_dir: Optional[str] = None,
redis_endpoint: str = "",
num_node_in_disk_cache: int = 1,
):
self.cpu_cache_client = cpu_cache_client
self._pages_all_idle = False
assert disk_cache_storage_size > 0
storage_size = int(disk_cache_storage_size * (1024 ** 3))
# num_shard与KVCACHE_MAX_BLOCK_SIZE相关,KVCACHE_MAX_BLOCK_SIZE默认64MB前提下,
if num_node_in_disk_cache <= 0:
raise ValueError(f"num_node_in_disk_cache must be >= 1, got {num_node_in_disk_cache}")
num_shard = 64 * num_node_in_disk_cache if redis_endpoint else 64
num_worker = 48
# 读写同时进行时,分配16线程用来写,32线程用来读
max_concurrent_write_tasks = 16
cache_dir = disk_cache_dir
if not cache_dir:
cache_dir = os.path.join(tempfile.gettempdir(), f"lightllm_disk_cache_{get_unique_server_name()}")
os.makedirs(cache_dir, exist_ok=True)
cache_file = os.path.join(cache_dir, "cache_file")
self.max_concurrent_write_tasks = max_concurrent_write_tasks
self._page_major_tensor = self._prepare_tensor(cpu_cache_client.cpu_kv_cache_tensor)
self.service = PyLocalCacheService(
kvcache_tensor=self._page_major_tensor,
file=cache_file,
storage_size=storage_size,
num_shard=num_shard,
num_worker=num_worker,
index_endpoint=redis_endpoint,
index_prefix=get_disk_cache_index_prefix(),
bandwidth_log=True,
)
logger.info(
"disk cache worker initialized: dir=%s size_bytes=%d shards=%d workers=%d pages_per_block=%d",
cache_dir,
storage_size,
num_shard,
num_worker,
self.service._n,
)
def _prepare_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
return tensor.flatten(1).view(dtype=torch.uint8)
def run(self) -> None:
while True:
time.sleep(0.1)
payload_groups = self._gather_offload_payloads()
if not payload_groups:
continue
for payloads in payload_groups:
if not payloads:
continue
self._persist_pages_to_disk(payloads)
def _gather_offload_payloads(self) -> List[List[_PagePayload]]:
self.cpu_cache_client.lock.acquire_sleep1ms()
grouped_indexes = self.cpu_cache_client.get_pages_to_offloading()
self.cpu_cache_client.lock.release()
payload_groups: List[List[_PagePayload]] = []
if not grouped_indexes:
return payload_groups
page_items = self.cpu_cache_client.page_items.linked_items
for group in grouped_indexes:
payloads: List[_PagePayload] = []
for page_index in group:
page_item = page_items[page_index]
payloads.append(_PagePayload(index=page_index, hash_key=int(page_item.hash_key)))
payload_groups.append(payloads)
return payload_groups
# 数据写入磁盘
def _persist_pages_to_disk(self, payloads: List[_PagePayload]) -> None:
if not payloads:
return
page_indexes = [payload.index for payload in payloads]
hashs = [payload.hash_key for payload in payloads]
if not page_indexes:
return
kv_indexer = torch.tensor(page_indexes, dtype=torch.int32, device="cpu")
query_result = self.service.query(hashs)
if not all(query_result):
# 限制写入并发量,给读取操作留资源
while (
self.service.active_threads("r") and self.service.active_threads("w") >= self.max_concurrent_write_tasks
):
time.sleep(0.001)
task = self.service.create(hash_128s=hashs, kv_page_indexer=kv_indexer, mode="w")
# 立即释放已经在disk cache中的页面
if task.page_already_list:
self.cpu_cache_client.lock.acquire_sleep1ms()
self.cpu_cache_client.deref_pages(page_list=task.page_already_list)
self.cpu_cache_client.lock.release()
# 数据安全即可结束等待,无需写入完成
while not task.data_safe():
time.sleep(0.001)
# 释放剩余需要写入的页面
remining_indexes = list(set(page_indexes) - set(task.page_already_list))
if remining_indexes:
self.cpu_cache_client.lock.acquire_sleep1ms()
self.cpu_cache_client.deref_pages(page_list=remining_indexes)
self.cpu_cache_client.lock.release()
else:
self.cpu_cache_client.lock.acquire_sleep1ms()
self.cpu_cache_client.deref_pages(page_list=page_indexes)
self.cpu_cache_client.lock.release()
def query_loadable_pages(self, hashs: List[int], start_pos: int) -> int:
"""
查询从start_pos位置开始,可以从disk cache加载的最长前缀长度
Returns:
loadable_len: 从start_pos开始可以加载的长度
"""
if not hashs or start_pos < 0 or start_pos >= len(hashs):
return 0
query_result = self.service.query(hashs)
n = self.service._n
start_block = start_pos // n
try:
first_false_idx = start_block + query_result[start_block:].index(False)
except ValueError:
return len(hashs) - start_pos
first_missing_pos = first_false_idx * n
return max(0, first_missing_pos - start_pos)
# 从磁盘读取数据到内存
def load_pages(self, hashs: List[int], page_indexes: List[int], start_pos: int = 0) -> bool:
if not hashs or not page_indexes or len(hashs) != len(page_indexes):
return False
if start_pos < 0 or start_pos >= len(hashs):
return False
# 检测当前是否有写操作在进行,若有则跳过本次load请求,暂时不用
# if self.service.active_threads("w") > 0:
# logger.warning("disk cache worker is busy writing, skip load_pages")
# return False
kv_indexer = torch.tensor(page_indexes, dtype=torch.int32, device="cpu")
task = self.service.create(hash_128s=hashs, kv_page_indexer=kv_indexer, mode="r", start_pos=start_pos)
while not task.ready():
time.sleep(0.001)
return all(state == PyState.Finished for state in task.state())