@@ -63,12 +63,12 @@ class RequestDispatchMeta:
6363
6464class KVCacheLayout :
6565 def __init__ (
66- self , kvcaches , use_layerwise : bool , vllm_config : "VllmConfig"
66+ self , kvcaches , launch_config : dict , vllm_config : "VllmConfig"
6767 ) -> None :
6868 # each row is a layer, each column is a tensor_size/ptr in the layer (e.g., k, v, rope, k_index)
6969 self .base_ptrs : np .ndarray # (n_layers, n_ptrs)
7070 self .tensor_size_lists : np .ndarray # (n_layers, n_tensor_sizes)
71- self .use_layerwise = use_layerwise
71+ self .use_layerwise = launch_config . get ( " use_layerwise" , False )
7272 self .vllm_config = vllm_config
7373 self .pp_size = self .vllm_config .parallel_config .pipeline_parallel_size
7474 self .num_hidden_layers = getattr (
@@ -246,7 +246,10 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
246246 self .enable_event_sync = self .launch_config .get ("enable_event_sync" , True )
247247 assert len (self .connector_configs ) > 0 , "no storage connector name in config."
248248
249- self .chunk_size = self .block_size
249+ self .chunk_size = self .launch_config .get ("chunk_size" , self .block_size )
250+ assert (
251+ self .chunk_size % self .block_size == 0
252+ ), "chunk_size must be divisible by block_size"
250253 self .blocks_per_chunk = self .chunk_size // self .block_size
251254
252255 if role == KVConnectorRole .SCHEDULER :
@@ -361,7 +364,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
361364 for i , tensor in enumerate (sample_kv_layer ):
362365 logger .info (f"kv cache shape { i } : { tensor .shape } " )
363366 self .kv_cache_layout = KVCacheLayout (
364- self .kv_caches , self .use_layerwise , self ._vllm_config
367+ self .kv_caches , self .launch_config , self ._vllm_config
365368 )
366369 self .block_data_size = self .kv_cache_layout .block_size
367370 self .layer_name_to_id = self .kv_cache_layout .layer_name_to_id
@@ -395,10 +398,10 @@ def get_num_new_matched_tokens(
395398 num_computed_tokens : int ,
396399 ) -> tuple [int , bool ]:
397400 assert num_computed_tokens % self .block_size == 0
398- hbm_hit_block_num = num_computed_tokens // self .block_size
401+ hbm_hit_block_num = num_computed_tokens // self .chunk_size
399402
400403 ucm_block_ids = self .generate_hash (
401- self .block_size , request .all_token_ids , self ._seed
404+ self .chunk_size , request .all_token_ids , self ._seed
402405 )
403406
404407 external_block_ids = ucm_block_ids [hbm_hit_block_num :]
@@ -422,12 +425,15 @@ def get_num_new_matched_tokens(
422425
423426 total_hit_block_num = hbm_hit_block_num + external_hit_blocks
424427
425- external_hit_tokens = external_hit_blocks * self .block_size
428+ external_hit_tokens = 0
429+ if external_hit_blocks > 0 :
430+ remainder = num_computed_tokens % self .chunk_size
431+ external_hit_tokens = external_hit_blocks * self .chunk_size - remainder
426432
427433 # When all the tokens are cached in ssd or hbm,
428434 # we need to recompute the last token. This if condition will be removed
429435 # once vLLM scheduler provides a better solution in the future.
430- num_total_hit_tokens = total_hit_block_num * self . block_size
436+ num_total_hit_tokens = external_hit_tokens + num_computed_tokens
431437 if num_total_hit_tokens == request .num_tokens :
432438 external_hit_tokens -= 1
433439
@@ -474,13 +480,19 @@ def _generate_dispatch_meta(
474480 dump_ucm_block_ids , dump_vllm_block_ids = [], []
475481 if need_load :
476482 load_ucm_block_ids = ucm_block_ids [hbm_hit_block_num :total_hit_block_num ]
477- load_vllm_block_ids = vllm_block_ids [hbm_hit_block_num :total_hit_block_num ]
483+ load_vllm_block_ids = vllm_block_ids [
484+ hbm_hit_block_num
485+ * self .blocks_per_chunk : total_hit_block_num
486+ * self .blocks_per_chunk
487+ ]
478488
479489 if req_meta .token_processed < req_meta .num_token_ids :
480- start_idx = req_meta .token_processed // self .block_size
481- end_idx = (req_meta .token_processed + new_tokens ) // self .block_size
490+ start_idx = req_meta .token_processed // self .chunk_size
491+ end_idx = (req_meta .token_processed + new_tokens ) // self .chunk_size
482492 dump_ucm_block_ids = ucm_block_ids [start_idx :end_idx ]
483- dump_vllm_block_ids = req_meta .vllm_block_ids [start_idx :end_idx ]
493+ dump_vllm_block_ids = req_meta .vllm_block_ids [
494+ start_idx * self .blocks_per_chunk : end_idx * self .blocks_per_chunk
495+ ]
484496 req_meta .token_processed += new_tokens
485497
486498 return RequestDispatchMeta (
@@ -569,7 +581,10 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
569581 for i , ucm_block_id in enumerate (ucm_block_ids ):
570582 ucm_block_ids [i ] = self .request_hasher (ucm_block_id )
571583 total_ptrs = self .kv_cache_layout .extract_block_addrs (vllm_block_ids )
572- total_ptrs = total_ptrs .reshape (total_ptrs .shape [0 ], - 1 )
584+ total_ptrs = total_ptrs .reshape (
585+ total_ptrs .shape [0 ] // self .blocks_per_chunk , - 1
586+ )
587+ assert total_ptrs .shape [0 ] == len (ucm_block_ids )
573588 shard_indexs = [0 ] * len (ucm_block_ids )
574589 try :
575590 task = self .store .load_data (ucm_block_ids , shard_indexs , total_ptrs )
@@ -662,7 +677,10 @@ def wait_for_save(self) -> None:
662677
663678 if is_save :
664679 total_ptrs = self .kv_cache_layout .extract_block_addrs (total_vllm_block_ids )
665- total_ptrs = total_ptrs .reshape (total_ptrs .shape [0 ], - 1 )
680+ total_ptrs = total_ptrs .reshape (
681+ total_ptrs .shape [0 ] // self .blocks_per_chunk , - 1
682+ )
683+ assert total_ptrs .shape [0 ] == len (total_ucm_block_ids )
666684 shard_indexs = [0 ] * len (total_ucm_block_ids )
667685 try :
668686 event_handle = self ._get_dump_event_handle ()
@@ -777,6 +795,14 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
777795 total_ptrs = self .kv_cache_layout .extract_block_addrs (
778796 vllm_block_ids , layer_first = True
779797 )
798+ # (n_layers, num_blocks, n_ptrs) -> (n_layers, num_blocks//bpc, bpc*n_ptrs)
799+ n_layers , n_blocks , n_ptrs = total_ptrs .shape
800+ total_ptrs = total_ptrs .reshape (
801+ n_layers ,
802+ n_blocks // self .blocks_per_chunk ,
803+ self .blocks_per_chunk * n_ptrs ,
804+ )
805+ assert total_ptrs .shape [1 ] == len (ucm_block_ids )
780806 self .request_data .append ((request_id , ucm_block_ids , total_ptrs ))
781807
782808 if self .need_load :
@@ -843,6 +869,14 @@ def save_kv_layer(
843869 self .dump_total_ptrs = self .kv_cache_layout .extract_block_addrs (
844870 total_vllm_block_ids , layer_first = True
845871 )
872+ # (n_layers, num_blocks, n_ptrs) -> (n_layers, num_blocks//bpc, bpc*n_ptrs)
873+ n_layers , n_blocks , n_ptrs = self .dump_total_ptrs .shape
874+ self .dump_total_ptrs = self .dump_total_ptrs .reshape (
875+ n_layers ,
876+ n_blocks // self .blocks_per_chunk ,
877+ self .blocks_per_chunk * n_ptrs ,
878+ )
879+ assert self .dump_total_ptrs .shape [1 ] == len (total_ucm_block_ids )
846880 shard_indexs = [layer_id ] * len (total_ucm_block_ids )
847881 try :
848882 layer_ptrs = np .ascontiguousarray (self .dump_total_ptrs [local_layer_id ])
0 commit comments