Skip to content

Commit 592b992

Browse files
[Optimization][Speculative Decoding]opt mtp logprob (#7883)
* opt mtp logprob * fix * fix test and log * fix bits * Adapt logprobs baseline update in test_ernie_21b_mtp_multistep.py --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
1 parent afbd674 commit 592b992

8 files changed

Lines changed: 76 additions & 68 deletions

File tree

custom_ops/gpu_ops/speculate_decoding/draft_model/mtp_save_first_token_with_topk.cc

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,11 @@ void MTPSaveFirstTokenWithTopK(const paddle::Tensor& sampled_token_ids,
119119
msg_sed.mtype = 1;
120120
msg_sed.meta[0] = not_need_stop.data<bool>()[0] ? inference_msg_id_from_env
121121
: -inference_msg_id_from_env;
122-
msg_sed.meta[1] = message_flag;
123-
msg_sed.meta[2] = bsz;
122+
// Pack message_flag (low 8 bits) and max_num_logprobs (high 24 bits) into
123+
// meta[1]. Receiver unpacks both to avoid reading unused topk slots.
124124
int max_num_logprobs = logprob_token_ids.shape()[1];
125+
msg_sed.meta[1] = message_flag | (max_num_logprobs << 8);
126+
msg_sed.meta[2] = bsz;
125127
for (int i = 0; i < bsz; i++) {
126128
int cur_token_num;
127129
if (seq_lens_decoder_data[i] < prompt_lens_data[i] ||
@@ -139,29 +141,24 @@ void MTPSaveFirstTokenWithTopK(const paddle::Tensor& sampled_token_ids,
139141
auto* cur_batch_msg_sed = &msg_sed.mtext[i];
140142
int token_offset = cu_batch_token_offset_data[i];
141143
for (int j = 0; j < cur_token_num; j++) {
144+
// Use SPEC_LOGPROB_K+1 as stride (fixed struct layout), but only write
145+
// max_num_logprobs columns to avoid filling unused topk slots.
142146
auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)];
143147
auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)];
144148
if (j == 0) {
145149
// first token has full logprobs
146-
for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) {
150+
for (int k = 0; k < max_num_logprobs; k++) {
147151
if (k == 0) {
148152
cur_tokens[k] =
149153
(int)sampled_token_ids_data[i * max_draft_tokens + j];
150154
cur_scores[k] =
151-
logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
152-
k];
153-
} else if (k < max_num_logprobs) {
154-
// only for first token
155-
cur_tokens[k] =
156-
(int)logprob_token_ids_data[(token_offset + j) *
157-
(SPEC_LOGPROB_K + 1) +
158-
k];
159-
cur_scores[k] =
160-
logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
161-
k];
155+
logprob_scores_data[(token_offset + j) * max_num_logprobs + k];
162156
} else {
163-
cur_tokens[k] = -1;
164-
cur_scores[k] = 0.0;
157+
cur_tokens[k] = (int)
158+
logprob_token_ids_data[(token_offset + j) * max_num_logprobs +
159+
k];
160+
cur_scores[k] =
161+
logprob_scores_data[(token_offset + j) * max_num_logprobs + k];
165162
}
166163
}
167164
cur_batch_msg_sed->ranks[j] = (int)logprob_ranks_data[token_offset + j];
@@ -174,7 +171,8 @@ void MTPSaveFirstTokenWithTopK(const paddle::Tensor& sampled_token_ids,
174171
#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG
175172
std::cout << "msg data: " << std::endl;
176173
std::cout << "stop_flag: " << msg_sed.meta[0]
177-
<< ", message_flag: " << msg_sed.meta[1]
174+
<< ", message_flag: " << (msg_sed.meta[1] & 0xFF)
175+
<< ", max_num_logprobs: " << (msg_sed.meta[1] >> 8)
178176
<< ", bsz: " << msg_sed.meta[2] << std::endl;
179177
for (int i = 0; i < bsz; i++) {
180178
int cur_token_num = msg_sed.meta[3 + i];

custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,11 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens,
7575

7676
int bsz = msg_rcv.meta[2];
7777
output_tokens_data[0] = (int64_t)msg_rcv.meta[0];
78+
// Unpack message_flag (low 8 bits) and actual_topk (high 24 bits) from
79+
// meta[1]. Keep packed value; Python unpacks message_flag and actual_topk.
7880
output_tokens_data[1] = (int64_t)msg_rcv.meta[1];
7981
output_tokens_data[2] = (int64_t)msg_rcv.meta[2];
82+
int actual_topk = msg_rcv.meta[1] >> 8;
8083

8184
int output_tokens_offset = 3 + SPEC_LOGPROB_MAX_BSZ;
8285
for (int i = 0; i < bsz; i++) {
@@ -89,7 +92,7 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens,
8992
output_scores_data + i * (MAX_DRAFT_TOKEN_NUM * (SPEC_LOGPROB_K + 1));
9093
auto* cur_batch_msg_rcv = &msg_rcv.mtext[i];
9194
for (int j = 0; j < cur_token_num; j++) {
92-
for (int k = 0; k < real_k + 1; k++) {
95+
for (int k = 0; k < actual_topk; k++) {
9396
cur_output_token[j * (SPEC_LOGPROB_K + 1) + k] =
9497
(int64_t)cur_batch_msg_rcv->tokens[j * (SPEC_LOGPROB_K + 1) + k];
9598
cur_output_score[j * (SPEC_LOGPROB_K + 1) + k] =
@@ -102,7 +105,8 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens,
102105
#ifdef SPECULATE_GET_WITH_OUTPUT_DEBUG
103106
std::cout << "msg data: " << std::endl;
104107
std::cout << "stop_flag: " << output_tokens_data[0]
105-
<< ", message_flag: " << output_tokens_data[1]
108+
<< ", message_flag: " << (output_tokens_data[1] & 0xFF)
109+
<< ", max_num_logprobs: " << (output_tokens_data[1] >> 8)
106110
<< ", bsz: " << output_tokens_data[2] << std::endl;
107111
for (int i = 0; i < output_tokens_data[2]; i++) {
108112
int cur_token_num = output_tokens_data[3 + i];

custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,11 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
121121
msg_sed.mtype = 1;
122122
msg_sed.meta[0] = not_need_stop.data<bool>()[0] ? inference_msg_id_from_env
123123
: -inference_msg_id_from_env;
124-
msg_sed.meta[1] = message_flag;
125-
msg_sed.meta[2] = bsz;
124+
// Pack message_flag (low 8 bits) and max_num_logprobs (high 24 bits) into
125+
// meta[1]. Receiver unpacks both to avoid reading unused topk slots.
126126
int max_num_logprobs = logprob_token_ids.shape()[1];
127+
msg_sed.meta[1] = message_flag | (max_num_logprobs << 8);
128+
msg_sed.meta[2] = bsz;
127129
for (int i = 0; i < bsz; i++) {
128130
int cur_token_num;
129131
if (seq_lens_decoder_data[i] < prompt_lens_data[i]) {
@@ -139,24 +141,20 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
139141
auto* cur_batch_msg_sed = &msg_sed.mtext[i];
140142
int token_offset = cu_batch_token_offset_data[i];
141143
for (int j = 0; j < cur_token_num; j++) {
144+
// Use SPEC_LOGPROB_K+1 as stride (fixed struct layout), but only write
145+
// max_num_logprobs columns to avoid filling unused topk slots.
142146
auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)];
143147
auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)];
144-
for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) {
148+
for (int k = 0; k < max_num_logprobs; k++) {
145149
if (k == 0) {
146150
cur_tokens[k] = (int)sampled_token_ids_data[i * max_draft_tokens + j];
147151
cur_scores[k] =
148-
logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
149-
k];
150-
} else if (k < max_num_logprobs) {
152+
logprob_scores_data[(token_offset + j) * max_num_logprobs + k];
153+
} else {
151154
cur_tokens[k] = (int)
152-
logprob_token_ids_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
153-
k];
155+
logprob_token_ids_data[(token_offset + j) * max_num_logprobs + k];
154156
cur_scores[k] =
155-
logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
156-
k];
157-
} else {
158-
cur_tokens[k] = -1;
159-
cur_scores[k] = 0.0;
157+
logprob_scores_data[(token_offset + j) * max_num_logprobs + k];
160158
}
161159
}
162160
cur_batch_msg_sed->ranks[j] = (int)logprob_ranks_data[token_offset + j];
@@ -165,7 +163,8 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
165163
#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG
166164
std::cout << "msg data: " << std::endl;
167165
std::cout << "stop_flag: " << msg_sed.meta[0]
168-
<< ", message_flag: " << msg_sed.meta[1]
166+
<< ", message_flag: " << (msg_sed.meta[1] & 0xFF)
167+
<< ", max_num_logprobs: " << (msg_sed.meta[1] >> 8)
169168
<< ", bsz: " << msg_sed.meta[2] << std::endl;
170169
for (int i = 0; i < bsz; i++) {
171170
int cur_token_num = msg_sed.meta[3 + i];

fastdeploy/output/token_processor.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -739,12 +739,15 @@ def _process_batch_draft_tokens(self, mtype, batch, accept_num, tokens, scores,
739739
metrics=None,
740740
)
741741

742-
token_ids = tokens[i][:, 0].tolist()[: accept_num[i]]
742+
tokens_i = tokens[i].tolist()
743+
scores_i = scores[i].tolist()
744+
ranks_i = ranks[i].tolist()
745+
token_ids = [row[0] for row in tokens_i[: accept_num[i]]]
743746
for batch_token_index in range(len(token_ids)):
744-
result.outputs.logprob = float(scores[i, batch_token_index, 0])
745-
topk_token_ids = tokens[i, batch_token_index, :].tolist()
746-
topk_logprobs = scores[i, batch_token_index, :].tolist()
747-
sampled_rank = ranks[i, batch_token_index].item()
747+
result.outputs.logprob = scores_i[batch_token_index][0]
748+
topk_token_ids = tokens_i[batch_token_index]
749+
topk_logprobs = scores_i[batch_token_index]
750+
sampled_rank = ranks_i[batch_token_index]
748751

749752
if result.outputs.draft_top_logprobs is None:
750753
result.outputs.draft_top_logprobs = LogprobsLists(
@@ -771,16 +774,19 @@ def _process_batch_output(self):
771774
mtype = 3
772775
if self.cfg.speculative_config.method:
773776
if self.use_logprobs:
774-
mtype = int(self.output_tokens[1, 0].item())
777+
# meta[1] packs message_flag (low 8 bits) and actual_topk (high 24 bits).
778+
packed_meta1 = int(self.output_tokens[1, 0].item())
779+
mtype = packed_meta1 & 0xFF
780+
actual_topk = packed_meta1 >> 8
775781
batch = self.output_tokens[2, 0]
776782
accept_num = [int(num[0]) for num in self.output_tokens[3 : batch + 3]]
777783
tokens = tokens[3 + MAX_BSZ : 3 + MAX_BSZ + batch * MAX_DRAFT_TOKENS * (K + 1)].reshape(
778784
[batch, MAX_DRAFT_TOKENS, K + 1]
779-
)
785+
)[:, :, :actual_topk]
780786
scores = (
781787
self.output_scores[: batch * MAX_DRAFT_TOKENS * (K + 1)]
782788
.numpy()
783-
.reshape([batch, MAX_DRAFT_TOKENS, K + 1])
789+
.reshape([batch, MAX_DRAFT_TOKENS, K + 1])[:, :, :actual_topk]
784790
)
785791
ranks = self.output_ranks[: batch * MAX_DRAFT_TOKENS].numpy().reshape([batch, MAX_DRAFT_TOKENS])
786792

@@ -789,6 +795,10 @@ def _process_batch_output(self):
789795
batch_result = self._process_batch_draft_tokens(mtype, batch, accept_num, tokens, scores, ranks)
790796
self.postprocess(batch_result, mtype)
791797
return
798+
# Pre-convert full arrays to Python lists once for MTP target token path.
799+
tokens_lists = tokens.tolist()
800+
scores_lists = scores.tolist()
801+
ranks_list = ranks.tolist()
792802
else:
793803
batch = self.output_tokens[1]
794804
accept_num = tokens[2 : batch + 2]
@@ -856,7 +866,7 @@ def _process_batch_output(self):
856866
)
857867
token_ids = [RECOVERY_STOP_SIGNAL]
858868
elif self.use_logprobs:
859-
token_ids = tokens[i][:, 0].tolist()[: accept_num[i]]
869+
token_ids = [row[0] for row in tokens_lists[i][: accept_num[i]]]
860870
else:
861871
token_ids = tokens[
862872
2
@@ -988,10 +998,10 @@ def _process_batch_output(self):
988998
task.output_token_ids.append(token_id)
989999
if self.use_logprobs:
9901000
if self.cfg.speculative_config.method:
991-
result.outputs.logprob = float(scores[i, batch_token_index, 0])
992-
topk_token_ids = tokens[i, batch_token_index, :].tolist()
993-
topk_logprobs = scores[i, batch_token_index, :].tolist()
994-
sampled_rank = ranks[i, batch_token_index].item()
1001+
result.outputs.logprob = scores_lists[i][batch_token_index][0]
1002+
topk_token_ids = tokens_lists[i][batch_token_index]
1003+
topk_logprobs = scores_lists[i][batch_token_index]
1004+
sampled_rank = ranks_list[i][batch_token_index]
9951005
else:
9961006
# Use pre-converted lists (batch .tolist() done before the loop).
9971007
result.outputs.logprob = scores_lists[i][0]

fastdeploy/worker/gpu_model_runner.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,15 +1279,11 @@ def _prepare_inputs(self, cached_token_num=-1, cached_real_bsz=-1, is_dummy_or_p
12791279
req.sampling_params.top_p_normalized_logprobs and req.sampling_params.top_p != 1.0 for req in logprobs_reqs
12801280
)
12811281
if logprobs_reqs:
1282-
self.max_logprobs = (
1283-
max(
1284-
[
1285-
self.ori_vocab_size if req.sampling_params.logprobs < 0 else req.sampling_params.logprobs
1286-
for req in logprobs_reqs
1287-
]
1288-
)
1289-
if not self.speculative_decoding
1290-
else 20
1282+
self.max_logprobs = max(
1283+
[
1284+
self.ori_vocab_size if req.sampling_params.logprobs < 0 else req.sampling_params.logprobs
1285+
for req in logprobs_reqs
1286+
]
12911287
)
12921288
elif self.enable_logprob:
12931289
self.max_logprobs = None if not self.speculative_decoding else 0

tests/e2e/test_ernie_21b_mtp_multistep.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,11 @@ def test_prefix_cache_text(api_url):
212212
if os.getenv("BASELINE") == "1":
213213
baseline_manager.save("base_21b_step3", result)
214214
baseline_manager.save("base_21b_mtp_metrics_step3", speculate_metrics_2)
215-
baseline_manager.save("base_21b_logprobs_step3", logprobs_2)
215+
baseline_manager.save("base_21b_logprobs_step3_new", logprobs_2)
216216

217217
baseline_result = baseline_manager.load("base_21b_step3")
218218
baseline_mtp_metrics = baseline_manager.load("base_21b_mtp_metrics_step3")
219-
baseline_logprobs = baseline_manager.load("base_21b_logprobs_step3")
219+
baseline_logprobs = baseline_manager.load("base_21b_logprobs_step3_new")
220220

221221
assert logprobs == logprobs_2, (
222222
"logprobs 前后不一致\n"

tests/output/test_process_batch_output.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,9 @@ def test_speculative_decoding_use_logprobs(self):
210210

211211
# stop_flag
212212
processor.output_tokens[0, 0].set_tensor(paddle.to_tensor(2))
213-
# mtype target = 3, decode = 4
214-
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3))
213+
# meta[1] packs mtype (low 8 bits) and actual_topk (high 16 bits)
214+
actual_topk = K + 1
215+
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3 | (actual_topk << 8)))
215216
# batch
216217
processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(2))
217218
# accept_num
@@ -243,12 +244,12 @@ def test_speculative_decoding_use_logprobs(self):
243244
assert len(request_output.outputs.token_ids) == accept_num[i]
244245
assert len(request_output.outputs.top_logprobs) == 3
245246
# tokens, scores, ranks
246-
assert len(request_output.outputs.top_logprobs[0][0]) == K + 1
247-
assert len(request_output.outputs.top_logprobs[1][0]) == K + 1
247+
assert len(request_output.outputs.top_logprobs[0][0]) == actual_topk
248+
assert len(request_output.outputs.top_logprobs[1][0]) == actual_topk
248249
assert len(request_output.outputs.top_logprobs[2]) == accept_num[i]
249250

250251
# mtype = 4
251-
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(4))
252+
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(4 | (actual_topk << 8)))
252253
processor._process_batch_output()
253254
cached_generated_tokens: MockCachedGeneratedTokens = processor.cached_generated_tokens
254255
for c in cached_generated_tokens.cache:
@@ -257,8 +258,8 @@ def test_speculative_decoding_use_logprobs(self):
257258
assert len(request_output.outputs.top_logprobs) == 3
258259
assert len(request_output.outputs.draft_top_logprobs) == 3
259260
# tokens, scores, ranks
260-
assert len(request_output.outputs.draft_top_logprobs[0][0]) == K + 1
261-
assert len(request_output.outputs.draft_top_logprobs[1][0]) == K + 1
261+
assert len(request_output.outputs.draft_top_logprobs[0][0]) == actual_topk
262+
assert len(request_output.outputs.draft_top_logprobs[1][0]) == actual_topk
262263
assert len(request_output.outputs.draft_top_logprobs[2]) == accept_num[i]
263264

264265
def test_process_batch_output_aborted_task_negative_token_speculative_decoding(self):
@@ -281,8 +282,8 @@ def test_process_batch_output_aborted_task_negative_token_speculative_decoding(s
281282
# Set up output tokens with negative token
282283
# stop_flag
283284
processor.output_tokens[0, 0].set_tensor(paddle.to_tensor(2))
284-
# mtype target = 3
285-
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3))
285+
# mtype target = 3, actual_topk packed in high bits
286+
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3 | ((K + 1) << 8)))
286287
# batch = 2 (so batch_id=0 is < batch_size-1=1)
287288
processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(2))
288289
# Set accept_num = PREEMPTED_TOKEN_ID (-9) for first task to trigger abort logic

tests/output/test_token_processor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,7 @@ def test_process_batch_output_speculative_logprob_handles_draft_batch():
748748
)
749749
processor._batch_result_buffer = [target]
750750
processor.cached_generated_tokens = mock.Mock()
751-
processor.output_tokens[1, 0] = 4
751+
processor.output_tokens[1, 0] = 4 | ((K + 1) << 8)
752752
processor.output_tokens[2, 0] = 1
753753
processor.output_tokens[3, 0] = 1
754754

@@ -925,7 +925,7 @@ def test_process_batch_output_speculative_logprob_targets_topk_scores():
925925
task.trace_carrier = None
926926
rm.tasks_list[0] = task
927927
rm.req_dict[task.request_id] = task
928-
processor.output_tokens[1, 0] = 3
928+
processor.output_tokens[1, 0] = 3 | ((K + 1) << 8)
929929
processor.output_tokens[2, 0] = 1
930930
processor.output_tokens[3, 0] = 2
931931
token_block = np.arange(MAX_DRAFT_TOKENS * (K + 1), dtype=np.int64).reshape([-1, 1]) + 3

0 commit comments

Comments
 (0)