Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cpp/tensorrt_llm/nanobind/thop/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ void initBindings(nb::module_& m)
nb::arg("sage_attn_num_elts_per_blk_k") = 0, nb::arg("sage_attn_num_elts_per_blk_v") = 0,
nb::arg("sage_attn_qk_int8") = false, nb::arg("num_contexts") = 0, nb::arg("num_ctx_tokens") = 0,
nb::arg("trtllm_gen_jit_warmup") = false, nb::arg("compressed_kv_cache_pool_ptr") = std::nullopt,
nb::arg("is_cross") = false, nb::arg("cross_kv") = std::nullopt,
nb::arg("relative_attention_bias") = std::nullopt, nb::arg("relative_attention_max_distance") = 0,
nb::arg("spec_decoding_target_max_draft_tokens") = std::nullopt, "Multi-head attention operation",
nb::call_guard<nb::gil_scoped_release>());

Expand Down
69 changes: 60 additions & 9 deletions cpp/tensorrt_llm/thop/attentionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,8 @@ class RunnerBase
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer,
std::optional<torch::Tensor> flash_mla_tile_scheduler_metadata,
std::optional<torch::Tensor> flash_mla_num_splits, bool trtllm_gen_jit_warmup,
std::optional<int64_t> compressed_kv_cache_pool_ptr) const
std::optional<int64_t> compressed_kv_cache_pool_ptr, bool const is_cross, std::optional<torch::Tensor> cross_kv,
std::optional<torch::Tensor> relative_attention_bias) const
= 0;
};

Expand Down Expand Up @@ -444,7 +445,8 @@ class Runner : public RunnerBase
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer,
std::optional<torch::Tensor> flash_mla_tile_scheduler_metadata,
std::optional<torch::Tensor> flash_mla_num_splits, bool trtllm_gen_jit_warmup,
std::optional<int64_t> compressed_kv_cache_pool_ptr) const override
std::optional<int64_t> compressed_kv_cache_pool_ptr, bool const is_cross, std::optional<torch::Tensor> cross_kv,
std::optional<torch::Tensor> relative_attention_bias) const override
{
auto stream = at::cuda::getCurrentCUDAStream(qkv_or_q.get_device());
T* attention_input = static_cast<T*>(qkv_or_q.slice(0, token_offset).data_ptr());
Expand Down Expand Up @@ -677,6 +679,20 @@ class Runner : public RunnerBase
attention_sinks.value().dtype() == torch::kFloat32, "Expected attention_sinks to have float dtype");
attention_sinks_ptr = attention_sinks.value().data_ptr<float>();
}
T const* relative_attention_bias_ptr = nullptr;
int relative_attention_bias_stride = 0;
if (relative_attention_bias.has_value())
{
auto const& relative_attention_bias_tensor = relative_attention_bias.value();
TORCH_CHECK(relative_attention_bias_tensor.dim() == 2 || relative_attention_bias_tensor.dim() == 3,
"relative_attention_bias must be [num_heads, num_buckets] for implicit mode or "
"[num_heads, max_seq_len, max_seq_len] for explicit mode");
TORCH_CHECK(relative_attention_bias_tensor.is_contiguous(), "relative_attention_bias must be contiguous");
TORCH_CHECK(relative_attention_bias_tensor.scalar_type() == qkv_or_q.scalar_type(),
"relative_attention_bias dtype must match attention input dtype");
relative_attention_bias_ptr = static_cast<T const*>(relative_attention_bias_tensor.data_ptr());
relative_attention_bias_stride = static_cast<int>(relative_attention_bias_tensor.size(1));
}

// Prepare sparse attention parameters
op.mRuntimeSparseAttentionParams.sparse_kv_indices
Expand Down Expand Up @@ -723,6 +739,8 @@ class Runner : public RunnerBase
common_enqueue_params.attention_sinks = attention_sinks_ptr;
common_enqueue_params.rotary_inv_freq = rotary_inv_freq_ptr;
common_enqueue_params.rotary_cos_sin = rotary_cos_sin_ptr;
common_enqueue_params.relative_attention_bias = relative_attention_bias_ptr;
common_enqueue_params.relative_attention_bias_stride = relative_attention_bias_stride;
common_enqueue_params.max_past_kv_length = max_past_kv_length;
common_enqueue_params.max_attention_window_size = max_attention_window_size;
common_enqueue_params.cyclic_attention_window_size = cyclic_attention_window_size;
Expand All @@ -747,6 +765,13 @@ class Runner : public RunnerBase
common_enqueue_params.host_context_lengths = host_context_lengths.data_ptr<int32_t>();
common_enqueue_params.workspace = workspace_ptr;
common_enqueue_params.trtllm_gen_jit_warmup = trtllm_gen_jit_warmup;
if (is_cross)
{
// For cross attention, the KV (encoder) sequence lengths are passed in via
// `sequence_length` (already sliced into `sequence_lengths_ptr`), so reuse
// it directly instead of a redundant `encoder_input_lengths` tensor.
common_enqueue_params.encoder_input_lengths = sequence_lengths_ptr;
}
if (softmax_stats_tensor.has_value())
{
TLLM_CHECK_WITH_INFO(softmax_stats_tensor.value().scalar_type() == at::ScalarType::Float,
Expand Down Expand Up @@ -807,6 +832,14 @@ class Runner : public RunnerBase
{
enqueue_params.v_stride_in_bytes = v->strides()[0] * v->element_size();
}
if (is_cross && cross_kv.has_value())
{
auto const& cross_kv_tensor = cross_kv.value();
enqueue_params.cross_kv = static_cast<T const*>(cross_kv_tensor.data_ptr());
enqueue_params.num_encoder_tokens = static_cast<int32_t>(cross_kv_tensor.size(0));
enqueue_params.cross_kv_length
= host_past_key_value_lengths.slice(0, seq_offset, seq_offset + num_seqs).max().item<int32_t>();
}

if (op.isMLAEnabled())
{
Expand Down Expand Up @@ -993,7 +1026,9 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
std::optional<torch::Tensor> flash_mla_tile_scheduler_metadata, std::optional<torch::Tensor> flash_mla_num_splits,
int64_t sage_attn_num_elts_per_blk_q, int64_t sage_attn_num_elts_per_blk_k, int64_t sage_attn_num_elts_per_blk_v,
bool sage_attn_qk_int8, int64_t num_contexts, int64_t num_ctx_tokens, bool trtllm_gen_jit_warmup,
std::optional<int64_t> compressed_kv_cache_pool_ptr, std::optional<int64_t> spec_decoding_target_max_draft_tokens)
std::optional<int64_t> compressed_kv_cache_pool_ptr, bool const is_cross, std::optional<torch::Tensor> cross_kv,
std::optional<torch::Tensor> relative_attention_bias, int64_t relative_attention_max_distance,
std::optional<int64_t> spec_decoding_target_max_draft_tokens)
{
TLLM_LOG_TRACE("Attention op starts at layer %d", local_layer_idx);
// Use these tensors to infer if the attention is using KV cache
Expand All @@ -1002,16 +1037,17 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to

bool const use_sage_attn
= sage_attn_num_elts_per_blk_q > 0 || sage_attn_num_elts_per_blk_k > 0 || sage_attn_num_elts_per_blk_v > 0;
TLLM_CHECK_WITH_INFO(is_mla_enable || is_fused_qkv || use_sage_attn,
"Context attention only allows these non-MLA cases: fused QKV; separate QKV with SageAttention");
TLLM_CHECK_WITH_INFO(update_kv_cache, "KV cache update cannot be disabled now");
TLLM_CHECK_WITH_INFO(is_mla_enable || is_fused_qkv || use_sage_attn || is_cross,
"For non-MLA, non-cross, non-SageAttention attention, only fused QKV is supported now.");
TLLM_CHECK_WITH_INFO(
update_kv_cache || is_cross, "KV cache update cannot be disabled now (except for cross attention).");
auto qkv_or_q = q;
if (is_fused_qkv)
{
TLLM_CHECK_WITH_INFO(!k.has_value(), "The k tensor should be null if using fused QKV");
TLLM_CHECK_WITH_INFO(!v.has_value(), "The v tensor should be null if using fused QKV");
}
if (!is_fused_qkv && update_kv_cache)
if (!is_fused_qkv && update_kv_cache && !is_cross)
{
TLLM_CHECK_WITH_INFO(k.has_value(), "The k tensor should be provided if updating KV cache with unfused K/V");
TLLM_CHECK_WITH_INFO(v.has_value(), "The v tensor should be provided if updating KV cache with unfused K/V");
Expand Down Expand Up @@ -1094,6 +1130,20 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
op->mQScaling = q_scaling;
op->mPositionEmbeddingType
= static_cast<tensorrt_llm::kernels::PositionEmbeddingType>(int8_t(position_embedding_type));
if (relative_attention_bias.has_value())
{
auto const relative_attention_bias_dim = relative_attention_bias.value().dim();
TORCH_CHECK(relative_attention_bias_dim == 2 || relative_attention_bias_dim == 3,
"relative_attention_bias must be [num_heads, num_buckets] for implicit mode or "
"[num_heads, max_seq_len, max_seq_len] for explicit mode");
TORCH_CHECK(relative_attention_bias_dim != 2 || relative_attention_max_distance > 0,
"relative_attention_max_distance must be positive when relative_attention_bias is a bucket table");
TORCH_CHECK(relative_attention_bias_dim != 3 || relative_attention_max_distance == 0,
"relative_attention_max_distance must be 0 when relative_attention_bias is precomputed");
TLLM_CHECK_WITH_INFO(op->mPositionEmbeddingType == tensorrt_llm::kernels::PositionEmbeddingType::kRELATIVE,
"relative_attention_bias requires position_embedding_type to be relative.");
op->mMaxDistance = static_cast<int>(relative_attention_max_distance);
}
op->mRotaryEmbeddingDim = rope_dim;
op->mRotaryEmbeddingBase = rope_base;
op->mRotaryEmbeddingScaleType = static_cast<tensorrt_llm::kernels::RotaryScalingType>(int8_t(rope_scale_type));
Expand All @@ -1111,6 +1161,7 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
op->mSageAttnQkInt8 = sage_attn_qk_int8;
op->mFP8AttenOutput = is_fp8_out;
op->mPagedContextFMHA = use_paged_context_fmha;
op->mCrossAttention = is_cross;

op->mAttentionChunkSize = attention_chunk_size;
op->mSkipSoftmaxThresholdScaleFactorPrefill
Expand Down Expand Up @@ -1275,7 +1326,7 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets, sparse_attn_indices_block_size,
num_sparse_topk_value, sparse_mla_topk_lens, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter,
mla_bmm1_scale, mla_bmm2_scale, quant_q_buffer, flash_mla_tile_scheduler_metadata, flash_mla_num_splits,
trtllm_gen_jit_warmup, compressed_kv_cache_pool_ptr);
trtllm_gen_jit_warmup, compressed_kv_cache_pool_ptr, is_cross, cross_kv, relative_attention_bias);
}

if ((num_generations > 0) && (attn_input_type != AttentionInputType::ContextOnly))
Expand All @@ -1297,7 +1348,7 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets, sparse_attn_indices_block_size,
num_sparse_topk_value, sparse_mla_topk_lens, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter,
mla_bmm1_scale, mla_bmm2_scale, quant_q_buffer, flash_mla_tile_scheduler_metadata, flash_mla_num_splits,
trtllm_gen_jit_warmup, compressed_kv_cache_pool_ptr);
trtllm_gen_jit_warmup, compressed_kv_cache_pool_ptr, is_cross, cross_kv, relative_attention_bias);
}

TLLM_LOG_TRACE("Attention op stops at layer %d", local_layer_idx);
Expand Down
4 changes: 3 additions & 1 deletion cpp/tensorrt_llm/thop/attentionOp.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
std::optional<torch::Tensor> flash_mla_num_splits = std::nullopt, int64_t sage_attn_num_elts_per_blk_q = 0,
int64_t sage_attn_num_elts_per_blk_k = 0, int64_t sage_attn_num_elts_per_blk_v = 0, bool sage_attn_qk_int8 = false,
int64_t num_contexts = 0, int64_t num_ctx_tokens = 0, bool trtllm_gen_jit_warmup = false,
std::optional<int64_t> compressed_kv_cache_pool_ptr = std::nullopt,
std::optional<int64_t> compressed_kv_cache_pool_ptr = std::nullopt, bool const is_cross = false,
std::optional<torch::Tensor> cross_kv = std::nullopt,
std::optional<torch::Tensor> relative_attention_bias = std::nullopt, int64_t relative_attention_max_distance = 0,
std::optional<int64_t> spec_decoding_target_max_draft_tokens = std::nullopt);

struct KvCachePoolPointers
Expand Down
10 changes: 9 additions & 1 deletion tensorrt_llm/_torch/attention_backend/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,12 @@ def seq_lens_kv(self, value: Optional[torch.Tensor]):
# The model executor sets seqlens to None initially.
if self._seq_lens_kv is not None:
self._seq_lens_kv = maybe_pin_memory(self._seq_lens_kv)
self._seq_lens_kv_cuda = self._seq_lens_kv.cuda(non_blocking=True)
if self.is_cuda_graph and self._seq_lens_kv_cuda is not None:
self._seq_lens_kv_cuda.copy_(self._seq_lens_kv,
non_blocking=True)
else:
self._seq_lens_kv_cuda = self._seq_lens_kv.cuda(
non_blocking=True)

@property
def seq_lens_kv_cuda(self):
Expand Down Expand Up @@ -747,6 +752,9 @@ class AttentionForwardArgs:
attention_window_size: Optional[int] = None
attention_mask_data: Optional[torch.Tensor] = None
attention_sinks: Optional[torch.Tensor] = None
relative_attention_bias: Optional[torch.Tensor] = None
relative_attention_max_distance: int = 0
cross_kv: Optional[torch.Tensor] = None

latent_cache: Optional[torch.Tensor] = None
q_pe: Optional[torch.Tensor] = None
Expand Down
52 changes: 46 additions & 6 deletions tensorrt_llm/_torch/attention_backend/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,16 @@ def max_context_length(self) -> int:
"""
return min(self.max_seq_len, self.max_num_tokens)

@property
def effective_beam_width(self) -> int:
"""Beam width visible to the kernel.

Cross-attention metadata is already expanded to one row per decoder
beam, and all beams read the same encoder K/V cache. Keep kernel beam
indirection disabled for that path.
"""
return 1 if self.is_cross else self.beam_width

@property
def max_seq_len(self) -> int:
"""
Expand Down Expand Up @@ -1439,6 +1449,26 @@ def _run(
metadata: TrtllmAttentionMetadata,
forward_args: AttentionForwardArgs,
) -> None:
if metadata.is_cross:
if k is not None and v is not None:
k_flat = k.contiguous().view(k.shape[0], -1)
v_flat = v.contiguous().view(v.shape[0], -1)
forward_args.cross_kv = torch.cat([k_flat, v_flat],
dim=1).contiguous()

q_hidden_size = self.num_heads * self.head_dim
kv_hidden_size = self.num_kv_heads * self.head_dim
qkv_hidden_size = q_hidden_size + 2 * kv_hidden_size
if q.shape[1] == q_hidden_size:
fused_q = q.new_zeros((q.shape[0], qkv_hidden_size))
fused_q[:, :q_hidden_size].copy_(q)
q = fused_q
else:
assert q.shape[1] == qkv_hidden_size
k = None
v = None
forward_args.is_fused_qkv = True

attention_input_type = forward_args.attention_input_type
if not self.is_mla_enable:
if forward_args.is_fused_qkv:
Expand All @@ -1453,7 +1483,7 @@ def _run(
assert k.shape[1] == kv_hidden_size
assert v.shape[1] == kv_hidden_size
num_tokens = q.shape[0]
if k is not None:
if k is not None and not metadata.is_cross:
assert k.shape[0] == num_tokens
assert v.shape[0] == num_tokens
else:
Expand Down Expand Up @@ -1586,7 +1616,7 @@ def _run(
block_ids_per_seq=metadata.block_ids_per_seq,
tokens_per_block=metadata.tokens_per_block,
max_num_requests=metadata.max_num_requests,
beam_width=metadata.beam_width,
beam_width=metadata.effective_beam_width,
use_paged_context_fmha=metadata.use_paged_context_fmha,
helix_position_offsets=metadata.helix_position_offsets,
helix_is_inactive_rank=metadata.helix_is_inactive_rank,
Expand All @@ -1612,6 +1642,7 @@ def _run(
max_context_length=metadata.max_context_length,
max_seq_len=metadata.max_seq_len,
trtllm_gen_jit_warmup=metadata.trtllm_gen_jit_warmup,
is_cross=metadata.is_cross,

# --- Per-call (AttentionForwardArgs) ---
out_scale=forward_args.out_scale,
Expand Down Expand Up @@ -1643,6 +1674,10 @@ def _run(
sage_attn_qk_int8=forward_args.sage_attn_qk_int8,
is_fused_qkv=forward_args.is_fused_qkv,
update_kv_cache=forward_args.update_kv_cache,
cross_kv=forward_args.cross_kv,
relative_attention_bias=forward_args.relative_attention_bias,
relative_attention_max_distance=(
forward_args.relative_attention_max_distance),

# --- Module config (TrtllmAttention) ---
rotary_inv_freq=self.rotary_inv_freq,
Expand Down Expand Up @@ -1716,7 +1751,8 @@ def forward(
metadata,
TrtllmAttentionMetadata,
)
assert not metadata.is_cross, "TRT-LLM Attention does not support cross attention yet."
# Cross-attention uses the THOP path; the trtllm-gen backend API does
# not carry encoder K/V tensors yet.

# SM90 forces ``use_paged_context_fmha`` on for correctness
# (https://nvbugs/5624818).
Expand Down Expand Up @@ -1750,9 +1786,13 @@ def forward(

forward_args.is_fused_qkv = not metadata.is_cross and k is None
forward_args.update_kv_cache = not metadata.is_cross or k is not None
assert (forward_args.is_fused_qkv and k is None
and v is None) or (not forward_args.is_fused_qkv
and k is not None and v is not None)
has_fused_qkv = forward_args.is_fused_qkv and k is None and v is None
has_unfused_kv = (not forward_args.is_fused_qkv and k is not None
and v is not None)
uses_cached_cross_kv = (metadata.is_cross
and not forward_args.update_kv_cache
and k is None and v is None)
assert has_fused_qkv or has_unfused_kv or uses_cached_cross_kv
if forward_args.cu_q_seqlens is None:
forward_args.cu_q_seqlens = metadata.cu_q_seqlens
if forward_args.cu_kv_seqlens is None:
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/attention_backend/trtllm_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,8 @@ def is_supported(
attn.skip_softmax_threshold_scale_factor_prefill is not None
or attn.skip_softmax_threshold_scale_factor_decode is not None
)
if meta.is_cross:
return False, "trtllm-gen does not support cross attention."
if (
fwd.sage_attn_num_elts_per_blk_q > 0
or fwd.sage_attn_num_elts_per_blk_k > 0
Expand Down
Loading