1818
1919import torch
2020
21- from sglang .srt .distributed import get_tp_group
2221from sglang .srt .layers .dp_attention import get_attention_tp_group
2322from sglang .srt .layers .logits_processor import LogitsProcessorOutput
2423from sglang .srt .layers .moe .utils import speculative_moe_backend_context
@@ -136,7 +135,7 @@ def __init__(
136135 server_args = server_args ,
137136 gpu_id = gpu_id ,
138137 tp_rank = tp_rank ,
139- pp_rank = 0 , # FIXME
138+ pp_rank = 0 , # spec workers don't support pipeline parallelism
140139 dp_rank = dp_rank ,
141140 moe_ep_rank = moe_ep_rank ,
142141 attn_cp_rank = attn_cp_rank ,
@@ -293,8 +292,9 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
293292 set_time_batch (batch .reqs , "set_spec_draft_end_time" , trace_only = True )
294293 set_time_batch (batch .reqs , "set_spec_verify_start_time" , trace_only = True )
295294
295+ # Install verify_input as `batch.spec_info` for the verify forward.
296296 batch .spec_info = verify_input
297- logits_output , verify_output , can_run_cuda_graph = self .verify (batch )
297+ verify_output = self .verify (batch )
298298
299299 if get_global_tracing_enabled ():
300300 for idx , req in enumerate (batch .reqs ):
@@ -320,8 +320,9 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
320320 self .server_args .enable_dp_attention
321321 or draft_extend_input .input_ids .shape [0 ] > 0
322322 ):
323- # decode is not finished; stash for extend, then restash
324- # the next-iter EagleDraftInput it returns.
323+ # decode is not finished; install draft_extend_input for
324+ # the extend forward, then install the next-iter
325+ # EagleDraftInput it returns.
325326 batch .spec_info = draft_extend_input
326327 next_draft_input = self .forward_draft_extend_after_decode (batch )
327328 batch .spec_info = next_draft_input
@@ -337,31 +338,13 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
337338 )
338339
339340 return GenerationBatchResult (
340- logits_output = logits_output ,
341+ logits_output = verify_output . logits_output ,
341342 next_token_ids = verify_output .accept_tokens ,
342343 num_correct_drafts = sum (verify_output .num_correct_drafts_per_req_cpu ),
343344 num_correct_drafts_per_req_cpu = verify_output .num_correct_drafts_per_req_cpu ,
344- can_run_cuda_graph = can_run_cuda_graph ,
345+ can_run_cuda_graph = verify_output . can_run_cuda_graph ,
345346 )
346347
347- def check_forward_draft_extend_after_decode (self , verify_output : EagleVerifyOutput ):
348- local_need_forward = verify_output .draft_extend_input .input_ids .shape [0 ] > 0
349- if not self .server_args .enable_dp_attention :
350- return local_need_forward
351-
352- global_need_forward = torch .tensor (
353- [
354- (local_need_forward ),
355- ],
356- dtype = torch .int64 ,
357- )
358- torch .distributed .all_reduce (
359- global_need_forward , group = get_tp_group ().cpu_group
360- )
361- global_need_forward_cnt = global_need_forward [0 ].item ()
362- need_forward = global_need_forward_cnt > 0
363- return need_forward
364-
365348 def forward_target_extend (
366349 self , batch : ScheduleBatch
367350 ) -> Tuple [LogitsProcessorOutput , torch .Tensor , Optional [torch .Tensor ], bool ]:
@@ -644,7 +627,8 @@ def verify(self, batch: ScheduleBatch):
644627 ForwardMode .DECODE if not batch .forward_mode .is_idle () else ForwardMode .IDLE
645628 )
646629
647- return logits_output , res , can_run_cuda_graph
630+ res .can_run_cuda_graph = can_run_cuda_graph
631+ return res
648632
649633 def forward_draft_extend (
650634 self ,
0 commit comments