77namespace rtp_llm {
88
99XQAAttnOp::XQAAttnOp (const AttentionConfigs& attn_configs): attn_configs_(attn_configs) {}
10+ XQASpecAttnOp::XQASpecAttnOp (const AttentionConfigs& attn_configs): attn_configs_(attn_configs) {}
1011
1112bool XQAAttnOp::support (torch_ext::PyAttentionInputs attn_inputs) {
13+ if (attn_inputs.is_target_verify ) {
14+ return false ;
15+ }
1216 return attn_configs_.kv_cache_dtype != KvCacheDataType::INT8 && get_sm () >= tensorrt_llm::kernels::kSM_90
1317 && supportXqa (DataType::TYPE_BF16,
1418 DataType::TYPE_BF16,
@@ -41,6 +45,51 @@ ParamsBasePtr XQAAttnOp::prepare(torch_ext::PyAttentionInputs attn_inputs) {
4145 return ParamsBasePtr (params);
4246}
4347
48+ bool XQASpecAttnOp::support (torch_ext::PyAttentionInputs attn_inputs) {
49+ if (!attn_inputs.is_target_verify || !attn_inputs.decode_cu_seqlens_d .defined ()
50+ || attn_inputs.decode_cu_seqlens_d .numel () <= 1 || attn_configs_.kv_cache_dtype != KvCacheDataType::FP8
51+ || get_sm () != tensorrt_llm::kernels::kSM_90 ) {
52+ return false ;
53+ }
54+ const auto input_type = attn_configs_.dtype == torch::kBFloat16 ? DataType::TYPE_BF16 : DataType::TYPE_FP16;
55+ const auto kv_type = DataType::TYPE_FP8_E4M3;
56+ return supportXqa (input_type,
57+ input_type,
58+ kv_type,
59+ attn_configs_.head_num / attn_configs_.kv_head_num ,
60+ attn_configs_.size_per_head ,
61+ attn_configs_.kernel_tokens_per_block );
62+ }
63+
64+ ParamsBasePtr XQASpecAttnOp::prepare (torch_ext::PyAttentionInputs attn_inputs) {
65+ XQAParamsPtr params = std::make_shared<XQAParams>();
66+ int batch_size = attn_inputs.sequence_lengths .size (0 );
67+ RTP_LLM_CHECK_WITH_INFO (attn_inputs.kv_cache_kernel_block_id_host .defined ()
68+ && attn_inputs.kv_cache_kernel_block_id_device .defined (),
69+ " decode should have kv cache block id." );
70+
71+ auto run_stream = at::cuda::getCurrentCUDAStream (at::cuda::current_device ()).stream ();
72+ bool use_fp8_fmha = attn_configs_.kv_cache_dtype == KvCacheDataType::FP8;
73+ auto trt_params = prepareTrtAttnParams (
74+ attn_configs_, attn_inputs.kv_cache_kernel_block_id_device , batch_size, use_fp8_fmha, run_stream, false );
75+ params->kv_block_array = ((TRTAttn*)trt_params.get ())->kv_block_array ;
76+ params->kv_cache_offset = ((TRTAttn*)trt_params.get ())->kv_cache_offset .clone ();
77+ params->batch_size = batch_size;
78+ params->sequence_lengths =
79+ (attn_inputs.sequence_lengths + attn_inputs.input_lengths )
80+ .to (torch::TensorOptions ().dtype (torch::kInt32 ).device (torch::kCUDA ), /* non_blocking=*/ true );
81+ params->q_cu_seqlens = attn_inputs.decode_cu_seqlens_d ;
82+ params->max_q_len = static_cast <size_t >(
83+ (attn_inputs.decode_cu_seqlens_d .slice (0 , 1 ) - attn_inputs.decode_cu_seqlens_d .slice (0 , 0 , -1 ))
84+ .max ()
85+ .item <int32_t >());
86+ params->max_seq_len =
87+ attn_inputs.input_lengths .max ().item <int32_t >() + attn_inputs.prefix_lengths .max ().item <int32_t >();
88+ params->kv_block_array .cache_type = attn_configs_.kv_cache_dtype ;
89+
90+ return ParamsBasePtr (params);
91+ }
92+
4493torch::Tensor XQAAttnOp::forward (const torch::Tensor& input,
4594 std::optional<torch_ext::LayerKVCache> kv_cache,
4695 const XQAParamsPtr& params) {
@@ -80,6 +129,49 @@ torch::Tensor XQAAttnOp::forward(const torch::Tensor& input,
80129 return output;
81130}
82131
132+ torch::Tensor XQASpecAttnOp::forward (const torch::Tensor& input,
133+ std::optional<torch_ext::LayerKVCache> kv_cache,
134+ const XQAParamsPtr& params) {
135+ const int batch_size = params->batch_size ;
136+ const int local_head_num = attn_configs_.head_num ;
137+ const int local_head_num_kv = attn_configs_.kv_head_num ;
138+ const int size_per_head = attn_configs_.size_per_head ;
139+ torch::TensorOptions options = torch::TensorOptions (input.dtype ()).device (input.device ());
140+ torch::Tensor output =
141+ torch::empty ({batch_size, static_cast <int64_t >(params->max_q_len ), local_head_num, size_per_head}, options);
142+
143+ KVBlockArray kv_block_array;
144+ if (kv_cache.has_value ()) {
145+ kv_block_array = params->kv_block_array ;
146+ kv_block_array.mPrimaryPoolPtr = kv_cache.value ().kv_cache_base .data_ptr ();
147+ if (kv_cache.value ().kv_scale_base .defined () && kv_cache.value ().kv_scale_base .numel () > 0 ) {
148+ kv_block_array.scale = kv_cache.value ().kv_scale_base .data_ptr ();
149+ }
150+ }
151+
152+ RTP_LLM_CHECK_WITH_INFO (kv_cache.has_value (), " spec decode should have kv cache." );
153+
154+ torch::Tensor xqa_input = input.contiguous ();
155+ runXqa (xqa_input.data_ptr (),
156+ input.dtype () == torch::kBFloat16 ,
157+ output.data_ptr (),
158+ local_head_num,
159+ local_head_num_kv,
160+ size_per_head,
161+ params->batch_size ,
162+ static_cast <size_t >(kv_block_array.mMaxBlocksPerSeq ),
163+ params->max_seq_len ,
164+ attn_configs_.kernel_tokens_per_block ,
165+ kv_block_array.mPrimaryPoolPtr ,
166+ reinterpret_cast <int32_t *>((KVCacheIndex*)(params->kv_cache_offset .data_ptr ())),
167+ kv_block_array.cache_type == KvCacheDataType::FP8,
168+ reinterpret_cast <uint32_t *>(params->sequence_lengths .data_ptr ()),
169+ nullptr ,
170+ params->max_q_len ,
171+ params->q_cu_seqlens .data_ptr ());
172+ return output;
173+ }
174+
83175void registerXQAAttnOp (const py::module & m) {
84176 pybind11::class_<XQAParams, std::shared_ptr<XQAParams>, rtp_llm::ParamsBase>(m, " XQAParams" )
85177 .def (pybind11::init<>())
@@ -93,6 +185,11 @@ void registerXQAAttnOp(const py::module& m) {
93185 .def (" support" , &XQAAttnOp::support, py::arg (" attn_inputs" ).noconvert ())
94186 .def (" prepare" , &XQAAttnOp::prepare, py::arg (" attn_inputs" ))
95187 .def (" forward" , &XQAAttnOp::forward, py::arg (" input" ), py::arg (" kv_cache" ), py::arg (" params" ));
188+ pybind11::class_<XQASpecAttnOp>(m, " XQASpecAttnOp" )
189+ .def (pybind11::init<const AttentionConfigs&>(), py::arg (" attn_configs" ))
190+ .def (" support" , &XQASpecAttnOp::support, py::arg (" attn_inputs" ).noconvert ())
191+ .def (" prepare" , &XQASpecAttnOp::prepare, py::arg (" attn_inputs" ))
192+ .def (" forward" , &XQASpecAttnOp::forward, py::arg (" input" ), py::arg (" kv_cache" ), py::arg (" params" ));
96193}
97194
98195} // namespace rtp_llm
0 commit comments