Skip to content

Commit c349b12

Browse files
author
cloudforge1
committed
Fix test data: step_idx semantics and ngram-matchable patterns
1 parent 477f749 commit c349b12

1 file changed

Lines changed: 22 additions & 10 deletions

File tree

tests/spec_decode/test_ngram_gpu_kernel.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ def _cpu_ngram_match(
5050
threshold=128,
5151
):
5252
"""Pure NumPy reference matching the original ngram_match.cc logic."""
53+
# Flatten (N,1) shaped arrays to 1D for scalar indexing
54+
max_dec_len = max_dec_len.ravel()
55+
step_idx = step_idx.ravel()
56+
draft_token_num = draft_token_num.ravel()
57+
prompt_lens = prompt_lens.ravel()
58+
input_ids_len = input_ids_len.ravel()
5359
max_batch_size = seq_lens_this_time.shape[0]
5460

5561
unprocessed = sum(1 for b in range(max_batch_size) if seq_lens_encoder[b] > 0 or seq_lens_decoder[b] > 0)
@@ -135,6 +141,11 @@ def _cpu_hybrid_mtp_ngram(
135141
threshold=1024,
136142
):
137143
"""Pure NumPy reference matching the original ngram_match_mixed.cu CPU logic."""
144+
# Flatten (N,1) shaped arrays to 1D for scalar indexing
145+
max_dec_len = max_dec_len.ravel()
146+
step_idx = step_idx.ravel()
147+
draft_token_num = draft_token_num.ravel()
148+
input_ids_len = input_ids_len.ravel()
138149
max_batch_size = seq_lens_this_time.shape[0]
139150

140151
unprocessed = sum(1 for b in range(max_batch_size) if seq_lens_decoder[b] > 0)
@@ -223,13 +234,13 @@ def _make_ngram_test_data(batch_size=4, input_len=64, max_model_len=256, max_dra
223234
for b in range(batch_size):
224235
# Copy prompt into token_ids_all
225236
token_ids_all[b, :input_len] = input_ids[b]
226-
# Simulate some generated tokens that repeat parts of the prompt
237+
# Simulate generated tokens: copy contiguous blocks from prompt
238+
# to guarantee ngram matches exist
227239
gen_len = 20
228-
for g in range(gen_len):
229-
# Copy from prompt to create ngram-matchable patterns
230-
src = rng.randint(0, max(1, input_len - 5))
231-
token_ids_all[b, input_len + g] = input_ids[b, src + (g % 5)]
232-
step_idx[b] = gen_len
240+
src = rng.randint(0, max(1, input_len - gen_len))
241+
token_ids_all[b, input_len : input_len + gen_len] = input_ids[b, src : src + gen_len]
242+
# step_idx = last valid position (0-based index)
243+
step_idx[b] = gen_len - 1
233244

234245
return {
235246
"input_ids": input_ids,
@@ -264,11 +275,12 @@ def _make_mixed_test_data(batch_size=4, input_len=64, pre_ids_len=256, max_draft
264275
max_dec_len = np.full((batch_size, 1), 200, dtype=np.int64)
265276

266277
for b in range(batch_size):
278+
# Copy contiguous blocks from prompt to guarantee ngram matches
267279
gen_len = 20
268-
for g in range(gen_len):
269-
src = rng.randint(0, max(1, input_len - 5))
270-
pre_ids[b, g] = input_ids[b, src + (g % 5)]
271-
step_idx[b] = gen_len
280+
src = rng.randint(0, max(1, input_len - gen_len))
281+
pre_ids[b, :gen_len] = input_ids[b, src : src + gen_len]
282+
# step_idx = last valid position (0-based index)
283+
step_idx[b] = gen_len - 1
272284

273285
return {
274286
"input_ids": input_ids,

0 commit comments

Comments
 (0)