Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions custom_ops/gpu_ops/get_output_msg_with_topk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,17 @@ void GetOutputTopK(const paddle::Tensor& x,
return;
}

int bsz = msg_rcv.mtext[1];
// Unpack bsz (low 16 bits) and actual_topk (high 16 bits) from mtext[1].
// This matches the packing in save_output_msg_with_topk.cc:
// mtext[1] = bsz | (max_num_logprobs << 16)
int bsz = msg_rcv.mtext[1] & 0xFFFF;
int actual_topk = (msg_rcv.mtext[1] >> 16) & 0xFFFF;
out_data[0] = (int64_t)msg_rcv.mtext[0];
out_data[1] = (int64_t)msg_rcv.mtext[1];
out_data[1] = (int64_t)msg_rcv.mtext[1]; // keep packed value; Python unpacks

for (int i = 0; i < bsz; i++) {
for (int j = 0; j < k + 1; j++) {
const int64_t offset = i * (K + 1) + j;
for (int j = 0; j < actual_topk; j++) {
const int64_t offset = i * actual_topk + j;
out_data[offset + 2] = (int64_t)msg_rcv.mtext[offset + 2];
scores_data[offset] = msg_rcv.mtext_f[offset];
}
Expand Down
21 changes: 11 additions & 10 deletions custom_ops/gpu_ops/save_output_msg_with_topk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,20 +109,21 @@ void SaveOutMmsgTopK(const paddle::Tensor& x,
: -inference_msg_id_from_env;
int bsz = x.shape()[0];
int max_num_logprobs = logprob_token_ids.shape()[1];
msg_sed.mtext[1] = bsz;
// Pack bsz (low 16 bits) and max_num_logprobs (high 16 bits) into mtext[1].
// token_processor unpacks both fields to avoid reading unused topk slots.
msg_sed.mtext[1] = bsz | (max_num_logprobs << 16);
for (int i = 0; i < bsz; i++) {
for (int j = 0; j < K + 1; j++) {
const int64_t offset = i * (K + 1) + j;
// Loop only over actual logprob columns (max_num_logprobs) instead of the
// fixed K+1=21, and use max_num_logprobs as the stride so data is packed
// densely in the message buffer.
for (int j = 0; j < max_num_logprobs; j++) {
const int64_t offset = i * max_num_logprobs + j;
if (j == 0) {
msg_sed.mtext[offset + 2] = (int)x_data[i];
msg_sed.mtext_f[offset] = logprob_scores_data[i * max_num_logprobs + j];
} else if (j < max_num_logprobs) {
msg_sed.mtext[offset + 2] =
(int)logprob_token_ids_data[i * max_num_logprobs + j];
msg_sed.mtext_f[offset] = logprob_scores_data[i * max_num_logprobs + j];
msg_sed.mtext_f[offset] = logprob_scores_data[offset];
} else {
msg_sed.mtext[offset + 2] = -1;
msg_sed.mtext_f[offset] = 0.0;
msg_sed.mtext[offset + 2] = (int)logprob_token_ids_data[offset];
msg_sed.mtext_f[offset] = logprob_scores_data[offset];
}
if (preempted_idx_data[i] == 1) {
msg_sed.mtext[offset + 2] = -9;
Expand Down
12 changes: 8 additions & 4 deletions custom_ops/xpu_ops/src/ops/get_output_msg_with_topk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,17 @@ void GetOutputTopK(const paddle::Tensor& x,
return;
}

int bsz = msg_rcv.mtext[1];
// Unpack bsz (low 16 bits) and actual_topk (high 16 bits) from mtext[1].
// This matches the packing in save_output_msg_with_topk.cc:
// mtext[1] = bsz | (max_num_logprobs << 16)
int bsz = msg_rcv.mtext[1] & 0xFFFF;
int actual_topk = (msg_rcv.mtext[1] >> 16) & 0xFFFF;
out_data[0] = (int64_t)msg_rcv.mtext[0];
out_data[1] = (int64_t)msg_rcv.mtext[1];
out_data[1] = (int64_t)msg_rcv.mtext[1]; // keep packed value; Python unpacks

for (int i = 0; i < bsz; i++) {
for (int j = 0; j < k + 1; j++) {
const int64_t offset = i * (K + 1) + j;
for (int j = 0; j < actual_topk; j++) {
const int64_t offset = i * actual_topk + j;
out_data[offset + 2] = (int64_t)msg_rcv.mtext[offset + 2];
scores_data[offset] = msg_rcv.mtext_f[offset];
}
Expand Down
21 changes: 11 additions & 10 deletions custom_ops/xpu_ops/src/ops/save_output_msg_with_topk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,20 +109,21 @@ void SaveOutMmsgTopK(const paddle::Tensor& x,
: -inference_msg_id_from_env;
int bsz = x.shape()[0];
int max_num_logprobs = logprob_token_ids.shape()[1];
msg_sed.mtext[1] = bsz;
// Pack bsz (low 16 bits) and max_num_logprobs (high 16 bits) into mtext[1].
// token_processor unpacks both fields to avoid reading unused topk slots.
msg_sed.mtext[1] = bsz | (max_num_logprobs << 16);
for (int i = 0; i < bsz; i++) {
for (int j = 0; j < K + 1; j++) {
const int64_t offset = i * (K + 1) + j;
// Loop only over actual logprob columns (max_num_logprobs) instead of the
// fixed K+1=21, and use max_num_logprobs as the stride so data is packed
// densely in the message buffer.
for (int j = 0; j < max_num_logprobs; j++) {
const int64_t offset = i * max_num_logprobs + j;
if (j == 0) {
msg_sed.mtext[offset + 2] = (int)x_data[i];
msg_sed.mtext_f[offset] = logprob_scores_data[i * max_num_logprobs + j];
} else if (j < max_num_logprobs) {
msg_sed.mtext[offset + 2] =
(int)logprob_token_ids_data[i * max_num_logprobs + j];
msg_sed.mtext_f[offset] = logprob_scores_data[i * max_num_logprobs + j];
msg_sed.mtext_f[offset] = logprob_scores_data[offset];
} else {
msg_sed.mtext[offset + 2] = -1;
msg_sed.mtext_f[offset] = 0.0;
msg_sed.mtext[offset + 2] = (int)logprob_token_ids_data[offset];
msg_sed.mtext_f[offset] = logprob_scores_data[offset];
}
if (preempted_idx_data[i] == 1) {
msg_sed.mtext[offset + 2] = -9;
Expand Down
27 changes: 20 additions & 7 deletions fastdeploy/output/token_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,10 +846,22 @@ def _process_batch_output(self):
batch = self.output_tokens[1]
accept_num = tokens[2 : batch + 2]
elif self.use_logprobs:
batch = self.output_tokens[1, 0]
tokens = tokens[2 : batch * (K + 1) + 2].reshape([batch, K + 1])[:, : (K + 1)]
scores = self.output_scores[: batch * (K + 1)].numpy().reshape([batch, K + 1])[:, : (K + 1)]
# mtext[1] packs bsz (low 16 bits) and actual_topk (high 16 bits).
# actual_topk = max_num_logprobs written by save_output_topk, which
# equals the actual number of logprob columns in this step's message
# (top_logprobs+1 across the batch). Using actual_topk as stride
# avoids processing the K+1=21 fixed-size slots when fewer are needed.
packed = int(self.output_tokens[1, 0])
batch = packed & 0xFFFF
actual_topk = (packed >> 16) & 0xFFFF
tokens = tokens[2 : batch * actual_topk + 2].reshape([batch, actual_topk])
scores = self.output_scores[: batch * actual_topk].numpy().reshape([batch, actual_topk])
ranks = self.output_ranks[:batch].numpy()
# Pre-convert the full [batch, actual_topk] arrays to Python lists once,
# avoiding per-row .tolist() calls inside the loop below.
tokens_lists = tokens.tolist()
scores_lists = scores.tolist()
ranks_list = ranks.tolist()
else:
batch = self.output_tokens[1, 0]
tokens = tokens[2 : batch + 2]
Expand Down Expand Up @@ -1022,10 +1034,11 @@ def _process_batch_output(self):
topk_logprobs = scores[i, batch_token_index, :].tolist()
sampled_rank = ranks[i, batch_token_index].item()
else:
result.outputs.logprob = float(scores[i, 0])
topk_token_ids = tokens[i, :].tolist()
topk_logprobs = scores[i, :].tolist()
sampled_rank = ranks[i].item()
# Use pre-converted lists (batch .tolist() done before the loop).
result.outputs.logprob = scores_lists[i][0]
topk_token_ids = tokens_lists[i]
topk_logprobs = scores_lists[i]
sampled_rank = ranks_list[i]

if result.outputs.top_logprobs is None:
result.outputs.top_logprobs = LogprobsLists(
Expand Down
6 changes: 4 additions & 2 deletions tests/output/test_token_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,8 @@ def test_process_batch_output_logprob_records_topk_and_caching():
task.trace_carrier = None
rm.tasks_list[0] = task
rm.req_dict[task.request_id] = task
processor.output_tokens[1, 0] = 1
# mtext[1] packs bsz (low 16 bits) | actual_topk (high 16 bits)
processor.output_tokens[1, 0] = 1 | ((K + 1) << 16)
token_block = np.arange(K + 1, dtype=np.int64) + 3
processor.output_tokens[2 : 2 + K + 1] = paddle.to_tensor(token_block.reshape([-1, 1]))
processor.output_scores[: K + 1] = paddle.ones([K + 1, 1], dtype="float32")
Expand Down Expand Up @@ -842,7 +843,8 @@ def test_process_batch_output_prefill_chunk_and_adapter_skip():
task.get = lambda key, default=None: getattr(task, key, default)
rm.tasks_list[0] = task
rm.req_dict[task.request_id] = task
processor.output_tokens[1, 0] = 1
# mtext[1] packs bsz (low 16 bits) | actual_topk (high 16 bits)
processor.output_tokens[1, 0] = 1 | ((K + 1) << 16)
processor.output_tokens[2 : 2 + K + 1] = paddle.to_tensor(np.ones([K + 1, 1], dtype=np.int64))
processor.output_scores[: K + 1] = paddle.ones([K + 1, 1], dtype="float32")
processor.output_ranks[0] = paddle.to_tensor(0, dtype="int64")
Expand Down
Loading