Skip to content

Commit 6000d75

Browse files
authored
[opt] adapt cache blend for store and sparse's new version (#664)
1 parent 3ccbd66 commit 6000d75

6 files changed

Lines changed: 224 additions & 200 deletions

File tree

examples/offline_inference_blend.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,10 @@ def build_llm_with_uc(module_path: str, name: str, model: str):
9696
"ucm_connector_name": "UcmNfsStore",
9797
"ucm_connector_config": {
9898
"storage_backends": data_dir,
99-
"kv_block_size": 33554432,
99+
"use_direct": False,
100100
},
101101
}
102102
],
103-
"load_only_first_rank": False,
104103
"ucm_sparse_config": {
105104
"Blend": {
106105
"chunk_end_token_id": chunk_end_token_id,
@@ -111,7 +110,6 @@ def build_llm_with_uc(module_path: str, name: str, model: str):
111110
},
112111
}
113112
},
114-
"use_layerwise": True,
115113
},
116114
)
117115

@@ -261,10 +259,11 @@ def main():
261259

262260
print(f"Baseline generated text: {baseline_gen_text!r}")
263261
print(f"Baseline generated cost time: {baseline_time:.2f} seconds")
264-
print(f"Blend generated text: {blend_gen_text!r}")
265-
print(f"Blend generated cost time: {blend_time:.2f} seconds")
266262
print(f"Prefix Cache generated text: {pc_gen_text!r}")
267263
print(f"Prefix Cache generated cost time: {pc_time:.2f} seconds")
264+
print(f"Blend generated text: {blend_gen_text!r}")
265+
print(f"Blend generated cost time: {blend_time:.2f} seconds")
266+
268267
print(f"Question:{dataset_row['input']}")
269268
print(f"Golden answer:{dataset_row["answers"]}")
270269

ucm/integration/vllm/blend_connector.py

Lines changed: 34 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414

1515
from ucm.integration.vllm.ucm_connector import (
1616
RequestDispatchMeta,
17-
RequestHasher,
18-
RequestMeta,
1917
UCMConnectorMetadata,
2018
UCMDirectConnector,
2119
)
@@ -41,7 +39,7 @@ class ChunkMetaData:
4139
cached_start_position: int
4240

4341
vllm_blk_ids: List[int] = field(default_factory=list)
44-
chunk_blks_hash: List[str] = field(default_factory=list)
42+
chunk_blks_hash: List[bytes] = field(default_factory=list)
4543
store_hits: List[bool] = field(default_factory=list)
4644

4745
@property
@@ -65,7 +63,7 @@ def hits_vllm_blk_ids(self) -> List[int]:
6563
return list(itertools.compress(self.vllm_blk_ids, self.store_hits))
6664

6765
@property
68-
def hits_chunk_blks_hash(self) -> List[str]:
66+
def hits_chunk_blks_hash(self) -> List[bytes]:
6967
return list(itertools.compress(self.chunk_blks_hash, self.store_hits))
7068

7169
def merge_chunk(self, temp_chunk_meta: Self) -> None:
@@ -102,7 +100,7 @@ def is_prefix_cache(self) -> bool:
102100

103101
@dataclass
104102
class BlendRequestMeta:
105-
ucm_block_hashs: list[str] = field(default_factory=list)
103+
ucm_block_hashs: list[bytes] = field(default_factory=list)
106104
# hbm pc is not supported
107105
hbm_hit_block_num: int = 0
108106
# ucm pc is supported
@@ -138,35 +136,15 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
138136
else:
139137
raise "UCMBlendConnector init failed, please check your config"
140138

141-
self.ucm_chunk_end_hash: int = self.request_hasher("UCM_CHUNK_END_HASH")
142-
self.ucm_chunk_continue_hash: int = self.request_hasher(
143-
"UCM_CHUNK_CONTINUE_HASH"
144-
)
145139
self.requests_blend_meta: dict[str, BlendRequestMeta] = {}
146140
self.cos_sin_cache: torch.Tensor = None
147141

148142
# if chunk cache hits less than min_blend_threshold, no need to cache blend
149143
self.min_blend_threshold = 16
150144

151-
def _generate_hash(
152-
self, block_size: int, token_ids: list[int], parent_block_hash_value: int
153-
) -> list[str]:
154-
ret = []
155-
for start in range(0, len(token_ids), block_size):
156-
end = start + block_size
157-
block_token_ids = token_ids[start:end]
158-
# Do not hash the block if it is not full.
159-
if len(block_token_ids) < block_size:
160-
break
161-
162-
block_token_ids_tuple = tuple(block_token_ids)
163-
hash_value = self.request_hasher(
164-
(parent_block_hash_value, block_token_ids_tuple)
165-
)
166-
parent_block_hash_value = hash_value
167-
ret.append(str(hash_value))
168-
169-
return ret
145+
# post process delta rope meta
146+
self.delta_rope_vllm_ids: torch.Tensor = None
147+
self.delta_rope_positions: torch.Tensor = None
170148

171149
def _process_req(self, all_token_ids: List[int]):
172150
"""
@@ -177,8 +155,8 @@ def _process_req(self, all_token_ids: List[int]):
177155
finally, if there are quite many chunk block-hits, we do cache blend to get TTFT-promot
178156
"""
179157
chunks_meta = []
180-
prefix_block_hashes = self._generate_hash(
181-
self.block_size, all_token_ids, RequestHasher._SEED_HASH
158+
prefix_block_hashes = self.generate_hash(
159+
self.block_size, all_token_ids, self._seed
182160
)
183161
if (
184162
all_token_ids[-1] == self.chunk_end_token_id
@@ -203,8 +181,8 @@ def _process_req(self, all_token_ids: List[int]):
203181
# but this will bring lots of modification to engine.
204182
if all_token_ids[end_token_idx] == self.chunk_end_token_id:
205183
chunk_token_ids = all_token_ids[start_token_dix : end_token_idx + 1]
206-
chunk_blks_hash = self._generate_hash(
207-
self.block_size, chunk_token_ids, RequestHasher._SEED_HASH
184+
chunk_blks_hash = self.generate_hash(
185+
self.block_size, chunk_token_ids, self._seed
208186
)
209187

210188
chunk_blks_len = end_blk_idx - start_blk_idx + 1
@@ -371,7 +349,7 @@ def get_num_new_matched_tokens(
371349
) -> tuple[int, bool]:
372350

373351
# current not support HBM prefix cache, cause the blended cached have a ground view of all chunks
374-
# so they can not apply to other req
352+
# so they can not be applied to other req
375353
assert num_computed_tokens == 0
376354
all_token_ids = request.all_token_ids
377355

@@ -488,21 +466,34 @@ def build_connector_meta(
488466

489467
return UCMBlendConnectorMetadata(requests_dispatch_meta)
490468

491-
def wait_for_layer_load(self, layer_name: str) -> None:
492-
metadata = self._get_connector_metadata()
493-
assert isinstance(metadata, UCMBlendConnectorMetadata)
494-
469+
def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
470+
"""
471+
Blend need build post process meta for loaded kv cache
472+
"""
473+
super().bind_connector_metadata(connector_metadata)
495474
all_hits_vllm_ids = []
496475
positions = []
497-
k_cache = self.kv_caches[layer_name][0]
498-
for request_id, request in metadata.request_meta.items():
476+
for request_id, request in connector_metadata.request_meta.items():
499477
for chunk_meta in request.chunks_meta:
500478
all_hits_vllm_ids.extend(chunk_meta.hits_vllm_blk_ids)
501479
positions.extend(
502480
[chunk_meta.position_offset] * len(chunk_meta.hits_vllm_blk_ids)
503481
)
504482
if all_hits_vllm_ids:
505-
vllm_ids = torch.tensor(all_hits_vllm_ids, device=k_cache.device)
506-
positions = torch.tensor(positions, device=k_cache.device)
507-
self._post_process_chunk_cache(k_cache, vllm_ids, positions)
508-
pass
483+
self.delta_rope_vllm_ids = torch.tensor(
484+
all_hits_vllm_ids, device=self.device
485+
)
486+
self.delta_rope_positions = torch.tensor(positions, device=self.device)
487+
488+
def clear_connector_metadata(self) -> None:
489+
"""Clear the post process meta"""
490+
super().clear_connector_metadata()
491+
self.delta_rope_vllm_ids = None
492+
self.delta_rope_positions = None
493+
494+
def wait_for_layer_load(self, layer_name: str) -> None:
495+
if self.delta_rope_vllm_ids is not None:
496+
k_cache = self.kv_caches[layer_name][0]
497+
self._post_process_chunk_cache(
498+
k_cache, self.delta_rope_vllm_ids, self.delta_rope_positions
499+
)

0 commit comments

Comments
 (0)