Skip to content

Commit dad5a43

Browse files
authored
[Speculative Decoding]【Hackathon 10th Spring No.54】ngram 端到端验证 (#7774)
* refine ngram kernel signature and adapt ngram proposer logic * update old unittest * add e2e test
1 parent bda1756 commit dad5a43

10 files changed

Lines changed: 509 additions & 132 deletions

File tree

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -951,9 +951,7 @@ void SpeculateScheduleCache(const paddle::Tensor& draft_tokens,
951951
const int block_size,
952952
const int max_draft_tokens);
953953

954-
void NgramMatch(const paddle::Tensor& input_ids,
955-
const paddle::Tensor& input_ids_len,
956-
const paddle::Tensor& token_ids_all,
954+
void NgramMatch(const paddle::Tensor& token_ids_all,
957955
const paddle::Tensor& prompt_lens,
958956
const paddle::Tensor& step_idx,
959957
const paddle::Tensor& draft_token_num,

custom_ops/gpu_ops/speculate_decoding/ngram_match.cu

Lines changed: 15 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,7 @@
2727
// the tentative new seq_lens_this_time to a copy buffer.
2828
// Phase 2 will decide which ones to keep (threshold logic).
2929
// ============================================================
30-
__global__ void ngram_match_search_kernel(const int64_t *input_ids,
31-
const int64_t *input_ids_len,
32-
const int64_t *token_ids_all,
30+
__global__ void ngram_match_search_kernel(const int64_t *token_ids_all,
3331
const int64_t *prompt_lens,
3432
const int64_t *step_idx,
3533
const int *draft_token_num,
@@ -38,7 +36,6 @@ __global__ void ngram_match_search_kernel(const int64_t *input_ids,
3836
const int64_t *max_dec_len,
3937
int64_t *draft_tokens_copy,
4038
int32_t *seq_lens_this_time_copy,
41-
int64_t input_ids_stride,
4239
int64_t max_model_len,
4340
int64_t draft_tokens_stride,
4441
int64_t max_batch_size,
@@ -63,9 +60,9 @@ __global__ void ngram_match_search_kernel(const int64_t *input_ids,
6360
// Active decoder item: at least the base token.
6461
if (threadIdx.x == 0) seq_lens_this_time_copy[batch_idx] = 1;
6562

66-
const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride;
67-
const int64_t cur_input_ids_len = input_ids_len[batch_idx];
6863
const int64_t prompt_len = prompt_lens[batch_idx];
64+
const int64_t *cur_input_ids = token_ids_all + batch_idx * max_model_len;
65+
const int64_t cur_input_ids_len = prompt_len;
6966
const int64_t *cur_pre_ids =
7067
token_ids_all + batch_idx * max_model_len + prompt_len;
7168
const int64_t cur_step_idx = step_idx[batch_idx];
@@ -79,7 +76,7 @@ __global__ void ngram_match_search_kernel(const int64_t *input_ids,
7976
for (int ngram_size = max_ngram_size; ngram_size >= 1; --ngram_size) {
8077
if (cur_step_idx < ngram_size) continue;
8178

82-
const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size);
79+
const int64_t *ngram = cur_pre_ids + (cur_step_idx - ngram_size);
8380

8481
int64_t pos = parallel_ngram_search(
8582
cur_input_ids, cur_input_ids_len, ngram, ngram_size, &s_min_pos);
@@ -235,9 +232,7 @@ static int sum_cpu(const int *value, int num) {
235232
return sum_value;
236233
}
237234

238-
static void find_candidate_pred_tokens(const int64_t *input_ids,
239-
const int64_t *input_ids_len,
240-
const int64_t *token_ids_all,
235+
static void find_candidate_pred_tokens(const int64_t *token_ids_all,
241236
const int64_t *prompt_lens,
242237
const int64_t *step_idx,
243238
const int *draft_token_num,
@@ -246,7 +241,6 @@ static void find_candidate_pred_tokens(const int64_t *input_ids,
246241
int32_t *seq_lens_encoder,
247242
int32_t *seq_lens_decoder,
248243
int64_t *max_dec_len,
249-
int64_t input_ids_stride,
250244
int64_t max_model_len,
251245
int64_t draft_tokens_stride,
252246
int64_t max_batch_size,
@@ -274,12 +268,12 @@ static void find_candidate_pred_tokens(const int64_t *input_ids,
274268
continue;
275269
}
276270

277-
const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride;
271+
const int64_t *cur_input_ids = token_ids_all + batch_idx * max_model_len;
272+
const int64_t cur_input_ids_len = prompt_lens[batch_idx];
278273
int64_t *cur_draft_tokens = draft_tokens + batch_idx * draft_tokens_stride;
279274
const int64_t *cur_pre_ids =
280-
token_ids_all + batch_idx * max_model_len + prompt_lens[batch_idx];
275+
token_ids_all + batch_idx * max_model_len + cur_input_ids_len;
281276
const int64_t cur_step_idx = step_idx[batch_idx];
282-
const int64_t cur_input_ids_len = input_ids_len[batch_idx];
283277
seq_lens_this_time[batch_idx] = 1;
284278
unprocessed_batch_size--;
285279

@@ -301,7 +295,7 @@ static void find_candidate_pred_tokens(const int64_t *input_ids,
301295
if (cur_step_idx < ngram_size) {
302296
continue;
303297
}
304-
const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size);
298+
const int64_t *ngram = cur_pre_ids + (cur_step_idx - ngram_size);
305299

306300
bool match_input = false;
307301
for (int64_t i = 0; i <= cur_input_ids_len - ngram_size; ++i) {
@@ -370,9 +364,7 @@ static void find_candidate_pred_tokens(const int64_t *input_ids,
370364
// bsz × NGRAM_BLOCK_THREADS threads. Phase 2 is O(bsz) with scans.
371365
// ============================================================
372366

373-
void NgramMatch(const paddle::Tensor &input_ids,
374-
const paddle::Tensor &input_ids_len,
375-
const paddle::Tensor &token_ids_all,
367+
void NgramMatch(const paddle::Tensor &token_ids_all,
376368
const paddle::Tensor &prompt_lens,
377369
const paddle::Tensor &step_idx,
378370
const paddle::Tensor &draft_token_num,
@@ -383,9 +375,6 @@ void NgramMatch(const paddle::Tensor &input_ids,
383375
const paddle::Tensor &max_dec_len,
384376
const int max_ngram_size,
385377
const int max_draft_tokens) {
386-
auto input_ids_shape = input_ids.shape();
387-
const int64_t input_ids_stride = input_ids_shape[1];
388-
389378
const int64_t max_model_len = token_ids_all.shape()[1];
390379

391380
auto draft_tokens_shape = draft_tokens.shape();
@@ -399,8 +388,8 @@ void NgramMatch(const paddle::Tensor &input_ids,
399388
threshold = std::stoi(env_var);
400389
}
401390

402-
if (input_ids.is_gpu()) {
403-
auto stream = input_ids.stream();
391+
if (token_ids_all.is_gpu()) {
392+
auto stream = token_ids_all.stream();
404393

405394
// Persistent scratch buffers for Phase 1 → Phase 2 communication.
406395
// Cached across calls to avoid per-invocation allocation overhead.
@@ -416,9 +405,9 @@ void NgramMatch(const paddle::Tensor &input_ids,
416405
draft_tokens_stride > s_scratch_stride) {
417406
s_draft_copy = paddle::empty({max_batch_size, draft_tokens_stride},
418407
paddle::DataType::INT64,
419-
input_ids.place());
408+
token_ids_all.place());
420409
s_seqlens_copy = paddle::empty(
421-
{max_batch_size}, paddle::DataType::INT32, input_ids.place());
410+
{max_batch_size}, paddle::DataType::INT32, token_ids_all.place());
422411
s_scratch_batch = max_batch_size;
423412
s_scratch_stride = draft_tokens_stride;
424413
}
@@ -435,8 +424,6 @@ void NgramMatch(const paddle::Tensor &input_ids,
435424
NGRAM_BLOCK_THREADS,
436425
0,
437426
stream>>>(
438-
input_ids.data<int64_t>(),
439-
input_ids_len.data<int64_t>(),
440427
token_ids_all.data<int64_t>(),
441428
prompt_lens.data<int64_t>(),
442429
step_idx.data<int64_t>(),
@@ -446,7 +433,6 @@ void NgramMatch(const paddle::Tensor &input_ids,
446433
max_dec_len.data<int64_t>(),
447434
draft_tokens_copy.data<int64_t>(),
448435
seq_lens_this_time_copy.data<int32_t>(),
449-
input_ids_stride,
450436
max_model_len,
451437
draft_tokens_stride,
452438
max_batch_size,
@@ -465,8 +451,6 @@ void NgramMatch(const paddle::Tensor &input_ids,
465451
threshold);
466452
} else {
467453
find_candidate_pred_tokens(
468-
input_ids.data<int64_t>(),
469-
input_ids_len.data<int64_t>(),
470454
token_ids_all.data<int64_t>(),
471455
prompt_lens.data<int64_t>(),
472456
step_idx.data<int64_t>(),
@@ -476,7 +460,6 @@ void NgramMatch(const paddle::Tensor &input_ids,
476460
const_cast<int32_t *>(seq_lens_encoder.data<int32_t>()),
477461
const_cast<int32_t *>(seq_lens_decoder.data<int32_t>()),
478462
const_cast<int64_t *>(max_dec_len.data<int64_t>()),
479-
input_ids_stride,
480463
max_model_len,
481464
draft_tokens_stride,
482465
max_batch_size,
@@ -486,9 +469,7 @@ void NgramMatch(const paddle::Tensor &input_ids,
486469
}
487470

488471
PD_BUILD_STATIC_OP(ngram_match)
489-
.Inputs({"input_ids",
490-
"input_ids_len",
491-
"token_ids_all",
472+
.Inputs({"token_ids_all",
492473
"prompt_lens",
493474
"step_idx",
494475
"draft_token_num",

fastdeploy/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1964,6 +1964,7 @@ def __init__(
19641964
in [
19651965
SpecMethod.MTP,
19661966
SpecMethod.SUFFIX,
1967+
SpecMethod.NGRAM,
19671968
]
19681969
)
19691970
else 0

fastdeploy/spec_decode/ngram.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616

1717
from typing import TYPE_CHECKING
1818

19-
import paddle
20-
2119
from fastdeploy.model_executor.ops.gpu import ngram_match
2220

2321
from .base import Proposer
@@ -36,23 +34,12 @@ class NgramProposer(Proposer):
3634
def __init__(self, fd_config: "FDConfig"):
3735
super().__init__(fd_config)
3836
self.max_ngram_size = self.speculative_config.max_ngram_size
39-
self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu()
40-
self.input_ids_len_gpu = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cuda()
41-
42-
def update(self, bid: int, seq_len: int):
43-
"""
44-
update
45-
"""
46-
self.input_ids_len[bid] = seq_len
47-
self.input_ids_len_gpu[bid] = seq_len
4837

4938
def _run_impl(self, share_inputs):
5039
"""
5140
run
5241
"""
5342
ngram_match(
54-
share_inputs["input_ids_cpu"].cuda(),
55-
self.input_ids_len_gpu,
5643
share_inputs["token_ids_all"],
5744
share_inputs["prompt_lens"],
5845
share_inputs["step_idx"],

fastdeploy/worker/gpu_model_runner.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2094,7 +2094,11 @@ def capture_model(self) -> None:
20942094
logger.info(
20952095
f"Warm up the model with the num_tokens:{num_tokens}, expected_decode_len:{expected_decode_len}"
20962096
)
2097-
elif self.speculative_decoding and self.spec_method in [SpecMethod.MTP, SpecMethod.SUFFIX]:
2097+
elif self.speculative_decoding and self.spec_method in [
2098+
SpecMethod.MTP,
2099+
SpecMethod.SUFFIX,
2100+
SpecMethod.NGRAM,
2101+
]:
20982102
for capture_size in sorted(capture_sizes, reverse=True):
20992103
expected_decode_len = (self.speculative_config.num_speculative_tokens + 1) * 2
21002104
self._dummy_run(

0 commit comments

Comments
 (0)