-
Notifications
You must be signed in to change notification settings - Fork 88
[Feat] Add opt-in reshape cache event sync #1025
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,7 +13,7 @@ | |
| from abc import ABC, abstractmethod | ||
| from dataclasses import dataclass | ||
| from itertools import accumulate | ||
| from typing import Dict, List, Optional, Tuple, Union | ||
| from typing import Any, Dict, List, Optional, Tuple, Union | ||
|
|
||
| import torch | ||
| from vllm.platforms import current_platform | ||
|
|
@@ -26,12 +26,16 @@ | |
| class Device(ABC): | ||
| def __init__(self): | ||
| self.events = {} | ||
| self.borrowed_events = [] | ||
|
|
||
| @abstractmethod | ||
| def get_event_handle(self) -> int: | ||
| """Return event handle for stream sync. 0 means no event (use synchronize instead).""" | ||
| pass | ||
|
|
||
| def get_event_handle_from_event(self, event: Any) -> int: | ||
| return 0 | ||
|
|
||
| @abstractmethod | ||
| def synchronize(self): | ||
| pass | ||
|
|
@@ -114,11 +118,23 @@ def get_event_handle(self) -> int: | |
| logger.error(f"get cuda event handle failed. {e}") | ||
| return 0 | ||
|
|
||
| def get_event_handle_from_event(self, event: Any) -> int: | ||
| try: | ||
| handle = getattr(event, "cuda_event", None) | ||
| if handle is None or int(handle) == 0: | ||
| return 0 | ||
| self.borrowed_events.append(event) | ||
| return int(handle) | ||
| except Exception as e: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using Recommended specific exceptions: except (AttributeError, TypeError, ValueError) as e:This applies to both CudaDevice (line 128) and NpuDevice (line 279) implementations. |
||
| logger.error(f"get cuda event handle from existing event failed. {e}") | ||
| return 0 | ||
|
|
||
| def synchronize(self): | ||
| torch.cuda.current_stream().synchronize() | ||
|
|
||
| def destroy_event_handles(self): | ||
| self.events.clear() | ||
| self.borrowed_events.clear() | ||
|
|
||
| def destroy_event_handle(self, handle: int): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The If individual event handles are destroyed (rather than batch clearing via Consider either:
|
||
| self.events.pop(handle, None) | ||
|
|
@@ -248,6 +264,22 @@ def get_event_handle(self) -> int: | |
| logger.error(f"get npu event handle failed. {e}") | ||
| return 0 | ||
|
|
||
| def get_event_handle_from_event(self, event: Any) -> int: | ||
| try: | ||
| handle = getattr(event, "npu_event", None) | ||
| if handle is None: | ||
| parameter = getattr(event, "_as_parameter_", None) | ||
| if callable(parameter): | ||
| parameter = parameter() | ||
| handle = getattr(parameter, "value", None) | ||
| if handle is None or int(handle) == 0: | ||
| return 0 | ||
| self.borrowed_events.append(event) | ||
| return int(handle) | ||
| except Exception as e: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same issue as CudaDevice - using |
||
| logger.error(f"get npu event handle from existing event failed. {e}") | ||
| return 0 | ||
|
|
||
| def synchronize(self): | ||
| torch.npu.current_stream().synchronize() | ||
|
|
||
|
|
@@ -260,6 +292,7 @@ def destroy_event_handles(self): | |
| except Exception as e: | ||
| logger.error(f"destroy npu event failed. {e}") | ||
| self.events.clear() | ||
| self.borrowed_events.clear() | ||
|
|
||
| def destroy_event_handle(self, handle: int): | ||
| import acl | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -323,6 +323,7 @@ def __init__( | |
| self.tp_rank = self._vllm_config.parallel_config.rank | ||
| self.block_size = self._vllm_config.cache_config.block_size | ||
| self.is_mla = self._vllm_config.model_config.is_deepseek_mla | ||
| self.use_mla = getattr(self._vllm_config.model_config, "use_mla", self.is_mla) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Suggestion - Unused variable The variable Note: |
||
| self.num_layers = self._vllm_config.model_config.get_num_layers( | ||
| self._vllm_config.parallel_config | ||
| ) | ||
|
|
@@ -363,6 +364,9 @@ def __init__( | |
| self.launch_config = ucm_config.get_config() | ||
| self.connector_configs = self.launch_config.get("ucm_connectors", []) | ||
| self.enable_event_sync = self.launch_config.get("enable_event_sync", True) | ||
| self.enable_reshape_cache_event_sync = self.launch_config.get( | ||
| "enable_reshape_cache_event_sync", False | ||
| ) | ||
| self.enable_record_traces = self.launch_config.get( | ||
| "enable_record_traces", False | ||
| ) | ||
|
|
@@ -826,14 +830,61 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: | |
| def wait_for_layer_load(self, layer_name: str) -> None: | ||
| pass | ||
|
|
||
| def _get_dump_event_handle(self) -> int: | ||
| def _get_reshape_cache_event( | ||
| self, layer_name: Optional[str], attn_metadata: Optional["AttentionMetadata"] | ||
| ) -> tuple[Optional[Any], str]: | ||
| if attn_metadata is None: | ||
| return None, "none" | ||
|
|
||
| if layer_name: | ||
| try: | ||
| layer_metadata = attn_metadata[layer_name] | ||
| event = getattr(layer_metadata, "reshape_cache_event", None) | ||
| if event is not None: | ||
| return event, "layer" | ||
| except (KeyError, TypeError, AttributeError): | ||
| pass | ||
|
|
||
| event = getattr(attn_metadata, "reshape_cache_event", None) | ||
| if event is not None: | ||
| return event, "direct" | ||
| return None, "direct_missing" | ||
|
Comment on lines
+833
to
+851
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When This creates an observability gap - operators cannot distinguish between:
Consider adding metrics like:
|
||
|
|
||
| def _get_dump_event_handle( | ||
| self, | ||
| layer_name: Optional[str] = None, | ||
| attn_metadata: Optional["AttentionMetadata"] = None, | ||
| ) -> int: | ||
| if not self.enable_event_sync: | ||
| self.device.synchronize() | ||
| return 0 | ||
|
|
||
| if self.enable_reshape_cache_event_sync: | ||
| reshape_cache_event, event_source = self._get_reshape_cache_event( | ||
| layer_name, attn_metadata | ||
| ) | ||
| event_handle = ( | ||
| self.device.get_event_handle_from_event(reshape_cache_event) | ||
| if reshape_cache_event is not None | ||
| else 0 | ||
| ) | ||
| if event_handle != 0: | ||
| if event_source == "direct": | ||
| ucmmetrics.update_stats( | ||
| "dump_event_reshape_cache_direct_used_total", 1.0 | ||
| ) | ||
| elif event_source == "layer": | ||
| ucmmetrics.update_stats( | ||
| "dump_event_reshape_cache_layer_used_total", 1.0 | ||
| ) | ||
| return event_handle | ||
|
Comment on lines
+862
to
+880
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When The metric doesn't reflect that we attempted reshape cache sync first but failed at handle extraction. Consider:
|
||
|
|
||
| event_handle = self.device.get_event_handle() | ||
| if event_handle == 0: | ||
| self.device.synchronize() | ||
| ucmmetrics.update_stats("dump_event_sync_fallback_used_total", 1.0) | ||
| else: | ||
| ucmmetrics.update_stats("dump_event_current_stream_used_total", 1.0) | ||
| return event_handle | ||
|
|
||
| def save_kv_layer( | ||
|
|
@@ -1140,7 +1191,7 @@ def save_kv_layer( | |
| shard_indexs = [layer_id] * len(total_ucm_block_ids) | ||
| try: | ||
| layer_ptrs = np.ascontiguousarray(self.dump_total_ptrs[local_layer_id]) | ||
| event_handle = self._get_dump_event_handle() | ||
| event_handle = self._get_dump_event_handle(layer_name, attn_metadata) | ||
| task = self.store.dump_data( | ||
| total_ucm_block_ids, shard_indexs, layer_ptrs, event_handle | ||
| ) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Suggestion - Thread safety consideration
The
borrowed_eventslist is accessed without synchronization. If this device is used in a multi-threaded context (e.g., multiple workers sharing the same device instance), there could be race conditions when:get_event_handle_from_eventconcurrently (append)Consider adding:
borrowed_eventsoperations