diff --git a/docs/source/getting-started/quickstart_vllm.md b/docs/source/getting-started/quickstart_vllm.md index e3fefb549..e808e1772 100644 --- a/docs/source/getting-started/quickstart_vllm.md +++ b/docs/source/getting-started/quickstart_vllm.md @@ -2,7 +2,7 @@ This document describes how to install unified-cache-management with vllm on cuda platform. ## Prerequisites -- vllm >=0.9.1, device=cuda (vllm == 0.9.2 to use the Sparse Feature) +- vllm >=0.9.1, device=cuda (Sparse Feature is supported in vllm 0.9.2 and v0.11.0) ## Step 1: UCM Installation @@ -44,6 +44,7 @@ Download the pre-built `vllm/vllm-openai:v0.9.2` docker image and build unified- 1. Prepare vLLM Environment For the sake of environment isolation and simplicity, we recommend preparing the vLLM environment by pulling the official, pre-built vLLM Docker image. + > Note: v0.11.0 is newly supported (replace the tag with v0.11.0 if needed). ```bash docker pull vllm/vllm-openai:v0.9.2 @@ -87,6 +88,15 @@ Download the pre-built `vllm/vllm-openai:v0.9.2` docker image and build unified- ``` Apply the patch that matches your development needs: + #### vLLM 0.11.0 + + Note: v0.11.0 only requires the sparse attention patch. + + ```bash + git apply unified-cache-management/ucm/integration/vllm/patch/0.11.0/vllm-adapt-sparse.patch + ``` + + #### vLLM 0.9.2 - Full UCM integration (recommended): ```bash git apply unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch diff --git a/examples/offline_inference_kvcomphbm.py b/examples/offline_inference_gsa_on_device.py similarity index 98% rename from examples/offline_inference_kvcomphbm.py rename to examples/offline_inference_gsa_on_device.py index fc5d5142d..2de5b0a52 100644 --- a/examples/offline_inference_kvcomphbm.py +++ b/examples/offline_inference_gsa_on_device.py @@ -77,7 +77,7 @@ def build_llm_with_uc(module_path: str, name: str, model: str): }, } ], - "ucm_sparse_config": {"KvCompOnDevice": {}}, + "ucm_sparse_config": {"GSAOnDevice": {}}, }, ) diff --git a/ucm/integration/vllm/patch/0.11.0/vllm-adapt-sparse.patch b/ucm/integration/vllm/patch/0.11.0/vllm-adapt-sparse.patch new file mode 100644 index 000000000..0933c13f8 --- /dev/null +++ b/ucm/integration/vllm/patch/0.11.0/vllm-adapt-sparse.patch @@ -0,0 +1,865 @@ +From d886cedc2bf71d685dfe292102a26a5464b4f9c1 Mon Sep 17 00:00:00 2001 +From: AooooooA-C +Date: Thu, 8 Jan 2026 00:09:57 -0800 +Subject: [PATCH] apply sparse method patches + +--- + vllm/attention/layer.py | 65 +++++++++++++++- + vllm/model_executor/models/llama.py | 24 ++++++ + vllm/model_executor/models/qwen2.py | 24 ++++++ + vllm/v1/attention/backends/flash_attn.py | 11 +++ + vllm/v1/attention/backends/mla/common.py | 21 ++++++ + vllm/v1/attention/backends/mla/flashmla.py | 18 +++++ + vllm/v1/core/kv_cache_manager.py | 10 ++- + vllm/v1/core/kv_cache_utils.py | 14 ++++ + vllm/v1/core/sched/output.py | 3 + + vllm/v1/core/sched/scheduler.py | 36 ++++++++- + vllm/v1/worker/block_table.py | 13 ++++ + vllm/v1/worker/gpu_model_runner.py | 87 +++++++++++++++++++--- + vllm/v1/worker/gpu_worker.py | 3 + + 13 files changed, 315 insertions(+), 14 deletions(-) + +diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py +index 79879b680..06fa41a4d 100644 +--- a/vllm/attention/layer.py ++++ b/vllm/attention/layer.py +@@ -3,6 +3,7 @@ + """Attention layer.""" + from typing import List, Optional + ++import os + import torch + import torch.nn as nn + import torch.nn.functional as F +@@ -30,6 +31,8 @@ from vllm.model_executor.models.vision import get_vit_attn_backend + from vllm.platforms import _Backend, current_platform + from vllm.utils import GiB_bytes, direct_register_custom_op + ++from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse ++ + logger = init_logger(__name__) + USE_XFORMERS_OPS = None + try: +@@ -571,9 +574,11 @@ def unified_attention( + attn_metadata = attn_metadata[layer_name] + self = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] ++ query, key, value, _ = maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context) + output = self.impl.forward(self, query, key, value, kv_cache, + attn_metadata) +- ++ ++ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context) + maybe_save_kv_layer_to_connector(layer_name, kv_cache) + return output + +@@ -611,6 +616,16 @@ def unified_attention_with_output( + attn_metadata = attn_metadata[layer_name] + self = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] ++ if not self.use_mla: ++ if attn_metadata is not None: ++ if os.getenv("VLLM_HASH_ATTENTION") == "1": ++ kv_cache, k_hash = kv_cache ++ else: ++ k_hash = None ++ query, key, value, output = maybe_execute_sparse_attention_begin( ++ query, key, value, layer_name, forward_context, output, k_hash=k_hash ++ ) ++ + self.impl.forward(self, + query, + key, +@@ -621,6 +636,9 @@ def unified_attention_with_output( + output_scale=output_scale, + output_block_scale=output_block_scale) + ++ if not self.use_mla: ++ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context) ++ + maybe_save_kv_layer_to_connector(layer_name, kv_cache) + + +@@ -643,3 +661,48 @@ direct_register_custom_op( + fake_impl=unified_attention_with_output_fake, + tags=tag_cudagraph_unsafe, + ) ++ ++def maybe_execute_sparse_attention_begin( ++ query: torch.Tensor, ++ key: torch.Tensor, ++ value: torch.Tensor, ++ layer_name: str, ++ forward_context: ForwardContext, ++ output: Optional[torch.Tensor] = None, ++ phase: Optional[str] = None, ++ k_hash: Optional[torch.Tensor] = None, ++ decode_ql_nope: Optional[torch.Tensor] = None, ++ decode_q_pe: Optional[torch.Tensor] = None, ++): ++ if not has_ucm_sparse(): ++ return query, key, value, output ++ ++ ucm_sparse = get_ucm_sparse() ++ ++ attn_metadata = forward_context.attn_metadata ++ if attn_metadata is None: ++ return query, key, value, output ++ ++ return ucm_sparse.attention_begin( ++ query, key, value, layer_name, forward_context, output, phase, k_hash, decode_ql_nope, decode_q_pe ++ ) ++ ++def maybe_execute_sparse_attention_finished( ++ query: torch.Tensor, ++ key: torch.Tensor, ++ value: torch.Tensor, ++ attn_output: torch.Tensor, ++ layer_name: str, ++ forward_context: ForwardContext, ++ phase: Optional[str] = None, ++): ++ if not has_ucm_sparse(): ++ return ++ ++ ucm_sparse = get_ucm_sparse() ++ ++ attn_metadata = forward_context.attn_metadata ++ if attn_metadata is None: ++ return ++ ++ ucm_sparse.attention_finished(query, key, value, attn_output, layer_name, forward_context, phase) +\ No newline at end of file +diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py +index c7dd134ea..db64f258c 100644 +--- a/vllm/model_executor/models/llama.py ++++ b/vllm/model_executor/models/llama.py +@@ -56,6 +56,12 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + ++from ucm.sparse.state import( ++ maybe_execute_sparse_ffn_begin, ++ maybe_execute_sparse_ffn_finished, ++ maybe_execute_sparse_layer_begin, ++ maybe_execute_sparse_layer_finished, ++) + + class LlamaMLP(nn.Module): + +@@ -322,10 +328,19 @@ class LlamaDecoderLayer(nn.Module): + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states) + ++ hidden_states, residual = maybe_execute_sparse_ffn_begin( ++ hidden_states, residual ++ ) ++ + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) ++ ++ hidden_states, residual = maybe_execute_sparse_ffn_finished( ++ hidden_states, residual ++ ) ++ + return hidden_states, residual + + +@@ -400,10 +415,19 @@ class LlamaModel(nn.Module): + aux_hidden_states = [] + for idx, layer in enumerate( + islice(self.layers, self.start_layer, self.end_layer)): ++ ++ positions, hidden_states, residual = maybe_execute_sparse_layer_begin( ++ positions, hidden_states, residual ++ ) ++ + if idx in self.aux_hidden_state_layers: + aux_hidden_states.append(hidden_states + residual) + hidden_states, residual = layer(positions, hidden_states, residual) + ++ positions, hidden_states, residual = maybe_execute_sparse_layer_finished( ++ positions, hidden_states, residual ++ ) ++ + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, +diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py +index c536b0f60..dc9460057 100644 +--- a/vllm/model_executor/models/qwen2.py ++++ b/vllm/model_executor/models/qwen2.py +@@ -58,6 +58,12 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + ++from ucm.sparse.state import( ++ maybe_execute_sparse_ffn_begin, ++ maybe_execute_sparse_ffn_finished, ++ maybe_execute_sparse_layer_begin, ++ maybe_execute_sparse_layer_finished, ++) + + class Qwen2MLP(nn.Module): + +@@ -259,10 +265,19 @@ class Qwen2DecoderLayer(nn.Module): + hidden_states=hidden_states, + ) + ++ hidden_states, residual = maybe_execute_sparse_ffn_begin( ++ hidden_states, residual ++ ) ++ + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) ++ ++ hidden_states, residual = maybe_execute_sparse_ffn_finished( ++ hidden_states, residual ++ ) ++ + return hidden_states, residual + + +@@ -361,8 +376,17 @@ class Qwen2Model(nn.Module): + islice(self.layers, self.start_layer, self.end_layer)): + if idx in self.aux_hidden_state_layers: + aux_hidden_states.append(hidden_states + residual) ++ ++ positions, hidden_states, residual = maybe_execute_sparse_layer_begin( ++ positions, hidden_states, residual ++ ) ++ + hidden_states, residual = layer(positions, hidden_states, residual) + ++ positions, hidden_states, residual = maybe_execute_sparse_layer_finished( ++ positions, hidden_states, residual ++ ) ++ + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, +diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py +index f0770f744..717dc836e 100755 +--- a/vllm/v1/attention/backends/flash_attn.py ++++ b/vllm/v1/attention/backends/flash_attn.py +@@ -6,6 +6,7 @@ from typing import Optional + + import numpy as np + import torch ++import os + + from vllm import envs + from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, +@@ -31,6 +32,8 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport, + get_kv_cache_layout) + from vllm.v1.kv_cache_interface import AttentionSpec + ++from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse ++ + logger = init_logger(__name__) + + +@@ -236,6 +239,14 @@ class FlashAttentionMetadataBuilder( + seq_lens = common_attn_metadata.seq_lens + seq_lens_cpu = common_attn_metadata.seq_lens_cpu + block_table_tensor = common_attn_metadata.block_table_tensor ++ ++ if has_ucm_sparse(): ++ ucm_sparse = get_ucm_sparse() ++ if os.getenv("VLLM_HASH_ATTENTION") == "1": ++ decode_mask, topk_seq_lens = ucm_sparse.build_decode_attention_meta( ++ query_start_loc, seq_lens, block_table_tensor ++ ) ++ + slot_mapping = common_attn_metadata.slot_mapping + causal = common_attn_metadata.causal + +diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py +index 963f1c5ab..c5667aadc 100755 +--- a/vllm/v1/attention/backends/mla/common.py ++++ b/vllm/v1/attention/backends/mla/common.py +@@ -194,6 +194,7 @@ from typing import Generic, Optional, TypeVar, Union + + import torch + from tqdm import tqdm ++import os + + import vllm.envs as envs + from vllm import _custom_ops as ops +@@ -220,6 +221,10 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + split_decodes_and_prefills) + from vllm.v1.kv_cache_interface import AttentionSpec + ++from vllm.forward_context import ForwardContext, get_forward_context ++from vllm.attention.layer import (maybe_execute_sparse_attention_begin, ++ maybe_execute_sparse_attention_finished) ++ + try: + from vllm.vllm_flash_attn import flash_attn_varlen_func + is_vllm_fa = True +@@ -1640,6 +1645,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): + ) -> torch.Tensor: + assert output is not None, "Output tensor must be provided." + ++ forward_context: ForwardContext = get_forward_context() ++ + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" +@@ -1689,6 +1696,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): + prefill_k_pe = k_pe[num_decode_tokens:] + prefill_k_c_normed = k_c_normed[num_decode_tokens:] + ++ if os.getenv("VLLM_HASH_ATTENTION") == "1": ++ kv_cache, k_hash = kv_cache ++ else: ++ k_hash = None ++ + # write the latent and rope to kv cache + if kv_cache.numel() > 0: + ops.concat_and_cache_mla( +@@ -1704,10 +1716,15 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): + kv_cache = kv_cache.view(current_platform.fp8_dtype()) + + if has_prefill: ++ ++ prefill_q, k_c_normed, k_pe, output = maybe_execute_sparse_attention_begin(prefill_q, k_c_normed, k_pe, layer.layer_name, forward_context, output=output, phase="prefill", k_hash=k_hash) ++ + output[num_decode_tokens:] = self._forward_prefill( + prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, + attn_metadata, layer._k_scale) + ++ maybe_execute_sparse_attention_finished(prefill_q, prefill_k_c_normed, prefill_k_pe, output[num_decode_tokens:], layer.layer_name, forward_context, "prefill") ++ + if has_decode: + assert attn_metadata.decode is not None + decode_q_nope, decode_q_pe = decode_q.split( +@@ -1771,8 +1788,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): + decode_q = get_dcp_group().all_gather(decode_q, dim=1) + + # call decode attn ++ _, k_c_normed, k_pe, output = maybe_execute_sparse_attention_begin(torch.cat([decode_ql_nope, decode_q_pe],dim=-1), k_c_normed, k_pe, layer.layer_name, forward_context, output=output, phase="decode", k_hash=k_hash, decode_ql_nope=decode_ql_nope, decode_q_pe=decode_q_pe) ++ + attn_out, lse = self._forward_decode(decode_q, kv_cache, + attn_metadata, layer) ++ ++ maybe_execute_sparse_attention_finished(torch.cat([decode_ql_nope, decode_q_pe],dim=-1), decode_ql_nope, decode_q_pe, output[:num_decode_tokens], layer.layer_name, forward_context, "decode") + + # recorect dcp attn_out with lse. + if self.dcp_world_size > 1: +diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py +index 67c21f83c..43c26d333 100644 +--- a/vllm/v1/attention/backends/mla/flashmla.py ++++ b/vllm/v1/attention/backends/mla/flashmla.py +@@ -5,6 +5,7 @@ from dataclasses import dataclass + from typing import ClassVar, Optional, Union + + import torch ++import os + + from vllm.attention.backends.abstract import AttentionLayer, AttentionType + from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, +@@ -20,6 +21,8 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend, + from vllm.v1.attention.backends.utils import AttentionCGSupport + from vllm.v1.kv_cache_interface import AttentionSpec + ++from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse ++ + logger = init_logger(__name__) + + +@@ -46,6 +49,10 @@ class FlashMLABackend(MLACommonBackend): + class FlashMLADecodeMetadata(MLACommonDecodeMetadata): + tile_scheduler_metadata: torch.Tensor + num_splits: torch.Tensor ++ topk_seq_lens: torch.Tensor ++ topk_tile_scheduler_metadata: torch.Tensor ++ topk_num_splits: torch.Tensor ++ topk_block_table: torch.Tensor = None + + + @dataclass +@@ -96,6 +103,13 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): + self.num_q_heads, + 1, # MQA for the decode path + ) ++ topk_seq_lens = None ++ topk_tile_scheduler_metadata = None ++ topk_num_splits = None ++ if has_ucm_sparse(): ++ ucm_sparse = get_ucm_sparse() ++ if os.getenv("VLLM_HASH_ATTENTION") == "1": ++ topk_seq_lens, topk_tile_scheduler_metadata, topk_num_splits = ucm_sparse.build_decode_hash(seq_lens_device) + + # TODO: we can disambiguate between decode and mixed-prefill decode here + # so we can only use the persistent buffer if a cudagraph is actually +@@ -123,12 +137,16 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): + # it needs to monotonically increasing by 1) + self.cg_buf_num_splits[n:].fill_(num_splits[-1]) + num_splits = num_splits_view ++ topk_tile_scheduler_metadata, topk_num_splits = ucm_sparse.maybe_init_cudagraph_buffers_for_topk(n, tile_scheduler_metadata) + + return FlashMLADecodeMetadata( + block_table=block_table_tensor, + seq_lens=seq_lens_device, + tile_scheduler_metadata=tile_scheduler_metadata, + num_splits=num_splits, ++ topk_seq_lens=topk_seq_lens, ++ topk_tile_scheduler_metadata=topk_tile_scheduler_metadata, ++ topk_num_splits=topk_num_splits, + ) + + +diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py +index 401327f72..eb7acf00d 100644 +--- a/vllm/v1/core/kv_cache_manager.py ++++ b/vllm/v1/core/kv_cache_manager.py +@@ -2,7 +2,7 @@ + # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + from dataclasses import dataclass +-from typing import Literal, Optional, overload ++from typing import Literal, Optional, overload, Union + + from vllm.distributed.kv_events import KVCacheEvent + from vllm.logger import init_logger +@@ -12,6 +12,9 @@ from vllm.v1.kv_cache_interface import KVCacheConfig + from vllm.v1.metrics.stats import PrefixCacheStats + from vllm.v1.request import Request, RequestStatus + ++from ucm.sparse.state import get_ucm_sparse ++from ucm.sparse.base import INVALID_SLOT ++ + logger = init_logger(__name__) + + +@@ -199,6 +202,7 @@ class KVCacheManager: + num_lookahead_tokens: int = 0, + delay_cache_blocks: bool = False, + num_encoder_tokens: int = 0, ++ num_slots_sparsed: Union[None, int] = None + ) -> Optional[KVCacheBlocks]: + """Add slots for a request with new tokens to append. + +@@ -238,6 +242,10 @@ class KVCacheManager: + if num_new_tokens == 0: + raise ValueError("num_new_tokens must be greater than 0") + ++ if num_slots_sparsed != INVALID_SLOT: ++ return get_ucm_sparse().allocate_slots(self, request, num_slots_sparsed) ++ ++ + if new_computed_blocks is not None: + new_computed_block_list = new_computed_blocks.blocks + else: +diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py +index 2ff1bb681..dca14b5d3 100644 +--- a/vllm/v1/core/kv_cache_utils.py ++++ b/vllm/v1/core/kv_cache_utils.py +@@ -1035,6 +1035,20 @@ def get_kv_cache_config_from_groups(vllm_config: VllmConfig, + num_blocks = available_memory // kv_cache_groups[ + 0].kv_cache_spec.page_size_bytes + num_blocks = may_override_num_blocks(vllm_config, num_blocks) ++ ++ if os.getenv("VLLM_HASH_ATTENTION") == "1": ++ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE ++ ++ if vllm_config.cache_config.cache_dtype == 'auto': ++ dtype = vllm_config.model_config.dtype ++ else: ++ dtype = STR_DTYPE_TO_TORCH_DTYPE[vllm_config.cache_config.cache_dtype] ++ khash_scale = dtype.itemsize * 8 ++ new_num_blocks = num_blocks * khash_scale // (khash_scale + 1) ++ logger.info("[HASH_ATTN] reduce num_blocks from %d to %d to allocate khash_cache", ++ num_blocks, new_num_blocks) ++ num_blocks = new_num_blocks ++ + per_layer_specs = kv_cache_groups[0].kv_cache_spec.kv_cache_specs + kv_cache_tensors = [ + KVCacheTensor(size=per_layer_specs[layer_name].page_size_bytes * +diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py +index 209fc2a44..4c7617ca6 100644 +--- a/vllm/v1/core/sched/output.py ++++ b/vllm/v1/core/sched/output.py +@@ -164,3 +164,6 @@ class SchedulerOutput: + + # KV Cache Connector metadata. + kv_connector_metadata: Optional[KVConnectorMetadata] = None ++ ++ # modified slots by sparse algorithm ++ req_sparsed_slots: dict[str, int] = None +diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py +index 2b2cd63c2..e42ed67b1 100644 +--- a/vllm/v1/core/sched/scheduler.py ++++ b/vllm/v1/core/sched/scheduler.py +@@ -36,6 +36,10 @@ from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput + from vllm.v1.request import Request, RequestStatus + from vllm.v1.spec_decode.metrics import SpecDecodingStats + from vllm.v1.structured_output import StructuredOutputManager ++from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import MultiConnector ++from ucm.utils import Config ++from ucm.sparse.state import ensure_ucm_sparse_initialized, get_ucm_sparse ++from ucm.sparse.base import UcmSparseRole, INVALID_SLOT + + logger = init_logger(__name__) + +@@ -82,6 +86,7 @@ class Scheduler(SchedulerInterface): + # will have a corresponding KVConnector with Role=WORKER. + # KV Connector pushes/pull of remote KVs for P/D and offloading. + self.connector = None ++ self.ucm_sparse = None + if self.vllm_config.kv_transfer_config is not None: + assert len(self.kv_cache_config.kv_cache_groups) == 1, ( + "Multiple KV cache groups are not currently supported " +@@ -91,6 +96,14 @@ class Scheduler(SchedulerInterface): + "with KV connectors") + self.connector = KVConnectorFactory.create_connector( + config=self.vllm_config, role=KVConnectorRole.SCHEDULER) ++ ++ # Initialize UCM Sparse if available ++ ucm_config = Config(self.vllm_config.kv_transfer_config) ++ ucm_sparse_config = ucm_config.get_config().get("ucm_sparse_config") ++ if ucm_sparse_config: ++ ensure_ucm_sparse_initialized(vllm_config, role=UcmSparseRole.SCHEDULER) ++ self.ucm_sparse = get_ucm_sparse() ++ logger.info("UCM Sparse initialized successfully: {}".format(self.ucm_sparse)) + + self.kv_event_publisher = EventPublisherFactory.create( + self.kv_events_config, +@@ -207,9 +220,15 @@ class Scheduler(SchedulerInterface): + + # First, schedule the RUNNING requests. + req_index = 0 ++ req_sparsed_slots: dict[str, int] = {} + while req_index < len(self.running) and token_budget > 0: + request = self.running[req_index] + ++ num_slots_sparsed = INVALID_SLOT ++ if self.ucm_sparse: ++ num_slots_sparsed = self.ucm_sparse.estimate_num_slots_sparsed(request) ++ req_sparsed_slots.update({request.request_id: num_slots_sparsed}) ++ + num_new_tokens = (request.num_tokens_with_spec + + request.num_output_placeholders - + request.num_computed_tokens) +@@ -255,7 +274,8 @@ class Scheduler(SchedulerInterface): + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens, +- num_lookahead_tokens=self.num_lookahead_tokens) ++ num_lookahead_tokens=self.num_lookahead_tokens, ++ num_slots_sparsed=num_slots_sparsed) + if new_blocks is None: + # The request cannot be scheduled. + # Preempt the lowest-priority request. +@@ -339,6 +359,11 @@ class Scheduler(SchedulerInterface): + + request = self.waiting.peek_request() + ++ num_slots_sparsed = INVALID_SLOT ++ if self.ucm_sparse: ++ num_slots_sparsed = self.ucm_sparse.estimate_num_slots_sparsed(request) ++ req_sparsed_slots.update({request.request_id: num_slots_sparsed}) ++ + # KVTransfer: skip request if still waiting for remote kvs. + if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: + is_ready = self._update_waiting_for_remote_kv(request) +@@ -476,6 +501,7 @@ class Scheduler(SchedulerInterface): + num_lookahead_tokens=effective_lookahead_tokens, + delay_cache_blocks=load_kv_async, + num_encoder_tokens=num_encoder_tokens, ++ num_slots_sparsed=num_slots_sparsed + ) + + if new_blocks is None: +@@ -587,6 +613,7 @@ class Scheduler(SchedulerInterface): + scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, + scheduled_encoder_inputs=scheduled_encoder_inputs, + num_common_prefix_blocks=num_common_prefix_blocks, ++ req_sparsed_slots=req_sparsed_slots, + # finished_req_ids is an existing state in the scheduler, + # instead of being newly scheduled in this step. + # It contains the request IDs that are finished in between +@@ -1097,6 +1124,10 @@ class Scheduler(SchedulerInterface): + def add_request(self, request: Request) -> None: + self.waiting.add_request(request) + self.requests[request.request_id] = request ++ ++ if self.ucm_sparse: ++ self.ucm_sparse.request_begin(request.request_id, request.prompt_token_ids) ++ + if self.log_stats: + request.record_event(EngineCoreEventType.QUEUED) + +@@ -1147,6 +1178,9 @@ class Scheduler(SchedulerInterface): + def _free_request(self, request: Request) -> Optional[dict[str, Any]]: + assert request.is_finished() + ++ if self.ucm_sparse: ++ self.ucm_sparse.request_finished_in_scheduler(request.request_id) ++ + delay_free_blocks, kv_xfer_params = self._connector_finished(request) + self.encoder_cache_manager.free(request) + request_id = request.request_id +diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py +index 82b6d1b51..85cdcb7ac 100644 +--- a/vllm/v1/worker/block_table.py ++++ b/vllm/v1/worker/block_table.py +@@ -58,6 +58,15 @@ class BlockTable: + self.num_blocks_per_row[row_idx] += num_blocks + self.block_table.np[row_idx, start:start + num_blocks] = block_ids + ++ def reset_row( ++ self, ++ row_idx: int, ++ ) -> None: ++ self.num_blocks_per_row[row_idx] = 0 ++ self.block_table.gpu[row_idx].fill_(0) ++ self.block_table.cpu[row_idx].fill_(0) ++ self.block_table.np[row_idx].fill(0) ++ + def add_row(self, block_ids: list[int], row_idx: int) -> None: + self.num_blocks_per_row[row_idx] = 0 + self.append_row(block_ids, row_idx) +@@ -176,6 +185,10 @@ class MultiGroupBlockTable: + for i, block_table in enumerate(self.block_tables): + block_table.append_row(block_ids[i], row_idx) + ++ def reset_row(self, row_idx: int) -> None: ++ for i, block_table in enumerate(self.block_tables): ++ block_table.reset_row(row_idx) ++ + def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: + for i, block_table in enumerate(self.block_tables): + block_table.add_row(block_ids[i], row_idx) +diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py +index a438c7777..28f619c8f 100644 +--- a/vllm/v1/worker/gpu_model_runner.py ++++ b/vllm/v1/worker/gpu_model_runner.py +@@ -1,6 +1,7 @@ + # SPDX-License-Identifier: Apache-2.0 + # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + ++import os + import gc + import itertools + import time +@@ -114,6 +115,9 @@ from .utils import (AttentionGroup, MultiModalBudget, + gather_mm_placeholders, sanity_check_mm_encoder_outputs, + scatter_mm_placeholders) + ++from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse ++from ucm.sparse.base import INVALID_SLOT ++ + if TYPE_CHECKING: + from vllm.model_executor.model_loader.tensorizer import TensorizerConfig + from vllm.v1.core.sched.output import SchedulerOutput +@@ -535,6 +539,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): + """ + # Remove finished requests from the cached states. + for req_id in scheduler_output.finished_req_ids: ++ self.ucm_sparse_request_finished_in_worker(req_id) + self.requests.pop(req_id, None) + # Remove the finished requests from the persistent batch. + # NOTE(woosuk): There could be an edge case where finished_req_ids and +@@ -611,11 +616,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): + # Update the states of the running/resumed requests. + is_last_rank = get_pp_group().is_last_rank + req_data = scheduler_output.scheduled_cached_reqs ++ req_sparsed_slots = scheduler_output.req_sparsed_slots + for i, req_id in enumerate(req_data.req_ids): + req_state = self.requests[req_id] + num_computed_tokens = req_data.num_computed_tokens[i] + new_block_ids = req_data.new_block_ids[i] + resumed_from_preemption = req_data.resumed_from_preemption[i] ++ is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT + + # Update the cached states. + req_state.num_computed_tokens = num_computed_tokens +@@ -637,17 +644,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): + new_token_ids[-num_new_tokens:]) + + # Update the block IDs. +- if not resumed_from_preemption: ++ if resumed_from_preemption or is_sparsed_request: ++ # The request is resumed from preemption. ++ # Replace the existing block IDs with the new ones. ++ req_state.block_ids = new_block_ids ++ else: + if new_block_ids is not None: + # Append the new blocks to the existing block IDs. + for block_ids, new_ids in zip(req_state.block_ids, + new_block_ids): + block_ids.extend(new_ids) +- else: +- assert new_block_ids is not None +- # The request is resumed from preemption. +- # Replace the existing block IDs with the new ones. +- req_state.block_ids = new_block_ids + + req_index = self.input_batch.req_id_to_index.get(req_id) + if req_index is None: +@@ -660,6 +666,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): + # Update the persistent batch. + self.input_batch.num_computed_tokens_cpu[req_index] = ( + num_computed_tokens) ++ ++ if is_sparsed_request: ++ self.input_batch.block_table.reset_row(req_index) ++ + if new_block_ids is not None: + self.input_batch.block_table.append_row( + new_block_ids, req_index) +@@ -968,6 +978,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): + if self.uses_mrope: + self._calc_mrope_positions(scheduler_output) + ++ self.seq_lens.np[:num_reqs] = ( ++ self.input_batch.num_computed_tokens_cpu[:num_reqs] + ++ num_scheduled_tokens) ++ ++ # TODO: improve performance, no `positions_np.copy()` ++ sparsed_positions = positions_np.copy() ++ req_sparsed_slots = scheduler_output.req_sparsed_slots ++ for req_id in self.input_batch.req_id_to_index: ++ is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT ++ req_index = self.input_batch.req_id_to_index[req_id] ++ offset = 0 if req_index == 0 else cu_num_tokens[req_index - 1] # TODO: support MTP ++ if is_sparsed_request: ++ sparsed_positions[offset] = req_sparsed_slots[req_id] - 1 ++ + # Get token indices. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] +@@ -1031,7 +1055,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): + output_idx += num_sched + + self.input_batch.block_table.compute_slot_mapping( +- req_indices, positions_np) ++ req_indices, sparsed_positions) + self.input_batch.block_table.commit_slot_mapping( + total_num_scheduled_tokens) + +@@ -1057,9 +1081,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): + uniform_decode=uniform_decode, + vllm_config=self.vllm_config) + +- self.seq_lens.np[:num_reqs] = ( +- self.input_batch.num_computed_tokens_cpu[:num_reqs] + +- num_scheduled_tokens) ++ is_sparsed_np = np.zeros((num_reqs,), dtype=np.bool_) ++ for req_id in self.input_batch.req_id_to_index: ++ req_index = self.input_batch.req_id_to_index[req_id] ++ is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT ++ if is_sparsed_request: ++ self.seq_lens.np[req_index] = req_sparsed_slots[req_id] ++ is_sparsed_np[req_index] = True ++ + # Fill unused with 0 for full cuda graph mode. + self.seq_lens.np[num_reqs:].fill(0) + self.seq_lens.copy_to_gpu() +@@ -1073,7 +1102,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): + + # Record the index of requests that should not be sampled, + # so that we could clear the sampled tokens before returning +- discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np ++ discard_requests_mask = (self.seq_lens.np[:num_reqs] < num_tokens_np) & (~is_sparsed_np) + discard_request_indices = np.nonzero(discard_requests_mask)[0] + self.num_discarded_requests = len(discard_request_indices) + self.discard_request_indices.np[:self.num_discarded_requests] = ( +@@ -1091,6 +1120,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): + non_blocking=True) + else: + # Common case (1D positions) ++ self.positions.cpu[:total_num_scheduled_tokens] = torch.from_numpy( ++ self.positions.np[:total_num_scheduled_tokens]) + self.positions.copy_to_gpu(total_num_scheduled_tokens) + + use_spec_decode = len( +@@ -2295,6 +2326,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): + ), record_function_or_nullcontext("Forward"), + self.maybe_get_kv_connector_output(scheduler_output) as + kv_connector_output): ++ ++ self.maybe_execute_ucm_sparse_begin(scheduler_output, attn_metadata) ++ + model_output = self.model( + input_ids=input_ids, + positions=positions, +@@ -2303,6 +2337,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): + **model_kwargs, + ) + ++ logits_indices = self.maybe_execute_ucm_sparse_finished(logits_indices) ++ + with record_function_or_nullcontext("Postprocess"): + if self.use_aux_hidden_state_outputs: + # True when EAGLE 3 is used. +@@ -2584,6 +2620,30 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): + ) + return draft_token_ids + ++ def maybe_execute_ucm_sparse_begin(self, scheduler_output: "SchedulerOutput", attn_metadata: CommonAttentionMetadata): ++ if not has_ucm_sparse(): ++ return ++ if has_kv_transfer_group(): ++ uc_connector = get_kv_transfer_group() ++ uc_setup_model = getattr(uc_connector, "setup_model", None) ++ if callable(uc_setup_model): ++ uc_setup_model(self.model) ++ ucm_sparse = get_ucm_sparse() ++ ucm_sparse.build_sparse_meta(scheduler_output, self.requests, self.input_batch, attn_metadata) ++ ucm_sparse.execute_begin(scheduler_output) ++ ++ def maybe_execute_ucm_sparse_finished(self, logits_indices): ++ if not has_ucm_sparse(): ++ return logits_indices ++ ucm_sparse = get_ucm_sparse() ++ return ucm_sparse.execute_finished(logits_indices) ++ ++ def ucm_sparse_request_finished_in_worker(self, request_id: str | int): ++ if not has_ucm_sparse(): ++ return ++ ucm_sparse = get_ucm_sparse() ++ ucm_sparse.request_finished_in_worker(request_id) ++ + def update_config(self, overrides: dict[str, Any]) -> None: + allowed_config_names = {"load_config", "model_config"} + for config_name, config_overrides in overrides.items(): +@@ -3928,6 +3988,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): + kv_caches = self._reshape_kv_cache_tensors(kv_cache_config, + kv_cache_raw_tensors) + ++ if has_ucm_sparse(): ++ ucm_sparse = get_ucm_sparse() ++ if os.getenv("VLLM_HASH_ATTENTION") == "1": ++ ucm_sparse.initialize_kv_hash_cache_tensors(kv_caches, self.device) ++ + # Set up cross-layer KV cache sharing + for layer_name, target_layer_name in self.shared_kv_cache_layers.items( + ): +diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py +index 8c75e8914..c3e9e781f 100644 +--- a/vllm/v1/worker/gpu_worker.py ++++ b/vllm/v1/worker/gpu_worker.py +@@ -35,6 +35,8 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner + from vllm.v1.worker.utils import is_residual_scattered_for_sp + from vllm.v1.worker.worker_base import WorkerBase + ++from ucm.sparse.state import ensure_ucm_sparse_initialized ++ + logger = init_logger(__name__) + + if TYPE_CHECKING: +@@ -708,3 +710,4 @@ def init_worker_distributed_environment( + parallel_config.decode_context_parallel_size) + + ensure_kv_transfer_initialized(vllm_config) ++ ensure_ucm_sparse_initialized(vllm_config) +-- +2.34.1 +