@@ -30,6 +30,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
3030 const paddle::Tensor& slot_mapping,
3131 const paddle::optional<paddle::Tensor>& kv_signal_data,
3232 cudaStream_t& stream,
33+ const std::string& cache_quant_type_str,
3334 paddle::Tensor* kv_cache) {
3435 typedef PDTraits<T> traits_;
3536 typedef typename traits_::DataType DataType_;
@@ -50,27 +51,51 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
5051 int grid_size = 1 ;
5152 GetNumBlocks<128 >(pack_num, &grid_size);
5253
53- using CT = DataType_;
54-
55- prefill_absorb_cache_kernel<DataType_, PackSize, CT>
56- <<<grid_size, blocksize, 0 , stream>>> (
57- reinterpret_cast <DataType_*>(
58- const_cast <data_t *>(kv_nope.data <data_t >())),
59- reinterpret_cast <DataType_*>(
60- const_cast <data_t *>(kv_pe.data <data_t >())),
61- reinterpret_cast <DataType_*>(kv_cache->data <data_t >()),
62- block_tables.data <int >(),
63- slot_mapping.data <int64_t >(),
64- batch_id_per_token.data <int >(),
65- cu_seqlens_q.data <int >(),
66- seq_lens.data <int >(),
67- seq_lens_decoder.data <int >(),
68- max_blocks_per_seq,
69- kv_num_heads,
70- nope_size,
71- pe_size,
72- block_size,
73- elem_nums);
54+ if (cache_quant_type_str == " cache_fp8" ) {
55+ using CT = __nv_fp8_e4m3;
56+ prefill_absorb_cache_kernel<DataType_, PackSize, CT>
57+ <<<grid_size, blocksize, 0 , stream>>> (
58+ reinterpret_cast <DataType_*>(
59+ const_cast <data_t *>(kv_nope.data <data_t >())),
60+ reinterpret_cast <DataType_*>(
61+ const_cast <data_t *>(kv_pe.data <data_t >())),
62+ reinterpret_cast <CT*>(kv_cache->data <uint8_t >()),
63+ block_tables.data <int >(),
64+ slot_mapping.data <int64_t >(),
65+ batch_id_per_token.data <int >(),
66+ cu_seqlens_q.data <int >(),
67+ seq_lens.data <int >(),
68+ seq_lens_decoder.data <int >(),
69+ max_blocks_per_seq,
70+ kv_num_heads,
71+ nope_size,
72+ pe_size,
73+ block_size,
74+ elem_nums);
75+ } else if (cache_quant_type_str == " none" ) {
76+ prefill_absorb_cache_kernel<DataType_, PackSize, DataType_>
77+ <<<grid_size, blocksize, 0 , stream>>> (
78+ reinterpret_cast <DataType_*>(
79+ const_cast <data_t *>(kv_nope.data <data_t >())),
80+ reinterpret_cast <DataType_*>(
81+ const_cast <data_t *>(kv_pe.data <data_t >())),
82+ reinterpret_cast <DataType_*>(kv_cache->data <data_t >()),
83+ block_tables.data <int >(),
84+ slot_mapping.data <int64_t >(),
85+ batch_id_per_token.data <int >(),
86+ cu_seqlens_q.data <int >(),
87+ seq_lens.data <int >(),
88+ seq_lens_decoder.data <int >(),
89+ max_blocks_per_seq,
90+ kv_num_heads,
91+ nope_size,
92+ pe_size,
93+ block_size,
94+ elem_nums);
95+ } else {
96+ PD_THROW (" Unsupported cache_quant_type_str type: %s." ,
97+ cache_quant_type_str.c_str ());
98+ }
7499
75100 const char * fmt_write_cache_completed_signal_str =
76101 std::getenv (" FLAGS_fmt_write_cache_completed_signal" );
@@ -142,6 +167,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
142167 slot_mapping,
143168 kv_signal_data,
144169 stream,
170+ cache_quant_type_str,
145171 const_cast <paddle::Tensor*>(&kv_cache));
146172 }
147173 case paddle::DataType::FLOAT16: {
@@ -157,6 +183,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
157183 slot_mapping,
158184 kv_signal_data,
159185 stream,
186+ cache_quant_type_str,
160187 const_cast <paddle::Tensor*>(&kv_cache));
161188 }
162189 }
0 commit comments