Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions fastdeploy/worker/worker_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,18 @@ def _broadcast_model_weights_signal(self, src: int, group) -> int:
value = model_weights_signal_tensor.numpy()[0]
return int(value)

def _get_exist_task_flag(self) -> bool:
if self.nnode > 1:
return self.task_queue.read_finish_flag.get() == 1
else:
return self.exist_task_signal.value[0] == ExistTaskStatus.EXIST

def _update_exist_task_flag(self, flag: bool) -> None:
if self.nnode > 1:
self.task_queue.read_finish_flag.set(1 if flag else 0)
else:
self.exist_task_signal.value[0] = ExistTaskStatus.EXIST if flag else ExistTaskStatus.EMPTY

def _tp_barrier_wait(self):
if current_platform.is_xpu() or self.enable_overlap_schedule:
self.task_queue.worker_process_tp_barrier.wait()
Expand Down Expand Up @@ -436,7 +448,7 @@ def event_loop_normal(self) -> None:
self._init_eplb_signal()
tp_size = self.parallel_config.tensor_parallel_size
# Currently, only support single node
self.nnode = (tp_size + self.max_chips_per_node) // self.max_chips_per_node
self.nnode = (tp_size + self.max_chips_per_node - 1) // self.max_chips_per_node
max_occupied_batch_index = 0
tp_rank = self.local_rank % tp_size

Expand All @@ -454,16 +466,17 @@ def event_loop_normal(self) -> None:
req_dicts = None
self.worker_healthy_live_signal.value[tp_rank % self.max_chips_per_node] = int(time.time())

self._tp_barrier_wait() if tp_size > 1 else None

# The first worker detects whether there are tasks in the task queue
if tp_rank == 0:
if self.task_queue.exist_tasks():
if envs.ENABLE_V1_KVCACHE_SCHEDULER or not (
self.fd_config.enable_mm_runtime and self.worker.exist_prefill()
):
if self.nnode > 1:
self.task_queue.read_finish_flag.set(1)
else:
self.exist_task_signal.value[0] = ExistTaskStatus.EXIST
self._update_exist_task_flag(True)
else:
self._update_exist_task_flag(False)

# Synchronize the signal set by tp_rank0 visiable to other workers
self._tp_barrier_wait() if tp_size > 1 else None
Expand Down Expand Up @@ -521,17 +534,14 @@ def event_loop_normal(self) -> None:
) # 所有 Rank 已同步唤醒,启动权重更新流程
continue

if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1:
logger.info(f"Rank: {self.local_rank} Detected new requests.")
if self._get_exist_task_flag():
logger.debug(f"Rank: {self.local_rank} Detected new requests.")

tasks, read_finish = self.task_queue.get_tasks()
# Only one of all tp_size client will get read_finish == True.
if read_finish:
# Reset the two signal.
if self.nnode > 1:
self.task_queue.read_finish_flag.set(0)
else:
self.exist_task_signal.value[0] = ExistTaskStatus.EMPTY
self._update_exist_task_flag(False)
self._tp_barrier_wait() if tp_size > 1 else None

req_dicts, control_reqs = [], []
for req_dict, bsz in tasks:
Expand Down