Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,11 @@ void MTPSaveFirstTokenWithTopK(const paddle::Tensor& sampled_token_ids,
msg_sed.mtype = 1;
msg_sed.meta[0] = not_need_stop.data<bool>()[0] ? inference_msg_id_from_env
: -inference_msg_id_from_env;
msg_sed.meta[1] = message_flag;
msg_sed.meta[2] = bsz;
// Pack message_flag (low 8 bits) and max_num_logprobs (high 24 bits) into
// meta[1]. Receiver unpacks both to avoid reading unused topk slots.
int max_num_logprobs = logprob_token_ids.shape()[1];
msg_sed.meta[1] = message_flag | (max_num_logprobs << 8);
msg_sed.meta[2] = bsz;
for (int i = 0; i < bsz; i++) {
int cur_token_num;
if (seq_lens_decoder_data[i] < prompt_lens_data[i] ||
Expand All @@ -139,29 +141,24 @@ void MTPSaveFirstTokenWithTopK(const paddle::Tensor& sampled_token_ids,
auto* cur_batch_msg_sed = &msg_sed.mtext[i];
int token_offset = cu_batch_token_offset_data[i];
for (int j = 0; j < cur_token_num; j++) {
// Use SPEC_LOGPROB_K+1 as stride (fixed struct layout), but only write
// max_num_logprobs columns to avoid filling unused topk slots.
auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)];
auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)];
if (j == 0) {
// first token has full logprobs
for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) {
for (int k = 0; k < max_num_logprobs; k++) {
if (k == 0) {
cur_tokens[k] =
(int)sampled_token_ids_data[i * max_draft_tokens + j];
cur_scores[k] =
logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
k];
} else if (k < max_num_logprobs) {
// only for first token
cur_tokens[k] =
(int)logprob_token_ids_data[(token_offset + j) *
(SPEC_LOGPROB_K + 1) +
k];
cur_scores[k] =
logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
k];
logprob_scores_data[(token_offset + j) * max_num_logprobs + k];
} else {
cur_tokens[k] = -1;
cur_scores[k] = 0.0;
cur_tokens[k] = (int)
logprob_token_ids_data[(token_offset + j) * max_num_logprobs +
k];
cur_scores[k] =
logprob_scores_data[(token_offset + j) * max_num_logprobs + k];
}
}
cur_batch_msg_sed->ranks[j] = (int)logprob_ranks_data[token_offset + j];
Expand All @@ -174,7 +171,8 @@ void MTPSaveFirstTokenWithTopK(const paddle::Tensor& sampled_token_ids,
#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG
std::cout << "msg data: " << std::endl;
std::cout << "stop_flag: " << msg_sed.meta[0]
<< ", message_flag: " << msg_sed.meta[1]
<< ", message_flag: " << (msg_sed.meta[1] & 0xFF)
<< ", max_num_logprobs: " << (msg_sed.meta[1] >> 8)
<< ", bsz: " << msg_sed.meta[2] << std::endl;
for (int i = 0; i < bsz; i++) {
int cur_token_num = msg_sed.meta[3 + i];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,11 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens,

int bsz = msg_rcv.meta[2];
output_tokens_data[0] = (int64_t)msg_rcv.meta[0];
// Unpack message_flag (low 8 bits) and actual_topk (high 24 bits) from
// meta[1]. Keep packed value; Python unpacks message_flag and actual_topk.
output_tokens_data[1] = (int64_t)msg_rcv.meta[1];
output_tokens_data[2] = (int64_t)msg_rcv.meta[2];

This comment was marked as outdated.

int actual_topk = msg_rcv.meta[1] >> 8;

int output_tokens_offset = 3 + SPEC_LOGPROB_MAX_BSZ;
for (int i = 0; i < bsz; i++) {
Expand All @@ -89,7 +92,7 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens,
output_scores_data + i * (MAX_DRAFT_TOKEN_NUM * (SPEC_LOGPROB_K + 1));
auto* cur_batch_msg_rcv = &msg_rcv.mtext[i];
for (int j = 0; j < cur_token_num; j++) {
for (int k = 0; k < real_k + 1; k++) {
for (int k = 0; k < actual_topk; k++) {
cur_output_token[j * (SPEC_LOGPROB_K + 1) + k] =
(int64_t)cur_batch_msg_rcv->tokens[j * (SPEC_LOGPROB_K + 1) + k];
cur_output_score[j * (SPEC_LOGPROB_K + 1) + k] =
Expand All @@ -102,7 +105,8 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens,
#ifdef SPECULATE_GET_WITH_OUTPUT_DEBUG
std::cout << "msg data: " << std::endl;
std::cout << "stop_flag: " << output_tokens_data[0]
<< ", message_flag: " << output_tokens_data[1]
<< ", message_flag: " << (output_tokens_data[1] & 0xFF)
<< ", max_num_logprobs: " << (output_tokens_data[1] >> 8)
<< ", bsz: " << output_tokens_data[2] << std::endl;
for (int i = 0; i < output_tokens_data[2]; i++) {
int cur_token_num = output_tokens_data[3 + i];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,11 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
msg_sed.mtype = 1;
msg_sed.meta[0] = not_need_stop.data<bool>()[0] ? inference_msg_id_from_env
: -inference_msg_id_from_env;
msg_sed.meta[1] = message_flag;
msg_sed.meta[2] = bsz;
// Pack message_flag (low 8 bits) and max_num_logprobs (high 24 bits) into
// meta[1]. Receiver unpacks both to avoid reading unused topk slots.
int max_num_logprobs = logprob_token_ids.shape()[1];
msg_sed.meta[1] = message_flag | (max_num_logprobs << 8);

This comment was marked as outdated.

msg_sed.meta[2] = bsz;
for (int i = 0; i < bsz; i++) {
int cur_token_num;
if (seq_lens_decoder_data[i] < prompt_lens_data[i]) {
Expand All @@ -139,24 +141,20 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
auto* cur_batch_msg_sed = &msg_sed.mtext[i];
int token_offset = cu_batch_token_offset_data[i];
for (int j = 0; j < cur_token_num; j++) {
// Use SPEC_LOGPROB_K+1 as stride (fixed struct layout), but only write
// max_num_logprobs columns to avoid filling unused topk slots.
auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)];
auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)];
for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) {
for (int k = 0; k < max_num_logprobs; k++) {
if (k == 0) {
cur_tokens[k] = (int)sampled_token_ids_data[i * max_draft_tokens + j];
cur_scores[k] =
logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
k];
} else if (k < max_num_logprobs) {
logprob_scores_data[(token_offset + j) * max_num_logprobs + k];
} else {
cur_tokens[k] = (int)
logprob_token_ids_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
k];
logprob_token_ids_data[(token_offset + j) * max_num_logprobs + k];
cur_scores[k] =
logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
k];
} else {
cur_tokens[k] = -1;
cur_scores[k] = 0.0;
logprob_scores_data[(token_offset + j) * max_num_logprobs + k];
}
}
cur_batch_msg_sed->ranks[j] = (int)logprob_ranks_data[token_offset + j];
Expand All @@ -165,7 +163,8 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG
std::cout << "msg data: " << std::endl;
std::cout << "stop_flag: " << msg_sed.meta[0]
<< ", message_flag: " << msg_sed.meta[1]
<< ", message_flag: " << (msg_sed.meta[1] & 0xFF)
<< ", max_num_logprobs: " << (msg_sed.meta[1] >> 8)
<< ", bsz: " << msg_sed.meta[2] << std::endl;
for (int i = 0; i < bsz; i++) {
int cur_token_num = msg_sed.meta[3 + i];
Expand Down
36 changes: 23 additions & 13 deletions fastdeploy/output/token_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,12 +739,15 @@ def _process_batch_draft_tokens(self, mtype, batch, accept_num, tokens, scores,
metrics=None,
)

token_ids = tokens[i][:, 0].tolist()[: accept_num[i]]
tokens_i = tokens[i].tolist()
scores_i = scores[i].tolist()
ranks_i = ranks[i].tolist()
token_ids = [row[0] for row in tokens_i[: accept_num[i]]]
for batch_token_index in range(len(token_ids)):
result.outputs.logprob = float(scores[i, batch_token_index, 0])
topk_token_ids = tokens[i, batch_token_index, :].tolist()
topk_logprobs = scores[i, batch_token_index, :].tolist()
sampled_rank = ranks[i, batch_token_index].item()
result.outputs.logprob = scores_i[batch_token_index][0]
topk_token_ids = tokens_i[batch_token_index]
topk_logprobs = scores_i[batch_token_index]
sampled_rank = ranks_i[batch_token_index]

if result.outputs.draft_top_logprobs is None:
result.outputs.draft_top_logprobs = LogprobsLists(
Expand All @@ -771,16 +774,19 @@ def _process_batch_output(self):
mtype = 3
if self.cfg.speculative_config.method:
if self.use_logprobs:
mtype = int(self.output_tokens[1, 0].item())
# meta[1] packs message_flag (low 8 bits) and actual_topk (high 24 bits).
packed_meta1 = int(self.output_tokens[1, 0].item())
mtype = packed_meta1 & 0xFF
actual_topk = packed_meta1 >> 8
batch = self.output_tokens[2, 0]
accept_num = [int(num[0]) for num in self.output_tokens[3 : batch + 3]]
tokens = tokens[3 + MAX_BSZ : 3 + MAX_BSZ + batch * MAX_DRAFT_TOKENS * (K + 1)].reshape(
[batch, MAX_DRAFT_TOKENS, K + 1]
)
)[:, :, :actual_topk]

This comment was marked as outdated.

scores = (
self.output_scores[: batch * MAX_DRAFT_TOKENS * (K + 1)]
.numpy()
.reshape([batch, MAX_DRAFT_TOKENS, K + 1])
.reshape([batch, MAX_DRAFT_TOKENS, K + 1])[:, :, :actual_topk]
)
ranks = self.output_ranks[: batch * MAX_DRAFT_TOKENS].numpy().reshape([batch, MAX_DRAFT_TOKENS])

Expand All @@ -789,6 +795,10 @@ def _process_batch_output(self):
batch_result = self._process_batch_draft_tokens(mtype, batch, accept_num, tokens, scores, ranks)
self.postprocess(batch_result, mtype)
return
# Pre-convert full arrays to Python lists once for MTP target token path.
tokens_lists = tokens.tolist()
scores_lists = scores.tolist()
ranks_list = ranks.tolist()
else:
batch = self.output_tokens[1]
accept_num = tokens[2 : batch + 2]
Expand Down Expand Up @@ -856,7 +866,7 @@ def _process_batch_output(self):
)

This comment was marked as outdated.

This comment was marked as outdated.

token_ids = [RECOVERY_STOP_SIGNAL]
elif self.use_logprobs:
token_ids = tokens[i][:, 0].tolist()[: accept_num[i]]
token_ids = [row[0] for row in tokens_lists[i][: accept_num[i]]]
else:
token_ids = tokens[
2
Expand Down Expand Up @@ -988,10 +998,10 @@ def _process_batch_output(self):
task.output_token_ids.append(token_id)
if self.use_logprobs:
if self.cfg.speculative_config.method:
result.outputs.logprob = float(scores[i, batch_token_index, 0])
topk_token_ids = tokens[i, batch_token_index, :].tolist()
topk_logprobs = scores[i, batch_token_index, :].tolist()
sampled_rank = ranks[i, batch_token_index].item()
result.outputs.logprob = scores_lists[i][batch_token_index][0]
topk_token_ids = tokens_lists[i][batch_token_index]
topk_logprobs = scores_lists[i][batch_token_index]
sampled_rank = ranks_list[i][batch_token_index]
else:
# Use pre-converted lists (batch .tolist() done before the loop).
result.outputs.logprob = scores_lists[i][0]
Expand Down
14 changes: 5 additions & 9 deletions fastdeploy/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,15 +1279,11 @@ def _prepare_inputs(self, cached_token_num=-1, cached_real_bsz=-1, is_dummy_or_p
req.sampling_params.top_p_normalized_logprobs and req.sampling_params.top_p != 1.0 for req in logprobs_reqs
)
if logprobs_reqs:
self.max_logprobs = (
max(
[
self.ori_vocab_size if req.sampling_params.logprobs < 0 else req.sampling_params.logprobs
for req in logprobs_reqs
]
)
if not self.speculative_decoding
else 20
self.max_logprobs = max(
[
self.ori_vocab_size if req.sampling_params.logprobs < 0 else req.sampling_params.logprobs
for req in logprobs_reqs
]
)
elif self.enable_logprob:
self.max_logprobs = None if not self.speculative_decoding else 0
Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/test_ernie_21b_mtp_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,11 @@ def test_prefix_cache_text(api_url):
if os.getenv("BASELINE") == "1":
baseline_manager.save("base_21b_step3", result)
baseline_manager.save("base_21b_mtp_metrics_step3", speculate_metrics_2)
baseline_manager.save("base_21b_logprobs_step3", logprobs_2)
baseline_manager.save("base_21b_logprobs_step3_new", logprobs_2)

baseline_result = baseline_manager.load("base_21b_step3")
baseline_mtp_metrics = baseline_manager.load("base_21b_mtp_metrics_step3")
baseline_logprobs = baseline_manager.load("base_21b_logprobs_step3")
baseline_logprobs = baseline_manager.load("base_21b_logprobs_step3_new")

assert logprobs == logprobs_2, (
"logprobs 前后不一致\n"
Expand Down
19 changes: 10 additions & 9 deletions tests/output/test_process_batch_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,9 @@ def test_speculative_decoding_use_logprobs(self):

# stop_flag
processor.output_tokens[0, 0].set_tensor(paddle.to_tensor(2))

This comment was marked as outdated.

# mtype target = 3, decode = 4
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3))
# meta[1] packs mtype (low 8 bits) and actual_topk (high 16 bits)

This comment was marked as outdated.

actual_topk = K + 1
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3 | (actual_topk << 8)))
# batch
processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(2))
# accept_num
Expand Down Expand Up @@ -243,12 +244,12 @@ def test_speculative_decoding_use_logprobs(self):
assert len(request_output.outputs.token_ids) == accept_num[i]
assert len(request_output.outputs.top_logprobs) == 3
# tokens, scores, ranks
assert len(request_output.outputs.top_logprobs[0][0]) == K + 1
assert len(request_output.outputs.top_logprobs[1][0]) == K + 1
assert len(request_output.outputs.top_logprobs[0][0]) == actual_topk
assert len(request_output.outputs.top_logprobs[1][0]) == actual_topk
assert len(request_output.outputs.top_logprobs[2]) == accept_num[i]

# mtype = 4
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(4))
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(4 | (actual_topk << 8)))
processor._process_batch_output()
cached_generated_tokens: MockCachedGeneratedTokens = processor.cached_generated_tokens
for c in cached_generated_tokens.cache:
Expand All @@ -257,8 +258,8 @@ def test_speculative_decoding_use_logprobs(self):
assert len(request_output.outputs.top_logprobs) == 3
assert len(request_output.outputs.draft_top_logprobs) == 3
# tokens, scores, ranks
assert len(request_output.outputs.draft_top_logprobs[0][0]) == K + 1
assert len(request_output.outputs.draft_top_logprobs[1][0]) == K + 1
assert len(request_output.outputs.draft_top_logprobs[0][0]) == actual_topk
assert len(request_output.outputs.draft_top_logprobs[1][0]) == actual_topk
assert len(request_output.outputs.draft_top_logprobs[2]) == accept_num[i]

def test_process_batch_output_aborted_task_negative_token_speculative_decoding(self):
Expand All @@ -281,8 +282,8 @@ def test_process_batch_output_aborted_task_negative_token_speculative_decoding(s
# Set up output tokens with negative token
# stop_flag
processor.output_tokens[0, 0].set_tensor(paddle.to_tensor(2))
# mtype target = 3
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3))
# mtype target = 3, actual_topk packed in high bits
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3 | ((K + 1) << 8)))
# batch = 2 (so batch_id=0 is < batch_size-1=1)
processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(2))
# Set accept_num = PREEMPTED_TOKEN_ID (-9) for first task to trigger abort logic
Expand Down
4 changes: 2 additions & 2 deletions tests/output/test_token_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ def test_process_batch_output_speculative_logprob_handles_draft_batch():
)
processor._batch_result_buffer = [target]
processor.cached_generated_tokens = mock.Mock()
processor.output_tokens[1, 0] = 4
processor.output_tokens[1, 0] = 4 | ((K + 1) << 8)
processor.output_tokens[2, 0] = 1
processor.output_tokens[3, 0] = 1

Expand Down Expand Up @@ -925,7 +925,7 @@ def test_process_batch_output_speculative_logprob_targets_topk_scores():
task.trace_carrier = None
rm.tasks_list[0] = task
rm.req_dict[task.request_id] = task
processor.output_tokens[1, 0] = 3
processor.output_tokens[1, 0] = 3 | ((K + 1) << 8)
processor.output_tokens[2, 0] = 1
processor.output_tokens[3, 0] = 2
token_block = np.arange(MAX_DRAFT_TOKENS * (K + 1), dtype=np.int64).reshape([-1, 1]) + 3
Expand Down
Loading