@@ -61,7 +61,7 @@ def run_kernel(paddle_inputs, inputs):
6161
6262def get_outputs (paddle_inputs ) -> Dict [str , np .ndarray ]:
6363 """Extract all in-place-modified tensors back to numpy."""
64- keys = ["accept_tokens" , "accept_num" ]
64+ keys = ["accept_tokens" , "accept_num" , "step_idx" ]
6565 return {k : paddle_inputs [k ].numpy () for k in keys }
6666
6767
@@ -100,7 +100,9 @@ def gen_inputs(
100100 accept_tokens [i , : accept_num [i ]] = rng .integers (1 , vocab_size , size = accept_num [i ])
101101
102102 stop_flags = np .zeros (real_bsz , dtype = "bool" )
103- seq_lens = (step_idx + accept_num ).astype ("int32" )
103+ # New semantics: step_idx already includes accept_num,
104+ # so seq_lens = step_idx (not step_idx + accept_num)
105+ seq_lens = step_idx .astype ("int32" )
104106
105107 # stop_seqs: [bsz, stop_seqs_bs, stop_seqs_max_len]
106108 stop_seqs = rng .integers (1 , vocab_size , size = (real_bsz , stop_seqs_bs , stop_seqs_max_len )).astype ("int64" )
@@ -140,10 +142,10 @@ def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str
140142 """Python reference — must match CUDA kernel logic exactly."""
141143 accept_tokens = inputs ["accept_tokens" ].copy ()
142144 accept_num = inputs ["accept_num" ].copy ()
145+ step_idx = inputs ["step_idx" ].copy ()
143146 stop_flags = inputs ["stop_flags" ].copy ()
144147 token_ids_all = inputs ["token_ids_all" ]
145148 prompt_lens = inputs ["prompt_lens" ]
146- step_idx = inputs ["step_idx" ]
147149 stop_seqs = inputs ["stop_seqs" ]
148150 stop_seqs_len = inputs ["stop_seqs_len" ]
149151 end_ids = inputs ["end_ids" ]
@@ -174,18 +176,22 @@ def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str
174176
175177 accept_idx = 0
176178 is_end = False
177- while accept_idx <= an - 1 and not is_end :
178- if step_idx_now + accept_idx + 1 < stop_seq_len :
179+ # Loop excludes last accept token (accept_idx < an - 1)
180+ while accept_idx < an - 1 and not is_end :
181+ if step_idx_now - an + accept_idx + 1 < stop_seq_len :
179182 accept_idx += 1
180183 continue
181184
182185 # Check one stop_seq match
183186 for i in range (stop_seq_len - 1 , - 1 , - 1 ):
184187 cur_token_idx = - 1
185- if stop_seq_len - 1 - i < accept_idx :
186- cur_token_idx = accept_tokens_now [accept_idx - (stop_seq_len - 1 - i ) - 1 ]
188+ # Token boundary: <= (not <)
189+ if stop_seq_len - 1 - i <= accept_idx :
190+ # Accept token index: no -1 offset
191+ cur_token_idx = accept_tokens_now [accept_idx - (stop_seq_len - 1 - i )]
187192 else :
188- pre_ids_idx = step_idx_now + accept_idx - (stop_seq_len - 1 - i )
193+ # Pre_ids index: step_idx_now - an + accept_idx
194+ pre_ids_idx = step_idx_now - an + accept_idx - (stop_seq_len - 1 - i )
189195 if pre_ids_idx <= 0 :
190196 break
191197 cur_token_idx = pre_ids_now [pre_ids_idx ]
@@ -199,13 +205,19 @@ def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str
199205 accept_idx += 1
200206
201207 if is_end :
202- accept_num [bid ] = accept_idx
203- accept_tokens [bid , accept_idx - 1 ] = end_ids [0 ]
204- # stop_flags[bid] = True # kernel no longer sets stop_flags
208+ # accept_idx already incremented by 1
209+ # keep stop token + append end_id, rollback step_idx
210+ keep_count = accept_idx + 1
211+ discarded = an - keep_count
212+ if discarded > 0 :
213+ step_idx [bid ] -= discarded
214+ accept_num [bid ] = keep_count
215+ accept_tokens [bid , accept_idx ] = end_ids [0 ]
205216
206217 return {
207218 "accept_tokens" : accept_tokens ,
208219 "accept_num" : accept_num ,
220+ "step_idx" : step_idx ,
209221 }
210222
211223
@@ -245,7 +257,7 @@ def _run_and_get(self, inputs):
245257 def _check_all_outputs (self , inputs , outputs ):
246258 """Compare ALL output tensors against reference."""
247259 ref = reference_spec_set_stop_value_multi_seqs (inputs )
248- for key in ["accept_tokens" , "accept_num" ]:
260+ for key in ["accept_tokens" , "accept_num" , "step_idx" ]:
249261 np .testing .assert_array_equal (outputs [key ], ref [key ], err_msg = f"{ key } mismatch" )
250262
251263 def _run_full_test (self , config ):
@@ -266,16 +278,20 @@ def test_configs(self):
266278 def test_match_in_accept_tokens_only (self ):
267279 """Stop seq found entirely within accept_tokens."""
268280 inputs = gen_inputs (real_bsz = 1 , accept_tokens_len = 5 , stop_seqs_bs = 1 , stop_seqs_max_len = 3 , seed = 10 )
269- # Place stop seq [A, B, C] at accept_tokens positions [0,1,2]
281+ # Place stop seq [10, 20, 30] matching at accept_idx=2
282+ # New semantics: step_idx already includes accept_num
270283 inputs ["accept_num" ][:] = 4
271284 inputs ["accept_tokens" ][0 , :4 ] = [10 , 20 , 30 , 40 ]
272285 inputs ["stop_seqs" ][0 , 0 , :3 ] = [10 , 20 , 30 ]
273286 inputs ["stop_seqs_len" ][0 , 0 ] = 3
274- inputs ["step_idx" ][:] = 10
287+ inputs ["step_idx" ][:] = 14 # old_step=10 + accept_num=4
275288 inputs ["stop_flags" ][:] = False
276289 inputs ["min_tokens" ][:] = 0
277290 outputs = self ._run_and_get (inputs )
278291 self ._check_all_outputs (inputs , outputs )
292+ # Match at accept_idx=2: keep_count=2+1+1=4, accept_tokens[3]=end_id
293+ self .assertEqual (outputs ["accept_num" ][0 ], 4 )
294+ self .assertEqual (outputs ["accept_tokens" ][0 , 3 ], - 1 )
279295
280296 def test_match_spanning_pre_ids_and_accept (self ):
281297 """Stop seq spans token_ids_all (pre_ids) and accept_tokens."""
@@ -288,27 +304,32 @@ def test_match_spanning_pre_ids_and_accept(self):
288304 seed = 20 ,
289305 )
290306 inputs ["prompt_lens" ][:] = 0
291- inputs ["step_idx" ][:] = 6
307+ # New semantics: step_idx = old_step(6) + accept_num(3) = 9
308+ inputs ["step_idx" ][:] = 9
292309 inputs ["accept_num" ][:] = 3
293- # Kernel matching at accept_idx=2 (3rd token, 0-indexed ):
294- # i=2(last): stop_seq_len-1-i=0 < accept_idx(2) -> accept_tokens[2-0-1 ]=accept_tokens[1]
295- # i=1: stop_seq_len-1-i=1 < accept_idx(2) -> accept_tokens[2- 1-1]=accept_tokens[0]
296- # i=0: stop_seq_len-1-i=2 >= accept_idx(2) -> pre_ids[step_idx+2-(3-1-0) ]=pre_ids[6 ]
297- # So stop_seq should be [pre_ids[6 ], accept_tokens[0], accept_tokens[1]]
298- inputs ["token_ids_all" ][0 , 6 ] = 99
310+ # New kernel matching at accept_idx=1 (for loop: 0..accept_num-2=1 ):
311+ # i=2(last): j=0, 0<=1 -> accept_tokens[1-0 ]=accept_tokens[1]=22
312+ # i=1: j=1, 1<=1 -> accept_tokens[1-1]=accept_tokens[0]=11
313+ # i=0: j=2, 2<=1 false -> pre_ids[9-3+1-2 ]=pre_ids[5 ]
314+ # So stop_seq should be [pre_ids[5 ], accept_tokens[0], accept_tokens[1]]
315+ inputs ["token_ids_all" ][0 , 5 ] = 99
299316 inputs ["accept_tokens" ][0 , :3 ] = [11 , 22 , 33 ]
300317 inputs ["stop_seqs" ][0 , 0 , :3 ] = [99 , 11 , 22 ]
301318 inputs ["stop_seqs_len" ][0 , 0 ] = 3
302319 inputs ["stop_flags" ][:] = False
303320 inputs ["min_tokens" ][:] = 0
304321 outputs = self ._run_and_get (inputs )
305322 self ._check_all_outputs (inputs , outputs )
306- # Match at accept_idx=2 , loop increments to 3
323+ # Match at accept_idx=1 , loop increments to 2, keep_count= 3
307324 self .assertEqual (outputs ["accept_num" ][0 ], 3 )
308325 self .assertEqual (outputs ["accept_tokens" ][0 , 2 ], - 1 )
309326
310- def test_match_in_pre_ids_only (self ):
311- """Stop seq found entirely within token_ids_all (pre_ids), matching at accept_idx=0."""
327+ def test_match_mostly_in_pre_ids (self ):
328+ """Stop seq mostly in pre_ids, last token in accept_tokens[0], matching at accept_idx=0.
329+
330+ New kernel: at accept_idx=0, stop_seq[last] always comes from accept_tokens[0],
331+ remaining tokens come from pre_ids.
332+ """
312333 inputs = gen_inputs (
313334 real_bsz = 1 ,
314335 accept_tokens_len = 5 ,
@@ -318,25 +339,27 @@ def test_match_in_pre_ids_only(self):
318339 seed = 30 ,
319340 )
320341 inputs ["prompt_lens" ][:] = 0
321- inputs ["step_idx" ][:] = 8
342+ # New semantics: step_idx = old_step(8) + accept_num(3) = 11
343+ inputs ["step_idx" ][:] = 11
322344 inputs ["accept_num" ][:] = 3
323- # pre_ids at step_idx positions: token_ids_all[0, 6]=50, [0,7]=60, [0,8]=70
324- # stop_seq = [50, 60, 70], all 3 tokens are in pre_ids
325- # For accept_idx=0: step_idx_now + 0 + 1 = 9 >= stop_seq_len=3, so we check
326- # i=2: pre_ids_idx = 8+0-(3-1-2) = 8 -> pre_ids_now[8] = 70
327- # i=1: pre_ids_idx = 8+0-(3-1-1) = 7 -> pre_ids_now[7] = 60
328- # i=0: pre_ids_idx = 8+0-(3-1-0) = 6 -> pre_ids_now[6] = 50
345+ # At accept_idx=0 with stop_seq_len=3:
346+ # i=2: j=0, 0<=0 -> accept_tokens[0]=70
347+ # i=1: j=1, 1<=0 false -> pre_ids[11-3+0-1]=pre_ids[7]=60
348+ # i=0: j=2, 2<=0 false -> pre_ids[11-3+0-2]=pre_ids[6]=50
329349 inputs ["token_ids_all" ][0 , 6 ] = 50
330350 inputs ["token_ids_all" ][0 , 7 ] = 60
331- inputs ["token_ids_all" ][0 , 8 ] = 70
332- inputs ["accept_tokens" ][0 , :3 ] = [1 , 2 , 3 ]
351+ inputs ["accept_tokens" ][0 , :3 ] = [70 , 2 , 3 ]
333352 inputs ["stop_seqs" ][0 , 0 , :3 ] = [50 , 60 , 70 ]
334353 inputs ["stop_seqs_len" ][0 , 0 ] = 3
335354 inputs ["stop_flags" ][:] = False
336355 inputs ["min_tokens" ][:] = 0
337356 outputs = self ._run_and_get (inputs )
338357 self ._check_all_outputs (inputs , outputs )
339- self .assertEqual (outputs ["accept_num" ][0 ], 1 )
358+ # Match at accept_idx=0, loop increments to 1, keep_count=2
359+ # discarded = 3 - 2 = 1, step_idx rolled back by 1 (11->10)
360+ self .assertEqual (outputs ["accept_num" ][0 ], 2 )
361+ self .assertEqual (outputs ["accept_tokens" ][0 , 1 ], - 1 )
362+ self .assertEqual (outputs ["step_idx" ][0 ], 10 )
340363
341364 def test_already_stopped (self ):
342365 """Kernel skips sequences with stop_flags=True."""
@@ -346,9 +369,10 @@ def test_already_stopped(self):
346369 inputs ["stop_seqs_len" ][:] = 2
347370 outputs = self ._run_and_get (inputs )
348371 self ._check_all_outputs (inputs , outputs )
349- # accept_tokens and accept_num should be unchanged
372+ # accept_tokens, accept_num and step_idx should be unchanged
350373 np .testing .assert_array_equal (outputs ["accept_tokens" ], inputs ["accept_tokens" ])
351374 np .testing .assert_array_equal (outputs ["accept_num" ], inputs ["accept_num" ])
375+ np .testing .assert_array_equal (outputs ["step_idx" ], inputs ["step_idx" ])
352376
353377 def test_min_tokens_blocks_stop (self ):
354378 """Kernel skips stop check when step_idx < min_tokens."""
@@ -361,17 +385,17 @@ def test_min_tokens_blocks_stop(self):
361385 seed = 50 ,
362386 )
363387 inputs ["prompt_lens" ][:] = 0
364- inputs ["step_idx" ][:] = 8
388+ # New semantics: step_idx = old_step(8) + accept_num(3) = 11
389+ inputs ["step_idx" ][:] = 11
365390 inputs ["accept_num" ][:] = 3
366- # Same setup that would match (like test_match_in_pre_ids_only )
391+ # Same setup that would match (like test_match_mostly_in_pre_ids )
367392 inputs ["token_ids_all" ][0 , 6 ] = 50
368393 inputs ["token_ids_all" ][0 , 7 ] = 60
369- inputs ["token_ids_all" ][0 , 8 ] = 70
370- inputs ["accept_tokens" ][0 , :3 ] = [1 , 2 , 3 ]
394+ inputs ["accept_tokens" ][0 , :3 ] = [70 , 2 , 3 ]
371395 inputs ["stop_seqs" ][0 , 0 , :3 ] = [50 , 60 , 70 ]
372396 inputs ["stop_seqs_len" ][0 , 0 ] = 3
373397 inputs ["stop_flags" ][:] = False
374- inputs ["min_tokens" ][:] = 100 # step_idx=8 < 100, should NOT stop
398+ inputs ["min_tokens" ][:] = 100 # step_idx=11 < 100, should NOT stop
375399 outputs = self ._run_and_get (inputs )
376400 self ._check_all_outputs (inputs , outputs )
377401
@@ -386,17 +410,17 @@ def test_min_tokens_allows_stop(self):
386410 seed = 60 ,
387411 )
388412 inputs ["prompt_lens" ][:] = 0
389- inputs ["step_idx" ][:] = 8
413+ # New semantics: step_idx = old_step(8) + accept_num(3) = 11
414+ inputs ["step_idx" ][:] = 11
390415 inputs ["accept_num" ][:] = 3
391- # Put stop_seq entirely in pre_ids (same pattern as test_match_in_pre_ids_only)
416+ # Same setup as test_match_mostly_in_pre_ids
392417 inputs ["token_ids_all" ][0 , 6 ] = 50
393418 inputs ["token_ids_all" ][0 , 7 ] = 60
394- inputs ["token_ids_all" ][0 , 8 ] = 70
395- inputs ["accept_tokens" ][0 , :3 ] = [1 , 2 , 3 ]
419+ inputs ["accept_tokens" ][0 , :3 ] = [70 , 2 , 3 ]
396420 inputs ["stop_seqs" ][0 , 0 , :3 ] = [50 , 60 , 70 ]
397421 inputs ["stop_seqs_len" ][0 , 0 ] = 3
398422 inputs ["stop_flags" ][:] = False
399- inputs ["min_tokens" ][:] = 5 # step_idx=8 >= 5, should stop
423+ inputs ["min_tokens" ][:] = 5 # step_idx=11 >= 5, should stop
400424 outputs = self ._run_and_get (inputs )
401425 self ._check_all_outputs (inputs , outputs )
402426
@@ -411,11 +435,12 @@ def test_multiple_stop_seqs_second_matches(self):
411435 seed = 70 ,
412436 )
413437 inputs ["prompt_lens" ][:] = 0
414- inputs ["step_idx" ][:] = 8
438+ # New semantics: step_idx = old_step(8) + accept_num(3) = 11
439+ inputs ["step_idx" ][:] = 11
415440 inputs ["accept_num" ][:] = 3
416- # accept_tokens: stop_seq[20,30] matches at accept_idx=2 :
417- # i=1: accept_tokens[2-0-1 ]=accept_tokens[1]=30 vs stop_seq[1]=30 OK
418- # i=0: accept_tokens[2- 1-1]=accept_tokens[0]=20 vs stop_seq[0]=20 OK
441+ # accept_tokens: stop_seq[20,30] matches at accept_idx=1 :
442+ # i=1(last): j=0, 0<=1 -> accept_tokens[1-0 ]=accept_tokens[1]=30
443+ # i=0: j=1, 1<=1 -> accept_tokens[1-1]=accept_tokens[0]=20
419444 inputs ["accept_tokens" ][0 , :3 ] = [20 , 30 , 40 ]
420445 # First stop seq doesn't match
421446 inputs ["stop_seqs" ][0 , 0 , :3 ] = [99 , 98 , 97 ]
@@ -440,17 +465,15 @@ def test_nonzero_prompt_lens(self):
440465 )
441466 prompt_len = 10
442467 inputs ["prompt_lens" ][:] = prompt_len
443- inputs ["step_idx" ][:] = 5
468+ # New semantics: step_idx = old_step(5) + accept_num(2) = 7
469+ inputs ["step_idx" ][:] = 7
444470 inputs ["accept_num" ][:] = 2
445471 inputs ["accept_tokens" ][0 , :2 ] = [55 , 66 ]
446472 # pre_ids_now starts at token_ids_all[0, prompt_len:]
447- # stop_seq = [X, 55] where X = token_ids_all[0, prompt_len + step_idx]
448- # For accept_idx=0: pre_ids_idx = step_idx + 0 - (2-1-0) = 5-1 = 4
449- # -> pre_ids_now[4] = token_ids_all[0, prompt_len + 4]
450- # For accept_idx=1 (second token is accept_tokens[0,0]=55):
451- # i=1: accept_tokens_now[1-(2-1-1)-1] = accept_tokens_now[0] = 55
452- # i=0: pre_ids_idx = step_idx + 1 - (2-1-0) = 5+1-1 = 5 -> pre_ids_now[5]
453- target_val = int (inputs ["token_ids_all" ][0 , prompt_len + 5 ])
473+ # At accept_idx=0 with stop_seq_len=2:
474+ # i=1(last): j=0, 0<=0 -> accept_tokens[0]=55
475+ # i=0: j=1, 1<=0 false -> pre_ids[7-2+0-1]=pre_ids[4]=token_ids_all[0, prompt_len+4]
476+ target_val = int (inputs ["token_ids_all" ][0 , prompt_len + 4 ])
454477 inputs ["stop_seqs" ][0 , 0 , :2 ] = [target_val , 55 ]
455478 inputs ["stop_seqs_len" ][0 , 0 ] = 2
456479 inputs ["stop_flags" ][:] = False
0 commit comments