2727// the tentative new seq_lens_this_time to a copy buffer.
2828// Phase 2 will decide which ones to keep (threshold logic).
2929// ============================================================
30- __global__ void ngram_match_search_kernel (const int64_t *input_ids,
31- const int64_t *input_ids_len,
32- const int64_t *token_ids_all,
30+ __global__ void ngram_match_search_kernel (const int64_t *token_ids_all,
3331 const int64_t *prompt_lens,
3432 const int64_t *step_idx,
3533 const int *draft_token_num,
@@ -38,7 +36,6 @@ __global__ void ngram_match_search_kernel(const int64_t *input_ids,
3836 const int64_t *max_dec_len,
3937 int64_t *draft_tokens_copy,
4038 int32_t *seq_lens_this_time_copy,
41- int64_t input_ids_stride,
4239 int64_t max_model_len,
4340 int64_t draft_tokens_stride,
4441 int64_t max_batch_size,
@@ -63,9 +60,9 @@ __global__ void ngram_match_search_kernel(const int64_t *input_ids,
6360 // Active decoder item: at least the base token.
6461 if (threadIdx .x == 0 ) seq_lens_this_time_copy[batch_idx] = 1 ;
6562
66- const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride;
67- const int64_t cur_input_ids_len = input_ids_len[batch_idx];
6863 const int64_t prompt_len = prompt_lens[batch_idx];
64+ const int64_t *cur_input_ids = token_ids_all + batch_idx * max_model_len;
65+ const int64_t cur_input_ids_len = prompt_len;
6966 const int64_t *cur_pre_ids =
7067 token_ids_all + batch_idx * max_model_len + prompt_len;
7168 const int64_t cur_step_idx = step_idx[batch_idx];
@@ -79,7 +76,7 @@ __global__ void ngram_match_search_kernel(const int64_t *input_ids,
7976 for (int ngram_size = max_ngram_size; ngram_size >= 1 ; --ngram_size) {
8077 if (cur_step_idx < ngram_size) continue ;
8178
82- const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size);
79+ const int64_t *ngram = cur_pre_ids + (cur_step_idx - ngram_size);
8380
8481 int64_t pos = parallel_ngram_search (
8582 cur_input_ids, cur_input_ids_len, ngram, ngram_size, &s_min_pos);
@@ -235,9 +232,7 @@ static int sum_cpu(const int *value, int num) {
235232 return sum_value;
236233}
237234
238- static void find_candidate_pred_tokens (const int64_t *input_ids,
239- const int64_t *input_ids_len,
240- const int64_t *token_ids_all,
235+ static void find_candidate_pred_tokens (const int64_t *token_ids_all,
241236 const int64_t *prompt_lens,
242237 const int64_t *step_idx,
243238 const int *draft_token_num,
@@ -246,7 +241,6 @@ static void find_candidate_pred_tokens(const int64_t *input_ids,
246241 int32_t *seq_lens_encoder,
247242 int32_t *seq_lens_decoder,
248243 int64_t *max_dec_len,
249- int64_t input_ids_stride,
250244 int64_t max_model_len,
251245 int64_t draft_tokens_stride,
252246 int64_t max_batch_size,
@@ -274,12 +268,12 @@ static void find_candidate_pred_tokens(const int64_t *input_ids,
274268 continue ;
275269 }
276270
277- const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride;
271+ const int64_t *cur_input_ids = token_ids_all + batch_idx * max_model_len;
272+ const int64_t cur_input_ids_len = prompt_lens[batch_idx];
278273 int64_t *cur_draft_tokens = draft_tokens + batch_idx * draft_tokens_stride;
279274 const int64_t *cur_pre_ids =
280- token_ids_all + batch_idx * max_model_len + prompt_lens[batch_idx] ;
275+ token_ids_all + batch_idx * max_model_len + cur_input_ids_len ;
281276 const int64_t cur_step_idx = step_idx[batch_idx];
282- const int64_t cur_input_ids_len = input_ids_len[batch_idx];
283277 seq_lens_this_time[batch_idx] = 1 ;
284278 unprocessed_batch_size--;
285279
@@ -301,7 +295,7 @@ static void find_candidate_pred_tokens(const int64_t *input_ids,
301295 if (cur_step_idx < ngram_size) {
302296 continue ;
303297 }
304- const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size);
298+ const int64_t *ngram = cur_pre_ids + (cur_step_idx - ngram_size);
305299
306300 bool match_input = false ;
307301 for (int64_t i = 0 ; i <= cur_input_ids_len - ngram_size; ++i) {
@@ -370,9 +364,7 @@ static void find_candidate_pred_tokens(const int64_t *input_ids,
370364// bsz × NGRAM_BLOCK_THREADS threads. Phase 2 is O(bsz) with scans.
371365// ============================================================
372366
373- void NgramMatch (const paddle::Tensor &input_ids,
374- const paddle::Tensor &input_ids_len,
375- const paddle::Tensor &token_ids_all,
367+ void NgramMatch (const paddle::Tensor &token_ids_all,
376368 const paddle::Tensor &prompt_lens,
377369 const paddle::Tensor &step_idx,
378370 const paddle::Tensor &draft_token_num,
@@ -383,9 +375,6 @@ void NgramMatch(const paddle::Tensor &input_ids,
383375 const paddle::Tensor &max_dec_len,
384376 const int max_ngram_size,
385377 const int max_draft_tokens) {
386- auto input_ids_shape = input_ids.shape ();
387- const int64_t input_ids_stride = input_ids_shape[1 ];
388-
389378 const int64_t max_model_len = token_ids_all.shape ()[1 ];
390379
391380 auto draft_tokens_shape = draft_tokens.shape ();
@@ -399,8 +388,8 @@ void NgramMatch(const paddle::Tensor &input_ids,
399388 threshold = std::stoi (env_var);
400389 }
401390
402- if (input_ids .is_gpu ()) {
403- auto stream = input_ids .stream ();
391+ if (token_ids_all .is_gpu ()) {
392+ auto stream = token_ids_all .stream ();
404393
405394 // Persistent scratch buffers for Phase 1 → Phase 2 communication.
406395 // Cached across calls to avoid per-invocation allocation overhead.
@@ -416,9 +405,9 @@ void NgramMatch(const paddle::Tensor &input_ids,
416405 draft_tokens_stride > s_scratch_stride) {
417406 s_draft_copy = paddle::empty ({max_batch_size, draft_tokens_stride},
418407 paddle::DataType::INT64,
419- input_ids .place ());
408+ token_ids_all .place ());
420409 s_seqlens_copy = paddle::empty (
421- {max_batch_size}, paddle::DataType::INT32, input_ids .place ());
410+ {max_batch_size}, paddle::DataType::INT32, token_ids_all .place ());
422411 s_scratch_batch = max_batch_size;
423412 s_scratch_stride = draft_tokens_stride;
424413 }
@@ -435,8 +424,6 @@ void NgramMatch(const paddle::Tensor &input_ids,
435424 NGRAM_BLOCK_THREADS,
436425 0 ,
437426 stream>>> (
438- input_ids.data <int64_t >(),
439- input_ids_len.data <int64_t >(),
440427 token_ids_all.data <int64_t >(),
441428 prompt_lens.data <int64_t >(),
442429 step_idx.data <int64_t >(),
@@ -446,7 +433,6 @@ void NgramMatch(const paddle::Tensor &input_ids,
446433 max_dec_len.data <int64_t >(),
447434 draft_tokens_copy.data <int64_t >(),
448435 seq_lens_this_time_copy.data <int32_t >(),
449- input_ids_stride,
450436 max_model_len,
451437 draft_tokens_stride,
452438 max_batch_size,
@@ -465,8 +451,6 @@ void NgramMatch(const paddle::Tensor &input_ids,
465451 threshold);
466452 } else {
467453 find_candidate_pred_tokens (
468- input_ids.data <int64_t >(),
469- input_ids_len.data <int64_t >(),
470454 token_ids_all.data <int64_t >(),
471455 prompt_lens.data <int64_t >(),
472456 step_idx.data <int64_t >(),
@@ -476,7 +460,6 @@ void NgramMatch(const paddle::Tensor &input_ids,
476460 const_cast <int32_t *>(seq_lens_encoder.data <int32_t >()),
477461 const_cast <int32_t *>(seq_lens_decoder.data <int32_t >()),
478462 const_cast <int64_t *>(max_dec_len.data <int64_t >()),
479- input_ids_stride,
480463 max_model_len,
481464 draft_tokens_stride,
482465 max_batch_size,
@@ -486,9 +469,7 @@ void NgramMatch(const paddle::Tensor &input_ids,
486469}
487470
488471PD_BUILD_STATIC_OP (ngram_match)
489- .Inputs({" input_ids" ,
490- " input_ids_len" ,
491- " token_ids_all" ,
472+ .Inputs({" token_ids_all" ,
492473 " prompt_lens" ,
493474 " step_idx" ,
494475 " draft_token_num" ,
0 commit comments