Skip to content

Commit c52b063

Browse files
[Cherry-Pick][Optimization][Speculative Decoding]opt mtp logprob (#7883) (#7884)
* opt mtp logprob * fix * fix test and log * fix bits * Adapt logprobs baseline update in test_ernie_21b_mtp_multistep.py --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
1 parent a095d6f commit c52b063

8 files changed

Lines changed: 76 additions & 68 deletions

File tree

custom_ops/gpu_ops/speculate_decoding/draft_model/mtp_save_first_token_with_topk.cc

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,11 @@ void MTPSaveFirstTokenWithTopK(const paddle::Tensor& sampled_token_ids,
119119
msg_sed.mtype = 1;
120120
msg_sed.meta[0] = not_need_stop.data<bool>()[0] ? inference_msg_id_from_env
121121
: -inference_msg_id_from_env;
122-
msg_sed.meta[1] = message_flag;
123-
msg_sed.meta[2] = bsz;
122+
// Pack message_flag (low 8 bits) and max_num_logprobs (high 24 bits) into
123+
// meta[1]. Receiver unpacks both to avoid reading unused topk slots.
124124
int max_num_logprobs = logprob_token_ids.shape()[1];
125+
msg_sed.meta[1] = message_flag | (max_num_logprobs << 8);
126+
msg_sed.meta[2] = bsz;
125127
for (int i = 0; i < bsz; i++) {
126128
int cur_token_num;
127129
if (seq_lens_decoder_data[i] < prompt_lens_data[i] ||
@@ -139,29 +141,24 @@ void MTPSaveFirstTokenWithTopK(const paddle::Tensor& sampled_token_ids,
139141
auto* cur_batch_msg_sed = &msg_sed.mtext[i];
140142
int token_offset = cu_batch_token_offset_data[i];
141143
for (int j = 0; j < cur_token_num; j++) {
144+
// Use SPEC_LOGPROB_K+1 as stride (fixed struct layout), but only write
145+
// max_num_logprobs columns to avoid filling unused topk slots.
142146
auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)];
143147
auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)];
144148
if (j == 0) {
145149
// first token has full logprobs
146-
for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) {
150+
for (int k = 0; k < max_num_logprobs; k++) {
147151
if (k == 0) {
148152
cur_tokens[k] =
149153
(int)sampled_token_ids_data[i * max_draft_tokens + j];
150154
cur_scores[k] =
151-
logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
152-
k];
153-
} else if (k < max_num_logprobs) {
154-
// only for first token
155-
cur_tokens[k] =
156-
(int)logprob_token_ids_data[(token_offset + j) *
157-
(SPEC_LOGPROB_K + 1) +
158-
k];
159-
cur_scores[k] =
160-
logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
161-
k];
155+
logprob_scores_data[(token_offset + j) * max_num_logprobs + k];
162156
} else {
163-
cur_tokens[k] = -1;
164-
cur_scores[k] = 0.0;
157+
cur_tokens[k] = (int)
158+
logprob_token_ids_data[(token_offset + j) * max_num_logprobs +
159+
k];
160+
cur_scores[k] =
161+
logprob_scores_data[(token_offset + j) * max_num_logprobs + k];
165162
}
166163
}
167164
cur_batch_msg_sed->ranks[j] = (int)logprob_ranks_data[token_offset + j];
@@ -174,7 +171,8 @@ void MTPSaveFirstTokenWithTopK(const paddle::Tensor& sampled_token_ids,
174171
#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG
175172
std::cout << "msg data: " << std::endl;
176173
std::cout << "stop_flag: " << msg_sed.meta[0]
177-
<< ", message_flag: " << msg_sed.meta[1]
174+
<< ", message_flag: " << (msg_sed.meta[1] & 0xFF)
175+
<< ", max_num_logprobs: " << (msg_sed.meta[1] >> 8)
178176
<< ", bsz: " << msg_sed.meta[2] << std::endl;
179177
for (int i = 0; i < bsz; i++) {
180178
int cur_token_num = msg_sed.meta[3 + i];

custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,11 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens,
7575

7676
int bsz = msg_rcv.meta[2];
7777
output_tokens_data[0] = (int64_t)msg_rcv.meta[0];
78+
// Unpack message_flag (low 8 bits) and actual_topk (high 24 bits) from
79+
// meta[1]. Keep packed value; Python unpacks message_flag and actual_topk.
7880
output_tokens_data[1] = (int64_t)msg_rcv.meta[1];
7981
output_tokens_data[2] = (int64_t)msg_rcv.meta[2];
82+
int actual_topk = msg_rcv.meta[1] >> 8;
8083

8184
int output_tokens_offset = 3 + SPEC_LOGPROB_MAX_BSZ;
8285
for (int i = 0; i < bsz; i++) {
@@ -89,7 +92,7 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens,
8992
output_scores_data + i * (MAX_DRAFT_TOKEN_NUM * (SPEC_LOGPROB_K + 1));
9093
auto* cur_batch_msg_rcv = &msg_rcv.mtext[i];
9194
for (int j = 0; j < cur_token_num; j++) {
92-
for (int k = 0; k < real_k + 1; k++) {
95+
for (int k = 0; k < actual_topk; k++) {
9396
cur_output_token[j * (SPEC_LOGPROB_K + 1) + k] =
9497
(int64_t)cur_batch_msg_rcv->tokens[j * (SPEC_LOGPROB_K + 1) + k];
9598
cur_output_score[j * (SPEC_LOGPROB_K + 1) + k] =
@@ -102,7 +105,8 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens,
102105
#ifdef SPECULATE_GET_WITH_OUTPUT_DEBUG
103106
std::cout << "msg data: " << std::endl;
104107
std::cout << "stop_flag: " << output_tokens_data[0]
105-
<< ", message_flag: " << output_tokens_data[1]
108+
<< ", message_flag: " << (output_tokens_data[1] & 0xFF)
109+
<< ", max_num_logprobs: " << (output_tokens_data[1] >> 8)
106110
<< ", bsz: " << output_tokens_data[2] << std::endl;
107111
for (int i = 0; i < output_tokens_data[2]; i++) {
108112
int cur_token_num = output_tokens_data[3 + i];

custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,11 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
121121
msg_sed.mtype = 1;
122122
msg_sed.meta[0] = not_need_stop.data<bool>()[0] ? inference_msg_id_from_env
123123
: -inference_msg_id_from_env;
124-
msg_sed.meta[1] = message_flag;
125-
msg_sed.meta[2] = bsz;
124+
// Pack message_flag (low 8 bits) and max_num_logprobs (high 24 bits) into
125+
// meta[1]. Receiver unpacks both to avoid reading unused topk slots.
126126
int max_num_logprobs = logprob_token_ids.shape()[1];
127+
msg_sed.meta[1] = message_flag | (max_num_logprobs << 8);
128+
msg_sed.meta[2] = bsz;
127129
for (int i = 0; i < bsz; i++) {
128130
int cur_token_num;
129131
if (seq_lens_decoder_data[i] < prompt_lens_data[i]) {
@@ -139,24 +141,20 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
139141
auto* cur_batch_msg_sed = &msg_sed.mtext[i];
140142
int token_offset = cu_batch_token_offset_data[i];
141143
for (int j = 0; j < cur_token_num; j++) {
144+
// Use SPEC_LOGPROB_K+1 as stride (fixed struct layout), but only write
145+
// max_num_logprobs columns to avoid filling unused topk slots.
142146
auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)];
143147
auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)];
144-
for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) {
148+
for (int k = 0; k < max_num_logprobs; k++) {
145149
if (k == 0) {
146150
cur_tokens[k] = (int)sampled_token_ids_data[i * max_draft_tokens + j];
147151
cur_scores[k] =
148-
logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
149-
k];
150-
} else if (k < max_num_logprobs) {
152+
logprob_scores_data[(token_offset + j) * max_num_logprobs + k];
153+
} else {
151154
cur_tokens[k] = (int)
152-
logprob_token_ids_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
153-
k];
155+
logprob_token_ids_data[(token_offset + j) * max_num_logprobs + k];
154156
cur_scores[k] =
155-
logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
156-
k];
157-
} else {
158-
cur_tokens[k] = -1;
159-
cur_scores[k] = 0.0;
157+
logprob_scores_data[(token_offset + j) * max_num_logprobs + k];
160158
}
161159
}
162160
cur_batch_msg_sed->ranks[j] = (int)logprob_ranks_data[token_offset + j];
@@ -165,7 +163,8 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
165163
#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG
166164
std::cout << "msg data: " << std::endl;
167165
std::cout << "stop_flag: " << msg_sed.meta[0]
168-
<< ", message_flag: " << msg_sed.meta[1]
166+
<< ", message_flag: " << (msg_sed.meta[1] & 0xFF)
167+
<< ", max_num_logprobs: " << (msg_sed.meta[1] >> 8)
169168
<< ", bsz: " << msg_sed.meta[2] << std::endl;
170169
for (int i = 0; i < bsz; i++) {
171170
int cur_token_num = msg_sed.meta[3 + i];

fastdeploy/output/token_processor.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -796,12 +796,15 @@ def _process_batch_draft_tokens(self, mtype, batch, accept_num, tokens, scores,
796796
metrics=None,
797797
)
798798

799-
token_ids = tokens[i][:, 0].tolist()[: accept_num[i]]
799+
tokens_i = tokens[i].tolist()
800+
scores_i = scores[i].tolist()
801+
ranks_i = ranks[i].tolist()
802+
token_ids = [row[0] for row in tokens_i[: accept_num[i]]]
800803
for batch_token_index in range(len(token_ids)):
801-
result.outputs.logprob = float(scores[i, batch_token_index, 0])
802-
topk_token_ids = tokens[i, batch_token_index, :].tolist()
803-
topk_logprobs = scores[i, batch_token_index, :].tolist()
804-
sampled_rank = ranks[i, batch_token_index].item()
804+
result.outputs.logprob = scores_i[batch_token_index][0]
805+
topk_token_ids = tokens_i[batch_token_index]
806+
topk_logprobs = scores_i[batch_token_index]
807+
sampled_rank = ranks_i[batch_token_index]
805808

806809
if result.outputs.draft_top_logprobs is None:
807810
result.outputs.draft_top_logprobs = LogprobsLists(
@@ -828,16 +831,19 @@ def _process_batch_output(self):
828831
mtype = 3
829832
if self.cfg.speculative_config.method:
830833
if self.use_logprobs:
831-
mtype = int(self.output_tokens[1, 0].item())
834+
# meta[1] packs message_flag (low 8 bits) and actual_topk (high 24 bits).
835+
packed_meta1 = int(self.output_tokens[1, 0].item())
836+
mtype = packed_meta1 & 0xFF
837+
actual_topk = packed_meta1 >> 8
832838
batch = self.output_tokens[2, 0]
833839
accept_num = [int(num[0]) for num in self.output_tokens[3 : batch + 3]]
834840
tokens = tokens[3 + MAX_BSZ : 3 + MAX_BSZ + batch * MAX_DRAFT_TOKENS * (K + 1)].reshape(
835841
[batch, MAX_DRAFT_TOKENS, K + 1]
836-
)
842+
)[:, :, :actual_topk]
837843
scores = (
838844
self.output_scores[: batch * MAX_DRAFT_TOKENS * (K + 1)]
839845
.numpy()
840-
.reshape([batch, MAX_DRAFT_TOKENS, K + 1])
846+
.reshape([batch, MAX_DRAFT_TOKENS, K + 1])[:, :, :actual_topk]
841847
)
842848
ranks = self.output_ranks[: batch * MAX_DRAFT_TOKENS].numpy().reshape([batch, MAX_DRAFT_TOKENS])
843849

@@ -846,6 +852,10 @@ def _process_batch_output(self):
846852
batch_result = self._process_batch_draft_tokens(mtype, batch, accept_num, tokens, scores, ranks)
847853
self.postprocess(batch_result, mtype)
848854
return
855+
# Pre-convert full arrays to Python lists once for MTP target token path.
856+
tokens_lists = tokens.tolist()
857+
scores_lists = scores.tolist()
858+
ranks_list = ranks.tolist()
849859
else:
850860
batch = self.output_tokens[1]
851861
accept_num = tokens[2 : batch + 2]
@@ -914,7 +924,7 @@ def _process_batch_output(self):
914924
llm_logger.info(f"recovery stop signal found at task {task_id}")
915925
token_ids = [RECOVERY_STOP_SIGNAL]
916926
elif self.use_logprobs:
917-
token_ids = tokens[i][:, 0].tolist()[: accept_num[i]]
927+
token_ids = [row[0] for row in tokens_lists[i][: accept_num[i]]]
918928
else:
919929
token_ids = tokens[
920930
2
@@ -1033,10 +1043,10 @@ def _process_batch_output(self):
10331043
task.output_token_ids.append(token_id)
10341044
if self.use_logprobs:
10351045
if self.cfg.speculative_config.method:
1036-
result.outputs.logprob = float(scores[i, batch_token_index, 0])
1037-
topk_token_ids = tokens[i, batch_token_index, :].tolist()
1038-
topk_logprobs = scores[i, batch_token_index, :].tolist()
1039-
sampled_rank = ranks[i, batch_token_index].item()
1046+
result.outputs.logprob = scores_lists[i][batch_token_index][0]
1047+
topk_token_ids = tokens_lists[i][batch_token_index]
1048+
topk_logprobs = scores_lists[i][batch_token_index]
1049+
sampled_rank = ranks_list[i][batch_token_index]
10401050
else:
10411051
# Use pre-converted lists (batch .tolist() done before the loop).
10421052
result.outputs.logprob = scores_lists[i][0]

fastdeploy/worker/gpu_model_runner.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1226,15 +1226,11 @@ def _prepare_inputs(self, cached_token_num=-1, cached_real_bsz=-1, is_dummy_or_p
12261226
req.sampling_params.top_p_normalized_logprobs and req.sampling_params.top_p != 1.0 for req in logprobs_reqs
12271227
)
12281228
if logprobs_reqs:
1229-
self.max_logprobs = (
1230-
max(
1231-
[
1232-
self.ori_vocab_size if req.sampling_params.logprobs < 0 else req.sampling_params.logprobs
1233-
for req in logprobs_reqs
1234-
]
1235-
)
1236-
if not self.speculative_decoding
1237-
else 20
1229+
self.max_logprobs = max(
1230+
[
1231+
self.ori_vocab_size if req.sampling_params.logprobs < 0 else req.sampling_params.logprobs
1232+
for req in logprobs_reqs
1233+
]
12381234
)
12391235
elif self.enable_logprob:
12401236
self.max_logprobs = None if not self.speculative_decoding else 0

tests/e2e/test_ernie_21b_mtp_multistep.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,11 @@ def test_prefix_cache_text(api_url):
212212
if os.getenv("BASELINE") == "1":
213213
baseline_manager.save("base_21b_step3", result)
214214
baseline_manager.save("base_21b_mtp_metrics_step3", speculate_metrics_2)
215-
baseline_manager.save("base_21b_logprobs_step3", logprobs_2)
215+
baseline_manager.save("base_21b_logprobs_step3_new", logprobs_2)
216216

217217
baseline_result = baseline_manager.load("base_21b_step3")
218218
baseline_mtp_metrics = baseline_manager.load("base_21b_mtp_metrics_step3")
219-
baseline_logprobs = baseline_manager.load("base_21b_logprobs_step3")
219+
baseline_logprobs = baseline_manager.load("base_21b_logprobs_step3_new")
220220

221221
assert logprobs == logprobs_2, (
222222
"logprobs 前后不一致\n"

tests/output/test_process_batch_output.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,9 @@ def test_speculative_decoding_use_logprobs(self):
211211

212212
# stop_flag
213213
processor.output_tokens[0, 0].set_tensor(paddle.to_tensor(2))
214-
# mtype target = 3, decode = 4
215-
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3))
214+
# meta[1] packs mtype (low 8 bits) and actual_topk (high 16 bits)
215+
actual_topk = K + 1
216+
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3 | (actual_topk << 8)))
216217
# batch
217218
processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(2))
218219
# accept_num
@@ -244,12 +245,12 @@ def test_speculative_decoding_use_logprobs(self):
244245
assert len(request_output.outputs.token_ids) == accept_num[i]
245246
assert len(request_output.outputs.top_logprobs) == 3
246247
# tokens, scores, ranks
247-
assert len(request_output.outputs.top_logprobs[0][0]) == K + 1
248-
assert len(request_output.outputs.top_logprobs[1][0]) == K + 1
248+
assert len(request_output.outputs.top_logprobs[0][0]) == actual_topk
249+
assert len(request_output.outputs.top_logprobs[1][0]) == actual_topk
249250
assert len(request_output.outputs.top_logprobs[2]) == accept_num[i]
250251

251252
# mtype = 4
252-
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(4))
253+
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(4 | (actual_topk << 8)))
253254
processor._process_batch_output()
254255
cached_generated_tokens: MockCachedGeneratedTokens = processor.cached_generated_tokens
255256
for c in cached_generated_tokens.cache:
@@ -258,8 +259,8 @@ def test_speculative_decoding_use_logprobs(self):
258259
assert len(request_output.outputs.top_logprobs) == 3
259260
assert len(request_output.outputs.draft_top_logprobs) == 3
260261
# tokens, scores, ranks
261-
assert len(request_output.outputs.draft_top_logprobs[0][0]) == K + 1
262-
assert len(request_output.outputs.draft_top_logprobs[1][0]) == K + 1
262+
assert len(request_output.outputs.draft_top_logprobs[0][0]) == actual_topk
263+
assert len(request_output.outputs.draft_top_logprobs[1][0]) == actual_topk
263264
assert len(request_output.outputs.draft_top_logprobs[2]) == accept_num[i]
264265

265266
def test_process_batch_output_aborted_task_negative_token_speculative_decoding(self):
@@ -282,8 +283,8 @@ def test_process_batch_output_aborted_task_negative_token_speculative_decoding(s
282283
# Set up output tokens with negative token
283284
# stop_flag
284285
processor.output_tokens[0, 0].set_tensor(paddle.to_tensor(2))
285-
# mtype target = 3
286-
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3))
286+
# mtype target = 3, actual_topk packed in high bits
287+
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3 | ((K + 1) << 8)))
287288
# batch = 2 (so batch_id=0 is < batch_size-1=1)
288289
processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(2))
289290
# Set accept_num = PREEMPTED_TOKEN_ID (-9) for first task to trigger abort logic

tests/output/test_token_processor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,7 @@ def test_process_batch_output_speculative_logprob_handles_draft_batch():
749749
)
750750
processor._batch_result_buffer = [target]
751751
processor.cached_generated_tokens = mock.Mock()
752-
processor.output_tokens[1, 0] = 4
752+
processor.output_tokens[1, 0] = 4 | ((K + 1) << 8)
753753
processor.output_tokens[2, 0] = 1
754754
processor.output_tokens[3, 0] = 1
755755

@@ -926,7 +926,7 @@ def test_process_batch_output_speculative_logprob_targets_topk_scores():
926926
task.trace_carrier = None
927927
rm.tasks_list[0] = task
928928
rm.req_dict[task.request_id] = task
929-
processor.output_tokens[1, 0] = 3
929+
processor.output_tokens[1, 0] = 3 | ((K + 1) << 8)
930930
processor.output_tokens[2, 0] = 1
931931
processor.output_tokens[3, 0] = 2
932932
token_block = np.arange(MAX_DRAFT_TOKENS * (K + 1), dtype=np.int64).reshape([-1, 1]) + 3

0 commit comments

Comments
 (0)