@@ -138,7 +138,9 @@ __global__ void ngram_match_gather_kernel(
138138 int32_t *seq_lens_this_time,
139139 int64_t draft_tokens_stride,
140140 int64_t max_batch_size,
141- int threshold) {
141+ int threshold,
142+ int max_draft_tokens_param,
143+ bool pad_to_max) {
142144 typedef cub::BlockScan<int , NGRAM_GATHER_THREADS> BlockScanInt;
143145 __shared__ typename BlockScanInt::TempStorage temp_storage1;
144146 __shared__ typename BlockScanInt::TempStorage temp_storage2;
@@ -203,16 +205,39 @@ __global__ void ngram_match_gather_kernel(
203205 actual = min (tentative, budget);
204206 }
205207
206- seq_lens_this_time[tid] = actual;
207-
208- // Copy draft tokens (slots 1..actual-1) from scratch to output
208+ // Copy draft tokens (slots 1..actual-1) from scratch to output FIRST
209+ // (so subsequent padding doesn't overwrite real ngram hits)
209210 if (actual > 1 ) {
210211 int64_t *dst = draft_tokens + tid * draft_tokens_stride;
211212 const int64_t *src = draft_tokens_copy + tid * draft_tokens_stride;
212213 for (int k = 1 ; k < actual; k++) {
213214 dst[k] = src[k];
214215 }
215216 }
217+
218+ // === Pad seq_lens_this_time to K+1 for cudagraph stability ===
219+ // Variable seq_lens_this_time (range [1, K+1]) clashes with cudagraph's
220+ // fixed launch params captured at warm-up time; downstream kernels read
221+ // past valid cu_seqlens / slot_mapping when replay sees a smaller slt,
222+ // leading to OOB / CUDA 700. When pad_to_max=true (cudagraph enabled),
223+ // pad missing positions with a placeholder so slt is fixed at K+1.
224+ // pad_to_max=false skips the padding cost when cudagraph is off.
225+ if (pad_to_max) {
226+ int target_slt = max_draft_tokens_param + 1 ;
227+ if (actual < target_slt) {
228+ int64_t *dst = draft_tokens + tid * draft_tokens_stride;
229+ // Reuse the last valid draft token as placeholder. It is a token the
230+ // model could plausibly have produced, so attention math stays
231+ // well-defined; rejection happens at the sampler level.
232+ int64_t pad_token = (actual > 0 ) ? dst[actual - 1 ] : 0 ;
233+ for (int k = actual; k < target_slt; k++) {
234+ dst[k] = pad_token;
235+ }
236+ actual = target_slt;
237+ }
238+ }
239+
240+ seq_lens_this_time[tid] = actual;
216241 }
217242}
218243
@@ -374,7 +399,8 @@ void NgramMatch(const paddle::Tensor &token_ids_all,
374399 const paddle::Tensor &seq_lens_decoder,
375400 const paddle::Tensor &max_dec_len,
376401 const int max_ngram_size,
377- const int max_draft_tokens) {
402+ const int max_draft_tokens,
403+ const bool pad_to_max) {
378404 const int64_t max_model_len = token_ids_all.shape ()[1 ];
379405
380406 auto draft_tokens_shape = draft_tokens.shape ();
@@ -448,7 +474,9 @@ void NgramMatch(const paddle::Tensor &token_ids_all,
448474 const_cast <int32_t *>(seq_lens_this_time.data <int32_t >()),
449475 draft_tokens_stride,
450476 max_batch_size,
451- threshold);
477+ threshold,
478+ max_draft_tokens,
479+ pad_to_max);
452480 } else {
453481 find_candidate_pred_tokens (
454482 token_ids_all.data <int64_t >(),
@@ -478,7 +506,7 @@ PD_BUILD_STATIC_OP(ngram_match)
478506 " seq_lens_encoder" ,
479507 " seq_lens_decoder" ,
480508 " max_dec_len" })
481- .Attrs({" max_ngram_size: int" , " max_draft_tokens: int" })
509+ .Attrs({" max_ngram_size: int" , " max_draft_tokens: int" , " pad_to_max: bool " })
482510 .Outputs({" draft_tokens_out" , " seq_lens_this_time_out" })
483511 .SetKernelFn(PD_KERNEL(NgramMatch))
484512 .SetInplaceMap({{" draft_tokens" , " draft_tokens_out" },
0 commit comments