Skip to content

Commit d418d54

Browse files
committed
Refactor UCMLite connector hash logging
1 parent c8f085b commit d418d54

1 file changed

Lines changed: 101 additions & 48 deletions

File tree

ucm/integration/vllm/ucm_connector.py

Lines changed: 101 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1250,72 +1250,125 @@ def get_num_new_matched_tokens(
12501250
return expect_hit_block_num * self.block_size, False
12511251

12521252

1253-
class UCMLiteConnector(UCMDirectConnector):
1253+
class UCMLiteConnector(KVConnectorBase_V1):
12541254
def __init__(
12551255
self,
12561256
vllm_config,
12571257
role,
12581258
kv_cache_config: Optional["KVCacheConfig"] = None,
12591259
):
1260-
ucm_config = Config(vllm_config.kv_transfer_config)
1261-
launch_config = ucm_config.get_config()
1262-
enable_record_traces = launch_config.get("enable_record_traces", False)
1263-
persist_token_threshold = launch_config.get("persist_token_threshold", 0)
1264-
vllm_config.kv_transfer_config.kv_connector_extra_config = {
1265-
"ucm_connectors": [
1266-
{
1267-
"ucm_connector_name": "UcmPipelineStore",
1268-
"ucm_connector_config": {
1269-
"store_pipeline": "Fake",
1270-
"share_buffer_enable": True,
1271-
"buffer_number": 244032232,
1272-
},
1273-
}
1274-
],
1275-
"enable_record_traces": enable_record_traces,
1276-
"persist_token_threshold": persist_token_threshold,
1277-
"use_lite": True,
1278-
}
1279-
super().__init__(vllm_config, role, kv_cache_config)
1260+
self.block_size = vllm_config.cache_config.block_size
1261+
self.hash_block_size = self.block_size
1262+
self.requests_meta: dict[str, RequestMeta] = {}
12801263
self.total_block_nums = 0
1281-
self.total_hit_block_nums = 0
1264+
1265+
self.request_hasher = RequestHasher(vllm_config, 0)
1266+
self._seed = self.request_hasher("UCM_HASH_SEED")
1267+
1268+
super().__init__(vllm_config, role, kv_cache_config)
1269+
12821270
logger.info("Init UCMLiteConnector.")
12831271

12841272
def get_num_new_matched_tokens(self, request, num_computed_tokens):
1285-
super().get_num_new_matched_tokens(request, num_computed_tokens)
1286-
1287-
external_hit_blocks = 0
12881273
req_blocks_num = len(request.all_token_ids) // self.hash_block_size
12891274
if req_blocks_num < 1:
12901275
return 0, False
1291-
self.total_block_nums += req_blocks_num
1292-
if request.request_id in self.requests_meta:
1293-
request_meta = self.requests_meta[request.request_id]
1294-
external_hit_blocks = (
1295-
request_meta.total_hit_block_num - request_meta.hbm_hit_block_num
1276+
if request.request_id not in self.requests_meta:
1277+
hash_start = time.perf_counter()
1278+
ucm_block_ids = self.generate_hash(
1279+
self.hash_block_size, request.all_token_ids, self._seed
12961280
)
1297-
need_dump_blks = request_meta.ucm_block_ids[
1298-
request_meta.total_hit_block_num :
1299-
]
1300-
shard_indexs = [0] * len(need_dump_blks)
1301-
total_ptrs = [[0]] * len(need_dump_blks)
1302-
try:
1303-
task = self.store.dump_data(need_dump_blks, shard_indexs, total_ptrs)
1304-
self.store.wait(task)
1305-
except Exception as e:
1306-
logger.error(
1307-
f"request {request.request_id} wait dump task error. {type(e).__name__}: {e}"
1308-
)
1309-
self.requests_meta[request.request_id] = RequestMeta()
1281+
hash_end = time.perf_counter()
1282+
hash_time_ms = (hash_end - hash_start) * 1000.0
13101283

1311-
self.total_hit_block_nums += external_hit_blocks
1284+
# prepare hex ids and log the canonical info line
1285+
print_start = time.perf_counter()
1286+
hex_ucm_block_ids = [b.hex() for b in ucm_block_ids]
1287+
print_time_ms = (time.perf_counter() - print_start) * 1000.0
1288+
logger.info(
1289+
f"timestamp: {time.perf_counter()}, "
1290+
f"request_id: {request.request_id}, "
1291+
f"input_length: {request.num_tokens}, "
1292+
f"output_length: {request.max_tokens}, "
1293+
f"hash_time_ms: {hash_time_ms:.3f}, "
1294+
f"print_time_ms: {print_time_ms:.3f}, "
1295+
f"ucm_block_ids: {hex_ucm_block_ids}"
1296+
)
1297+
1298+
# store minimal RequestMeta for scheduler bookkeeping
1299+
self.requests_meta[request.request_id] = RequestMeta(
1300+
ucm_block_ids=ucm_block_ids,
1301+
hbm_hit_block_num=0,
1302+
total_hit_block_num=0,
1303+
num_token_ids=len(request.all_token_ids),
1304+
token_processed=0,
1305+
)
1306+
1307+
self.total_block_nums += req_blocks_num
13121308

1313-
logger.info(
1314-
f"req external hit rate: {(external_hit_blocks / req_blocks_num):.2f}, "
1315-
f"total external hit rate: {(self.total_hit_block_nums / self.total_block_nums):.2f}"
1316-
)
13171309
return 0, False
13181310

1311+
def update_state_after_alloc(
1312+
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
1313+
):
1314+
pass
1315+
1316+
def build_connector_meta(
1317+
self, scheduler_output: SchedulerOutput
1318+
) -> KVConnectorMetadata:
1319+
for request_id in scheduler_output.finished_req_ids:
1320+
self.requests_meta.pop(request_id, None)
1321+
return UCMConnectorMetadata()
1322+
1323+
def request_finished(
1324+
self,
1325+
request: "Request",
1326+
block_ids: list[int],
1327+
) -> tuple[bool, dict[str, Any] | None]:
1328+
self.requests_meta.pop(request.request_id, None)
1329+
return False, None
1330+
1331+
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
1332+
pass
1333+
1334+
def wait_for_layer_load(self, layer_name: str) -> None:
1335+
pass
1336+
1337+
def save_kv_layer(
1338+
self,
1339+
layer_name: str,
1340+
kv_layer: torch.Tensor,
1341+
attn_metadata: "AttentionMetadata",
1342+
**kwargs: Any,
1343+
) -> None:
1344+
pass
1345+
1346+
def wait_for_save(self):
1347+
pass
1348+
1349+
def generate_hash(
1350+
self,
1351+
block_size: int,
1352+
token_ids: List[int],
1353+
parent_block_hash_value: bytes,
1354+
) -> list[bytes]:
1355+
ret = []
1356+
for start in range(0, len(token_ids), block_size):
1357+
end = start + block_size
1358+
block_token_ids = token_ids[start:end]
1359+
# Do not hash the block if it is not full.
1360+
if len(block_token_ids) < block_size:
1361+
break
1362+
1363+
block_token_ids_tuple = tuple(block_token_ids)
1364+
hash_value = self.request_hasher(
1365+
(parent_block_hash_value, block_token_ids_tuple)
1366+
)
1367+
parent_block_hash_value = hash_value
1368+
ret.append(hash_value)
1369+
1370+
return ret
1371+
13191372

13201373
def layer_name_to_kv_cache_spec(
13211374
kv_cache_config: KVCacheConfig,

0 commit comments

Comments
 (0)