@@ -463,13 +463,20 @@ class SpecMetadata:
463463 use_sampling_params_for_draft_tokens : bool = False
464464 # Vocab size used for draft_probs buffer allocation.
465465 vocab_size : int = 0
466- # Draft probabilities buffer for rejection sampling, stored flat.
466+ # Draft probabilities buffer for rejection sampling, indexed by py_seq_slot
467+ # so per-request data is stable across iterations regardless of batch
468+ # composition shifts (chunking ctx, gen completion, new ctx joining).
469+ # Shape: [max_num_requests, max_draft_len, vocab_size].
467470 draft_probs : Optional [torch .Tensor ] = None
468471 draft_probs_vocab_size : int = 0
469472 # Whether draft_probs contains valid data.
470473 draft_probs_valid : bool = False
471474 # Last dimension size of the draft logits/probs stored in draft_probs.
472475 draft_probs_last_dim : int = 0
476+ # Per-request slot ids (py_seq_slot) for the current batch, in batch order.
477+ # Used to scatter draft probs by slot at write time and gather them by slot
478+ # at the next iter's verify. Shape: [max_num_requests], dtype=long.
479+ batch_slot_ids : Optional [torch .Tensor ] = None
473480 # Draft-to-target vocab offset tensor.
474481 d2t : Optional [torch .Tensor ] = None
475482
@@ -482,12 +489,18 @@ def prepare(self):
482489 """
483490 if (self .use_rejection_sampling and self .draft_probs is None
484491 and self .vocab_size > 0 ):
485- buffer_size = (self .max_num_requests * self .max_draft_len *
486- self .vocab_size )
487- self .draft_probs = torch .empty (buffer_size ,
488- dtype = torch .float32 ,
489- device = 'cuda' )
492+ # 3D [slot, draft_step, vocab] so we can scatter/gather by slot id
493+ # and avoid the brittle "batch position == buffer position" mapping.
494+ self .draft_probs = torch .empty (
495+ (self .max_num_requests , self .max_draft_len , self .vocab_size ),
496+ dtype = torch .float32 ,
497+ device = 'cuda' )
490498 self .draft_probs_vocab_size = self .vocab_size
499+ if (self .use_rejection_sampling and self .batch_slot_ids is None
500+ and self .max_num_requests > 0 ):
501+ self .batch_slot_ids = torch .empty ((self .max_num_requests , ),
502+ dtype = torch .long ,
503+ device = 'cuda' )
491504
492505 def create_cuda_graph_metadata (self , max_batch_size : int ):
493506 """
@@ -587,6 +600,7 @@ def _normalize_request_sampling_params(
587600 top_k_enabled = False
588601 top_p_enabled = False
589602 has_greedy_requests = False
603+ per_request_slot_ids : list [int ] = []
590604
591605 for request in requests :
592606 sampling_config = request .sampling_config
@@ -618,6 +632,12 @@ def _normalize_request_sampling_params(
618632
619633 per_request_normalized .append (
620634 (temp_val , tk_val , tp_val , num_tokens ))
635+ # py_seq_slot is a stable per-request id used to scatter / gather
636+ # draft probs across iterations. Dummies / unallocated slots fall
637+ # back to 0 (any valid index is fine — the data at that slot will
638+ # be overwritten on the next real iteration before being read).
639+ per_request_slot_ids .append (
640+ request .py_seq_slot if request .py_seq_slot is not None else 0 )
621641
622642 self .skip_temperature = not temperature_enabled
623643 self .skip_top_k = not top_k_enabled
@@ -653,9 +673,20 @@ def _normalize_request_sampling_params(
653673 dtype = torch .float32 ,
654674 device = 'cuda' )
655675
676+ # Always-populate the per-request slot id table when rejection sampling
677+ # is configured: it's tiny (max_num_requests longs) and needed at
678+ # _compute_and_store_draft_probs time to scatter draft probs by slot.
679+ if self .use_rejection_sampling and self .batch_slot_ids is not None :
680+ self .batch_slot_ids [:len (per_request_slot_ids )].copy_ (
681+ torch .tensor (per_request_slot_ids ,
682+ dtype = torch .long ,
683+ pin_memory = prefer_pinned ()),
684+ non_blocking = True ,
685+ )
686+
656687 # All-greedy: sampler takes the argmax branch (and rejection sampling
657- # is also bypassed for all-greedy), so the buffers are never read.
658- # Skip the H->D copies.
688+ # is also bypassed for all-greedy), so the per-token buffers are never
689+ # read. Skip the heavier H->D copies.
659690 if self .is_all_greedy_sample :
660691 return
661692
@@ -1014,112 +1045,162 @@ def _accept_draft_tokens(self, logits, draft_tokens, num_contexts,
10141045 batch_size , spec_metadata ):
10151046 """
10161047 Accept draft tokens with optional rejection sampling support.
1048+
1049+ Mixed batches (num_contexts > 0) are supported: context rows take the
1050+ first sampled target token via the base logic, and rejection sampling
1051+ runs on the gen subset. Draft probs for the gen subset are gathered
1052+ from the slot-indexed buffer by `py_seq_slot`.
10171053 """
1018- if self ._can_use_rejection_sampling (spec_metadata , num_contexts ):
1054+ num_gens = batch_size - num_contexts
1055+ if num_gens > 0 and self ._can_use_rejection_sampling (spec_metadata ):
10191056 draft_len = draft_tokens .shape [1 ]
10201057 stored_vocab = (spec_metadata .draft_probs_last_dim
10211058 if spec_metadata .draft_probs_last_dim > 0 else
10221059 spec_metadata .draft_probs_vocab_size )
1023- draft_probs = spec_metadata .draft_probs [:batch_size * draft_len *
1024- stored_vocab ].reshape (
1025- batch_size , draft_len ,
1026- stored_vocab )
1060+ # Gather the slot rows for the gen subset. The buffer was filled
1061+ # at the previous draft step indexed by py_seq_slot, so each gen
1062+ # request reads back exactly its own probs, regardless of batch
1063+ # composition changes since then.
1064+ gen_slot_ids = spec_metadata .batch_slot_ids [num_contexts :batch_size ]
1065+ draft_probs = spec_metadata .draft_probs [
1066+ gen_slot_ids , :draft_len , :stored_vocab ]
10271067 return self ._sample_and_accept_draft_tokens_rejection (
1028- logits , draft_tokens , draft_probs , batch_size , spec_metadata )
1068+ logits , draft_tokens , draft_probs , num_contexts , batch_size ,
1069+ spec_metadata )
10291070 return self ._sample_and_accept_draft_tokens_base (
10301071 logits , draft_tokens , num_contexts , batch_size , spec_metadata )
10311072
1032- def _can_use_rejection_sampling (self , spec_metadata : SpecMetadata ,
1033- num_contexts : int ) -> bool :
1034- # Skip rejection sampling when the whole batch is greedy: the
1035- # accepted result is identical to argmax and the base path is cheaper.
1073+ def _can_use_rejection_sampling (self , spec_metadata : SpecMetadata ) -> bool :
1074+ # Skip rejection sampling when the whole batch is greedy: the accepted
1075+ # result is identical to argmax and the base path is cheaper. Mixed
1076+ # batches (context + gen) are handled via slot-indexed draft probs and
1077+ # are split inside _sample_and_accept_draft_tokens_rejection.
10361078 return (spec_metadata .use_rejection_sampling
1037- and spec_metadata .draft_probs_valid and num_contexts == 0
1079+ and spec_metadata .draft_probs_valid
10381080 and not spec_metadata .is_all_greedy_sample )
10391081
10401082 def _sample_and_accept_draft_tokens_rejection (
10411083 self ,
10421084 logits : torch .Tensor ,
10431085 draft_tokens : torch .Tensor ,
10441086 draft_probs : torch .Tensor ,
1087+ num_contexts : int ,
10451088 batch_size : int ,
10461089 spec_metadata ,
10471090 ):
10481091 """
10491092 Rejection-sampling acceptance for one-model speculative decoding.
1093+
1094+ Mixed batches are handled by treating the two subsets separately:
1095+ - context rows (first `num_contexts`) take the target's sampled first
1096+ token; no draft tokens to verify.
1097+ - generation rows (`[num_contexts:batch_size]`) run the rejection
1098+ sampling kernel on slot-gathered draft probs.
1099+
1100+ Per-token sampling-parameter tensors (`temperatures / top_ks / top_ps`)
1101+ are laid out as `[ctx (1 each), gen (draft_len+1 each)]`, matching the
1102+ logits layout, so slicing is symmetric for both subsets.
10501103 """
10511104 device = logits .device
10521105 vocab_size = logits .shape [- 1 ]
1106+ num_gens = batch_size - num_contexts
1107+ runtime_draft_len = draft_tokens .shape [1 ]
10531108
10541109 if logits .dim () == 1 :
10551110 logits = logits .unsqueeze (0 )
10561111
1057- runtime_draft_len = draft_tokens .shape [1 ]
1058- draft_vocab_size = draft_probs .shape [- 1 ]
1059- num_target_tokens = batch_size * (runtime_draft_len + 1 )
1060-
1061- temperatures = spec_metadata .temperatures [:num_target_tokens ]
1062- # Pass None instead of an all-disabled tensor so the C++ op can short-circuit
1063- # on a host-side check rather than a `.item<bool>()` sync, which would break
1064- # CUDA graph capture.
1065- top_ks = None if spec_metadata .skip_top_k else spec_metadata .top_ks [:
1066- num_target_tokens ]
1067- top_ps = None if spec_metadata .skip_top_p else spec_metadata .top_ps [:
1068- num_target_tokens ]
1069-
1070- target_probs_flat = compute_probs_from_logits (logits .clone (),
1071- temperatures , top_ks ,
1072- top_ps )
1073- target_probs = target_probs_flat .reshape (batch_size ,
1074- runtime_draft_len + 1 ,
1075- vocab_size )
1076-
1077- assert draft_probs .shape [1 ] == runtime_draft_len , (
1078- f"draft_probs draft length mismatch: { draft_probs .shape [1 ]} != "
1079- f"{ runtime_draft_len } " )
1080- d2t = getattr (spec_metadata , "d2t" , None )
1081- if draft_vocab_size != vocab_size :
1082- full_draft_probs = torch .zeros (
1083- (batch_size , runtime_draft_len , vocab_size ),
1084- dtype = torch .float32 ,
1085- device = device )
1086- if d2t is not None :
1087- assert d2t .numel () == draft_vocab_size , (
1088- f"d2t size mismatch: { d2t .numel ()} != { draft_vocab_size } " )
1089- d2t = d2t .to (device = device )
1090- source_indices = torch .arange (draft_vocab_size ,
1091- device = device ,
1092- dtype = torch .long )
1093- target_indices = (source_indices + d2t ) % vocab_size
1094- full_draft_probs [:, :runtime_draft_len ,
1095- target_indices ] = draft_probs
1112+ accepted_tokens = torch .empty ((batch_size , runtime_draft_len + 1 ),
1113+ dtype = torch .int ,
1114+ device = device )
1115+ num_accepted_tokens = torch .ones (batch_size ,
1116+ dtype = torch .int ,
1117+ device = device )
1118+
1119+ # === Context subset: sample target's first token directly ===
1120+ if num_contexts > 0 :
1121+ ctx_target_tokens = self ._sample_tokens_for_batch (
1122+ logits [:num_contexts ], spec_metadata , num_contexts ,
1123+ num_contexts )
1124+ accepted_tokens [:num_contexts , 0 ] = ctx_target_tokens
1125+
1126+ # === Generation subset: rejection sampling on the gen slice ===
1127+ if num_gens > 0 :
1128+ num_gen_logits = num_gens * (runtime_draft_len + 1 )
1129+ gen_logits = logits [num_contexts :num_contexts + num_gen_logits ]
1130+ gen_start = num_contexts
1131+ gen_end = num_contexts + num_gen_logits
1132+
1133+ temperatures = spec_metadata .temperatures [gen_start :gen_end ]
1134+ # Pass None instead of an all-disabled tensor so the C++ op can short-circuit
1135+ # on a host-side check rather than a `.item<bool>()` sync, which would break
1136+ # CUDA graph capture.
1137+ top_ks = (None if spec_metadata .skip_top_k else
1138+ spec_metadata .top_ks [gen_start :gen_end ])
1139+ top_ps = (None if spec_metadata .skip_top_p else
1140+ spec_metadata .top_ps [gen_start :gen_end ])
1141+
1142+ target_probs_flat = compute_probs_from_logits (
1143+ gen_logits .clone (), temperatures , top_ks , top_ps )
1144+ target_probs = target_probs_flat .reshape (num_gens ,
1145+ runtime_draft_len + 1 ,
1146+ vocab_size )
1147+
1148+ draft_vocab_size = draft_probs .shape [- 1 ]
1149+ assert draft_probs .shape [0 ] == num_gens , (
1150+ f"draft_probs batch mismatch: { draft_probs .shape [0 ]} != "
1151+ f"num_gens={ num_gens } " )
1152+ assert draft_probs .shape [1 ] == runtime_draft_len , (
1153+ f"draft_probs draft length mismatch: { draft_probs .shape [1 ]} != "
1154+ f"{ runtime_draft_len } " )
1155+ d2t = getattr (spec_metadata , "d2t" , None )
1156+ if draft_vocab_size != vocab_size :
1157+ full_draft_probs = torch .zeros (
1158+ (num_gens , runtime_draft_len , vocab_size ),
1159+ dtype = torch .float32 ,
1160+ device = device )
1161+ if d2t is not None :
1162+ assert d2t .numel () == draft_vocab_size , (
1163+ f"d2t size mismatch: { d2t .numel ()} != { draft_vocab_size } "
1164+ )
1165+ d2t = d2t .to (device = device )
1166+ source_indices = torch .arange (draft_vocab_size ,
1167+ device = device ,
1168+ dtype = torch .long )
1169+ target_indices = (source_indices + d2t ) % vocab_size
1170+ full_draft_probs [:, :runtime_draft_len ,
1171+ target_indices ] = draft_probs
1172+ else :
1173+ assert draft_vocab_size < vocab_size
1174+ full_draft_probs [:, :runtime_draft_len , :
1175+ draft_vocab_size ] = (draft_probs )
10961176 else :
1097- assert draft_vocab_size < vocab_size
1098- full_draft_probs [:, :runtime_draft_len , :draft_vocab_size ] = (
1099- draft_probs )
1100- else :
1101- full_draft_probs = draft_probs
1102-
1103- full_draft_tokens = draft_tokens .to (torch .int32 ).contiguous ()
1104-
1105- if self .seed is None :
1106- self .seed = torch .tensor ([0 ], dtype = torch .int64 , device = device )
1107- if self .offset is None :
1108- self .offset = torch .tensor ([0 ], dtype = torch .int64 , device = device )
1109- self .seed += 1
1110- self .seed %= 2 ** 31
1111-
1112- accepted_tokens , num_accepted_tokens = rejection_sampling_one_model (
1113- draft_probs = full_draft_probs ,
1114- draft_token_ids = full_draft_tokens ,
1115- target_probs = target_probs ,
1116- deterministic = True ,
1117- seed = self .seed ,
1118- offset = self .offset ,
1119- )
1177+ full_draft_probs = draft_probs
1178+
1179+ full_draft_tokens = draft_tokens .to (torch .int32 ).contiguous ()
1180+
1181+ if self .seed is None :
1182+ self .seed = torch .tensor ([0 ], dtype = torch .int64 , device = device )
1183+ if self .offset is None :
1184+ self .offset = torch .tensor ([0 ],
1185+ dtype = torch .int64 ,
1186+ device = device )
1187+ self .seed += 1
1188+ self .seed %= 2 ** 31
1189+
1190+ gen_accepted , gen_num_accepted = rejection_sampling_one_model (
1191+ draft_probs = full_draft_probs ,
1192+ draft_token_ids = full_draft_tokens ,
1193+ target_probs = target_probs ,
1194+ deterministic = True ,
1195+ seed = self .seed ,
1196+ offset = self .offset ,
1197+ )
1198+
1199+ accepted_tokens [num_contexts :] = gen_accepted
1200+ num_accepted_tokens [num_contexts :] = gen_num_accepted
11201201
11211202 num_accepted_tokens = self ._apply_force_accepted_tokens (
1122- num_accepted_tokens , 0 , draft_tokens . shape [ 1 ] )
1203+ num_accepted_tokens , num_contexts , runtime_draft_len )
11231204 return accepted_tokens , num_accepted_tokens
11241205
11251206 def _draft_sampler_greedy (self , logits : torch .Tensor , d2t = None ):
@@ -1204,7 +1285,10 @@ def _compute_and_store_draft_probs(
12041285 batch_size : int ,
12051286 ):
12061287 """
1207- Compute draft probabilities and store them for next-step rejection sampling.
1288+ Compute draft probabilities and store them for next-step rejection
1289+ sampling. The storage is keyed by py_seq_slot, so the data is robust
1290+ to batch composition shifts across iterations (chunking ctxs, gen
1291+ completion, new ctxs joining).
12081292 """
12091293 draft_tokens_per_request = len (draft_logits_list )
12101294 vocab_size = draft_logits_list [0 ].shape [- 1 ]
@@ -1233,9 +1317,21 @@ def _compute_and_store_draft_probs(
12331317 draft_probs_flat = compute_probs_from_logits (draft_logits_flat ,
12341318 draft_temps , draft_top_ks ,
12351319 draft_top_ps )
1236- num_elements = batch_size * draft_tokens_per_request * vocab_size
1237- spec_metadata .draft_probs [:num_elements ].copy_ (
1238- draft_probs_flat .flatten ())
1320+ # [batch_size, draft_len, draft_vocab]
1321+ draft_probs_per_request = draft_probs_flat .reshape (
1322+ batch_size , draft_tokens_per_request , vocab_size )
1323+
1324+ # Scatter into draft_probs[slot] for each request in the current batch.
1325+ # spec_metadata.draft_probs is shaped [max_num_requests, max_draft_len,
1326+ # vocab_size]. Different iterations may have different batch
1327+ # compositions, but a given request's data always lives at its
1328+ # py_seq_slot row, so reads at the next iter pick up the right data.
1329+ assert spec_metadata .batch_slot_ids is not None , (
1330+ "batch_slot_ids must be populated by "
1331+ "populate_sampling_params_for_one_model before draft probs storage" )
1332+ batch_slots = spec_metadata .batch_slot_ids [:batch_size ]
1333+ spec_metadata .draft_probs [batch_slots , :draft_tokens_per_request , :
1334+ vocab_size ] = draft_probs_per_request
12391335 spec_metadata .draft_probs_last_dim = vocab_size
12401336 spec_metadata .draft_probs_valid = True
12411337
0 commit comments