1414
1515from ucm .integration .vllm .ucm_connector import (
1616 RequestDispatchMeta ,
17- RequestHasher ,
18- RequestMeta ,
1917 UCMConnectorMetadata ,
2018 UCMDirectConnector ,
2119)
@@ -41,7 +39,7 @@ class ChunkMetaData:
4139 cached_start_position : int
4240
4341 vllm_blk_ids : List [int ] = field (default_factory = list )
44- chunk_blks_hash : List [str ] = field (default_factory = list )
42+ chunk_blks_hash : List [bytes ] = field (default_factory = list )
4543 store_hits : List [bool ] = field (default_factory = list )
4644
4745 @property
@@ -65,7 +63,7 @@ def hits_vllm_blk_ids(self) -> List[int]:
6563 return list (itertools .compress (self .vllm_blk_ids , self .store_hits ))
6664
6765 @property
68- def hits_chunk_blks_hash (self ) -> List [str ]:
66+ def hits_chunk_blks_hash (self ) -> List [bytes ]:
6967 return list (itertools .compress (self .chunk_blks_hash , self .store_hits ))
7068
7169 def merge_chunk (self , temp_chunk_meta : Self ) -> None :
@@ -102,7 +100,7 @@ def is_prefix_cache(self) -> bool:
102100
103101@dataclass
104102class BlendRequestMeta :
105- ucm_block_hashs : list [str ] = field (default_factory = list )
103+ ucm_block_hashs : list [bytes ] = field (default_factory = list )
106104 # hbm pc is not supported
107105 hbm_hit_block_num : int = 0
108106 # ucm pc is supported
@@ -138,35 +136,15 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
138136 else :
139137 raise "UCMBlendConnector init failed, please check your config"
140138
141- self .ucm_chunk_end_hash : int = self .request_hasher ("UCM_CHUNK_END_HASH" )
142- self .ucm_chunk_continue_hash : int = self .request_hasher (
143- "UCM_CHUNK_CONTINUE_HASH"
144- )
145139 self .requests_blend_meta : dict [str , BlendRequestMeta ] = {}
146140 self .cos_sin_cache : torch .Tensor = None
147141
148142 # if chunk cache hits less than min_blend_threshold, no need to cache blend
149143 self .min_blend_threshold = 16
150144
151- def _generate_hash (
152- self , block_size : int , token_ids : list [int ], parent_block_hash_value : int
153- ) -> list [str ]:
154- ret = []
155- for start in range (0 , len (token_ids ), block_size ):
156- end = start + block_size
157- block_token_ids = token_ids [start :end ]
158- # Do not hash the block if it is not full.
159- if len (block_token_ids ) < block_size :
160- break
161-
162- block_token_ids_tuple = tuple (block_token_ids )
163- hash_value = self .request_hasher (
164- (parent_block_hash_value , block_token_ids_tuple )
165- )
166- parent_block_hash_value = hash_value
167- ret .append (str (hash_value ))
168-
169- return ret
145+ # post process delta rope meta
146+ self .delta_rope_vllm_ids : torch .Tensor = None
147+ self .delta_rope_positions : torch .Tensor = None
170148
171149 def _process_req (self , all_token_ids : List [int ]):
172150 """
@@ -177,8 +155,8 @@ def _process_req(self, all_token_ids: List[int]):
177155 finally, if there are quite many chunk block-hits, we do cache blend to get TTFT-promot
178156 """
179157 chunks_meta = []
180- prefix_block_hashes = self ._generate_hash (
181- self .block_size , all_token_ids , RequestHasher . _SEED_HASH
158+ prefix_block_hashes = self .generate_hash (
159+ self .block_size , all_token_ids , self . _seed
182160 )
183161 if (
184162 all_token_ids [- 1 ] == self .chunk_end_token_id
@@ -203,8 +181,8 @@ def _process_req(self, all_token_ids: List[int]):
203181 # but this will bring lots of modification to engine.
204182 if all_token_ids [end_token_idx ] == self .chunk_end_token_id :
205183 chunk_token_ids = all_token_ids [start_token_dix : end_token_idx + 1 ]
206- chunk_blks_hash = self ._generate_hash (
207- self .block_size , chunk_token_ids , RequestHasher . _SEED_HASH
184+ chunk_blks_hash = self .generate_hash (
185+ self .block_size , chunk_token_ids , self . _seed
208186 )
209187
210188 chunk_blks_len = end_blk_idx - start_blk_idx + 1
@@ -371,7 +349,7 @@ def get_num_new_matched_tokens(
371349 ) -> tuple [int , bool ]:
372350
373351 # current not support HBM prefix cache, cause the blended cached have a ground view of all chunks
374- # so they can not apply to other req
352+ # so they can not be applied to other req
375353 assert num_computed_tokens == 0
376354 all_token_ids = request .all_token_ids
377355
@@ -488,21 +466,34 @@ def build_connector_meta(
488466
489467 return UCMBlendConnectorMetadata (requests_dispatch_meta )
490468
491- def wait_for_layer_load (self , layer_name : str ) -> None :
492- metadata = self ._get_connector_metadata ()
493- assert isinstance (metadata , UCMBlendConnectorMetadata )
494-
469+ def bind_connector_metadata (self , connector_metadata : KVConnectorMetadata ) -> None :
470+ """
471+ Blend need build post process meta for loaded kv cache
472+ """
473+ super ().bind_connector_metadata (connector_metadata )
495474 all_hits_vllm_ids = []
496475 positions = []
497- k_cache = self .kv_caches [layer_name ][0 ]
498- for request_id , request in metadata .request_meta .items ():
476+ for request_id , request in connector_metadata .request_meta .items ():
499477 for chunk_meta in request .chunks_meta :
500478 all_hits_vllm_ids .extend (chunk_meta .hits_vllm_blk_ids )
501479 positions .extend (
502480 [chunk_meta .position_offset ] * len (chunk_meta .hits_vllm_blk_ids )
503481 )
504482 if all_hits_vllm_ids :
505- vllm_ids = torch .tensor (all_hits_vllm_ids , device = k_cache .device )
506- positions = torch .tensor (positions , device = k_cache .device )
507- self ._post_process_chunk_cache (k_cache , vllm_ids , positions )
508- pass
483+ self .delta_rope_vllm_ids = torch .tensor (
484+ all_hits_vllm_ids , device = self .device
485+ )
486+ self .delta_rope_positions = torch .tensor (positions , device = self .device )
487+
488+ def clear_connector_metadata (self ) -> None :
489+ """Clear the post process meta"""
490+ super ().clear_connector_metadata ()
491+ self .delta_rope_vllm_ids = None
492+ self .delta_rope_positions = None
493+
494+ def wait_for_layer_load (self , layer_name : str ) -> None :
495+ if self .delta_rope_vllm_ids is not None :
496+ k_cache = self .kv_caches [layer_name ][0 ]
497+ self ._post_process_chunk_cache (
498+ k_cache , self .delta_rope_vllm_ids , self .delta_rope_positions
499+ )
0 commit comments