@@ -595,47 +595,23 @@ def event_loop_normal(self) -> None:
595595 len (tasks ) > 0
596596 ), f"task_queue.get_tasks() should contain at least one tuple, [([req1, ...] ,real_bsz)], but got len(tasks)={ len (tasks )} "
597597
598- control_reqs = []
599- req_dicts = BatchRequest ()
600- for req_dict , bsz in tasks :
601- if len (req_dict ) > 0 and isinstance (req_dict [0 ], ControlRequest ):
602- control_reqs .append (req_dict [0 ])
603- else :
604- max_occupied_batch_index = int (bsz )
605- # req_dict can be either List[Request] or BatchRequest
606- if isinstance (req_dict , BatchRequest ):
607- req_dicts .append (req_dict )
608- else :
609- for req in req_dict :
610- req_dicts .add_request (req )
598+ batch_request , control_reqs , max_occupied_batch_index = BatchRequest .from_tasks (tasks )
611599
612- # todo: run control request async
613600 if len (control_reqs ) > 0 :
614601 logger .info (f"Rank: { self .local_rank } received { len (control_reqs )} control request." )
615602 for control_req in control_reqs :
616603 if self .parallel_config .use_ep :
617604 self .cached_control_reqs .append (control_req )
618605 logger .info (f"Rank: { self .local_rank } cached ep control request: { control_req } " )
619606 else :
620- max_occupied_batch_index = int (bsz )
621- req_dicts .extend (req_dict )
622-
623- # todo: run control request async
624- if len (control_reqs ) > 0 :
625- logger .info (f"Rank: { self .local_rank } received { len (control_reqs )} control request." )
626- for control_req in control_reqs :
627- if self .parallel_config .use_ep :
628- self .cached_control_reqs .append (control_req )
629- logger .info (f"Rank: { self .local_rank } cached ep control request: { control_req } " )
630- else :
631- self .run_control_method (control_req )
632- self ._tp_barrier_wait () if tp_size > 1 else None
633-
634- if len (req_dicts ) > 0 :
607+ self .run_control_method (control_req )
608+ self ._tp_barrier_wait () if tp_size > 1 else None
609+
610+ if len (batch_request ) > 0 :
635611 # Count prefill requests in current batch
636- num_prefill_requests = sum (1 for req in req_dicts if req .task_type == RequestType .PREFILL )
637- num_scheduled_requests = len (req_dicts )
638- scheduled_request_ids = [req .request_id for req in req_dicts ]
612+ num_prefill_requests = sum (1 for req in batch_request if req .task_type == RequestType .PREFILL )
613+ num_scheduled_requests = len (batch_request )
614+ scheduled_request_ids = [req .request_id for req in batch_request ]
639615 logger .info (
640616 f"Rank: { self .local_rank } , num_prefill_requests: { num_prefill_requests } , "
641617 f"max_occupied_batch_index: { max_occupied_batch_index } , "
@@ -644,7 +620,7 @@ def event_loop_normal(self) -> None:
644620 )
645621
646622 # Process prefill inputs
647- self .worker .preprocess_new_task (req_dicts , max_occupied_batch_index )
623+ self .worker .preprocess_new_task (batch_request , max_occupied_batch_index )
648624 else :
649625 if self .scheduler_config .splitwise_role == "prefill" :
650626 if tp_size > 1 :
0 commit comments