Skip to content

Commit 68ff089

Browse files
committed
feat - support xqa spec
1 parent 4fcdfe2 commit 68ff089

7 files changed

Lines changed: 443 additions & 4 deletions

File tree

rtp_llm/cpp/cuda/ops/CudaXqa.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ struct XQAParams: public ParamsBase {
1313
size_t max_seq_len;
1414
torch::Tensor kv_cache_offset;
1515
torch::Tensor sequence_lengths;
16+
torch::Tensor q_cu_seqlens;
17+
size_t max_q_len{1};
1618
};
1719

1820
using XQAParamsPtr = std::shared_ptr<XQAParams>;

rtp_llm/models_py/bindings/cuda/XQAAttnOp.cc

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,12 @@
77
namespace rtp_llm {
88

99
XQAAttnOp::XQAAttnOp(const AttentionConfigs& attn_configs): attn_configs_(attn_configs) {}
10+
XQASpecAttnOp::XQASpecAttnOp(const AttentionConfigs& attn_configs): attn_configs_(attn_configs) {}
1011

1112
bool XQAAttnOp::support(torch_ext::PyAttentionInputs attn_inputs) {
13+
if (attn_inputs.is_target_verify) {
14+
return false;
15+
}
1216
return attn_configs_.kv_cache_dtype != KvCacheDataType::INT8 && get_sm() >= tensorrt_llm::kernels::kSM_90
1317
&& supportXqa(DataType::TYPE_BF16,
1418
DataType::TYPE_BF16,
@@ -41,6 +45,51 @@ ParamsBasePtr XQAAttnOp::prepare(torch_ext::PyAttentionInputs attn_inputs) {
4145
return ParamsBasePtr(params);
4246
}
4347

48+
bool XQASpecAttnOp::support(torch_ext::PyAttentionInputs attn_inputs) {
49+
if (!attn_inputs.is_target_verify || !attn_inputs.decode_cu_seqlens_d.defined()
50+
|| attn_inputs.decode_cu_seqlens_d.numel() <= 1 || attn_configs_.kv_cache_dtype != KvCacheDataType::FP8
51+
|| get_sm() != tensorrt_llm::kernels::kSM_90) {
52+
return false;
53+
}
54+
const auto input_type = attn_configs_.dtype == torch::kBFloat16 ? DataType::TYPE_BF16 : DataType::TYPE_FP16;
55+
const auto kv_type = DataType::TYPE_FP8_E4M3;
56+
return supportXqa(input_type,
57+
input_type,
58+
kv_type,
59+
attn_configs_.head_num / attn_configs_.kv_head_num,
60+
attn_configs_.size_per_head,
61+
attn_configs_.kernel_tokens_per_block);
62+
}
63+
64+
ParamsBasePtr XQASpecAttnOp::prepare(torch_ext::PyAttentionInputs attn_inputs) {
65+
XQAParamsPtr params = std::make_shared<XQAParams>();
66+
int batch_size = attn_inputs.sequence_lengths.size(0);
67+
RTP_LLM_CHECK_WITH_INFO(attn_inputs.kv_cache_kernel_block_id_host.defined()
68+
&& attn_inputs.kv_cache_kernel_block_id_device.defined(),
69+
"decode should have kv cache block id.");
70+
71+
auto run_stream = at::cuda::getCurrentCUDAStream(at::cuda::current_device()).stream();
72+
bool use_fp8_fmha = attn_configs_.kv_cache_dtype == KvCacheDataType::FP8;
73+
auto trt_params = prepareTrtAttnParams(
74+
attn_configs_, attn_inputs.kv_cache_kernel_block_id_device, batch_size, use_fp8_fmha, run_stream, false);
75+
params->kv_block_array = ((TRTAttn*)trt_params.get())->kv_block_array;
76+
params->kv_cache_offset = ((TRTAttn*)trt_params.get())->kv_cache_offset.clone();
77+
params->batch_size = batch_size;
78+
params->sequence_lengths =
79+
(attn_inputs.sequence_lengths + attn_inputs.input_lengths)
80+
.to(torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA), /*non_blocking=*/true);
81+
params->q_cu_seqlens = attn_inputs.decode_cu_seqlens_d;
82+
params->max_q_len = static_cast<size_t>(
83+
(attn_inputs.decode_cu_seqlens_d.slice(0, 1) - attn_inputs.decode_cu_seqlens_d.slice(0, 0, -1))
84+
.max()
85+
.item<int32_t>());
86+
params->max_seq_len =
87+
attn_inputs.input_lengths.max().item<int32_t>() + attn_inputs.prefix_lengths.max().item<int32_t>();
88+
params->kv_block_array.cache_type = attn_configs_.kv_cache_dtype;
89+
90+
return ParamsBasePtr(params);
91+
}
92+
4493
torch::Tensor XQAAttnOp::forward(const torch::Tensor& input,
4594
std::optional<torch_ext::LayerKVCache> kv_cache,
4695
const XQAParamsPtr& params) {
@@ -80,6 +129,49 @@ torch::Tensor XQAAttnOp::forward(const torch::Tensor& input,
80129
return output;
81130
}
82131

132+
torch::Tensor XQASpecAttnOp::forward(const torch::Tensor& input,
133+
std::optional<torch_ext::LayerKVCache> kv_cache,
134+
const XQAParamsPtr& params) {
135+
const int batch_size = params->batch_size;
136+
const int local_head_num = attn_configs_.head_num;
137+
const int local_head_num_kv = attn_configs_.kv_head_num;
138+
const int size_per_head = attn_configs_.size_per_head;
139+
torch::TensorOptions options = torch::TensorOptions(input.dtype()).device(input.device());
140+
torch::Tensor output =
141+
torch::empty({batch_size, static_cast<int64_t>(params->max_q_len), local_head_num, size_per_head}, options);
142+
143+
KVBlockArray kv_block_array;
144+
if (kv_cache.has_value()) {
145+
kv_block_array = params->kv_block_array;
146+
kv_block_array.mPrimaryPoolPtr = kv_cache.value().kv_cache_base.data_ptr();
147+
if (kv_cache.value().kv_scale_base.defined() && kv_cache.value().kv_scale_base.numel() > 0) {
148+
kv_block_array.scale = kv_cache.value().kv_scale_base.data_ptr();
149+
}
150+
}
151+
152+
RTP_LLM_CHECK_WITH_INFO(kv_cache.has_value(), "spec decode should have kv cache.");
153+
154+
torch::Tensor xqa_input = input.contiguous();
155+
runXqa(xqa_input.data_ptr(),
156+
input.dtype() == torch::kBFloat16,
157+
output.data_ptr(),
158+
local_head_num,
159+
local_head_num_kv,
160+
size_per_head,
161+
params->batch_size,
162+
static_cast<size_t>(kv_block_array.mMaxBlocksPerSeq),
163+
params->max_seq_len,
164+
attn_configs_.kernel_tokens_per_block,
165+
kv_block_array.mPrimaryPoolPtr,
166+
reinterpret_cast<int32_t*>((KVCacheIndex*)(params->kv_cache_offset.data_ptr())),
167+
kv_block_array.cache_type == KvCacheDataType::FP8,
168+
reinterpret_cast<uint32_t*>(params->sequence_lengths.data_ptr()),
169+
nullptr,
170+
params->max_q_len,
171+
params->q_cu_seqlens.data_ptr());
172+
return output;
173+
}
174+
83175
void registerXQAAttnOp(const py::module& m) {
84176
pybind11::class_<XQAParams, std::shared_ptr<XQAParams>, rtp_llm::ParamsBase>(m, "XQAParams")
85177
.def(pybind11::init<>())
@@ -93,6 +185,11 @@ void registerXQAAttnOp(const py::module& m) {
93185
.def("support", &XQAAttnOp::support, py::arg("attn_inputs").noconvert())
94186
.def("prepare", &XQAAttnOp::prepare, py::arg("attn_inputs"))
95187
.def("forward", &XQAAttnOp::forward, py::arg("input"), py::arg("kv_cache"), py::arg("params"));
188+
pybind11::class_<XQASpecAttnOp>(m, "XQASpecAttnOp")
189+
.def(pybind11::init<const AttentionConfigs&>(), py::arg("attn_configs"))
190+
.def("support", &XQASpecAttnOp::support, py::arg("attn_inputs").noconvert())
191+
.def("prepare", &XQASpecAttnOp::prepare, py::arg("attn_inputs"))
192+
.def("forward", &XQASpecAttnOp::forward, py::arg("input"), py::arg("kv_cache"), py::arg("params"));
96193
}
97194

98195
} // namespace rtp_llm

rtp_llm/models_py/bindings/cuda/XQAAttnOp.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,20 @@ class XQAAttnOp {
2424
AttentionConfigs attn_configs_;
2525
};
2626

27+
class XQASpecAttnOp {
28+
public:
29+
XQASpecAttnOp(const AttentionConfigs& attn_configs);
30+
bool support(torch_ext::PyAttentionInputs attn_inputs);
31+
32+
ParamsBasePtr prepare(torch_ext::PyAttentionInputs attn_inputs);
33+
34+
torch::Tensor
35+
forward(const torch::Tensor& input, std::optional<torch_ext::LayerKVCache> kv_cache, const XQAParamsPtr& params);
36+
37+
protected:
38+
AttentionConfigs attn_configs_;
39+
};
40+
2741
void registerXQAAttnOp(const py::module& m);
2842

2943
} // namespace rtp_llm

rtp_llm/models_py/modules/factory/attention/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
)
7070
from rtp_llm.models_py.modules.factory.attention.cuda_impl.xqa import (
7171
get_xqa_impl,
72+
XQASpecImpl,
7273
)
7374

7475
PREFILL_MHA_IMPS.extend(
@@ -77,6 +78,7 @@
7778
HeadWisePrefillImpl,
7879
FlashInferTRTLLMSpecDecodeImpl,
7980
FlashInferTRTLLMPrefillImpl,
81+
XQASpecImpl,
8082
TRTMHAImpl,
8183
PyFlashinferPrefillImpl,
8284
PyFlashinferPagedPrefillImpl,

rtp_llm/models_py/modules/factory/attention/cuda_impl/test/base_attention_test.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import torch
77

8-
from rtp_llm.ops import AttentionConfigs, ParallelismConfig
8+
from rtp_llm.ops import AttentionConfigs, KvCacheDataType, ParallelismConfig
99
from rtp_llm.ops.compute_ops import LayerKVCache, PyAttentionInputs, get_typemeta
1010

1111
logging.basicConfig(level=logging.INFO, format="%(message)s")
@@ -72,6 +72,7 @@ def _create_config(
7272
seq_size_per_block: int = 64,
7373
tp_size: int = 1,
7474
data_type: str = "fp16",
75+
kv_cache_dtype: KvCacheDataType = KvCacheDataType.BASE,
7576
) -> TestConfig:
7677
"""Helper to create a test config"""
7778
attn_configs = AttentionConfigs()
@@ -89,6 +90,7 @@ def _create_config(
8990
"bf16": torch.bfloat16,
9091
}
9192
attn_configs.dtype = dtype_map.get(data_type, torch.float16)
93+
attn_configs.kv_cache_dtype = kv_cache_dtype
9294

9395
parallelism_config = ParallelismConfig()
9496
parallelism_config.tp_size = tp_size
@@ -263,15 +265,24 @@ def _create_kv_cache(
263265

264266
# Create combined KV cache with shape [total_blocks, 2, num_kv_heads, seq_size_per_block, head_dim]
265267
# where dim=1, index=0 is K and index=1 is V
268+
sample_dtype = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
266269
kv_cache_combined = torch.randn(
267270
total_blocks,
268271
2, # K and V
269272
num_kv_heads,
270273
seq_size_per_block,
271274
head_dim,
272-
dtype=dtype,
275+
dtype=sample_dtype,
273276
device=self.device,
274277
)
278+
if dtype == torch.float8_e4m3fn:
279+
kv_cache_combined = kv_cache_combined.to(dtype)
280+
kv_cache.kv_scale_base = torch.ones(
281+
total_blocks,
282+
num_kv_heads * seq_size_per_block,
283+
dtype=torch.float32,
284+
device=self.device,
285+
)
275286

276287
kv_cache.kv_cache_base = kv_cache_combined
277288

0 commit comments

Comments
 (0)