@@ -722,6 +722,8 @@ def _track_async_dump_requests(
722722 self ,
723723 requests_dispatch_meta : dict [str , RequestDispatchMeta ],
724724 ) -> None :
725+ if self .use_layerwise :
726+ return
725727 self ._async_dump_req_ids .update (
726728 request_id
727729 for request_id , dispatch_meta in requests_dispatch_meta .items ()
@@ -858,6 +860,8 @@ def _flush_pending_dump_tasks(self, request_ids: Optional[set[str]] = None) -> N
858860 self ._pending_dump_tasks = remaining_tasks
859861
860862 def handle_preemptions (self , kv_connector_metadata : KVConnectorMetadata ):
863+ if self .use_layerwise :
864+ return
861865 preempted_req_ids = getattr (kv_connector_metadata , "preempted_req_ids" , None )
862866 if preempted_req_ids :
863867 self ._flush_pending_dump_tasks (preempted_req_ids )
@@ -948,6 +952,8 @@ def request_finished(
948952 request : "Request" ,
949953 block_ids : list [int ],
950954 ) -> tuple [bool , dict [str , Any ] | None ]:
955+ if self .use_layerwise :
956+ return False , None
951957 if request .request_id in self ._async_dump_req_ids :
952958 self ._async_dump_req_ids .discard (request .request_id )
953959 return True , None
@@ -966,6 +972,8 @@ def get_finished(
966972 self ,
967973 finished_req_ids : set [str ],
968974 ) -> tuple [Optional [set [str ]], Optional [set [str ]]]:
975+ if self .use_layerwise :
976+ return None , None
969977 async_finished_req_ids = finished_req_ids & self ._async_dump_req_ids
970978
971979 if async_finished_req_ids :
@@ -1160,46 +1168,58 @@ def save_kv_layer(
11601168 total_vllm_block_ids .extend (vllm_block_ids )
11611169
11621170 if dump_request_ids :
1163- self ._async_dump_req_ids .update (dump_request_ids )
11641171 if self .dump_total_ptrs is None :
11651172 self .dump_total_ptrs = self .kv_cache_layout .extract_block_addrs (
11661173 total_vllm_block_ids , layer_first = True
11671174 )
11681175 shard_indexs = [layer_id ] * len (total_ucm_block_ids )
1169- event_handle = 0
11701176 try :
11711177 layer_ptrs = np .ascontiguousarray (self .dump_total_ptrs [local_layer_id ])
11721178 event_handle = self ._get_dump_event_handle ()
11731179 task = self .store .dump_data (
11741180 total_ucm_block_ids , shard_indexs , layer_ptrs , event_handle
11751181 )
1176- pending_dump_task = PendingDumpTask (
1177- task = task ,
1178- request_ids = set (dump_request_ids ),
1179- event_handle = event_handle ,
1180- )
1181- self ._pending_dump_tasks .append (pending_dump_task )
1182+ self .dump_tasks [layer_name ] = task
11821183 except Exception as e :
11831184 logger .error (f"submit dump task failed. { type (e ).__name__ } : { e } " )
1184- if self .enable_event_sync and event_handle and self .device is not None :
1185- self .device .destroy_event_handle (event_handle )
11861185 if self .is_save :
11871186 submit_end = time .perf_counter ()
11881187 ucmmetrics .update_stats (
11891188 {"layerwise_save_submit_ms" : (submit_end - submit_start ) * 1000 }
11901189 )
11911190
11921191 def wait_for_save (self ) -> None :
1193- if self ._connector_metadata :
1194- metadata = self ._get_connector_metadata ()
1195- self ._async_dump_req_ids .update (
1196- request_id
1197- for request_id , request in metadata .request_meta .items ()
1198- if len (request .dump_block_ids [0 ]) > 0
1199- )
1192+ if not self .is_save :
1193+ total_end = time .perf_counter ()
1194+ if self ._layerwise_batch_start is not None :
1195+ batch_total_ms = (total_end - self ._layerwise_batch_start ) * 1000
1196+ ucmmetrics .update_stats ({"layerwise_batch_total_ms" : batch_total_ms })
1197+ self ._layerwise_batch_start = None
1198+ return
1199+
1200+ total_start = time .perf_counter ()
1201+ try :
1202+ for layer_name in self .kv_caches :
1203+ if layer_name not in self .dump_tasks :
1204+ continue
1205+ self .store .wait (self .dump_tasks [layer_name ])
1206+ except Exception as e :
1207+ logger .error (f"wait for dump kv cache failed. { type (e ).__name__ } : { e } " )
1208+
1209+ total_end = time .perf_counter ()
1210+ stats = {"layerwise_save_tail_total_ms" : (total_end - total_start ) * 1000 }
1211+ if self ._layerwise_batch_start is not None :
1212+ stats ["layerwise_batch_total_ms" ] = (
1213+ total_end - self ._layerwise_batch_start
1214+ ) * 1000
1215+ self ._layerwise_batch_start = None
1216+ ucmmetrics .update_stats (stats )
1217+
12001218 self .dump_tasks .clear ()
12011219 self .is_save = False
12021220 self .dump_total_ptrs = None
1221+ if self .enable_event_sync :
1222+ self .device .destroy_event_handles ()
12031223
12041224
12051225class UCMCPConnector (UCMLayerWiseConnector ):
0 commit comments