Skip to content

Commit 0f4325c

Browse files
author
guanshihui]
committed
[BugFix][Speculative Decoding] fix bug of speculate limit_thinking and stop_seqs
1 parent 98f3fc9 commit 0f4325c

File tree

3 files changed

+101
-71
lines changed

3 files changed

+101
-71
lines changed

custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length.cu

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,7 @@ __global__ void speculate_limit_thinking_content_length_kernel(
9898
if (max_think_len > 0) {
9999
// A) 超长触发:到达 max_think_len 时开始注入(从本 token 起输出
100100
// inject_token_ids[0])
101-
if (status == 0 &&
102-
(current_step - 1) ==
103-
max_think_len) { // current_step - 1 是因为 speculate_verify 里
104-
// step_idx + 1 了
101+
if (status == 0 && current_step == max_think_len) {
105102
status = (inject_len > 0) ? 1 : done_status;
106103
}
107104
} else if (max_think_len == 0) {

custom_ops/gpu_ops/speculate_decoding/speculate_set_stop_value_multi_seqs.cu

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
2424
int *accept_nums,
2525
const int64_t *token_ids_all,
2626
const int64_t *prompt_lens,
27-
const int64_t *step_idx,
27+
int64_t *step_idx,
2828
const int64_t *stop_seqs,
2929
const int *stop_seqs_len,
3030
const int *seq_lens,
@@ -56,9 +56,10 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
5656
if (!stop_flags[bid]) {
5757
int accept_idx = 0;
5858
bool is_end = false;
59-
// 遍历起始位置
60-
for (; accept_idx <= accept_num - 1 && !is_end; accept_idx++) {
61-
if (step_idx_now + accept_idx + 1 < stop_seq_len) {
59+
// 遍历起始位置(不包含最后一个 accept token,由
60+
// unified_update_model_status 处理 EOS 检测)
61+
for (; accept_idx < accept_num - 1 && !is_end; accept_idx++) {
62+
if (step_idx_now - accept_num + accept_idx + 1 < stop_seq_len) {
6263
#ifdef DEBUG_SPEC_STOP_SEQS
6364
printf("num %d < stop_seq_len %d\n",
6465
step_idx_now - accept_num + accept_idx + 1,
@@ -71,7 +72,7 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
7172
int64_t cur_token_idx = -1;
7273

7374
// 通过当前值判断 token 是在 pre_ids 还是 accept_token 里
74-
if (stop_seq_len - 1 - i < accept_idx) {
75+
if (stop_seq_len - 1 - i <= accept_idx) {
7576
#ifdef DEBUG_SPEC_STOP_SEQS
7677
printf(
7778
"AcceptTokens bid:%d. tid:%d, accept_idx:%d, "
@@ -83,7 +84,7 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
8384
accept_idx - (stop_seq_len - 1 - i) - 1);
8485
#endif
8586
cur_token_idx =
86-
accept_tokens_now[accept_idx - (stop_seq_len - 1 - i) - 1];
87+
accept_tokens_now[accept_idx - (stop_seq_len - 1 - i)];
8788
} else {
8889
#ifdef DEBUG_SPEC_STOP_SEQS
8990
printf(
@@ -98,7 +99,7 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
9899
(stop_seq_len - 1 - i));
99100
#endif
100101
int pre_ids_idx =
101-
step_idx_now + accept_idx - (stop_seq_len - 1 - i);
102+
step_idx_now - accept_num + accept_idx - (stop_seq_len - 1 - i);
102103
// EC3
103104
// 特殊拼接会导致input_ids最后一位无特殊token,即pre_ids[0]可能为23,
104105
// 导致异常结束
@@ -129,8 +130,17 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
129130
printf("bid:%d end with accept_idx %d", bid, accept_idx);
130131
#endif
131132

132-
accept_nums[bid] = accept_idx;
133-
accept_tokens_now[accept_idx - 1] = end_ids[0];
133+
// accept_idx
134+
// 已自增1,stop_seq的最后一个token在accept_tokens[accept_idx-1]
135+
// 回退逻辑:丢弃stop token之后的多余token,保留stop token并追加eos
136+
// 对齐非MTP行为: ...<|im_end|> <eos>
137+
int keep_count = accept_idx + 1;
138+
int discarded = accept_num - keep_count;
139+
if (discarded > 0) {
140+
step_idx[bid] -= discarded;
141+
}
142+
accept_nums[bid] = keep_count;
143+
accept_tokens_now[accept_idx] = end_ids[0];
134144
// stop_flags[bid] = true;
135145
}
136146
}
@@ -167,7 +177,7 @@ void SpecGetStopFlagsMultiSeqs(const paddle::Tensor &accept_tokens,
167177
const_cast<int *>(accept_num.data<int>()),
168178
token_ids_all.data<int64_t>(),
169179
prompt_lens.data<int64_t>(),
170-
step_idx.data<int64_t>(),
180+
const_cast<int64_t *>(step_idx.data<int64_t>()),
171181
stop_seqs.data<int64_t>(),
172182
stop_seqs_len.data<int>(),
173183
seq_lens.data<int>(),

tests/operators/test_speculate_set_stop_value_multi_seqs.py

Lines changed: 80 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def run_kernel(paddle_inputs, inputs):
6161

6262
def get_outputs(paddle_inputs) -> Dict[str, np.ndarray]:
6363
"""Extract all in-place-modified tensors back to numpy."""
64-
keys = ["accept_tokens", "accept_num"]
64+
keys = ["accept_tokens", "accept_num", "step_idx"]
6565
return {k: paddle_inputs[k].numpy() for k in keys}
6666

6767

@@ -100,7 +100,9 @@ def gen_inputs(
100100
accept_tokens[i, : accept_num[i]] = rng.integers(1, vocab_size, size=accept_num[i])
101101

102102
stop_flags = np.zeros(real_bsz, dtype="bool")
103-
seq_lens = (step_idx + accept_num).astype("int32")
103+
# New semantics: step_idx already includes accept_num,
104+
# so seq_lens = step_idx (not step_idx + accept_num)
105+
seq_lens = step_idx.astype("int32")
104106

105107
# stop_seqs: [bsz, stop_seqs_bs, stop_seqs_max_len]
106108
stop_seqs = rng.integers(1, vocab_size, size=(real_bsz, stop_seqs_bs, stop_seqs_max_len)).astype("int64")
@@ -140,10 +142,10 @@ def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str
140142
"""Python reference — must match CUDA kernel logic exactly."""
141143
accept_tokens = inputs["accept_tokens"].copy()
142144
accept_num = inputs["accept_num"].copy()
145+
step_idx = inputs["step_idx"].copy()
143146
stop_flags = inputs["stop_flags"].copy()
144147
token_ids_all = inputs["token_ids_all"]
145148
prompt_lens = inputs["prompt_lens"]
146-
step_idx = inputs["step_idx"]
147149
stop_seqs = inputs["stop_seqs"]
148150
stop_seqs_len = inputs["stop_seqs_len"]
149151
end_ids = inputs["end_ids"]
@@ -174,18 +176,22 @@ def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str
174176

175177
accept_idx = 0
176178
is_end = False
177-
while accept_idx <= an - 1 and not is_end:
178-
if step_idx_now + accept_idx + 1 < stop_seq_len:
179+
# Loop excludes last accept token (accept_idx < an - 1)
180+
while accept_idx < an - 1 and not is_end:
181+
if step_idx_now - an + accept_idx + 1 < stop_seq_len:
179182
accept_idx += 1
180183
continue
181184

182185
# Check one stop_seq match
183186
for i in range(stop_seq_len - 1, -1, -1):
184187
cur_token_idx = -1
185-
if stop_seq_len - 1 - i < accept_idx:
186-
cur_token_idx = accept_tokens_now[accept_idx - (stop_seq_len - 1 - i) - 1]
188+
# Token boundary: <= (not <)
189+
if stop_seq_len - 1 - i <= accept_idx:
190+
# Accept token index: no -1 offset
191+
cur_token_idx = accept_tokens_now[accept_idx - (stop_seq_len - 1 - i)]
187192
else:
188-
pre_ids_idx = step_idx_now + accept_idx - (stop_seq_len - 1 - i)
193+
# Pre_ids index: step_idx_now - an + accept_idx
194+
pre_ids_idx = step_idx_now - an + accept_idx - (stop_seq_len - 1 - i)
189195
if pre_ids_idx <= 0:
190196
break
191197
cur_token_idx = pre_ids_now[pre_ids_idx]
@@ -199,13 +205,19 @@ def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str
199205
accept_idx += 1
200206

201207
if is_end:
202-
accept_num[bid] = accept_idx
203-
accept_tokens[bid, accept_idx - 1] = end_ids[0]
204-
# stop_flags[bid] = True # kernel no longer sets stop_flags
208+
# accept_idx already incremented by 1
209+
# keep stop token + append end_id, rollback step_idx
210+
keep_count = accept_idx + 1
211+
discarded = an - keep_count
212+
if discarded > 0:
213+
step_idx[bid] -= discarded
214+
accept_num[bid] = keep_count
215+
accept_tokens[bid, accept_idx] = end_ids[0]
205216

206217
return {
207218
"accept_tokens": accept_tokens,
208219
"accept_num": accept_num,
220+
"step_idx": step_idx,
209221
}
210222

211223

@@ -245,7 +257,7 @@ def _run_and_get(self, inputs):
245257
def _check_all_outputs(self, inputs, outputs):
246258
"""Compare ALL output tensors against reference."""
247259
ref = reference_spec_set_stop_value_multi_seqs(inputs)
248-
for key in ["accept_tokens", "accept_num"]:
260+
for key in ["accept_tokens", "accept_num", "step_idx"]:
249261
np.testing.assert_array_equal(outputs[key], ref[key], err_msg=f"{key} mismatch")
250262

251263
def _run_full_test(self, config):
@@ -266,16 +278,20 @@ def test_configs(self):
266278
def test_match_in_accept_tokens_only(self):
267279
"""Stop seq found entirely within accept_tokens."""
268280
inputs = gen_inputs(real_bsz=1, accept_tokens_len=5, stop_seqs_bs=1, stop_seqs_max_len=3, seed=10)
269-
# Place stop seq [A, B, C] at accept_tokens positions [0,1,2]
281+
# Place stop seq [10, 20, 30] matching at accept_idx=2
282+
# New semantics: step_idx already includes accept_num
270283
inputs["accept_num"][:] = 4
271284
inputs["accept_tokens"][0, :4] = [10, 20, 30, 40]
272285
inputs["stop_seqs"][0, 0, :3] = [10, 20, 30]
273286
inputs["stop_seqs_len"][0, 0] = 3
274-
inputs["step_idx"][:] = 10
287+
inputs["step_idx"][:] = 14 # old_step=10 + accept_num=4
275288
inputs["stop_flags"][:] = False
276289
inputs["min_tokens"][:] = 0
277290
outputs = self._run_and_get(inputs)
278291
self._check_all_outputs(inputs, outputs)
292+
# Match at accept_idx=2: keep_count=2+1+1=4, accept_tokens[3]=end_id
293+
self.assertEqual(outputs["accept_num"][0], 4)
294+
self.assertEqual(outputs["accept_tokens"][0, 3], -1)
279295

280296
def test_match_spanning_pre_ids_and_accept(self):
281297
"""Stop seq spans token_ids_all (pre_ids) and accept_tokens."""
@@ -288,27 +304,32 @@ def test_match_spanning_pre_ids_and_accept(self):
288304
seed=20,
289305
)
290306
inputs["prompt_lens"][:] = 0
291-
inputs["step_idx"][:] = 6
307+
# New semantics: step_idx = old_step(6) + accept_num(3) = 9
308+
inputs["step_idx"][:] = 9
292309
inputs["accept_num"][:] = 3
293-
# Kernel matching at accept_idx=2 (3rd token, 0-indexed):
294-
# i=2(last): stop_seq_len-1-i=0 < accept_idx(2) -> accept_tokens[2-0-1]=accept_tokens[1]
295-
# i=1: stop_seq_len-1-i=1 < accept_idx(2) -> accept_tokens[2-1-1]=accept_tokens[0]
296-
# i=0: stop_seq_len-1-i=2 >= accept_idx(2) -> pre_ids[step_idx+2-(3-1-0)]=pre_ids[6]
297-
# So stop_seq should be [pre_ids[6], accept_tokens[0], accept_tokens[1]]
298-
inputs["token_ids_all"][0, 6] = 99
310+
# New kernel matching at accept_idx=1 (for loop: 0..accept_num-2=1):
311+
# i=2(last): j=0, 0<=1 -> accept_tokens[1-0]=accept_tokens[1]=22
312+
# i=1: j=1, 1<=1 -> accept_tokens[1-1]=accept_tokens[0]=11
313+
# i=0: j=2, 2<=1 false -> pre_ids[9-3+1-2]=pre_ids[5]
314+
# So stop_seq should be [pre_ids[5], accept_tokens[0], accept_tokens[1]]
315+
inputs["token_ids_all"][0, 5] = 99
299316
inputs["accept_tokens"][0, :3] = [11, 22, 33]
300317
inputs["stop_seqs"][0, 0, :3] = [99, 11, 22]
301318
inputs["stop_seqs_len"][0, 0] = 3
302319
inputs["stop_flags"][:] = False
303320
inputs["min_tokens"][:] = 0
304321
outputs = self._run_and_get(inputs)
305322
self._check_all_outputs(inputs, outputs)
306-
# Match at accept_idx=2, loop increments to 3
323+
# Match at accept_idx=1, loop increments to 2, keep_count=3
307324
self.assertEqual(outputs["accept_num"][0], 3)
308325
self.assertEqual(outputs["accept_tokens"][0, 2], -1)
309326

310-
def test_match_in_pre_ids_only(self):
311-
"""Stop seq found entirely within token_ids_all (pre_ids), matching at accept_idx=0."""
327+
def test_match_mostly_in_pre_ids(self):
328+
"""Stop seq mostly in pre_ids, last token in accept_tokens[0], matching at accept_idx=0.
329+
330+
New kernel: at accept_idx=0, stop_seq[last] always comes from accept_tokens[0],
331+
remaining tokens come from pre_ids.
332+
"""
312333
inputs = gen_inputs(
313334
real_bsz=1,
314335
accept_tokens_len=5,
@@ -318,25 +339,27 @@ def test_match_in_pre_ids_only(self):
318339
seed=30,
319340
)
320341
inputs["prompt_lens"][:] = 0
321-
inputs["step_idx"][:] = 8
342+
# New semantics: step_idx = old_step(8) + accept_num(3) = 11
343+
inputs["step_idx"][:] = 11
322344
inputs["accept_num"][:] = 3
323-
# pre_ids at step_idx positions: token_ids_all[0, 6]=50, [0,7]=60, [0,8]=70
324-
# stop_seq = [50, 60, 70], all 3 tokens are in pre_ids
325-
# For accept_idx=0: step_idx_now + 0 + 1 = 9 >= stop_seq_len=3, so we check
326-
# i=2: pre_ids_idx = 8+0-(3-1-2) = 8 -> pre_ids_now[8] = 70
327-
# i=1: pre_ids_idx = 8+0-(3-1-1) = 7 -> pre_ids_now[7] = 60
328-
# i=0: pre_ids_idx = 8+0-(3-1-0) = 6 -> pre_ids_now[6] = 50
345+
# At accept_idx=0 with stop_seq_len=3:
346+
# i=2: j=0, 0<=0 -> accept_tokens[0]=70
347+
# i=1: j=1, 1<=0 false -> pre_ids[11-3+0-1]=pre_ids[7]=60
348+
# i=0: j=2, 2<=0 false -> pre_ids[11-3+0-2]=pre_ids[6]=50
329349
inputs["token_ids_all"][0, 6] = 50
330350
inputs["token_ids_all"][0, 7] = 60
331-
inputs["token_ids_all"][0, 8] = 70
332-
inputs["accept_tokens"][0, :3] = [1, 2, 3]
351+
inputs["accept_tokens"][0, :3] = [70, 2, 3]
333352
inputs["stop_seqs"][0, 0, :3] = [50, 60, 70]
334353
inputs["stop_seqs_len"][0, 0] = 3
335354
inputs["stop_flags"][:] = False
336355
inputs["min_tokens"][:] = 0
337356
outputs = self._run_and_get(inputs)
338357
self._check_all_outputs(inputs, outputs)
339-
self.assertEqual(outputs["accept_num"][0], 1)
358+
# Match at accept_idx=0, loop increments to 1, keep_count=2
359+
# discarded = 3 - 2 = 1, step_idx rolled back by 1 (11->10)
360+
self.assertEqual(outputs["accept_num"][0], 2)
361+
self.assertEqual(outputs["accept_tokens"][0, 1], -1)
362+
self.assertEqual(outputs["step_idx"][0], 10)
340363

341364
def test_already_stopped(self):
342365
"""Kernel skips sequences with stop_flags=True."""
@@ -346,9 +369,10 @@ def test_already_stopped(self):
346369
inputs["stop_seqs_len"][:] = 2
347370
outputs = self._run_and_get(inputs)
348371
self._check_all_outputs(inputs, outputs)
349-
# accept_tokens and accept_num should be unchanged
372+
# accept_tokens, accept_num and step_idx should be unchanged
350373
np.testing.assert_array_equal(outputs["accept_tokens"], inputs["accept_tokens"])
351374
np.testing.assert_array_equal(outputs["accept_num"], inputs["accept_num"])
375+
np.testing.assert_array_equal(outputs["step_idx"], inputs["step_idx"])
352376

353377
def test_min_tokens_blocks_stop(self):
354378
"""Kernel skips stop check when step_idx < min_tokens."""
@@ -361,17 +385,17 @@ def test_min_tokens_blocks_stop(self):
361385
seed=50,
362386
)
363387
inputs["prompt_lens"][:] = 0
364-
inputs["step_idx"][:] = 8
388+
# New semantics: step_idx = old_step(8) + accept_num(3) = 11
389+
inputs["step_idx"][:] = 11
365390
inputs["accept_num"][:] = 3
366-
# Same setup that would match (like test_match_in_pre_ids_only)
391+
# Same setup that would match (like test_match_mostly_in_pre_ids)
367392
inputs["token_ids_all"][0, 6] = 50
368393
inputs["token_ids_all"][0, 7] = 60
369-
inputs["token_ids_all"][0, 8] = 70
370-
inputs["accept_tokens"][0, :3] = [1, 2, 3]
394+
inputs["accept_tokens"][0, :3] = [70, 2, 3]
371395
inputs["stop_seqs"][0, 0, :3] = [50, 60, 70]
372396
inputs["stop_seqs_len"][0, 0] = 3
373397
inputs["stop_flags"][:] = False
374-
inputs["min_tokens"][:] = 100 # step_idx=8 < 100, should NOT stop
398+
inputs["min_tokens"][:] = 100 # step_idx=11 < 100, should NOT stop
375399
outputs = self._run_and_get(inputs)
376400
self._check_all_outputs(inputs, outputs)
377401

@@ -386,17 +410,17 @@ def test_min_tokens_allows_stop(self):
386410
seed=60,
387411
)
388412
inputs["prompt_lens"][:] = 0
389-
inputs["step_idx"][:] = 8
413+
# New semantics: step_idx = old_step(8) + accept_num(3) = 11
414+
inputs["step_idx"][:] = 11
390415
inputs["accept_num"][:] = 3
391-
# Put stop_seq entirely in pre_ids (same pattern as test_match_in_pre_ids_only)
416+
# Same setup as test_match_mostly_in_pre_ids
392417
inputs["token_ids_all"][0, 6] = 50
393418
inputs["token_ids_all"][0, 7] = 60
394-
inputs["token_ids_all"][0, 8] = 70
395-
inputs["accept_tokens"][0, :3] = [1, 2, 3]
419+
inputs["accept_tokens"][0, :3] = [70, 2, 3]
396420
inputs["stop_seqs"][0, 0, :3] = [50, 60, 70]
397421
inputs["stop_seqs_len"][0, 0] = 3
398422
inputs["stop_flags"][:] = False
399-
inputs["min_tokens"][:] = 5 # step_idx=8 >= 5, should stop
423+
inputs["min_tokens"][:] = 5 # step_idx=11 >= 5, should stop
400424
outputs = self._run_and_get(inputs)
401425
self._check_all_outputs(inputs, outputs)
402426

@@ -411,11 +435,12 @@ def test_multiple_stop_seqs_second_matches(self):
411435
seed=70,
412436
)
413437
inputs["prompt_lens"][:] = 0
414-
inputs["step_idx"][:] = 8
438+
# New semantics: step_idx = old_step(8) + accept_num(3) = 11
439+
inputs["step_idx"][:] = 11
415440
inputs["accept_num"][:] = 3
416-
# accept_tokens: stop_seq[20,30] matches at accept_idx=2:
417-
# i=1: accept_tokens[2-0-1]=accept_tokens[1]=30 vs stop_seq[1]=30 OK
418-
# i=0: accept_tokens[2-1-1]=accept_tokens[0]=20 vs stop_seq[0]=20 OK
441+
# accept_tokens: stop_seq[20,30] matches at accept_idx=1:
442+
# i=1(last): j=0, 0<=1 -> accept_tokens[1-0]=accept_tokens[1]=30
443+
# i=0: j=1, 1<=1 -> accept_tokens[1-1]=accept_tokens[0]=20
419444
inputs["accept_tokens"][0, :3] = [20, 30, 40]
420445
# First stop seq doesn't match
421446
inputs["stop_seqs"][0, 0, :3] = [99, 98, 97]
@@ -440,17 +465,15 @@ def test_nonzero_prompt_lens(self):
440465
)
441466
prompt_len = 10
442467
inputs["prompt_lens"][:] = prompt_len
443-
inputs["step_idx"][:] = 5
468+
# New semantics: step_idx = old_step(5) + accept_num(2) = 7
469+
inputs["step_idx"][:] = 7
444470
inputs["accept_num"][:] = 2
445471
inputs["accept_tokens"][0, :2] = [55, 66]
446472
# pre_ids_now starts at token_ids_all[0, prompt_len:]
447-
# stop_seq = [X, 55] where X = token_ids_all[0, prompt_len + step_idx]
448-
# For accept_idx=0: pre_ids_idx = step_idx + 0 - (2-1-0) = 5-1 = 4
449-
# -> pre_ids_now[4] = token_ids_all[0, prompt_len + 4]
450-
# For accept_idx=1 (second token is accept_tokens[0,0]=55):
451-
# i=1: accept_tokens_now[1-(2-1-1)-1] = accept_tokens_now[0] = 55
452-
# i=0: pre_ids_idx = step_idx + 1 - (2-1-0) = 5+1-1 = 5 -> pre_ids_now[5]
453-
target_val = int(inputs["token_ids_all"][0, prompt_len + 5])
473+
# At accept_idx=0 with stop_seq_len=2:
474+
# i=1(last): j=0, 0<=0 -> accept_tokens[0]=55
475+
# i=0: j=1, 1<=0 false -> pre_ids[7-2+0-1]=pre_ids[4]=token_ids_all[0, prompt_len+4]
476+
target_val = int(inputs["token_ids_all"][0, prompt_len + 4])
454477
inputs["stop_seqs"][0, 0, :2] = [target_val, 55]
455478
inputs["stop_seqs_len"][0, 0] = 2
456479
inputs["stop_flags"][:] = False

0 commit comments

Comments
 (0)