Skip to content

Commit aca17e3

Browse files
committed
Modify the test tool code to adapt to the latest code
1 parent 4907197 commit aca17e3

2 files changed

Lines changed: 61 additions & 53 deletions

File tree

test/test_ucm_connector_save_load.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,13 @@ def run_once(
254254
load_block_ids=([], []),
255255
dump_block_ids=(dump_hashes, dump_vllm_block_ids),
256256
)
257-
connector.connector.kv_caches = kv_caches
257+
258+
if (
259+
not hasattr(connector.connector, "k_store")
260+
or connector.connector.k_store is None
261+
):
262+
connector.connector.register_kv_caches(kv_caches)
263+
258264
connector.bind_connector_metadata(metadata)
259265

260266
total_bytes = compute_total_bytes(kv_caches, batch_size, is_mla)
@@ -267,7 +273,7 @@ def run_once(
267273

268274
write_bw = (total_bytes / (1024**3)) / write_time if write_time > 0 else 0.0
269275

270-
lookup = connector.connector.store.lookup(dump_hashes)
276+
lookup = connector.connector.k_store.lookup(dump_hashes)
271277
if not all(lookup):
272278
raise RuntimeError("Found missing cache blocks before load test.")
273279

@@ -277,7 +283,7 @@ def run_once(
277283
load_block_ids=(dump_hashes, load_vllm_block_ids),
278284
dump_block_ids=([], []),
279285
)
280-
connector.connector.kv_caches = kv_caches
286+
281287
connector.bind_connector_metadata(load_metadata)
282288

283289
forward_context = build_forward_context(kv_caches, is_mla)
@@ -375,6 +381,8 @@ def broadcast(self, tensor, src):
375381
mla,
376382
)
377383

384+
connector.connector.register_kv_caches(kv_caches)
385+
378386
w_sizes, w_times, w_bws = [], [], []
379387
r_sizes, r_times, r_bws = [], [], []
380388

ucm/store/test/e2e/nfsstore_embed_fetch.py

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@
3232
import torch
3333

3434
from ucm.store.nfsstore.nfsstore_connector import UcmNfsStore
35+
from ucm.store.pcstore.pcstore_connector_v1 import UcmPcStoreV1
3536
from ucm.store.ucmstore import UcmKVStoreBase
37+
from ucm.store.ucmstore_v1 import UcmKVStoreBaseV1
3638

3739

3840
def setup(
@@ -42,7 +44,7 @@ def setup(
4244
io_size,
4345
transferStreamNumber,
4446
transferIoDirect,
45-
) -> UcmKVStoreBase:
47+
) -> UcmKVStoreBaseV1:
4648
config = {
4749
"storage_backends": storage_backends,
4850
"kv_block_size": block_size,
@@ -51,8 +53,9 @@ def setup(
5153
"io_size": io_size,
5254
"transferStreamNumber": transferStreamNumber,
5355
"transferIoDirect": transferIoDirect,
56+
"unique_id": secrets.token_hex(8),
5457
}
55-
return UcmNfsStore(config)
58+
return UcmPcStoreV1(config)
5659

5760

5861
def make_aligned_tensor(shape, dtype, device, alignment=4096):
@@ -79,64 +82,59 @@ def make_aligned_tensor(shape, dtype, device, alignment=4096):
7982
def make_buffers(
8083
block_number, device_id, batch_size, head_dim, block_len, block_layer, num_head, kv
8184
):
82-
hashes = [secrets.token_hex(16) for _ in range(block_number)]
83-
kv_caches = {}
84-
for i in range(block_layer):
85-
kv_caches[i] = make_aligned_tensor(
85+
hashes = [secrets.token_bytes(16) for _ in range(block_number)]
86+
kvcaches = {}
87+
for layer_id in range(block_layer):
88+
kvcaches[layer_id] = make_aligned_tensor(
8689
[kv, block_number, block_len, num_head, head_dim],
87-
dtype=torch.float16,
90+
dtype=torch.bfloat16,
8891
device=f"cuda:{device_id}",
8992
)
90-
return hashes, kv_caches
93+
kvcaches[layer_id].random_()
94+
return hashes, kvcaches
9195

9296

93-
def store_all_hashes(hashes: List[str]):
97+
def store_all_hashes(hashes: List[bytes]):
9498
file_path = os.path.join(os.path.dirname(__file__), "kvcache_block_hashes.txt")
9599
with open(file_path, "w", encoding="utf-8") as f:
96100
for h in hashes:
97-
f.write(h + "\n")
101+
f.write(h.hex() + "\n")
98102

99103

100-
def load_hashes_from_file() -> List[str]:
104+
def load_hashes_from_file() -> List[bytes]:
101105
file_path = os.path.join(os.path.dirname(__file__), "kvcache_block_hashes.txt")
102106
if not os.path.exists(file_path):
103107
return []
104108
with open(file_path, "r", encoding="utf-8") as f:
105-
return [line.strip() for line in f.readlines()]
109+
return [bytes.fromhex(line.strip()) for line in f.readlines()]
106110

107111

108112
def embed(
109-
store: UcmKVStoreBase,
110-
hashes: List[str],
113+
store: UcmKVStoreBaseV1,
114+
hashes: List[bytes],
111115
kvcaches: Dict[int, torch.Tensor],
112116
mla: bool,
113117
):
114118
start_time = time.perf_counter()
115119

116-
total_block_ids, total_offsets, total_tensors = [], [], []
120+
total_tensors = []
117121
total_size = 0
118122

119123
for i, hash_val in enumerate(hashes):
120-
offset = 0
124+
tensors = []
121125
for layer_id, kv_layer in kvcaches.items():
122-
k_tensor = kv_layer[0][i] # kv=1
123-
total_tensors.append(k_tensor)
124-
total_block_ids.append(hash_val)
125-
total_offsets.append(offset)
126+
k_tensor = kv_layer[0][i].contiguous()
127+
tensors.append(k_tensor)
126128
sz = k_tensor.numel() * k_tensor.element_size()
127-
offset += sz
128129
total_size += sz
129130

130131
if not mla:
131-
v_tensor = kv_layer[1][i]
132-
total_tensors.append(v_tensor)
133-
total_block_ids.append(hash_val)
134-
total_offsets.append(offset)
132+
v_tensor = kv_layer[1][i].contiguous()
133+
tensors.append(v_tensor)
135134
sz = v_tensor.numel() * v_tensor.element_size()
136-
offset += sz
137135
total_size += sz
138-
139-
task = store.dump(total_block_ids, total_offsets, total_tensors)
136+
total_tensors.append(tensors)
137+
task = store.dump(hashes, [], total_tensors)
140138
store.wait(task)
141139

142140
elapsed_time = time.perf_counter() - start_time
@@ -151,8 +149,8 @@ def embed(
151149

152150

153151
def fetch(
154-
store: UcmKVStoreBase,
155-
hashes: List[str],
152+
store: UcmKVStoreBaseV1,
153+
hashes: List[bytes],
156154
kvcaches: Dict[int, torch.Tensor],
157155
mla: bool,
158156
):
@@ -162,32 +160,33 @@ def fetch(
162160
for f in founds:
163161
assert f, "Cache block miss detected"
164162

165-
block_ids, offsets, tensors = [], [], []
163+
totoal_tensors = []
166164
total_size = 0
167165

168166
for i, hash_val in enumerate(hashes):
169-
offset = 0
167+
tensors = []
170168
for layer_id, kv_layer in kvcaches.items():
171-
k_tensor = kv_layer[0][i] # kv=1
172-
block_ids.append(hash_val)
173-
offsets.append(offset)
169+
k_tensor = kv_layer[0][i].contiguous()
174170
tensors.append(k_tensor)
175171
sz = k_tensor.numel() * k_tensor.element_size()
176-
offset += sz
177172
total_size += sz
178173

179174
if not mla:
180-
v_tensor = kv_layer[1][i]
181-
block_ids.append(hash_val)
182-
offsets.append(offset)
175+
v_tensor = kv_layer[1][i].contiguous()
183176
tensors.append(v_tensor)
184177
sz = v_tensor.numel() * v_tensor.element_size()
185-
offset += sz
186178
total_size += sz
179+
totoal_tensors.append(tensors)
187180

188-
task = store.load(block_ids, offsets, tensors)
189-
ret = store.wait(task)
190-
assert ret == 0, "Load operation failed"
181+
task = store.load(hashes, [], totoal_tensors)
182+
try:
183+
ret = store.wait(task)
184+
if ret is None:
185+
ret = 0
186+
except RuntimeError as e:
187+
print(f"Load operation failed with error: {e}")
188+
raise
189+
assert ret == 0, f"Load operation failed with return code: {ret}"
191190

192191
elapsed_time = time.perf_counter() - start_time
193192
throughput_gbps = (total_size / (1024**3)) / elapsed_time if elapsed_time > 0 else 0
@@ -226,6 +225,10 @@ def run(
226225
block_dim = head_size * num_head
227226
io_size = block_dim * block_len * block_elem_size
228227
block_size = io_size * block_layer
228+
229+
if not mla:
230+
block_size = block_size * 2
231+
229232
batch_size = int(num_tokens / block_len)
230233
real_blocks = batch_size + 10
231234

@@ -257,16 +260,13 @@ def run(
257260
kv,
258261
)
259262

260-
results = store.create(hashes[:batch_size])
261-
assert sum(results) == 0, "Create operation failed"
262-
263263
w_size, w_time, w_bw = embed(
264264
store,
265265
hashes[:batch_size],
266266
kvcaches,
267267
mla,
268268
)
269-
store.commit(hashes[:batch_size], True)
269+
time.sleep(1)
270270

271271
if r == 0:
272272
store_all_hashes(hashes[:batch_size])
@@ -349,10 +349,10 @@ def run(
349349
try:
350350
result = run(
351351
storage_backends=".",
352-
device_id=1,
353-
repeat=1,
352+
device_id=6,
353+
repeat=2,
354354
num_head=1,
355-
block_len=128,
355+
block_len=64,
356356
transferStreamNumber=32,
357357
num_tokens=4096,
358358
block_layer=61,

0 commit comments

Comments
 (0)