diff --git a/custom_ops/gpu_ops/speculate_decoding/draft_model/mtp_save_first_token_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/draft_model/mtp_save_first_token_with_topk.cc index 02203a51cff..0ec49b854ae 100644 --- a/custom_ops/gpu_ops/speculate_decoding/draft_model/mtp_save_first_token_with_topk.cc +++ b/custom_ops/gpu_ops/speculate_decoding/draft_model/mtp_save_first_token_with_topk.cc @@ -119,9 +119,11 @@ void MTPSaveFirstTokenWithTopK(const paddle::Tensor& sampled_token_ids, msg_sed.mtype = 1; msg_sed.meta[0] = not_need_stop.data()[0] ? inference_msg_id_from_env : -inference_msg_id_from_env; - msg_sed.meta[1] = message_flag; - msg_sed.meta[2] = bsz; + // Pack message_flag (low 8 bits) and max_num_logprobs (high 24 bits) into + // meta[1]. Receiver unpacks both to avoid reading unused topk slots. int max_num_logprobs = logprob_token_ids.shape()[1]; + msg_sed.meta[1] = message_flag | (max_num_logprobs << 8); + msg_sed.meta[2] = bsz; for (int i = 0; i < bsz; i++) { int cur_token_num; if (seq_lens_decoder_data[i] < prompt_lens_data[i] || @@ -139,29 +141,24 @@ void MTPSaveFirstTokenWithTopK(const paddle::Tensor& sampled_token_ids, auto* cur_batch_msg_sed = &msg_sed.mtext[i]; int token_offset = cu_batch_token_offset_data[i]; for (int j = 0; j < cur_token_num; j++) { + // Use SPEC_LOGPROB_K+1 as stride (fixed struct layout), but only write + // max_num_logprobs columns to avoid filling unused topk slots. auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)]; auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)]; if (j == 0) { // first token has full logprobs - for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) { + for (int k = 0; k < max_num_logprobs; k++) { if (k == 0) { cur_tokens[k] = (int)sampled_token_ids_data[i * max_draft_tokens + j]; cur_scores[k] = - logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) + - k]; - } else if (k < max_num_logprobs) { - // only for first token - cur_tokens[k] = - (int)logprob_token_ids_data[(token_offset + j) * - (SPEC_LOGPROB_K + 1) + - k]; - cur_scores[k] = - logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) + - k]; + logprob_scores_data[(token_offset + j) * max_num_logprobs + k]; } else { - cur_tokens[k] = -1; - cur_scores[k] = 0.0; + cur_tokens[k] = (int) + logprob_token_ids_data[(token_offset + j) * max_num_logprobs + + k]; + cur_scores[k] = + logprob_scores_data[(token_offset + j) * max_num_logprobs + k]; } } cur_batch_msg_sed->ranks[j] = (int)logprob_ranks_data[token_offset + j]; @@ -174,7 +171,8 @@ void MTPSaveFirstTokenWithTopK(const paddle::Tensor& sampled_token_ids, #ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG std::cout << "msg data: " << std::endl; std::cout << "stop_flag: " << msg_sed.meta[0] - << ", message_flag: " << msg_sed.meta[1] + << ", message_flag: " << (msg_sed.meta[1] & 0xFF) + << ", max_num_logprobs: " << (msg_sed.meta[1] >> 8) << ", bsz: " << msg_sed.meta[2] << std::endl; for (int i = 0; i < bsz; i++) { int cur_token_num = msg_sed.meta[3 + i]; diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc index 4fd7d4103c4..3e5ed2430b0 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc @@ -75,8 +75,11 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens, int bsz = msg_rcv.meta[2]; output_tokens_data[0] = (int64_t)msg_rcv.meta[0]; + // Unpack message_flag (low 8 bits) and actual_topk (high 24 bits) from + // meta[1]. Keep packed value; Python unpacks message_flag and actual_topk. output_tokens_data[1] = (int64_t)msg_rcv.meta[1]; output_tokens_data[2] = (int64_t)msg_rcv.meta[2]; + int actual_topk = msg_rcv.meta[1] >> 8; int output_tokens_offset = 3 + SPEC_LOGPROB_MAX_BSZ; for (int i = 0; i < bsz; i++) { @@ -89,7 +92,7 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens, output_scores_data + i * (MAX_DRAFT_TOKEN_NUM * (SPEC_LOGPROB_K + 1)); auto* cur_batch_msg_rcv = &msg_rcv.mtext[i]; for (int j = 0; j < cur_token_num; j++) { - for (int k = 0; k < real_k + 1; k++) { + for (int k = 0; k < actual_topk; k++) { cur_output_token[j * (SPEC_LOGPROB_K + 1) + k] = (int64_t)cur_batch_msg_rcv->tokens[j * (SPEC_LOGPROB_K + 1) + k]; cur_output_score[j * (SPEC_LOGPROB_K + 1) + k] = @@ -102,7 +105,8 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens, #ifdef SPECULATE_GET_WITH_OUTPUT_DEBUG std::cout << "msg data: " << std::endl; std::cout << "stop_flag: " << output_tokens_data[0] - << ", message_flag: " << output_tokens_data[1] + << ", message_flag: " << (output_tokens_data[1] & 0xFF) + << ", max_num_logprobs: " << (output_tokens_data[1] >> 8) << ", bsz: " << output_tokens_data[2] << std::endl; for (int i = 0; i < output_tokens_data[2]; i++) { int cur_token_num = output_tokens_data[3 + i]; diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc index 0b3de384cee..a11897b7ff3 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc @@ -121,9 +121,11 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, msg_sed.mtype = 1; msg_sed.meta[0] = not_need_stop.data()[0] ? inference_msg_id_from_env : -inference_msg_id_from_env; - msg_sed.meta[1] = message_flag; - msg_sed.meta[2] = bsz; + // Pack message_flag (low 8 bits) and max_num_logprobs (high 24 bits) into + // meta[1]. Receiver unpacks both to avoid reading unused topk slots. int max_num_logprobs = logprob_token_ids.shape()[1]; + msg_sed.meta[1] = message_flag | (max_num_logprobs << 8); + msg_sed.meta[2] = bsz; for (int i = 0; i < bsz; i++) { int cur_token_num; if (seq_lens_decoder_data[i] < prompt_lens_data[i]) { @@ -139,24 +141,20 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, auto* cur_batch_msg_sed = &msg_sed.mtext[i]; int token_offset = cu_batch_token_offset_data[i]; for (int j = 0; j < cur_token_num; j++) { + // Use SPEC_LOGPROB_K+1 as stride (fixed struct layout), but only write + // max_num_logprobs columns to avoid filling unused topk slots. auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)]; auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)]; - for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) { + for (int k = 0; k < max_num_logprobs; k++) { if (k == 0) { cur_tokens[k] = (int)sampled_token_ids_data[i * max_draft_tokens + j]; cur_scores[k] = - logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) + - k]; - } else if (k < max_num_logprobs) { + logprob_scores_data[(token_offset + j) * max_num_logprobs + k]; + } else { cur_tokens[k] = (int) - logprob_token_ids_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) + - k]; + logprob_token_ids_data[(token_offset + j) * max_num_logprobs + k]; cur_scores[k] = - logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) + - k]; - } else { - cur_tokens[k] = -1; - cur_scores[k] = 0.0; + logprob_scores_data[(token_offset + j) * max_num_logprobs + k]; } } cur_batch_msg_sed->ranks[j] = (int)logprob_ranks_data[token_offset + j]; @@ -165,7 +163,8 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, #ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG std::cout << "msg data: " << std::endl; std::cout << "stop_flag: " << msg_sed.meta[0] - << ", message_flag: " << msg_sed.meta[1] + << ", message_flag: " << (msg_sed.meta[1] & 0xFF) + << ", max_num_logprobs: " << (msg_sed.meta[1] >> 8) << ", bsz: " << msg_sed.meta[2] << std::endl; for (int i = 0; i < bsz; i++) { int cur_token_num = msg_sed.meta[3 + i]; diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index f0cd22e1309..bdd2bf305d4 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -739,12 +739,15 @@ def _process_batch_draft_tokens(self, mtype, batch, accept_num, tokens, scores, metrics=None, ) - token_ids = tokens[i][:, 0].tolist()[: accept_num[i]] + tokens_i = tokens[i].tolist() + scores_i = scores[i].tolist() + ranks_i = ranks[i].tolist() + token_ids = [row[0] for row in tokens_i[: accept_num[i]]] for batch_token_index in range(len(token_ids)): - result.outputs.logprob = float(scores[i, batch_token_index, 0]) - topk_token_ids = tokens[i, batch_token_index, :].tolist() - topk_logprobs = scores[i, batch_token_index, :].tolist() - sampled_rank = ranks[i, batch_token_index].item() + result.outputs.logprob = scores_i[batch_token_index][0] + topk_token_ids = tokens_i[batch_token_index] + topk_logprobs = scores_i[batch_token_index] + sampled_rank = ranks_i[batch_token_index] if result.outputs.draft_top_logprobs is None: result.outputs.draft_top_logprobs = LogprobsLists( @@ -771,16 +774,19 @@ def _process_batch_output(self): mtype = 3 if self.cfg.speculative_config.method: if self.use_logprobs: - mtype = int(self.output_tokens[1, 0].item()) + # meta[1] packs message_flag (low 8 bits) and actual_topk (high 24 bits). + packed_meta1 = int(self.output_tokens[1, 0].item()) + mtype = packed_meta1 & 0xFF + actual_topk = packed_meta1 >> 8 batch = self.output_tokens[2, 0] accept_num = [int(num[0]) for num in self.output_tokens[3 : batch + 3]] tokens = tokens[3 + MAX_BSZ : 3 + MAX_BSZ + batch * MAX_DRAFT_TOKENS * (K + 1)].reshape( [batch, MAX_DRAFT_TOKENS, K + 1] - ) + )[:, :, :actual_topk] scores = ( self.output_scores[: batch * MAX_DRAFT_TOKENS * (K + 1)] .numpy() - .reshape([batch, MAX_DRAFT_TOKENS, K + 1]) + .reshape([batch, MAX_DRAFT_TOKENS, K + 1])[:, :, :actual_topk] ) ranks = self.output_ranks[: batch * MAX_DRAFT_TOKENS].numpy().reshape([batch, MAX_DRAFT_TOKENS]) @@ -789,6 +795,10 @@ def _process_batch_output(self): batch_result = self._process_batch_draft_tokens(mtype, batch, accept_num, tokens, scores, ranks) self.postprocess(batch_result, mtype) return + # Pre-convert full arrays to Python lists once for MTP target token path. + tokens_lists = tokens.tolist() + scores_lists = scores.tolist() + ranks_list = ranks.tolist() else: batch = self.output_tokens[1] accept_num = tokens[2 : batch + 2] @@ -856,7 +866,7 @@ def _process_batch_output(self): ) token_ids = [RECOVERY_STOP_SIGNAL] elif self.use_logprobs: - token_ids = tokens[i][:, 0].tolist()[: accept_num[i]] + token_ids = [row[0] for row in tokens_lists[i][: accept_num[i]]] else: token_ids = tokens[ 2 @@ -988,10 +998,10 @@ def _process_batch_output(self): task.output_token_ids.append(token_id) if self.use_logprobs: if self.cfg.speculative_config.method: - result.outputs.logprob = float(scores[i, batch_token_index, 0]) - topk_token_ids = tokens[i, batch_token_index, :].tolist() - topk_logprobs = scores[i, batch_token_index, :].tolist() - sampled_rank = ranks[i, batch_token_index].item() + result.outputs.logprob = scores_lists[i][batch_token_index][0] + topk_token_ids = tokens_lists[i][batch_token_index] + topk_logprobs = scores_lists[i][batch_token_index] + sampled_rank = ranks_list[i][batch_token_index] else: # Use pre-converted lists (batch .tolist() done before the loop). result.outputs.logprob = scores_lists[i][0] diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 8ee28db1003..233f1461abf 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1279,15 +1279,11 @@ def _prepare_inputs(self, cached_token_num=-1, cached_real_bsz=-1, is_dummy_or_p req.sampling_params.top_p_normalized_logprobs and req.sampling_params.top_p != 1.0 for req in logprobs_reqs ) if logprobs_reqs: - self.max_logprobs = ( - max( - [ - self.ori_vocab_size if req.sampling_params.logprobs < 0 else req.sampling_params.logprobs - for req in logprobs_reqs - ] - ) - if not self.speculative_decoding - else 20 + self.max_logprobs = max( + [ + self.ori_vocab_size if req.sampling_params.logprobs < 0 else req.sampling_params.logprobs + for req in logprobs_reqs + ] ) elif self.enable_logprob: self.max_logprobs = None if not self.speculative_decoding else 0 diff --git a/tests/e2e/test_ernie_21b_mtp_multistep.py b/tests/e2e/test_ernie_21b_mtp_multistep.py index 8c4e3b6bab4..9f84b495f8b 100644 --- a/tests/e2e/test_ernie_21b_mtp_multistep.py +++ b/tests/e2e/test_ernie_21b_mtp_multistep.py @@ -212,11 +212,11 @@ def test_prefix_cache_text(api_url): if os.getenv("BASELINE") == "1": baseline_manager.save("base_21b_step3", result) baseline_manager.save("base_21b_mtp_metrics_step3", speculate_metrics_2) - baseline_manager.save("base_21b_logprobs_step3", logprobs_2) + baseline_manager.save("base_21b_logprobs_step3_new", logprobs_2) baseline_result = baseline_manager.load("base_21b_step3") baseline_mtp_metrics = baseline_manager.load("base_21b_mtp_metrics_step3") - baseline_logprobs = baseline_manager.load("base_21b_logprobs_step3") + baseline_logprobs = baseline_manager.load("base_21b_logprobs_step3_new") assert logprobs == logprobs_2, ( "logprobs 前后不一致\n" diff --git a/tests/output/test_process_batch_output.py b/tests/output/test_process_batch_output.py index 5cc7d4c88f9..fb82a35bed2 100644 --- a/tests/output/test_process_batch_output.py +++ b/tests/output/test_process_batch_output.py @@ -210,8 +210,9 @@ def test_speculative_decoding_use_logprobs(self): # stop_flag processor.output_tokens[0, 0].set_tensor(paddle.to_tensor(2)) - # mtype target = 3, decode = 4 - processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3)) + # meta[1] packs mtype (low 8 bits) and actual_topk (high 16 bits) + actual_topk = K + 1 + processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3 | (actual_topk << 8))) # batch processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(2)) # accept_num @@ -243,12 +244,12 @@ def test_speculative_decoding_use_logprobs(self): assert len(request_output.outputs.token_ids) == accept_num[i] assert len(request_output.outputs.top_logprobs) == 3 # tokens, scores, ranks - assert len(request_output.outputs.top_logprobs[0][0]) == K + 1 - assert len(request_output.outputs.top_logprobs[1][0]) == K + 1 + assert len(request_output.outputs.top_logprobs[0][0]) == actual_topk + assert len(request_output.outputs.top_logprobs[1][0]) == actual_topk assert len(request_output.outputs.top_logprobs[2]) == accept_num[i] # mtype = 4 - processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(4)) + processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(4 | (actual_topk << 8))) processor._process_batch_output() cached_generated_tokens: MockCachedGeneratedTokens = processor.cached_generated_tokens for c in cached_generated_tokens.cache: @@ -257,8 +258,8 @@ def test_speculative_decoding_use_logprobs(self): assert len(request_output.outputs.top_logprobs) == 3 assert len(request_output.outputs.draft_top_logprobs) == 3 # tokens, scores, ranks - assert len(request_output.outputs.draft_top_logprobs[0][0]) == K + 1 - assert len(request_output.outputs.draft_top_logprobs[1][0]) == K + 1 + assert len(request_output.outputs.draft_top_logprobs[0][0]) == actual_topk + assert len(request_output.outputs.draft_top_logprobs[1][0]) == actual_topk assert len(request_output.outputs.draft_top_logprobs[2]) == accept_num[i] def test_process_batch_output_aborted_task_negative_token_speculative_decoding(self): @@ -281,8 +282,8 @@ def test_process_batch_output_aborted_task_negative_token_speculative_decoding(s # Set up output tokens with negative token # stop_flag processor.output_tokens[0, 0].set_tensor(paddle.to_tensor(2)) - # mtype target = 3 - processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3)) + # mtype target = 3, actual_topk packed in high bits + processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3 | ((K + 1) << 8))) # batch = 2 (so batch_id=0 is < batch_size-1=1) processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(2)) # Set accept_num = PREEMPTED_TOKEN_ID (-9) for first task to trigger abort logic diff --git a/tests/output/test_token_processor.py b/tests/output/test_token_processor.py index 4ca70b9a689..0b60caa78be 100644 --- a/tests/output/test_token_processor.py +++ b/tests/output/test_token_processor.py @@ -748,7 +748,7 @@ def test_process_batch_output_speculative_logprob_handles_draft_batch(): ) processor._batch_result_buffer = [target] processor.cached_generated_tokens = mock.Mock() - processor.output_tokens[1, 0] = 4 + processor.output_tokens[1, 0] = 4 | ((K + 1) << 8) processor.output_tokens[2, 0] = 1 processor.output_tokens[3, 0] = 1 @@ -925,7 +925,7 @@ def test_process_batch_output_speculative_logprob_targets_topk_scores(): task.trace_carrier = None rm.tasks_list[0] = task rm.req_dict[task.request_id] = task - processor.output_tokens[1, 0] = 3 + processor.output_tokens[1, 0] = 3 | ((K + 1) << 8) processor.output_tokens[2, 0] = 1 processor.output_tokens[3, 0] = 2 token_block = np.arange(MAX_DRAFT_TOKENS * (K + 1), dtype=np.int64).reshape([-1, 1]) + 3