Skip to content

Commit 00c9cfe

Browse files
committed
update ngram kernel with the same cudagraph adapting logic
1 parent 10804fd commit 00c9cfe

7 files changed

Lines changed: 116 additions & 22 deletions

File tree

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -960,7 +960,8 @@ void NgramMatch(const paddle::Tensor& token_ids_all,
960960
const paddle::Tensor& seq_lens_decoder,
961961
const paddle::Tensor& max_dec_len,
962962
const int max_ngram_size,
963-
const int max_draft_tokens);
963+
const int max_draft_tokens,
964+
const bool pad_to_max);
964965

965966
void HybridMtpNgram(const paddle::Tensor& token_ids_all,
966967
const paddle::Tensor& prompt_lens,

custom_ops/gpu_ops/speculate_decoding/draft_model/ngram_match_mixed.cu

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -216,20 +216,21 @@ __global__ void ngram_match_mixed_gather_kernel(
216216
}
217217
}
218218

219-
// === Pad seq_lens_this_time to K+1 for cudagraph stability ===
220-
// Hybrid MTP-ngram produces variable seq_lens_this_time depending on how
221-
// many ngram positions hit (range: [num_model_steps+1, K+1]). cudagraph
222-
// captures launch params (grid dim, kernel args) at capture time; if the
223-
// captured slt differs from replay-time slt, downstream kernels read past
224-
// valid ranges of cu_seqlens / slot_mapping etc., causing CUDA 700.
219+
// === Pad seq_lens_this_time to num_speculative_tokens+1 for cudagraph
220+
// stability === Hybrid MTP-ngram produces variable seq_lens_this_time
221+
// depending on how many ngram positions hit (range: [num_model_steps+1,
222+
// num_speculative_tokens+1]). cudagraph captures launch params (grid dim,
223+
// kernel args) at capture time; if the captured slt differs from
224+
// replay-time slt, downstream kernels read past valid ranges of cu_seqlens
225+
// / slot_mapping etc., causing CUDA 700.
225226
//
226-
// When pad_to_max=true (cudagraph enabled), force slt = K+1 =
227-
// max_draft_tokens + 1: positions beyond actual ngram hits get padded
228-
// with a placeholder token. The target model will verify these
229-
// placeholders and (almost always) reject them, but the verify cost is
230-
// fixed per iteration => grid dim is now invariant. When pad_to_max=
231-
// false (cudagraph disabled), keep the natural variable slt to avoid
232-
// wasting verify compute on placeholders.
227+
// When pad_to_max=true (cudagraph enabled), force slt =
228+
// num_speculative_tokens+1 = max_draft_tokens + 1: positions beyond actual
229+
// ngram hits get padded with a placeholder token. The target model will
230+
// verify these placeholders and (almost always) reject them, but the verify
231+
// cost is fixed per iteration => grid dim is now invariant. When
232+
// pad_to_max= false (cudagraph disabled), keep the natural variable slt to
233+
// avoid wasting verify compute on placeholders.
233234
if (pad_to_max) {
234235
int target_slt = max_draft_tokens_param + 1;
235236
if (actual < target_slt) {

custom_ops/gpu_ops/speculate_decoding/ngram_match.cu

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,9 @@ __global__ void ngram_match_gather_kernel(
138138
int32_t *seq_lens_this_time,
139139
int64_t draft_tokens_stride,
140140
int64_t max_batch_size,
141-
int threshold) {
141+
int threshold,
142+
int max_draft_tokens_param,
143+
bool pad_to_max) {
142144
typedef cub::BlockScan<int, NGRAM_GATHER_THREADS> BlockScanInt;
143145
__shared__ typename BlockScanInt::TempStorage temp_storage1;
144146
__shared__ typename BlockScanInt::TempStorage temp_storage2;
@@ -203,16 +205,39 @@ __global__ void ngram_match_gather_kernel(
203205
actual = min(tentative, budget);
204206
}
205207

206-
seq_lens_this_time[tid] = actual;
207-
208-
// Copy draft tokens (slots 1..actual-1) from scratch to output
208+
// Copy draft tokens (slots 1..actual-1) from scratch to output FIRST
209+
// (so subsequent padding doesn't overwrite real ngram hits)
209210
if (actual > 1) {
210211
int64_t *dst = draft_tokens + tid * draft_tokens_stride;
211212
const int64_t *src = draft_tokens_copy + tid * draft_tokens_stride;
212213
for (int k = 1; k < actual; k++) {
213214
dst[k] = src[k];
214215
}
215216
}
217+
218+
// === Pad seq_lens_this_time to K+1 for cudagraph stability ===
219+
// Variable seq_lens_this_time (range [1, K+1]) clashes with cudagraph's
220+
// fixed launch params captured at warm-up time; downstream kernels read
221+
// past valid cu_seqlens / slot_mapping when replay sees a smaller slt,
222+
// leading to OOB / CUDA 700. When pad_to_max=true (cudagraph enabled),
223+
// pad missing positions with a placeholder so slt is fixed at K+1.
224+
// pad_to_max=false skips the padding cost when cudagraph is off.
225+
if (pad_to_max) {
226+
int target_slt = max_draft_tokens_param + 1;
227+
if (actual < target_slt) {
228+
int64_t *dst = draft_tokens + tid * draft_tokens_stride;
229+
// Reuse the last valid draft token as placeholder. It is a token the
230+
// model could plausibly have produced, so attention math stays
231+
// well-defined; rejection happens at the sampler level.
232+
int64_t pad_token = (actual > 0) ? dst[actual - 1] : 0;
233+
for (int k = actual; k < target_slt; k++) {
234+
dst[k] = pad_token;
235+
}
236+
actual = target_slt;
237+
}
238+
}
239+
240+
seq_lens_this_time[tid] = actual;
216241
}
217242
}
218243

@@ -374,7 +399,8 @@ void NgramMatch(const paddle::Tensor &token_ids_all,
374399
const paddle::Tensor &seq_lens_decoder,
375400
const paddle::Tensor &max_dec_len,
376401
const int max_ngram_size,
377-
const int max_draft_tokens) {
402+
const int max_draft_tokens,
403+
const bool pad_to_max) {
378404
const int64_t max_model_len = token_ids_all.shape()[1];
379405

380406
auto draft_tokens_shape = draft_tokens.shape();
@@ -448,7 +474,9 @@ void NgramMatch(const paddle::Tensor &token_ids_all,
448474
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
449475
draft_tokens_stride,
450476
max_batch_size,
451-
threshold);
477+
threshold,
478+
max_draft_tokens,
479+
pad_to_max);
452480
} else {
453481
find_candidate_pred_tokens(
454482
token_ids_all.data<int64_t>(),
@@ -478,7 +506,7 @@ PD_BUILD_STATIC_OP(ngram_match)
478506
"seq_lens_encoder",
479507
"seq_lens_decoder",
480508
"max_dec_len"})
481-
.Attrs({"max_ngram_size: int", "max_draft_tokens: int"})
509+
.Attrs({"max_ngram_size: int", "max_draft_tokens: int", "pad_to_max: bool"})
482510
.Outputs({"draft_tokens_out", "seq_lens_this_time_out"})
483511
.SetKernelFn(PD_KERNEL(NgramMatch))
484512
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},

fastdeploy/spec_decode/mtp_cuda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ def _update_status(self):
394394

395395
def _extend_draft_token_with_ngram_match(self):
396396
# pad_to_max forces hybrid kernel to write a fixed seq_lens_this_time
397-
# = K + 1, padding unfilled ngram slots with a placeholder draft token.
397+
# = num_speculative_tokens + 1, padding unfilled ngram slots with a placeholder draft token.
398398
# Required when target cudagraph is enabled (capture-time seq_lens_this_time
399399
# must match replay-time seq_lens_this_time).
400400
hybrid_mtp_ngram(

fastdeploy/spec_decode/ngram.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ def _run_impl(self, share_inputs):
3939
"""
4040
run
4141
"""
42+
# pad_to_max forces the kernel to write a fixed seq_lens_this_time =
43+
# num_speculative_tokens + 1, padding unfilled draft slots with a placeholder token.
44+
# Required when target cudagraph is enabled (capture-time slt must
45+
# match replay-time slt; see ngram_match.cu for details). Disabled
46+
# when cudagraph is off to avoid wasted verify on placeholders.
4247
ngram_match(
4348
share_inputs["token_ids_all"],
4449
share_inputs["prompt_lens"],
@@ -51,4 +56,5 @@ def _run_impl(self, share_inputs):
5156
share_inputs["max_dec_len"],
5257
self.max_ngram_size,
5358
self.max_draft_token_num,
59+
self.graph_opt_config.use_cudagraph,
5460
)

tests/operators/test_hybrid_mtp_ngram.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def setUp(self):
7575
self.ref_draft_tokens = np.array([[8, 7, 6, 10, 9, 8], [8, 7, 6, 10, 9, 8]], dtype="int64")
7676

7777
def test_ngram_match_mixed(self):
78+
"""pad_to_max=False: GPU output matches the CPU reference baseline."""
7879
hybrid_mtp_ngram(
7980
self.token_ids_all,
8081
self.prompt_lens,
@@ -94,6 +95,50 @@ def test_ngram_match_mixed(self):
9495
np.testing.assert_allclose(self.seq_lens_this_time.numpy(), self.ref_seq_lens_this_time)
9596
np.testing.assert_allclose(self.draft_tokens.numpy(), self.ref_draft_tokens)
9697

98+
def test_ngram_match_mixed_pad_to_max(self):
99+
"""pad_to_max=True: slt is forced to K+1 and unfilled draft slots are
100+
padded with the last valid draft token (placeholder for cudagraph
101+
stability).
102+
103+
To exercise the pad path we drive step_idx below min_ngram_size so
104+
the search kernel finds no ngram match. Without pad, slt stays at
105+
ori_seq_len_this_time=2; with pad, slt becomes max_draft_tokens+1=6
106+
and draft_tokens[2:6] are filled with draft_tokens[1] (=7).
107+
"""
108+
# No ngram match path: step_idx < min_ngram_size short-circuits search.
109+
self.step_idx[:] = self.min_ngram_size - 1
110+
111+
hybrid_mtp_ngram(
112+
self.token_ids_all,
113+
self.prompt_lens,
114+
self.pre_ids,
115+
self.step_idx,
116+
self.draft_token_num,
117+
self.draft_tokens,
118+
self.seq_lens_this_time,
119+
self.seq_lens_decoder,
120+
self.max_dec_len,
121+
self.max_ngram_size,
122+
self.min_ngram_size,
123+
self.max_draft_tokens,
124+
True, # pad_to_max
125+
)
126+
127+
target_slt = self.max_draft_tokens + 1 # K+1 = 6
128+
slt = self.seq_lens_this_time.numpy()
129+
assert (slt == target_slt).all(), f"expected all slt == {target_slt}, got {slt.flatten().tolist()}"
130+
131+
# ori_seq_len_this_time was 2; positions [2..6) should be padded with
132+
# draft_tokens[1] (= 7, the last valid draft token before padding).
133+
drafts = self.draft_tokens.numpy()
134+
expected_placeholder = 7
135+
for b in range(self.max_bsz):
136+
np.testing.assert_array_equal(
137+
drafts[b, 2:target_slt],
138+
np.full(target_slt - 2, expected_placeholder, dtype="int64"),
139+
err_msg=f"batch {b}: padded slots [2:{target_slt}) should equal placeholder {expected_placeholder}",
140+
)
141+
97142

98143
if __name__ == "__main__":
99144
unittest.main()

tests/spec_decode/test_ngram_gpu_kernel.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@ def test_correctness_basic(self):
365365
gpu_data["max_dec_len"],
366366
max_ngram_size,
367367
max_draft_tokens,
368+
False,
368369
)
369370
paddle.device.synchronize()
370371

@@ -395,6 +396,7 @@ def test_correctness_varied_seeds(self):
395396
data["max_dec_len"],
396397
3,
397398
10,
399+
False,
398400
)
399401
gpu_data = _to_gpu(data)
400402
self.ngram_match(
@@ -409,6 +411,7 @@ def test_correctness_varied_seeds(self):
409411
gpu_data["max_dec_len"],
410412
3,
411413
10,
414+
False,
412415
)
413416
paddle.device.synchronize()
414417
np.testing.assert_array_equal(gpu_data["seq_lens_this_time"].numpy(), cpu_slt)
@@ -456,6 +459,7 @@ def test_large_batch_long_seq(self):
456459
gpu_data["max_dec_len"],
457460
3,
458461
10,
462+
False,
459463
)
460464
paddle.device.synchronize()
461465
finally:
@@ -485,6 +489,7 @@ def test_single_batch_long_seq(self):
485489
data["max_dec_len"],
486490
3,
487491
10,
492+
False,
488493
)
489494
gpu_data = _to_gpu(data)
490495
self.ngram_match(
@@ -499,6 +504,7 @@ def test_single_batch_long_seq(self):
499504
gpu_data["max_dec_len"],
500505
3,
501506
10,
507+
False,
502508
)
503509
paddle.device.synchronize()
504510
np.testing.assert_array_equal(gpu_data["seq_lens_this_time"].numpy(), cpu_slt)
@@ -542,6 +548,7 @@ def test_many_short_seqs(self):
542548
gpu_data["max_dec_len"],
543549
3,
544550
10,
551+
False,
545552
)
546553
paddle.device.synchronize()
547554
finally:
@@ -569,6 +576,7 @@ def test_latency(self):
569576
d["max_dec_len"],
570577
3,
571578
10,
579+
False,
572580
)
573581
paddle.device.synchronize()
574582

@@ -591,6 +599,7 @@ def test_latency(self):
591599
gpu_data["max_dec_len"],
592600
3,
593601
10,
602+
False,
594603
)
595604
paddle.device.synchronize()
596605
t1 = time.perf_counter()
@@ -641,6 +650,7 @@ def test_latency_scaling(self):
641650
gpu_data["max_dec_len"],
642651
3,
643652
10,
653+
False,
644654
)
645655
paddle.device.synchronize()
646656

@@ -660,6 +670,7 @@ def test_latency_scaling(self):
660670
gpu_data["max_dec_len"],
661671
3,
662672
10,
673+
False,
663674
)
664675
paddle.device.synchronize()
665676
gpu_ms = (time.perf_counter() - t0) / n_runs * 1000
@@ -742,6 +753,7 @@ def test_latency_extreme(self):
742753
gpu_data["max_dec_len"],
743754
3,
744755
10,
756+
False,
745757
)
746758
paddle.device.synchronize()
747759

@@ -761,6 +773,7 @@ def test_latency_extreme(self):
761773
gpu_data["max_dec_len"],
762774
3,
763775
10,
776+
False,
764777
)
765778
paddle.device.synchronize()
766779
t1 = time.perf_counter()

0 commit comments

Comments
 (0)