diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index a6048a073..11a54ea72 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -6,7 +6,8 @@ import time from collections import defaultdict from dataclasses import dataclass, field -from typing import TYPE_CHECKING, List, Optional, Tuple +from enum import Enum +from typing import TYPE_CHECKING, Any, List, Optional, Tuple import numpy as np import torch @@ -16,12 +17,20 @@ KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, + SupportsHMA, ) from vllm.distributed.parallel_state import get_tp_group, get_world_group from vllm.distributed.utils import get_pp_indices from vllm.model_executor.models.utils import extract_layer_index from vllm.platforms import current_platform +from vllm.utils.torch_utils import get_dtype_size from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import ( + AttentionSpec, + KVCacheConfig, + KVCacheSpec, + MambaSpec, +) from ucm.integration.vllm.device import create_device from ucm.logger import init_logger @@ -45,30 +54,78 @@ @dataclass class RequestMeta: ucm_block_ids: list[bytes] = field(default_factory=list) + mamba_block_ids: list[bytes] = field(default_factory=list) hbm_hit_block_num: int = 0 # local_computed_block + external_computed_block total_hit_block_num: int = 0 num_token_ids: int = 0 vllm_block_ids: list[int] = field(default_factory=list) token_processed: int = 0 + ucm_block_ids_by_group: list[list[bytes]] = field(default_factory=list) + vllm_block_ids_by_group: list[list[int]] = field(default_factory=list) + + +class KVCacheType(Enum): + ATTENTION = "attention" + MAMBA = "mamba" + UNKNOWN = "unknown" @dataclass class RequestDispatchMeta: - load_block_ids: tuple[ - list[bytes], list[int] - ] # [0] mean ucm_block_ids, [1] means vllm_block_ids - dump_block_ids: tuple[list[bytes], list[int]] + load_block_ids: dict[KVCacheType, tuple[list[bytes], list[int]]] + dump_block_ids: dict[KVCacheType, tuple[list[bytes], list[int]]] + + +@dataclass +class SingleLayout: + base_ptrs: np.ndarray + stride_lists: np.ndarray + tensor_size_lists: np.ndarray + + def extract_block_addrs( + self, vllm_block_ids: List[int], layer_first: bool = False + ) -> np.ndarray: + vllm_block_ids_np = np.array(vllm_block_ids, np.uint64) + if layer_first: + return ( + vllm_block_ids_np[None, :, None] * self.stride_lists[:, None, :] + + self.base_ptrs[:, None, :] + ) + return ( + vllm_block_ids_np[:, None, None] * self.stride_lists[None, :, :] + + self.base_ptrs[None, :, :] + ) + + def tensor_size_list(self, use_layerwise: bool) -> list[int]: + return ( + self.tensor_size_lists.reshape(-1).tolist() + if not use_layerwise + else self.tensor_size_lists[0].tolist() + ) + + def shard_size(self, use_layerwise: bool) -> int: + return int( + self.tensor_size_lists[0].sum() + if use_layerwise + else self.tensor_size_lists.sum() + ) + + def block_size(self, pp_size: int, num_hidden_layers: int) -> int: + if pp_size > 1: + return int(self.tensor_size_lists[0].sum() * num_hidden_layers) + return int(self.tensor_size_lists.sum()) class KVCacheLayout: def __init__( - self, kvcaches, use_layerwise: bool, vllm_config: "VllmConfig" + self, + vllm_config: "VllmConfig", + ucm_config: Config, + kv_cache_config: Optional["KVCacheConfig"] = None, ) -> None: - # each row is a layer, each column is a tensor_size/ptr in the layer (e.g., k, v, rope, k_index) - self.base_ptrs: np.ndarray # (n_layers, n_ptrs) - self.tensor_size_lists: np.ndarray # (n_layers, n_tensor_sizes) - self.use_layerwise = use_layerwise + self.use_layerwise = ucm_config.get_config().get("use_layerwise", False) + self.kv_cache_config = kv_cache_config self.vllm_config = vllm_config self.pp_size = self.vllm_config.parallel_config.pipeline_parallel_size self.num_hidden_layers = getattr( @@ -82,92 +139,222 @@ def __init__( self.local_num_hidden_layers = end - start if self.pp_size > 1 and self.num_hidden_layers <= 0: raise ValueError("num_hidden_layers must be > 0 when pp_size > 1") - self.layer_name_to_id = { - name: extract_layer_index(name) for name in kvcaches.keys() - } - self.first_layer_id = next(iter(self.layer_name_to_id.values())) - self._build_layout(kvcaches) + self.cache_block_size = self.vllm_config.cache_config.block_size + self.layouts: dict[KVCacheType, SingleLayout] = {} + self.layer_name_to_id: dict[str, int] = {} + self.first_layer_id: int = 0 + self.layer_ids: list[int] = [] + self.layer_name_to_group_id: dict[str, int] = {} + self.layer_name_to_kv_cache_type: dict[str, KVCacheType] = {} + self.layer_name_to_raw_tensor_idx: dict[str, int] = {} + self.group_ids_by_kv_cache_type: dict[KVCacheType, list[int]] = defaultdict( + list + ) + self.kv_cache_types: list[KVCacheType] = [KVCacheType.ATTENTION] + self.kernel_block_size_scale = 1 + if self.kv_cache_config is not None: + self._initialize_kv_cache_config() - def _build_layout(self, kvcaches): - raw_ptr_rows = [[] for _ in range(self.local_num_hidden_layers)] - stride_rows = [[] for _ in range(self.local_num_hidden_layers)] + @property + def is_hybrid(self) -> bool: + return len(self.kv_cache_types) > 1 - for layer_name, kv_layer in kvcaches.items(): - ptrs = [] - strides = [] - - def handle_tensor(t: torch.Tensor, size_dims): - ptrs.append(t[0].data_ptr()) - - stride = math.prod([t.shape[i] for i in size_dims]) * t.element_size() - strides.append(stride) - - if isinstance(kv_layer, torch.Tensor): - if kv_layer.dim() == 5: - # [2, num_blocks, block_size, num_head, head_dim] - handle_tensor(kv_layer[0], (-3, -2, -1)) - handle_tensor(kv_layer[1], (-3, -2, -1)) - elif kv_layer.dim() == 3: - # [num_blocks, block_size, head_dim] - handle_tensor(kv_layer, (-2, -1)) - else: - raise ValueError( - f"Unsupported kv cache tensor shape: {kv_layer.shape}" - ) - elif isinstance(kv_layer, Tuple): - # vllm_ascend >= 0.10.0, ([num_blocks, block_size, num_head, head_dim], ...) - for tensor in kv_layer: - handle_tensor(tensor, (-3, -2, -1)) + @property + def default_kv_cache_type(self) -> KVCacheType: + if KVCacheType.ATTENTION in self.kv_cache_types: + return KVCacheType.ATTENTION + return self.kv_cache_types[0] + + def _kv_cache_type_from_spec(self, kv_cache_spec: "KVCacheSpec") -> KVCacheType: + if isinstance(kv_cache_spec, MambaSpec): + return KVCacheType.MAMBA + if isinstance(kv_cache_spec, AttentionSpec): + return KVCacheType.ATTENTION + return KVCacheType.UNKNOWN + + def _initialize_kv_cache_config(self): + discovered_kv_cache_types: list[KVCacheType] = [] + for group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups + ): + kv_cache_type = self._kv_cache_type_from_spec( + kv_cache_group_spec.kv_cache_spec + ) + if ( + kv_cache_type not in discovered_kv_cache_types + and kv_cache_type != KVCacheType.UNKNOWN + ): + discovered_kv_cache_types.append(kv_cache_type) + if kv_cache_type != KVCacheType.UNKNOWN: + self.group_ids_by_kv_cache_type[kv_cache_type].append(group_id) + for layer_name in kv_cache_group_spec.layer_names: + self.layer_name_to_group_id[layer_name] = group_id + self.layer_name_to_kv_cache_type[layer_name] = kv_cache_type + if discovered_kv_cache_types: + self.kv_cache_types = discovered_kv_cache_types + + def initialize_kv_cache_layout(self, kvcaches): + if self.kv_cache_config is not None: + self._build_layout_with_kv_cache_config(kvcaches) + else: + self._build_layout(kvcaches) + self.layer_ids = list(sorted(set(self.layer_name_to_id.values()))) + self.first_layer_id = self.layer_ids[0] + + def _handle_kv_layer(self, kv_layer, kv_cache_type: KVCacheType): + ptrs = [] + strides = [] + tensor_sizes = [] + + def handle_kv_tensor(t: torch.Tensor): + ptrs.append(t.data_ptr()) + strides.append(t.stride(0) * t.element_size()) + tensor_size = ( + math.prod([t.shape[i] for i in range(1, t.dim())]) * t.element_size() + ) + if kv_cache_type == KVCacheType.ATTENTION: + kernel_block_size = t.shape[1] + self.kernel_block_size_scale = ( + self.cache_block_size // kernel_block_size + ) + tensor_size *= self.kernel_block_size_scale + tensor_sizes.append(tensor_size) + + if isinstance(kv_layer, torch.Tensor): + if kv_layer.dim() == 5 and kv_layer.shape[0] == 2: + # full attention packed KV: [2, num_blocks, ...] + handle_kv_tensor(kv_layer[0]) + handle_kv_tensor(kv_layer[1]) else: - raise TypeError(f"Unsupported kv cache type: {type(kv_layer)}") + handle_kv_tensor(kv_layer) + elif isinstance(kv_layer, (tuple, list)): + for tensor in kv_layer: + handle_kv_tensor(tensor) + else: + raise TypeError(f"Unsupported kv cache type: {type(kv_layer)}") + + return ptrs, strides, tensor_sizes + + def _create_layout( + self, + raw_ptr_rows: list[list[int]], + stride_rows: list[list[int]], + tensor_size_rows: list[list[int]], + ) -> SingleLayout: + return SingleLayout( + base_ptrs=np.asarray(raw_ptr_rows, dtype=np.uint64), + stride_lists=np.asarray(stride_rows, dtype=np.uint64), + tensor_size_lists=np.asarray(tensor_size_rows, dtype=np.uint64), + ) - local_layer_id = self.layer_name_to_id[layer_name] - self.first_layer_id - raw_ptr_rows[local_layer_id].extend(ptrs) - stride_rows[local_layer_id].extend(strides) + def _build_layout(self, kvcaches): + raw_ptr_rows = [] + stride_rows = [] + tensor_size_rows = [] + kv_cache_type = self.default_kv_cache_type - self.base_ptrs = np.asarray(raw_ptr_rows, dtype=np.uint64) - self.tensor_size_lists = np.asarray(stride_rows, dtype=np.uint64) + for layer_name, kv_layer in kvcaches.items(): + self.layer_name_to_id[layer_name] = extract_layer_index(layer_name) + ptrs, strides, tensor_sizes = self._handle_kv_layer(kv_layer, kv_cache_type) + raw_ptr_rows.append(ptrs) + stride_rows.append(strides) + tensor_size_rows.append(tensor_sizes) + self.layer_name_to_kv_cache_type.setdefault(layer_name, kv_cache_type) + raw_tensor_idx = len(raw_ptr_rows) - 1 + self.layer_name_to_raw_tensor_idx.setdefault(layer_name, raw_tensor_idx) + + self.layouts = { + kv_cache_type: self._create_layout( + raw_ptr_rows, stride_rows, tensor_size_rows + ) + } + self.kv_cache_types = [kv_cache_type] + layout = self.layouts[kv_cache_type] logger.info( - f"base_ptrs: {self.base_ptrs.shape}, tensor_size_lists: {self.tensor_size_lists.shape}" + f"layout[{kv_cache_type}]: base_ptrs {layout.base_ptrs.shape}, " + f"stride_lists {layout.stride_lists.shape}, " + f"tensor_size_lists {layout.tensor_size_lists.shape}" ) + def _build_layout_with_kv_cache_config(self, kvcaches): + raw_ptr_rows: dict[KVCacheType, list[list[int]]] = defaultdict(list) + stride_rows: dict[KVCacheType, list[list[int]]] = defaultdict(list) + tensor_size_rows: dict[KVCacheType, list[list[int]]] = defaultdict(list) + + for raw_tensor_idx, kv_cache_tensor in enumerate( + self.kv_cache_config.kv_cache_tensors + ): + kv_cache_type_to_layer_names: dict[KVCacheType, list[str]] = defaultdict( + list + ) + for layer_name in kv_cache_tensor.shared_by: + self.layer_name_to_id[layer_name] = extract_layer_index(layer_name) + self.layer_name_to_raw_tensor_idx[layer_name] = raw_tensor_idx + kv_cache_type = self.layer_name_to_kv_cache_type.get( + layer_name, KVCacheType.UNKNOWN + ) + kv_cache_type_to_layer_names[kv_cache_type].append(layer_name) + for ( + kv_cache_type, + shared_layer_names, + ) in kv_cache_type_to_layer_names.items(): + if kv_cache_type == KVCacheType.UNKNOWN: + continue + kv_layer = kvcaches.get(shared_layer_names[0]) + if kv_layer is None: + raise KeyError( + f"Layer {shared_layer_names[0]} referenced by kv_cache_config " + "was not found in registered KV caches." + ) + ptrs, strides, tensor_sizes = self._handle_kv_layer( + kv_layer, kv_cache_type + ) + raw_ptr_rows[kv_cache_type].append(ptrs) + stride_rows[kv_cache_type].append(strides) + tensor_size_rows[kv_cache_type].append(tensor_sizes) + + self.layouts = {} + for kv_cache_type in self.kv_cache_types: + self.layouts[kv_cache_type] = self._create_layout( + raw_ptr_rows[kv_cache_type], + stride_rows[kv_cache_type], + tensor_size_rows[kv_cache_type], + ) + + for kv_cache_type, layout in self.layouts.items(): + logger.info( + f"layout[{kv_cache_type}]: base_ptrs {layout.base_ptrs.shape}, " + f"stride_lists {layout.stride_lists.shape}, " + f"tensor_size_lists {layout.tensor_size_lists.shape}" + ) + + def get_layout(self, kv_cache_type: Optional[KVCacheType] = None) -> SingleLayout: + resolved_kv_cache_type = kv_cache_type or self.default_kv_cache_type + return self.layouts[resolved_kv_cache_type] + def extract_block_addrs( - self, vllm_block_ids: List[int], layer_first: bool = False + self, vllm_block_ids: List[int], kv_cache_type: Optional[KVCacheType] = None ) -> np.ndarray: - vllm_block_ids_np = np.array(vllm_block_ids, np.uint64) - if layer_first: - # (n_layers, num_blocks, n_ptrs) - return ( - self.tensor_size_lists[:, None, :] * vllm_block_ids_np[None, :, None] - + self.base_ptrs[:, None, :] - ) - return ( - vllm_block_ids_np[:, None, None] * self.tensor_size_lists[None, :, :] - + self.base_ptrs[None, :, :] - ) # (num_blocks, n_layers, n_ptrs) + resolved_kv_cache_type = kv_cache_type or self.default_kv_cache_type + if resolved_kv_cache_type == KVCacheType.ATTENTION: + vllm_block_ids = [ + block_id * self.kernel_block_size_scale for block_id in vllm_block_ids + ] + return self.get_layout(kv_cache_type).extract_block_addrs(vllm_block_ids) - @property - def tensor_size_list(self) -> list[int]: - return ( - self.tensor_size_lists.reshape(-1).tolist() - if not self.use_layerwise - else self.tensor_size_lists[0].tolist() - ) + def tensor_size_list( + self, kv_cache_type: Optional[KVCacheType] = None + ) -> list[int]: + return self.get_layout(kv_cache_type).tensor_size_list(self.use_layerwise) - @property - def shard_size(self) -> int: - return int( - self.tensor_size_lists.sum() - if not self.use_layerwise - else self.tensor_size_lists[0].sum() - ) + def shard_size(self, kv_cache_type: Optional[KVCacheType] = None) -> int: + return self.get_layout(kv_cache_type).shard_size(self.use_layerwise) - @property - def block_size(self) -> int: - if self.pp_size > 1: - return int(self.tensor_size_lists[0].sum() * self.num_hidden_layers) - return int(self.tensor_size_lists.sum()) + def block_size(self, kv_cache_type: Optional[KVCacheType] = None) -> int: + return self.get_layout(kv_cache_type).block_size( + self.pp_size, self.num_hidden_layers + ) @dataclass @@ -192,14 +379,21 @@ def __call__(self, input_data) -> bytes: return h.digest() -class UCMDirectConnector(KVConnectorBase_V1): +class UCMDirectConnector(KVConnectorBase_V1, SupportsHMA): """ This connector means synchronize: load -> forward -> save """ - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): - super().__init__(vllm_config=vllm_config, role=role) + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): + super().__init__( + vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config + ) self.use_layerwise = False self.kv_caches: dict[str, torch.Tensor] = {} self.local_rank = ( @@ -229,6 +423,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): self.device = torch_dev.device(f"{dev_name}:{self.local_rank}") self.store: UcmKVStoreBaseV1 + self.stores: dict[KVCacheType, UcmKVStoreBaseV1] = {} self.rope_store: Optional[UcmKVStoreBaseV1] = None # save block info, avoid hash request twice, and track them until request finished @@ -243,12 +438,18 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): self.chunk_size = self.block_size self.blocks_per_chunk = self.chunk_size // self.block_size + self.kv_cache_config = kv_cache_config + self.kv_cache_layout = KVCacheLayout( + self._vllm_config, + ucm_config, + getattr(self, "kv_cache_config", None), + ) if role == KVConnectorRole.SCHEDULER: self.request_hasher = RequestHasher(vllm_config, 0) self._seed = self.request_hasher("UCM_HASH_SEED") - # init scheduler-size connector - self.store = self._create_store(None) + self.stores = self._create_stores(self.kv_cache_layout) + self.store = self.stores[self.kv_cache_layout.default_kv_cache_type] else: self.request_hasher = RequestHasher( vllm_config, self.tp_rank % self.tp_size @@ -270,7 +471,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): f"metrics_config_path: {self.metrics_config}, set worker_id: {worker_id}" ) - # invlalid block ids due to load errors + # invalid block ids due to load errors self._invalid_block_ids: set[int] = set() def generate_hash( @@ -293,9 +494,37 @@ def generate_hash( return ret + def _generate_mamba_block_ids(self, attn_block_id: bytes) -> list[bytes]: + if not hasattr(self, "kv_cache_layout"): + return [] + mamba_group_ids = self.kv_cache_layout.group_ids_by_kv_cache_type.get( + KVCacheType.MAMBA, [] + ) + return [ + self.request_hasher((attn_block_id, f"group_{group_id}")) + for group_id in mamba_group_ids + ] + + def _validate_attn_hit_with_mamba( + self, attn_block_id: bytes + ) -> tuple[bool, list[bytes]]: + mamba_store = self.stores.get(KVCacheType.MAMBA) + if mamba_store is None: + return True, [] + + mamba_block_ids = self._generate_mamba_block_ids(attn_block_id) + if not mamba_block_ids: + return True, [] + + lookup_result = mamba_store.lookup(mamba_block_ids) + if all(lookup_result): + return True, mamba_block_ids + return False, [] + def _create_store( self, kv_cache_layout: Optional[KVCacheLayout], + kv_cache_type: KVCacheType, cpu_affinity_cores: Optional[list[int]] = None, ) -> UcmKVStoreBaseV1: if len(self.connector_configs) != 1: @@ -312,20 +541,41 @@ def _create_store( if "storage_backends" in config: backends = [path for path in config["storage_backends"].split(":")] config["storage_backends"] = backends - config["unique_id"] = f"{self.engine_id}" + config["unique_id"] = f"{self.engine_id}:{kv_cache_type.value}" if self._role == KVConnectorRole.WORKER: config["device_id"] = self.local_rank config["tensor_size_list"] = ( - kv_cache_layout.tensor_size_list * self.blocks_per_chunk + kv_cache_layout.tensor_size_list(kv_cache_type) * self.blocks_per_chunk + ) + config["shard_size"] = ( + kv_cache_layout.shard_size(kv_cache_type) * self.blocks_per_chunk + ) + config["block_size"] = ( + kv_cache_layout.block_size(kv_cache_type) * self.blocks_per_chunk ) - config["shard_size"] = kv_cache_layout.shard_size * self.blocks_per_chunk - config["block_size"] = kv_cache_layout.block_size * self.blocks_per_chunk config["local_rank_size"] = self.tp_size if self.is_mla else 1 if cpu_affinity_cores: config["cpu_affinity_cores"] = list(cpu_affinity_cores) logger.info(f"create {name} with config: {config}") return UcmConnectorFactoryV1.create_connector(name, config, module_path) + def _create_stores( + self, + kv_cache_layout: Optional[KVCacheLayout], + cpu_affinity_cores: Optional[list[int]] = None, + ) -> dict[KVCacheType, UcmKVStoreBaseV1]: + kv_cache_types = ( + kv_cache_layout.kv_cache_types + if kv_cache_layout is not None + else [KVCacheType.ATTENTION] + ) + return { + kv_cache_type: self._create_store( + kv_cache_layout, kv_cache_type, cpu_affinity_cores + ) + for kv_cache_type in kv_cache_types + } + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): if has_ucm_sparse() and os.getenv("VLLM_HASH_ATTENTION") == "1": for layer_name, value in kv_caches.items(): @@ -342,13 +592,11 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # vllm_ascend >= 0.10.0 uses Tuple for kvcaches for i, tensor in enumerate(sample_kv_layer): logger.info(f"kv cache shape {i}: {tensor.shape}") - self.kv_cache_layout = KVCacheLayout( - self.kv_caches, self.use_layerwise, self._vllm_config - ) - self.block_data_size = self.kv_cache_layout.block_size + self.kv_cache_layout.initialize_kv_cache_layout(self.kv_caches) + self.block_data_size = self.kv_cache_layout.block_size() self.layer_name_to_id = self.kv_cache_layout.layer_name_to_id - self.layer_ids = sorted(set(self.layer_name_to_id.values())) - self.first_layer_id = self.layer_ids[0] + self.layer_ids = self.kv_cache_layout.layer_ids + self.first_layer_id = self.kv_cache_layout.first_layer_id self.device = create_device() @@ -359,7 +607,10 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): else (None, None) ) - self.store = self._create_store(self.kv_cache_layout, store_cores) + self.stores: dict[KVCacheType, UcmKVStoreBaseV1] = self._create_stores( + self.kv_cache_layout, store_cores + ) + self.store = self.stores[self.kv_cache_layout.default_kv_cache_type] if worker_cores: try: @@ -386,8 +637,24 @@ def get_num_new_matched_tokens( external_block_ids = ucm_block_ids[hbm_hit_block_num:] if not external_block_ids: return 0, False + mamba_block_ids: list[bytes] = [] try: - external_hit_blocks = self.store.lookup_on_prefix(external_block_ids) + 1 + if self.kv_cache_layout.is_hybrid: + attn_store = self.stores.get(KVCacheType.ATTENTION, self.store) + external_hit_blocks = ( + attn_store.lookup_on_prefix(external_block_ids) + 1 + ) + else: + external_hit_blocks = ( + self.store.lookup_on_prefix(external_block_ids) + 1 + ) + if self.kv_cache_layout.is_hybrid and external_hit_blocks > 0: + last_hit_attn_block_id = external_block_ids[external_hit_blocks - 1] + mamba_hit, mamba_block_ids = self._validate_attn_hit_with_mamba( + last_hit_attn_block_id + ) + if not mamba_hit: + external_hit_blocks = 0 except RuntimeError as e: external_hit_blocks = 0 logger.error(f"request {request.request_id} look up error. {e}") @@ -415,6 +682,7 @@ def get_num_new_matched_tokens( self.requests_meta[request.request_id] = RequestMeta( ucm_block_ids=ucm_block_ids, + mamba_block_ids=mamba_block_ids, hbm_hit_block_num=hbm_hit_block_num, total_hit_block_num=total_hit_block_num, num_token_ids=len(request.all_token_ids), @@ -428,11 +696,30 @@ def update_state_after_alloc( ): pass + def _extend_vllm_block_ids_by_group( + self, req_meta: RequestMeta, vllm_block_ids_by_group: list[list[int]] + ) -> None: + normalized_block_ids = [ + list(group_block_ids) for group_block_ids in vllm_block_ids_by_group + ] + if not req_meta.vllm_block_ids_by_group: + req_meta.vllm_block_ids_by_group = [ + [] for _ in range(len(normalized_block_ids)) + ] + for target, source in zip( + req_meta.vllm_block_ids_by_group, normalized_block_ids + ): + target.extend(source) + + @staticmethod + def _nonzero_block_ids(block_ids: list[int]) -> list[int]: + return [block_id for block_id in block_ids if block_id != 0] + def _generate_dispatch_meta( self, req_meta: RequestMeta, new_tokens: int, - vllm_block_ids: list[int], + vllm_block_ids: list[list[int]], need_load: bool = True, ) -> RequestDispatchMeta: """ @@ -449,26 +736,90 @@ def _generate_dispatch_meta( hbm_hit_block_num = req_meta.hbm_hit_block_num total_hit_block_num = req_meta.total_hit_block_num - ucm_block_ids = req_meta.ucm_block_ids - req_meta.vllm_block_ids.extend(vllm_block_ids) + self._extend_vllm_block_ids_by_group(req_meta, vllm_block_ids) + + load_block_ids: dict[KVCacheType, tuple[list[bytes], list[int]]] = {} + dump_block_ids: dict[KVCacheType, tuple[list[bytes], list[int]]] = {} + + group_ids_by_kv_cache_type = self.kv_cache_layout.group_ids_by_kv_cache_type + attn_group_ids = group_ids_by_kv_cache_type.get(KVCacheType.ATTENTION, []) + mamba_group_ids = group_ids_by_kv_cache_type.get(KVCacheType.MAMBA, []) + assert ( + len(attn_group_ids) == 1 + ), "Current hybrid path expects exactly one attention group." + attn_group_id = attn_group_ids[0] + attn_vllm_block_ids = req_meta.vllm_block_ids_by_group[attn_group_id] - load_ucm_block_ids, load_vllm_block_ids = [], [] - dump_ucm_block_ids, dump_vllm_block_ids = [], [] if need_load: - load_ucm_block_ids = ucm_block_ids[hbm_hit_block_num:total_hit_block_num] - load_vllm_block_ids = vllm_block_ids[hbm_hit_block_num:total_hit_block_num] + # Attention uses the dense block timeline directly. + attn_load_ucm_block_ids = req_meta.ucm_block_ids[ + hbm_hit_block_num:total_hit_block_num + ] + attn_load_vllm_block_ids = attn_vllm_block_ids[ + hbm_hit_block_num:total_hit_block_num + ] + if attn_load_ucm_block_ids and attn_load_vllm_block_ids: + load_block_ids[KVCacheType.ATTENTION] = ( + list(attn_load_ucm_block_ids), + list(attn_load_vllm_block_ids), + ) + + if need_load and self.kv_cache_layout.is_hybrid and req_meta.mamba_block_ids: + mamba_load_pairs: list[tuple[bytes, int]] = [] + # For Mamba, the penultimate non-zero block is the load target. + for group_id, mamba_ucm_block_id in zip( + mamba_group_ids, req_meta.mamba_block_ids + ): + nonzero_vllm_block_ids = self._nonzero_block_ids( + req_meta.vllm_block_ids_by_group[group_id] + ) + if len(nonzero_vllm_block_ids) >= 2: + mamba_load_pairs.append( + (mamba_ucm_block_id, nonzero_vllm_block_ids[-1]) + ) + if mamba_load_pairs: + load_block_ids[KVCacheType.MAMBA] = ( + [ucm_block_id for ucm_block_id, _ in mamba_load_pairs], + [vllm_block_id for _, vllm_block_id in mamba_load_pairs], + ) if req_meta.token_processed < req_meta.num_token_ids: start_idx = req_meta.token_processed // self.block_size end_idx = (req_meta.token_processed + new_tokens) // self.block_size - dump_ucm_block_ids = ucm_block_ids[start_idx:end_idx] - dump_vllm_block_ids = req_meta.vllm_block_ids[start_idx:end_idx] + attn_dump_ucm_block_ids = req_meta.ucm_block_ids[start_idx:end_idx] + attn_dump_vllm_block_ids = attn_vllm_block_ids[start_idx:end_idx] + + if attn_dump_ucm_block_ids and attn_dump_vllm_block_ids: + # Attention dump follows the newly scheduled dense blocks. + dump_block_ids[KVCacheType.ATTENTION] = ( + list(attn_dump_ucm_block_ids), + list(attn_dump_vllm_block_ids), + ) + + if self.kv_cache_layout.is_hybrid: + mamba_dump_pairs: list[tuple[bytes, int]] = [] + last_attn_dump_block_id = attn_dump_ucm_block_ids[-1] + # Mamba dump keys are derived from the last attention dump key. + for group_id, mamba_ucm_block_id in zip( + mamba_group_ids, + self._generate_mamba_block_ids(last_attn_dump_block_id), + ): + nonzero_vllm_block_ids = self._nonzero_block_ids( + req_meta.vllm_block_ids_by_group[group_id] + ) + if nonzero_vllm_block_ids: + mamba_dump_pairs.append( + (mamba_ucm_block_id, nonzero_vllm_block_ids[-1]) + ) + if mamba_dump_pairs: + dump_block_ids[KVCacheType.MAMBA] = ( + [ucm_block_id for ucm_block_id, _ in mamba_dump_pairs], + [vllm_block_id for _, vllm_block_id in mamba_dump_pairs], + ) + req_meta.token_processed += new_tokens - return RequestDispatchMeta( - (load_ucm_block_ids, load_vllm_block_ids), - (dump_ucm_block_ids, dump_vllm_block_ids), - ) + return RequestDispatchMeta(load_block_ids, dump_block_ids) def build_connector_meta( self, scheduler_output: SchedulerOutput @@ -476,13 +827,13 @@ def build_connector_meta( requests_dispatch_meta = {} # for new request, we need to load and dump for request in scheduler_output.scheduled_new_reqs: - request_id, vllm_block_ids = request.req_id, request.block_ids[0] + request_id = request.req_id req_meta = self.requests_meta.get(request_id) if req_meta: requests_dispatch_meta[request_id] = self._generate_dispatch_meta( req_meta, scheduler_output.num_scheduled_tokens[request_id], - vllm_block_ids, + list(request.block_ids), ) # for cached request, there are 3 situation: @@ -495,9 +846,9 @@ def build_connector_meta( for i, request_id in enumerate(scheduled_cached_reqs.req_ids): req_meta = self.requests_meta.get(request_id) if req_meta: - new_block_ids = [] - if scheduled_cached_reqs.new_block_ids[i] != None: - new_block_ids = scheduled_cached_reqs.new_block_ids[i][0] + new_block_ids: list[list[int]] = [] + if scheduled_cached_reqs.new_block_ids[i] is not None: + new_block_ids = list(scheduled_cached_reqs.new_block_ids[i]) if hasattr(scheduled_cached_reqs, "resumed_from_preemption"): resumed_from_preemption = ( scheduled_cached_reqs.resumed_from_preemption[i] @@ -520,7 +871,7 @@ def build_connector_meta( requests_dispatch_meta[request_id] = self._generate_dispatch_meta( req_meta, scheduler_output.num_scheduled_tokens[request_id], - request.new_block_ids[0], + list(request.new_block_ids), request.resumed_from_preemption, ) @@ -534,45 +885,52 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: metadata = self._get_connector_metadata() assert isinstance(metadata, UCMConnectorMetadata) - request_to_task: dict[str, Task] = {} + pending_tasks: list[tuple["UcmKVStoreBaseV1", str, KVCacheType, Task]] = [] is_load = False num_loaded_block = 0 num_loaded_request = 0 load_start_time = time.perf_counter() * 1000 for request_id, request in metadata.request_meta.items(): - if len(request.load_block_ids[0]) == 0: + if not request.load_block_ids: continue is_load = True - num_loaded_block += len(request.load_block_ids[0]) - num_loaded_request += 1 - - ucm_block_ids, vllm_block_ids = request.load_block_ids - if self.tp_rank != 0 and not self.is_mla: - for i, ucm_block_id in enumerate(ucm_block_ids): - ucm_block_ids[i] = self.request_hasher(ucm_block_id) - total_ptrs = self.kv_cache_layout.extract_block_addrs(vllm_block_ids) - total_ptrs = total_ptrs.reshape(total_ptrs.shape[0], -1) - shard_indexs = [0] * len(ucm_block_ids) - try: - task = self.store.load_data(ucm_block_ids, shard_indexs, total_ptrs) - request_to_task[request_id] = task - except RuntimeError as e: - logger.error(f"request {request_id} submit load task error. {e}") - self._invalid_block_ids.update( - metadata.request_meta[request_id].load_block_ids[1] + for kv_cache_type, ( + load_ucm_ids, + load_vllm_ids, + ) in request.load_block_ids.items(): + store = self.stores.get(kv_cache_type, self.store) + if not load_ucm_ids or not load_vllm_ids: + continue + num_loaded_block += len(load_ucm_ids) + num_loaded_request += 1 + + if self.tp_rank != 0 and not self.is_mla: + for i, ucm_id in enumerate(load_ucm_ids): + load_ucm_ids[i] = self.request_hasher(ucm_id) + + total_ptrs = self.kv_cache_layout.extract_block_addrs( + load_vllm_ids, kv_cache_type ) - num_loaded_block -= len(request.load_block_ids[0]) - - for request_id, task in request_to_task.items(): + total_ptrs = total_ptrs.reshape(total_ptrs.shape[0], -1) + shard_idxs = [0] * len(load_ucm_ids) + try: + task = store.load_data(load_ucm_ids, shard_idxs, total_ptrs) + pending_tasks.append((store, request_id, kv_cache_type, task)) + except RuntimeError as e: + logger.error(f"request {request_id} submit load task error. {e}") + self._invalid_block_ids.update(load_vllm_ids) + num_loaded_block -= len(load_ucm_ids) + + for store, request_id, kv_cache_type, task in pending_tasks: try: - self.store.wait(task) + store.wait(task) except RuntimeError as e: logger.error(f"request {request_id} wait load task error. {e}") self._invalid_block_ids.update( - metadata.request_meta[request_id].load_block_ids[1] + metadata.request_meta[request_id].load_block_ids[kv_cache_type][1] ) num_loaded_block -= len( - metadata.request_meta[request_id].load_block_ids[0] + metadata.request_meta[request_id].load_block_ids[kv_cache_type][0] ) load_end_time = time.perf_counter() * 1000 @@ -623,43 +981,56 @@ def wait_for_save(self) -> None: metadata = self._get_connector_metadata() assert isinstance(metadata, UCMConnectorMetadata) - dump_tasks: List[Task] = [] + dump_tasks: list[tuple["UcmKVStoreBaseV1", Task]] = [] is_save = False num_saved_block = 0 num_saved_request = 0 - total_ucm_block_ids, total_vllm_block_ids = [], [] - for request_id, request in metadata.request_meta.items(): - if len(request.dump_block_ids[0]) == 0: + total_ucm_block_ids: dict[KVCacheType, list[bytes]] = defaultdict(list) + total_vllm_block_ids: dict[KVCacheType, list[int]] = defaultdict(list) + for _, request in metadata.request_meta.items(): + if not request.dump_block_ids: continue - is_save = True - num_saved_block += len(request.dump_block_ids[0]) - num_saved_request += 1 - - ucm_block_ids, vllm_block_ids = request.dump_block_ids - if self.tp_rank != 0: - for i, ucm_block_id in enumerate(ucm_block_ids): - ucm_block_ids[i] = self.request_hasher(ucm_block_id) - total_ucm_block_ids.extend(ucm_block_ids) - total_vllm_block_ids.extend(vllm_block_ids) + for kv_cache_type, ( + ucm_block_ids, + vllm_block_ids, + ) in request.dump_block_ids.items(): + if not ucm_block_ids or not vllm_block_ids: + continue + is_save = True + num_saved_block += len(ucm_block_ids) + num_saved_request += 1 + total_ucm_block_ids[kv_cache_type].extend(ucm_block_ids) + total_vllm_block_ids[kv_cache_type].extend(vllm_block_ids) if is_save: - total_ptrs = self.kv_cache_layout.extract_block_addrs(total_vllm_block_ids) - total_ptrs = total_ptrs.reshape(total_ptrs.shape[0], -1) - shard_indexs = [0] * len(total_ucm_block_ids) - try: - event_handle = self._get_dump_event_handle() - save_start_time = time.perf_counter() * 1000 - task = self.store.dump_data( - total_ucm_block_ids, shard_indexs, total_ptrs, event_handle + save_start_time = time.perf_counter() * 1000 + for kv_cache_type, ucm_block_ids in total_ucm_block_ids.items(): + store = self.stores.get(kv_cache_type, self.store) + ucm_ids = list(ucm_block_ids) + if self.tp_rank != 0: + for i, ucm_block_id in enumerate(ucm_ids): + ucm_ids[i] = self.request_hasher(ucm_block_id) + + vllm_block_ids = total_vllm_block_ids[kv_cache_type] + total_ptrs = self.kv_cache_layout.extract_block_addrs( + vllm_block_ids, kv_cache_type ) - dump_tasks.append(task) - except RuntimeError as e: - logger.error(f"dump kv cache failed. {e}") - return + total_ptrs = total_ptrs.reshape(total_ptrs.shape[0], -1) + shard_indexs = [0] * len(ucm_ids) + try: + event_handle = self._get_dump_event_handle() + task = store.dump_data( + ucm_ids, shard_indexs, total_ptrs, event_handle + ) + dump_tasks.append((store, task)) + except RuntimeError as e: + logger.error(f"dump kv cache failed. {e}") + return + if is_save and dump_tasks: try: - for task in dump_tasks: - self.store.wait(task) + for store, task in dump_tasks: + store.wait(task) save_end_time = time.perf_counter() * 1000 except RuntimeError as e: logger.error(f"wait for dump kv cache failed.{e}") @@ -697,6 +1068,13 @@ def get_block_ids_with_load_errors(self) -> set[int]: self._invalid_block_ids = set() return res + def request_finished_all_groups( + self, + request: "Request", + block_ids: tuple[list[int], ...], + ) -> tuple[bool, dict[str, Any] | None]: + return False, None + class UCMLayerWiseConnector(UCMDirectConnector): """ @@ -706,8 +1084,13 @@ class UCMLayerWiseConnector(UCMDirectConnector): load l2 -> forward l2 -> save l2 """ - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): - super().__init__(vllm_config, role) + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): + super().__init__(vllm_config, role, kv_cache_config) # {layer_id: {request_id: Task}} self.load_tasks: dict[int, dict[str, Task]] = defaultdict(dict) self.dump_tasks: dict[str, Task] = {} @@ -850,8 +1233,13 @@ def wait_for_save(self) -> None: class UCMCPConnector(UCMLayerWiseConnector): - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): - super().__init__(vllm_config, role) + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): + super().__init__(vllm_config, role, kv_cache_config) self.use_layerwise = self.launch_config.get("use_layerwise", False) try: @@ -1079,8 +1467,13 @@ class UCMMockConnector(UCMDirectConnector): will reduce hit_tokens under the hit_ratio you set. """ - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): - super().__init__(vllm_config, role) + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): + super().__init__(vllm_config, role, kv_cache_config) self._hit_ratio = float(self.launch_config["hit_ratio"]) logger.info(f"hit_ratio: {self._hit_ratio}") @@ -1111,9 +1504,16 @@ def get_num_new_matched_tokens( return expect_hit_block_num * self.block_size, False -class UCMConnector(KVConnectorBase_V1): - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): - super().__init__(vllm_config=vllm_config, role=role) +class UCMConnector(KVConnectorBase_V1, SupportsHMA): + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): + super().__init__( + vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config + ) self.connector: KVConnectorBase_V1 ucm_config = Config(vllm_config.kv_transfer_config) self.launch_config = ucm_config.get_config() @@ -1130,7 +1530,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): "Pipeline parallelism is not supported in UCMDirectConnector, please set use_layerwise=True." ) if self.launch_config is not None and "hit_ratio" in self.launch_config: - self.connector = UCMMockConnector(vllm_config, role) + self.connector = UCMMockConnector(vllm_config, role, kv_cache_config) elif ( hasattr(self._vllm_config.parallel_config, "prefill_context_parallel_size") and hasattr( @@ -1140,11 +1540,11 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): * self._vllm_config.parallel_config.decode_context_parallel_size > 1 ): - self.connector = UCMCPConnector(vllm_config, role) + self.connector = UCMCPConnector(vllm_config, role, kv_cache_config) elif use_layerwise: - self.connector = UCMLayerWiseConnector(vllm_config, role) + self.connector = UCMLayerWiseConnector(vllm_config, role, kv_cache_config) else: - self.connector = UCMDirectConnector(vllm_config, role) + self.connector = UCMDirectConnector(vllm_config, role, kv_cache_config) def get_num_new_matched_tokens( self, @@ -1296,3 +1696,10 @@ def get_block_ids_with_load_errors(self) -> set[int]: Empty set if no load errors occurred. """ return self.connector.get_block_ids_with_load_errors() + + def request_finished_all_groups( + self, + request: "Request", + block_ids: tuple[list[int], ...], + ) -> tuple[bool, dict[str, Any] | None]: + return self.connector.request_finished_all_groups(request, block_ids)