@@ -1250,72 +1250,125 @@ def get_num_new_matched_tokens(
12501250 return expect_hit_block_num * self .block_size , False
12511251
12521252
1253- class UCMLiteConnector (UCMDirectConnector ):
1253+ class UCMLiteConnector (KVConnectorBase_V1 ):
12541254 def __init__ (
12551255 self ,
12561256 vllm_config ,
12571257 role ,
12581258 kv_cache_config : Optional ["KVCacheConfig" ] = None ,
12591259 ):
1260- ucm_config = Config (vllm_config .kv_transfer_config )
1261- launch_config = ucm_config .get_config ()
1262- enable_record_traces = launch_config .get ("enable_record_traces" , False )
1263- persist_token_threshold = launch_config .get ("persist_token_threshold" , 0 )
1264- vllm_config .kv_transfer_config .kv_connector_extra_config = {
1265- "ucm_connectors" : [
1266- {
1267- "ucm_connector_name" : "UcmPipelineStore" ,
1268- "ucm_connector_config" : {
1269- "store_pipeline" : "Fake" ,
1270- "share_buffer_enable" : True ,
1271- "buffer_number" : 244032232 ,
1272- },
1273- }
1274- ],
1275- "enable_record_traces" : enable_record_traces ,
1276- "persist_token_threshold" : persist_token_threshold ,
1277- "use_lite" : True ,
1278- }
1279- super ().__init__ (vllm_config , role , kv_cache_config )
1260+ self .block_size = vllm_config .cache_config .block_size
1261+ self .hash_block_size = self .block_size
1262+ self .requests_meta : dict [str , RequestMeta ] = {}
12801263 self .total_block_nums = 0
1281- self .total_hit_block_nums = 0
1264+
1265+ self .request_hasher = RequestHasher (vllm_config , 0 )
1266+ self ._seed = self .request_hasher ("UCM_HASH_SEED" )
1267+
1268+ super ().__init__ (vllm_config , role , kv_cache_config )
1269+
12821270 logger .info ("Init UCMLiteConnector." )
12831271
12841272 def get_num_new_matched_tokens (self , request , num_computed_tokens ):
1285- super ().get_num_new_matched_tokens (request , num_computed_tokens )
1286-
1287- external_hit_blocks = 0
12881273 req_blocks_num = len (request .all_token_ids ) // self .hash_block_size
12891274 if req_blocks_num < 1 :
12901275 return 0 , False
1291- self .total_block_nums += req_blocks_num
1292- if request .request_id in self .requests_meta :
1293- request_meta = self .requests_meta [request .request_id ]
1294- external_hit_blocks = (
1295- request_meta .total_hit_block_num - request_meta .hbm_hit_block_num
1276+ if request .request_id not in self .requests_meta :
1277+ hash_start = time .perf_counter ()
1278+ ucm_block_ids = self .generate_hash (
1279+ self .hash_block_size , request .all_token_ids , self ._seed
12961280 )
1297- need_dump_blks = request_meta .ucm_block_ids [
1298- request_meta .total_hit_block_num :
1299- ]
1300- shard_indexs = [0 ] * len (need_dump_blks )
1301- total_ptrs = [[0 ]] * len (need_dump_blks )
1302- try :
1303- task = self .store .dump_data (need_dump_blks , shard_indexs , total_ptrs )
1304- self .store .wait (task )
1305- except Exception as e :
1306- logger .error (
1307- f"request { request .request_id } wait dump task error. { type (e ).__name__ } : { e } "
1308- )
1309- self .requests_meta [request .request_id ] = RequestMeta ()
1281+ hash_end = time .perf_counter ()
1282+ hash_time_ms = (hash_end - hash_start ) * 1000.0
13101283
1311- self .total_hit_block_nums += external_hit_blocks
1284+ # prepare hex ids and log the canonical info line
1285+ print_start = time .perf_counter ()
1286+ hex_ucm_block_ids = [b .hex () for b in ucm_block_ids ]
1287+ print_time_ms = (time .perf_counter () - print_start ) * 1000.0
1288+ logger .info (
1289+ f"timestamp: { time .perf_counter ()} , "
1290+ f"request_id: { request .request_id } , "
1291+ f"input_length: { request .num_tokens } , "
1292+ f"output_length: { request .max_tokens } , "
1293+ f"hash_time_ms: { hash_time_ms :.3f} , "
1294+ f"print_time_ms: { print_time_ms :.3f} , "
1295+ f"ucm_block_ids: { hex_ucm_block_ids } "
1296+ )
1297+
1298+ # store minimal RequestMeta for scheduler bookkeeping
1299+ self .requests_meta [request .request_id ] = RequestMeta (
1300+ ucm_block_ids = ucm_block_ids ,
1301+ hbm_hit_block_num = 0 ,
1302+ total_hit_block_num = 0 ,
1303+ num_token_ids = len (request .all_token_ids ),
1304+ token_processed = 0 ,
1305+ )
1306+
1307+ self .total_block_nums += req_blocks_num
13121308
1313- logger .info (
1314- f"req external hit rate: { (external_hit_blocks / req_blocks_num ):.2f} , "
1315- f"total external hit rate: { (self .total_hit_block_nums / self .total_block_nums ):.2f} "
1316- )
13171309 return 0 , False
13181310
1311+ def update_state_after_alloc (
1312+ self , request : "Request" , blocks : "KVCacheBlocks" , num_external_tokens : int
1313+ ):
1314+ pass
1315+
1316+ def build_connector_meta (
1317+ self , scheduler_output : SchedulerOutput
1318+ ) -> KVConnectorMetadata :
1319+ for request_id in scheduler_output .finished_req_ids :
1320+ self .requests_meta .pop (request_id , None )
1321+ return UCMConnectorMetadata ()
1322+
1323+ def request_finished (
1324+ self ,
1325+ request : "Request" ,
1326+ block_ids : list [int ],
1327+ ) -> tuple [bool , dict [str , Any ] | None ]:
1328+ self .requests_meta .pop (request .request_id , None )
1329+ return False , None
1330+
1331+ def start_load_kv (self , forward_context : "ForwardContext" , ** kwargs : Any ) -> None :
1332+ pass
1333+
1334+ def wait_for_layer_load (self , layer_name : str ) -> None :
1335+ pass
1336+
1337+ def save_kv_layer (
1338+ self ,
1339+ layer_name : str ,
1340+ kv_layer : torch .Tensor ,
1341+ attn_metadata : "AttentionMetadata" ,
1342+ ** kwargs : Any ,
1343+ ) -> None :
1344+ pass
1345+
1346+ def wait_for_save (self ):
1347+ pass
1348+
1349+ def generate_hash (
1350+ self ,
1351+ block_size : int ,
1352+ token_ids : List [int ],
1353+ parent_block_hash_value : bytes ,
1354+ ) -> list [bytes ]:
1355+ ret = []
1356+ for start in range (0 , len (token_ids ), block_size ):
1357+ end = start + block_size
1358+ block_token_ids = token_ids [start :end ]
1359+ # Do not hash the block if it is not full.
1360+ if len (block_token_ids ) < block_size :
1361+ break
1362+
1363+ block_token_ids_tuple = tuple (block_token_ids )
1364+ hash_value = self .request_hasher (
1365+ (parent_block_hash_value , block_token_ids_tuple )
1366+ )
1367+ parent_block_hash_value = hash_value
1368+ ret .append (hash_value )
1369+
1370+ return ret
1371+
13191372
13201373def layer_name_to_kv_cache_spec (
13211374 kv_cache_config : KVCacheConfig ,
0 commit comments