Skip to content

Commit f01a1d2

Browse files
authored
[Spec] fold can_run_cuda_graph into EagleVerifyOutput; drop dead extend-after-decode check (sgl-project#25566)
1 parent 4a1ddd4 commit f01a1d2

9 files changed

Lines changed: 42 additions & 65 deletions

python/sglang/srt/speculative/eagle_info.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,6 +1012,10 @@ class EagleVerifyOutput:
10121012
num_correct_drafts_per_req_cpu: List[int]
10131013
# Accepted indices from logits_output.next_token_logits
10141014
accept_indices: torch.Tensor
1015+
# Whether the target verify forward ran a captured cuda graph. Set by
1016+
# the worker after `EagleVerifyInput.sample` returns; default kept so
1017+
# idle / direct constructions don't have to pass it.
1018+
can_run_cuda_graph: bool = False
10151019

10161020
@classmethod
10171021
def create_idle(

python/sglang/srt/speculative/eagle_worker.py

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import torch
77

8-
from sglang.srt.distributed import get_tp_group
98
from sglang.srt.hardware_backend.npu.graph_runner.eagle_draft_npu_graph_runner import (
109
EAGLEDraftNpuGraphRunner,
1110
)
@@ -162,7 +161,7 @@ def __init__(
162161
server_args=server_args,
163162
gpu_id=gpu_id,
164163
tp_rank=tp_rank,
165-
pp_rank=0, # FIXME
164+
pp_rank=0, # spec workers don't support pipeline parallelism
166165
dp_rank=dp_rank,
167166
moe_ep_rank=moe_ep_rank,
168167
attn_cp_rank=attn_cp_rank,
@@ -492,8 +491,9 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
492491
set_time_batch(batch.reqs, "set_spec_draft_end_time", trace_only=True)
493492
set_time_batch(batch.reqs, "set_spec_verify_start_time", trace_only=True)
494493

494+
# Install verify_input as `batch.spec_info` for the verify forward.
495495
batch.spec_info = verify_input
496-
logits_output, verify_output, can_run_cuda_graph = self.verify(batch)
496+
verify_output = self.verify(batch)
497497

498498
if get_global_tracing_enabled():
499499
for idx, req in enumerate(batch.reqs):
@@ -520,8 +520,9 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
520520
self.server_args.enable_dp_attention
521521
or draft_extend_input.input_ids.shape[0] > 0
522522
):
523-
# decode is not finished; stash for extend, then restash
524-
# the next-iter EagleDraftInput it returns.
523+
# decode is not finished; install draft_extend_input for
524+
# the extend forward, then install the next-iter
525+
# EagleDraftInput it returns.
525526
batch.spec_info = draft_extend_input
526527
next_draft_input = self.forward_draft_extend_after_decode(batch)
527528
batch.spec_info = next_draft_input
@@ -542,31 +543,13 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
542543
)
543544

544545
return GenerationBatchResult(
545-
logits_output=logits_output,
546+
logits_output=verify_output.logits_output,
546547
next_token_ids=verify_output.accept_tokens,
547548
num_correct_drafts=sum(verify_output.num_correct_drafts_per_req_cpu),
548549
num_correct_drafts_per_req_cpu=verify_output.num_correct_drafts_per_req_cpu,
549-
can_run_cuda_graph=can_run_cuda_graph,
550+
can_run_cuda_graph=verify_output.can_run_cuda_graph,
550551
)
551552

552-
def check_forward_draft_extend_after_decode(self, verify_output: EagleVerifyOutput):
553-
local_need_forward = verify_output.draft_extend_input.input_ids.shape[0] > 0
554-
if not self.server_args.enable_dp_attention:
555-
return local_need_forward
556-
557-
global_need_forward = torch.tensor(
558-
[
559-
(local_need_forward),
560-
],
561-
dtype=torch.int64,
562-
)
563-
torch.distributed.all_reduce(
564-
global_need_forward, group=get_tp_group().cpu_group
565-
)
566-
global_need_forward_cnt = global_need_forward[0].item()
567-
need_forward = global_need_forward_cnt > 0
568-
return need_forward
569-
570553
def forward_target_extend(
571554
self, batch: ScheduleBatch
572555
) -> Tuple[LogitsProcessorOutput, torch.Tensor, Optional[torch.Tensor], bool]:
@@ -1015,7 +998,8 @@ def verify(self, batch: ScheduleBatch):
1015998
ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE
1016999
)
10171000

1018-
return logits_output, res, can_run_cuda_graph
1001+
res.can_run_cuda_graph = can_run_cuda_graph
1002+
return res
10191003

10201004
def _mamba_verify_update(
10211005
self,

python/sglang/srt/speculative/eagle_worker_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def __init__(
154154
server_args=server_args,
155155
gpu_id=gpu_id,
156156
tp_rank=tp_rank,
157-
pp_rank=0, # FIXME
157+
pp_rank=0, # spec workers don't support pipeline parallelism
158158
dp_rank=dp_rank,
159159
moe_ep_rank=moe_ep_rank,
160160
attn_cp_rank=attn_cp_rank,

python/sglang/srt/speculative/frozen_kv_mtp_worker.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -449,8 +449,9 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
449449
set_time_batch(batch.reqs, "set_spec_draft_end_time", trace_only=True)
450450
set_time_batch(batch.reqs, "set_spec_verify_start_time", trace_only=True)
451451

452+
# Install verify_input as `batch.spec_info` for the verify forward.
452453
batch.spec_info = verify_input
453-
logits_output, verify_output, can_run_cuda_graph = self.verify(batch)
454+
verify_output = self.verify(batch)
454455

455456
if get_global_tracing_enabled():
456457
for idx, req in enumerate(batch.reqs):
@@ -470,18 +471,19 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
470471
self.server_args.enable_dp_attention
471472
or draft_extend_input.input_ids.shape[0] > 0
472473
):
473-
# Stash for the seed step; _run_assistant_seed_step swaps in
474-
# a fresh FrozenKVMTPDraftInput for next iter.
474+
# Install draft_extend_input as `batch.spec_info` for the seed
475+
# step; `_run_assistant_seed_step` replaces it with a fresh
476+
# `FrozenKVMTPDraftInput` for next iter.
475477
batch.spec_info = draft_extend_input
476478
self.forward_draft_extend_after_decode(batch)
477479
set_time_batch(batch.reqs, "set_spec_draft_extend_end_time", trace_only=True)
478480

479481
return GenerationBatchResult(
480-
logits_output=logits_output,
482+
logits_output=verify_output.logits_output,
481483
next_token_ids=verify_output.accept_tokens,
482484
num_correct_drafts=sum(verify_output.num_correct_drafts_per_req_cpu),
483485
num_correct_drafts_per_req_cpu=verify_output.num_correct_drafts_per_req_cpu,
484-
can_run_cuda_graph=can_run_cuda_graph,
486+
can_run_cuda_graph=verify_output.can_run_cuda_graph,
485487
)
486488

487489
def forward_target_extend(
@@ -518,7 +520,7 @@ def forward_draft_extend_after_decode(self, batch: ScheduleBatch) -> None:
518520
input_is_idle = batch.forward_mode.is_idle()
519521

520522
if not input_is_idle and draft_extend_input.input_ids.shape[0] == 0:
521-
# All reqs finished; stash an idle FrozenKVMTPDraftInput so the
523+
# All reqs finished. Install an idle FrozenKVMTPDraftInput so the
522524
# next-iter draft sees a valid spec_info.
523525
batch = batch.copy()
524526
batch.prepare_for_idle()
@@ -775,4 +777,5 @@ def verify(self, batch: ScheduleBatch):
775777
)
776778

777779
del seq_lens_pre_verify
778-
return logits_output, res, can_run_cuda_graph
780+
res.can_run_cuda_graph = can_run_cuda_graph
781+
return res

python/sglang/srt/speculative/multi_layer_eagle_worker.py

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
import torch
2020

21-
from sglang.srt.distributed import get_tp_group
2221
from sglang.srt.layers.dp_attention import get_attention_tp_group
2322
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
2423
from sglang.srt.layers.moe.utils import speculative_moe_backend_context
@@ -136,7 +135,7 @@ def __init__(
136135
server_args=server_args,
137136
gpu_id=gpu_id,
138137
tp_rank=tp_rank,
139-
pp_rank=0, # FIXME
138+
pp_rank=0, # spec workers don't support pipeline parallelism
140139
dp_rank=dp_rank,
141140
moe_ep_rank=moe_ep_rank,
142141
attn_cp_rank=attn_cp_rank,
@@ -293,8 +292,9 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
293292
set_time_batch(batch.reqs, "set_spec_draft_end_time", trace_only=True)
294293
set_time_batch(batch.reqs, "set_spec_verify_start_time", trace_only=True)
295294

295+
# Install verify_input as `batch.spec_info` for the verify forward.
296296
batch.spec_info = verify_input
297-
logits_output, verify_output, can_run_cuda_graph = self.verify(batch)
297+
verify_output = self.verify(batch)
298298

299299
if get_global_tracing_enabled():
300300
for idx, req in enumerate(batch.reqs):
@@ -320,8 +320,9 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
320320
self.server_args.enable_dp_attention
321321
or draft_extend_input.input_ids.shape[0] > 0
322322
):
323-
# decode is not finished; stash for extend, then restash
324-
# the next-iter EagleDraftInput it returns.
323+
# decode is not finished; install draft_extend_input for
324+
# the extend forward, then install the next-iter
325+
# EagleDraftInput it returns.
325326
batch.spec_info = draft_extend_input
326327
next_draft_input = self.forward_draft_extend_after_decode(batch)
327328
batch.spec_info = next_draft_input
@@ -337,31 +338,13 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
337338
)
338339

339340
return GenerationBatchResult(
340-
logits_output=logits_output,
341+
logits_output=verify_output.logits_output,
341342
next_token_ids=verify_output.accept_tokens,
342343
num_correct_drafts=sum(verify_output.num_correct_drafts_per_req_cpu),
343344
num_correct_drafts_per_req_cpu=verify_output.num_correct_drafts_per_req_cpu,
344-
can_run_cuda_graph=can_run_cuda_graph,
345+
can_run_cuda_graph=verify_output.can_run_cuda_graph,
345346
)
346347

347-
def check_forward_draft_extend_after_decode(self, verify_output: EagleVerifyOutput):
348-
local_need_forward = verify_output.draft_extend_input.input_ids.shape[0] > 0
349-
if not self.server_args.enable_dp_attention:
350-
return local_need_forward
351-
352-
global_need_forward = torch.tensor(
353-
[
354-
(local_need_forward),
355-
],
356-
dtype=torch.int64,
357-
)
358-
torch.distributed.all_reduce(
359-
global_need_forward, group=get_tp_group().cpu_group
360-
)
361-
global_need_forward_cnt = global_need_forward[0].item()
362-
need_forward = global_need_forward_cnt > 0
363-
return need_forward
364-
365348
def forward_target_extend(
366349
self, batch: ScheduleBatch
367350
) -> Tuple[LogitsProcessorOutput, torch.Tensor, Optional[torch.Tensor], bool]:
@@ -644,7 +627,8 @@ def verify(self, batch: ScheduleBatch):
644627
ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE
645628
)
646629

647-
return logits_output, res, can_run_cuda_graph
630+
res.can_run_cuda_graph = can_run_cuda_graph
631+
return res
648632

649633
def forward_draft_extend(
650634
self,

python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def __init__(
128128
server_args=server_args,
129129
gpu_id=gpu_id,
130130
tp_rank=tp_rank,
131-
pp_rank=0, # FIXME
131+
pp_rank=0, # spec workers don't support pipeline parallelism
132132
dp_rank=dp_rank,
133133
moe_ep_rank=moe_ep_rank,
134134
attn_cp_rank=attn_cp_rank,

python/sglang/srt/speculative/spec_info.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,9 +232,11 @@ class SpecInput(ABC):
232232
def __init__(self, spec_input_type: SpecInputType):
233233
self.spec_input_type = spec_input_type
234234

235+
# Cross-algorithm phase guards. Used by attention backends and
236+
# ForwardBatch padding logic to dispatch on phase without hardcoding the
237+
# specific algo class (EAGLE / FROZEN_KV_MTP / DFLASH / NGRAM each have
238+
# their own draft / verify SpecInput subclasses).
235239
def is_draft_input(self) -> bool:
236-
# FIXME: remove this function which is only used for assertion
237-
# or use another variable name like `draft_input` to substitute `spec_info`
238240
return self.spec_input_type in {
239241
SpecInputType.EAGLE_DRAFT,
240242
SpecInputType.EAGLE_DRAFT_EXTEND,

python/sglang/srt/speculative/standalone_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __init__(
8787
server_args=server_args,
8888
gpu_id=gpu_id,
8989
tp_rank=tp_rank,
90-
pp_rank=0, # FIXME
90+
pp_rank=0, # spec workers don't support pipeline parallelism
9191
dp_rank=dp_rank,
9292
moe_ep_rank=moe_ep_rank,
9393
attn_cp_rank=attn_cp_rank,

python/sglang/srt/speculative/standalone_worker_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def __init__(
9393
server_args=server_args,
9494
gpu_id=gpu_id,
9595
tp_rank=tp_rank,
96-
pp_rank=0, # FIXME
96+
pp_rank=0, # spec workers don't support pipeline parallelism
9797
dp_rank=dp_rank,
9898
moe_ep_rank=moe_ep_rank,
9999
attn_cp_rank=attn_cp_rank,

0 commit comments

Comments
 (0)