Skip to content

Commit 7a2e330

Browse files
cmcamdylizan1999
andauthored
[XPU] Refactor pre process (#6993)
* [XPU] support speculate_pre_process * merge develop * fix codestype * fix mtp, support cu_seqlens_q_output * fix mtp, support cu_seqlens_q_output * fix test --------- Co-authored-by: lizan1999 <lizan03@baidu.com>
1 parent fba8a51 commit 7a2e330

36 files changed

Lines changed: 2725 additions & 511 deletions

custom_ops/cpu_ops/get_padding_offset.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,8 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
5858
const int bsz = seq_len.shape()[0];
5959
const int seq_length = input_ids_shape[1];
6060
auto cum_offsets_out = cum_offsets.copy_to(paddle::CPUPlace(), false);
61-
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);
62-
63-
const int token_num_data = cpu_token_num.data<int64_t>()[0];
61+
// token num is cpu tensor
62+
const int token_num_data = token_num.data<int64_t>()[0];
6463
auto x_remove_padding = paddle::empty(
6564
{token_num_data}, paddle::DataType::INT64, input_ids.place());
6665
auto padding_offset = paddle::empty(

custom_ops/xpu_ops/src/ops/adjust_batch.cc

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424

2525
template <paddle::DataType T>
2626
std::vector<paddle::Tensor> AdjustBatchKernel(
27-
const paddle::Tensor &x, // [token_num, dim_embed]
28-
const paddle::Tensor &cum_offsets, // [bsz, 1]
27+
const paddle::Tensor &x, // [token_num, dim_embed]
2928
const paddle::Tensor &encoder_seq_lod,
3029
const paddle::Tensor &decoder_seq_lod,
3130
const paddle::Tensor &encoder_batch_idx,
@@ -49,7 +48,6 @@ std::vector<paddle::Tensor> AdjustBatchKernel(
4948
using data_t = typename PDTraits<T>::data_t;
5049
const int token_num = x.dims()[0];
5150
const int dim = x.dims()[1];
52-
const int bsz = cum_offsets.shape()[0];
5351
int enc_batch = len_info_cpu.data<int32_t>()[0];
5452
int dec_batch = len_info_cpu.data<int32_t>()[1];
5553

@@ -87,8 +85,7 @@ std::vector<paddle::Tensor> AdjustBatchKernel(
8785
}
8886

8987
using AdjustBatchKernelFuncPtr = std::vector<paddle::Tensor> (*)(
90-
const paddle::Tensor &x, // [token_num, dim_embed]
91-
const paddle::Tensor &cum_offsets, // [bsz, 1]
88+
const paddle::Tensor &x, // [token_num, dim_embed]
9289
const paddle::Tensor &encoder_seq_lod,
9390
const paddle::Tensor &decoder_seq_lod,
9491
const paddle::Tensor &encoder_batch_idx,
@@ -102,8 +99,7 @@ using AdjustBatchKernelFuncPtr = std::vector<paddle::Tensor> (*)(
10299
int max_input_length);
103100

104101
std::vector<paddle::Tensor> AdjustBatch(
105-
const paddle::Tensor &x, // [token_num, dim_embed]
106-
const paddle::Tensor &cum_offsets, // [bsz, 1]
102+
const paddle::Tensor &x, // [token_num, dim_embed]
107103
const paddle::Tensor &encoder_seq_lod,
108104
const paddle::Tensor &decoder_seq_lod,
109105
const paddle::Tensor &encoder_batch_idx,
@@ -135,7 +131,6 @@ std::vector<paddle::Tensor> AdjustBatch(
135131
}
136132

137133
return func(x,
138-
cum_offsets,
139134
encoder_seq_lod,
140135
decoder_seq_lod,
141136
encoder_batch_idx,
@@ -151,7 +146,6 @@ std::vector<paddle::Tensor> AdjustBatch(
151146

152147
std::vector<std::vector<int64_t>> AdjustBatchInferShape(
153148
const std::vector<int64_t> &x_shape,
154-
const std::vector<int64_t> &cum_offsets_shape,
155149
const std::vector<int64_t> &encoder_seq_lod_shape,
156150
const std::vector<int64_t> &decoder_seq_lod_shape,
157151
const std::vector<int64_t> &encoder_batch_idx_shape,
@@ -172,7 +166,6 @@ std::vector<std::vector<int64_t>> AdjustBatchInferShape(
172166

173167
std::vector<paddle::DataType> AdjustBatchInferDtype(
174168
const paddle::DataType &x_dtype,
175-
const paddle::DataType &cum_offsets_dtype,
176169
const paddle::DataType &encoder_seq_lod_dtype,
177170
const paddle::DataType &decoder_seq_lod_dtype,
178171
const paddle::DataType &encoder_batch_idx_dtype,
@@ -188,7 +181,6 @@ std::vector<paddle::DataType> AdjustBatchInferDtype(
188181

189182
PD_BUILD_STATIC_OP(adjust_batch)
190183
.Inputs({"x",
191-
"cum_offsets",
192184
"encoder_seq_lod",
193185
"decoder_seq_lod",
194186
"encoder_batch_idx",

custom_ops/xpu_ops/src/ops/block_attn.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ std::vector<paddle::Tensor> BlockAttnKernel(
6666
const paddle::Tensor& qkv,
6767
const paddle::Tensor& key_cache,
6868
const paddle::Tensor& value_cache,
69-
const paddle::Tensor& cum_offsets,
7069
const paddle::Tensor& rotary_embs,
7170
const paddle::Tensor& block_tables,
7271
const paddle::Tensor& prefix_block_tables,
@@ -122,7 +121,6 @@ std::vector<paddle::Tensor> BlockAttnKernel(
122121
auto qkv_shape = qkv.dims();
123122
auto cache_shape = key_cache.dims();
124123
auto block_table_shape = block_tables.dims();
125-
const int bsz = cum_offsets.dims()[0];
126124
const int block_batch = block_table_shape[0];
127125
const int max_block_per_seq = block_table_shape[1];
128126
const int kv_num_heads = cache_shape[1];
@@ -984,7 +982,6 @@ std::vector<paddle::Tensor> BlockAttn(
984982
const paddle::Tensor& qkv,
985983
const paddle::Tensor& key_cache,
986984
const paddle::Tensor& value_cache,
987-
const paddle::Tensor& cum_offsets,
988985
const paddle::Tensor& rotary_embs,
989986
const paddle::Tensor& block_tables,
990987
const paddle::Tensor& prefix_block_tables,
@@ -1023,7 +1020,6 @@ std::vector<paddle::Tensor> BlockAttn(
10231020
return BlockAttnKernel<TX, TC, TS>(qkv, \
10241021
key_cache, \
10251022
value_cache, \
1026-
cum_offsets, \
10271023
rotary_embs, \
10281024
block_tables, \
10291025
prefix_block_tables, \
@@ -1099,7 +1095,6 @@ PD_BUILD_STATIC_OP(block_attn)
10991095
.Inputs({"qkv",
11001096
"key_cache",
11011097
"value_cache",
1102-
"cum_offsets",
11031098
"rotary_embs",
11041099
"block_tables",
11051100
"prefix_block_tables",

custom_ops/xpu_ops/src/ops/gather_next_token.cc

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@
2222
#endif
2323

2424
std::vector<paddle::Tensor> GatherNextToken(
25-
const paddle::Tensor& x, // [token_num, dim_embed]
26-
const paddle::Tensor& cum_offsets, // [bsz, 1]
25+
const paddle::Tensor& x, // [token_num, dim_embed]
2726
const paddle::Tensor& encoder_seq_lod,
2827
const paddle::Tensor& decoder_seq_lod,
2928
const paddle::Tensor& encoder_batch_map,
@@ -46,7 +45,7 @@ std::vector<paddle::Tensor> GatherNextToken(
4645
typedef paddle::bfloat16 data_t;
4746
const int dim = x.dims()[1];
4847
const int token_num = x.shape()[0];
49-
int bsz = cum_offsets.shape()[0];
48+
int bsz = -1;
5049
int enc_batch = len_info_cpu.data<int32_t>()[0];
5150
int dec_batch = len_info_cpu.data<int32_t>()[1];
5251
if (max_bsz > 0) {
@@ -116,7 +115,6 @@ std::vector<paddle::Tensor> GatherNextToken(
116115

117116
std::vector<std::vector<int64_t>> GatherNextTokenInferShape(
118117
const std::vector<int64_t>& x_shape,
119-
const std::vector<int64_t>& cum_offsets_shape,
120118
const std::vector<int64_t>& encoder_seq_lod_shape,
121119
const std::vector<int64_t>& decoder_seq_lod_shape,
122120
const std::vector<int64_t>& encoder_batch_map_shape,
@@ -130,19 +128,18 @@ std::vector<std::vector<int64_t>> GatherNextTokenInferShape(
130128
// if (output_padding_offset_shape) {
131129
// PD_THROW("speculative decoding is not supported in XPU.");
132130
// }
133-
int64_t bsz = cum_offsets_shape[0];
131+
// int64_t bsz = cum_offsets_shape[0];
132+
int64_t bsz = 0;
134133
int64_t dim_embed = x_shape[1];
135134
if (output_padding_offset_shape) {
136135
return {{-1, dim_embed}};
137136
} else {
138-
int64_t bsz = cum_offsets_shape[0];
139137
return {{bsz, dim_embed}};
140138
}
141139
}
142140

143141
std::vector<paddle::DataType> GatherNextTokenInferDtype(
144142
const paddle::DataType& x_dtype,
145-
const paddle::DataType& cum_offsets_dtype,
146143
const paddle::DataType& encoder_seq_lod_dtype,
147144
const paddle::DataType& decoder_seq_lod_dtype,
148145
const paddle::DataType& encoder_batch_map_dtype,
@@ -158,7 +155,6 @@ std::vector<paddle::DataType> GatherNextTokenInferDtype(
158155

159156
PD_BUILD_STATIC_OP(gather_next_token)
160157
.Inputs({"x",
161-
"cum_offsets",
162158
"encoder_seq_lod",
163159
"decoder_seq_lod",
164160
"encoder_batch_map",

custom_ops/xpu_ops/src/ops/mtp/draft_model_update.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
2828
const paddle::Tensor& seq_lens_encoder,
2929
const paddle::Tensor& seq_lens_decoder,
3030
const paddle::Tensor& step_idx,
31-
const paddle::Tensor& output_cum_offsets,
31+
const paddle::Tensor& cu_seqlens_q_output,
3232
const paddle::Tensor& stop_flags,
3333
const paddle::Tensor& not_need_stop,
3434
const paddle::Tensor& max_dec_len,
@@ -72,7 +72,7 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
7272
const_cast<int*>(seq_lens_encoder.data<int>()),
7373
const_cast<int*>(seq_lens_decoder.data<int>()),
7474
const_cast<int64_t*>(step_idx.data<int64_t>()),
75-
output_cum_offsets.data<int>(),
75+
cu_seqlens_q_output.data<int>(),
7676
const_cast<bool*>(stop_flags.data<bool>()),
7777
const_cast<bool*>(not_need_stop_device.data<bool>()),
7878
max_dec_len.data<int64_t>(),
@@ -102,7 +102,7 @@ PD_BUILD_STATIC_OP(draft_model_update)
102102
"seq_lens_encoder",
103103
"seq_lens_decoder",
104104
"step_idx",
105-
"output_cum_offsets",
105+
"cu_seqlens_q_output",
106106
"stop_flags",
107107
"not_need_stop",
108108
"max_dec_len",
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <paddle/phi/backends/xpu/xpu_context.h>
16+
#include "paddle/extension.h"
17+
#include "xpu/plugin.h"
18+
19+
#ifndef PD_BUILD_STATIC_OP
20+
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
21+
#endif
22+
23+
namespace api = baidu::xpu::api;
24+
25+
std::vector<paddle::Tensor> SpeculatePreProcess(
26+
const int64_t cpu_token_num,
27+
const paddle::Tensor &input_ids,
28+
const paddle::Tensor &seq_len,
29+
const paddle::Tensor &draft_tokens,
30+
const paddle::Tensor &seq_lens_encoder,
31+
const paddle::Tensor &seq_lens_decoder) {
32+
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
33+
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
34+
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
35+
api::Context *ctx = xpu_ctx->x_context();
36+
37+
// just for ut to run base line
38+
std::unique_ptr<baidu::xpu::api::Context> cpu_ctx;
39+
if (input_ids.place().GetType() == phi::AllocationType::CPU) {
40+
cpu_ctx = std::make_unique<baidu::xpu::api::Context>(baidu::xpu::api::kCPU);
41+
ctx = cpu_ctx.get();
42+
}
43+
44+
std::vector<int64_t> input_ids_shape = input_ids.shape();
45+
const int bsz = seq_len.shape()[0];
46+
const int max_seq_len = input_ids_shape[1];
47+
const int token_num_data = cpu_token_num;
48+
auto ids_remove_padding = paddle::empty(
49+
{token_num_data}, paddle::DataType::INT64, input_ids.place());
50+
auto batch_id_per_token = paddle::empty(
51+
{token_num_data}, paddle::DataType::INT32, input_ids.place());
52+
auto cu_seqlens_q =
53+
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
54+
auto cu_seqlens_k =
55+
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
56+
const int max_draft_tokens_per_batch = draft_tokens.shape()[1];
57+
58+
auto seq_lens_output =
59+
paddle::empty({bsz}, paddle::DataType::INT32, input_ids.place());
60+
auto cu_seq_lens_q_output =
61+
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
62+
auto batch_id_per_token_output =
63+
paddle::empty({bsz * max_draft_tokens_per_batch},
64+
paddle::DataType::INT32,
65+
input_ids.place());
66+
auto real_output_token_num =
67+
paddle::empty({1}, paddle::DataType::INT32, input_ids.place());
68+
if (token_num_data == 0) {
69+
return {ids_remove_padding,
70+
batch_id_per_token,
71+
cu_seqlens_q,
72+
cu_seqlens_k,
73+
cu_seq_lens_q_output,
74+
batch_id_per_token_output,
75+
real_output_token_num};
76+
}
77+
78+
int64_t *ids_remove_padding_ptr = ids_remove_padding.data<int64_t>();
79+
int *batch_id_per_token_ptr = batch_id_per_token.data<int>();
80+
int *cu_seqlens_q_ptr = cu_seqlens_q.data<int>();
81+
int *cu_seqlens_k_ptr = cu_seqlens_k.data<int>();
82+
int *seq_lens_output_ptr = seq_lens_output.data<int>();
83+
int *cu_seq_lens_q_output_ptr = cu_seq_lens_q_output.data<int>();
84+
int *batch_id_per_token_output_ptr = batch_id_per_token_output.data<int>();
85+
int *real_output_token_num_ptr = real_output_token_num.data<int>();
86+
const int64_t *input_data_ptr = input_ids.data<int64_t>();
87+
const int *seq_len_ptr = seq_len.data<int>();
88+
const int64_t *draft_tokens_ptr = draft_tokens.data<int64_t>();
89+
const int *seq_lens_encoder_ptr = seq_lens_encoder.data<int>();
90+
91+
int r =
92+
fastdeploy::plugin::speculate_preprocess(ctx,
93+
ids_remove_padding_ptr,
94+
batch_id_per_token_ptr,
95+
cu_seqlens_q_ptr,
96+
cu_seqlens_k_ptr,
97+
seq_lens_output_ptr,
98+
cu_seq_lens_q_output_ptr,
99+
batch_id_per_token_output_ptr,
100+
real_output_token_num_ptr,
101+
input_data_ptr,
102+
seq_len_ptr,
103+
draft_tokens_ptr,
104+
seq_lens_encoder_ptr,
105+
max_seq_len,
106+
max_draft_tokens_per_batch,
107+
token_num_data,
108+
bsz);
109+
110+
return {ids_remove_padding,
111+
batch_id_per_token,
112+
cu_seqlens_q,
113+
cu_seqlens_k,
114+
cu_seq_lens_q_output,
115+
batch_id_per_token_output,
116+
real_output_token_num};
117+
}
118+
119+
PD_BUILD_STATIC_OP(speculate_pre_process)
120+
.Inputs({"input_ids",
121+
"seq_len",
122+
"draft_tokens",
123+
"seq_lens_encoder",
124+
"seq_lens_decoder"})
125+
.Outputs({"ids_remove_padding",
126+
"batch_id_per_token",
127+
"cu_seqlens_q",
128+
"cu_seqlens_k",
129+
"cu_seq_lens_q_output",
130+
"batch_id_per_token_output",
131+
"real_output_token_num"})
132+
.Attrs({"cpu_token_num: int64_t"})
133+
.SetKernelFn(PD_KERNEL(SpeculatePreProcess));

custom_ops/xpu_ops/src/ops/mtp/speculate_token_penalty_multi_scores.cc

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ void SpeculateTokenPenaltyMultiScores(
3333
const paddle::Tensor& min_len,
3434
const paddle::Tensor& eos_token_id,
3535
const paddle::Tensor& seq_lens_this_time,
36-
const paddle::Tensor& output_padding_offset,
37-
const paddle::Tensor& output_cum_offsets,
36+
const paddle::Tensor& batch_id_per_token_output,
37+
const paddle::Tensor& cu_seqlens_q_output,
3838
const int max_seq_len) {
3939
namespace api = baidu::xpu::api;
4040
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
@@ -72,8 +72,8 @@ void SpeculateTokenPenaltyMultiScores(
7272
min_len.data<int64_t>(),
7373
eos_token_id.data<int64_t>(),
7474
bad_tokens.data<int64_t>(),
75-
output_padding_offset.data<int>(),
76-
output_cum_offsets.data<int>(),
75+
batch_id_per_token_output.data<int>(),
76+
cu_seqlens_q_output.data<int>(),
7777
bs,
7878
length,
7979
length_id,
@@ -100,8 +100,8 @@ void SpeculateTokenPenaltyMultiScores(
100100
min_len.data<int64_t>(),
101101
eos_token_id.data<int64_t>(),
102102
bad_tokens.data<int64_t>(),
103-
output_padding_offset.data<int>(),
104-
output_cum_offsets.data<int>(),
103+
batch_id_per_token_output.data<int>(),
104+
cu_seqlens_q_output.data<int>(),
105105
bs,
106106
length,
107107
length_id,
@@ -125,8 +125,8 @@ void SpeculateTokenPenaltyMultiScores(
125125
min_len.data<int64_t>(),
126126
eos_token_id.data<int64_t>(),
127127
bad_tokens.data<int64_t>(),
128-
output_padding_offset.data<int>(),
129-
output_cum_offsets.data<int>(),
128+
batch_id_per_token_output.data<int>(),
129+
cu_seqlens_q_output.data<int>(),
130130
bs,
131131
length,
132132
length_id,
@@ -157,8 +157,8 @@ PD_BUILD_STATIC_OP(speculate_get_token_penalty_multi_scores)
157157
"min_len",
158158
"eos_token_id",
159159
"seq_lens_this_time",
160-
"output_padding_offset",
161-
"output_cum_offsets"})
160+
"batch_id_per_token_output",
161+
"cu_seqlens_q_output"})
162162
.Outputs({"logits_out"})
163163
.Attrs({"max_seq_len: int"})
164164
.SetInplaceMap({{"logits", "logits_out"}})

0 commit comments

Comments
 (0)