diff --git a/cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h b/cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h index 8c24d7e5c220..560a153ffa70 100644 --- a/cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h +++ b/cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h @@ -1420,7 +1420,7 @@ __global__ void updateKVCacheForCrossAttention(QKVPreprocessingParams cross_kv, bool cross_attention) { auto result = [&]() { @@ -70,7 +70,7 @@ nb::tuple trtllmGenContextPreprocessBinding(torch::Tensor qkv_input, torch::Tens max_past_kv_length, rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale_type, rotary_embedding_scale, rotary_embedding_max_positions, position_embedding_type, bmm1_scale, bmm2_scale, attention_chunk_size, fp8_context_fmha, paged_context_fmha, is_mla_enable, multi_processor_count, - total_num_blocks, kv_factor, need_build_kv_cache_metadata); + total_num_blocks, kv_factor, need_build_kv_cache_metadata, cross_kv, cross_attention); }(); return nb::make_tuple(std::get<0>(result), optionalToObject(std::get<1>(result)), @@ -92,7 +92,7 @@ nb::tuple trtllmGenGenerationPreprocessBinding(torch::Tensor qkv_input, torch::T int64_t rotary_embedding_scale_type, double rotary_embedding_scale, int64_t rotary_embedding_max_positions, int64_t position_embedding_type, double bmm1_scale, double bmm2_scale, bool fp8_context_fmha, int64_t predicted_tokens_per_seq, int64_t attention_chunk_size, int64_t multi_processor_count, - int64_t total_num_blocks, int64_t kv_factor, bool need_build_kv_cache_metadata) + int64_t total_num_blocks, int64_t kv_factor, bool need_build_kv_cache_metadata, bool cross_attention) { auto result = [&]() { @@ -106,7 +106,7 @@ nb::tuple trtllmGenGenerationPreprocessBinding(torch::Tensor qkv_input, torch::T rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale_type, rotary_embedding_scale, rotary_embedding_max_positions, position_embedding_type, bmm1_scale, bmm2_scale, fp8_context_fmha, predicted_tokens_per_seq, attention_chunk_size, multi_processor_count, total_num_blocks, kv_factor, - need_build_kv_cache_metadata); + need_build_kv_cache_metadata, cross_attention); }(); return nb::make_tuple(std::get<0>(result), optionalToObject(std::get<1>(result)), @@ -273,7 +273,8 @@ void initBindings(nb::module_& m) nb::arg("position_embedding_type"), nb::arg("bmm1_scale"), nb::arg("bmm2_scale"), nb::arg("attention_chunk_size"), nb::arg("fp8_context_fmha"), nb::arg("paged_context_fmha"), nb::arg("is_mla_enable"), nb::arg("multi_processor_count"), nb::arg("total_num_blocks"), nb::arg("kv_factor"), - nb::arg("need_build_kv_cache_metadata") = true, "Fused nanobind context preprocess for trtllm-gen attention."); + nb::arg("need_build_kv_cache_metadata") = true, nb::arg("cross_kv").none() = nb::none(), + nb::arg("cross_attention") = false, "Fused nanobind context preprocess for trtllm-gen attention."); m.def("trtllm_gen_context_postprocess", &torch_ext::trtllmGenContextPostprocess, nb::arg("qkv_input"), nb::arg("workspace"), nb::arg("sequence_lengths"), nb::arg("context_lengths"), @@ -332,6 +333,6 @@ void initBindings(nb::module_& m) nb::arg("position_embedding_type"), nb::arg("bmm1_scale"), nb::arg("bmm2_scale"), nb::arg("fp8_context_fmha"), nb::arg("predicted_tokens_per_seq"), nb::arg("attention_chunk_size"), nb::arg("multi_processor_count"), nb::arg("total_num_blocks"), nb::arg("kv_factor"), nb::arg("need_build_kv_cache_metadata") = true, - "Fused nanobind generation preprocess for trtllm-gen attention."); + nb::arg("cross_attention") = false, "Fused nanobind generation preprocess for trtllm-gen attention."); } } // namespace tensorrt_llm::nanobind::thop diff --git a/cpp/tensorrt_llm/thop/trtllmGenFusedOps.h b/cpp/tensorrt_llm/thop/trtllmGenFusedOps.h index cb4db3535756..2b3dead3fb57 100644 --- a/cpp/tensorrt_llm/thop/trtllmGenFusedOps.h +++ b/cpp/tensorrt_llm/thop/trtllmGenFusedOps.h @@ -42,7 +42,8 @@ trtllmGenContextPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, tor double rotary_embedding_scale, int64_t rotary_embedding_max_positions, int64_t position_embedding_type, double bmm1_scale, double bmm2_scale, int64_t attention_chunk_size, bool fp8_context_fmha, bool paged_context_fmha, bool is_mla_enable, int64_t multi_processor_count, int64_t total_num_blocks, int64_t kv_factor, - bool need_build_kv_cache_metadata); + bool need_build_kv_cache_metadata, std::optional cross_kv = std::nullopt, + bool cross_attention = false); void trtllmGenContextPostprocess(torch::Tensor qkv_input, torch::Tensor workspace, torch::Tensor sequence_lengths, torch::Tensor context_lengths, std::optional kv_cache_block_offsets, @@ -72,7 +73,7 @@ trtllmGenGenerationPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, int64_t rotary_embedding_scale_type, double rotary_embedding_scale, int64_t rotary_embedding_max_positions, int64_t position_embedding_type, double bmm1_scale, double bmm2_scale, bool fp8_context_fmha, int64_t predicted_tokens_per_seq, int64_t attention_chunk_size, int64_t multi_processor_count, - int64_t total_num_blocks, int64_t kv_factor, bool need_build_kv_cache_metadata); + int64_t total_num_blocks, int64_t kv_factor, bool need_build_kv_cache_metadata, bool cross_attention = false); } // namespace torch_ext diff --git a/cpp/tensorrt_llm/thop/trtllmGenQKVProcessOp.cpp b/cpp/tensorrt_llm/thop/trtllmGenQKVProcessOp.cpp index 9302e2a4978f..64931f687680 100644 --- a/cpp/tensorrt_llm/thop/trtllmGenQKVProcessOp.cpp +++ b/cpp/tensorrt_llm/thop/trtllmGenQKVProcessOp.cpp @@ -266,17 +266,22 @@ trtllmGenContextPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, tor int64_t const position_embedding_type, double const bmm1_scale, double const bmm2_scale, int64_t const attention_chunk_size, bool const fp8_context_fmha, bool const paged_context_fmha, bool const is_mla_enable, int64_t const multi_processor_count, int64_t const total_num_blocks, - int64_t const kv_factor, bool const need_build_kv_cache_metadata) + int64_t const kv_factor, bool const need_build_kv_cache_metadata, std::optional cross_kv, + bool const cross_attention) { (void) bmm2_scale; TORCH_CHECK(host_kv_cache_pool_pointers.has_value(), "host_kv_cache_pool_pointers is required."); TORCH_CHECK(host_kv_cache_pool_mapping.has_value(), "host_kv_cache_pool_mapping is required."); TORCH_CHECK(kv_cache_block_offsets.has_value(), "kv_cache_block_offsets is required."); + TORCH_CHECK(!cross_attention || !is_mla_enable, "trtllm-gen cross attention does not support MLA."); - bool const separateQKvOutput = paged_context_fmha || fp8_context_fmha; + bool const separateQKvOutput = paged_context_fmha || fp8_context_fmha || cross_attention; auto const qkvScalarType = qkv_input.scalar_type(); auto const qkvElementSize = static_cast(qkv_input.element_size()); auto const quantMode = tensorrt_llm::common::QuantMode(static_cast(kv_cache_quant_mode)); + int64_t const effectiveMaxAttentionWindowSize = cross_attention ? max_past_kv_length : max_attention_window_size; + int64_t const effectiveCyclicAttentionWindowSize + = cross_attention ? max_past_kv_length : cyclic_attention_window_size; auto const views = [&] { auto const layout = TrtllmAttentionWorkspaceManager::buildContextLayout( @@ -307,8 +312,8 @@ trtllmGenContextPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, tor decoderInfoParams.fmhaBmm2Scale = ptrs.fmhaBmm2ScalePtr; decoderInfoParams.batchSize = static_cast(batch_size); decoderInfoParams.maxQSeqLength = static_cast(input_seq_length); - decoderInfoParams.maxEncoderQSeqLength = 0; - decoderInfoParams.attentionWindowSize = static_cast(cyclic_attention_window_size); + decoderInfoParams.maxEncoderQSeqLength = cross_attention ? static_cast(max_past_kv_length) : 0; + decoderInfoParams.attentionWindowSize = static_cast(effectiveCyclicAttentionWindowSize); decoderInfoParams.numTokens = static_cast(num_tokens); decoderInfoParams.removePadding = true; decoderInfoParams.attentionMaskType = static_cast(mask_type); @@ -333,12 +338,13 @@ trtllmGenContextPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, tor { return buildPagedKvCacheBuffers(kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping, quantMode, layer_idx, batch_size, tokens_per_block, num_kv_heads, head_size, - cyclic_attention_window_size, max_attention_window_size, 0, 0, is_mla_enable, qkvElementSize); + effectiveCyclicAttentionWindowSize, effectiveMaxAttentionWindowSize, 0, 0, is_mla_enable, + qkvElementSize); }(); QKVPreprocessingParams qkvParams{}; qkvParams.qkv_input = qkv_input.data_ptr(); - qkvParams.cross_kv_input = nullptr; + qkvParams.cross_kv_input = optPtr(cross_kv); qkvParams.quantized_qkv_output = nullptr; qkvParams.q_output = ptrs.qBufPtr; qkvParams.kv_cache_buffer = kvArrays.kvCacheBuffer; @@ -353,8 +359,9 @@ trtllmGenContextPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, tor qkvParams.logn_scaling = nullptr; qkvParams.tokens_info = ptrs.tokensInfoPtr; qkvParams.seq_lens = static_cast(context_lengths.data_ptr()); - qkvParams.cache_seq_lens = static_cast(sequence_lengths.data_ptr()); - qkvParams.encoder_seq_lens = nullptr; + qkvParams.cache_seq_lens = cross_attention ? static_cast(context_lengths.data_ptr()) + : static_cast(sequence_lengths.data_ptr()); + qkvParams.encoder_seq_lens = cross_attention ? static_cast(sequence_lengths.data_ptr()) : nullptr; qkvParams.cu_seq_lens = ptrs.cuQSeqlensPtr; qkvParams.cu_kv_seq_lens = ptrs.cuKvSeqlensPtr; qkvParams.sparse_kv_offsets = nullptr; @@ -367,11 +374,11 @@ trtllmGenContextPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, tor qkvParams.batch_size = static_cast(batch_size); qkvParams.max_input_seq_len = static_cast(input_seq_length); qkvParams.max_kv_seq_len = static_cast(max_past_kv_length); - qkvParams.cyclic_kv_cache_len = static_cast(cyclic_attention_window_size); + qkvParams.cyclic_kv_cache_len = static_cast(effectiveCyclicAttentionWindowSize); qkvParams.token_num = static_cast(num_tokens); qkvParams.remove_padding = true; qkvParams.is_last_chunk = attention_chunk_size == 0 || input_seq_length == max_past_kv_length; - qkvParams.cross_attention = false; + qkvParams.cross_attention = cross_attention; qkvParams.head_num = static_cast(num_heads); qkvParams.kv_head_num = static_cast(num_kv_heads); qkvParams.qheads_per_kv_head = static_cast(num_heads / num_kv_heads); @@ -438,7 +445,9 @@ trtllmGenContextPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, tor // FlashInfer paged context launches trtllm-gen with multi-CTA-KV mode disabled, so it does not // consume the counter slab reserved at the head of the workspace. - auto const windowLeft = computeWindowLeft(cyclic_attention_window_size, max_past_kv_length, attention_chunk_size); + auto const windowLeft = cross_attention + ? int64_t{-1} + : computeWindowLeft(cyclic_attention_window_size, max_past_kv_length, attention_chunk_size); return {qProcessed, kvPool, blockTables, kvScalePool, views.fmhaBmm1Scale, views.fmhaBmm2Scale, views.trtllmGenWorkspace, views.cuQSeqlens, views.cuKvSeqlens, input_seq_length, max_past_kv_length, windowLeft}; @@ -577,7 +586,7 @@ trtllmGenGenerationPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, int64_t const rotary_embedding_max_positions, int64_t const position_embedding_type, double const bmm1_scale, double const bmm2_scale, bool const fp8_context_fmha, int64_t const predicted_tokens_per_seq, int64_t const attention_chunk_size, int64_t const multi_processor_count, int64_t const total_num_blocks, - int64_t const kv_factor, bool const need_build_kv_cache_metadata) + int64_t const kv_factor, bool const need_build_kv_cache_metadata, bool const cross_attention) { TORCH_CHECK(host_kv_cache_pool_pointers.has_value(), "host_kv_cache_pool_pointers is required."); TORCH_CHECK(host_kv_cache_pool_mapping.has_value(), "host_kv_cache_pool_mapping is required."); @@ -585,9 +594,14 @@ trtllmGenGenerationPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, (void) bmm2_scale; bool const isMultiTokenGen = spec_decoding_generation_lengths.has_value() && predicted_tokens_per_seq > 1; + TORCH_CHECK( + !cross_attention || !isMultiTokenGen, "trtllm-gen cross attention does not support multi-token generation."); auto const qkvScalarType = qkv_input.scalar_type(); auto const qkvElementSize = static_cast(qkv_input.element_size()); auto const quantMode = tensorrt_llm::common::QuantMode(static_cast(kv_cache_quant_mode)); + int64_t const effectiveMaxAttentionWindowSize = cross_attention ? max_past_kv_length : max_attention_window_size; + int64_t const effectiveCyclicAttentionWindowSize + = cross_attention ? max_past_kv_length : cyclic_attention_window_size; auto const views = [&] { auto const layout = TrtllmAttentionWorkspaceManager::buildGenerationLayout( @@ -617,7 +631,7 @@ trtllmGenGenerationPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, decoderInfoParams.fmhaBmm2Scale = nullptr; decoderInfoParams.batchSize = static_cast(batch_beam); decoderInfoParams.maxQSeqLength = static_cast(input_seq_length); - decoderInfoParams.maxEncoderQSeqLength = 0; + decoderInfoParams.maxEncoderQSeqLength = cross_attention ? static_cast(max_past_kv_length) : 0; decoderInfoParams.attentionWindowSize = 0; decoderInfoParams.sinkTokenLength = 0; decoderInfoParams.numTokens = static_cast(num_tokens); @@ -655,7 +669,8 @@ trtllmGenGenerationPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, { return buildPagedKvCacheBuffers(kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping, quantMode, layer_idx, batch_beam, tokens_per_block, num_kv_heads, head_size, - cyclic_attention_window_size, max_attention_window_size, 1, seq_offset, false, qkvElementSize); + effectiveCyclicAttentionWindowSize, effectiveMaxAttentionWindowSize, 1, seq_offset, false, + qkvElementSize); }(); QKVPreprocessingParams qkvParams{}; @@ -676,7 +691,7 @@ trtllmGenGenerationPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, qkvParams.tokens_info = isMultiTokenGen ? views.tokensInfoPtr : nullptr; qkvParams.seq_lens = isMultiTokenGen ? optPtr(spec_decoding_generation_lengths) : nullptr; qkvParams.cache_seq_lens = static_cast(sequence_lengths.data_ptr()); - qkvParams.encoder_seq_lens = nullptr; + qkvParams.encoder_seq_lens = cross_attention ? static_cast(sequence_lengths.data_ptr()) : nullptr; qkvParams.cu_seq_lens = buildDecoderInfoNeeded ? views.cuSeqlensPtr : nullptr; qkvParams.cu_kv_seq_lens = buildDecoderInfoNeeded ? views.cuKvSeqlensPtr : nullptr; qkvParams.sparse_kv_offsets = nullptr; @@ -690,11 +705,11 @@ trtllmGenGenerationPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, qkvParams.batch_size = static_cast(batch_beam); qkvParams.max_input_seq_len = static_cast(input_seq_length); qkvParams.max_kv_seq_len = static_cast(max_past_kv_length); - qkvParams.cyclic_kv_cache_len = static_cast(cyclic_attention_window_size); + qkvParams.cyclic_kv_cache_len = static_cast(effectiveCyclicAttentionWindowSize); qkvParams.token_num = static_cast(num_tokens); qkvParams.remove_padding = true; qkvParams.is_last_chunk = false; - qkvParams.cross_attention = false; + qkvParams.cross_attention = cross_attention; qkvParams.head_num = static_cast(num_heads); qkvParams.kv_head_num = static_cast(num_kv_heads); qkvParams.qheads_per_kv_head = static_cast(num_heads / num_kv_heads); @@ -751,7 +766,9 @@ trtllmGenGenerationPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, auto qProcessed = views.qBuf.view({num_tokens, num_heads, head_size}); - auto const windowLeft = computeWindowLeft(cyclic_attention_window_size, max_past_kv_length, attention_chunk_size); + auto const windowLeft = cross_attention + ? int64_t{-1} + : computeWindowLeft(cyclic_attention_window_size, max_past_kv_length, attention_chunk_size); return {qProcessed, kvPool, blockTables, kvScalePool, views.bmm1Scale, views.bmm2Scale, views.trtllmGenWorkspace, cuSeqlens, input_seq_length, max_past_kv_length, windowLeft, isMultiTokenGen}; } diff --git a/tensorrt_llm/_torch/attention_backend/fmha/fallback.py b/tensorrt_llm/_torch/attention_backend/fmha/fallback.py index 73c5778379d8..c13a3298dafa 100644 --- a/tensorrt_llm/_torch/attention_backend/fmha/fallback.py +++ b/tensorrt_llm/_torch/attention_backend/fmha/fallback.py @@ -94,7 +94,7 @@ def forward( block_ids_per_seq=metadata.block_ids_per_seq, tokens_per_block=metadata.tokens_per_block, max_num_requests=metadata.max_num_requests, - beam_width=metadata.effective_beam_width, + beam_width=metadata.beam_width, use_paged_context_fmha=metadata.use_paged_context_fmha, helix_position_offsets=metadata.helix_position_offsets, helix_is_inactive_rank=metadata.helix_is_inactive_rank, diff --git a/tensorrt_llm/_torch/attention_backend/fmha/flashinfer_trtllm_gen.py b/tensorrt_llm/_torch/attention_backend/fmha/flashinfer_trtllm_gen.py index d880335a8b75..803a8d0172a5 100644 --- a/tensorrt_llm/_torch/attention_backend/fmha/flashinfer_trtllm_gen.py +++ b/tensorrt_llm/_torch/attention_backend/fmha/flashinfer_trtllm_gen.py @@ -184,6 +184,7 @@ def _trtllm_gen_batch_context_with_kv_cache( enable_pdl: bool, kv_scale_pool: Optional[torch.Tensor], uses_shared_paged_kv_idx: bool, + causal: bool, ) -> None: bmm1_scale_arg = ( _get_bmm1_scale_log2(bmm1_scale) if isinstance(bmm1_scale, torch.Tensor) else bmm1_scale @@ -219,7 +220,7 @@ def _trtllm_gen_batch_context_with_kv_cache( kv_scale_pool, # value_block_scales None, # skip_softmax_threshold_scale_factor uses_shared_paged_kv_idx, - True, # causal + causal, # causal None, # lse 0, # lse_stride_tokens 0, # lse_stride_heads @@ -595,8 +596,6 @@ def _is_supported_with_reason( sparse_params = attn.sparse_params has_skip_softmax = getattr(sparse_params, "algorithm", None) == "skip_softmax" has_sparse_attention = sparse_params is not None and not has_skip_softmax - if meta.is_cross: - return False, "trtllm-gen does not support cross attention." if ( fwd.sage_attn_num_elts_per_blk_q > 0 or fwd.sage_attn_num_elts_per_blk_k > 0 @@ -616,6 +615,8 @@ def _is_supported_with_reason( return False, "trtllm-gen does not support sparse attention." if has_skip_softmax: return False, "trtllm-gen does not support skip-softmax attention." + if fwd.relative_attention_bias is not None: + return False, "Relative attention bias is not supported by trtllm-gen backend." if meta.use_spec_decoding and meta.is_spec_dec_tree: return ( False, @@ -648,6 +649,25 @@ def _is_supported_with_reason( kv_cache_dtype = self._get_kv_cache_dtype(meta) if kv_cache_dtype is None: kv_cache_dtype = torch_dtype_to_binding(q_dtype) + if meta.is_cross: + if kv_cache_dtype == DataType.NVFP4: + return ( + False, + "Cross attention with NVFP4 KV cache is not supported by trtllm-gen backend.", + ) + if is_mla_enable: + return False, "Cross attention with MLA is not supported by trtllm-gen backend." + if meta.is_spec_decoding_enabled or meta.use_spec_decoding: + return ( + False, + "Cross attention with speculative decoding is not supported by " + "trtllm-gen backend.", + ) + if fwd.update_kv_cache and fwd.cross_kv is None: + return ( + False, + "trtllm-gen cross attention requires cross_kv when update_kv_cache=True.", + ) is_fp8_out = output.dtype == torch.float8_e4m3fn is_fp4_out = output.dtype == torch.uint8 @@ -685,7 +705,7 @@ def _is_supported_with_reason( ) if has_generation_phase: - if meta.beam_width != 1: + if meta.beam_width != 1 and not meta.is_cross: return ( False, f"[Generation] Beam search (beam_width={meta.beam_width}) " @@ -862,7 +882,6 @@ def run_context( rope_params = attn.rope_params bmm1_scale_static = self._get_bmm1_scale(attn) attention_chunk_size = self._get_attention_chunk_size(attn) - ( q_processed, kv_pool, @@ -919,6 +938,8 @@ def run_context( params.total_num_blocks, # total_num_blocks params.kv_factor, # kv_factor True, # need_build_kv_cache_metadata + fwd.cross_kv, # cross_kv + params.is_cross, # is_cross ) has_fp4_kv = QuantMode(attn.quant_mode).has_fp4_kv_cache() @@ -935,7 +956,11 @@ def run_context( bmm1_scale if params.fp8_context_fmha and bmm1_scale is not None else bmm1_scale_static ) ctx_bmm2_scale = bmm2_scale if params.fp8_context_fmha and bmm2_scale is not None else 1.0 - + causal = ( + False + if params.is_cross + else AttentionMaskType(fwd.mask_type) == AttentionMaskType.causal + ) _trtllm_gen_batch_context_with_kv_cache( q_processed, # query kv_pool, # kv_pool @@ -955,8 +980,12 @@ def run_context( self._enable_pdl, # enable_pdl kv_scale_pool, # kv_scale_pool self.USE_SHARED_PAGED_KV_IDX, # uses_shared_paged_kv_idx + causal, # causal ) + if params.is_cross: + return + thop.trtllm_gen_context_postprocess( params.qkv_input, # qkv_input params.workspace, # workspace @@ -1063,6 +1092,7 @@ def run_generation( params.total_num_blocks, # total_num_blocks params.kv_factor, # kv_factor True, # need_build_kv_cache_metadata + params.is_cross, # is_cross ) # FIXME: Flashinfer trtllm-gen API doesn't support a separate diff --git a/tensorrt_llm/_torch/attention_backend/fmha/phased.py b/tensorrt_llm/_torch/attention_backend/fmha/phased.py index aced12909cc9..3752403ecc89 100644 --- a/tensorrt_llm/_torch/attention_backend/fmha/phased.py +++ b/tensorrt_llm/_torch/attention_backend/fmha/phased.py @@ -56,6 +56,7 @@ class FmhaParams: num_requests: int = 0 spec_decoding_generation_lengths: Optional[torch.Tensor] = None spec_decoding_position_offsets: Optional[torch.Tensor] = None + is_cross: bool = False class PhasedFmha(Fmha): @@ -184,6 +185,7 @@ def forward( fp8_context_fmha=fp8_context_fmha, kv_factor=self.kv_factor, total_num_blocks=self._get_total_num_blocks(metadata), + is_cross=metadata.is_cross, ) sequence_length = metadata.kv_lens_cuda_runtime diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index 93e1a2bfe4c4..f79f5da4287d 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -185,16 +185,6 @@ def max_context_length(self) -> int: """ return min(self.max_seq_len, self.max_num_tokens) - @property - def effective_beam_width(self) -> int: - """Beam width visible to the kernel. - - Cross-attention metadata is already expanded to one row per decoder - beam, and all beams read the same encoder K/V cache. Keep kernel beam - indirection disabled for that path. - """ - return 1 if self.is_cross else self.beam_width - @property def max_seq_len(self) -> int: """ diff --git a/tensorrt_llm/_torch/models/modeling_t5.py b/tensorrt_llm/_torch/models/modeling_t5.py index 5dc5a5c0e389..caddd80c6694 100644 --- a/tensorrt_llm/_torch/models/modeling_t5.py +++ b/tensorrt_llm/_torch/models/modeling_t5.py @@ -38,6 +38,7 @@ from torch import nn from transformers import T5Config +from tensorrt_llm._utils import get_sm_version from tensorrt_llm.functional import PositionEmbeddingType from ..attention_backend import AttentionMetadata @@ -132,6 +133,40 @@ def _clamp_fp16_infs(hidden_states: torch.Tensor) -> torch.Tensor: return torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) +class T5LayerNorm(RMSNorm): + """T5 RMSNorm with HF-compatible architecture-specific precision. + + On Hopper, HF uses Apex FusedRMSNorm when Apex is available. The generic + TRT-LLM RMSNorm path matches that behavior for ByT5 BF16 decoding. On + Blackwell, the generic fused path drifts from the HF reference, so use the + explicit T5 computation. + """ + + def __init__( + self, + *, + hidden_size: int, + eps: float, + dtype: Optional[torch.dtype] = None, + ): + super().__init__(hidden_size=hidden_size, eps=eps, dtype=dtype) + self._use_hopper_rms_norm: Optional[bool] = None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self._use_hopper_rms_norm is None and hidden_states.is_cuda: + sm_version = get_sm_version() + self._use_hopper_rms_norm = 90 <= sm_version < 100 + + if self._use_hopper_rms_norm and hidden_states.dtype in (torch.float16, torch.bfloat16): + return super().forward(hidden_states) + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + if self.weight.dtype in (torch.float16, torch.bfloat16): + hidden_states = hidden_states.to(self.weight.dtype) + return self.weight * hidden_states + + def _t5_encoder_num_layers(config: T5Config) -> int: return config.num_layers @@ -224,17 +259,17 @@ class T5Attention(Attention): When ``position_bias`` is provided (from a ``T5RelativePositionBias`` module living on layer 0), it is added to the QK^T scores before - softmax. Without a KV cache the module computes SDPA directly - (bypassing the VANILLA backend's ``flash_attn_varlen_func`` which - cannot accept an additive bias). With a KV cache it passes the learned - relative-attention table to the TRTLLM backend. + softmax. T5 self-attention requires the TRTLLM backend so the relative + bias can be routed through the attention backend. Without a KV cache, this + module passes a precomputed dense relative bias as explicit attention bias. + With a KV cache, decoder attention passes the learned relative-attention + table to the TRTLLM backend. """ def __init__( self, model_config: ModelConfig[T5Config], layer_idx: Optional[int] = None, - is_decoder: bool = True, ): config = model_config.pretrained_config num_heads = config.num_heads @@ -254,37 +289,11 @@ def __init__( q_scaling=_t5_q_scaling(config), head_dim=_t5_head_dim(config), ) - self._is_decoder = is_decoder - self._head_dim = _t5_head_dim(config) def apply_rope(self, q, k, v, position_ids): """T5 has no RoPE — pass through unchanged.""" return q, k, v - def _split_qkv( - self, - hidden_states: torch.Tensor, - num_tokens: int, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - qkv = self.qkv_proj(hidden_states) - q_size = self.num_heads * self._head_dim - kv_size = self.num_key_value_heads * self._head_dim - q, k, v = qkv[:num_tokens].split([q_size, kv_size, kv_size], dim=-1) - - q = q.view(-1, self.num_heads, self._head_dim) - k = k.view(-1, self.num_key_value_heads, self._head_dim) - v = v.view(-1, self.num_key_value_heads, self._head_dim) - return q, k, v - - @staticmethod - def _slice_position_bias( - position_bias: torch.Tensor, - query_length: int, - key_length: int, - ) -> torch.Tensor: - query_start = key_length - query_length - return position_bias[:, :, query_start:key_length, :key_length].squeeze(0) - def _local_position_bias( self, position_bias: torch.Tensor, @@ -336,70 +345,43 @@ def forward( relative_attention_max_distance: int = 0, **kwargs, ) -> torch.Tensor: - if position_bias is None and relative_attention_bias is None: - return super().forward( - position_ids=position_ids, - hidden_states=hidden_states, - attn_metadata=attn_metadata, - attention_mask=attention_mask, - **kwargs, + if self.attn_backend != "TRTLLM": + raise ValueError( + "T5 self-attention with relative position bias requires " + f"attn_backend='TRTLLM'. Current backend: {self.attn_backend}." ) - assert attn_metadata is not None - assert hidden_states is not None - if attn_metadata.kv_cache_manager is not None: + forward_kwargs = dict(kwargs) + if attn_metadata is not None and attn_metadata.kv_cache_manager is not None: if relative_attention_bias is None: raise ValueError("Cached T5 attention requires a relative attention bias table.") + assert hidden_states is not None relative_attention_bias = self._local_relative_attention_bias( relative_attention_bias, hidden_states, ) - return super().forward( - position_ids=position_ids, - hidden_states=hidden_states, - attn_metadata=attn_metadata, - attention_mask=attention_mask, - relative_attention_bias=relative_attention_bias, - relative_attention_max_distance=relative_attention_max_distance, - **kwargs, + elif position_bias is not None: + assert hidden_states is not None + position_bias = self._local_position_bias(position_bias, hidden_states) + relative_attention_bias = position_bias.squeeze(0).contiguous() + relative_attention_max_distance = 0 + forward_kwargs["attention_window_size"] = relative_attention_bias.shape[-1] + elif relative_attention_bias is not None: + assert hidden_states is not None + relative_attention_bias = self._local_relative_attention_bias( + relative_attention_bias, + hidden_states, ) - # Manual SDPA with additive position bias (no-KV-cache path). - assert position_bias is not None - position_bias = self._local_position_bias(position_bias, hidden_states) - num_tokens = attn_metadata.num_tokens - q, k, v = self._split_qkv(hidden_states, num_tokens) - - # Per-request SDPA with position bias applied to each request's scores. - seq_lens = attn_metadata.seq_lens - offset = 0 - outputs = [] - for seq_len in seq_lens: - sl = int(seq_len) - q_s = q[offset : offset + sl].transpose(0, 1) # (H, S, D) - k_s = k[offset : offset + sl].transpose(0, 1) - v_s = v[offset : offset + sl].transpose(0, 1) - - scores = torch.matmul(q_s, k_s.transpose(-2, -1)) - # position_bias: (1, H, qlen, klen) — slice to this request's lengths - scores = scores + self._slice_position_bias(position_bias, sl, sl) - - if self._is_decoder: - causal_mask = torch.triu( - torch.full((sl, sl), float("-inf"), device=scores.device, dtype=scores.dtype), - diagonal=1, - ) - scores = scores + causal_mask - - attn_weights = F.softmax(scores.float(), dim=-1).to(q.dtype) - out = torch.matmul(attn_weights, v_s) # (H, S, D) - outputs.append(out.transpose(0, 1)) # (S, H, D) - offset += sl - - attn_output = torch.cat(outputs, dim=0) # (T, H, D) - attn_output = attn_output.reshape(num_tokens, -1) - attn_output = self.o_proj(attn_output) - return attn_output + return super().forward( + position_ids=position_ids, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + attention_mask=attention_mask, + relative_attention_bias=relative_attention_bias, + relative_attention_max_distance=relative_attention_max_distance, + **forward_kwargs, + ) class T5CrossAttention(CrossAttention): @@ -450,14 +432,14 @@ def __init__( act_fn = _t5_gated_act_fn(config) if is_gated else _t5_dense_act_fn(config) - self.self_attn = T5Attention(model_config, layer_idx=layer_idx, is_decoder=False) + self.self_attn = T5Attention(model_config, layer_idx=layer_idx) - self.input_layernorm = RMSNorm( + self.input_layernorm = T5LayerNorm( hidden_size=hidden_size, eps=config.layer_norm_epsilon, dtype=config.torch_dtype, ) - self.post_attention_layernorm = RMSNorm( + self.post_attention_layernorm = T5LayerNorm( hidden_size=hidden_size, eps=config.layer_norm_epsilon, dtype=config.torch_dtype, @@ -536,21 +518,21 @@ def __init__( act_fn = _t5_gated_act_fn(config) if is_gated else _t5_dense_act_fn(config) - self.self_attn = T5Attention(model_config, layer_idx=layer_idx, is_decoder=True) + self.self_attn = T5Attention(model_config, layer_idx=layer_idx) self.cross_attn = T5CrossAttention(model_config, layer_idx=layer_idx) - self.input_layernorm = RMSNorm( + self.input_layernorm = T5LayerNorm( hidden_size=hidden_size, eps=config.layer_norm_epsilon, dtype=config.torch_dtype, ) - self.post_attention_layernorm = RMSNorm( + self.post_attention_layernorm = T5LayerNorm( hidden_size=hidden_size, eps=config.layer_norm_epsilon, dtype=config.torch_dtype, ) - self.cross_attn_layernorm = RMSNorm( + self.cross_attn_layernorm = T5LayerNorm( hidden_size=hidden_size, eps=config.layer_norm_epsilon, dtype=config.torch_dtype, @@ -652,7 +634,7 @@ def __init__(self, model_config: ModelConfig[T5Config]): self.layers = nn.ModuleList( [T5EncoderLayer(model_config, layer_idx=i) for i in range(num_layers)] ) - self.final_layernorm = RMSNorm( + self.final_layernorm = T5LayerNorm( hidden_size=config.d_model, eps=config.layer_norm_epsilon, dtype=config.torch_dtype, @@ -664,7 +646,12 @@ def forward( attn_metadata: AttentionMetadata, position_ids: Optional[torch.IntTensor] = None, ) -> torch.Tensor: - seq_len = hidden_states.shape[0] + max_context_q_len_override = getattr(attn_metadata, "max_context_q_len_override", None) + if max_context_q_len_override is not None: + seq_len = int(max_context_q_len_override) + else: + seq_lens = attn_metadata.seq_lens + seq_len = hidden_states.shape[0] if seq_lens is None else int(seq_lens.max().item()) position_bias = self.relative_position_bias(seq_len, seq_len, hidden_states.device) for layer in self.layers: @@ -702,7 +689,7 @@ def __init__(self, model_config: ModelConfig[T5Config]): self.layers = nn.ModuleList( [T5DecoderLayer(model_config, layer_idx=i) for i in range(num_layers)] ) - self.final_layernorm = RMSNorm( + self.final_layernorm = T5LayerNorm( hidden_size=config.d_model, eps=config.layer_norm_epsilon, dtype=config.torch_dtype, diff --git a/tensorrt_llm/_torch/modules/cross_attention.py b/tensorrt_llm/_torch/modules/cross_attention.py index ff0f440e245c..dece6d3fb400 100644 --- a/tensorrt_llm/_torch/modules/cross_attention.py +++ b/tensorrt_llm/_torch/modules/cross_attention.py @@ -41,14 +41,6 @@ class CrossAttention(nn.Module): subsequent generation steps, K/V are read from the cache without re-projection. - The cross-attention sub-layer honors ``ModelConfig.attn_backend``: when - set to ``"TRTLLM"`` it dispatches through the production C++ attention op - on every supported architecture. Cross-attention currently uses the THOP - attention path because the ``trtllm_gen`` backend API does not yet carry - encoder K/V tensors. - - Encoder and decoder self-attention are unaffected and continue to use - whatever backend ``ModelConfig.attn_backend`` selects. """ def __init__( diff --git a/tests/integration/defs/llmapi/test_llm_api_pytorch_bart.py b/tests/integration/defs/llmapi/test_llm_api_pytorch_bart.py index 3ae0e3652a3d..e35486872abc 100644 --- a/tests/integration/defs/llmapi/test_llm_api_pytorch_bart.py +++ b/tests/integration/defs/llmapi/test_llm_api_pytorch_bart.py @@ -63,6 +63,7 @@ def _test_case( num_return_sequences: int, exact_match: bool, feature_id: str, + kv_cache_dtype: str = "auto", ): expected_output_token_ids = [_EXPECTED_GREEDY_OUTPUT_TOKEN_IDS] if num_beams == 1 else None assert not exact_match or expected_output_token_ids is not None @@ -75,6 +76,7 @@ def _test_case( num_beams, num_return_sequences, exact_match, + kv_cache_dtype, id=f"{feature_id}-{_MODEL_NAME}", ) @@ -245,12 +247,7 @@ def _assert_expected_generation( assert token_ids_by_output == expected_token_ids_by_output -@pytest.mark.parametrize( - "expected_output_token_ids_by_output,torch_dtype,use_kv_cache_manager_v2," - "enable_cuda_graph,num_beams,num_return_sequences,exact_match", - _TEST_CASES, -) -def test_bart_pytorch_generate_encoder_decoder_end_to_end( +def _run_bart_pytorch_generate_encoder_decoder( monkeypatch: pytest.MonkeyPatch, expected_output_token_ids_by_output: list[list[int]] | None, torch_dtype: str, @@ -259,6 +256,7 @@ def test_bart_pytorch_generate_encoder_decoder_end_to_end( num_beams: int, num_return_sequences: int, exact_match: bool, + kv_cache_dtype: str = "auto", ) -> None: monkeypatch.setenv("TLLM_WORKER_USE_SINGLE_PROCESS", "1") monkeypatch.setenv("TRTLLM_SKIP_KV_CACHE_ESTIMATION", "1") @@ -267,7 +265,8 @@ def test_bart_pytorch_generate_encoder_decoder_end_to_end( tokenizer = AutoTokenizer.from_pretrained(model_path) case_id = ( f"model={_MODEL_NAME}, dtype={torch_dtype}, kv_v2={use_kv_cache_manager_v2}, " - f"cuda_graph={enable_cuda_graph}, beams={num_beams}, returns={num_return_sequences}" + f"cuda_graph={enable_cuda_graph}, beams={num_beams}, returns={num_return_sequences}, " + f"kv_dtype={kv_cache_dtype}" ) sampling_params = _sampling_params(num_beams, num_return_sequences) @@ -285,6 +284,7 @@ def test_bart_pytorch_generate_encoder_decoder_end_to_end( free_gpu_memory_fraction=_FREE_GPU_MEMORY_FRACTION, cross_kv_cache_fraction=_CROSS_KV_CACHE_FRACTION, use_kv_cache_manager_v2=use_kv_cache_manager_v2, + dtype=kv_cache_dtype, ), max_batch_size=1, max_beam_width=num_beams, @@ -312,6 +312,35 @@ def test_bart_pytorch_generate_encoder_decoder_end_to_end( ) +@pytest.mark.parametrize( + "expected_output_token_ids_by_output,torch_dtype,use_kv_cache_manager_v2," + "enable_cuda_graph,num_beams,num_return_sequences,exact_match,kv_cache_dtype", + _TEST_CASES, +) +def test_bart_pytorch_generate_encoder_decoder_end_to_end( + monkeypatch: pytest.MonkeyPatch, + expected_output_token_ids_by_output: list[list[int]] | None, + torch_dtype: str, + use_kv_cache_manager_v2: bool, + enable_cuda_graph: bool, + num_beams: int, + num_return_sequences: int, + exact_match: bool, + kv_cache_dtype: str, +) -> None: + _run_bart_pytorch_generate_encoder_decoder( + monkeypatch, + expected_output_token_ids_by_output, + torch_dtype, + use_kv_cache_manager_v2, + enable_cuda_graph, + num_beams, + num_return_sequences, + exact_match, + kv_cache_dtype, + ) + + @pytest.mark.parametrize( "torch_dtype,use_kv_cache_manager_v2,num_beams,num_return_sequences", _MIXED_BATCH_TEST_CASES, diff --git a/tests/integration/defs/llmapi/test_llm_api_pytorch_t5.py b/tests/integration/defs/llmapi/test_llm_api_pytorch_t5.py index c42bb0a40057..b9a68bcac8be 100644 --- a/tests/integration/defs/llmapi/test_llm_api_pytorch_t5.py +++ b/tests/integration/defs/llmapi/test_llm_api_pytorch_t5.py @@ -79,8 +79,8 @@ ("t5-small", 2): [ _HF_BEAM_OUTPUT_TOKEN_IDS_BY_MODEL_AND_BEAMS[("t5-small", 2)], [ - [644, 4675, 229, 219], [644, 4675, 4186, 219], + [644, 4675, 229, 219], ], ], ("flan-t5-small", 2): [ @@ -106,7 +106,7 @@ def _test_case( num_return_sequences: int, exact_match: bool, feature_id: str, - marks=(), + marks=None, ): if num_beams == 1: expected_output_token_ids = ( @@ -127,6 +127,10 @@ def _test_case( assert not exact_match or expected_output_token_ids is not None + param_kwargs = {"id": f"{feature_id}-{model_name}"} + if marks is not None: + param_kwargs["marks"] = marks + return pytest.param( model_name, expected_output_token_ids, @@ -136,8 +140,7 @@ def _test_case( num_beams, num_return_sequences, exact_match, - id=f"{feature_id}-{model_name}", - marks=marks, + **param_kwargs, ) @@ -371,7 +374,6 @@ def _mixed_batch_test_case( num_return_sequences: int, exact_match: bool, feature_id: str, - marks=(), ): expected_output_token_ids_by_request = ( _MIXED_ENCODER_OUTPUT_TOKEN_IDS_BY_MODEL_AND_BEAMS.get((model_name, num_beams)) @@ -389,7 +391,6 @@ def _mixed_batch_test_case( num_return_sequences, exact_match, id=f"{feature_id}-{model_name}", - marks=marks, ) @@ -511,25 +512,21 @@ def _assert_expected_generation( assert all(decoded_text_by_output) if expected_token_ids_by_output is None: assert all(expected_text_fragment in text for text in decoded_text_by_output) + elif exact_match: + assert token_ids_by_output == expected_token_ids_by_output + elif len(expected_token_ids_by_output) > 1: + assert {tuple(token_ids) for token_ids in token_ids_by_output} == { + tuple(token_ids) for token_ids in expected_token_ids_by_output + } else: assert token_ids_by_output[0] == expected_token_ids_by_output[0] if len(token_ids_by_output) > 1: assert len({tuple(token_ids) for token_ids in token_ids_by_output}) == len( token_ids_by_output ) - if not exact_match: - return - - assert expected_token_ids_by_output is not None - assert token_ids_by_output == expected_token_ids_by_output -@pytest.mark.parametrize( - "model_name,expected_output_token_ids_by_output,torch_dtype,use_kv_cache_manager_v2," - "enable_cuda_graph,num_beams,num_return_sequences,exact_match", - _TEST_CASES, -) -def test_t5_pytorch_generate_encoder_decoder_end_to_end( +def _run_t5_pytorch_generate_encoder_decoder( monkeypatch: pytest.MonkeyPatch, model_name: str, expected_output_token_ids_by_output: list[list[int]] | None, @@ -592,6 +589,35 @@ def test_t5_pytorch_generate_encoder_decoder_end_to_end( ) +@pytest.mark.parametrize( + "model_name,expected_output_token_ids_by_output,torch_dtype,use_kv_cache_manager_v2," + "enable_cuda_graph,num_beams,num_return_sequences,exact_match", + _TEST_CASES, +) +def test_t5_pytorch_generate_encoder_decoder_end_to_end( + monkeypatch: pytest.MonkeyPatch, + model_name: str, + expected_output_token_ids_by_output: list[list[int]] | None, + torch_dtype: str, + use_kv_cache_manager_v2: bool, + enable_cuda_graph: bool, + num_beams: int, + num_return_sequences: int, + exact_match: bool, +) -> None: + _run_t5_pytorch_generate_encoder_decoder( + monkeypatch, + model_name, + expected_output_token_ids_by_output, + torch_dtype, + use_kv_cache_manager_v2, + enable_cuda_graph, + num_beams, + num_return_sequences, + exact_match, + ) + + @pytest.mark.parametrize( "model_name,expected_output_token_ids_by_request,torch_dtype,use_kv_cache_manager_v2," "num_beams,num_return_sequences,exact_match",