diff --git a/docs/source/user-guide/metrics/metrics.md b/docs/source/user-guide/metrics/metrics.md index 3cf15ee4e..9409b754d 100644 --- a/docs/source/user-guide/metrics/metrics.md +++ b/docs/source/user-guide/metrics/metrics.md @@ -223,6 +223,10 @@ The default metrics configuration contains the following UCM metrics. | `ucm:posix_h2s_bytes_total` | Total bytes transferred from host buffer to Posix storage. | | `ucm:load_bytes_total` | Total bytes loaded through the UCM connector. | | `ucm:save_bytes_total` | Total bytes saved through the UCM connector. | +| `ucm:dump_event_reshape_cache_direct_used_total` | Dump submissions synchronized with `attn_metadata.reshape_cache_event`. | +| `ucm:dump_event_reshape_cache_layer_used_total` | Dump submissions synchronized with `attn_metadata[layer_name].reshape_cache_event`. | +| `ucm:dump_event_current_stream_used_total` | Dump submissions synchronized with a UCM-recorded current stream event fallback. | +| `ucm:dump_event_sync_fallback_used_total` | Dump submissions that fell back to device synchronization because no event handle was available. | ### Gauges diff --git a/docs/source/user-guide/prefix-cache/pipeline_store.md b/docs/source/user-guide/prefix-cache/pipeline_store.md index 202a5f31c..12b7d046f 100644 --- a/docs/source/user-guide/prefix-cache/pipeline_store.md +++ b/docs/source/user-guide/prefix-cache/pipeline_store.md @@ -216,6 +216,7 @@ ucm_connectors: use_gdr: false enable_event_sync: true use_layerwise: true +enable_reshape_cache_event_sync: false enable_record_traces: false use_lite: false persist_token_threshold: 0 @@ -320,4 +321,4 @@ This log indicates that the **Posix Store** has received a **load or dump task** ```text [UC][D] Posix task({task_id},{operation},{subtask_number},{size}) finished, cost {time}ms. [PID,TID] ``` -This log indicates that a load or dump task in the **Posix Store** has completed, along with its execution time in **in ms**. \ No newline at end of file +This log indicates that a load or dump task in the **Posix Store** has completed, along with its execution time in **in ms**. diff --git a/examples/metrics/metrics_configs.yaml b/examples/metrics/metrics_configs.yaml index cf5e7a68e..d5dcb36dd 100644 --- a/examples/metrics/metrics_configs.yaml +++ b/examples/metrics/metrics_configs.yaml @@ -63,6 +63,14 @@ counter: documentation: "Total bytes loaded through the UCM connector (summed across all start_load_kv calls)" - name: "save_bytes_total" documentation: "Total bytes saved through the UCM connector (summed across all wait_for_save calls)" + - name: "dump_event_reshape_cache_direct_used_total" + documentation: "Number of dump submissions synchronized with attn_metadata.reshape_cache_event" + - name: "dump_event_reshape_cache_layer_used_total" + documentation: "Number of dump submissions synchronized with attn_metadata[layer_name].reshape_cache_event" + - name: "dump_event_current_stream_used_total" + documentation: "Number of dump submissions synchronized with a UCM-recorded current stream event fallback" + - name: "dump_event_sync_fallback_used_total" + documentation: "Number of dump submissions that fell back to device synchronization because no event handle was available" # Gauge metrics configuration gauge: diff --git a/examples/ucm_config_example.yaml b/examples/ucm_config_example.yaml index e9e1a6cbb..49ba4a36e 100644 --- a/examples/ucm_config_example.yaml +++ b/examples/ucm_config_example.yaml @@ -26,6 +26,9 @@ ucm_connectors: # When you use UcmNfsStore, you should set enable_event_sync to false. enable_event_sync: true +# Use vLLM-Ascend reshape cache events to start D2H immediately after KV cache is ready. +# Enable for better dump performance. +enable_reshape_cache_event_sync: false # Enable UCM metrics so they can be monitored online via Grafana and Prometheus. # metrics_config_path: "/workspace/unified-cache-management/examples/metrics/metrics_configs.yaml" diff --git a/ucm/integration/vllm/device.py b/ucm/integration/vllm/device.py index 1a29328c5..4c3980245 100644 --- a/ucm/integration/vllm/device.py +++ b/ucm/integration/vllm/device.py @@ -13,7 +13,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from itertools import accumulate -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from vllm.platforms import current_platform @@ -26,12 +26,16 @@ class Device(ABC): def __init__(self): self.events = {} + self.borrowed_events = [] @abstractmethod def get_event_handle(self) -> int: """Return event handle for stream sync. 0 means no event (use synchronize instead).""" pass + def get_event_handle_from_event(self, event: Any) -> int: + return 0 + @abstractmethod def synchronize(self): pass @@ -114,11 +118,23 @@ def get_event_handle(self) -> int: logger.error(f"get cuda event handle failed. {e}") return 0 + def get_event_handle_from_event(self, event: Any) -> int: + try: + handle = getattr(event, "cuda_event", None) + if handle is None or int(handle) == 0: + return 0 + self.borrowed_events.append(event) + return int(handle) + except Exception as e: + logger.error(f"get cuda event handle from existing event failed. {e}") + return 0 + def synchronize(self): torch.cuda.current_stream().synchronize() def destroy_event_handles(self): self.events.clear() + self.borrowed_events.clear() def destroy_event_handle(self, handle: int): self.events.pop(handle, None) @@ -248,6 +264,22 @@ def get_event_handle(self) -> int: logger.error(f"get npu event handle failed. {e}") return 0 + def get_event_handle_from_event(self, event: Any) -> int: + try: + handle = getattr(event, "npu_event", None) + if handle is None: + parameter = getattr(event, "_as_parameter_", None) + if callable(parameter): + parameter = parameter() + handle = getattr(parameter, "value", None) + if handle is None or int(handle) == 0: + return 0 + self.borrowed_events.append(event) + return int(handle) + except Exception as e: + logger.error(f"get npu event handle from existing event failed. {e}") + return 0 + def synchronize(self): torch.npu.current_stream().synchronize() @@ -260,6 +292,7 @@ def destroy_event_handles(self): except Exception as e: logger.error(f"destroy npu event failed. {e}") self.events.clear() + self.borrowed_events.clear() def destroy_event_handle(self, handle: int): import acl diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index e7868084d..9d9affdd8 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -323,6 +323,7 @@ def __init__( self.tp_rank = self._vllm_config.parallel_config.rank self.block_size = self._vllm_config.cache_config.block_size self.is_mla = self._vllm_config.model_config.is_deepseek_mla + self.use_mla = getattr(self._vllm_config.model_config, "use_mla", self.is_mla) self.num_layers = self._vllm_config.model_config.get_num_layers( self._vllm_config.parallel_config ) @@ -363,6 +364,9 @@ def __init__( self.launch_config = ucm_config.get_config() self.connector_configs = self.launch_config.get("ucm_connectors", []) self.enable_event_sync = self.launch_config.get("enable_event_sync", True) + self.enable_reshape_cache_event_sync = self.launch_config.get( + "enable_reshape_cache_event_sync", False + ) self.enable_record_traces = self.launch_config.get( "enable_record_traces", False ) @@ -826,14 +830,61 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: def wait_for_layer_load(self, layer_name: str) -> None: pass - def _get_dump_event_handle(self) -> int: + def _get_reshape_cache_event( + self, layer_name: Optional[str], attn_metadata: Optional["AttentionMetadata"] + ) -> tuple[Optional[Any], str]: + if attn_metadata is None: + return None, "none" + + if layer_name: + try: + layer_metadata = attn_metadata[layer_name] + event = getattr(layer_metadata, "reshape_cache_event", None) + if event is not None: + return event, "layer" + except (KeyError, TypeError, AttributeError): + pass + + event = getattr(attn_metadata, "reshape_cache_event", None) + if event is not None: + return event, "direct" + return None, "direct_missing" + + def _get_dump_event_handle( + self, + layer_name: Optional[str] = None, + attn_metadata: Optional["AttentionMetadata"] = None, + ) -> int: if not self.enable_event_sync: self.device.synchronize() return 0 + if self.enable_reshape_cache_event_sync: + reshape_cache_event, event_source = self._get_reshape_cache_event( + layer_name, attn_metadata + ) + event_handle = ( + self.device.get_event_handle_from_event(reshape_cache_event) + if reshape_cache_event is not None + else 0 + ) + if event_handle != 0: + if event_source == "direct": + ucmmetrics.update_stats( + "dump_event_reshape_cache_direct_used_total", 1.0 + ) + elif event_source == "layer": + ucmmetrics.update_stats( + "dump_event_reshape_cache_layer_used_total", 1.0 + ) + return event_handle + event_handle = self.device.get_event_handle() if event_handle == 0: self.device.synchronize() + ucmmetrics.update_stats("dump_event_sync_fallback_used_total", 1.0) + else: + ucmmetrics.update_stats("dump_event_current_stream_used_total", 1.0) return event_handle def save_kv_layer( @@ -1140,7 +1191,7 @@ def save_kv_layer( shard_indexs = [layer_id] * len(total_ucm_block_ids) try: layer_ptrs = np.ascontiguousarray(self.dump_total_ptrs[local_layer_id]) - event_handle = self._get_dump_event_handle() + event_handle = self._get_dump_event_handle(layer_name, attn_metadata) task = self.store.dump_data( total_ucm_block_ids, shard_indexs, layer_ptrs, event_handle )