Skip to content

Commit 7a5aa25

Browse files
kevincheng2claude
andcommitted
[Feature][KVCache] add BatchRequest.from_tasks and refactor worker task parsing
## Motivation 将 worker_process 中重复的 task 解析逻辑收敛到 BatchRequest,减少代码冗余,提升可维护性。 ## Modifications - `fastdeploy/engine/request.py`:新增 `BatchRequest.from_tasks()` 类方法,统一将 task_queue 任务分类为推理请求和控制请求 - `fastdeploy/worker/worker_process.py`:使用 `BatchRequest.from_tasks()` 替代内联解析逻辑,并修复重复的 control_reqs 处理块 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 003d6a8 commit 7a5aa25

2 files changed

Lines changed: 40 additions & 33 deletions

File tree

fastdeploy/engine/request.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,37 @@ def extend(self, batch_requests: list["BatchRequest"]):
694694
for br in batch_requests:
695695
self.append(br)
696696

697+
@classmethod
698+
def from_tasks(cls, tasks: list) -> tuple["BatchRequest", list, int]:
699+
"""Classify tasks from the engine worker queue into inference requests and control requests.
700+
701+
Args:
702+
tasks: List of (payload, real_bsz) tuples from task_queue.get_tasks().
703+
payload is one of: BatchRequest, List[Request], or [ControlRequest].
704+
705+
Returns:
706+
(batch_request, control_reqs, max_occupied_batch_index)
707+
- batch_request: merged BatchRequest containing all inference requests
708+
- control_reqs: list of ControlRequest objects
709+
- max_occupied_batch_index: real_bsz of the last inference task batch
710+
"""
711+
batch_request = cls()
712+
control_reqs = []
713+
max_occupied_batch_index = 0
714+
715+
for payload, bsz in tasks:
716+
if len(payload) > 0 and isinstance(payload[0], ControlRequest):
717+
control_reqs.append(payload[0])
718+
else:
719+
max_occupied_batch_index = int(bsz)
720+
if isinstance(payload, cls):
721+
batch_request.append(payload)
722+
else:
723+
for req in payload:
724+
batch_request.add_request(req)
725+
726+
return batch_request, control_reqs, max_occupied_batch_index
727+
697728

698729
class ControlRequest:
699730
"""A generic control request that supports method and args for control operations.

fastdeploy/worker/worker_process.py

Lines changed: 9 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -595,47 +595,23 @@ def event_loop_normal(self) -> None:
595595
len(tasks) > 0
596596
), f"task_queue.get_tasks() should contain at least one tuple, [([req1, ...] ,real_bsz)], but got len(tasks)={len(tasks)}"
597597

598-
control_reqs = []
599-
req_dicts = BatchRequest()
600-
for req_dict, bsz in tasks:
601-
if len(req_dict) > 0 and isinstance(req_dict[0], ControlRequest):
602-
control_reqs.append(req_dict[0])
603-
else:
604-
max_occupied_batch_index = int(bsz)
605-
# req_dict can be either List[Request] or BatchRequest
606-
if isinstance(req_dict, BatchRequest):
607-
req_dicts.append(req_dict)
608-
else:
609-
for req in req_dict:
610-
req_dicts.add_request(req)
598+
batch_request, control_reqs, max_occupied_batch_index = BatchRequest.from_tasks(tasks)
611599

612-
# todo: run control request async
613600
if len(control_reqs) > 0:
614601
logger.info(f"Rank: {self.local_rank} received {len(control_reqs)} control request.")
615602
for control_req in control_reqs:
616603
if self.parallel_config.use_ep:
617604
self.cached_control_reqs.append(control_req)
618605
logger.info(f"Rank: {self.local_rank} cached ep control request: {control_req}")
619606
else:
620-
max_occupied_batch_index = int(bsz)
621-
req_dicts.extend(req_dict)
622-
623-
# todo: run control request async
624-
if len(control_reqs) > 0:
625-
logger.info(f"Rank: {self.local_rank} received {len(control_reqs)} control request.")
626-
for control_req in control_reqs:
627-
if self.parallel_config.use_ep:
628-
self.cached_control_reqs.append(control_req)
629-
logger.info(f"Rank: {self.local_rank} cached ep control request: {control_req}")
630-
else:
631-
self.run_control_method(control_req)
632-
self._tp_barrier_wait() if tp_size > 1 else None
633-
634-
if len(req_dicts) > 0:
607+
self.run_control_method(control_req)
608+
self._tp_barrier_wait() if tp_size > 1 else None
609+
610+
if len(batch_request) > 0:
635611
# Count prefill requests in current batch
636-
num_prefill_requests = sum(1 for req in req_dicts if req.task_type == RequestType.PREFILL)
637-
num_scheduled_requests = len(req_dicts)
638-
scheduled_request_ids = [req.request_id for req in req_dicts]
612+
num_prefill_requests = sum(1 for req in batch_request if req.task_type == RequestType.PREFILL)
613+
num_scheduled_requests = len(batch_request)
614+
scheduled_request_ids = [req.request_id for req in batch_request]
639615
logger.info(
640616
f"Rank: {self.local_rank}, num_prefill_requests: {num_prefill_requests}, "
641617
f"max_occupied_batch_index: {max_occupied_batch_index}, "
@@ -644,7 +620,7 @@ def event_loop_normal(self) -> None:
644620
)
645621

646622
# Process prefill inputs
647-
self.worker.preprocess_new_task(req_dicts, max_occupied_batch_index)
623+
self.worker.preprocess_new_task(batch_request, max_occupied_batch_index)
648624
else:
649625
if self.scheduler_config.splitwise_role == "prefill":
650626
if tp_size > 1:

0 commit comments

Comments
 (0)