@@ -50,6 +50,12 @@ def _cpu_ngram_match(
5050 threshold = 128 ,
5151):
5252 """Pure NumPy reference matching the original ngram_match.cc logic."""
53+ # Flatten (N,1) shaped arrays to 1D for scalar indexing
54+ max_dec_len = max_dec_len .ravel ()
55+ step_idx = step_idx .ravel ()
56+ draft_token_num = draft_token_num .ravel ()
57+ prompt_lens = prompt_lens .ravel ()
58+ input_ids_len = input_ids_len .ravel ()
5359 max_batch_size = seq_lens_this_time .shape [0 ]
5460
5561 unprocessed = sum (1 for b in range (max_batch_size ) if seq_lens_encoder [b ] > 0 or seq_lens_decoder [b ] > 0 )
@@ -135,6 +141,11 @@ def _cpu_hybrid_mtp_ngram(
135141 threshold = 1024 ,
136142):
137143 """Pure NumPy reference matching the original ngram_match_mixed.cu CPU logic."""
144+ # Flatten (N,1) shaped arrays to 1D for scalar indexing
145+ max_dec_len = max_dec_len .ravel ()
146+ step_idx = step_idx .ravel ()
147+ draft_token_num = draft_token_num .ravel ()
148+ input_ids_len = input_ids_len .ravel ()
138149 max_batch_size = seq_lens_this_time .shape [0 ]
139150
140151 unprocessed = sum (1 for b in range (max_batch_size ) if seq_lens_decoder [b ] > 0 )
@@ -223,13 +234,13 @@ def _make_ngram_test_data(batch_size=4, input_len=64, max_model_len=256, max_dra
223234 for b in range (batch_size ):
224235 # Copy prompt into token_ids_all
225236 token_ids_all [b , :input_len ] = input_ids [b ]
226- # Simulate some generated tokens that repeat parts of the prompt
237+ # Simulate generated tokens: copy contiguous blocks from prompt
238+ # to guarantee ngram matches exist
227239 gen_len = 20
228- for g in range (gen_len ):
229- # Copy from prompt to create ngram-matchable patterns
230- src = rng .randint (0 , max (1 , input_len - 5 ))
231- token_ids_all [b , input_len + g ] = input_ids [b , src + (g % 5 )]
232- step_idx [b ] = gen_len
240+ src = rng .randint (0 , max (1 , input_len - gen_len ))
241+ token_ids_all [b , input_len : input_len + gen_len ] = input_ids [b , src : src + gen_len ]
242+ # step_idx = last valid position (0-based index)
243+ step_idx [b ] = gen_len - 1
233244
234245 return {
235246 "input_ids" : input_ids ,
@@ -264,11 +275,12 @@ def _make_mixed_test_data(batch_size=4, input_len=64, pre_ids_len=256, max_draft
264275 max_dec_len = np .full ((batch_size , 1 ), 200 , dtype = np .int64 )
265276
266277 for b in range (batch_size ):
278+ # Copy contiguous blocks from prompt to guarantee ngram matches
267279 gen_len = 20
268- for g in range ( gen_len ):
269- src = rng . randint ( 0 , max ( 1 , input_len - 5 ))
270- pre_ids [ b , g ] = input_ids [ b , src + ( g % 5 )]
271- step_idx [b ] = gen_len
280+ src = rng . randint ( 0 , max ( 1 , input_len - gen_len ))
281+ pre_ids [ b , : gen_len ] = input_ids [ b , src : src + gen_len ]
282+ # step_idx = last valid position (0-based index)
283+ step_idx [b ] = gen_len - 1
272284
273285 return {
274286 "input_ids" : input_ids ,
0 commit comments