Skip to content

Commit 4f14845

Browse files
support cfp8 in blackwell mla (#7876)
1 parent 4402396 commit 4f14845

3 files changed

Lines changed: 92 additions & 32 deletions

File tree

custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
3030
const paddle::Tensor& slot_mapping,
3131
const paddle::optional<paddle::Tensor>& kv_signal_data,
3232
cudaStream_t& stream,
33+
const std::string& cache_quant_type_str,
3334
paddle::Tensor* kv_cache) {
3435
typedef PDTraits<T> traits_;
3536
typedef typename traits_::DataType DataType_;
@@ -50,27 +51,51 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
5051
int grid_size = 1;
5152
GetNumBlocks<128>(pack_num, &grid_size);
5253

53-
using CT = DataType_;
54-
55-
prefill_absorb_cache_kernel<DataType_, PackSize, CT>
56-
<<<grid_size, blocksize, 0, stream>>>(
57-
reinterpret_cast<DataType_*>(
58-
const_cast<data_t*>(kv_nope.data<data_t>())),
59-
reinterpret_cast<DataType_*>(
60-
const_cast<data_t*>(kv_pe.data<data_t>())),
61-
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
62-
block_tables.data<int>(),
63-
slot_mapping.data<int64_t>(),
64-
batch_id_per_token.data<int>(),
65-
cu_seqlens_q.data<int>(),
66-
seq_lens.data<int>(),
67-
seq_lens_decoder.data<int>(),
68-
max_blocks_per_seq,
69-
kv_num_heads,
70-
nope_size,
71-
pe_size,
72-
block_size,
73-
elem_nums);
54+
if (cache_quant_type_str == "cache_fp8") {
55+
using CT = __nv_fp8_e4m3;
56+
prefill_absorb_cache_kernel<DataType_, PackSize, CT>
57+
<<<grid_size, blocksize, 0, stream>>>(
58+
reinterpret_cast<DataType_*>(
59+
const_cast<data_t*>(kv_nope.data<data_t>())),
60+
reinterpret_cast<DataType_*>(
61+
const_cast<data_t*>(kv_pe.data<data_t>())),
62+
reinterpret_cast<CT*>(kv_cache->data<uint8_t>()),
63+
block_tables.data<int>(),
64+
slot_mapping.data<int64_t>(),
65+
batch_id_per_token.data<int>(),
66+
cu_seqlens_q.data<int>(),
67+
seq_lens.data<int>(),
68+
seq_lens_decoder.data<int>(),
69+
max_blocks_per_seq,
70+
kv_num_heads,
71+
nope_size,
72+
pe_size,
73+
block_size,
74+
elem_nums);
75+
} else if (cache_quant_type_str == "none") {
76+
prefill_absorb_cache_kernel<DataType_, PackSize, DataType_>
77+
<<<grid_size, blocksize, 0, stream>>>(
78+
reinterpret_cast<DataType_*>(
79+
const_cast<data_t*>(kv_nope.data<data_t>())),
80+
reinterpret_cast<DataType_*>(
81+
const_cast<data_t*>(kv_pe.data<data_t>())),
82+
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
83+
block_tables.data<int>(),
84+
slot_mapping.data<int64_t>(),
85+
batch_id_per_token.data<int>(),
86+
cu_seqlens_q.data<int>(),
87+
seq_lens.data<int>(),
88+
seq_lens_decoder.data<int>(),
89+
max_blocks_per_seq,
90+
kv_num_heads,
91+
nope_size,
92+
pe_size,
93+
block_size,
94+
elem_nums);
95+
} else {
96+
PD_THROW("Unsupported cache_quant_type_str type: %s.",
97+
cache_quant_type_str.c_str());
98+
}
7499

75100
const char* fmt_write_cache_completed_signal_str =
76101
std::getenv("FLAGS_fmt_write_cache_completed_signal");
@@ -142,6 +167,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
142167
slot_mapping,
143168
kv_signal_data,
144169
stream,
170+
cache_quant_type_str,
145171
const_cast<paddle::Tensor*>(&kv_cache));
146172
}
147173
case paddle::DataType::FLOAT16: {
@@ -157,6 +183,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
157183
slot_mapping,
158184
kv_signal_data,
159185
stream,
186+
cache_quant_type_str,
160187
const_cast<paddle::Tensor*>(&kv_cache));
161188
}
162189
}

custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,8 @@ __global__ void prefill_absorb_cache_kernel(
186186
const uint32_t elem_cnt) {
187187
using LoadT = AlignedVector<T, VecSize>;
188188
LoadT src_vec;
189+
using StoreT = AlignedVector<CT, VecSize>;
190+
StoreT dst_vec;
189191

190192
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
191193
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
@@ -227,7 +229,20 @@ __global__ void prefill_absorb_cache_kernel(
227229
hi * block_size * all_size + block_offset * all_size + h_bias;
228230
const uint32_t ori_idx = token_idx * nope_hidden_size + inner_bias;
229231
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
230-
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
232+
233+
if constexpr (std::is_same_v<CT, __nv_fp8_e4m3>) {
234+
for (int i = 0; i < VecSize; i++) {
235+
float quant_value = (float)(src_vec[i]);
236+
quant_value = quant_value > 448.0f ? 448.0f : quant_value;
237+
quant_value = quant_value < -448.0f ? -448.0f : quant_value;
238+
dst_vec[i] = static_cast<__nv_fp8_e4m3>(quant_value);
239+
}
240+
241+
Store<CT, VecSize>(dst_vec, &kv_cache[tgt_idx]);
242+
} else {
243+
Store<CT, VecSize>(src_vec, &kv_cache[tgt_idx]);
244+
}
245+
231246
} else {
232247
const uint32_t inner_bias = bias - nope_hidden_size;
233248
const uint32_t hi = inner_bias / pe_size;
@@ -238,7 +253,18 @@ __global__ void prefill_absorb_cache_kernel(
238253
h_bias;
239254
const uint32_t ori_idx = token_idx * pe_hidden_size + inner_bias;
240255
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
241-
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
256+
257+
if constexpr (std::is_same_v<CT, __nv_fp8_e4m3>) {
258+
for (int i = 0; i < VecSize; i++) {
259+
float quant_value = (float)(src_vec[i]);
260+
quant_value = quant_value > 448.0f ? 448.0f : quant_value;
261+
quant_value = quant_value < -448.0f ? -448.0f : quant_value;
262+
dst_vec[i] = static_cast<__nv_fp8_e4m3>(quant_value);
263+
}
264+
Store<CT, VecSize>(dst_vec, &kv_cache[tgt_idx]);
265+
} else {
266+
Store<CT, VecSize>(src_vec, &kv_cache[tgt_idx]);
267+
}
242268
}
243269
}
244270
}

fastdeploy/model_executor/layers/attention/mla_attention_backend.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -860,7 +860,7 @@ def forward_mixed(
860860
metadata.block_tables,
861861
forward_meta.slot_mapping,
862862
metadata.kv_signal_data_list[layer.layer_id],
863-
"none",
863+
getattr(layer, "cache_quant_type_str", "none"),
864864
)
865865

866866
# Prefill branch: k is not None
@@ -998,11 +998,13 @@ def forward_mixed(
998998
@staticmethod
999999
def mla_blackwell(decoder_q, latent_cache, block_table, cache_seqlens, attn_softmax_scale):
10001000

1001-
# decoder_q = decoder_q.cast(paddle.float8_e4m3fn)
1002-
# latent_cache = latent_cache.cast(paddle.float8_e4m3fn)
1001+
assert latent_cache.dtype in [paddle.bfloat16, paddle.uint8], latent_cache.dtype
1002+
use_fp8_cache_kv = latent_cache.dtype == paddle.uint8
1003+
if use_fp8_cache_kv:
1004+
decoder_q = decoder_q.cast(paddle.float8_e4m3fn)
1005+
latent_cache = latent_cache.view(paddle.float8_e4m3fn)
10031006

10041007
assert decoder_q.dtype == latent_cache.dtype
1005-
q_dtype = decoder_q.dtype
10061008

10071009
page_size = latent_cache.shape[2]
10081010
q_num_heads = decoder_q.shape[2]
@@ -1049,11 +1051,16 @@ def mla_blackwell(decoder_q, latent_cache, block_table, cache_seqlens, attn_soft
10491051
softmax_scale = attn_softmax_scale
10501052
output_scale = 1.0
10511053

1052-
from mla_decode_fp16 import BlackwellMultiHeadLatentAttentionForwardFP16
1053-
1054-
# from mla_decode_fp8 import BlackwellMultiHeadLatentAttentionForwardFP8
1054+
if use_fp8_cache_kv:
1055+
from mla_decode_fp8 import (
1056+
BlackwellMultiHeadLatentAttentionForwardFP8 as kernel,
1057+
)
1058+
else:
1059+
from mla_decode_fp16 import (
1060+
BlackwellMultiHeadLatentAttentionForwardFP16 as kernel,
1061+
)
10551062

1056-
mla = BlackwellMultiHeadLatentAttentionForwardFP16(
1063+
mla = kernel(
10571064
cutlass.Float32,
10581065
cutlass.Float32,
10591066
mma_qk_tiler_mn=(128, 128),
@@ -1108,7 +1115,7 @@ def mla_blackwell(decoder_q, latent_cache, block_table, cache_seqlens, attn_soft
11081115
stream,
11091116
)
11101117

1111-
if q_dtype == paddle.float8_e4m3fn:
1118+
if use_fp8_cache_kv:
11121119
paddle_output = paddle_output.cast("bfloat16")
11131120
return paddle_output
11141121

0 commit comments

Comments
 (0)