diff --git a/cpp/tensorrt_llm/nanobind/thop/bindings.cpp b/cpp/tensorrt_llm/nanobind/thop/bindings.cpp index 04cac49d8e3d..90d4fb86a938 100644 --- a/cpp/tensorrt_llm/nanobind/thop/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/thop/bindings.cpp @@ -174,6 +174,8 @@ void initBindings(nb::module_& m) nb::arg("sage_attn_num_elts_per_blk_k") = 0, nb::arg("sage_attn_num_elts_per_blk_v") = 0, nb::arg("sage_attn_qk_int8") = false, nb::arg("num_contexts") = 0, nb::arg("num_ctx_tokens") = 0, nb::arg("trtllm_gen_jit_warmup") = false, nb::arg("compressed_kv_cache_pool_ptr") = std::nullopt, + nb::arg("is_cross") = false, nb::arg("cross_kv") = std::nullopt, + nb::arg("relative_attention_bias") = std::nullopt, nb::arg("relative_attention_max_distance") = 0, nb::arg("spec_decoding_target_max_draft_tokens") = std::nullopt, "Multi-head attention operation", nb::call_guard()); diff --git a/cpp/tensorrt_llm/thop/attentionOp.cpp b/cpp/tensorrt_llm/thop/attentionOp.cpp index 583d05f4acba..64b205e54085 100644 --- a/cpp/tensorrt_llm/thop/attentionOp.cpp +++ b/cpp/tensorrt_llm/thop/attentionOp.cpp @@ -376,7 +376,8 @@ class RunnerBase std::optional mla_bmm2_scale, std::optional quant_q_buffer, std::optional flash_mla_tile_scheduler_metadata, std::optional flash_mla_num_splits, bool trtllm_gen_jit_warmup, - std::optional compressed_kv_cache_pool_ptr) const + std::optional compressed_kv_cache_pool_ptr, bool const is_cross, std::optional cross_kv, + std::optional relative_attention_bias) const = 0; }; @@ -444,7 +445,8 @@ class Runner : public RunnerBase std::optional mla_bmm2_scale, std::optional quant_q_buffer, std::optional flash_mla_tile_scheduler_metadata, std::optional flash_mla_num_splits, bool trtllm_gen_jit_warmup, - std::optional compressed_kv_cache_pool_ptr) const override + std::optional compressed_kv_cache_pool_ptr, bool const is_cross, std::optional cross_kv, + std::optional relative_attention_bias) const override { auto stream = at::cuda::getCurrentCUDAStream(qkv_or_q.get_device()); T* attention_input = static_cast(qkv_or_q.slice(0, token_offset).data_ptr()); @@ -677,6 +679,20 @@ class Runner : public RunnerBase attention_sinks.value().dtype() == torch::kFloat32, "Expected attention_sinks to have float dtype"); attention_sinks_ptr = attention_sinks.value().data_ptr(); } + T const* relative_attention_bias_ptr = nullptr; + int relative_attention_bias_stride = 0; + if (relative_attention_bias.has_value()) + { + auto const& relative_attention_bias_tensor = relative_attention_bias.value(); + TORCH_CHECK(relative_attention_bias_tensor.dim() == 2 || relative_attention_bias_tensor.dim() == 3, + "relative_attention_bias must be [num_heads, num_buckets] for implicit mode or " + "[num_heads, max_seq_len, max_seq_len] for explicit mode"); + TORCH_CHECK(relative_attention_bias_tensor.is_contiguous(), "relative_attention_bias must be contiguous"); + TORCH_CHECK(relative_attention_bias_tensor.scalar_type() == qkv_or_q.scalar_type(), + "relative_attention_bias dtype must match attention input dtype"); + relative_attention_bias_ptr = static_cast(relative_attention_bias_tensor.data_ptr()); + relative_attention_bias_stride = static_cast(relative_attention_bias_tensor.size(1)); + } // Prepare sparse attention parameters op.mRuntimeSparseAttentionParams.sparse_kv_indices @@ -723,6 +739,8 @@ class Runner : public RunnerBase common_enqueue_params.attention_sinks = attention_sinks_ptr; common_enqueue_params.rotary_inv_freq = rotary_inv_freq_ptr; common_enqueue_params.rotary_cos_sin = rotary_cos_sin_ptr; + common_enqueue_params.relative_attention_bias = relative_attention_bias_ptr; + common_enqueue_params.relative_attention_bias_stride = relative_attention_bias_stride; common_enqueue_params.max_past_kv_length = max_past_kv_length; common_enqueue_params.max_attention_window_size = max_attention_window_size; common_enqueue_params.cyclic_attention_window_size = cyclic_attention_window_size; @@ -747,6 +765,13 @@ class Runner : public RunnerBase common_enqueue_params.host_context_lengths = host_context_lengths.data_ptr(); common_enqueue_params.workspace = workspace_ptr; common_enqueue_params.trtllm_gen_jit_warmup = trtllm_gen_jit_warmup; + if (is_cross) + { + // For cross attention, the KV (encoder) sequence lengths are passed in via + // `sequence_length` (already sliced into `sequence_lengths_ptr`), so reuse + // it directly instead of a redundant `encoder_input_lengths` tensor. + common_enqueue_params.encoder_input_lengths = sequence_lengths_ptr; + } if (softmax_stats_tensor.has_value()) { TLLM_CHECK_WITH_INFO(softmax_stats_tensor.value().scalar_type() == at::ScalarType::Float, @@ -807,6 +832,14 @@ class Runner : public RunnerBase { enqueue_params.v_stride_in_bytes = v->strides()[0] * v->element_size(); } + if (is_cross && cross_kv.has_value()) + { + auto const& cross_kv_tensor = cross_kv.value(); + enqueue_params.cross_kv = static_cast(cross_kv_tensor.data_ptr()); + enqueue_params.num_encoder_tokens = static_cast(cross_kv_tensor.size(0)); + enqueue_params.cross_kv_length + = host_past_key_value_lengths.slice(0, seq_offset, seq_offset + num_seqs).max().item(); + } if (op.isMLAEnabled()) { @@ -993,7 +1026,9 @@ void attention(torch::Tensor q, std::optional k, std::optional flash_mla_tile_scheduler_metadata, std::optional flash_mla_num_splits, int64_t sage_attn_num_elts_per_blk_q, int64_t sage_attn_num_elts_per_blk_k, int64_t sage_attn_num_elts_per_blk_v, bool sage_attn_qk_int8, int64_t num_contexts, int64_t num_ctx_tokens, bool trtllm_gen_jit_warmup, - std::optional compressed_kv_cache_pool_ptr, std::optional spec_decoding_target_max_draft_tokens) + std::optional compressed_kv_cache_pool_ptr, bool const is_cross, std::optional cross_kv, + std::optional relative_attention_bias, int64_t relative_attention_max_distance, + std::optional spec_decoding_target_max_draft_tokens) { TLLM_LOG_TRACE("Attention op starts at layer %d", local_layer_idx); // Use these tensors to infer if the attention is using KV cache @@ -1002,16 +1037,17 @@ void attention(torch::Tensor q, std::optional k, std::optional 0 || sage_attn_num_elts_per_blk_k > 0 || sage_attn_num_elts_per_blk_v > 0; - TLLM_CHECK_WITH_INFO(is_mla_enable || is_fused_qkv || use_sage_attn, - "Context attention only allows these non-MLA cases: fused QKV; separate QKV with SageAttention"); - TLLM_CHECK_WITH_INFO(update_kv_cache, "KV cache update cannot be disabled now"); + TLLM_CHECK_WITH_INFO(is_mla_enable || is_fused_qkv || use_sage_attn || is_cross, + "For non-MLA, non-cross, non-SageAttention attention, only fused QKV is supported now."); + TLLM_CHECK_WITH_INFO( + update_kv_cache || is_cross, "KV cache update cannot be disabled now (except for cross attention)."); auto qkv_or_q = q; if (is_fused_qkv) { TLLM_CHECK_WITH_INFO(!k.has_value(), "The k tensor should be null if using fused QKV"); TLLM_CHECK_WITH_INFO(!v.has_value(), "The v tensor should be null if using fused QKV"); } - if (!is_fused_qkv && update_kv_cache) + if (!is_fused_qkv && update_kv_cache && !is_cross) { TLLM_CHECK_WITH_INFO(k.has_value(), "The k tensor should be provided if updating KV cache with unfused K/V"); TLLM_CHECK_WITH_INFO(v.has_value(), "The v tensor should be provided if updating KV cache with unfused K/V"); @@ -1094,6 +1130,20 @@ void attention(torch::Tensor q, std::optional k, std::optionalmQScaling = q_scaling; op->mPositionEmbeddingType = static_cast(int8_t(position_embedding_type)); + if (relative_attention_bias.has_value()) + { + auto const relative_attention_bias_dim = relative_attention_bias.value().dim(); + TORCH_CHECK(relative_attention_bias_dim == 2 || relative_attention_bias_dim == 3, + "relative_attention_bias must be [num_heads, num_buckets] for implicit mode or " + "[num_heads, max_seq_len, max_seq_len] for explicit mode"); + TORCH_CHECK(relative_attention_bias_dim != 2 || relative_attention_max_distance > 0, + "relative_attention_max_distance must be positive when relative_attention_bias is a bucket table"); + TORCH_CHECK(relative_attention_bias_dim != 3 || relative_attention_max_distance == 0, + "relative_attention_max_distance must be 0 when relative_attention_bias is precomputed"); + TLLM_CHECK_WITH_INFO(op->mPositionEmbeddingType == tensorrt_llm::kernels::PositionEmbeddingType::kRELATIVE, + "relative_attention_bias requires position_embedding_type to be relative."); + op->mMaxDistance = static_cast(relative_attention_max_distance); + } op->mRotaryEmbeddingDim = rope_dim; op->mRotaryEmbeddingBase = rope_base; op->mRotaryEmbeddingScaleType = static_cast(int8_t(rope_scale_type)); @@ -1111,6 +1161,7 @@ void attention(torch::Tensor q, std::optional k, std::optionalmSageAttnQkInt8 = sage_attn_qk_int8; op->mFP8AttenOutput = is_fp8_out; op->mPagedContextFMHA = use_paged_context_fmha; + op->mCrossAttention = is_cross; op->mAttentionChunkSize = attention_chunk_size; op->mSkipSoftmaxThresholdScaleFactorPrefill @@ -1275,7 +1326,7 @@ void attention(torch::Tensor q, std::optional k, std::optional 0) && (attn_input_type != AttentionInputType::ContextOnly)) @@ -1297,7 +1348,7 @@ void attention(torch::Tensor q, std::optional k, std::optional k, std::optional flash_mla_num_splits = std::nullopt, int64_t sage_attn_num_elts_per_blk_q = 0, int64_t sage_attn_num_elts_per_blk_k = 0, int64_t sage_attn_num_elts_per_blk_v = 0, bool sage_attn_qk_int8 = false, int64_t num_contexts = 0, int64_t num_ctx_tokens = 0, bool trtllm_gen_jit_warmup = false, - std::optional compressed_kv_cache_pool_ptr = std::nullopt, + std::optional compressed_kv_cache_pool_ptr = std::nullopt, bool const is_cross = false, + std::optional cross_kv = std::nullopt, + std::optional relative_attention_bias = std::nullopt, int64_t relative_attention_max_distance = 0, std::optional spec_decoding_target_max_draft_tokens = std::nullopt); struct KvCachePoolPointers diff --git a/tensorrt_llm/_torch/attention_backend/interface.py b/tensorrt_llm/_torch/attention_backend/interface.py index 0b554d6c77e2..1e7b953836f3 100644 --- a/tensorrt_llm/_torch/attention_backend/interface.py +++ b/tensorrt_llm/_torch/attention_backend/interface.py @@ -263,7 +263,12 @@ def seq_lens_kv(self, value: Optional[torch.Tensor]): # The model executor sets seqlens to None initially. if self._seq_lens_kv is not None: self._seq_lens_kv = maybe_pin_memory(self._seq_lens_kv) - self._seq_lens_kv_cuda = self._seq_lens_kv.cuda(non_blocking=True) + if self.is_cuda_graph and self._seq_lens_kv_cuda is not None: + self._seq_lens_kv_cuda.copy_(self._seq_lens_kv, + non_blocking=True) + else: + self._seq_lens_kv_cuda = self._seq_lens_kv.cuda( + non_blocking=True) @property def seq_lens_kv_cuda(self): @@ -747,6 +752,9 @@ class AttentionForwardArgs: attention_window_size: Optional[int] = None attention_mask_data: Optional[torch.Tensor] = None attention_sinks: Optional[torch.Tensor] = None + relative_attention_bias: Optional[torch.Tensor] = None + relative_attention_max_distance: int = 0 + cross_kv: Optional[torch.Tensor] = None latent_cache: Optional[torch.Tensor] = None q_pe: Optional[torch.Tensor] = None diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index c8bc1b0e66cf..62a8a28a07f4 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -190,6 +190,16 @@ def max_context_length(self) -> int: """ return min(self.max_seq_len, self.max_num_tokens) + @property + def effective_beam_width(self) -> int: + """Beam width visible to the kernel. + + Cross-attention metadata is already expanded to one row per decoder + beam, and all beams read the same encoder K/V cache. Keep kernel beam + indirection disabled for that path. + """ + return 1 if self.is_cross else self.beam_width + @property def max_seq_len(self) -> int: """ @@ -1439,6 +1449,26 @@ def _run( metadata: TrtllmAttentionMetadata, forward_args: AttentionForwardArgs, ) -> None: + if metadata.is_cross: + if k is not None and v is not None: + k_flat = k.contiguous().view(k.shape[0], -1) + v_flat = v.contiguous().view(v.shape[0], -1) + forward_args.cross_kv = torch.cat([k_flat, v_flat], + dim=1).contiguous() + + q_hidden_size = self.num_heads * self.head_dim + kv_hidden_size = self.num_kv_heads * self.head_dim + qkv_hidden_size = q_hidden_size + 2 * kv_hidden_size + if q.shape[1] == q_hidden_size: + fused_q = q.new_zeros((q.shape[0], qkv_hidden_size)) + fused_q[:, :q_hidden_size].copy_(q) + q = fused_q + else: + assert q.shape[1] == qkv_hidden_size + k = None + v = None + forward_args.is_fused_qkv = True + attention_input_type = forward_args.attention_input_type if not self.is_mla_enable: if forward_args.is_fused_qkv: @@ -1453,7 +1483,7 @@ def _run( assert k.shape[1] == kv_hidden_size assert v.shape[1] == kv_hidden_size num_tokens = q.shape[0] - if k is not None: + if k is not None and not metadata.is_cross: assert k.shape[0] == num_tokens assert v.shape[0] == num_tokens else: @@ -1586,7 +1616,7 @@ def _run( block_ids_per_seq=metadata.block_ids_per_seq, tokens_per_block=metadata.tokens_per_block, max_num_requests=metadata.max_num_requests, - beam_width=metadata.beam_width, + beam_width=metadata.effective_beam_width, use_paged_context_fmha=metadata.use_paged_context_fmha, helix_position_offsets=metadata.helix_position_offsets, helix_is_inactive_rank=metadata.helix_is_inactive_rank, @@ -1612,6 +1642,7 @@ def _run( max_context_length=metadata.max_context_length, max_seq_len=metadata.max_seq_len, trtllm_gen_jit_warmup=metadata.trtllm_gen_jit_warmup, + is_cross=metadata.is_cross, # --- Per-call (AttentionForwardArgs) --- out_scale=forward_args.out_scale, @@ -1643,6 +1674,10 @@ def _run( sage_attn_qk_int8=forward_args.sage_attn_qk_int8, is_fused_qkv=forward_args.is_fused_qkv, update_kv_cache=forward_args.update_kv_cache, + cross_kv=forward_args.cross_kv, + relative_attention_bias=forward_args.relative_attention_bias, + relative_attention_max_distance=( + forward_args.relative_attention_max_distance), # --- Module config (TrtllmAttention) --- rotary_inv_freq=self.rotary_inv_freq, @@ -1716,7 +1751,8 @@ def forward( metadata, TrtllmAttentionMetadata, ) - assert not metadata.is_cross, "TRT-LLM Attention does not support cross attention yet." + # Cross-attention uses the THOP path; the trtllm-gen backend API does + # not carry encoder K/V tensors yet. # SM90 forces ``use_paged_context_fmha`` on for correctness # (https://nvbugs/5624818). @@ -1750,9 +1786,13 @@ def forward( forward_args.is_fused_qkv = not metadata.is_cross and k is None forward_args.update_kv_cache = not metadata.is_cross or k is not None - assert (forward_args.is_fused_qkv and k is None - and v is None) or (not forward_args.is_fused_qkv - and k is not None and v is not None) + has_fused_qkv = forward_args.is_fused_qkv and k is None and v is None + has_unfused_kv = (not forward_args.is_fused_qkv and k is not None + and v is not None) + uses_cached_cross_kv = (metadata.is_cross + and not forward_args.update_kv_cache + and k is None and v is None) + assert has_fused_qkv or has_unfused_kv or uses_cached_cross_kv if forward_args.cu_q_seqlens is None: forward_args.cu_q_seqlens = metadata.cu_q_seqlens if forward_args.cu_kv_seqlens is None: diff --git a/tensorrt_llm/_torch/attention_backend/trtllm_gen.py b/tensorrt_llm/_torch/attention_backend/trtllm_gen.py index 70adab96e959..192dd71d1617 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm_gen.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm_gen.py @@ -597,6 +597,8 @@ def is_supported( attn.skip_softmax_threshold_scale_factor_prefill is not None or attn.skip_softmax_threshold_scale_factor_decode is not None ) + if meta.is_cross: + return False, "trtllm-gen does not support cross attention." if ( fwd.sage_attn_num_elts_per_blk_q > 0 or fwd.sage_attn_num_elts_per_blk_k > 0