From 6249ff4bb2360ba1e8e052df6ac2e116c6c4b500 Mon Sep 17 00:00:00 2001 From: Clarence-1103 Date: Thu, 22 Jan 2026 02:00:14 -0800 Subject: [PATCH] [feat] add monkey patch for gsa on device v0.9.2 --- ucm/__init__.py | 16 + ucm/integration/vllm/__init__.py | 16 - .../vllm/patch/patch_funcs/v092/vllm_patch.py | 545 +++++++++++++++++- ucm/integration/vllm/ucm_connector.py | 4 +- 4 files changed, 536 insertions(+), 45 deletions(-) diff --git a/ucm/__init__.py b/ucm/__init__.py index e69de29bb..39aa7fb63 100644 --- a/ucm/__init__.py +++ b/ucm/__init__.py @@ -0,0 +1,16 @@ +# from ucm.integration.vllm.ucm_connector import UCMConnector + +# try: +# from ucm.integration.vllm.patch.apply_patch import ensure_patches_applied + +# ensure_patches_applied() +# except Exception as e: +# # Don't fail if patches can't be applied - might be running in environment without vLLM +# import warnings + +# warnings.warn( +# f"Failed to apply vLLM patches: {e}. " +# f"If you're using vLLM, ensure it's installed and patches are compatible." +# ) + +# __all__ = ["UCMConnector"] diff --git a/ucm/integration/vllm/__init__.py b/ucm/integration/vllm/__init__.py index 39aa7fb63..e69de29bb 100644 --- a/ucm/integration/vllm/__init__.py +++ b/ucm/integration/vllm/__init__.py @@ -1,16 +0,0 @@ -# from ucm.integration.vllm.ucm_connector import UCMConnector - -# try: -# from ucm.integration.vllm.patch.apply_patch import ensure_patches_applied - -# ensure_patches_applied() -# except Exception as e: -# # Don't fail if patches can't be applied - might be running in environment without vLLM -# import warnings - -# warnings.warn( -# f"Failed to apply vLLM patches: {e}. " -# f"If you're using vLLM, ensure it's installed and patches are compatible." -# ) - -# __all__ = ["UCMConnector"] diff --git a/ucm/integration/vllm/patch/patch_funcs/v092/vllm_patch.py b/ucm/integration/vllm/patch/patch_funcs/v092/vllm_patch.py index 6ce7589e0..3e727537f 100644 --- a/ucm/integration/vllm/patch/patch_funcs/v092/vllm_patch.py +++ b/ucm/integration/vllm/patch/patch_funcs/v092/vllm_patch.py @@ -43,7 +43,10 @@ def _apply_sparse_adapt() -> None: _patch_block_table() _patch_kv_cache_manager() _patch_shared_storage_connector() + _patch_kv_cache_utils() + _patch_flash_attention_metadata_builder() _patch_attention_layer() + _patch_flash_mla() _patch_mla_common() _patch_gpu_model_runner() _patch_gpu_worker() @@ -234,6 +237,9 @@ def maybe_execute_sparse_attention_begin( forward_context: ForwardContext, output: Optional[torch.Tensor] = None, phase: Optional[str] = None, + k_hash: Optional[torch.Tensor] = None, + decode_ql_nope: Optional[torch.Tensor] = None, + decode_q_pe: Optional[torch.Tensor] = None, ): if not has_ucm_sparse(): return query, key, value, output @@ -245,7 +251,16 @@ def maybe_execute_sparse_attention_begin( return query, key, value, output return ucm_sparse.attention_begin( - query, key, value, layer_name, forward_context, output, phase + query, + key, + value, + layer_name, + forward_context, + output, + phase, + k_hash, + decode_ql_nope, + decode_q_pe, ) def maybe_execute_sparse_attention_finished( @@ -327,9 +342,21 @@ def unified_attention_with_output_impl( self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] if not self.use_mla: - query, key, value, output = maybe_execute_sparse_attention_begin( - query, key, value, layer_name, forward_context, output - ) + if attn_metadata is not None: + if has_ucm_sparse() and os.getenv("VLLM_HASH_ATTENTION") == "1": + kv_cache, k_hash = kv_cache + else: + k_hash = None + query, _, _, _ = maybe_execute_sparse_attention_begin( + query, + key, + value, + layer_name, + forward_context, + output, + k_hash=k_hash, + ) + self.impl.forward( self, query, @@ -340,6 +367,7 @@ def unified_attention_with_output_impl( output=output, output_scale=output_scale, ) + if not self.use_mla: maybe_execute_sparse_attention_finished( query, key, value, output, layer_name, forward_context @@ -428,6 +456,8 @@ def _patch_mla_common() -> None: MLACommonMetadata, ) + from ucm.sparse.state import has_ucm_sparse + M = TypeVar("M", bound=MLACommonMetadata) def forward( @@ -481,6 +511,10 @@ def forward( prefill_k_pe = k_pe[num_decode_tokens:] prefill_k_c_normed = k_c_normed[num_decode_tokens:] + if has_ucm_sparse() and os.getenv("VLLM_HASH_ATTENTION") == "1": + kv_cache, k_hash = kv_cache + else: + k_hash = None # write the latent and rope to kv cache if kv_cache.numel() > 0: ops.concat_and_cache_mla( @@ -493,15 +527,15 @@ def forward( ) if has_prefill: - prefill_q, prefill_k_c_normed, prefill_k_pe, _ = ( - maybe_execute_sparse_attention_begin( - prefill_q, - prefill_k_c_normed, - prefill_k_pe, - layer.layer_name, - forward_context, - phase="prefill", - ) + prefill_q, _, _, _ = maybe_execute_sparse_attention_begin( + prefill_q, + k_c_normed, + k_pe, + layer.layer_name, + forward_context, + output=output, + phase="prefill", + k_hash=k_hash, ) output[num_decode_tokens:] = self._forward_prefill( prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, attn_metadata @@ -526,15 +560,17 @@ def forward( decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) # Convert from (N, B, L) to (B, N, L) decode_ql_nope = decode_ql_nope.transpose(0, 1) - _, decode_ql_nope, decode_q_pe, _ = ( - maybe_execute_sparse_attention_begin( - torch.cat([decode_ql_nope, decode_q_pe], dim=-1), - decode_ql_nope, - decode_q_pe, - layer.layer_name, - forward_context, - phase="decode", - ) + _, _, _, _ = maybe_execute_sparse_attention_begin( + torch.cat([decode_ql_nope, decode_q_pe], dim=-1), + k_c_normed, + k_pe, + layer.layer_name, + forward_context, + output=output, + phase="decode", + k_hash=k_hash, + decode_ql_nope=decode_ql_nope, + decode_q_pe=decode_q_pe, ) output[:num_decode_tokens] = self._forward_decode( decode_ql_nope, decode_q_pe, kv_cache, attn_metadata @@ -913,12 +949,12 @@ def patched_schedule(self) -> SchedulerOutput: ) # Get externally-cached tokens if using a KVConnector. - if self.connector is not None: - num_external_computed_tokens, load_kv_async = ( - self.connector.get_num_new_matched_tokens( - request, num_new_local_computed_tokens - ) - ) + # if self.connector is not None: + # num_external_computed_tokens, load_kv_async = ( + # self.connector.get_num_new_matched_tokens( + # request, num_new_local_computed_tokens + # ) + # ) # Total computed tokens (local + external). num_computed_tokens = ( @@ -1212,10 +1248,13 @@ def _patch_gpu_model_runner() -> None: from vllm.sequence import IntermediateTensors from vllm.utils import round_up from vllm.v1.attention.backends.utils import CommonAttentionMetadata + from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput from vllm.v1.spec_decode.metadata import SpecDecodeMetadata + from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState + from vllm.v1.worker.utils import initialize_kv_cache_for_kv_sharing from ucm.sparse.base import INVALID_SLOT from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse @@ -2046,6 +2085,49 @@ def execute_model( GPUModelRunner.execute_model = execute_model + # Patch the GPUModelRunner's initialize_kv_cache_tensors method to add UCM sparse support. + def initialize_kv_cache_tensors( + self, kv_cache_config: KVCacheConfig + ) -> dict[str, torch.Tensor]: + """ + Initialize the memory buffer for KV cache. + + Args: + kv_cache_config: The KV cache config + Returns: + Dict[str, torch.Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. + """ + # Initialize the memory buffer for KV cache + kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config) + # Change the memory buffer to the desired shape + kv_caches = self._reshape_kv_cache_tensors( + kv_cache_config, kv_cache_raw_tensors + ) + + if has_ucm_sparse(): + ucm_sparse = get_ucm_sparse() + if os.getenv("VLLM_HASH_ATTENTION") == "1": + ucm_sparse.initialize_kv_hash_cache_tensors(kv_caches, self.device) + + # Setup `kv_cache_config` and `kv_caches` for models + # with cross-layer KV sharing + if self.shared_kv_cache_layers: + initialize_kv_cache_for_kv_sharing( + self.shared_kv_cache_layers, + kv_cache_config.kv_cache_groups, + kv_caches, + ) + + bind_kv_cache( + kv_caches, + self.compilation_config.static_forward_context, + self.kv_caches, + ) + return kv_caches + + GPUModelRunner.initialize_kv_cache_tensors = initialize_kv_cache_tensors + except ImportError: logger.warning("Could not patch prepare inputs - module not found") @@ -2084,6 +2166,413 @@ def patched_init_worker_distributed_environment( logger.warning("Could not patch gpu worker - module not found") +# ==================== vllm/v1/attention/backends/flash_attn.py ==================== +def _patch_flash_attention_metadata_builder() -> None: + """Patch flash attention metadata builder to add UCM sparse support.""" + try: + import torch + from vllm.attention.utils.fa_utils import get_scheduler_metadata + from vllm.v1.attention.backends.flash_attn import ( + FlashAttentionMetadata, + FlashAttentionMetadataBuilder, + _get_sliding_window_configs, + ) + from vllm.v1.attention.backends.utils import ( + CommonAttentionMetadata, + make_local_attention_virtual_batches, + ) + + from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse + + def patched_build( + self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata + ) -> FlashAttentionMetadata: + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + max_query_len = common_attn_metadata.max_query_len + + max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + block_table = self.block_table + block_table_tensor = block_table.get_device_tensor()[:num_reqs] + + if has_ucm_sparse(): + ucm_sparse = get_ucm_sparse() + if os.getenv("VLLM_HASH_ATTENTION") == "1": + decode_mask, topk_seq_lens = ucm_sparse.build_decode_attention_meta( + query_start_loc, seq_lens, block_table_tensor + ) + + block_table.slot_mapping[:num_actual_tokens].copy_( + block_table.slot_mapping_cpu[:num_actual_tokens], non_blocking=True + ) + # Fill unused with -1. Needed for reshape_and_cache in full cuda graph + # mode. + block_table.slot_mapping[num_actual_tokens:].fill_(-1) + + slot_mapping = block_table.slot_mapping[:num_actual_tokens] + + if self.aot_sliding_window is None: + self.aot_sliding_window = (-1, -1) + # For the AOT scheduler we need the sliding window value to be + # constant for all layers to. We have to populate this on the first + # build() call so the layers are constructed (cannot populate) + # in __init__. + if self.aot_schedule: + sliding_window_configs = _get_sliding_window_configs( + self.runner.vllm_config + ) + if len(sliding_window_configs) == 1: + sliding_window_config = sliding_window_configs.pop() + if sliding_window_config is not None: + self.aot_sliding_window = sliding_window_config + elif len(sliding_window_configs) > 1: + self.aot_schedule = False + + def schedule( + batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal + ): + if self.aot_schedule: + return get_scheduler_metadata( + batch_size=batch_size, + max_seqlen_q=max_query_len, + max_seqlen_k=max_seq_len, + cache_seqlens=seqlens, + num_heads_q=self.num_heads_q, + num_heads_kv=self.num_heads_kv, + headdim=self.headdim, + page_size=self.block_size, + cu_seqlens_q=cu_query_lens, + causal=causal, + window_size=self.aot_sliding_window, + num_splits=self.max_num_splits, + ) + return None + + # for local attention + local_attn_metadata = None + if self.runner.attention_chunk_size is not None: + ( + seqlens_q_local_np, + virt_q_cu_seqlens_np, + virt_k_seqlens_np, + virt_block_table_tensor, + ) = make_local_attention_virtual_batches( + self.runner.attention_chunk_size, + self.runner.query_start_loc_np[: num_reqs + 1], + self.runner.seq_lens_np[:num_reqs], + block_table_tensor, + self.block_size, + ) + local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to( + self.runner.device, non_blocking=True + ) + local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to( + self.runner.device, non_blocking=True + ) + local_max_query_len = seqlens_q_local_np.max() + local_max_seq_len = virt_k_seqlens_np.max() + local_scheduler_metadata = schedule( + batch_size=local_query_start_loc.shape[0] - 1, + cu_query_lens=local_query_start_loc, + max_query_len=local_max_query_len, + seqlens=local_seqused_k, + max_seq_len=local_max_seq_len, + causal=True, + ) + + local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata( + local_query_start_loc=local_query_start_loc, + local_seqused_k=local_seqused_k, + local_block_table=virt_block_table_tensor, + local_max_query_len=local_max_query_len, + local_max_seq_len=local_max_seq_len, + local_scheduler_metadata=local_scheduler_metadata, + ) + + use_cascade = common_prefix_len > 0 + + if use_cascade: + cu_prefix_query_lens = torch.tensor( + [0, num_actual_tokens], dtype=torch.int32, device=self.runner.device + ) + prefix_kv_lens = torch.tensor( + [common_prefix_len], dtype=torch.int32, device=self.runner.device + ) + suffix_kv_lens = self.runner.seq_lens_np[:num_reqs] - common_prefix_len + suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(self.runner.device) + prefix_scheduler_metadata = schedule( + batch_size=1, + cu_query_lens=cu_prefix_query_lens, + max_query_len=num_actual_tokens, + seqlens=prefix_kv_lens, + max_seq_len=common_prefix_len, + causal=False, + ) + scheduler_metadata = schedule( + batch_size=num_reqs, + cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=suffix_kv_lens, + max_seq_len=max_seq_len - common_prefix_len, + causal=True, + ) + else: + cu_prefix_query_lens = None + prefix_kv_lens = None + suffix_kv_lens = None + prefix_scheduler_metadata = None + scheduler_metadata = schedule( + batch_size=num_reqs, + cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=seq_lens, + max_seq_len=max_seq_len, + causal=True, + ) + + if self.use_full_cuda_graph: + assert scheduler_metadata is not None + n = scheduler_metadata.shape[0] + self.scheduler_metadata[:n] = scheduler_metadata + # NOTE(woosuk): We should zero out the rest of the scheduler + # metadata to guarantee the correctness. Otherwise, some thread + # blocks may use the invalid scheduler metadata and overwrite the + # output buffer. + self.scheduler_metadata[n:] = 0 + scheduler_metadata = self.scheduler_metadata[:n] + + max_num_splits = 0 + if ( + self.use_full_cuda_graph + and num_actual_tokens <= self.max_cudagraph_size + ): + # NOTE(woosuk): Setting num_splits > 1 may increase the memory + # usage, because the intermediate buffers of size [num_splits, + # num_heads, num_tokens, head_size] are allocated. Therefore, + # we only set num_splits when using cuda graphs. + max_num_splits = self.max_num_splits + + attn_metadata = FlashAttentionMetadata( + num_actual_tokens=num_actual_tokens, + max_query_len=max_query_len, + query_start_loc=query_start_loc, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + block_table=block_table_tensor, + slot_mapping=slot_mapping, + use_cascade=use_cascade, + common_prefix_len=common_prefix_len, + scheduler_metadata=scheduler_metadata, + cu_prefix_query_lens=cu_prefix_query_lens, + prefix_kv_lens=prefix_kv_lens, + suffix_kv_lens=suffix_kv_lens, + local_attn_metadata=local_attn_metadata, + prefix_scheduler_metadata=prefix_scheduler_metadata, + max_num_splits=max_num_splits, + ) + return attn_metadata + + FlashAttentionMetadataBuilder.build = patched_build + except ImportError: + logger.warning( + "Could not patch flash attention metadata builder - module not found" + ) + + +# ==================== vllm/v1/attention/backends/flashmla.py ==================== +def _patch_flash_mla() -> None: + """Patch flash mla to add vLLM support.""" + try: + from dataclasses import dataclass + + import torch + from vllm.attention.ops.flashmla import get_mla_metadata + from vllm.v1.attention.backends.mla import flashmla + from vllm.v1.attention.backends.mla.common import MLACommonDecodeMetadata + + from ucm.sparse.state import ( + get_ucm_sparse, + has_ucm_sparse, + ) + + @dataclass + class FlashMLADecodeMetadata(MLACommonDecodeMetadata): + tile_scheduler_metadata: torch.Tensor + num_splits: torch.Tensor + topk_seq_lens: torch.Tensor + topk_tile_scheduler_metadata: torch.Tensor + topk_num_splits: torch.Tensor + topk_block_table: torch.Tensor = None + + # Set module and qualname to make the class pickleable + # This ensures pickle can find the class when serializing + FlashMLADecodeMetadata.__module__ = flashmla.__name__ + FlashMLADecodeMetadata.__qualname__ = "FlashMLADecodeMetadata" + + flashmla.FlashMLADecodeMetadata = FlashMLADecodeMetadata + + def patched_build_decode( + self, block_table_tensor: torch.Tensor, seq_lens: torch.Tensor + ) -> FlashMLADecodeMetadata: + tile_scheduler_metadata, num_splits = get_mla_metadata( + seq_lens, + self.num_q_heads, + 1, # MQA for the decode path + ) + topk_seq_lens = None + topk_tile_scheduler_metadata = None + topk_num_splits = None + if has_ucm_sparse(): + ucm_sparse = get_ucm_sparse() + if os.getenv("VLLM_HASH_ATTENTION") == "1": + topk_seq_lens, topk_tile_scheduler_metadata, topk_num_splits = ( + ucm_sparse.build_decode_hash(seq_lens) + ) + + if self.runner.full_cuda_graph: + # First time around (CUDAGraph capture), allocate the static buffer + if self.cg_buf_tile_scheduler_metadata is None: + self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata + self.cg_buf_num_splits = num_splits + else: + assert self.cg_buf_num_splits is not None + + # Metadata per-SM, fixed size (#SMs, TileMetadataSize) + assert ( + self.cg_buf_tile_scheduler_metadata.size() + == tile_scheduler_metadata.size() + ) + self.cg_buf_tile_scheduler_metadata.copy_(tile_scheduler_metadata) + tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata + + # Num splits is per-batch, varying size (batch_size,) + n = num_splits.size(0) + # make sure static buffer is large enough + assert n <= self.cg_buf_num_splits.size(0) + num_splits_view = self.cg_buf_num_splits[:n] + num_splits_view.copy_(num_splits) + self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s + num_splits = num_splits_view + topk_tile_scheduler_metadata, topk_num_splits = ( + ucm_sparse.maybe_init_cudagraph_buffers_for_topk( + n, tile_scheduler_metadata + ) + ) + + return FlashMLADecodeMetadata( + block_table=block_table_tensor, + seq_lens=seq_lens, + tile_scheduler_metadata=tile_scheduler_metadata, + num_splits=num_splits, + topk_seq_lens=topk_seq_lens, + topk_tile_scheduler_metadata=topk_tile_scheduler_metadata, + topk_num_splits=topk_num_splits, + ) + + flashmla.FlashMLAMetadataBuilder._build_decode = patched_build_decode + + except ImportError: + logger.warning("Could not patch flash mla - module not found") + + +# ==================== vllm/v1/core/kv_cache_utils.py ==================== +def _patch_kv_cache_utils() -> None: + """Patch kv_cache_utils.py for vLLM.""" + try: + from vllm.config import VllmConfig + from vllm.v1.core import kv_cache_utils + from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + KVCacheSpec, + KVCacheTensor, + SlidingWindowSpec, + ) + + def _patch_get_kv_cache_config_uniform_type( + vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int, + ) -> KVCacheConfig: + """ + Generates the KV cache configuration for a model with one type of KV cache. + Divide the available memory equally among all layers. + + Args: + vllm_config: The global VllmConfig + kv_cache_spec: The kv cache spec of each attention layer in the model + available_memory: Memory available for KV cache in bytes. + + Returns: + The generated KVCacheConfig + """ + + page_size = kv_cache_utils.get_uniform_page_size(kv_cache_spec) + num_blocks = kv_cache_utils.get_num_blocks( + vllm_config, len(kv_cache_spec), available_memory, page_size + ) + + if os.getenv("VLLM_HASH_ATTENTION") == "1": + from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE + + if vllm_config.cache_config.cache_dtype == "auto": + dtype = vllm_config.model_config.dtype + else: + dtype = STR_DTYPE_TO_TORCH_DTYPE[ + vllm_config.cache_config.cache_dtype + ] + khash_scale = dtype.itemsize * 8 + new_num_blocks = num_blocks * khash_scale // (khash_scale + 1) + logger.info( + "[HASH_ATTN] reduce num_blocks from %d to %d to allocate khash_cache", + num_blocks, + new_num_blocks, + ) + num_blocks = new_num_blocks + + per_layer_size = page_size * num_blocks + # All layers have the same KV cache spec, so we create one kv cache group + # for all layers. + grouped_layer_names = [list(kv_cache_spec.keys())] + + # Each layer uses a separate Tensor to store its KV cache. + kv_cache_tensors = [ + KVCacheTensor(size=per_layer_size, shared_by=[layer_name]) + for layer_name in kv_cache_spec + ] + + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=kv_cache_tensors, + kv_cache_groups=kv_cache_utils.create_kv_cache_group_specs( + kv_cache_spec, grouped_layer_names + ), + ) + + num_tokens = num_blocks * vllm_config.cache_config.block_size + num_tokens_str = f"{num_tokens:,}" + logger.info("GPU KV cache size: %s tokens", num_tokens_str) + max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" + max_concurrency = kv_cache_utils.get_max_concurrency_for_kv_cache_config( + vllm_config, kv_cache_config + ) + logger.info( + "Maximum concurrency for %s tokens per request: %.2fx", + max_model_len_str, + max_concurrency, + ) + return kv_cache_config + + kv_cache_utils._get_kv_cache_config_uniform_type = ( + _patch_get_kv_cache_config_uniform_type + ) + except ImportError: + logger.warning("Could not patch kv_cache_utils modelr - module not found") + + # ==================== vllm/model_executor/models/llama.py ==================== def _patch_llama_model() -> None: """Patch gpu worker to add UCM sparse support.""" diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index 6253fc027..68318caf1 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -33,6 +33,8 @@ from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request +from ucm.sparse.state import has_ucm_sparse + logger = init_logger(__name__) @@ -225,7 +227,7 @@ def _generate_storage_backends( return backends def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): - if os.getenv("VLLM_HASH_ATTENTION", "0") == "1": + if has_ucm_sparse() and os.getenv("VLLM_HASH_ATTENTION") == "1": for layer_name, value in kv_caches.items(): kv_cache, k_hash = value self.kv_caches[layer_name] = kv_cache