Skip to content

Commit 2c25222

Browse files
committed
Modify the test tool code to adapt to the latest code.
1 parent aca17e3 commit 2c25222

2 files changed

Lines changed: 63 additions & 39 deletions

File tree

test/test_ucm_connector_save_load.py

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@
5757
UCMConnectorMetadata,
5858
)
5959
from ucm.logger import init_logger
60+
from ucm.store.factory_v1 import UcmConnectorFactoryV1
61+
from ucm.store.ucmstore_v1 import UcmKVStoreBaseV1
6062

6163
logger = init_logger(__name__)
6264

@@ -91,7 +93,7 @@ def make_buffers(
9193
is_mla: bool,
9294
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
9395
logger.info(f"Allocating buffers: blocks={block_number}, batch_size={batch_size}")
94-
hashes = [secrets.token_hex(16) for _ in range(block_number)]
96+
hashes = [secrets.token_bytes(16) for _ in range(block_number)]
9597
device = f"cuda:{device_id}"
9698
kv_caches: Dict[str, torch.Tensor] = {}
9799

@@ -123,8 +125,8 @@ def build_vllm_config(
123125
tp_size: int,
124126
connector_name: str,
125127
storage_backends: str,
126-
transfer_stream_number: int,
127-
use_direct: bool,
128+
stream_number: int,
129+
io_direct: bool,
128130
) -> VllmConfig:
129131
cache_config = CacheConfig(
130132
block_size=block_size,
@@ -189,8 +191,8 @@ def build_vllm_config(
189191
"ucm_connector_name": connector_name,
190192
"ucm_connector_config": {
191193
"storage_backends": storage_backends,
192-
"use_direct": use_direct,
193-
"stream_number": transfer_stream_number,
194+
"io_direct": io_direct,
195+
"stream_number": stream_number,
194196
"local_rank_size": 1,
195197
},
196198
}
@@ -241,6 +243,7 @@ def compute_total_bytes(
241243

242244
def run_once(
243245
connector: UCMConnector,
246+
scheduler: UcmKVStoreBaseV1,
244247
kv_caches: Dict[str, torch.Tensor],
245248
hashes: List[str],
246249
batch_size: int,
@@ -255,12 +258,8 @@ def run_once(
255258
dump_block_ids=(dump_hashes, dump_vllm_block_ids),
256259
)
257260

258-
if (
259-
not hasattr(connector.connector, "k_store")
260-
or connector.connector.k_store is None
261-
):
261+
if not hasattr(connector.connector, "store") or connector.connector.store is None:
262262
connector.connector.register_kv_caches(kv_caches)
263-
264263
connector.bind_connector_metadata(metadata)
265264

266265
total_bytes = compute_total_bytes(kv_caches, batch_size, is_mla)
@@ -273,7 +272,7 @@ def run_once(
273272

274273
write_bw = (total_bytes / (1024**3)) / write_time if write_time > 0 else 0.0
275274

276-
lookup = connector.connector.k_store.lookup(dump_hashes)
275+
lookup = scheduler.lookup(dump_hashes)
277276
if not all(lookup):
278277
raise RuntimeError("Found missing cache blocks before load test.")
279278

@@ -322,8 +321,8 @@ def run_test(
322321
ucm_connector_name: str,
323322
total_tp_size: int,
324323
model_path: str,
325-
transfer_stream_number: int,
326-
use_direct: bool,
324+
stream_number: int,
325+
io_direct: bool,
327326
) -> Tuple[float, float, float, float, float, float]:
328327
block_dim = head_size * num_head
329328
io_size = block_dim * block_len * block_elem_size
@@ -341,8 +340,8 @@ def run_test(
341340
tp_size=total_tp_size,
342341
connector_name=ucm_connector_name,
343342
storage_backends=storage_backends,
344-
transfer_stream_number=transfer_stream_number,
345-
use_direct=use_direct,
343+
stream_number=stream_number,
344+
io_direct=io_direct,
346345
)
347346

348347
dummy_world_group = type("DummyWorldGroup", (), {"local_rank": 0})()
@@ -383,6 +382,19 @@ def broadcast(self, tensor, src):
383382

384383
connector.connector.register_kv_caches(kv_caches)
385384

385+
scheduler_config = {
386+
"storage_backends": storage_backends,
387+
"block_size": block_size,
388+
"device_id": -1, # device_id=-1 means transferEnable=false
389+
"tensor_size": io_size,
390+
"stream_number": stream_number,
391+
"io_direct": io_direct,
392+
"unique_id": secrets.token_hex(8),
393+
}
394+
scheduler = UcmConnectorFactoryV1.create_connector(
395+
ucm_connector_name, scheduler_config
396+
)
397+
386398
w_sizes, w_times, w_bws = [], [], []
387399
r_sizes, r_times, r_bws = [], [], []
388400

@@ -393,10 +405,10 @@ def broadcast(self, tensor, src):
393405
round_hashes = hashes[start_hash_idx:end_hash_idx]
394406

395407
if len(round_hashes) < batch_size:
396-
round_hashes = [secrets.token_hex(16) for _ in range(batch_size)]
408+
round_hashes = [secrets.token_bytes(16) for _ in range(batch_size)]
397409

398410
(w_size, w_time, w_bw), (r_size, r_time, r_bw) = run_once(
399-
connector, kv_caches, round_hashes, batch_size, mla
411+
connector, scheduler, kv_caches, round_hashes, batch_size, mla
400412
)
401413

402414
if round_idx != 0:
@@ -459,7 +471,7 @@ def main():
459471
num_tokens_list = [2048, 4096, 8192, 16384, 32768]
460472
ucm_connector_name = "UcmNfsStore"
461473
model_path = "/home/models/QwQ-32B"
462-
transfer_stream_numbers = [32, 64, 128]
474+
stream_numbers = [32, 64, 128]
463475
os.environ["UC_LOGGER_LEVEL"] = "debug"
464476

465477
print("1. Model Selection:")
@@ -470,8 +482,8 @@ def main():
470482
print("\n2. IoDirect Transfer:")
471483
print(" 1 - Disable IoDirect (default)")
472484
print(" 2 - Enable IoDirect")
473-
use_direct = get_user_input("Please select Direct IO mode", "1")
474-
use_direct = False if use_direct == "1" else True
485+
io_direct = get_user_input("Please select Direct IO mode", "1")
486+
io_direct = False if io_direct == "1" else True
475487

476488
if mla:
477489
block_lens = [64]
@@ -523,7 +535,7 @@ def main():
523535

524536
for num_head in num_head_list:
525537
for block_len in block_lens:
526-
for transfer_stream_number in transfer_stream_numbers:
538+
for stream_number in stream_numbers:
527539
block_dim = head_size * num_head
528540
io_size = block_dim * block_len * block_elem_size
529541

@@ -556,8 +568,8 @@ def main():
556568
ucm_connector_name,
557569
total_tp_size,
558570
model_path,
559-
transfer_stream_number,
560-
use_direct,
571+
stream_number,
572+
io_direct,
561573
),
562574
)
563575

@@ -587,7 +599,7 @@ def main():
587599
kv,
588600
num_head,
589601
block_len,
590-
transfer_stream_number,
602+
stream_number,
591603
io_count,
592604
io_size,
593605
f"{avg_w_size:.4f}",

ucm/store/test/e2e/nfsstore_embed_fetch.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,17 +42,16 @@ def setup(
4242
block_size,
4343
device_id,
4444
io_size,
45-
transferStreamNumber,
46-
transferIoDirect,
45+
stream_number,
46+
io_direct,
4747
) -> UcmKVStoreBaseV1:
4848
config = {
4949
"storage_backends": storage_backends,
50-
"kv_block_size": block_size,
51-
"role": "worker",
52-
"device": device_id,
53-
"io_size": io_size,
54-
"transferStreamNumber": transferStreamNumber,
55-
"transferIoDirect": transferIoDirect,
50+
"block_size": block_size,
51+
"device_id": device_id,
52+
"tensor_size": io_size,
53+
"stream_number": stream_number,
54+
"io_direct": io_direct,
5655
"unique_id": secrets.token_hex(8),
5756
}
5857
return UcmPcStoreV1(config)
@@ -150,13 +149,14 @@ def embed(
150149

151150
def fetch(
152151
store: UcmKVStoreBaseV1,
152+
scheduler: UcmKVStoreBaseV1,
153153
hashes: List[bytes],
154154
kvcaches: Dict[int, torch.Tensor],
155155
mla: bool,
156156
):
157157
start_time = time.perf_counter()
158158

159-
founds = store.lookup(hashes)
159+
founds = scheduler.lookup(hashes)
160160
for f in founds:
161161
assert f, "Cache block miss detected"
162162

@@ -179,6 +179,7 @@ def fetch(
179179
totoal_tensors.append(tensors)
180180

181181
task = store.load(hashes, [], totoal_tensors)
182+
182183
try:
183184
ret = store.wait(task)
184185
if ret is None:
@@ -205,14 +206,14 @@ def run(
205206
repeat: int,
206207
num_head: int,
207208
block_len: int,
208-
transferStreamNumber: int,
209+
stream_number: int,
209210
num_tokens: int,
210211
block_layer: int,
211212
head_size: int,
212213
block_elem_size: int,
213214
kv: int,
214215
mla: bool,
215-
transferIoDirect: bool,
216+
io_direct: bool,
216217
operation_mode: str = "both", # "write_only", "read_only", or "both"
217218
) -> Tuple[float, float, float, float, float, float]:
218219
"""
@@ -241,8 +242,17 @@ def run(
241242
block_size,
242243
device_id,
243244
io_size,
244-
transferStreamNumber,
245-
transferIoDirect,
245+
stream_number,
246+
io_direct,
247+
)
248+
249+
scheduler = setup(
250+
storage_backends,
251+
block_size,
252+
-1, # device_id=-1 means transferEnable=false
253+
io_size,
254+
stream_number,
255+
io_direct,
246256
)
247257

248258
for r in range(repeat):
@@ -302,13 +312,15 @@ def run(
302312

303313
r_size, r_time, r_bw = fetch(
304314
store,
315+
scheduler,
305316
saved_hashes[:batch_size],
306317
kvcaches,
307318
mla,
308319
)
309320
else:
310321
r_size, r_time, r_bw = fetch(
311322
store,
323+
scheduler,
312324
hashes[:batch_size],
313325
kvcaches,
314326
mla,
@@ -353,14 +365,14 @@ def run(
353365
repeat=2,
354366
num_head=1,
355367
block_len=64,
356-
transferStreamNumber=32,
368+
stream_number=32,
357369
num_tokens=4096,
358370
block_layer=61,
359371
head_size=576,
360372
block_elem_size=2,
361373
kv=1,
362374
mla=True,
363-
transferIoDirect=False,
375+
io_direct=False,
364376
operation_mode="both",
365377
)
366378

0 commit comments

Comments
 (0)