@@ -177,7 +177,7 @@ def dispatch_task_loop(self):
177177 self .waiting_dict [task .get_key ()] = task
178178 else :
179179 task .start_trans_time = time .time ()
180- self .success_queue .put ((None , task ))
180+ self .success_queue .put ((None , None , task ))
181181
182182 # up status
183183 task = trans_task_group .task_list [0 ]
@@ -335,7 +335,10 @@ def read_page_to_mems_loop(self):
335335 while True :
336336 trans_task : NIXLChunckedTransTask = self .ready_page_task_queue .get ()
337337 # 将数据写回 mem manger
338+ copy_start_event = torch .cuda .Event (enable_timing = True )
339+ copy_end_event = torch .cuda .Event (enable_timing = True )
338340 with torch .cuda .stream (stream = self .copy_cuda_stream ):
341+ copy_start_event .record (self .copy_cuda_stream )
339342 cur_mem = self .mem_managers [self .device_id ]
340343 cur_mem .read_page_kv_move_buffer_to_mem (
341344 mem_indexes = trans_task .mem_indexes ,
@@ -344,22 +347,21 @@ def read_page_to_mems_loop(self):
344347 mem_managers = self .mem_managers ,
345348 dp_world_size = self .dp_world_size ,
346349 )
347- sync_event = torch .cuda .Event ()
348- sync_event .record ()
350+ copy_end_event .record (self .copy_cuda_stream )
349351
350- self .success_queue .put ((sync_event , trans_task ))
352+ self .success_queue .put ((copy_end_event , copy_start_event , trans_task ))
351353 return
352354
353355 @log_exception
354356 def success_loop (self ):
355357 torch .cuda .set_device (self .device_id )
356358 while True :
357- sync_event , trans_task = self .success_queue .get ()
359+ copy_end_event , copy_start_event , trans_task = self .success_queue .get ()
358360 trans_task : NIXLChunckedTransTask = trans_task
359- sync_event : Optional [ torch . cuda . Event ] = sync_event
360- # 兼容传输kv 数量为0的时候, sync_event 为 None的情况。
361- if sync_event is not None :
362- sync_event . synchronize ( )
361+ read_page_gpu_time_ms = - 1.0
362+ if copy_end_event is not None :
363+ copy_end_event . synchronize ()
364+ read_page_gpu_time_ms = copy_start_event . elapsed_time ( copy_end_event )
363365
364366 if trans_task .nixl_dst_page_index is not None :
365367 self .page_index_queue .put (trans_task .nixl_dst_page_index )
@@ -369,7 +371,13 @@ def success_loop(self):
369371
370372 ret = trans_task .createRetObj ()
371373 self .task_out_queue .put (ret )
372- logger .info (f"trans task ret success:{ ret } cost time: { trans_task .transfer_time ()} s" )
374+ if read_page_gpu_time_ms >= 0 :
375+ logger .info (
376+ f"trans task ret success:{ ret } cost time: { trans_task .transfer_time ()} s "
377+ f"read_page_gpu_time: { read_page_gpu_time_ms :.3f} ms"
378+ )
379+ else :
380+ logger .info (f"trans task ret success:{ ret } cost time: { trans_task .transfer_time ()} s" )
373381
374382 @log_exception
375383 def fail_loop (self ):
0 commit comments