diff --git a/examples/offline_inference_blend.py b/examples/offline_inference_blend.py index bdc2b211b..45f50a883 100644 --- a/examples/offline_inference_blend.py +++ b/examples/offline_inference_blend.py @@ -96,11 +96,10 @@ def build_llm_with_uc(module_path: str, name: str, model: str): "ucm_connector_name": "UcmNfsStore", "ucm_connector_config": { "storage_backends": data_dir, - "kv_block_size": 33554432, + "use_direct": False, }, } ], - "load_only_first_rank": False, "ucm_sparse_config": { "Blend": { "chunk_end_token_id": chunk_end_token_id, @@ -111,7 +110,6 @@ def build_llm_with_uc(module_path: str, name: str, model: str): }, } }, - "use_layerwise": True, }, ) @@ -261,10 +259,11 @@ def main(): print(f"Baseline generated text: {baseline_gen_text!r}") print(f"Baseline generated cost time: {baseline_time:.2f} seconds") - print(f"Blend generated text: {blend_gen_text!r}") - print(f"Blend generated cost time: {blend_time:.2f} seconds") print(f"Prefix Cache generated text: {pc_gen_text!r}") print(f"Prefix Cache generated cost time: {pc_time:.2f} seconds") + print(f"Blend generated text: {blend_gen_text!r}") + print(f"Blend generated cost time: {blend_time:.2f} seconds") + print(f"Question:{dataset_row['input']}") print(f"Golden answer:{dataset_row["answers"]}") diff --git a/ucm/integration/vllm/blend_connector.py b/ucm/integration/vllm/blend_connector.py index eaba1b381..828d5dc64 100644 --- a/ucm/integration/vllm/blend_connector.py +++ b/ucm/integration/vllm/blend_connector.py @@ -14,8 +14,6 @@ from ucm.integration.vllm.ucm_connector import ( RequestDispatchMeta, - RequestHasher, - RequestMeta, UCMConnectorMetadata, UCMDirectConnector, ) @@ -41,7 +39,7 @@ class ChunkMetaData: cached_start_position: int vllm_blk_ids: List[int] = field(default_factory=list) - chunk_blks_hash: List[str] = field(default_factory=list) + chunk_blks_hash: List[bytes] = field(default_factory=list) store_hits: List[bool] = field(default_factory=list) @property @@ -65,7 +63,7 @@ def hits_vllm_blk_ids(self) -> List[int]: return list(itertools.compress(self.vllm_blk_ids, self.store_hits)) @property - def hits_chunk_blks_hash(self) -> List[str]: + def hits_chunk_blks_hash(self) -> List[bytes]: return list(itertools.compress(self.chunk_blks_hash, self.store_hits)) def merge_chunk(self, temp_chunk_meta: Self) -> None: @@ -102,7 +100,7 @@ def is_prefix_cache(self) -> bool: @dataclass class BlendRequestMeta: - ucm_block_hashs: list[str] = field(default_factory=list) + ucm_block_hashs: list[bytes] = field(default_factory=list) # hbm pc is not supported hbm_hit_block_num: int = 0 # ucm pc is supported @@ -138,35 +136,15 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): else: raise "UCMBlendConnector init failed, please check your config" - self.ucm_chunk_end_hash: int = self.request_hasher("UCM_CHUNK_END_HASH") - self.ucm_chunk_continue_hash: int = self.request_hasher( - "UCM_CHUNK_CONTINUE_HASH" - ) self.requests_blend_meta: dict[str, BlendRequestMeta] = {} self.cos_sin_cache: torch.Tensor = None # if chunk cache hits less than min_blend_threshold, no need to cache blend self.min_blend_threshold = 16 - def _generate_hash( - self, block_size: int, token_ids: list[int], parent_block_hash_value: int - ) -> list[str]: - ret = [] - for start in range(0, len(token_ids), block_size): - end = start + block_size - block_token_ids = token_ids[start:end] - # Do not hash the block if it is not full. - if len(block_token_ids) < block_size: - break - - block_token_ids_tuple = tuple(block_token_ids) - hash_value = self.request_hasher( - (parent_block_hash_value, block_token_ids_tuple) - ) - parent_block_hash_value = hash_value - ret.append(str(hash_value)) - - return ret + # post process delta rope meta + self.delta_rope_vllm_ids: torch.Tensor = None + self.delta_rope_positions: torch.Tensor = None def _process_req(self, all_token_ids: List[int]): """ @@ -177,8 +155,8 @@ def _process_req(self, all_token_ids: List[int]): finally, if there are quite many chunk block-hits, we do cache blend to get TTFT-promot """ chunks_meta = [] - prefix_block_hashes = self._generate_hash( - self.block_size, all_token_ids, RequestHasher._SEED_HASH + prefix_block_hashes = self.generate_hash( + self.block_size, all_token_ids, self._seed ) if ( all_token_ids[-1] == self.chunk_end_token_id @@ -203,8 +181,8 @@ def _process_req(self, all_token_ids: List[int]): # but this will bring lots of modification to engine. if all_token_ids[end_token_idx] == self.chunk_end_token_id: chunk_token_ids = all_token_ids[start_token_dix : end_token_idx + 1] - chunk_blks_hash = self._generate_hash( - self.block_size, chunk_token_ids, RequestHasher._SEED_HASH + chunk_blks_hash = self.generate_hash( + self.block_size, chunk_token_ids, self._seed ) chunk_blks_len = end_blk_idx - start_blk_idx + 1 @@ -371,7 +349,7 @@ def get_num_new_matched_tokens( ) -> tuple[int, bool]: # current not support HBM prefix cache, cause the blended cached have a ground view of all chunks - # so they can not apply to other req + # so they can not be applied to other req assert num_computed_tokens == 0 all_token_ids = request.all_token_ids @@ -488,21 +466,34 @@ def build_connector_meta( return UCMBlendConnectorMetadata(requests_dispatch_meta) - def wait_for_layer_load(self, layer_name: str) -> None: - metadata = self._get_connector_metadata() - assert isinstance(metadata, UCMBlendConnectorMetadata) - + def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None: + """ + Blend need build post process meta for loaded kv cache + """ + super().bind_connector_metadata(connector_metadata) all_hits_vllm_ids = [] positions = [] - k_cache = self.kv_caches[layer_name][0] - for request_id, request in metadata.request_meta.items(): + for request_id, request in connector_metadata.request_meta.items(): for chunk_meta in request.chunks_meta: all_hits_vllm_ids.extend(chunk_meta.hits_vllm_blk_ids) positions.extend( [chunk_meta.position_offset] * len(chunk_meta.hits_vllm_blk_ids) ) if all_hits_vllm_ids: - vllm_ids = torch.tensor(all_hits_vllm_ids, device=k_cache.device) - positions = torch.tensor(positions, device=k_cache.device) - self._post_process_chunk_cache(k_cache, vllm_ids, positions) - pass + self.delta_rope_vllm_ids = torch.tensor( + all_hits_vllm_ids, device=self.device + ) + self.delta_rope_positions = torch.tensor(positions, device=self.device) + + def clear_connector_metadata(self) -> None: + """Clear the post process meta""" + super().clear_connector_metadata() + self.delta_rope_vllm_ids = None + self.delta_rope_positions = None + + def wait_for_layer_load(self, layer_name: str) -> None: + if self.delta_rope_vllm_ids is not None: + k_cache = self.kv_caches[layer_name][0] + self._post_process_chunk_cache( + k_cache, self.delta_rope_vllm_ids, self.delta_rope_positions + ) diff --git a/ucm/integration/vllm/patch/0.9.2/vllm-adapt-sparse.patch b/ucm/integration/vllm/patch/0.9.2/vllm-adapt-sparse.patch index 705872d78..ad0546ef0 100644 --- a/ucm/integration/vllm/patch/0.9.2/vllm-adapt-sparse.patch +++ b/ucm/integration/vllm/patch/0.9.2/vllm-adapt-sparse.patch @@ -1,26 +1,25 @@ -From 8cb493f9ece884cbc2ba71e367bed2b4116ae1b3 Mon Sep 17 00:00:00 2001 +From 857bc2830fb294d893ca5fac50eff48e10aaddf2 Mon Sep 17 00:00:00 2001 From: wenxinwang -Date: Tue, 23 Dec 2025 19:44:21 -0800 -Subject: [PATCH] kvcomp qwen deepseek +Date: Wed, 21 Jan 2026 00:53:52 -0800 +Subject: [PATCH] update for gsaondevice + sparse + cache blend --- - vllm/attention/layer.py | 63 ++++++++++++++++- + vllm/attention/layer.py | 63 +++++++++++++++- vllm/model_executor/models/llama.py | 21 +++++- - vllm/model_executor/models/qwen2.py | 23 ++++++- - vllm/v1/attention/backends/flash_attn.py | 7 ++ + vllm/model_executor/models/qwen2.py | 23 +++++- vllm/v1/attention/backends/mla/common.py | 15 +++- vllm/v1/attention/backends/mla/flashmla.py | 18 ++++- vllm/v1/core/kv_cache_manager.py | 7 +- vllm/v1/core/kv_cache_utils.py | 13 ++++ - vllm/v1/core/sched/output.py | 3 + - vllm/v1/core/sched/scheduler.py | 30 +++++++- + vllm/v1/core/sched/output.py | 7 +- + vllm/v1/core/sched/scheduler.py | 34 ++++++++- vllm/v1/worker/block_table.py | 13 ++++ - vllm/v1/worker/gpu_model_runner.py | 80 +++++++++++++++++++--- + vllm/v1/worker/gpu_model_runner.py | 87 +++++++++++++++++++--- vllm/v1/worker/gpu_worker.py | 2 + - 13 files changed, 275 insertions(+), 20 deletions(-) + 12 files changed, 281 insertions(+), 22 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py -index f0ad68b16..ba93960de 100644 +index f0ad68b16..5b2e0f04f 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -8,6 +8,7 @@ import torch.nn as nn @@ -61,7 +60,7 @@ index f0ad68b16..ba93960de 100644 + kv_cache, k_hash = kv_cache + else: + k_hash = None -+ query, _, _, _ = maybe_execute_sparse_attention_begin( ++ query, key, value, output = maybe_execute_sparse_attention_begin( + query, key, value, layer_name, forward_context, output, k_hash=k_hash + ) self.impl.forward(self, @@ -237,33 +236,8 @@ index 7ef9d248d..e35ab2fdc 100644 if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, -diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py -index fbc13c06c..2b2244949 100755 ---- a/vllm/v1/attention/backends/flash_attn.py -+++ b/vllm/v1/attention/backends/flash_attn.py -@@ -16,6 +16,8 @@ from vllm.attention.ops.merge_attn_states import merge_attn_states - from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, - get_flash_attn_version, - is_flash_attn_varlen_func_available) -+from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse -+import os - - if is_flash_attn_varlen_func_available(): - from vllm.attention.utils.fa_utils import (flash_attn_varlen_func, -@@ -221,6 +223,11 @@ class FlashAttentionMetadataBuilder( - block_table = self.block_table - block_table_tensor = block_table.get_device_tensor()[:num_reqs] - -+ if has_ucm_sparse(): -+ ucm_sparse = get_ucm_sparse() -+ if os.getenv("VLLM_HASH_ATTENTION") == "1": -+ decode_mask, topk_seq_lens = ucm_sparse.build_decode_attention_meta(query_start_loc, seq_lens, block_table_tensor) -+ - block_table.slot_mapping[:num_actual_tokens].copy_( - block_table.slot_mapping_cpu[:num_actual_tokens], - non_blocking=True) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py -index f2aaf59a4..439bb9b14 100644 +index f2aaf59a4..205bdbe71 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -200,6 +200,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, @@ -307,7 +281,7 @@ index f2aaf59a4..439bb9b14 100644 ) if has_prefill: -+ prefill_q, _, _, _ = maybe_execute_sparse_attention_begin(prefill_q, k_c_normed, k_pe, layer.layer_name, forward_context, output=output, phase="prefill", k_hash=k_hash) ++ prefill_q, k_c_normed, k_pe, output = maybe_execute_sparse_attention_begin(prefill_q, k_c_normed, k_pe, layer.layer_name, forward_context, output=output, phase="prefill", k_hash=k_hash) output[num_decode_tokens:] = self._forward_prefill( prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, attn_metadata) @@ -320,7 +294,7 @@ index f2aaf59a4..439bb9b14 100644 decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) # Convert from (N, B, L) to (B, N, L) decode_ql_nope = decode_ql_nope.transpose(0, 1) -+ _, _, _, _ = maybe_execute_sparse_attention_begin(torch.cat([decode_ql_nope, decode_q_pe],dim=-1), k_c_normed, k_pe, layer.layer_name, forward_context, output=output, phase="decode", k_hash=k_hash, decode_ql_nope=decode_ql_nope, decode_q_pe=decode_q_pe) ++ _, k_c_normed, k_pe, output = maybe_execute_sparse_attention_begin(torch.cat([decode_ql_nope, decode_q_pe],dim=-1), k_c_normed, k_pe, layer.layer_name, forward_context, output=output, phase="decode", k_hash=k_hash, decode_ql_nope=decode_ql_nope, decode_q_pe=decode_q_pe) output[:num_decode_tokens] = self._forward_decode( decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) @@ -454,18 +428,29 @@ index 2fbcb569e..40c199563 100644 # All layers have the same KV cache spec, so we create one kv cache group # for all layers. diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py -index d34f39327..141d750b3 100644 +index d34f39327..0f60ac77d 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py -@@ -155,3 +155,6 @@ class SchedulerOutput: +@@ -3,7 +3,7 @@ + + from __future__ import annotations + +-from dataclasses import dataclass ++from dataclasses import dataclass, field + from typing import TYPE_CHECKING, Optional + + if TYPE_CHECKING: +@@ -155,3 +155,8 @@ class SchedulerOutput: # KV Cache Connector metadata. kv_connector_metadata: Optional[KVConnectorMetadata] = None + + # modified slots by sparse algorithm + req_sparsed_slots: dict[str, int] = None ++ # The number of tokens computed externally for each request ++ num_external_computed_tokens_per_req: dict[str, int] = field(default_factory=dict) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py -index fe552db74..0d8a67eba 100644 +index fe552db74..7d98745c8 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -34,6 +34,10 @@ from vllm.v1.outputs import ModelRunnerOutput @@ -524,7 +509,14 @@ index fe552db74..0d8a67eba 100644 if new_blocks is None: # The request cannot be scheduled. # Preempt the lowest-priority request. -@@ -337,6 +355,10 @@ class Scheduler(SchedulerInterface): +@@ -331,12 +349,17 @@ class Scheduler(SchedulerInterface): + skipped_waiting_requests = create_request_queue(self.policy) + + # Next, schedule the WAITING requests. ++ num_external_computed_tokens_per_req: dict[str, int] = {} + if not preempted_reqs: + while self.waiting and token_budget > 0: + if len(self.running) == self.max_num_running_reqs: break request = self.waiting.peek_request() @@ -535,7 +527,16 @@ index fe552db74..0d8a67eba 100644 # KVTransfer: skip request if still waiting for remote kvs. if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: -@@ -446,6 +468,7 @@ class Scheduler(SchedulerInterface): +@@ -387,7 +410,7 @@ class Scheduler(SchedulerInterface): + num_external_computed_tokens, load_kv_async = ( + self.connector.get_num_new_matched_tokens( + request, num_new_local_computed_tokens)) +- ++ num_external_computed_tokens_per_req.update({request.request_id: num_external_computed_tokens}) + # Total computed tokens (local + external). + num_computed_tokens = (num_new_local_computed_tokens + + num_external_computed_tokens) +@@ -446,6 +469,7 @@ class Scheduler(SchedulerInterface): new_computed_blocks, num_lookahead_tokens=self.num_lookahead_tokens, delay_cache_blocks=load_kv_async, @@ -543,15 +544,16 @@ index fe552db74..0d8a67eba 100644 ) if new_blocks is None: # The request cannot be scheduled. -@@ -559,6 +582,7 @@ class Scheduler(SchedulerInterface): +@@ -559,6 +583,8 @@ class Scheduler(SchedulerInterface): scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, scheduled_encoder_inputs=scheduled_encoder_inputs, num_common_prefix_blocks=num_common_prefix_blocks, + req_sparsed_slots=req_sparsed_slots, ++ num_external_computed_tokens_per_req = num_external_computed_tokens_per_req, # finished_req_ids is an existing state in the scheduler, # instead of being newly scheduled in this step. # It contains the request IDs that are finished in between -@@ -927,6 +951,8 @@ class Scheduler(SchedulerInterface): +@@ -927,6 +953,8 @@ class Scheduler(SchedulerInterface): def add_request(self, request: Request) -> None: self.waiting.add_request(request) self.requests[request.request_id] = request @@ -560,7 +562,7 @@ index fe552db74..0d8a67eba 100644 if self.log_stats: request.record_event(EngineCoreEventType.QUEUED) -@@ -976,6 +1002,8 @@ class Scheduler(SchedulerInterface): +@@ -976,6 +1004,8 @@ class Scheduler(SchedulerInterface): def _free_request(self, request: Request) -> Optional[dict[str, Any]]: assert request.is_finished() @@ -601,7 +603,7 @@ index 8f4e8d64c..f45e39f5c 100644 for i, block_table in enumerate(self.block_tables): block_table.add_row(block_ids[i], row_idx) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py -index 5a26e88db..6a39240d2 100644 +index 5a26e88db..41544a077 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -15,6 +15,7 @@ import torch.nn as nn @@ -622,15 +624,17 @@ index 5a26e88db..6a39240d2 100644 if TYPE_CHECKING: import xgrammar as xgr import xgrammar.kernels.apply_token_bitmask_inplace_torch_compile as xgr_torch_compile # noqa: E501 -@@ -365,6 +369,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -364,7 +368,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): + new/resumed/paused/finished request in the batch. """ # Remove finished requests from the cached states. ++ self.ucm_sparse_update_states(scheduler_output) for req_id in scheduler_output.finished_req_ids: + self.ucm_sparse_request_finished_in_worker(req_id) self.requests.pop(req_id, None) self.encoder_cache.pop(req_id, None) # Remove the finished requests from the persistent batch. -@@ -468,11 +473,13 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -468,11 +474,13 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Update the states of the running/resumed requests. is_last_rank = get_pp_group().is_last_rank req_data = scheduler_output.scheduled_cached_reqs @@ -644,7 +648,7 @@ index 5a26e88db..6a39240d2 100644 # Update the cached states. req_state.num_computed_tokens = num_computed_tokens -@@ -494,15 +501,15 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -494,15 +502,15 @@ class GPUModelRunner(LoRAModelRunnerMixin): new_token_ids[-num_new_tokens:]) # Update the block IDs. @@ -666,7 +670,7 @@ index 5a26e88db..6a39240d2 100644 req_index = self.input_batch.req_id_to_index.get(req_id) if req_index is None: -@@ -515,6 +522,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -515,6 +523,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( num_computed_tokens) @@ -675,7 +679,7 @@ index 5a26e88db..6a39240d2 100644 self.input_batch.block_table.append_row(new_block_ids, req_index) # For the last rank, we don't need to update the token_ids_cpu -@@ -623,6 +632,19 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -623,6 +633,19 @@ class GPUModelRunner(LoRAModelRunnerMixin): if self.uses_mrope: self._calc_mrope_positions(scheduler_output) @@ -695,7 +699,7 @@ index 5a26e88db..6a39240d2 100644 # Get token indices. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] -@@ -652,11 +674,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -652,11 +675,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): # block_size. block_table_indices = ( req_indices * block_table.max_num_blocks_per_req + @@ -709,7 +713,7 @@ index 5a26e88db..6a39240d2 100644 np.add( block_numbers * block_size, block_offsets, -@@ -666,9 +688,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -666,9 +689,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.query_start_loc_np[0] = 0 self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens @@ -724,7 +728,7 @@ index 5a26e88db..6a39240d2 100644 # Copy the tensors to the GPU. self.input_ids[:total_num_scheduled_tokens].copy_( -@@ -680,6 +704,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -680,6 +705,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): non_blocking=True) else: # Common case (1D positions) @@ -733,7 +737,7 @@ index 5a26e88db..6a39240d2 100644 self.positions[:total_num_scheduled_tokens].copy_( self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) -@@ -1370,6 +1396,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -1370,6 +1397,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): skip_cuda_graphs=skip_cuda_graphs, ): self.maybe_setup_kv_connector(scheduler_output) @@ -741,7 +745,7 @@ index 5a26e88db..6a39240d2 100644 model_output = self.model( input_ids=input_ids, -@@ -1379,6 +1406,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -1379,6 +1407,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ) self.maybe_wait_for_kv_save() @@ -750,7 +754,7 @@ index 5a26e88db..6a39240d2 100644 finished_sending, finished_recving = ( self.get_finished_kv_transfers(scheduler_output)) -@@ -1723,6 +1752,30 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -1723,6 +1753,36 @@ class GPUModelRunner(LoRAModelRunnerMixin): if has_kv_transfer_group(): get_kv_transfer_group().wait_for_save() @@ -777,11 +781,17 @@ index 5a26e88db..6a39240d2 100644 + return + ucm_sparse = get_ucm_sparse() + ucm_sparse.request_finished_in_worker(request_id) ++ ++ def ucm_sparse_update_states(self, scheduler_output: "SchedulerOutput"): ++ if not has_ucm_sparse(): ++ return ++ ucm_sparse = get_ucm_sparse() ++ ucm_sparse.update_states(scheduler_output) + @staticmethod def get_finished_kv_transfers( scheduler_output: "SchedulerOutput", -@@ -2570,6 +2623,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -2570,6 +2630,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): kv_caches = self._reshape_kv_cache_tensors(kv_cache_config, kv_cache_raw_tensors) diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index 68318caf1..0ab9b2326 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -162,11 +162,10 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): # invlalid block ids due to load errors self._invalid_block_ids: set[int] = set() - def generate_hash(self, block_size: int, request: "Request") -> list[bytes]: - token_ids = request.all_token_ids - + def generate_hash( + self, block_size: int, token_ids: List[int], parent_block_hash_value: bytes + ) -> list[bytes]: ret = [] - parent_block_hash_value = self._seed for start in range(0, len(token_ids), block_size): end = start + block_size block_token_ids = token_ids[start:end] @@ -297,7 +296,9 @@ def get_num_new_matched_tokens( assert num_computed_tokens % self.block_size == 0 hbm_hit_block_num = num_computed_tokens // self.block_size - ucm_block_ids = self.generate_hash(self.block_size, request) + ucm_block_ids = self.generate_hash( + self.block_size, request.all_token_ids, self._seed + ) external_block_ids = ucm_block_ids[hbm_hit_block_num:] if not external_block_ids: diff --git a/ucm/sparse/base.py b/ucm/sparse/base.py index e2a6c002c..f0acca692 100644 --- a/ucm/sparse/base.py +++ b/ucm/sparse/base.py @@ -158,6 +158,13 @@ def attention_finished( """ pass + def update_states(self, scheduler_output: SchedulerOutput) -> None: + """ + This is called at the beginning of "ModelRunner->execute_model" function. + Update the cached states with the scheduler output. + """ + pass + def ffn_begin( self, hidden_states: torch.Tensor, residual: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: diff --git a/ucm/sparse/blend/blend.py b/ucm/sparse/blend/blend.py index 32b975b6d..3b6843df9 100644 --- a/ucm/sparse/blend/blend.py +++ b/ucm/sparse/blend/blend.py @@ -1,13 +1,14 @@ import time +from contextlib import contextmanager from dataclasses import dataclass, field from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch +import torch.cuda.nvtx as nvtx from sympy import false from torch import Tensor from ucm.logger import init_logger -from ucm.store.dramstore.dramstore_connector import device logger = init_logger(__name__) @@ -25,6 +26,17 @@ from ucm.sparse.utils import round_up +@contextmanager +def nvtx_range(msg: str, enable: bool = True): + if enable: + nvtx.range_push(msg) + try: + yield + finally: + if enable: + nvtx.range_pop() + + def get_num_blks(num_tokens, block_size): return (num_tokens + block_size - 1) // block_size @@ -226,98 +238,102 @@ def attention_begin( ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: attn = forward_context.no_compile_layers[layer_name] kv_cache = attn.kv_cache[forward_context.virtual_engine] - start_time = time.perf_counter() if layer_name in self.compute_meta.keys(): need_update = False self.blend_req_metas.reset_compute_mask() - # maybe we can use triton kernel for req_meta in self.blend_req_metas.requests: - req_idx = req_meta.req_idx - req_query_start = self.attn_metadata.query_start_loc[req_idx].item() - req_query_end = self.attn_metadata.query_start_loc[req_idx + 1].item() - - if not req_meta.need_blend: - self.blend_req_metas.compute_mask[ - req_query_start:req_query_end - ].fill_(True) - continue - req_chunk_end = req_query_start + req_meta.chunks_len - - # HBM prefix cache is not supported now - # UC store prefix cache can be fully reused for the first chunk - his_vllm_blk_ids = self.attn_metadata.block_table[req_idx][ - req_meta.prefix_blk_len : req_meta.prefix_blk_len - + req_meta.chunks_blk_len - ] - # only compute topk of chunk's hits block - chunk_hit_mask = self.blend_req_metas.chunk_blks_hit_mask[ - : len(req_meta.chunk_hit_mask) - ] - src = torch.as_tensor( - req_meta.chunk_hit_mask, - dtype=chunk_hit_mask.dtype, - device=chunk_hit_mask.device, - ) - chunk_hit_mask.copy_(src) - - his_vllm_blk_ids = his_vllm_blk_ids[chunk_hit_mask] - his_k = kv_cache[0, his_vllm_blk_ids] - candidate_len = req_meta.chunk_hit_blk_len * self.block_size - his_k = his_k.reshape(candidate_len, -1) - - req_key = key[req_query_start:req_chunk_end] - - # req_key does not contain prefix cache - golden_k = req_key.reshape( - req_meta.chunks_blk_len, self.block_size, -1 - )[chunk_hit_mask] - golden_k = golden_k.reshape(candidate_len, -1) - - diff_k = torch.sum((his_k - golden_k).abs(), dim=[1]) - topK_num = int(candidate_len * self.compute_meta[layer_name]["ratio"]) - - topK_indices = torch.topk(diff_k, k=topK_num).indices - - # get origin idx in req_key - topK_indices = self.mask_idx[: req_meta.chunks_blk_len][ - chunk_hit_mask - ].reshape(-1)[topK_indices] - - # update compute_mask - self.blend_req_metas.update_req_compute_mask( - req_query_start, - req_chunk_end, - req_query_end, - chunk_hit_mask, - topK_indices, - ) - - self.blend_req_metas.update_query_lens( - req_idx, candidate_len - topK_num - ) - need_update = True + with nvtx_range(f"prepare meta req, :{req_meta.req_idx}"): + req_idx = req_meta.req_idx + req_query_start = self.attn_metadata.query_start_loc[req_idx].item() + req_query_end = self.attn_metadata.query_start_loc[ + req_idx + 1 + ].item() + + if not req_meta.need_blend: + self.blend_req_metas.compute_mask[ + req_query_start:req_query_end + ].fill_(True) + continue + req_chunk_end = req_query_start + req_meta.chunks_len + + with nvtx_range(f"prepare data, req :{req_meta.req_idx}"): + # HBM prefix cache is not supported now + # UC store prefix cache can be fully reused for the first chunk + his_vllm_blk_ids = self.attn_metadata.block_table[req_idx][ + req_meta.prefix_blk_len : req_meta.prefix_blk_len + + req_meta.chunks_blk_len + ] + # only compute topk of chunk's hits block + chunk_hit_mask = self.blend_req_metas.chunk_blks_hit_mask[ + : len(req_meta.chunk_hit_mask) + ] + src = torch.as_tensor( + req_meta.chunk_hit_mask, + dtype=chunk_hit_mask.dtype, + device=chunk_hit_mask.device, + ) + chunk_hit_mask.copy_(src) + + his_vllm_blk_ids = his_vllm_blk_ids[chunk_hit_mask] + his_k = kv_cache[0, his_vllm_blk_ids] + candidate_len = req_meta.chunk_hit_blk_len * self.block_size + his_k = his_k.reshape(candidate_len, -1) + + req_key = key[req_query_start:req_chunk_end] + + # req_key does not contain prefix cache + golden_k = req_key.reshape( + req_meta.chunks_blk_len, self.block_size, -1 + )[chunk_hit_mask] + golden_k = golden_k.reshape(candidate_len, -1) + + with nvtx_range(f"calculate topK, req :{req_meta.req_idx}"): + diff_k = torch.sum((his_k - golden_k).abs(), dim=[1]) + topK_num = int( + candidate_len * self.compute_meta[layer_name]["ratio"] + ) + + topK_indices = torch.topk(diff_k, k=topK_num).indices + + # get origin idx in req_key + topK_indices = self.mask_idx[: req_meta.chunks_blk_len][ + chunk_hit_mask + ].reshape(-1)[topK_indices] + + with nvtx_range(f"update blend meta, req :{req_meta.req_idx}"): + # update compute_mask + self.blend_req_metas.update_req_compute_mask( + req_query_start, + req_chunk_end, + req_query_end, + chunk_hit_mask, + topK_indices, + ) + + self.blend_req_metas.update_query_lens( + req_idx, candidate_len - topK_num + ) + need_update = True if need_update: - logger.info( - f"[blend-attn] compute_mask time: {(time.perf_counter() - start_time) * 1000}ms" - ) - self.blend_req_metas.update_need_re_index(True) - self._update_attn_metadata() - - indexed_query = query[self.blend_req_metas.compute_mask] - indexed_key = key[self.blend_req_metas.compute_mask] - indexed_value = value[self.blend_req_metas.compute_mask] - indexed_output = None - if output is not None: - indexed_output = output[: self.blend_req_metas.compute_mask.sum()] - logger.info( - f"[blend-attn] compute_mask time + index time: {(time.perf_counter() - start_time) * 1000}ms" - ) - logger.info( - f"[blend-attn] reduce attn tokens from {len(self.blend_req_metas.compute_mask)} " - f"to {self.attn_metadata.num_actual_tokens}" - ) + with nvtx_range(f"update attn meta"): + self.blend_req_metas.update_need_re_index(True) + self._update_attn_metadata() + + indexed_query = query[self.blend_req_metas.compute_mask] + indexed_key = key[self.blend_req_metas.compute_mask] + indexed_value = value[self.blend_req_metas.compute_mask] + indexed_output = None + if output is not None: + indexed_output = output[ + : self.blend_req_metas.compute_mask.sum() + ] + + logger.info( + f"[blend-attn] reduce attn tokens from {len(self.blend_req_metas.compute_mask)} " + f"to {self.attn_metadata.num_actual_tokens}" + ) return indexed_query, indexed_key, indexed_value, indexed_output return query, key, value, output