Skip to content

Commit 597af16

Browse files
committed
[Feat]Add Chunk size
1 parent 457be0c commit 597af16

1 file changed

Lines changed: 48 additions & 14 deletions

File tree

ucm/integration/vllm/ucm_connector.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,12 @@ class RequestDispatchMeta:
6363

6464
class KVCacheLayout:
6565
def __init__(
66-
self, kvcaches, use_layerwise: bool, vllm_config: "VllmConfig"
66+
self, kvcaches, launch_config: dict, vllm_config: "VllmConfig"
6767
) -> None:
6868
# each row is a layer, each column is a tensor_size/ptr in the layer (e.g., k, v, rope, k_index)
6969
self.base_ptrs: np.ndarray # (n_layers, n_ptrs)
7070
self.tensor_size_lists: np.ndarray # (n_layers, n_tensor_sizes)
71-
self.use_layerwise = use_layerwise
71+
self.use_layerwise = launch_config.get("use_layerwise", False)
7272
self.vllm_config = vllm_config
7373
self.pp_size = self.vllm_config.parallel_config.pipeline_parallel_size
7474
self.num_hidden_layers = getattr(
@@ -246,7 +246,10 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
246246
self.enable_event_sync = self.launch_config.get("enable_event_sync", True)
247247
assert len(self.connector_configs) > 0, "no storage connector name in config."
248248

249-
self.chunk_size = self.block_size
249+
self.chunk_size = self.launch_config.get("chunk_size", self.block_size)
250+
assert (
251+
self.chunk_size % self.block_size == 0
252+
), "chunk_size must be divisible by block_size"
250253
self.blocks_per_chunk = self.chunk_size // self.block_size
251254

252255
if role == KVConnectorRole.SCHEDULER:
@@ -361,7 +364,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
361364
for i, tensor in enumerate(sample_kv_layer):
362365
logger.info(f"kv cache shape {i}: {tensor.shape}")
363366
self.kv_cache_layout = KVCacheLayout(
364-
self.kv_caches, self.use_layerwise, self._vllm_config
367+
self.kv_caches, self.launch_config, self._vllm_config
365368
)
366369
self.block_data_size = self.kv_cache_layout.block_size
367370
self.layer_name_to_id = self.kv_cache_layout.layer_name_to_id
@@ -395,10 +398,10 @@ def get_num_new_matched_tokens(
395398
num_computed_tokens: int,
396399
) -> tuple[int, bool]:
397400
assert num_computed_tokens % self.block_size == 0
398-
hbm_hit_block_num = num_computed_tokens // self.block_size
401+
hbm_hit_block_num = num_computed_tokens // self.chunk_size
399402

400403
ucm_block_ids = self.generate_hash(
401-
self.block_size, request.all_token_ids, self._seed
404+
self.chunk_size, request.all_token_ids, self._seed
402405
)
403406

404407
external_block_ids = ucm_block_ids[hbm_hit_block_num:]
@@ -422,12 +425,15 @@ def get_num_new_matched_tokens(
422425

423426
total_hit_block_num = hbm_hit_block_num + external_hit_blocks
424427

425-
external_hit_tokens = external_hit_blocks * self.block_size
428+
external_hit_tokens = 0
429+
if external_hit_blocks > 0:
430+
remainder = num_computed_tokens % self.chunk_size
431+
external_hit_tokens = external_hit_blocks * self.chunk_size - remainder
426432

427433
# When all the tokens are cached in ssd or hbm,
428434
# we need to recompute the last token. This if condition will be removed
429435
# once vLLM scheduler provides a better solution in the future.
430-
num_total_hit_tokens = total_hit_block_num * self.block_size
436+
num_total_hit_tokens = external_hit_tokens + num_computed_tokens
431437
if num_total_hit_tokens == request.num_tokens:
432438
external_hit_tokens -= 1
433439

@@ -474,13 +480,19 @@ def _generate_dispatch_meta(
474480
dump_ucm_block_ids, dump_vllm_block_ids = [], []
475481
if need_load:
476482
load_ucm_block_ids = ucm_block_ids[hbm_hit_block_num:total_hit_block_num]
477-
load_vllm_block_ids = vllm_block_ids[hbm_hit_block_num:total_hit_block_num]
483+
load_vllm_block_ids = vllm_block_ids[
484+
hbm_hit_block_num
485+
* self.blocks_per_chunk : total_hit_block_num
486+
* self.blocks_per_chunk
487+
]
478488

479489
if req_meta.token_processed < req_meta.num_token_ids:
480-
start_idx = req_meta.token_processed // self.block_size
481-
end_idx = (req_meta.token_processed + new_tokens) // self.block_size
490+
start_idx = req_meta.token_processed // self.chunk_size
491+
end_idx = (req_meta.token_processed + new_tokens) // self.chunk_size
482492
dump_ucm_block_ids = ucm_block_ids[start_idx:end_idx]
483-
dump_vllm_block_ids = req_meta.vllm_block_ids[start_idx:end_idx]
493+
dump_vllm_block_ids = req_meta.vllm_block_ids[
494+
start_idx * self.blocks_per_chunk : end_idx * self.blocks_per_chunk
495+
]
484496
req_meta.token_processed += new_tokens
485497

486498
return RequestDispatchMeta(
@@ -569,7 +581,10 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
569581
for i, ucm_block_id in enumerate(ucm_block_ids):
570582
ucm_block_ids[i] = self.request_hasher(ucm_block_id)
571583
total_ptrs = self.kv_cache_layout.extract_block_addrs(vllm_block_ids)
572-
total_ptrs = total_ptrs.reshape(total_ptrs.shape[0], -1)
584+
total_ptrs = total_ptrs.reshape(
585+
total_ptrs.shape[0] // self.blocks_per_chunk, -1
586+
)
587+
assert total_ptrs.shape[0] == len(ucm_block_ids)
573588
shard_indexs = [0] * len(ucm_block_ids)
574589
try:
575590
task = self.store.load_data(ucm_block_ids, shard_indexs, total_ptrs)
@@ -662,7 +677,10 @@ def wait_for_save(self) -> None:
662677

663678
if is_save:
664679
total_ptrs = self.kv_cache_layout.extract_block_addrs(total_vllm_block_ids)
665-
total_ptrs = total_ptrs.reshape(total_ptrs.shape[0], -1)
680+
total_ptrs = total_ptrs.reshape(
681+
total_ptrs.shape[0] // self.blocks_per_chunk, -1
682+
)
683+
assert total_ptrs.shape[0] == len(total_ucm_block_ids)
666684
shard_indexs = [0] * len(total_ucm_block_ids)
667685
try:
668686
event_handle = self._get_dump_event_handle()
@@ -777,6 +795,14 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
777795
total_ptrs = self.kv_cache_layout.extract_block_addrs(
778796
vllm_block_ids, layer_first=True
779797
)
798+
# (n_layers, num_blocks, n_ptrs) -> (n_layers, num_blocks//bpc, bpc*n_ptrs)
799+
n_layers, n_blocks, n_ptrs = total_ptrs.shape
800+
total_ptrs = total_ptrs.reshape(
801+
n_layers,
802+
n_blocks // self.blocks_per_chunk,
803+
self.blocks_per_chunk * n_ptrs,
804+
)
805+
assert total_ptrs.shape[1] == len(ucm_block_ids)
780806
self.request_data.append((request_id, ucm_block_ids, total_ptrs))
781807

782808
if self.need_load:
@@ -843,6 +869,14 @@ def save_kv_layer(
843869
self.dump_total_ptrs = self.kv_cache_layout.extract_block_addrs(
844870
total_vllm_block_ids, layer_first=True
845871
)
872+
# (n_layers, num_blocks, n_ptrs) -> (n_layers, num_blocks//bpc, bpc*n_ptrs)
873+
n_layers, n_blocks, n_ptrs = self.dump_total_ptrs.shape
874+
self.dump_total_ptrs = self.dump_total_ptrs.reshape(
875+
n_layers,
876+
n_blocks // self.blocks_per_chunk,
877+
self.blocks_per_chunk * n_ptrs,
878+
)
879+
assert self.dump_total_ptrs.shape[1] == len(total_ucm_block_ids)
846880
shard_indexs = [layer_id] * len(total_ucm_block_ids)
847881
try:
848882
layer_ptrs = np.ascontiguousarray(self.dump_total_ptrs[local_layer_id])

0 commit comments

Comments
 (0)