@@ -259,8 +259,12 @@ def _build(self, requests: GuidedRequests) -> List[Tuple[int, str]]:
259259 matcher .fill_next_token_bitmask (self .bitmask_host , offset )
260260 self .token_mask_host [offset ] = 1
261261 self .num_guided_tokens [slot ] += 1
262- # Process draft tokens
263- for i , tid in enumerate (req .draft_tokens , 1 ):
262+ # Process draft tokens. Bound by the layout's draft length:
263+ # the new_tokens buffer always holds the static max, but only
264+ # `max_num_draft_tokens` slots are reserved this iteration.
265+ for i , tid in enumerate (
266+ req .draft_tokens [:requests .max_num_draft_tokens ],
267+ 1 ):
264268 accepted = matcher .accept_token (tid )
265269 if not accepted :
266270 break
@@ -332,9 +336,13 @@ def _apply_bitmask(self,
332336 d2t = d2t )
333337
334338 @nvtx_range ("GuidedDecoder.add_batch" )
335- def add_batch (self , scheduled_requests : ScheduledRequests ) -> None :
339+ def add_batch (self ,
340+ scheduled_requests : ScheduledRequests ,
341+ runtime_draft_len : Optional [int ] = None ) -> None :
342+ num_draft_tokens = (self .max_num_draft_tokens
343+ if runtime_draft_len is None else runtime_draft_len )
336344 self .requests = GuidedRequests .from_scheduled_requests (
337- scheduled_requests , self . max_num_draft_tokens )
345+ scheduled_requests , num_draft_tokens )
338346
339347 @nvtx_range ("GuideDecoder.build" )
340348 def build (self ) -> List [Tuple [int , str ]]:
@@ -470,9 +478,14 @@ def __init__(self,
470478 @nvtx_range ("GuidedDecoder.add_batch" )
471479 def add_batch (self ,
472480 scheduled_requests : ScheduledRequests ,
473- new_tokens : Optional [torch .Tensor ] = None ) -> None :
481+ new_tokens : Optional [torch .Tensor ] = None ,
482+ runtime_draft_len : Optional [int ] = None ) -> None :
483+ # See GuidedDecoder.add_batch: the layout must follow the runtime draft
484+ # length so the captured graph's bitmask matches the target logits.
485+ num_draft_tokens = (self .max_num_draft_tokens
486+ if runtime_draft_len is None else runtime_draft_len )
474487 self .requests = GuidedRequests .from_scheduled_requests (
475- scheduled_requests , self . max_num_draft_tokens )
488+ scheduled_requests , num_draft_tokens )
476489 if new_tokens is not None :
477490 self .new_tokens .copy_ (new_tokens .squeeze (- 1 ), non_blocking = True )
478491 self .queue .put ((self .requests , new_tokens is not None ))
0 commit comments