Skip to content

Commit 00ed78c

Browse files
authored
[#15022][fix] Guided decoding (xgrammar) + EAGLE-3 + draft_len_schedule reaching 0 crashes during CUDA graph capture, "bitmask must have the same batch size as logits" (#15023)
Signed-off-by: chungen04 <b09901027@ntu.edu.tw>
1 parent 1b360ee commit 00ed78c

2 files changed

Lines changed: 23 additions & 8 deletions

File tree

tensorrt_llm/_torch/pyexecutor/guided_decoder.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,12 @@ def _build(self, requests: GuidedRequests) -> List[Tuple[int, str]]:
259259
matcher.fill_next_token_bitmask(self.bitmask_host, offset)
260260
self.token_mask_host[offset] = 1
261261
self.num_guided_tokens[slot] += 1
262-
# Process draft tokens
263-
for i, tid in enumerate(req.draft_tokens, 1):
262+
# Process draft tokens. Bound by the layout's draft length:
263+
# the new_tokens buffer always holds the static max, but only
264+
# `max_num_draft_tokens` slots are reserved this iteration.
265+
for i, tid in enumerate(
266+
req.draft_tokens[:requests.max_num_draft_tokens],
267+
1):
264268
accepted = matcher.accept_token(tid)
265269
if not accepted:
266270
break
@@ -332,9 +336,13 @@ def _apply_bitmask(self,
332336
d2t=d2t)
333337

334338
@nvtx_range("GuidedDecoder.add_batch")
335-
def add_batch(self, scheduled_requests: ScheduledRequests) -> None:
339+
def add_batch(self,
340+
scheduled_requests: ScheduledRequests,
341+
runtime_draft_len: Optional[int] = None) -> None:
342+
num_draft_tokens = (self.max_num_draft_tokens
343+
if runtime_draft_len is None else runtime_draft_len)
336344
self.requests = GuidedRequests.from_scheduled_requests(
337-
scheduled_requests, self.max_num_draft_tokens)
345+
scheduled_requests, num_draft_tokens)
338346

339347
@nvtx_range("GuideDecoder.build")
340348
def build(self) -> List[Tuple[int, str]]:
@@ -470,9 +478,14 @@ def __init__(self,
470478
@nvtx_range("GuidedDecoder.add_batch")
471479
def add_batch(self,
472480
scheduled_requests: ScheduledRequests,
473-
new_tokens: Optional[torch.Tensor] = None) -> None:
481+
new_tokens: Optional[torch.Tensor] = None,
482+
runtime_draft_len: Optional[int] = None) -> None:
483+
# See GuidedDecoder.add_batch: the layout must follow the runtime draft
484+
# length so the captured graph's bitmask matches the target logits.
485+
num_draft_tokens = (self.max_num_draft_tokens
486+
if runtime_draft_len is None else runtime_draft_len)
474487
self.requests = GuidedRequests.from_scheduled_requests(
475-
scheduled_requests, self.max_num_draft_tokens)
488+
scheduled_requests, num_draft_tokens)
476489
if new_tokens is not None:
477490
self.new_tokens.copy_(new_tokens.squeeze(-1), non_blocking=True)
478491
self.queue.put((self.requests, new_tokens is not None))

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2639,8 +2639,10 @@ def _prepare_tp_inputs(
26392639

26402640
# Must be before the update of py_batch_idx
26412641
if self.guided_decoder is not None:
2642-
self.guided_decoder.add_batch(scheduled_requests,
2643-
new_tokens=new_tokens_device)
2642+
self.guided_decoder.add_batch(
2643+
scheduled_requests,
2644+
new_tokens=new_tokens_device,
2645+
runtime_draft_len=self.runtime_draft_len)
26442646

26452647
if self._can_use_incremental_update(scheduled_requests,
26462648
new_tokens_device,

0 commit comments

Comments
 (0)