2121#include " utils/utils.h"
2222
2323// retrun: amax and where(>0)
24- std::pair<int , std::vector<int >> get_max_and_where_nonzero (
25- int * seq_lens_encoder, const int elem_cnt) {
26- int max_seq_len = seq_lens_encoder[0 ];
24+ std::tuple<int , int , int , std::vector<int >> get_max_and_where_nonzero (
25+ int * seq_lens_encoder, int * seq_lens_decoder, const int elem_cnt) {
26+ int max_seq_len_without_context = 0 ;
27+ int max_seq_len_with_context = 0 ;
28+ int max_context_len = 0 ;
2729 std::vector<int > valid_batch;
2830 for (int i = 0 ; i < elem_cnt; ++i) {
29- if (seq_lens_encoder[i] > max_seq_len) {
30- max_seq_len = seq_lens_encoder[i];
31- }
3231 if (seq_lens_encoder[i] > 0 ) {
3332 valid_batch.push_back (i);
33+ if (seq_lens_encoder[i] > max_seq_len_without_context) {
34+ max_seq_len_without_context = seq_lens_encoder[i];
35+ max_seq_len_with_context = seq_lens_encoder[i];
36+ }
37+ if (seq_lens_decoder[i] > max_context_len) {
38+ max_context_len = seq_lens_decoder[i];
39+ }
40+ if (seq_lens_decoder[i] > 0 && seq_lens_encoder[i] + seq_lens_decoder[i] >
41+ max_seq_len_with_context) {
42+ max_seq_len_with_context = seq_lens_encoder[i] + seq_lens_decoder[i];
43+ }
3444 }
3545 }
36- return {max_seq_len, valid_batch};
46+ return {max_seq_len_without_context,
47+ max_seq_len_with_context,
48+ max_context_len,
49+ valid_batch};
3750}
3851
3952// return: where(>0)
@@ -90,16 +103,41 @@ void pad_fill(const T* input_p,
90103
91104template <typename T>
92105void pad_fill (const T* input_p,
106+ const T* offsets,
93107 T* padded,
94108 std::vector<int > valid_batches,
95109 int input_linewidth,
96- int padded_linewidth) {
110+ int padded_linewidth,
111+ int max_context_len,
112+ int block_size) {
113+ int copy_len = std::min (input_linewidth, padded_linewidth);
114+ #pragma omp parallel for num_threads(OMP_THREAD_NUM)
115+ for (int i = 0 ; i < static_cast <int >(valid_batches.size ()); ++i) {
116+ for (int j = (max_context_len - offsets[valid_batches[i]]) / block_size,
117+ k = 0 ;
118+ j < copy_len && k < copy_len;
119+ ++j, ++k) {
120+ padded[i * padded_linewidth + j] =
121+ input_p[valid_batches[i] * input_linewidth + k];
122+ }
123+ }
124+ }
125+
126+ template <typename T>
127+ void pad_fill (const T* input_p,
128+ const T* offsets,
129+ T* padded,
130+ std::vector<int > valid_batches,
131+ int input_linewidth,
132+ int padded_linewidth,
133+ int block_size) {
97134 int copy_len = std::min (input_linewidth, padded_linewidth);
98135#pragma omp parallel for num_threads(OMP_THREAD_NUM)
99136 for (int i = 0 ; i < static_cast <int >(valid_batches.size ()); ++i) {
100137 for (int j = 0 ; j < copy_len; ++j) {
101138 padded[i * padded_linewidth + j] =
102- input_p[valid_batches[i] * input_linewidth + j];
139+ input_p[valid_batches[i] * input_linewidth + j +
140+ offsets[valid_batches[i]] / block_size];
103141 }
104142 }
105143}
@@ -183,6 +221,10 @@ std::vector<paddle::Tensor> PrepareBlockMetadata(
183221 const char * env_prefill_batch_step = std::getenv (" BATCH_STEP_PREFILL" );
184222 const int batch_step_prefill =
185223 env_prefill_batch_step ? std::atoi (env_prefill_batch_step) : 1 ;
224+ const char * env_context_block_step =
225+ std::getenv (" CONTEXT_BLOCK_STEP_PREFILL" );
226+ const int context_block_step_prefill =
227+ env_context_block_step ? std::atoi (env_context_block_step) : 1 ;
186228 const char * env_decode_batch_step = std::getenv (" BATCH_STEP_DECODE" );
187229 const int batch_step_decode =
188230 env_decode_batch_step ? std::atoi (env_decode_batch_step) : 4 ;
@@ -194,8 +236,14 @@ std::vector<paddle::Tensor> PrepareBlockMetadata(
194236 const int max_blocks_each = block_tables.shape ()[1 ];
195237 phi::DataType device_dtype = phi::StringToDataType (dtype);
196238
197- auto [max_enc_len, valid_batches_enc] = get_max_and_where_nonzero (
198- const_cast <int *>(seq_lens_encoder_cpu.data <int >()), max_batches_in);
239+ auto [max_enc_len_without_context,
240+ max_enc_len_with_context,
241+ max_context_len,
242+ valid_batches_enc] =
243+ get_max_and_where_nonzero (
244+ const_cast <int *>(seq_lens_encoder_cpu.data <int >()),
245+ const_cast <int *>(seq_lens_decoder_cpu.data <int >()),
246+ max_batches_in);
199247 int enc_count = valid_batches_enc.size ();
200248
201249 auto valid_batches_dec = where_nonzero (
@@ -223,34 +271,83 @@ std::vector<paddle::Tensor> PrepareBlockMetadata(
223271
224272 auto input_ids_cpu = input_ids_selected.copy_to (paddle::CPUPlace (), true );
225273
226- int max_buckets = (max_enc_len + block_size - 1 ) / block_size;
227- int max_prompt_len = max_buckets * block_size;
274+ int max_buckets_without_context =
275+ (max_enc_len_without_context + block_size - 1 ) / block_size;
276+ int max_prompt_len_without_context =
277+ max_buckets_without_context * block_size;
228278
229- auto src_padded = paddle::full ({total_batch * max_prompt_len},
230- 0 ,
231- paddle::DataType::INT64 ,
232- paddle::CPUPlace ());
279+ auto src_padded =
280+ paddle::full ({total_batch * max_prompt_len_without_context},
281+ 0 ,
282+ paddle::DataType::INT64 ,
283+ paddle::CPUPlace ());
233284 pad_fill<int64_t >(const_cast <int64_t *>(input_ids_cpu.data <int64_t >()),
234285 reinterpret_cast <int64_t *>(src_padded.data <int64_t >()),
235286 static_cast <int >(valid_batches_enc.size ()),
236287 max_seq_len,
237- max_prompt_len );
288+ max_prompt_len_without_context );
238289
239- auto blk_padded = paddle::full ({total_batch * max_buckets },
290+ auto blk_padded = paddle::full ({total_batch * max_buckets_without_context },
240291 -1 ,
241292 paddle::DataType::INT32 ,
242293 paddle::CPUPlace ());
243- pad_fill<int32_t >(const_cast <int32_t *>(block_tables_cpu.data <int32_t >()),
244- reinterpret_cast <int32_t *>(blk_padded.data <int32_t >()),
245- valid_batches_enc,
246- max_blocks_each,
247- max_buckets);
294+ pad_fill<int32_t >(
295+ const_cast <int32_t *>(block_tables_cpu.data <int32_t >()),
296+ const_cast <int32_t *>(seq_lens_decoder_cpu.data <int32_t >()),
297+ reinterpret_cast <int32_t *>(blk_padded.data <int32_t >()),
298+ valid_batches_enc,
299+ max_blocks_each,
300+ max_buckets_without_context,
301+ block_size);
248302
249303 auto blk_padded_hpu =
250304 custom_kernel::copy_tensor_wrapper (dev_ctx, blk_padded, hpu_place);
251305
252- auto rope_emb_seg = paddle::experimental::slice (
253- rope_emb, {2 }, {0 }, {max_prompt_len}, {}, {});
306+ int max_buckets_with_context =
307+ (max_enc_len_with_context + block_size - 1 ) / block_size;
308+ max_buckets_with_context =
309+ ((max_buckets_with_context + context_block_step_prefill - 1 ) /
310+ context_block_step_prefill) *
311+ context_block_step_prefill;
312+ int max_prompt_len_with_context = max_buckets_with_context * block_size;
313+
314+ auto block_list_padded =
315+ paddle::full ({total_batch * max_buckets_with_context},
316+ -1 ,
317+ paddle::DataType::INT32 ,
318+ paddle::CPUPlace ());
319+ pad_fill<int32_t >(
320+ const_cast <int32_t *>(block_tables_cpu.data <int32_t >()),
321+ const_cast <int32_t *>(seq_lens_decoder_cpu.data <int32_t >()),
322+ reinterpret_cast <int32_t *>(block_list_padded.data <int32_t >()),
323+ valid_batches_enc,
324+ max_blocks_each,
325+ max_buckets_with_context,
326+ max_context_len,
327+ block_size);
328+
329+ auto block_list_hpu = custom_kernel::copy_tensor_wrapper (
330+ dev_ctx, block_list_padded, hpu_place);
331+
332+ paddle::Tensor rope_emb_seg;
333+ if (max_prompt_len_without_context == max_prompt_len_with_context) {
334+ rope_emb_seg = paddle::experimental::slice (
335+ rope_emb, {2 }, {0 }, {max_prompt_len_without_context}, {}, {});
336+ } else {
337+ std::vector<paddle::Tensor> rope_emb_segs;
338+ for (auto b : valid_batches_enc) {
339+ int start = seq_lens_decoder_cpu.data <int >()[b];
340+ auto seg = paddle::experimental::slice (
341+ rope_emb,
342+ {2 },
343+ {start},
344+ {start + max_prompt_len_without_context},
345+ {},
346+ {});
347+ rope_emb_segs.push_back (seg);
348+ }
349+ rope_emb_seg = paddle::experimental::concat (rope_emb_segs, 1 );
350+ }
254351 rope_emb_seg = paddle::experimental::cast (rope_emb_seg, device_dtype);
255352
256353 auto total_batch_cpu_tensor = paddle::full (
@@ -262,7 +359,7 @@ std::vector<paddle::Tensor> PrepareBlockMetadata(
262359 return {src_padded,
263360 rope_emb_seg,
264361 dummy_tensor,
265- dummy_tensor ,
362+ block_list_hpu ,
266363 blk_padded_hpu,
267364 dummy_tensor,
268365 dummy_tensor,
0 commit comments