Skip to content

Commit d83b854

Browse files
authored
[opt]Add event to sync (#768)
## Purpose This PR adds a config switch `enable_event_sync` to control whether dump tasks wait on a prerequisite compute event. ## Modifications Add `enable_event_sync`plumbing in `ucm_connector.py`. Pass event handle to `dump_data` only when sync is enabled; otherwise pass 0. Add `CopyStream::WaitEvent(void* event)` to apply wait on all internal copy streams. Update `DumpQueue` to issue prerequisite wait once via `CopyStream`, ensuring all streams are covered. Make `Trans::Stream::WaitEvent` pure virtual and implement it in SimuStream (no-op). ## Test <img width="1309" height="759" alt="image" src="https://github.com/user-attachments/assets/f9fd0590-154f-4399-b64a-3592fde511df" /> <img width="1345" height="825" alt="image" src="https://github.com/user-attachments/assets/ae79846a-05e5-4c85-9d63-776027ebca61" /> <img width="933" height="356" alt="image" src="https://github.com/user-attachments/assets/65f356b6-367e-4454-81e7-3f2386a46c6c" />
1 parent 2e97b73 commit d83b854

18 files changed

Lines changed: 196 additions & 18 deletions

File tree

examples/ucm_config_example.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ ucm_connectors:
1414
storage_backends: "/mnt/test"
1515
io_direct: false
1616

17+
enable_event_sync: false
1718
# Enable UCM metrics so they can be monitored online via Grafana and Prometheus.
1819
# metrics_config_path: "/workspace/unified-cache-management/examples/metrics/metrics_configs.yaml"
1920

ucm/integration/vllm/device.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Event-based sync between Python compute stream and C++ cache stream.
4+
5+
When dump_data is called, the cache's C++ stream does D2H from device memory.
6+
We must ensure the Python compute stream has finished writing KVCache before
7+
the cache reads. Event sync: record event on compute stream, pass to C++,
8+
cache stream waits for event before D2H. This avoids blocking the CPU.
9+
"""
10+
from abc import ABC, abstractmethod
11+
from typing import Optional
12+
13+
import torch
14+
from vllm.platforms import current_platform
15+
16+
from ucm.logger import init_logger
17+
18+
logger = init_logger(__name__)
19+
20+
21+
class Device(ABC):
22+
def __init__(self):
23+
self.events = []
24+
25+
@abstractmethod
26+
def get_event_handle(self) -> int:
27+
"""Return event handle for stream sync. 0 means no event (use synchronize instead)."""
28+
pass
29+
30+
@abstractmethod
31+
def synchronize(self):
32+
pass
33+
34+
@abstractmethod
35+
def destroy_event_handles(self):
36+
pass
37+
38+
39+
class CudaDevice(Device):
40+
def __init__(self):
41+
super().__init__()
42+
43+
def get_event_handle(self) -> int:
44+
try:
45+
cuda_event = torch.cuda.Event(enable_timing=False)
46+
stream = torch.cuda.current_stream()
47+
cuda_event.record(stream)
48+
handle = int(cuda_event.cuda_event)
49+
if handle is None or handle == 0:
50+
return 0
51+
self.events.append(cuda_event)
52+
return handle
53+
except Exception as e:
54+
logger.error(f"get cuda event handle failed. {e}")
55+
return 0
56+
57+
def synchronize(self):
58+
torch.cuda.current_stream().synchronize()
59+
60+
def destroy_event_handles(self):
61+
self.events.clear()
62+
63+
64+
class NpuDevice(Device):
65+
def __init__(self):
66+
super().__init__()
67+
68+
def get_event_handle(self) -> int:
69+
import acl
70+
import torch_npu
71+
72+
try:
73+
stream = torch_npu.npu.current_stream().npu_stream
74+
event, ret = acl.rt.create_event()
75+
if ret != 0:
76+
logger.error(f"acl create_event failed: {ret}")
77+
return 0
78+
self.events.append(event)
79+
ret = acl.rt.record_event(event, stream)
80+
if ret != 0:
81+
logger.error(f"acl record_event failed: {ret}")
82+
return 0
83+
handle = int(event)
84+
if not handle:
85+
return 0
86+
return handle
87+
except Exception as e:
88+
logger.error(f"get npu event handle failed. {e}")
89+
return 0
90+
91+
def synchronize(self):
92+
torch.npu.current_stream().synchronize()
93+
94+
def destroy_event_handles(self):
95+
import acl
96+
97+
for event in self.events:
98+
try:
99+
acl.rt.destroy_event(event)
100+
except Exception as e:
101+
logger.error(f"destroy npu event failed. {e}")
102+
continue
103+
self.events.clear()
104+
105+
106+
def create_device() -> Optional[Device]:
107+
if current_platform.is_cuda_alike():
108+
return CudaDevice()
109+
110+
if current_platform.device_type == "npu":
111+
return NpuDevice()
112+
113+
return None

ucm/integration/vllm/ucm_connector.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from vllm.platforms import current_platform
2222
from vllm.v1.core.sched.output import SchedulerOutput
2323

24+
from ucm.integration.vllm.device import create_device
2425
from ucm.logger import init_logger
2526
from ucm.observability import PrometheusStatsLogger
2627
from ucm.shared.metrics import ucmmetrics
@@ -207,6 +208,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
207208
self.launch_config = ucm_config.get_config()
208209
logger.info(f"self.launch_config: {self.launch_config}")
209210
self.connector_configs = self.launch_config.get("ucm_connectors", [])
211+
self.enable_event_sync = self.launch_config.get("enable_event_sync", False)
210212
assert len(self.connector_configs) > 0, "no storage connector name in config."
211213

212214
self.chunk_size = self.block_size
@@ -236,12 +238,6 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
236238
f"metrics_config_path: {self.metrics_config}, set worker_id: {worker_id}"
237239
)
238240

239-
self.synchronize = lambda: (
240-
torch.cuda.current_stream().synchronize()
241-
if current_platform.is_cuda_alike()
242-
else torch.npu.current_stream().synchronize()
243-
)
244-
245241
# invlalid block ids due to load errors
246242
self._invalid_block_ids: set[int] = set()
247243

@@ -319,6 +315,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
319315
}
320316

321317
self.store = self._create_store(self.kv_cache_layout)
318+
self.device = create_device()
319+
if self.device is None:
320+
raise RuntimeError(f"Unsupported device platform for UCMDirectConnector.")
322321

323322
def get_num_new_matched_tokens(
324323
self,
@@ -547,6 +546,16 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
547546
def wait_for_layer_load(self, layer_name: str) -> None:
548547
pass
549548

549+
def _get_dump_event_handle(self) -> int:
550+
if not self.enable_event_sync:
551+
self.device.synchronize()
552+
return 0
553+
554+
event_handle = self.device.get_event_handle()
555+
if event_handle == 0:
556+
self.device.synchronize()
557+
return event_handle
558+
550559
def save_kv_layer(
551560
self,
552561
layer_name: str,
@@ -588,10 +597,10 @@ def wait_for_save(self) -> None:
588597
total_ptrs = total_ptrs.reshape(total_ptrs.shape[0], -1)
589598
shard_indexs = [0] * len(total_ucm_block_ids)
590599
try:
591-
self.synchronize()
600+
event_handle = self._get_dump_event_handle()
592601
save_start_time = time.perf_counter() * 1000
593602
task = self.store.dump_data(
594-
total_ucm_block_ids, shard_indexs, total_ptrs
603+
total_ucm_block_ids, shard_indexs, total_ptrs, event_handle
595604
)
596605
dump_tasks.append(task)
597606
except RuntimeError as e:
@@ -726,9 +735,9 @@ def save_kv_layer(
726735
shard_indexs = [layer_id] * len(total_ucm_block_ids)
727736
try:
728737
layer_ptrs = np.ascontiguousarray(total_ptrs[:, layer_id, :])
729-
self.synchronize()
738+
event_handle = self._get_dump_event_handle()
730739
task = self.store.dump_data(
731-
total_ucm_block_ids, shard_indexs, layer_ptrs
740+
total_ucm_block_ids, shard_indexs, layer_ptrs, event_handle
732741
)
733742
self.dump_tasks[layer_name] = task
734743
except RuntimeError as e:
@@ -745,6 +754,8 @@ def wait_for_save(self) -> None:
745754
logger.error(f"wait for dump kv cache failed. {e}")
746755
self.dump_tasks.clear()
747756
self.is_save = False
757+
if self.enable_event_sync:
758+
self.device.destroy_event_handles()
748759

749760

750761
class UCMCPConnector(UCMLayerWiseConnector):

ucm/shared/trans/ascend/ascend_stream.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,4 +175,12 @@ Status AscendStream::Synchronized()
175175
return Status{ret, std::to_string(ret)};
176176
}
177177

178-
} // namespace UC::Trans
178+
Status AscendStream::WaitEvent(void* event)
179+
{
180+
if (event == nullptr) { return Status::OK(); }
181+
auto ret = aclrtStreamWaitEvent(stream_, static_cast<aclrtEvent>(event));
182+
if (ret != ACL_SUCCESS) [[unlikely]] { return Status{ret, std::to_string(ret)}; }
183+
return Status::OK();
184+
}
185+
186+
} // namespace UC::Trans

ucm/shared/trans/ascend/ascend_stream.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,9 @@ class AscendStream : public Stream {
5757

5858
Status AppendCallback(std::function<void(bool)> cb) override;
5959
Status Synchronized() override;
60+
Status WaitEvent(void* event) override;
6061
};
6162

62-
} // namespace UC::Trans
63+
} // namespace UC::Trans
6364

6465
#endif

ucm/shared/trans/cuda/cuda_stream.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,4 +155,12 @@ Status Trans::CudaStream::Synchronized()
155155
return Status::OK();
156156
}
157157

158-
} // namespace UC::Trans
158+
Status Trans::CudaStream::WaitEvent(void* event)
159+
{
160+
if (event == nullptr) { return Status::OK(); }
161+
auto ret = cudaStreamWaitEvent(stream_, static_cast<cudaEvent_t>(event), 0);
162+
if (ret != cudaSuccess) [[unlikely]] { return Status{ret, cudaGetErrorString(ret)}; }
163+
return Status::OK();
164+
}
165+
166+
} // namespace UC::Trans

ucm/shared/trans/cuda/cuda_stream.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,9 @@ class CudaStream : public Stream {
5252

5353
Status AppendCallback(std::function<void(bool)> cb) override;
5454
Status Synchronized() override;
55+
Status WaitEvent(void* event) override;
5556
};
5657

57-
} // namespace UC::Trans
58+
} // namespace UC::Trans
5859

5960
#endif

ucm/shared/trans/simu/simu_stream.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,12 @@ Status SimuStream::AppendCallback(std::function<void(bool)> cb)
157157
return Status::OK();
158158
}
159159

160+
Status SimuStream::WaitEvent(void* event)
161+
{
162+
(void)event;
163+
return Status::OK();
164+
}
165+
160166
Status SimuStream::Synchronized()
161167
{
162168
std::mutex mutex;

ucm/shared/trans/simu/simu_stream.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class SimuStream : public Stream {
6262
Status HostToDeviceAsync(void* host, void* device[], size_t size, size_t number) override;
6363

6464
Status AppendCallback(std::function<void(bool)> cb) override;
65+
Status WaitEvent(void* event) override;
6566
Status Synchronized() override;
6667
};
6768

ucm/shared/trans/stream.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,9 @@ class Stream {
5050

5151
virtual Status AppendCallback(std::function<void(bool)> cb) = 0;
5252
virtual Status Synchronized() = 0;
53+
virtual Status WaitEvent(void* event) = 0;
5354
};
5455

55-
} // namespace UC::Trans
56+
} // namespace UC::Trans
5657

5758
#endif

0 commit comments

Comments
 (0)