Skip to content

Commit a1cd30f

Browse files
committed
revert layerwise
1 parent 02f81ef commit a1cd30f

1 file changed

Lines changed: 37 additions & 17 deletions

File tree

ucm/integration/vllm/ucm_connector.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,8 @@ def _track_async_dump_requests(
722722
self,
723723
requests_dispatch_meta: dict[str, RequestDispatchMeta],
724724
) -> None:
725+
if self.use_layerwise:
726+
return
725727
self._async_dump_req_ids.update(
726728
request_id
727729
for request_id, dispatch_meta in requests_dispatch_meta.items()
@@ -858,6 +860,8 @@ def _flush_pending_dump_tasks(self, request_ids: Optional[set[str]] = None) -> N
858860
self._pending_dump_tasks = remaining_tasks
859861

860862
def handle_preemptions(self, kv_connector_metadata: KVConnectorMetadata):
863+
if self.use_layerwise:
864+
return
861865
preempted_req_ids = getattr(kv_connector_metadata, "preempted_req_ids", None)
862866
if preempted_req_ids:
863867
self._flush_pending_dump_tasks(preempted_req_ids)
@@ -948,6 +952,8 @@ def request_finished(
948952
request: "Request",
949953
block_ids: list[int],
950954
) -> tuple[bool, dict[str, Any] | None]:
955+
if self.use_layerwise:
956+
return False, None
951957
if request.request_id in self._async_dump_req_ids:
952958
self._async_dump_req_ids.discard(request.request_id)
953959
return True, None
@@ -966,6 +972,8 @@ def get_finished(
966972
self,
967973
finished_req_ids: set[str],
968974
) -> tuple[Optional[set[str]], Optional[set[str]]]:
975+
if self.use_layerwise:
976+
return None, None
969977
async_finished_req_ids = finished_req_ids & self._async_dump_req_ids
970978

971979
if async_finished_req_ids:
@@ -1160,46 +1168,58 @@ def save_kv_layer(
11601168
total_vllm_block_ids.extend(vllm_block_ids)
11611169

11621170
if dump_request_ids:
1163-
self._async_dump_req_ids.update(dump_request_ids)
11641171
if self.dump_total_ptrs is None:
11651172
self.dump_total_ptrs = self.kv_cache_layout.extract_block_addrs(
11661173
total_vllm_block_ids, layer_first=True
11671174
)
11681175
shard_indexs = [layer_id] * len(total_ucm_block_ids)
1169-
event_handle = 0
11701176
try:
11711177
layer_ptrs = np.ascontiguousarray(self.dump_total_ptrs[local_layer_id])
11721178
event_handle = self._get_dump_event_handle()
11731179
task = self.store.dump_data(
11741180
total_ucm_block_ids, shard_indexs, layer_ptrs, event_handle
11751181
)
1176-
pending_dump_task = PendingDumpTask(
1177-
task=task,
1178-
request_ids=set(dump_request_ids),
1179-
event_handle=event_handle,
1180-
)
1181-
self._pending_dump_tasks.append(pending_dump_task)
1182+
self.dump_tasks[layer_name] = task
11821183
except Exception as e:
11831184
logger.error(f"submit dump task failed. {type(e).__name__}: {e}")
1184-
if self.enable_event_sync and event_handle and self.device is not None:
1185-
self.device.destroy_event_handle(event_handle)
11861185
if self.is_save:
11871186
submit_end = time.perf_counter()
11881187
ucmmetrics.update_stats(
11891188
{"layerwise_save_submit_ms": (submit_end - submit_start) * 1000}
11901189
)
11911190

11921191
def wait_for_save(self) -> None:
1193-
if self._connector_metadata:
1194-
metadata = self._get_connector_metadata()
1195-
self._async_dump_req_ids.update(
1196-
request_id
1197-
for request_id, request in metadata.request_meta.items()
1198-
if len(request.dump_block_ids[0]) > 0
1199-
)
1192+
if not self.is_save:
1193+
total_end = time.perf_counter()
1194+
if self._layerwise_batch_start is not None:
1195+
batch_total_ms = (total_end - self._layerwise_batch_start) * 1000
1196+
ucmmetrics.update_stats({"layerwise_batch_total_ms": batch_total_ms})
1197+
self._layerwise_batch_start = None
1198+
return
1199+
1200+
total_start = time.perf_counter()
1201+
try:
1202+
for layer_name in self.kv_caches:
1203+
if layer_name not in self.dump_tasks:
1204+
continue
1205+
self.store.wait(self.dump_tasks[layer_name])
1206+
except Exception as e:
1207+
logger.error(f"wait for dump kv cache failed. {type(e).__name__}: {e}")
1208+
1209+
total_end = time.perf_counter()
1210+
stats = {"layerwise_save_tail_total_ms": (total_end - total_start) * 1000}
1211+
if self._layerwise_batch_start is not None:
1212+
stats["layerwise_batch_total_ms"] = (
1213+
total_end - self._layerwise_batch_start
1214+
) * 1000
1215+
self._layerwise_batch_start = None
1216+
ucmmetrics.update_stats(stats)
1217+
12001218
self.dump_tasks.clear()
12011219
self.is_save = False
12021220
self.dump_total_ptrs = None
1221+
if self.enable_event_sync:
1222+
self.device.destroy_event_handles()
12031223

12041224

12051225
class UCMCPConnector(UCMLayerWiseConnector):

0 commit comments

Comments
 (0)