Skip to content

Commit ace86b8

Browse files
committed
update test
1 parent 00c9cfe commit ace86b8

3 files changed

Lines changed: 11 additions & 7 deletions

File tree

custom_ops/gpu_ops/speculate_decoding/ngram_match.cu

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -215,13 +215,14 @@ __global__ void ngram_match_gather_kernel(
215215
}
216216
}
217217

218-
// === Pad seq_lens_this_time to K+1 for cudagraph stability ===
219-
// Variable seq_lens_this_time (range [1, K+1]) clashes with cudagraph's
220-
// fixed launch params captured at warm-up time; downstream kernels read
221-
// past valid cu_seqlens / slot_mapping when replay sees a smaller slt,
222-
// leading to OOB / CUDA 700. When pad_to_max=true (cudagraph enabled),
223-
// pad missing positions with a placeholder so slt is fixed at K+1.
224-
// pad_to_max=false skips the padding cost when cudagraph is off.
218+
// === Pad seq_lens_this_time to num_speculative_tokens+1 for cudagraph
219+
// stability === Variable seq_lens_this_time (range [1,
220+
// num_speculative_tokens+1]) clashes with cudagraph's fixed launch params
221+
// captured at warm-up time; downstream kernels read past valid cu_seqlens /
222+
// slot_mapping when replay sees a smaller slt, leading to OOB / CUDA 700.
223+
// When pad_to_max=true (cudagraph enabled), pad missing positions with a
224+
// placeholder so slt is fixed at num_speculative_tokens+1. pad_to_max=false
225+
// skips the padding cost when cudagraph is off.
225226
if (pad_to_max) {
226227
int target_slt = max_draft_tokens_param + 1;
227228
if (actual < target_slt) {

tests/operators/test_ngram_match.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def test_basic_match(self):
6161
max_dec_len,
6262
3,
6363
4,
64+
False, # pad_to_max: match unchanged (no-pad) reference behavior
6465
)
6566

6667
# Extract non-zero tokens and assert the results.
@@ -100,6 +101,7 @@ def test_no_match(self):
100101
max_dec_len,
101102
3,
102103
3,
104+
False, # pad_to_max: match unchanged (no-pad) reference behavior
103105
)
104106

105107
# No match → should only keep 1 token

tests/spec_decode/test_benchmark_ngram_kernel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def _run_gpu(ngram_match_fn, gpu_data):
155155
gpu_data["max_dec_len"],
156156
MAX_NGRAM_SIZE,
157157
MAX_DRAFT_TOKENS,
158+
False, # pad_to_max: benchmark unrelated to cudagraph, measure no-pad cost
158159
)
159160

160161

0 commit comments

Comments
 (0)