Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,11 @@ void MTPSaveFirstTokenWithTopK(const paddle::Tensor& sampled_token_ids,
msg_sed.mtype = 1;
msg_sed.meta[0] = not_need_stop.data<bool>()[0] ? inference_msg_id_from_env
: -inference_msg_id_from_env;
msg_sed.meta[1] = message_flag;
msg_sed.meta[2] = bsz;
// Pack message_flag (low 8 bits) and max_num_logprobs (high 24 bits) into
// meta[1]. Receiver unpacks both to avoid reading unused topk slots.
int max_num_logprobs = logprob_token_ids.shape()[1];
msg_sed.meta[1] = message_flag | (max_num_logprobs << 8);

This comment was marked as outdated.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 message_flag 打包前未做高位截断

message_flag | (max_num_logprobs << 8) 假设 message_flag < 256,但没有显式截断。当前 mtype 取值为 3/4,实际安全;但若未来扩展 flag 值 ≥ 256,高位会污染 max_num_logprobs,导致接收端解包错误。

建议加防御性截断:

msg_sed.meta[1] = (message_flag & 0xFF) | (max_num_logprobs << 8);

speculate_save_output_with_topk.cc 同样位置同理。

msg_sed.meta[2] = bsz;
for (int i = 0; i < bsz; i++) {
int cur_token_num;
if (seq_lens_decoder_data[i] < prompt_lens_data[i] ||
Expand All @@ -139,29 +141,24 @@ void MTPSaveFirstTokenWithTopK(const paddle::Tensor& sampled_token_ids,
auto* cur_batch_msg_sed = &msg_sed.mtext[i];
int token_offset = cu_batch_token_offset_data[i];
for (int j = 0; j < cur_token_num; j++) {
// Use SPEC_LOGPROB_K+1 as stride (fixed struct layout), but only write
// max_num_logprobs columns to avoid filling unused topk slots.
auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)];
auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)];
if (j == 0) {
// first token has full logprobs
for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) {
for (int k = 0; k < max_num_logprobs; k++) {
if (k == 0) {
cur_tokens[k] =
(int)sampled_token_ids_data[i * max_draft_tokens + j];
cur_scores[k] =
logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
k];
} else if (k < max_num_logprobs) {
// only for first token
cur_tokens[k] =
(int)logprob_token_ids_data[(token_offset + j) *
(SPEC_LOGPROB_K + 1) +
k];
cur_scores[k] =
logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
k];
logprob_scores_data[(token_offset + j) * max_num_logprobs + k];
} else {
cur_tokens[k] = -1;
cur_scores[k] = 0.0;
cur_tokens[k] = (int)
logprob_token_ids_data[(token_offset + j) * max_num_logprobs +
k];
cur_scores[k] =
logprob_scores_data[(token_offset + j) * max_num_logprobs + k];
}
}
cur_batch_msg_sed->ranks[j] = (int)logprob_ranks_data[token_offset + j];
Expand All @@ -174,7 +171,8 @@ void MTPSaveFirstTokenWithTopK(const paddle::Tensor& sampled_token_ids,
#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG
std::cout << "msg data: " << std::endl;
std::cout << "stop_flag: " << msg_sed.meta[0]
<< ", message_flag: " << msg_sed.meta[1]
<< ", message_flag: " << (msg_sed.meta[1] & 0xFF)
<< ", max_num_logprobs: " << (msg_sed.meta[1] >> 8)
<< ", bsz: " << msg_sed.meta[2] << std::endl;
for (int i = 0; i < bsz; i++) {
int cur_token_num = msg_sed.meta[3 + i];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,11 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens,

int bsz = msg_rcv.meta[2];
output_tokens_data[0] = (int64_t)msg_rcv.meta[0];
// Unpack message_flag (low 8 bits) and actual_topk (high 24 bits) from
// meta[1]. Keep packed value; Python unpacks message_flag and actual_topk.
output_tokens_data[1] = (int64_t)msg_rcv.meta[1];
output_tokens_data[2] = (int64_t)msg_rcv.meta[2];

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 actual_topk = 0 边界场景是否有保护?

top_logprobs=0 的请求进入 logprobs_reqs 时,max_logprobs = max([0]) = 0,从而 max_num_logprobs = 0,发送端内层循环不执行,sampled token 不写入消息结构体。

接收端 actual_topk = 0,copy 循环同样不执行,传到 Python 侧后:

tokens[:, :, :0]  # shape=[batch, MAX_DRAFT_TOKENS, 0]
token_ids = [row[0] for row in tokens_lists[i][:accept_num[i]]]  # IndexError!

请确认:

  1. top_logprobs=0 的请求是否会进入 logprobs_reqs(若不会则无问题)
  2. 若会,需要在 C++ 侧保证 max_num_logprobs >= 1(至少写入 sampled token),或在 Python 侧对 actual_topk == 0 分支特殊处理。

int actual_topk = msg_rcv.meta[1] >> 8;

int output_tokens_offset = 3 + SPEC_LOGPROB_MAX_BSZ;
for (int i = 0; i < bsz; i++) {
Expand All @@ -89,7 +92,7 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens,
output_scores_data + i * (MAX_DRAFT_TOKEN_NUM * (SPEC_LOGPROB_K + 1));
auto* cur_batch_msg_rcv = &msg_rcv.mtext[i];
for (int j = 0; j < cur_token_num; j++) {
for (int k = 0; k < real_k + 1; k++) {
for (int k = 0; k < actual_topk; k++) {
cur_output_token[j * (SPEC_LOGPROB_K + 1) + k] =
(int64_t)cur_batch_msg_rcv->tokens[j * (SPEC_LOGPROB_K + 1) + k];
cur_output_score[j * (SPEC_LOGPROB_K + 1) + k] =
Expand All @@ -102,7 +105,8 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens,
#ifdef SPECULATE_GET_WITH_OUTPUT_DEBUG
std::cout << "msg data: " << std::endl;
std::cout << "stop_flag: " << output_tokens_data[0]
<< ", message_flag: " << output_tokens_data[1]
<< ", message_flag: " << (output_tokens_data[1] & 0xFF)
<< ", max_num_logprobs: " << (output_tokens_data[1] >> 8)
<< ", bsz: " << output_tokens_data[2] << std::endl;
for (int i = 0; i < output_tokens_data[2]; i++) {
int cur_token_num = output_tokens_data[3 + i];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,11 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
msg_sed.mtype = 1;
msg_sed.meta[0] = not_need_stop.data<bool>()[0] ? inference_msg_id_from_env
: -inference_msg_id_from_env;
msg_sed.meta[1] = message_flag;
msg_sed.meta[2] = bsz;
// Pack message_flag (low 8 bits) and max_num_logprobs (high 24 bits) into
// meta[1]. Receiver unpacks both to avoid reading unused topk slots.
int max_num_logprobs = logprob_token_ids.shape()[1];
msg_sed.meta[1] = message_flag | (max_num_logprobs << 8);
msg_sed.meta[2] = bsz;
for (int i = 0; i < bsz; i++) {
int cur_token_num;
if (seq_lens_decoder_data[i] < prompt_lens_data[i]) {
Expand All @@ -139,24 +141,20 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
auto* cur_batch_msg_sed = &msg_sed.mtext[i];
int token_offset = cu_batch_token_offset_data[i];
for (int j = 0; j < cur_token_num; j++) {
// Use SPEC_LOGPROB_K+1 as stride (fixed struct layout), but only write
// max_num_logprobs columns to avoid filling unused topk slots.
auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)];
auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)];
for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) {
for (int k = 0; k < max_num_logprobs; k++) {

This comment was marked as outdated.

if (k == 0) {
cur_tokens[k] = (int)sampled_token_ids_data[i * max_draft_tokens + j];
cur_scores[k] =
logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
k];
} else if (k < max_num_logprobs) {
logprob_scores_data[(token_offset + j) * max_num_logprobs + k];
} else {
cur_tokens[k] = (int)
logprob_token_ids_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
k];
logprob_token_ids_data[(token_offset + j) * max_num_logprobs + k];
cur_scores[k] =
logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
k];
} else {
cur_tokens[k] = -1;
cur_scores[k] = 0.0;
logprob_scores_data[(token_offset + j) * max_num_logprobs + k];
}
}
cur_batch_msg_sed->ranks[j] = (int)logprob_ranks_data[token_offset + j];
Expand All @@ -165,7 +163,8 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG
std::cout << "msg data: " << std::endl;
std::cout << "stop_flag: " << msg_sed.meta[0]
<< ", message_flag: " << msg_sed.meta[1]
<< ", message_flag: " << (msg_sed.meta[1] & 0xFF)
<< ", max_num_logprobs: " << (msg_sed.meta[1] >> 8)
<< ", bsz: " << msg_sed.meta[2] << std::endl;
for (int i = 0; i < bsz; i++) {
int cur_token_num = msg_sed.meta[3 + i];
Expand Down
36 changes: 23 additions & 13 deletions fastdeploy/output/token_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,12 +796,15 @@ def _process_batch_draft_tokens(self, mtype, batch, accept_num, tokens, scores,
metrics=None,
)

token_ids = tokens[i][:, 0].tolist()[: accept_num[i]]
tokens_i = tokens[i].tolist()
scores_i = scores[i].tolist()
ranks_i = ranks[i].tolist()
token_ids = [row[0] for row in tokens_i[: accept_num[i]]]
for batch_token_index in range(len(token_ids)):
result.outputs.logprob = float(scores[i, batch_token_index, 0])
topk_token_ids = tokens[i, batch_token_index, :].tolist()
topk_logprobs = scores[i, batch_token_index, :].tolist()
sampled_rank = ranks[i, batch_token_index].item()
result.outputs.logprob = scores_i[batch_token_index][0]
topk_token_ids = tokens_i[batch_token_index]
topk_logprobs = scores_i[batch_token_index]
sampled_rank = ranks_i[batch_token_index]

if result.outputs.draft_top_logprobs is None:
result.outputs.draft_top_logprobs = LogprobsLists(
Expand All @@ -828,16 +831,19 @@ def _process_batch_output(self):
mtype = 3
if self.cfg.speculative_config.method:
if self.use_logprobs:
mtype = int(self.output_tokens[1, 0].item())
# meta[1] packs message_flag (low 8 bits) and actual_topk (high 24 bits).
packed_meta1 = int(self.output_tokens[1, 0].item())
mtype = packed_meta1 & 0xFF
actual_topk = packed_meta1 >> 8
batch = self.output_tokens[2, 0]
accept_num = [int(num[0]) for num in self.output_tokens[3 : batch + 3]]
tokens = tokens[3 + MAX_BSZ : 3 + MAX_BSZ + batch * MAX_DRAFT_TOKENS * (K + 1)].reshape(
[batch, MAX_DRAFT_TOKENS, K + 1]
)
)[:, :, :actual_topk]
scores = (
self.output_scores[: batch * MAX_DRAFT_TOKENS * (K + 1)]
.numpy()
.reshape([batch, MAX_DRAFT_TOKENS, K + 1])
.reshape([batch, MAX_DRAFT_TOKENS, K + 1])[:, :, :actual_topk]
)
ranks = self.output_ranks[: batch * MAX_DRAFT_TOKENS].numpy().reshape([batch, MAX_DRAFT_TOKENS])

Expand All @@ -846,6 +852,10 @@ def _process_batch_output(self):
batch_result = self._process_batch_draft_tokens(mtype, batch, accept_num, tokens, scores, ranks)
self.postprocess(batch_result, mtype)
return
# Pre-convert full arrays to Python lists once for MTP target token path.
tokens_lists = tokens.tolist()
scores_lists = scores.tolist()
ranks_list = ranks.tolist()
else:
batch = self.output_tokens[1]
accept_num = tokens[2 : batch + 2]
Expand Down Expand Up @@ -914,7 +924,7 @@ def _process_batch_output(self):
llm_logger.info(f"recovery stop signal found at task {task_id}")
token_ids = [RECOVERY_STOP_SIGNAL]
elif self.use_logprobs:
token_ids = tokens[i][:, 0].tolist()[: accept_num[i]]
token_ids = [row[0] for row in tokens_lists[i][: accept_num[i]]]
else:
token_ids = tokens[
2
Expand Down Expand Up @@ -1033,10 +1043,10 @@ def _process_batch_output(self):
task.output_token_ids.append(token_id)
if self.use_logprobs:
if self.cfg.speculative_config.method:
result.outputs.logprob = float(scores[i, batch_token_index, 0])
topk_token_ids = tokens[i, batch_token_index, :].tolist()
topk_logprobs = scores[i, batch_token_index, :].tolist()
sampled_rank = ranks[i, batch_token_index].item()
result.outputs.logprob = scores_lists[i][batch_token_index][0]
topk_token_ids = tokens_lists[i][batch_token_index]
topk_logprobs = scores_lists[i][batch_token_index]
sampled_rank = ranks_list[i][batch_token_index]
else:
# Use pre-converted lists (batch .tolist() done before the loop).
result.outputs.logprob = scores_lists[i][0]
Expand Down
14 changes: 5 additions & 9 deletions fastdeploy/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,15 +1226,11 @@ def _prepare_inputs(self, cached_token_num=-1, cached_real_bsz=-1, is_dummy_or_p
req.sampling_params.top_p_normalized_logprobs and req.sampling_params.top_p != 1.0 for req in logprobs_reqs
)
if logprobs_reqs:
self.max_logprobs = (
max(
[
self.ori_vocab_size if req.sampling_params.logprobs < 0 else req.sampling_params.logprobs
for req in logprobs_reqs
]
)
if not self.speculative_decoding
else 20
self.max_logprobs = max(
[
self.ori_vocab_size if req.sampling_params.logprobs < 0 else req.sampling_params.logprobs
for req in logprobs_reqs
]
)

This comment was marked as outdated.

This comment was marked as outdated.

elif self.enable_logprob:
self.max_logprobs = None if not self.speculative_decoding else 0
Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/test_ernie_21b_mtp_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,11 @@ def test_prefix_cache_text(api_url):
if os.getenv("BASELINE") == "1":
baseline_manager.save("base_21b_step3", result)
baseline_manager.save("base_21b_mtp_metrics_step3", speculate_metrics_2)
baseline_manager.save("base_21b_logprobs_step3", logprobs_2)
baseline_manager.save("base_21b_logprobs_step3_new", logprobs_2)

baseline_result = baseline_manager.load("base_21b_step3")
baseline_mtp_metrics = baseline_manager.load("base_21b_mtp_metrics_step3")
baseline_logprobs = baseline_manager.load("base_21b_logprobs_step3")
baseline_logprobs = baseline_manager.load("base_21b_logprobs_step3_new")

assert logprobs == logprobs_2, (
"logprobs 前后不一致\n"
Expand Down
19 changes: 10 additions & 9 deletions tests/output/test_process_batch_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,9 @@ def test_speculative_decoding_use_logprobs(self):

# stop_flag
processor.output_tokens[0, 0].set_tensor(paddle.to_tensor(2))

This comment was marked as outdated.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 bit 宽注释不一致

此处注释写 actual_topk (high 16 bits),而 C++ 侧(mtp_save_first_token_with_topk.ccspeculate_save_output_with_topk.ccspeculate_get_output_with_topk.cc)的注释均写 high 24 bits

实际实现是 >> 8 在 int32 上取高 24 位,建议统一注释为 high 24 bits,避免误导维护者对 max_num_logprobs 的范围误判。

# mtype target = 3, decode = 4
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3))
# meta[1] packs mtype (low 8 bits) and actual_topk (high 16 bits)
actual_topk = K + 1
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3 | (actual_topk << 8)))
# batch
processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(2))
# accept_num
Expand Down Expand Up @@ -244,12 +245,12 @@ def test_speculative_decoding_use_logprobs(self):
assert len(request_output.outputs.token_ids) == accept_num[i]
assert len(request_output.outputs.top_logprobs) == 3
# tokens, scores, ranks
assert len(request_output.outputs.top_logprobs[0][0]) == K + 1
assert len(request_output.outputs.top_logprobs[1][0]) == K + 1
assert len(request_output.outputs.top_logprobs[0][0]) == actual_topk
assert len(request_output.outputs.top_logprobs[1][0]) == actual_topk
assert len(request_output.outputs.top_logprobs[2]) == accept_num[i]

# mtype = 4
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(4))
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(4 | (actual_topk << 8)))
processor._process_batch_output()
cached_generated_tokens: MockCachedGeneratedTokens = processor.cached_generated_tokens
for c in cached_generated_tokens.cache:
Expand All @@ -258,8 +259,8 @@ def test_speculative_decoding_use_logprobs(self):
assert len(request_output.outputs.top_logprobs) == 3
assert len(request_output.outputs.draft_top_logprobs) == 3
# tokens, scores, ranks
assert len(request_output.outputs.draft_top_logprobs[0][0]) == K + 1
assert len(request_output.outputs.draft_top_logprobs[1][0]) == K + 1
assert len(request_output.outputs.draft_top_logprobs[0][0]) == actual_topk
assert len(request_output.outputs.draft_top_logprobs[1][0]) == actual_topk
assert len(request_output.outputs.draft_top_logprobs[2]) == accept_num[i]

def test_process_batch_output_aborted_task_negative_token_speculative_decoding(self):
Expand All @@ -282,8 +283,8 @@ def test_process_batch_output_aborted_task_negative_token_speculative_decoding(s
# Set up output tokens with negative token
# stop_flag
processor.output_tokens[0, 0].set_tensor(paddle.to_tensor(2))
# mtype target = 3
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3))
# mtype target = 3, actual_topk packed in high bits
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3 | ((K + 1) << 8)))
# batch = 2 (so batch_id=0 is < batch_size-1=1)
processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(2))
# Set accept_num = PREEMPTED_TOKEN_ID (-9) for first task to trigger abort logic
Expand Down
4 changes: 2 additions & 2 deletions tests/output/test_token_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ def test_process_batch_output_speculative_logprob_handles_draft_batch():
)
processor._batch_result_buffer = [target]
processor.cached_generated_tokens = mock.Mock()
processor.output_tokens[1, 0] = 4
processor.output_tokens[1, 0] = 4 | ((K + 1) << 8)
processor.output_tokens[2, 0] = 1
processor.output_tokens[3, 0] = 1

Expand Down Expand Up @@ -926,7 +926,7 @@ def test_process_batch_output_speculative_logprob_targets_topk_scores():
task.trace_carrier = None
rm.tasks_list[0] = task
rm.req_dict[task.request_id] = task
processor.output_tokens[1, 0] = 3
processor.output_tokens[1, 0] = 3 | ((K + 1) << 8)
processor.output_tokens[2, 0] = 1
processor.output_tokens[3, 0] = 2
token_block = np.arange(MAX_DRAFT_TOKENS * (K + 1), dtype=np.int64).reshape([-1, 1]) + 3
Expand Down
Loading