Skip to content

Commit 46ef1af

Browse files
committed
[TRTLLM-12339][feat] enable TRTLLM cross attention backend
Signed-off-by: Guiju Zhang <guijuz@nvidia.com>
1 parent 4f46653 commit 46ef1af

6 files changed

Lines changed: 122 additions & 17 deletions

File tree

cpp/tensorrt_llm/nanobind/thop/bindings.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@ void initBindings(nb::module_& m)
174174
nb::arg("sage_attn_num_elts_per_blk_k") = 0, nb::arg("sage_attn_num_elts_per_blk_v") = 0,
175175
nb::arg("sage_attn_qk_int8") = false, nb::arg("num_contexts") = 0, nb::arg("num_ctx_tokens") = 0,
176176
nb::arg("trtllm_gen_jit_warmup") = false, nb::arg("compressed_kv_cache_pool_ptr") = std::nullopt,
177+
nb::arg("is_cross") = false, nb::arg("cross_kv") = std::nullopt,
178+
nb::arg("relative_attention_bias") = std::nullopt, nb::arg("relative_attention_max_distance") = 0,
177179
nb::arg("spec_decoding_target_max_draft_tokens") = std::nullopt, "Multi-head attention operation",
178180
nb::call_guard<nb::gil_scoped_release>());
179181

cpp/tensorrt_llm/thop/attentionOp.cpp

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,8 @@ class RunnerBase
376376
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer,
377377
std::optional<torch::Tensor> flash_mla_tile_scheduler_metadata,
378378
std::optional<torch::Tensor> flash_mla_num_splits, bool trtllm_gen_jit_warmup,
379-
std::optional<int64_t> compressed_kv_cache_pool_ptr) const
379+
std::optional<int64_t> compressed_kv_cache_pool_ptr, bool const is_cross, std::optional<torch::Tensor> cross_kv,
380+
std::optional<torch::Tensor> relative_attention_bias) const
380381
= 0;
381382
};
382383

@@ -444,7 +445,8 @@ class Runner : public RunnerBase
444445
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer,
445446
std::optional<torch::Tensor> flash_mla_tile_scheduler_metadata,
446447
std::optional<torch::Tensor> flash_mla_num_splits, bool trtllm_gen_jit_warmup,
447-
std::optional<int64_t> compressed_kv_cache_pool_ptr) const override
448+
std::optional<int64_t> compressed_kv_cache_pool_ptr, bool const is_cross, std::optional<torch::Tensor> cross_kv,
449+
std::optional<torch::Tensor> relative_attention_bias) const override
448450
{
449451
auto stream = at::cuda::getCurrentCUDAStream(qkv_or_q.get_device());
450452
T* attention_input = static_cast<T*>(qkv_or_q.slice(0, token_offset).data_ptr());
@@ -677,6 +679,20 @@ class Runner : public RunnerBase
677679
attention_sinks.value().dtype() == torch::kFloat32, "Expected attention_sinks to have float dtype");
678680
attention_sinks_ptr = attention_sinks.value().data_ptr<float>();
679681
}
682+
T const* relative_attention_bias_ptr = nullptr;
683+
int relative_attention_bias_stride = 0;
684+
if (relative_attention_bias.has_value())
685+
{
686+
auto const& relative_attention_bias_tensor = relative_attention_bias.value();
687+
TORCH_CHECK(relative_attention_bias_tensor.dim() == 2 || relative_attention_bias_tensor.dim() == 3,
688+
"relative_attention_bias must be [num_heads, num_buckets] for implicit mode or "
689+
"[num_heads, max_seq_len, max_seq_len] for explicit mode");
690+
TORCH_CHECK(relative_attention_bias_tensor.is_contiguous(), "relative_attention_bias must be contiguous");
691+
TORCH_CHECK(relative_attention_bias_tensor.scalar_type() == qkv_or_q.scalar_type(),
692+
"relative_attention_bias dtype must match attention input dtype");
693+
relative_attention_bias_ptr = static_cast<T const*>(relative_attention_bias_tensor.data_ptr());
694+
relative_attention_bias_stride = static_cast<int>(relative_attention_bias_tensor.size(1));
695+
}
680696

681697
// Prepare sparse attention parameters
682698
op.mRuntimeSparseAttentionParams.sparse_kv_indices
@@ -723,6 +739,8 @@ class Runner : public RunnerBase
723739
common_enqueue_params.attention_sinks = attention_sinks_ptr;
724740
common_enqueue_params.rotary_inv_freq = rotary_inv_freq_ptr;
725741
common_enqueue_params.rotary_cos_sin = rotary_cos_sin_ptr;
742+
common_enqueue_params.relative_attention_bias = relative_attention_bias_ptr;
743+
common_enqueue_params.relative_attention_bias_stride = relative_attention_bias_stride;
726744
common_enqueue_params.max_past_kv_length = max_past_kv_length;
727745
common_enqueue_params.max_attention_window_size = max_attention_window_size;
728746
common_enqueue_params.cyclic_attention_window_size = cyclic_attention_window_size;
@@ -747,6 +765,13 @@ class Runner : public RunnerBase
747765
common_enqueue_params.host_context_lengths = host_context_lengths.data_ptr<int32_t>();
748766
common_enqueue_params.workspace = workspace_ptr;
749767
common_enqueue_params.trtllm_gen_jit_warmup = trtllm_gen_jit_warmup;
768+
if (is_cross)
769+
{
770+
// For cross attention, the KV (encoder) sequence lengths are passed in via
771+
// `sequence_length` (already sliced into `sequence_lengths_ptr`), so reuse
772+
// it directly instead of a redundant `encoder_input_lengths` tensor.
773+
common_enqueue_params.encoder_input_lengths = sequence_lengths_ptr;
774+
}
750775
if (softmax_stats_tensor.has_value())
751776
{
752777
TLLM_CHECK_WITH_INFO(softmax_stats_tensor.value().scalar_type() == at::ScalarType::Float,
@@ -807,6 +832,14 @@ class Runner : public RunnerBase
807832
{
808833
enqueue_params.v_stride_in_bytes = v->strides()[0] * v->element_size();
809834
}
835+
if (is_cross && cross_kv.has_value())
836+
{
837+
auto const& cross_kv_tensor = cross_kv.value();
838+
enqueue_params.cross_kv = static_cast<T const*>(cross_kv_tensor.data_ptr());
839+
enqueue_params.num_encoder_tokens = static_cast<int32_t>(cross_kv_tensor.size(0));
840+
enqueue_params.cross_kv_length
841+
= host_past_key_value_lengths.slice(0, seq_offset, seq_offset + num_seqs).max().item<int32_t>();
842+
}
810843

811844
if (op.isMLAEnabled())
812845
{
@@ -993,7 +1026,9 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
9931026
std::optional<torch::Tensor> flash_mla_tile_scheduler_metadata, std::optional<torch::Tensor> flash_mla_num_splits,
9941027
int64_t sage_attn_num_elts_per_blk_q, int64_t sage_attn_num_elts_per_blk_k, int64_t sage_attn_num_elts_per_blk_v,
9951028
bool sage_attn_qk_int8, int64_t num_contexts, int64_t num_ctx_tokens, bool trtllm_gen_jit_warmup,
996-
std::optional<int64_t> compressed_kv_cache_pool_ptr, std::optional<int64_t> spec_decoding_target_max_draft_tokens)
1029+
std::optional<int64_t> compressed_kv_cache_pool_ptr, bool const is_cross, std::optional<torch::Tensor> cross_kv,
1030+
std::optional<torch::Tensor> relative_attention_bias, int64_t relative_attention_max_distance,
1031+
std::optional<int64_t> spec_decoding_target_max_draft_tokens)
9971032
{
9981033
TLLM_LOG_TRACE("Attention op starts at layer %d", local_layer_idx);
9991034
// Use these tensors to infer if the attention is using KV cache
@@ -1002,16 +1037,17 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
10021037

10031038
bool const use_sage_attn
10041039
= sage_attn_num_elts_per_blk_q > 0 || sage_attn_num_elts_per_blk_k > 0 || sage_attn_num_elts_per_blk_v > 0;
1005-
TLLM_CHECK_WITH_INFO(is_mla_enable || is_fused_qkv || use_sage_attn,
1006-
"Context attention only allows these non-MLA cases: fused QKV; separate QKV with SageAttention");
1007-
TLLM_CHECK_WITH_INFO(update_kv_cache, "KV cache update cannot be disabled now");
1040+
TLLM_CHECK_WITH_INFO(is_mla_enable || is_fused_qkv || use_sage_attn || is_cross,
1041+
"For non-MLA, non-cross, non-SageAttention attention, only fused QKV is supported now.");
1042+
TLLM_CHECK_WITH_INFO(
1043+
update_kv_cache || is_cross, "KV cache update cannot be disabled now (except for cross attention).");
10081044
auto qkv_or_q = q;
10091045
if (is_fused_qkv)
10101046
{
10111047
TLLM_CHECK_WITH_INFO(!k.has_value(), "The k tensor should be null if using fused QKV");
10121048
TLLM_CHECK_WITH_INFO(!v.has_value(), "The v tensor should be null if using fused QKV");
10131049
}
1014-
if (!is_fused_qkv && update_kv_cache)
1050+
if (!is_fused_qkv && update_kv_cache && !is_cross)
10151051
{
10161052
TLLM_CHECK_WITH_INFO(k.has_value(), "The k tensor should be provided if updating KV cache with unfused K/V");
10171053
TLLM_CHECK_WITH_INFO(v.has_value(), "The v tensor should be provided if updating KV cache with unfused K/V");
@@ -1094,6 +1130,20 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
10941130
op->mQScaling = q_scaling;
10951131
op->mPositionEmbeddingType
10961132
= static_cast<tensorrt_llm::kernels::PositionEmbeddingType>(int8_t(position_embedding_type));
1133+
if (relative_attention_bias.has_value())
1134+
{
1135+
auto const relative_attention_bias_dim = relative_attention_bias.value().dim();
1136+
TORCH_CHECK(relative_attention_bias_dim == 2 || relative_attention_bias_dim == 3,
1137+
"relative_attention_bias must be [num_heads, num_buckets] for implicit mode or "
1138+
"[num_heads, max_seq_len, max_seq_len] for explicit mode");
1139+
TORCH_CHECK(relative_attention_bias_dim != 2 || relative_attention_max_distance > 0,
1140+
"relative_attention_max_distance must be positive when relative_attention_bias is a bucket table");
1141+
TORCH_CHECK(relative_attention_bias_dim != 3 || relative_attention_max_distance == 0,
1142+
"relative_attention_max_distance must be 0 when relative_attention_bias is precomputed");
1143+
TLLM_CHECK_WITH_INFO(op->mPositionEmbeddingType == tensorrt_llm::kernels::PositionEmbeddingType::kRELATIVE,
1144+
"relative_attention_bias requires position_embedding_type to be relative.");
1145+
op->mMaxDistance = static_cast<int>(relative_attention_max_distance);
1146+
}
10971147
op->mRotaryEmbeddingDim = rope_dim;
10981148
op->mRotaryEmbeddingBase = rope_base;
10991149
op->mRotaryEmbeddingScaleType = static_cast<tensorrt_llm::kernels::RotaryScalingType>(int8_t(rope_scale_type));
@@ -1111,6 +1161,7 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
11111161
op->mSageAttnQkInt8 = sage_attn_qk_int8;
11121162
op->mFP8AttenOutput = is_fp8_out;
11131163
op->mPagedContextFMHA = use_paged_context_fmha;
1164+
op->mCrossAttention = is_cross;
11141165

11151166
op->mAttentionChunkSize = attention_chunk_size;
11161167
op->mSkipSoftmaxThresholdScaleFactorPrefill
@@ -1275,7 +1326,7 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
12751326
sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets, sparse_attn_indices_block_size,
12761327
num_sparse_topk_value, sparse_mla_topk_lens, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter,
12771328
mla_bmm1_scale, mla_bmm2_scale, quant_q_buffer, flash_mla_tile_scheduler_metadata, flash_mla_num_splits,
1278-
trtllm_gen_jit_warmup, compressed_kv_cache_pool_ptr);
1329+
trtllm_gen_jit_warmup, compressed_kv_cache_pool_ptr, is_cross, cross_kv, relative_attention_bias);
12791330
}
12801331

12811332
if ((num_generations > 0) && (attn_input_type != AttentionInputType::ContextOnly))
@@ -1297,7 +1348,7 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
12971348
sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets, sparse_attn_indices_block_size,
12981349
num_sparse_topk_value, sparse_mla_topk_lens, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter,
12991350
mla_bmm1_scale, mla_bmm2_scale, quant_q_buffer, flash_mla_tile_scheduler_metadata, flash_mla_num_splits,
1300-
trtllm_gen_jit_warmup, compressed_kv_cache_pool_ptr);
1351+
trtllm_gen_jit_warmup, compressed_kv_cache_pool_ptr, is_cross, cross_kv, relative_attention_bias);
13011352
}
13021353

13031354
TLLM_LOG_TRACE("Attention op stops at layer %d", local_layer_idx);

cpp/tensorrt_llm/thop/attentionOp.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
9090
std::optional<torch::Tensor> flash_mla_num_splits = std::nullopt, int64_t sage_attn_num_elts_per_blk_q = 0,
9191
int64_t sage_attn_num_elts_per_blk_k = 0, int64_t sage_attn_num_elts_per_blk_v = 0, bool sage_attn_qk_int8 = false,
9292
int64_t num_contexts = 0, int64_t num_ctx_tokens = 0, bool trtllm_gen_jit_warmup = false,
93-
std::optional<int64_t> compressed_kv_cache_pool_ptr = std::nullopt,
93+
std::optional<int64_t> compressed_kv_cache_pool_ptr = std::nullopt, bool const is_cross = false,
94+
std::optional<torch::Tensor> cross_kv = std::nullopt,
95+
std::optional<torch::Tensor> relative_attention_bias = std::nullopt, int64_t relative_attention_max_distance = 0,
9496
std::optional<int64_t> spec_decoding_target_max_draft_tokens = std::nullopt);
9597

9698
struct KvCachePoolPointers

tensorrt_llm/_torch/attention_backend/interface.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,12 @@ def seq_lens_kv(self, value: Optional[torch.Tensor]):
263263
# The model executor sets seqlens to None initially.
264264
if self._seq_lens_kv is not None:
265265
self._seq_lens_kv = maybe_pin_memory(self._seq_lens_kv)
266-
self._seq_lens_kv_cuda = self._seq_lens_kv.cuda(non_blocking=True)
266+
if self.is_cuda_graph and self._seq_lens_kv_cuda is not None:
267+
self._seq_lens_kv_cuda.copy_(self._seq_lens_kv,
268+
non_blocking=True)
269+
else:
270+
self._seq_lens_kv_cuda = self._seq_lens_kv.cuda(
271+
non_blocking=True)
267272

268273
@property
269274
def seq_lens_kv_cuda(self):
@@ -747,6 +752,9 @@ class AttentionForwardArgs:
747752
attention_window_size: Optional[int] = None
748753
attention_mask_data: Optional[torch.Tensor] = None
749754
attention_sinks: Optional[torch.Tensor] = None
755+
relative_attention_bias: Optional[torch.Tensor] = None
756+
relative_attention_max_distance: int = 0
757+
cross_kv: Optional[torch.Tensor] = None
750758

751759
latent_cache: Optional[torch.Tensor] = None
752760
q_pe: Optional[torch.Tensor] = None

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,16 @@ def max_context_length(self) -> int:
190190
"""
191191
return min(self.max_seq_len, self.max_num_tokens)
192192

193+
@property
194+
def effective_beam_width(self) -> int:
195+
"""Beam width visible to the kernel.
196+
197+
Cross-attention metadata is already expanded to one row per decoder
198+
beam, and all beams read the same encoder K/V cache. Keep kernel beam
199+
indirection disabled for that path.
200+
"""
201+
return 1 if self.is_cross else self.beam_width
202+
193203
@property
194204
def max_seq_len(self) -> int:
195205
"""
@@ -1439,6 +1449,26 @@ def _run(
14391449
metadata: TrtllmAttentionMetadata,
14401450
forward_args: AttentionForwardArgs,
14411451
) -> None:
1452+
if metadata.is_cross:
1453+
if k is not None and v is not None:
1454+
k_flat = k.contiguous().view(k.shape[0], -1)
1455+
v_flat = v.contiguous().view(v.shape[0], -1)
1456+
forward_args.cross_kv = torch.cat([k_flat, v_flat],
1457+
dim=1).contiguous()
1458+
1459+
q_hidden_size = self.num_heads * self.head_dim
1460+
kv_hidden_size = self.num_kv_heads * self.head_dim
1461+
qkv_hidden_size = q_hidden_size + 2 * kv_hidden_size
1462+
if q.shape[1] == q_hidden_size:
1463+
fused_q = q.new_zeros((q.shape[0], qkv_hidden_size))
1464+
fused_q[:, :q_hidden_size].copy_(q)
1465+
q = fused_q
1466+
else:
1467+
assert q.shape[1] == qkv_hidden_size
1468+
k = None
1469+
v = None
1470+
forward_args.is_fused_qkv = True
1471+
14421472
attention_input_type = forward_args.attention_input_type
14431473
if not self.is_mla_enable:
14441474
if forward_args.is_fused_qkv:
@@ -1453,7 +1483,7 @@ def _run(
14531483
assert k.shape[1] == kv_hidden_size
14541484
assert v.shape[1] == kv_hidden_size
14551485
num_tokens = q.shape[0]
1456-
if k is not None:
1486+
if k is not None and not metadata.is_cross:
14571487
assert k.shape[0] == num_tokens
14581488
assert v.shape[0] == num_tokens
14591489
else:
@@ -1586,7 +1616,7 @@ def _run(
15861616
block_ids_per_seq=metadata.block_ids_per_seq,
15871617
tokens_per_block=metadata.tokens_per_block,
15881618
max_num_requests=metadata.max_num_requests,
1589-
beam_width=metadata.beam_width,
1619+
beam_width=metadata.effective_beam_width,
15901620
use_paged_context_fmha=metadata.use_paged_context_fmha,
15911621
helix_position_offsets=metadata.helix_position_offsets,
15921622
helix_is_inactive_rank=metadata.helix_is_inactive_rank,
@@ -1612,6 +1642,7 @@ def _run(
16121642
max_context_length=metadata.max_context_length,
16131643
max_seq_len=metadata.max_seq_len,
16141644
trtllm_gen_jit_warmup=metadata.trtllm_gen_jit_warmup,
1645+
is_cross=metadata.is_cross,
16151646

16161647
# --- Per-call (AttentionForwardArgs) ---
16171648
out_scale=forward_args.out_scale,
@@ -1643,6 +1674,10 @@ def _run(
16431674
sage_attn_qk_int8=forward_args.sage_attn_qk_int8,
16441675
is_fused_qkv=forward_args.is_fused_qkv,
16451676
update_kv_cache=forward_args.update_kv_cache,
1677+
cross_kv=forward_args.cross_kv,
1678+
relative_attention_bias=forward_args.relative_attention_bias,
1679+
relative_attention_max_distance=(
1680+
forward_args.relative_attention_max_distance),
16461681

16471682
# --- Module config (TrtllmAttention) ---
16481683
rotary_inv_freq=self.rotary_inv_freq,
@@ -1716,7 +1751,8 @@ def forward(
17161751
metadata,
17171752
TrtllmAttentionMetadata,
17181753
)
1719-
assert not metadata.is_cross, "TRT-LLM Attention does not support cross attention yet."
1754+
# Cross-attention uses the THOP path; the trtllm-gen backend API does
1755+
# not carry encoder K/V tensors yet.
17201756

17211757
# SM90 forces ``use_paged_context_fmha`` on for correctness
17221758
# (https://nvbugs/5624818).
@@ -1750,9 +1786,13 @@ def forward(
17501786

17511787
forward_args.is_fused_qkv = not metadata.is_cross and k is None
17521788
forward_args.update_kv_cache = not metadata.is_cross or k is not None
1753-
assert (forward_args.is_fused_qkv and k is None
1754-
and v is None) or (not forward_args.is_fused_qkv
1755-
and k is not None and v is not None)
1789+
has_fused_qkv = forward_args.is_fused_qkv and k is None and v is None
1790+
has_unfused_kv = (not forward_args.is_fused_qkv and k is not None
1791+
and v is not None)
1792+
uses_cached_cross_kv = (metadata.is_cross
1793+
and not forward_args.update_kv_cache
1794+
and k is None and v is None)
1795+
assert has_fused_qkv or has_unfused_kv or uses_cached_cross_kv
17561796
if forward_args.cu_q_seqlens is None:
17571797
forward_args.cu_q_seqlens = metadata.cu_q_seqlens
17581798
if forward_args.cu_kv_seqlens is None:

tensorrt_llm/_torch/attention_backend/trtllm_gen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,8 @@ def is_supported(
597597
attn.skip_softmax_threshold_scale_factor_prefill is not None
598598
or attn.skip_softmax_threshold_scale_factor_decode is not None
599599
)
600+
if meta.is_cross:
601+
return False, "trtllm-gen does not support cross attention."
600602
if (
601603
fwd.sage_attn_num_elts_per_blk_q > 0
602604
or fwd.sage_attn_num_elts_per_blk_k > 0

0 commit comments

Comments
 (0)