@@ -312,6 +312,18 @@ def _broadcast_model_weights_signal(self, src: int, group) -> int:
312312 value = model_weights_signal_tensor .numpy ()[0 ]
313313 return int (value )
314314
315+ def _get_exist_task_flag (self ) -> bool :
316+ if self .nnode > 1 :
317+ return self .task_queue .read_finish_flag .get () == 1
318+ else :
319+ return self .exist_task_signal .value [0 ] == ExistTaskStatus .EXIST
320+
321+ def _update_exist_task_flag (self , flag : bool ) -> None :
322+ if self .nnode > 1 :
323+ self .task_queue .read_finish_flag .set (1 if flag else 0 )
324+ else :
325+ self .exist_task_signal .value [0 ] = ExistTaskStatus .EXIST if flag else ExistTaskStatus .EMPTY
326+
315327 def _tp_barrier_wait (self ):
316328 if current_platform .is_xpu () or self .enable_overlap_schedule :
317329 self .task_queue .worker_process_tp_barrier .wait ()
@@ -436,7 +448,7 @@ def event_loop_normal(self) -> None:
436448 self ._init_eplb_signal ()
437449 tp_size = self .parallel_config .tensor_parallel_size
438450 # Currently, only support single node
439- self .nnode = (tp_size + self .max_chips_per_node ) // self .max_chips_per_node
451+ self .nnode = (tp_size + self .max_chips_per_node - 1 ) // self .max_chips_per_node
440452 max_occupied_batch_index = 0
441453 tp_rank = self .local_rank % tp_size
442454
@@ -454,16 +466,17 @@ def event_loop_normal(self) -> None:
454466 req_dicts = None
455467 self .worker_healthy_live_signal .value [tp_rank % self .max_chips_per_node ] = int (time .time ())
456468
469+ self ._tp_barrier_wait () if tp_size > 1 else None
470+
457471 # The first worker detects whether there are tasks in the task queue
458472 if tp_rank == 0 :
459473 if self .task_queue .exist_tasks ():
460474 if envs .ENABLE_V1_KVCACHE_SCHEDULER or not (
461475 self .fd_config .enable_mm_runtime and self .worker .exist_prefill ()
462476 ):
463- if self .nnode > 1 :
464- self .task_queue .read_finish_flag .set (1 )
465- else :
466- self .exist_task_signal .value [0 ] = ExistTaskStatus .EXIST
477+ self ._update_exist_task_flag (True )
478+ else :
479+ self ._update_exist_task_flag (False )
467480
468481 # Synchronize the signal set by tp_rank0 visiable to other workers
469482 self ._tp_barrier_wait () if tp_size > 1 else None
@@ -521,17 +534,14 @@ def event_loop_normal(self) -> None:
521534 ) # 所有 Rank 已同步唤醒,启动权重更新流程
522535 continue
523536
524- if self .exist_task_signal . value [ 0 ] == ExistTaskStatus . EXIST or self . task_queue . read_finish_flag . get () == 1 :
525- logger .info (f"Rank: { self .local_rank } Detected new requests." )
537+ if self ._get_exist_task_flag () :
538+ logger .debug (f"Rank: { self .local_rank } Detected new requests." )
526539
527540 tasks , read_finish = self .task_queue .get_tasks ()
528541 # Only one of all tp_size client will get read_finish == True.
529542 if read_finish :
530- # Reset the two signal.
531- if self .nnode > 1 :
532- self .task_queue .read_finish_flag .set (0 )
533- else :
534- self .exist_task_signal .value [0 ] = ExistTaskStatus .EMPTY
543+ self ._update_exist_task_flag (False )
544+ self ._tp_barrier_wait () if tp_size > 1 else None
535545
536546 req_dicts , control_reqs = [], []
537547 for req_dict , bsz in tasks :
0 commit comments