@@ -175,17 +175,19 @@ def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str
175175 accept_idx = 0
176176 is_end = False
177177 while accept_idx <= an - 1 and not is_end :
178- if step_idx_now + accept_idx + 1 < stop_seq_len :
178+ if step_idx_now - an + accept_idx + 1 < stop_seq_len :
179179 accept_idx += 1
180180 continue
181181
182182 # Check one stop_seq match
183183 for i in range (stop_seq_len - 1 , - 1 , - 1 ):
184184 cur_token_idx = - 1
185- if stop_seq_len - 1 - i < accept_idx :
186- cur_token_idx = accept_tokens_now [accept_idx - (stop_seq_len - 1 - i ) - 1 ]
185+ # 注意:新版本kernel改成了 <=,并且去掉了 -1
186+ if stop_seq_len - 1 - i <= accept_idx :
187+ cur_token_idx = accept_tokens_now [accept_idx - (stop_seq_len - 1 - i )]
187188 else :
188- pre_ids_idx = step_idx_now + accept_idx - (stop_seq_len - 1 - i )
189+ # 新版本:step_idx已经包含accept_num,所以要减去
190+ pre_ids_idx = step_idx_now - an + accept_idx - (stop_seq_len - 1 - i )
189191 if pre_ids_idx <= 0 :
190192 break
191193 cur_token_idx = pre_ids_now [pre_ids_idx ]
@@ -290,22 +292,22 @@ def test_match_spanning_pre_ids_and_accept(self):
290292 inputs ["prompt_lens" ][:] = 0
291293 inputs ["step_idx" ][:] = 6
292294 inputs ["accept_num" ][:] = 3
293- # Kernel matching at accept_idx=2 (3rd token, 0-indexed):
294- # i=2(last): stop_seq_len-1-i=0 < accept_idx(2) -> accept_tokens[2-0-1]=accept_tokens[1]
295- # i=1: stop_seq_len-1-i=1 < accept_idx(2 ) -> accept_tokens[2-1-1]= accept_tokens[0]
296- # i=0: stop_seq_len-1-i=2 > = accept_idx(2 ) -> pre_ids[step_idx+2-(3-1-0)]=pre_ids[6]
297- # So stop_seq should be [pre_ids[6], accept_tokens[0], accept_tokens[1]]
298- inputs ["token_ids_all" ][0 , 6 ] = 99
295+ # stop_seq spans pre_ids and accept_tokens
296+ # For accept_idx=1: step_idx_now - accept_num + 1 + 1 = 6-3+1+1 = 5 >= stop_seq_len=3, so we check
297+ # i=2: stop_seq_len-1-i=0 <= accept_idx(1 ) -> accept_tokens[1-0] = accept_tokens[1] = 22
298+ # i=1: stop_seq_len-1-i=1 < = accept_idx(1 ) -> accept_tokens[1-1] = accept_tokens[0] = 11
299+ # i=0: stop_seq_len-1-i=2 > accept_idx(1) -> pre_ids_idx = 6-3+1-(3-1-0) = 4-2 = 2 -> pre_ids[2] = 99
300+ inputs ["token_ids_all" ][0 , 2 ] = 99
299301 inputs ["accept_tokens" ][0 , :3 ] = [11 , 22 , 33 ]
300302 inputs ["stop_seqs" ][0 , 0 , :3 ] = [99 , 11 , 22 ]
301303 inputs ["stop_seqs_len" ][0 , 0 ] = 3
302304 inputs ["stop_flags" ][:] = False
303305 inputs ["min_tokens" ][:] = 0
304306 outputs = self ._run_and_get (inputs )
305307 self ._check_all_outputs (inputs , outputs )
306- # Match at accept_idx=2 , loop increments to 3
307- self .assertEqual (outputs ["accept_num" ][0 ], 3 )
308- self .assertEqual (outputs ["accept_tokens" ][0 , 2 ], - 1 )
308+ # Match at accept_idx=1 , loop increments to 2
309+ self .assertEqual (outputs ["accept_num" ][0 ], 2 )
310+ self .assertEqual (outputs ["accept_tokens" ][0 , 1 ], inputs [ "end_ids" ][ 0 ] )
309311
310312 def test_match_in_pre_ids_only (self ):
311313 """Stop seq found entirely within token_ids_all (pre_ids), matching at accept_idx=0."""
@@ -314,29 +316,29 @@ def test_match_in_pre_ids_only(self):
314316 accept_tokens_len = 5 ,
315317 max_model_len = 32 ,
316318 stop_seqs_bs = 1 ,
317- stop_seqs_max_len = 3 ,
319+ stop_seqs_max_len = 4 , # 需要4个元素
318320 seed = 30 ,
319321 )
320322 inputs ["prompt_lens" ][:] = 0
321323 inputs ["step_idx" ][:] = 8
322324 inputs ["accept_num" ][:] = 3
323- # pre_ids at step_idx positions: token_ids_all[0, 6]=50, [0,7]=60, [0,8]=70
324- # stop_seq = [50, 60, 70], all 3 tokens are in pre_ids
325- # For accept_idx=0: step_idx_now + 0 + 1 = 9 >= stop_seq_len=3, so we check
326- # i=2: pre_ids_idx = 8+0-(3-1-2) = 8 -> pre_ids_now[8] = 70
327- # i=1: pre_ids_idx = 8+0-(3-1-1) = 7 -> pre_ids_now[7] = 60
328- # i=0: pre_ids_idx = 8+0-(3-1-0) = 6 -> pre_ids_now[6] = 50
329- inputs ["token_ids_all" ][0 , 6 ] = 50
330- inputs ["token_ids_all" ][0 , 7 ] = 60
331- inputs ["token_ids_all" ][0 , 8 ] = 70
332- inputs ["accept_tokens" ][0 , :3 ] = [1 , 2 , 3 ]
333- inputs ["stop_seqs" ][0 , 0 , :3 ] = [50 , 60 , 70 ]
334- inputs ["stop_seqs_len" ][0 , 0 ] = 3
325+ # stop_seq partially in pre_ids, partially in accept_tokens
326+ # For accept_idx=1: step_idx_now - accept_num + 1 + 1 = 8-3+1+1 = 7 >= stop_seq_len=4, so we check
327+ # i=3: stop_seq_len-1-i=0 <= accept_idx(1) -> accept_tokens[1-0] = accept_tokens[1] = 22
328+ # i=2: stop_seq_len-1-i=1 <= accept_idx(1) -> accept_tokens[1-1] = accept_tokens[0] = 11
329+ # i=1: stop_seq_len-1-i=2 > accept_idx(1) -> pre_ids_idx = 8-3+1-(4-1-1) = 6-2 = 4 -> pre_ids[4] = 60
330+ # i=0: stop_seq_len-1-i=3 > accept_idx(1) -> pre_ids_idx = 8-3+1-(4-1-0) = 6-3 = 3 -> pre_ids[3] = 50
331+ inputs ["token_ids_all" ][0 , 3 ] = 50
332+ inputs ["token_ids_all" ][0 , 4 ] = 60
333+ inputs ["accept_tokens" ][0 , :3 ] = [11 , 22 , 3 ]
334+ inputs ["stop_seqs" ][0 , 0 , :4 ] = [50 , 60 , 11 , 22 ]
335+ inputs ["stop_seqs_len" ][0 , 0 ] = 4
335336 inputs ["stop_flags" ][:] = False
336337 inputs ["min_tokens" ][:] = 0
337338 outputs = self ._run_and_get (inputs )
338339 self ._check_all_outputs (inputs , outputs )
339- self .assertEqual (outputs ["accept_num" ][0 ], 1 )
340+ # Match at accept_idx=1, loop increments to 2
341+ self .assertEqual (outputs ["accept_num" ][0 ], 2 )
340342
341343 def test_already_stopped (self ):
342344 """Kernel skips sequences with stop_flags=True."""
0 commit comments