Skip to content

Commit 88a7479

Browse files
TBD1rain7996
andauthored
[BugFix] Fix stop token sequence pointer offset and actual length computation (#7721)
Co-authored-by: songyuxing <songyuxing@baidu.com>
1 parent a2f636e commit 88a7479

3 files changed

Lines changed: 127 additions & 5 deletions

File tree

custom_ops/gpu_ops/stop_generation_multi_ends.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,9 @@ __global__ void set_value_by_flags(bool *stop_flags,
7979
// dealing stop_seqs
8080
const int stop_seq_len = (stop_seqs_len + bid * stop_seqs_bs)[tid];
8181
if (stop_seq_len <= 0) return;
82-
const int64_t *stop_seq_now =
83-
stop_seqs + bid * stop_seqs_bs + tid * stop_seqs_max_len;
82+
const int64_t *stop_seq_now = stop_seqs +
83+
bid * stop_seqs_bs * stop_seqs_max_len +
84+
tid * stop_seqs_max_len;
8485
const int64_t *pre_ids_now =
8586
token_ids_all + bid * max_model_len + prompt_lens[bid];
8687
const int64_t step_idx_now = step_idx[bid];

fastdeploy/input/utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,24 +82,27 @@ def process_stop_token_ids(
8282
update_stop_seq_fn: Callable[[List[str]], Tuple[List[List[int]], List[int]]],
8383
) -> None:
8484
stop_token_ids_final = []
85+
stop_seqs_len_final = []
8586

8687
if request.get("stop_token_ids") is not None:
8788
stop_token_ids = request.get("stop_token_ids")
8889
if isinstance(stop_token_ids, list) and len(stop_token_ids) > 0:
8990
if isinstance(stop_token_ids[0], int):
9091
# List[int] -> List[List[int]]
9192
stop_token_ids_final.extend([[t] for t in stop_token_ids])
93+
stop_seqs_len_final.extend([1] * len(stop_token_ids))
9294
elif isinstance(stop_token_ids[0], list):
9395
# Already List[List[int]]
9496
stop_token_ids_final.extend(stop_token_ids)
97+
stop_seqs_len_final.extend([len(seq) for seq in stop_token_ids])
9598

9699
stop_sequences = request.get("stop", [])
97100
if stop_sequences:
98-
stop_seqs, _ = update_stop_seq_fn(stop_sequences)
101+
stop_seqs, stop_seqs_actual_lens = update_stop_seq_fn(stop_sequences)
99102
stop_token_ids_final.extend(stop_seqs)
103+
stop_seqs_len_final.extend(stop_seqs_actual_lens)
100104

101105
# Update request
102106
if stop_token_ids_final:
103-
stop_seqs_len = [len(seq) for seq in stop_token_ids_final]
104107
request["stop_token_ids"] = stop_token_ids_final
105-
request["stop_seqs_len"] = stop_seqs_len
108+
request["stop_seqs_len"] = stop_seqs_len_final
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for process_stop_token_ids in fastdeploy.input.utils.common."""
16+
17+
from fastdeploy.input.utils.common import process_stop_token_ids
18+
19+
20+
def _mock_update_stop_seq_fn(stop_sequences):
21+
"""Mock update_stop_seq that simulates tokenization and padding.
22+
23+
Simulates: "```" -> [101], "end" -> [201, 202]
24+
Returns padded sequences and actual lengths.
25+
"""
26+
token_map = {
27+
"```": [101],
28+
"end": [201, 202],
29+
"\n\n": [301, 302, 303],
30+
"stop": [401],
31+
}
32+
seqs = [token_map.get(s, [999]) for s in stop_sequences]
33+
actual_lens = [len(s) for s in seqs]
34+
# Simulate pad_batch_data: pad to max length with -1
35+
max_len = max(len(s) for s in seqs) if seqs else 0
36+
padded = [s + [-1] * (max_len - len(s)) for s in seqs]
37+
return padded, actual_lens
38+
39+
40+
def test_stop_token_ids_list_int():
41+
"""stop_token_ids as List[int] should produce length-1 sequences."""
42+
request = {"stop_token_ids": [100, 200, 300]}
43+
process_stop_token_ids(request, _mock_update_stop_seq_fn)
44+
45+
assert request["stop_token_ids"] == [[100], [200], [300]]
46+
assert request["stop_seqs_len"] == [1, 1, 1]
47+
48+
49+
def test_stop_token_ids_list_list_int():
50+
"""stop_token_ids as List[List[int]] should preserve actual lengths."""
51+
request = {"stop_token_ids": [[10, 20], [30]]}
52+
process_stop_token_ids(request, _mock_update_stop_seq_fn)
53+
54+
assert request["stop_token_ids"] == [[10, 20], [30]]
55+
assert request["stop_seqs_len"] == [2, 1]
56+
57+
58+
def test_stop_strings_uses_actual_lengths():
59+
"""stop strings with different tokenized lengths should use actual lengths, not padded."""
60+
request = {"stop": ["```", "end"]}
61+
process_stop_token_ids(request, _mock_update_stop_seq_fn)
62+
63+
# "```" -> [101, -1] (padded), actual len 1
64+
# "end" -> [201, 202], actual len 2
65+
assert request["stop_token_ids"] == [[101, -1], [201, 202]]
66+
assert request["stop_seqs_len"] == [1, 2]
67+
68+
69+
def test_mixed_stop_token_ids_and_stop_strings():
70+
"""Both stop_token_ids and stop strings should have correct lengths."""
71+
request = {
72+
"stop_token_ids": [100],
73+
"stop": ["```", "\n\n"],
74+
}
75+
process_stop_token_ids(request, _mock_update_stop_seq_fn)
76+
77+
# stop_token_ids: [100] -> [[100]], len [1]
78+
# "```" -> [101, -1, -1] (padded to 3), actual len 1
79+
# "\n\n" -> [301, 302, 303], actual len 3
80+
assert request["stop_token_ids"] == [[100], [101, -1, -1], [301, 302, 303]]
81+
assert request["stop_seqs_len"] == [1, 1, 3]
82+
83+
84+
def test_empty_request():
85+
"""No stop tokens or strings should leave request unchanged."""
86+
request = {}
87+
process_stop_token_ids(request, _mock_update_stop_seq_fn)
88+
89+
assert "stop_token_ids" not in request
90+
assert "stop_seqs_len" not in request
91+
92+
93+
def test_stop_token_ids_none():
94+
"""stop_token_ids=None should be treated as absent."""
95+
request = {"stop_token_ids": None, "stop": ["stop"]}
96+
process_stop_token_ids(request, _mock_update_stop_seq_fn)
97+
98+
assert request["stop_token_ids"] == [[401]]
99+
assert request["stop_seqs_len"] == [1]
100+
101+
102+
def test_stop_token_ids_empty_list():
103+
"""stop_token_ids=[] should be treated as absent."""
104+
request = {"stop_token_ids": []}
105+
process_stop_token_ids(request, _mock_update_stop_seq_fn)
106+
107+
assert "stop_seqs_len" not in request
108+
109+
110+
if __name__ == "__main__":
111+
test_stop_token_ids_list_int()
112+
test_stop_token_ids_list_list_int()
113+
test_stop_strings_uses_actual_lengths()
114+
test_mixed_stop_token_ids_and_stop_strings()
115+
test_empty_request()
116+
test_stop_token_ids_none()
117+
test_stop_token_ids_empty_list()
118+
print("All tests passed.")

0 commit comments

Comments
 (0)