Skip to content

Commit 8277b95

Browse files
remove speculate_get_padding_offset op (PaddlePaddle#6308)
1 parent 39dc4b0 commit 8277b95

7 files changed

Lines changed: 44 additions & 324 deletions

File tree

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -407,9 +407,12 @@ void GetBlockShapeAndSplitKVBlock(
407407
const int group_size,
408408
const int block_size);
409409

410-
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor& input_ids,
411-
const paddle::Tensor& seq_len,
412-
const int64_t token_num_cpu);
410+
std::vector<paddle::Tensor> GetPaddingOffset(
411+
const paddle::Tensor& input_ids,
412+
const paddle::Tensor& seq_len,
413+
const paddle::optional<paddle::Tensor>& draft_tokens,
414+
const paddle::optional<paddle::Tensor>& seq_lens_encoder,
415+
const int64_t token_num_cpu);
413416

414417
void SetValueByFlagsAndIdx(const paddle::Tensor& pre_ids_all,
415418
const paddle::Tensor& input_ids,
@@ -739,15 +742,6 @@ void free_shared_buffer(int64_t buffer);
739742

740743
void clear_ipc_handles(int64_t _fa);
741744

742-
// speculative decoding Kernel
743-
std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
744-
const paddle::Tensor& input_ids,
745-
const paddle::Tensor& draft_tokens,
746-
const paddle::Tensor& cum_offsets,
747-
const paddle::Tensor& seq_len,
748-
const paddle::Tensor& seq_lens_encoder,
749-
const int64_t token_num_cpu);
750-
751745
std::vector<paddle::Tensor> SpeculateGetSeqLensOutput(
752746
const paddle::Tensor& seq_lens_this_time,
753747
const paddle::Tensor& seq_lens_encoder,
@@ -1596,11 +1590,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
15961590
&get_graph_buffer_ipc_meta,
15971591
"get_graph_buffer_ipc_meta");
15981592

1599-
// speculative decoding Kernel
1600-
m.def("speculate_get_padding_offset",
1601-
&SpeculateGetPaddingOffset,
1602-
"speculate_get_padding_offset function");
1603-
16041593
m.def("speculate_get_seq_lens_output",
16051594
&SpeculateGetSeqLensOutput,
16061595
"speculate_get_seq_lens_output function");

custom_ops/gpu_ops/get_padding_offset.cu

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ __global__ void PrefixSumKernel(int64_t *ids_remove_padding,
2525
int *cu_seqlens_k,
2626
const int64_t *input_data,
2727
const int *seq_lens,
28-
const int max_seq_len) {
28+
const int max_seq_len,
29+
const int64_t *draft_tokens,
30+
const int *seq_lens_encoder,
31+
const int max_draft_tokens_per_batch) {
2932
const int bi = blockIdx.x;
3033
const int tid = threadIdx.x;
3134
#ifdef PADDLE_WITH_COREX
@@ -62,15 +65,26 @@ __global__ void PrefixSumKernel(int64_t *ids_remove_padding,
6265

6366
for (int i = tid; i < seq_lens[bi]; i += blockDim.x) {
6467
const int tgt_seq_id = cum_seq_len - seq_lens[bi] + i;
65-
const int src_seq_id = bi * max_seq_len + i;
66-
ids_remove_padding[tgt_seq_id] = input_data[src_seq_id];
68+
if (max_draft_tokens_per_batch > 0 && seq_lens_encoder[bi] <= 0) {
69+
// speculative decoding
70+
const int src_seq_id = bi * max_draft_tokens_per_batch + i;
71+
ids_remove_padding[tgt_seq_id] = draft_tokens[src_seq_id];
72+
} else {
73+
// Non-speculative decoding
74+
const int src_seq_id = bi * max_seq_len + i;
75+
ids_remove_padding[tgt_seq_id] = input_data[src_seq_id];
76+
}
77+
6778
batch_id_per_token[tgt_seq_id] = bi;
6879
}
6980
}
7081

71-
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
72-
const paddle::Tensor &seq_len,
73-
const int64_t cpu_token_num) {
82+
std::vector<paddle::Tensor> GetPaddingOffset(
83+
const paddle::Tensor &input_ids,
84+
const paddle::Tensor &seq_len,
85+
const paddle::optional<paddle::Tensor> &draft_tokens,
86+
const paddle::optional<paddle::Tensor> &seq_lens_encoder,
87+
const int64_t cpu_token_num) {
7488
#ifdef PADDLE_WITH_CUSTOM_DEVICE
7589
auto dev_ctx = static_cast<const phi::CustomContext *>(
7690
paddle::experimental::DeviceContextPool::Instance().Get(
@@ -98,14 +112,23 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
98112
int blockSize =
99113
min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128);
100114
#endif
115+
116+
int max_draft_tokens_per_batch = -1;
117+
if (draft_tokens) {
118+
max_draft_tokens_per_batch = draft_tokens.get().shape()[1];
119+
}
120+
101121
PrefixSumKernel<<<bsz, blockSize, 0, cu_stream>>>(
102122
x_remove_padding.data<int64_t>(),
103123
batch_id_per_token.data<int>(),
104124
cu_seqlens_q.data<int>(),
105125
cu_seqlens_k.data<int>(),
106126
input_ids.data<int64_t>(),
107127
seq_len.data<int>(),
108-
max_seq_len);
128+
max_seq_len,
129+
draft_tokens ? draft_tokens.get().data<int64_t>() : nullptr,
130+
seq_lens_encoder ? seq_lens_encoder.get().data<int32_t>() : nullptr,
131+
max_draft_tokens_per_batch);
109132

110133
return {x_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k};
111134
}
@@ -127,7 +150,10 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
127150
}
128151

129152
PD_BUILD_STATIC_OP(get_padding_offset)
130-
.Inputs({"input_ids", "seq_len"})
153+
.Inputs({"input_ids",
154+
"seq_len",
155+
paddle::Optional("draft_tokens"),
156+
paddle::Optional("seq_lens_encoder")})
131157
.Outputs({"x_remove_padding",
132158
"batch_id_per_token",
133159
"cu_seqlens_q",

custom_ops/gpu_ops/speculate_decoding/speculate_get_padding_offset.cu

Lines changed: 0 additions & 149 deletions
This file was deleted.

fastdeploy/model_executor/pre_and_post_process.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959
save_output_topk,
6060
set_stop_value_multi_ends,
6161
speculate_get_output_padding_offset,
62-
speculate_get_padding_offset,
6362
speculate_get_seq_lens_output,
6463
speculate_limit_thinking_content_length_v1,
6564
speculate_limit_thinking_content_length_v2,
@@ -86,7 +85,6 @@
8685
save_output_topk,
8786
set_stop_value_multi_ends,
8887
speculate_get_output_padding_offset,
89-
speculate_get_padding_offset,
9088
speculate_get_seq_lens_output,
9189
speculate_save_output,
9290
speculate_save_output_topk,
@@ -226,7 +224,7 @@ def pre_process(
226224
if specific_platform and not speculative_decoding:
227225
# Note(ZKK): This case's code is very simple!
228226
ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k = get_padding_offset(
229-
input_ids, seq_lens_this_time, token_num_cpu
227+
input_ids, seq_lens_this_time, None, None, token_num_cpu
230228
)
231229
return (
232230
ids_remove_padding,
@@ -247,9 +245,7 @@ def pre_process(
247245
batch_id_per_token,
248246
cu_seqlens_q,
249247
cu_seqlens_k,
250-
) = speculate_get_padding_offset(
251-
input_ids, draft_tokens, cum_offsets_now, seq_lens_this_time, seq_lens_encoder, token_num_cpu
252-
)
248+
) = get_padding_offset(input_ids, seq_lens_this_time, draft_tokens, seq_lens_encoder, token_num_cpu)
253249
seq_lens_output = speculate_get_seq_lens_output(
254250
seq_lens_this_time,
255251
seq_lens_encoder,

tests/layers/test_attention_layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def create_forward_meta(
274274
input_ids = paddle.zeros([batch_size, max_model_len], dtype="int64")
275275
token_num = np.sum(seq_lens_this_time)
276276
ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k = get_padding_offset(
277-
input_ids, seq_lens_this_time, token_num
277+
input_ids, seq_lens_this_time, None, None, token_num
278278
)
279279

280280
forward_meta = ForwardMeta(

tests/operators/test_get_padding_offset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def test_get_padding_offset(self):
3232
batch_id_per_token,
3333
cu_seqlens_q,
3434
cu_seqlens_k,
35-
) = get_padding_offset(paddle.to_tensor(input_ids), paddle.to_tensor(seq_lens), token_num_cpu)
35+
) = get_padding_offset(paddle.to_tensor(input_ids), paddle.to_tensor(seq_lens), None, None, token_num_cpu)
3636

3737
ref_x_remove_padding = np.array([8, 7, 8, 2, 4, 5, 5, 7, 6, 1, 7, 2, 6], "int64")
3838
ref_batch_id_per_token = np.array([0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 2], "int32")

0 commit comments

Comments
 (0)