@@ -376,7 +376,8 @@ class RunnerBase
376376 std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer,
377377 std::optional<torch::Tensor> flash_mla_tile_scheduler_metadata,
378378 std::optional<torch::Tensor> flash_mla_num_splits, bool trtllm_gen_jit_warmup,
379- std::optional<int64_t > compressed_kv_cache_pool_ptr) const
379+ std::optional<int64_t > compressed_kv_cache_pool_ptr, bool const is_cross, std::optional<torch::Tensor> cross_kv,
380+ std::optional<torch::Tensor> relative_attention_bias) const
380381 = 0;
381382};
382383
@@ -444,7 +445,8 @@ class Runner : public RunnerBase
444445 std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer,
445446 std::optional<torch::Tensor> flash_mla_tile_scheduler_metadata,
446447 std::optional<torch::Tensor> flash_mla_num_splits, bool trtllm_gen_jit_warmup,
447- std::optional<int64_t > compressed_kv_cache_pool_ptr) const override
448+ std::optional<int64_t > compressed_kv_cache_pool_ptr, bool const is_cross, std::optional<torch::Tensor> cross_kv,
449+ std::optional<torch::Tensor> relative_attention_bias) const override
448450 {
449451 auto stream = at::cuda::getCurrentCUDAStream (qkv_or_q.get_device ());
450452 T* attention_input = static_cast <T*>(qkv_or_q.slice (0 , token_offset).data_ptr ());
@@ -677,6 +679,20 @@ class Runner : public RunnerBase
677679 attention_sinks.value ().dtype () == torch::kFloat32 , " Expected attention_sinks to have float dtype" );
678680 attention_sinks_ptr = attention_sinks.value ().data_ptr <float >();
679681 }
682+ T const * relative_attention_bias_ptr = nullptr ;
683+ int relative_attention_bias_stride = 0 ;
684+ if (relative_attention_bias.has_value ())
685+ {
686+ auto const & relative_attention_bias_tensor = relative_attention_bias.value ();
687+ TORCH_CHECK (relative_attention_bias_tensor.dim () == 2 || relative_attention_bias_tensor.dim () == 3 ,
688+ " relative_attention_bias must be [num_heads, num_buckets] for implicit mode or "
689+ " [num_heads, max_seq_len, max_seq_len] for explicit mode" );
690+ TORCH_CHECK (relative_attention_bias_tensor.is_contiguous (), " relative_attention_bias must be contiguous" );
691+ TORCH_CHECK (relative_attention_bias_tensor.scalar_type () == qkv_or_q.scalar_type (),
692+ " relative_attention_bias dtype must match attention input dtype" );
693+ relative_attention_bias_ptr = static_cast <T const *>(relative_attention_bias_tensor.data_ptr ());
694+ relative_attention_bias_stride = static_cast <int >(relative_attention_bias_tensor.size (1 ));
695+ }
680696
681697 // Prepare sparse attention parameters
682698 op.mRuntimeSparseAttentionParams .sparse_kv_indices
@@ -723,6 +739,8 @@ class Runner : public RunnerBase
723739 common_enqueue_params.attention_sinks = attention_sinks_ptr;
724740 common_enqueue_params.rotary_inv_freq = rotary_inv_freq_ptr;
725741 common_enqueue_params.rotary_cos_sin = rotary_cos_sin_ptr;
742+ common_enqueue_params.relative_attention_bias = relative_attention_bias_ptr;
743+ common_enqueue_params.relative_attention_bias_stride = relative_attention_bias_stride;
726744 common_enqueue_params.max_past_kv_length = max_past_kv_length;
727745 common_enqueue_params.max_attention_window_size = max_attention_window_size;
728746 common_enqueue_params.cyclic_attention_window_size = cyclic_attention_window_size;
@@ -747,6 +765,13 @@ class Runner : public RunnerBase
747765 common_enqueue_params.host_context_lengths = host_context_lengths.data_ptr <int32_t >();
748766 common_enqueue_params.workspace = workspace_ptr;
749767 common_enqueue_params.trtllm_gen_jit_warmup = trtllm_gen_jit_warmup;
768+ if (is_cross)
769+ {
770+ // For cross attention, the KV (encoder) sequence lengths are passed in via
771+ // `sequence_length` (already sliced into `sequence_lengths_ptr`), so reuse
772+ // it directly instead of a redundant `encoder_input_lengths` tensor.
773+ common_enqueue_params.encoder_input_lengths = sequence_lengths_ptr;
774+ }
750775 if (softmax_stats_tensor.has_value ())
751776 {
752777 TLLM_CHECK_WITH_INFO (softmax_stats_tensor.value ().scalar_type () == at::ScalarType::Float,
@@ -807,6 +832,14 @@ class Runner : public RunnerBase
807832 {
808833 enqueue_params.v_stride_in_bytes = v->strides ()[0 ] * v->element_size ();
809834 }
835+ if (is_cross && cross_kv.has_value ())
836+ {
837+ auto const & cross_kv_tensor = cross_kv.value ();
838+ enqueue_params.cross_kv = static_cast <T const *>(cross_kv_tensor.data_ptr ());
839+ enqueue_params.num_encoder_tokens = static_cast <int32_t >(cross_kv_tensor.size (0 ));
840+ enqueue_params.cross_kv_length
841+ = host_past_key_value_lengths.slice (0 , seq_offset, seq_offset + num_seqs).max ().item <int32_t >();
842+ }
810843
811844 if (op.isMLAEnabled ())
812845 {
@@ -993,7 +1026,9 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
9931026 std::optional<torch::Tensor> flash_mla_tile_scheduler_metadata, std::optional<torch::Tensor> flash_mla_num_splits,
9941027 int64_t sage_attn_num_elts_per_blk_q, int64_t sage_attn_num_elts_per_blk_k, int64_t sage_attn_num_elts_per_blk_v,
9951028 bool sage_attn_qk_int8, int64_t num_contexts, int64_t num_ctx_tokens, bool trtllm_gen_jit_warmup,
996- std::optional<int64_t > compressed_kv_cache_pool_ptr, std::optional<int64_t > spec_decoding_target_max_draft_tokens)
1029+ std::optional<int64_t > compressed_kv_cache_pool_ptr, bool const is_cross, std::optional<torch::Tensor> cross_kv,
1030+ std::optional<torch::Tensor> relative_attention_bias, int64_t relative_attention_max_distance,
1031+ std::optional<int64_t > spec_decoding_target_max_draft_tokens)
9971032{
9981033 TLLM_LOG_TRACE (" Attention op starts at layer %d" , local_layer_idx);
9991034 // Use these tensors to infer if the attention is using KV cache
@@ -1002,16 +1037,17 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
10021037
10031038 bool const use_sage_attn
10041039 = sage_attn_num_elts_per_blk_q > 0 || sage_attn_num_elts_per_blk_k > 0 || sage_attn_num_elts_per_blk_v > 0 ;
1005- TLLM_CHECK_WITH_INFO (is_mla_enable || is_fused_qkv || use_sage_attn,
1006- " Context attention only allows these non-MLA cases: fused QKV; separate QKV with SageAttention" );
1007- TLLM_CHECK_WITH_INFO (update_kv_cache, " KV cache update cannot be disabled now" );
1040+ TLLM_CHECK_WITH_INFO (is_mla_enable || is_fused_qkv || use_sage_attn || is_cross,
1041+ " For non-MLA, non-cross, non-SageAttention attention, only fused QKV is supported now." );
1042+ TLLM_CHECK_WITH_INFO (
1043+ update_kv_cache || is_cross, " KV cache update cannot be disabled now (except for cross attention)." );
10081044 auto qkv_or_q = q;
10091045 if (is_fused_qkv)
10101046 {
10111047 TLLM_CHECK_WITH_INFO (!k.has_value (), " The k tensor should be null if using fused QKV" );
10121048 TLLM_CHECK_WITH_INFO (!v.has_value (), " The v tensor should be null if using fused QKV" );
10131049 }
1014- if (!is_fused_qkv && update_kv_cache)
1050+ if (!is_fused_qkv && update_kv_cache && !is_cross )
10151051 {
10161052 TLLM_CHECK_WITH_INFO (k.has_value (), " The k tensor should be provided if updating KV cache with unfused K/V" );
10171053 TLLM_CHECK_WITH_INFO (v.has_value (), " The v tensor should be provided if updating KV cache with unfused K/V" );
@@ -1094,6 +1130,20 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
10941130 op->mQScaling = q_scaling;
10951131 op->mPositionEmbeddingType
10961132 = static_cast <tensorrt_llm::kernels::PositionEmbeddingType>(int8_t (position_embedding_type));
1133+ if (relative_attention_bias.has_value ())
1134+ {
1135+ auto const relative_attention_bias_dim = relative_attention_bias.value ().dim ();
1136+ TORCH_CHECK (relative_attention_bias_dim == 2 || relative_attention_bias_dim == 3 ,
1137+ " relative_attention_bias must be [num_heads, num_buckets] for implicit mode or "
1138+ " [num_heads, max_seq_len, max_seq_len] for explicit mode" );
1139+ TORCH_CHECK (relative_attention_bias_dim != 2 || relative_attention_max_distance > 0 ,
1140+ " relative_attention_max_distance must be positive when relative_attention_bias is a bucket table" );
1141+ TORCH_CHECK (relative_attention_bias_dim != 3 || relative_attention_max_distance == 0 ,
1142+ " relative_attention_max_distance must be 0 when relative_attention_bias is precomputed" );
1143+ TLLM_CHECK_WITH_INFO (op->mPositionEmbeddingType == tensorrt_llm::kernels::PositionEmbeddingType::kRELATIVE ,
1144+ " relative_attention_bias requires position_embedding_type to be relative." );
1145+ op->mMaxDistance = static_cast <int >(relative_attention_max_distance);
1146+ }
10971147 op->mRotaryEmbeddingDim = rope_dim;
10981148 op->mRotaryEmbeddingBase = rope_base;
10991149 op->mRotaryEmbeddingScaleType = static_cast <tensorrt_llm::kernels::RotaryScalingType>(int8_t (rope_scale_type));
@@ -1111,6 +1161,7 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
11111161 op->mSageAttnQkInt8 = sage_attn_qk_int8;
11121162 op->mFP8AttenOutput = is_fp8_out;
11131163 op->mPagedContextFMHA = use_paged_context_fmha;
1164+ op->mCrossAttention = is_cross;
11141165
11151166 op->mAttentionChunkSize = attention_chunk_size;
11161167 op->mSkipSoftmaxThresholdScaleFactorPrefill
@@ -1275,7 +1326,7 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
12751326 sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets, sparse_attn_indices_block_size,
12761327 num_sparse_topk_value, sparse_mla_topk_lens, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter,
12771328 mla_bmm1_scale, mla_bmm2_scale, quant_q_buffer, flash_mla_tile_scheduler_metadata, flash_mla_num_splits,
1278- trtllm_gen_jit_warmup, compressed_kv_cache_pool_ptr);
1329+ trtllm_gen_jit_warmup, compressed_kv_cache_pool_ptr, is_cross, cross_kv, relative_attention_bias );
12791330 }
12801331
12811332 if ((num_generations > 0 ) && (attn_input_type != AttentionInputType::ContextOnly))
@@ -1297,7 +1348,7 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
12971348 sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets, sparse_attn_indices_block_size,
12981349 num_sparse_topk_value, sparse_mla_topk_lens, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter,
12991350 mla_bmm1_scale, mla_bmm2_scale, quant_q_buffer, flash_mla_tile_scheduler_metadata, flash_mla_num_splits,
1300- trtllm_gen_jit_warmup, compressed_kv_cache_pool_ptr);
1351+ trtllm_gen_jit_warmup, compressed_kv_cache_pool_ptr, is_cross, cross_kv, relative_attention_bias );
13011352 }
13021353
13031354 TLLM_LOG_TRACE (" Attention op stops at layer %d" , local_layer_idx);
0 commit comments