-
Notifications
You must be signed in to change notification settings - Fork 752
[Cherry-Pick][Optimization][Speculative Decoding]opt mtp logprob (#7883) #7884
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a988c1a
4f35c72
f8b507f
f90dd1c
2654153
70762af
3d636b5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -119,9 +119,11 @@ void MTPSaveFirstTokenWithTopK(const paddle::Tensor& sampled_token_ids, | |
| msg_sed.mtype = 1; | ||
| msg_sed.meta[0] = not_need_stop.data<bool>()[0] ? inference_msg_id_from_env | ||
| : -inference_msg_id_from_env; | ||
| msg_sed.meta[1] = message_flag; | ||
| msg_sed.meta[2] = bsz; | ||
| // Pack message_flag (low 8 bits) and max_num_logprobs (high 24 bits) into | ||
| // meta[1]. Receiver unpacks both to avoid reading unused topk slots. | ||
| int max_num_logprobs = logprob_token_ids.shape()[1]; | ||
| msg_sed.meta[1] = message_flag | (max_num_logprobs << 8); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 建议
建议加防御性截断: msg_sed.meta[1] = (message_flag & 0xFF) | (max_num_logprobs << 8);
|
||
| msg_sed.meta[2] = bsz; | ||
| for (int i = 0; i < bsz; i++) { | ||
| int cur_token_num; | ||
| if (seq_lens_decoder_data[i] < prompt_lens_data[i] || | ||
|
|
@@ -139,29 +141,24 @@ void MTPSaveFirstTokenWithTopK(const paddle::Tensor& sampled_token_ids, | |
| auto* cur_batch_msg_sed = &msg_sed.mtext[i]; | ||
| int token_offset = cu_batch_token_offset_data[i]; | ||
| for (int j = 0; j < cur_token_num; j++) { | ||
| // Use SPEC_LOGPROB_K+1 as stride (fixed struct layout), but only write | ||
| // max_num_logprobs columns to avoid filling unused topk slots. | ||
| auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)]; | ||
| auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)]; | ||
| if (j == 0) { | ||
| // first token has full logprobs | ||
| for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) { | ||
| for (int k = 0; k < max_num_logprobs; k++) { | ||
| if (k == 0) { | ||
| cur_tokens[k] = | ||
| (int)sampled_token_ids_data[i * max_draft_tokens + j]; | ||
| cur_scores[k] = | ||
| logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) + | ||
| k]; | ||
| } else if (k < max_num_logprobs) { | ||
| // only for first token | ||
| cur_tokens[k] = | ||
| (int)logprob_token_ids_data[(token_offset + j) * | ||
| (SPEC_LOGPROB_K + 1) + | ||
| k]; | ||
| cur_scores[k] = | ||
| logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) + | ||
| k]; | ||
| logprob_scores_data[(token_offset + j) * max_num_logprobs + k]; | ||
| } else { | ||
| cur_tokens[k] = -1; | ||
| cur_scores[k] = 0.0; | ||
| cur_tokens[k] = (int) | ||
| logprob_token_ids_data[(token_offset + j) * max_num_logprobs + | ||
| k]; | ||
| cur_scores[k] = | ||
| logprob_scores_data[(token_offset + j) * max_num_logprobs + k]; | ||
| } | ||
| } | ||
| cur_batch_msg_sed->ranks[j] = (int)logprob_ranks_data[token_offset + j]; | ||
|
|
@@ -174,7 +171,8 @@ void MTPSaveFirstTokenWithTopK(const paddle::Tensor& sampled_token_ids, | |
| #ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG | ||
| std::cout << "msg data: " << std::endl; | ||
| std::cout << "stop_flag: " << msg_sed.meta[0] | ||
| << ", message_flag: " << msg_sed.meta[1] | ||
| << ", message_flag: " << (msg_sed.meta[1] & 0xFF) | ||
| << ", max_num_logprobs: " << (msg_sed.meta[1] >> 8) | ||
| << ", bsz: " << msg_sed.meta[2] << std::endl; | ||
| for (int i = 0; i < bsz; i++) { | ||
| int cur_token_num = msg_sed.meta[3 + i]; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -75,8 +75,11 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens, | |
|
|
||
| int bsz = msg_rcv.meta[2]; | ||
| output_tokens_data[0] = (int64_t)msg_rcv.meta[0]; | ||
| // Unpack message_flag (low 8 bits) and actual_topk (high 24 bits) from | ||
| // meta[1]. Keep packed value; Python unpacks message_flag and actual_topk. | ||
| output_tokens_data[1] = (int64_t)msg_rcv.meta[1]; | ||
| output_tokens_data[2] = (int64_t)msg_rcv.meta[2]; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ❓ 疑问 当 接收端 tokens[:, :, :0] # shape=[batch, MAX_DRAFT_TOKENS, 0]
token_ids = [row[0] for row in tokens_lists[i][:accept_num[i]]] # IndexError!请确认:
|
||
| int actual_topk = msg_rcv.meta[1] >> 8; | ||
|
|
||
| int output_tokens_offset = 3 + SPEC_LOGPROB_MAX_BSZ; | ||
| for (int i = 0; i < bsz; i++) { | ||
|
|
@@ -89,7 +92,7 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens, | |
| output_scores_data + i * (MAX_DRAFT_TOKEN_NUM * (SPEC_LOGPROB_K + 1)); | ||
| auto* cur_batch_msg_rcv = &msg_rcv.mtext[i]; | ||
| for (int j = 0; j < cur_token_num; j++) { | ||
| for (int k = 0; k < real_k + 1; k++) { | ||
| for (int k = 0; k < actual_topk; k++) { | ||
| cur_output_token[j * (SPEC_LOGPROB_K + 1) + k] = | ||
| (int64_t)cur_batch_msg_rcv->tokens[j * (SPEC_LOGPROB_K + 1) + k]; | ||
| cur_output_score[j * (SPEC_LOGPROB_K + 1) + k] = | ||
|
|
@@ -102,7 +105,8 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens, | |
| #ifdef SPECULATE_GET_WITH_OUTPUT_DEBUG | ||
| std::cout << "msg data: " << std::endl; | ||
| std::cout << "stop_flag: " << output_tokens_data[0] | ||
| << ", message_flag: " << output_tokens_data[1] | ||
| << ", message_flag: " << (output_tokens_data[1] & 0xFF) | ||
| << ", max_num_logprobs: " << (output_tokens_data[1] >> 8) | ||
| << ", bsz: " << output_tokens_data[2] << std::endl; | ||
| for (int i = 0; i < output_tokens_data[2]; i++) { | ||
| int cur_token_num = output_tokens_data[3 + i]; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -211,8 +211,9 @@ def test_speculative_decoding_use_logprobs(self): | |
|
|
||
| # stop_flag | ||
| processor.output_tokens[0, 0].set_tensor(paddle.to_tensor(2)) | ||
This comment was marked as outdated.
Sorry, something went wrong. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ❓ 疑问 bit 宽注释不一致 此处注释写 实际实现是 |
||
| # mtype target = 3, decode = 4 | ||
| processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3)) | ||
| # meta[1] packs mtype (low 8 bits) and actual_topk (high 16 bits) | ||
| actual_topk = K + 1 | ||
| processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3 | (actual_topk << 8))) | ||
| # batch | ||
| processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(2)) | ||
| # accept_num | ||
|
|
@@ -244,12 +245,12 @@ def test_speculative_decoding_use_logprobs(self): | |
| assert len(request_output.outputs.token_ids) == accept_num[i] | ||
| assert len(request_output.outputs.top_logprobs) == 3 | ||
| # tokens, scores, ranks | ||
| assert len(request_output.outputs.top_logprobs[0][0]) == K + 1 | ||
| assert len(request_output.outputs.top_logprobs[1][0]) == K + 1 | ||
| assert len(request_output.outputs.top_logprobs[0][0]) == actual_topk | ||
| assert len(request_output.outputs.top_logprobs[1][0]) == actual_topk | ||
| assert len(request_output.outputs.top_logprobs[2]) == accept_num[i] | ||
|
|
||
| # mtype = 4 | ||
| processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(4)) | ||
| processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(4 | (actual_topk << 8))) | ||
| processor._process_batch_output() | ||
| cached_generated_tokens: MockCachedGeneratedTokens = processor.cached_generated_tokens | ||
| for c in cached_generated_tokens.cache: | ||
|
|
@@ -258,8 +259,8 @@ def test_speculative_decoding_use_logprobs(self): | |
| assert len(request_output.outputs.top_logprobs) == 3 | ||
| assert len(request_output.outputs.draft_top_logprobs) == 3 | ||
| # tokens, scores, ranks | ||
| assert len(request_output.outputs.draft_top_logprobs[0][0]) == K + 1 | ||
| assert len(request_output.outputs.draft_top_logprobs[1][0]) == K + 1 | ||
| assert len(request_output.outputs.draft_top_logprobs[0][0]) == actual_topk | ||
| assert len(request_output.outputs.draft_top_logprobs[1][0]) == actual_topk | ||
| assert len(request_output.outputs.draft_top_logprobs[2]) == accept_num[i] | ||
|
|
||
| def test_process_batch_output_aborted_task_negative_token_speculative_decoding(self): | ||
|
|
@@ -282,8 +283,8 @@ def test_process_batch_output_aborted_task_negative_token_speculative_decoding(s | |
| # Set up output tokens with negative token | ||
| # stop_flag | ||
| processor.output_tokens[0, 0].set_tensor(paddle.to_tensor(2)) | ||
| # mtype target = 3 | ||
| processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3)) | ||
| # mtype target = 3, actual_topk packed in high bits | ||
| processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3 | ((K + 1) << 8))) | ||
| # batch = 2 (so batch_id=0 is < batch_size-1=1) | ||
| processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(2)) | ||
| # Set accept_num = PREEMPTED_TOKEN_ID (-9) for first task to trigger abort logic | ||
|
|
||
This comment was marked as outdated.
Sorry, something went wrong.
Uh oh!
There was an error while loading. Please reload this page.