Skip to content

Commit 7f594d0

Browse files
authored
[INTEL_HPU] support prefix caching in prepare_block_metadata and fused_sdpa_proj_t (PaddlePaddle#2086)
1 parent f20ee5c commit 7f594d0

2 files changed

Lines changed: 144 additions & 29 deletions

File tree

backends/intel_hpu/custom_ops/llama_infer/fused_sdpa_proj_t.cc

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,13 @@ class FusedSdpaProjBTMH : public HpuFusedOperator {
171171
attn_inputs.push_back(q_r);
172172
attn_inputs.push_back(k_r);
173173
attn_inputs.push_back(v_r);
174-
174+
if (!params.sdpa_params.is_causal) {
175+
attn_inputs.push_back(createTensor(inputs[3].dims.size(),
176+
inputs[3].type,
177+
inputs[3].dims,
178+
true,
179+
inputs[3].name));
180+
}
175181
if (params.fp8_sdpa) {
176182
attn_inputs.push_back(nullptr); // Mask
177183
attn_inputs.push_back(nullptr); // Seed
@@ -310,6 +316,7 @@ void FusedSdpaProjBTMHKernel(
310316
const Context& dev_ctx,
311317
const phi::DenseTensor& query_states,
312318
const phi::DenseTensor& key_value_states,
319+
const phi::DenseTensor& attn_mask,
313320
const phi::DenseTensor& linear_weights,
314321
phi::DenseTensor* out_linear,
315322
const phi::Scalar& scaling_factor,
@@ -329,6 +336,9 @@ void FusedSdpaProjBTMHKernel(
329336
std::vector<DIMS> in_out_dims = ct.GetDims();
330337

331338
ct.Add(linear_weights);
339+
if (causal.to<bool>() == false) {
340+
ct.Add(attn_mask);
341+
}
332342

333343
unsigned int flags = 0;
334344
SDPA_SET_INPUT_AND_FLAGS(d_scale_q.get_ptr(), D_SCALE_Q)
@@ -422,6 +432,12 @@ std::vector<paddle::Tensor> FusedBaseSdpaProjBTMH(
422432
static_cast<const phi::DenseTensor*>(query_states.impl().get());
423433
auto key_value_states_tensor =
424434
static_cast<const phi::DenseTensor*>(key_value_states.impl().get());
435+
phi::DenseTensor* attn_mask_tensor = nullptr;
436+
if (attn_mask) {
437+
auto attn_mask_ptr = *(attn_mask.get_ptr());
438+
attn_mask_tensor =
439+
static_cast<phi::DenseTensor*>(attn_mask_ptr.impl().get());
440+
}
425441
auto linear_weights_tensor =
426442
static_cast<const phi::DenseTensor*>(linear_weights.impl().get());
427443

@@ -503,12 +519,13 @@ std::vector<paddle::Tensor> FusedBaseSdpaProjBTMH(
503519
dev_ctx->Alloc(out_linear.get(), query_states_tensor->dtype());
504520
}
505521

506-
if (!attn_mask && !valid_seq_len) {
522+
if (!valid_seq_len) {
507523
if (query_states.dtype() == phi::DataType::FLOAT16) {
508524
custom_kernel::FusedSdpaProjBTMHKernel<phi::dtype::float16>(
509525
*dev_ctx,
510526
*query_states_tensor,
511527
*key_value_states_tensor,
528+
attn_mask_tensor ? *attn_mask_tensor : phi::DenseTensor(),
512529
*linear_weights_tensor,
513530
out_linear.get(),
514531
phi::Scalar(scaling_factor),
@@ -528,6 +545,7 @@ std::vector<paddle::Tensor> FusedBaseSdpaProjBTMH(
528545
*dev_ctx,
529546
*query_states_tensor,
530547
*key_value_states_tensor,
548+
attn_mask_tensor ? *attn_mask_tensor : phi::DenseTensor(),
531549
*linear_weights_tensor,
532550
out_linear.get(),
533551
phi::Scalar(scaling_factor),

backends/intel_hpu/custom_ops/llama_infer/prepare_block_metadata.cc

Lines changed: 124 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,32 @@
2121
#include "utils/utils.h"
2222

2323
// retrun: amax and where(>0)
24-
std::pair<int, std::vector<int>> get_max_and_where_nonzero(
25-
int* seq_lens_encoder, const int elem_cnt) {
26-
int max_seq_len = seq_lens_encoder[0];
24+
std::tuple<int, int, int, std::vector<int>> get_max_and_where_nonzero(
25+
int* seq_lens_encoder, int* seq_lens_decoder, const int elem_cnt) {
26+
int max_seq_len_without_context = 0;
27+
int max_seq_len_with_context = 0;
28+
int max_context_len = 0;
2729
std::vector<int> valid_batch;
2830
for (int i = 0; i < elem_cnt; ++i) {
29-
if (seq_lens_encoder[i] > max_seq_len) {
30-
max_seq_len = seq_lens_encoder[i];
31-
}
3231
if (seq_lens_encoder[i] > 0) {
3332
valid_batch.push_back(i);
33+
if (seq_lens_encoder[i] > max_seq_len_without_context) {
34+
max_seq_len_without_context = seq_lens_encoder[i];
35+
max_seq_len_with_context = seq_lens_encoder[i];
36+
}
37+
if (seq_lens_decoder[i] > max_context_len) {
38+
max_context_len = seq_lens_decoder[i];
39+
}
40+
if (seq_lens_decoder[i] > 0 && seq_lens_encoder[i] + seq_lens_decoder[i] >
41+
max_seq_len_with_context) {
42+
max_seq_len_with_context = seq_lens_encoder[i] + seq_lens_decoder[i];
43+
}
3444
}
3545
}
36-
return {max_seq_len, valid_batch};
46+
return {max_seq_len_without_context,
47+
max_seq_len_with_context,
48+
max_context_len,
49+
valid_batch};
3750
}
3851

3952
// return: where(>0)
@@ -90,16 +103,41 @@ void pad_fill(const T* input_p,
90103

91104
template <typename T>
92105
void pad_fill(const T* input_p,
106+
const T* offsets,
93107
T* padded,
94108
std::vector<int> valid_batches,
95109
int input_linewidth,
96-
int padded_linewidth) {
110+
int padded_linewidth,
111+
int max_context_len,
112+
int block_size) {
113+
int copy_len = std::min(input_linewidth, padded_linewidth);
114+
#pragma omp parallel for num_threads(OMP_THREAD_NUM)
115+
for (int i = 0; i < static_cast<int>(valid_batches.size()); ++i) {
116+
for (int j = (max_context_len - offsets[valid_batches[i]]) / block_size,
117+
k = 0;
118+
j < copy_len && k < copy_len;
119+
++j, ++k) {
120+
padded[i * padded_linewidth + j] =
121+
input_p[valid_batches[i] * input_linewidth + k];
122+
}
123+
}
124+
}
125+
126+
template <typename T>
127+
void pad_fill(const T* input_p,
128+
const T* offsets,
129+
T* padded,
130+
std::vector<int> valid_batches,
131+
int input_linewidth,
132+
int padded_linewidth,
133+
int block_size) {
97134
int copy_len = std::min(input_linewidth, padded_linewidth);
98135
#pragma omp parallel for num_threads(OMP_THREAD_NUM)
99136
for (int i = 0; i < static_cast<int>(valid_batches.size()); ++i) {
100137
for (int j = 0; j < copy_len; ++j) {
101138
padded[i * padded_linewidth + j] =
102-
input_p[valid_batches[i] * input_linewidth + j];
139+
input_p[valid_batches[i] * input_linewidth + j +
140+
offsets[valid_batches[i]] / block_size];
103141
}
104142
}
105143
}
@@ -183,6 +221,10 @@ std::vector<paddle::Tensor> PrepareBlockMetadata(
183221
const char* env_prefill_batch_step = std::getenv("BATCH_STEP_PREFILL");
184222
const int batch_step_prefill =
185223
env_prefill_batch_step ? std::atoi(env_prefill_batch_step) : 1;
224+
const char* env_context_block_step =
225+
std::getenv("CONTEXT_BLOCK_STEP_PREFILL");
226+
const int context_block_step_prefill =
227+
env_context_block_step ? std::atoi(env_context_block_step) : 1;
186228
const char* env_decode_batch_step = std::getenv("BATCH_STEP_DECODE");
187229
const int batch_step_decode =
188230
env_decode_batch_step ? std::atoi(env_decode_batch_step) : 4;
@@ -194,8 +236,14 @@ std::vector<paddle::Tensor> PrepareBlockMetadata(
194236
const int max_blocks_each = block_tables.shape()[1];
195237
phi::DataType device_dtype = phi::StringToDataType(dtype);
196238

197-
auto [max_enc_len, valid_batches_enc] = get_max_and_where_nonzero(
198-
const_cast<int*>(seq_lens_encoder_cpu.data<int>()), max_batches_in);
239+
auto [max_enc_len_without_context,
240+
max_enc_len_with_context,
241+
max_context_len,
242+
valid_batches_enc] =
243+
get_max_and_where_nonzero(
244+
const_cast<int*>(seq_lens_encoder_cpu.data<int>()),
245+
const_cast<int*>(seq_lens_decoder_cpu.data<int>()),
246+
max_batches_in);
199247
int enc_count = valid_batches_enc.size();
200248

201249
auto valid_batches_dec = where_nonzero(
@@ -223,34 +271,83 @@ std::vector<paddle::Tensor> PrepareBlockMetadata(
223271

224272
auto input_ids_cpu = input_ids_selected.copy_to(paddle::CPUPlace(), true);
225273

226-
int max_buckets = (max_enc_len + block_size - 1) / block_size;
227-
int max_prompt_len = max_buckets * block_size;
274+
int max_buckets_without_context =
275+
(max_enc_len_without_context + block_size - 1) / block_size;
276+
int max_prompt_len_without_context =
277+
max_buckets_without_context * block_size;
228278

229-
auto src_padded = paddle::full({total_batch * max_prompt_len},
230-
0,
231-
paddle::DataType::INT64,
232-
paddle::CPUPlace());
279+
auto src_padded =
280+
paddle::full({total_batch * max_prompt_len_without_context},
281+
0,
282+
paddle::DataType::INT64,
283+
paddle::CPUPlace());
233284
pad_fill<int64_t>(const_cast<int64_t*>(input_ids_cpu.data<int64_t>()),
234285
reinterpret_cast<int64_t*>(src_padded.data<int64_t>()),
235286
static_cast<int>(valid_batches_enc.size()),
236287
max_seq_len,
237-
max_prompt_len);
288+
max_prompt_len_without_context);
238289

239-
auto blk_padded = paddle::full({total_batch * max_buckets},
290+
auto blk_padded = paddle::full({total_batch * max_buckets_without_context},
240291
-1,
241292
paddle::DataType::INT32,
242293
paddle::CPUPlace());
243-
pad_fill<int32_t>(const_cast<int32_t*>(block_tables_cpu.data<int32_t>()),
244-
reinterpret_cast<int32_t*>(blk_padded.data<int32_t>()),
245-
valid_batches_enc,
246-
max_blocks_each,
247-
max_buckets);
294+
pad_fill<int32_t>(
295+
const_cast<int32_t*>(block_tables_cpu.data<int32_t>()),
296+
const_cast<int32_t*>(seq_lens_decoder_cpu.data<int32_t>()),
297+
reinterpret_cast<int32_t*>(blk_padded.data<int32_t>()),
298+
valid_batches_enc,
299+
max_blocks_each,
300+
max_buckets_without_context,
301+
block_size);
248302

249303
auto blk_padded_hpu =
250304
custom_kernel::copy_tensor_wrapper(dev_ctx, blk_padded, hpu_place);
251305

252-
auto rope_emb_seg = paddle::experimental::slice(
253-
rope_emb, {2}, {0}, {max_prompt_len}, {}, {});
306+
int max_buckets_with_context =
307+
(max_enc_len_with_context + block_size - 1) / block_size;
308+
max_buckets_with_context =
309+
((max_buckets_with_context + context_block_step_prefill - 1) /
310+
context_block_step_prefill) *
311+
context_block_step_prefill;
312+
int max_prompt_len_with_context = max_buckets_with_context * block_size;
313+
314+
auto block_list_padded =
315+
paddle::full({total_batch * max_buckets_with_context},
316+
-1,
317+
paddle::DataType::INT32,
318+
paddle::CPUPlace());
319+
pad_fill<int32_t>(
320+
const_cast<int32_t*>(block_tables_cpu.data<int32_t>()),
321+
const_cast<int32_t*>(seq_lens_decoder_cpu.data<int32_t>()),
322+
reinterpret_cast<int32_t*>(block_list_padded.data<int32_t>()),
323+
valid_batches_enc,
324+
max_blocks_each,
325+
max_buckets_with_context,
326+
max_context_len,
327+
block_size);
328+
329+
auto block_list_hpu = custom_kernel::copy_tensor_wrapper(
330+
dev_ctx, block_list_padded, hpu_place);
331+
332+
paddle::Tensor rope_emb_seg;
333+
if (max_prompt_len_without_context == max_prompt_len_with_context) {
334+
rope_emb_seg = paddle::experimental::slice(
335+
rope_emb, {2}, {0}, {max_prompt_len_without_context}, {}, {});
336+
} else {
337+
std::vector<paddle::Tensor> rope_emb_segs;
338+
for (auto b : valid_batches_enc) {
339+
int start = seq_lens_decoder_cpu.data<int>()[b];
340+
auto seg = paddle::experimental::slice(
341+
rope_emb,
342+
{2},
343+
{start},
344+
{start + max_prompt_len_without_context},
345+
{},
346+
{});
347+
rope_emb_segs.push_back(seg);
348+
}
349+
rope_emb_seg = paddle::experimental::concat(rope_emb_segs, 1);
350+
}
254351
rope_emb_seg = paddle::experimental::cast(rope_emb_seg, device_dtype);
255352

256353
auto total_batch_cpu_tensor = paddle::full(
@@ -262,7 +359,7 @@ std::vector<paddle::Tensor> PrepareBlockMetadata(
262359
return {src_padded,
263360
rope_emb_seg,
264361
dummy_tensor,
265-
dummy_tensor,
362+
block_list_hpu,
266363
blk_padded_hpu,
267364
dummy_tensor,
268365
dummy_tensor,

0 commit comments

Comments
 (0)