Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions examples/offline_inference_blend.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,10 @@ def build_llm_with_uc(module_path: str, name: str, model: str):
"ucm_connector_name": "UcmNfsStore",
"ucm_connector_config": {
"storage_backends": data_dir,
"kv_block_size": 33554432,
"use_direct": False,
},
}
],
"load_only_first_rank": False,
"ucm_sparse_config": {
"Blend": {
"chunk_end_token_id": chunk_end_token_id,
Expand All @@ -111,7 +110,6 @@ def build_llm_with_uc(module_path: str, name: str, model: str):
},
}
},
"use_layerwise": True,
},
)

Expand Down Expand Up @@ -261,10 +259,11 @@ def main():

print(f"Baseline generated text: {baseline_gen_text!r}")
print(f"Baseline generated cost time: {baseline_time:.2f} seconds")
print(f"Blend generated text: {blend_gen_text!r}")
print(f"Blend generated cost time: {blend_time:.2f} seconds")
print(f"Prefix Cache generated text: {pc_gen_text!r}")
print(f"Prefix Cache generated cost time: {pc_time:.2f} seconds")
print(f"Blend generated text: {blend_gen_text!r}")
print(f"Blend generated cost time: {blend_time:.2f} seconds")

print(f"Question:{dataset_row['input']}")
print(f"Golden answer:{dataset_row["answers"]}")

Expand Down
77 changes: 34 additions & 43 deletions ucm/integration/vllm/blend_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

from ucm.integration.vllm.ucm_connector import (
RequestDispatchMeta,
RequestHasher,
RequestMeta,
UCMConnectorMetadata,
UCMDirectConnector,
)
Expand All @@ -41,7 +39,7 @@ class ChunkMetaData:
cached_start_position: int

vllm_blk_ids: List[int] = field(default_factory=list)
chunk_blks_hash: List[str] = field(default_factory=list)
chunk_blks_hash: List[bytes] = field(default_factory=list)
store_hits: List[bool] = field(default_factory=list)

@property
Expand All @@ -65,7 +63,7 @@ def hits_vllm_blk_ids(self) -> List[int]:
return list(itertools.compress(self.vllm_blk_ids, self.store_hits))

@property
def hits_chunk_blks_hash(self) -> List[str]:
def hits_chunk_blks_hash(self) -> List[bytes]:
return list(itertools.compress(self.chunk_blks_hash, self.store_hits))

def merge_chunk(self, temp_chunk_meta: Self) -> None:
Expand Down Expand Up @@ -102,7 +100,7 @@ def is_prefix_cache(self) -> bool:

@dataclass
class BlendRequestMeta:
ucm_block_hashs: list[str] = field(default_factory=list)
ucm_block_hashs: list[bytes] = field(default_factory=list)
# hbm pc is not supported
hbm_hit_block_num: int = 0
# ucm pc is supported
Expand Down Expand Up @@ -138,35 +136,15 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
else:
raise "UCMBlendConnector init failed, please check your config"

self.ucm_chunk_end_hash: int = self.request_hasher("UCM_CHUNK_END_HASH")
self.ucm_chunk_continue_hash: int = self.request_hasher(
"UCM_CHUNK_CONTINUE_HASH"
)
self.requests_blend_meta: dict[str, BlendRequestMeta] = {}
self.cos_sin_cache: torch.Tensor = None

# if chunk cache hits less than min_blend_threshold, no need to cache blend
self.min_blend_threshold = 16

def _generate_hash(
self, block_size: int, token_ids: list[int], parent_block_hash_value: int
) -> list[str]:
ret = []
for start in range(0, len(token_ids), block_size):
end = start + block_size
block_token_ids = token_ids[start:end]
# Do not hash the block if it is not full.
if len(block_token_ids) < block_size:
break

block_token_ids_tuple = tuple(block_token_ids)
hash_value = self.request_hasher(
(parent_block_hash_value, block_token_ids_tuple)
)
parent_block_hash_value = hash_value
ret.append(str(hash_value))

return ret
# post process delta rope meta
self.delta_rope_vllm_ids: torch.Tensor = None
self.delta_rope_positions: torch.Tensor = None

def _process_req(self, all_token_ids: List[int]):
"""
Expand All @@ -177,8 +155,8 @@ def _process_req(self, all_token_ids: List[int]):
finally, if there are quite many chunk block-hits, we do cache blend to get TTFT-promot
"""
chunks_meta = []
prefix_block_hashes = self._generate_hash(
self.block_size, all_token_ids, RequestHasher._SEED_HASH
prefix_block_hashes = self.generate_hash(
self.block_size, all_token_ids, self._seed
)
if (
all_token_ids[-1] == self.chunk_end_token_id
Expand All @@ -203,8 +181,8 @@ def _process_req(self, all_token_ids: List[int]):
# but this will bring lots of modification to engine.
if all_token_ids[end_token_idx] == self.chunk_end_token_id:
chunk_token_ids = all_token_ids[start_token_dix : end_token_idx + 1]
chunk_blks_hash = self._generate_hash(
self.block_size, chunk_token_ids, RequestHasher._SEED_HASH
chunk_blks_hash = self.generate_hash(
self.block_size, chunk_token_ids, self._seed
)

chunk_blks_len = end_blk_idx - start_blk_idx + 1
Expand Down Expand Up @@ -371,7 +349,7 @@ def get_num_new_matched_tokens(
) -> tuple[int, bool]:

# current not support HBM prefix cache, cause the blended cached have a ground view of all chunks
# so they can not apply to other req
# so they can not be applied to other req
assert num_computed_tokens == 0
all_token_ids = request.all_token_ids

Expand Down Expand Up @@ -488,21 +466,34 @@ def build_connector_meta(

return UCMBlendConnectorMetadata(requests_dispatch_meta)

def wait_for_layer_load(self, layer_name: str) -> None:
metadata = self._get_connector_metadata()
assert isinstance(metadata, UCMBlendConnectorMetadata)

def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
"""
Blend need build post process meta for loaded kv cache
"""
super().bind_connector_metadata(connector_metadata)
all_hits_vllm_ids = []
positions = []
k_cache = self.kv_caches[layer_name][0]
for request_id, request in metadata.request_meta.items():
for request_id, request in connector_metadata.request_meta.items():
for chunk_meta in request.chunks_meta:
all_hits_vllm_ids.extend(chunk_meta.hits_vllm_blk_ids)
positions.extend(
[chunk_meta.position_offset] * len(chunk_meta.hits_vllm_blk_ids)
)
if all_hits_vllm_ids:
vllm_ids = torch.tensor(all_hits_vllm_ids, device=k_cache.device)
positions = torch.tensor(positions, device=k_cache.device)
self._post_process_chunk_cache(k_cache, vllm_ids, positions)
pass
self.delta_rope_vllm_ids = torch.tensor(
all_hits_vllm_ids, device=self.device
)
self.delta_rope_positions = torch.tensor(positions, device=self.device)

def clear_connector_metadata(self) -> None:
"""Clear the post process meta"""
super().clear_connector_metadata()
self.delta_rope_vllm_ids = None
self.delta_rope_positions = None

def wait_for_layer_load(self, layer_name: str) -> None:
if self.delta_rope_vllm_ids is not None:
k_cache = self.kv_caches[layer_name][0]
self._post_process_chunk_cache(
k_cache, self.delta_rope_vllm_ids, self.delta_rope_positions
)
Loading