diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 3750827a2d5..e744c4c7e1f 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -312,6 +312,18 @@ def _broadcast_model_weights_signal(self, src: int, group) -> int: value = model_weights_signal_tensor.numpy()[0] return int(value) + def _get_exist_task_flag(self) -> bool: + if self.nnode > 1: + return self.task_queue.read_finish_flag.get() == 1 + else: + return self.exist_task_signal.value[0] == ExistTaskStatus.EXIST + + def _update_exist_task_flag(self, flag: bool) -> None: + if self.nnode > 1: + self.task_queue.read_finish_flag.set(1 if flag else 0) + else: + self.exist_task_signal.value[0] = ExistTaskStatus.EXIST if flag else ExistTaskStatus.EMPTY + def _tp_barrier_wait(self): if current_platform.is_xpu() or self.enable_overlap_schedule: self.task_queue.worker_process_tp_barrier.wait() @@ -436,7 +448,7 @@ def event_loop_normal(self) -> None: self._init_eplb_signal() tp_size = self.parallel_config.tensor_parallel_size # Currently, only support single node - self.nnode = (tp_size + self.max_chips_per_node) // self.max_chips_per_node + self.nnode = (tp_size + self.max_chips_per_node - 1) // self.max_chips_per_node max_occupied_batch_index = 0 tp_rank = self.local_rank % tp_size @@ -454,16 +466,17 @@ def event_loop_normal(self) -> None: req_dicts = None self.worker_healthy_live_signal.value[tp_rank % self.max_chips_per_node] = int(time.time()) + self._tp_barrier_wait() if tp_size > 1 else None + # The first worker detects whether there are tasks in the task queue if tp_rank == 0: if self.task_queue.exist_tasks(): if envs.ENABLE_V1_KVCACHE_SCHEDULER or not ( self.fd_config.enable_mm_runtime and self.worker.exist_prefill() ): - if self.nnode > 1: - self.task_queue.read_finish_flag.set(1) - else: - self.exist_task_signal.value[0] = ExistTaskStatus.EXIST + self._update_exist_task_flag(True) + else: + self._update_exist_task_flag(False) # Synchronize the signal set by tp_rank0 visiable to other workers self._tp_barrier_wait() if tp_size > 1 else None @@ -521,17 +534,14 @@ def event_loop_normal(self) -> None: ) # 所有 Rank 已同步唤醒,启动权重更新流程 continue - if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1: - logger.info(f"Rank: {self.local_rank} Detected new requests.") + if self._get_exist_task_flag(): + logger.debug(f"Rank: {self.local_rank} Detected new requests.") tasks, read_finish = self.task_queue.get_tasks() # Only one of all tp_size client will get read_finish == True. if read_finish: - # Reset the two signal. - if self.nnode > 1: - self.task_queue.read_finish_flag.set(0) - else: - self.exist_task_signal.value[0] = ExistTaskStatus.EMPTY + self._update_exist_task_flag(False) + self._tp_barrier_wait() if tp_size > 1 else None req_dicts, control_reqs = [], [] for req_dict, bsz in tasks: