From 37e550d11a214fb861ea77503f0ef7d0b51abfb0 Mon Sep 17 00:00:00 2001 From: Guiju Zhang <7135567+cascade812@users.noreply.github.com> Date: Tue, 9 Jun 2026 14:09:04 -0700 Subject: [PATCH 1/6] cross attention supported by trtllm-gen backend Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com> --- .../unfusedAttentionKernels_2_template.h | 2 +- cpp/tensorrt_llm/nanobind/thop/bindings.cpp | 13 +- cpp/tensorrt_llm/thop/trtllmGenFusedOps.h | 5 +- .../thop/trtllmGenQKVProcessOp.cpp | 53 ++-- .../_torch/attention_backend/trtllm.py | 3 - .../_torch/attention_backend/trtllm_gen.py | 47 +++- .../_torch/modules/cross_attention.py | 9 +- .../defs/llmapi/test_llm_api_pytorch_bart.py | 87 ++++++- .../test_lists/test-db/l0_b200.yml | 1 + .../attention_backend/test_trtllm_gen.py | 232 ++++++++++++++++++ 10 files changed, 403 insertions(+), 49 deletions(-) create mode 100644 tests/unittest/_torch/attention_backend/test_trtllm_gen.py diff --git a/cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h b/cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h index 8c24d7e5c220..560a153ffa70 100644 --- a/cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h +++ b/cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h @@ -1420,7 +1420,7 @@ __global__ void updateKVCacheForCrossAttention(QKVPreprocessingParams cross_kv, bool cross_attention) { auto result = [&]() { @@ -70,7 +70,7 @@ nb::tuple trtllmGenContextPreprocessBinding(torch::Tensor qkv_input, torch::Tens max_past_kv_length, rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale_type, rotary_embedding_scale, rotary_embedding_max_positions, position_embedding_type, bmm1_scale, bmm2_scale, attention_chunk_size, fp8_context_fmha, paged_context_fmha, is_mla_enable, multi_processor_count, - total_num_blocks, kv_factor, need_build_kv_cache_metadata); + total_num_blocks, kv_factor, need_build_kv_cache_metadata, cross_kv, cross_attention); }(); return nb::make_tuple(std::get<0>(result), optionalToObject(std::get<1>(result)), @@ -92,7 +92,7 @@ nb::tuple trtllmGenGenerationPreprocessBinding(torch::Tensor qkv_input, torch::T int64_t rotary_embedding_scale_type, double rotary_embedding_scale, int64_t rotary_embedding_max_positions, int64_t position_embedding_type, double bmm1_scale, double bmm2_scale, bool fp8_context_fmha, int64_t predicted_tokens_per_seq, int64_t attention_chunk_size, int64_t multi_processor_count, - int64_t total_num_blocks, int64_t kv_factor, bool need_build_kv_cache_metadata) + int64_t total_num_blocks, int64_t kv_factor, bool need_build_kv_cache_metadata, bool cross_attention) { auto result = [&]() { @@ -106,7 +106,7 @@ nb::tuple trtllmGenGenerationPreprocessBinding(torch::Tensor qkv_input, torch::T rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale_type, rotary_embedding_scale, rotary_embedding_max_positions, position_embedding_type, bmm1_scale, bmm2_scale, fp8_context_fmha, predicted_tokens_per_seq, attention_chunk_size, multi_processor_count, total_num_blocks, kv_factor, - need_build_kv_cache_metadata); + need_build_kv_cache_metadata, cross_attention); }(); return nb::make_tuple(std::get<0>(result), optionalToObject(std::get<1>(result)), @@ -273,7 +273,8 @@ void initBindings(nb::module_& m) nb::arg("position_embedding_type"), nb::arg("bmm1_scale"), nb::arg("bmm2_scale"), nb::arg("attention_chunk_size"), nb::arg("fp8_context_fmha"), nb::arg("paged_context_fmha"), nb::arg("is_mla_enable"), nb::arg("multi_processor_count"), nb::arg("total_num_blocks"), nb::arg("kv_factor"), - nb::arg("need_build_kv_cache_metadata") = true, "Fused nanobind context preprocess for trtllm-gen attention."); + nb::arg("need_build_kv_cache_metadata") = true, nb::arg("cross_kv").none() = nb::none(), + nb::arg("cross_attention") = false, "Fused nanobind context preprocess for trtllm-gen attention."); m.def("trtllm_gen_context_postprocess", &torch_ext::trtllmGenContextPostprocess, nb::arg("qkv_input"), nb::arg("workspace"), nb::arg("sequence_lengths"), nb::arg("context_lengths"), @@ -332,6 +333,6 @@ void initBindings(nb::module_& m) nb::arg("position_embedding_type"), nb::arg("bmm1_scale"), nb::arg("bmm2_scale"), nb::arg("fp8_context_fmha"), nb::arg("predicted_tokens_per_seq"), nb::arg("attention_chunk_size"), nb::arg("multi_processor_count"), nb::arg("total_num_blocks"), nb::arg("kv_factor"), nb::arg("need_build_kv_cache_metadata") = true, - "Fused nanobind generation preprocess for trtllm-gen attention."); + nb::arg("cross_attention") = false, "Fused nanobind generation preprocess for trtllm-gen attention."); } } // namespace tensorrt_llm::nanobind::thop diff --git a/cpp/tensorrt_llm/thop/trtllmGenFusedOps.h b/cpp/tensorrt_llm/thop/trtllmGenFusedOps.h index cb4db3535756..2b3dead3fb57 100644 --- a/cpp/tensorrt_llm/thop/trtllmGenFusedOps.h +++ b/cpp/tensorrt_llm/thop/trtllmGenFusedOps.h @@ -42,7 +42,8 @@ trtllmGenContextPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, tor double rotary_embedding_scale, int64_t rotary_embedding_max_positions, int64_t position_embedding_type, double bmm1_scale, double bmm2_scale, int64_t attention_chunk_size, bool fp8_context_fmha, bool paged_context_fmha, bool is_mla_enable, int64_t multi_processor_count, int64_t total_num_blocks, int64_t kv_factor, - bool need_build_kv_cache_metadata); + bool need_build_kv_cache_metadata, std::optional cross_kv = std::nullopt, + bool cross_attention = false); void trtllmGenContextPostprocess(torch::Tensor qkv_input, torch::Tensor workspace, torch::Tensor sequence_lengths, torch::Tensor context_lengths, std::optional kv_cache_block_offsets, @@ -72,7 +73,7 @@ trtllmGenGenerationPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, int64_t rotary_embedding_scale_type, double rotary_embedding_scale, int64_t rotary_embedding_max_positions, int64_t position_embedding_type, double bmm1_scale, double bmm2_scale, bool fp8_context_fmha, int64_t predicted_tokens_per_seq, int64_t attention_chunk_size, int64_t multi_processor_count, - int64_t total_num_blocks, int64_t kv_factor, bool need_build_kv_cache_metadata); + int64_t total_num_blocks, int64_t kv_factor, bool need_build_kv_cache_metadata, bool cross_attention = false); } // namespace torch_ext diff --git a/cpp/tensorrt_llm/thop/trtllmGenQKVProcessOp.cpp b/cpp/tensorrt_llm/thop/trtllmGenQKVProcessOp.cpp index 9302e2a4978f..64931f687680 100644 --- a/cpp/tensorrt_llm/thop/trtllmGenQKVProcessOp.cpp +++ b/cpp/tensorrt_llm/thop/trtllmGenQKVProcessOp.cpp @@ -266,17 +266,22 @@ trtllmGenContextPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, tor int64_t const position_embedding_type, double const bmm1_scale, double const bmm2_scale, int64_t const attention_chunk_size, bool const fp8_context_fmha, bool const paged_context_fmha, bool const is_mla_enable, int64_t const multi_processor_count, int64_t const total_num_blocks, - int64_t const kv_factor, bool const need_build_kv_cache_metadata) + int64_t const kv_factor, bool const need_build_kv_cache_metadata, std::optional cross_kv, + bool const cross_attention) { (void) bmm2_scale; TORCH_CHECK(host_kv_cache_pool_pointers.has_value(), "host_kv_cache_pool_pointers is required."); TORCH_CHECK(host_kv_cache_pool_mapping.has_value(), "host_kv_cache_pool_mapping is required."); TORCH_CHECK(kv_cache_block_offsets.has_value(), "kv_cache_block_offsets is required."); + TORCH_CHECK(!cross_attention || !is_mla_enable, "trtllm-gen cross attention does not support MLA."); - bool const separateQKvOutput = paged_context_fmha || fp8_context_fmha; + bool const separateQKvOutput = paged_context_fmha || fp8_context_fmha || cross_attention; auto const qkvScalarType = qkv_input.scalar_type(); auto const qkvElementSize = static_cast(qkv_input.element_size()); auto const quantMode = tensorrt_llm::common::QuantMode(static_cast(kv_cache_quant_mode)); + int64_t const effectiveMaxAttentionWindowSize = cross_attention ? max_past_kv_length : max_attention_window_size; + int64_t const effectiveCyclicAttentionWindowSize + = cross_attention ? max_past_kv_length : cyclic_attention_window_size; auto const views = [&] { auto const layout = TrtllmAttentionWorkspaceManager::buildContextLayout( @@ -307,8 +312,8 @@ trtllmGenContextPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, tor decoderInfoParams.fmhaBmm2Scale = ptrs.fmhaBmm2ScalePtr; decoderInfoParams.batchSize = static_cast(batch_size); decoderInfoParams.maxQSeqLength = static_cast(input_seq_length); - decoderInfoParams.maxEncoderQSeqLength = 0; - decoderInfoParams.attentionWindowSize = static_cast(cyclic_attention_window_size); + decoderInfoParams.maxEncoderQSeqLength = cross_attention ? static_cast(max_past_kv_length) : 0; + decoderInfoParams.attentionWindowSize = static_cast(effectiveCyclicAttentionWindowSize); decoderInfoParams.numTokens = static_cast(num_tokens); decoderInfoParams.removePadding = true; decoderInfoParams.attentionMaskType = static_cast(mask_type); @@ -333,12 +338,13 @@ trtllmGenContextPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, tor { return buildPagedKvCacheBuffers(kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping, quantMode, layer_idx, batch_size, tokens_per_block, num_kv_heads, head_size, - cyclic_attention_window_size, max_attention_window_size, 0, 0, is_mla_enable, qkvElementSize); + effectiveCyclicAttentionWindowSize, effectiveMaxAttentionWindowSize, 0, 0, is_mla_enable, + qkvElementSize); }(); QKVPreprocessingParams qkvParams{}; qkvParams.qkv_input = qkv_input.data_ptr(); - qkvParams.cross_kv_input = nullptr; + qkvParams.cross_kv_input = optPtr(cross_kv); qkvParams.quantized_qkv_output = nullptr; qkvParams.q_output = ptrs.qBufPtr; qkvParams.kv_cache_buffer = kvArrays.kvCacheBuffer; @@ -353,8 +359,9 @@ trtllmGenContextPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, tor qkvParams.logn_scaling = nullptr; qkvParams.tokens_info = ptrs.tokensInfoPtr; qkvParams.seq_lens = static_cast(context_lengths.data_ptr()); - qkvParams.cache_seq_lens = static_cast(sequence_lengths.data_ptr()); - qkvParams.encoder_seq_lens = nullptr; + qkvParams.cache_seq_lens = cross_attention ? static_cast(context_lengths.data_ptr()) + : static_cast(sequence_lengths.data_ptr()); + qkvParams.encoder_seq_lens = cross_attention ? static_cast(sequence_lengths.data_ptr()) : nullptr; qkvParams.cu_seq_lens = ptrs.cuQSeqlensPtr; qkvParams.cu_kv_seq_lens = ptrs.cuKvSeqlensPtr; qkvParams.sparse_kv_offsets = nullptr; @@ -367,11 +374,11 @@ trtllmGenContextPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, tor qkvParams.batch_size = static_cast(batch_size); qkvParams.max_input_seq_len = static_cast(input_seq_length); qkvParams.max_kv_seq_len = static_cast(max_past_kv_length); - qkvParams.cyclic_kv_cache_len = static_cast(cyclic_attention_window_size); + qkvParams.cyclic_kv_cache_len = static_cast(effectiveCyclicAttentionWindowSize); qkvParams.token_num = static_cast(num_tokens); qkvParams.remove_padding = true; qkvParams.is_last_chunk = attention_chunk_size == 0 || input_seq_length == max_past_kv_length; - qkvParams.cross_attention = false; + qkvParams.cross_attention = cross_attention; qkvParams.head_num = static_cast(num_heads); qkvParams.kv_head_num = static_cast(num_kv_heads); qkvParams.qheads_per_kv_head = static_cast(num_heads / num_kv_heads); @@ -438,7 +445,9 @@ trtllmGenContextPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, tor // FlashInfer paged context launches trtllm-gen with multi-CTA-KV mode disabled, so it does not // consume the counter slab reserved at the head of the workspace. - auto const windowLeft = computeWindowLeft(cyclic_attention_window_size, max_past_kv_length, attention_chunk_size); + auto const windowLeft = cross_attention + ? int64_t{-1} + : computeWindowLeft(cyclic_attention_window_size, max_past_kv_length, attention_chunk_size); return {qProcessed, kvPool, blockTables, kvScalePool, views.fmhaBmm1Scale, views.fmhaBmm2Scale, views.trtllmGenWorkspace, views.cuQSeqlens, views.cuKvSeqlens, input_seq_length, max_past_kv_length, windowLeft}; @@ -577,7 +586,7 @@ trtllmGenGenerationPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, int64_t const rotary_embedding_max_positions, int64_t const position_embedding_type, double const bmm1_scale, double const bmm2_scale, bool const fp8_context_fmha, int64_t const predicted_tokens_per_seq, int64_t const attention_chunk_size, int64_t const multi_processor_count, int64_t const total_num_blocks, - int64_t const kv_factor, bool const need_build_kv_cache_metadata) + int64_t const kv_factor, bool const need_build_kv_cache_metadata, bool const cross_attention) { TORCH_CHECK(host_kv_cache_pool_pointers.has_value(), "host_kv_cache_pool_pointers is required."); TORCH_CHECK(host_kv_cache_pool_mapping.has_value(), "host_kv_cache_pool_mapping is required."); @@ -585,9 +594,14 @@ trtllmGenGenerationPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, (void) bmm2_scale; bool const isMultiTokenGen = spec_decoding_generation_lengths.has_value() && predicted_tokens_per_seq > 1; + TORCH_CHECK( + !cross_attention || !isMultiTokenGen, "trtllm-gen cross attention does not support multi-token generation."); auto const qkvScalarType = qkv_input.scalar_type(); auto const qkvElementSize = static_cast(qkv_input.element_size()); auto const quantMode = tensorrt_llm::common::QuantMode(static_cast(kv_cache_quant_mode)); + int64_t const effectiveMaxAttentionWindowSize = cross_attention ? max_past_kv_length : max_attention_window_size; + int64_t const effectiveCyclicAttentionWindowSize + = cross_attention ? max_past_kv_length : cyclic_attention_window_size; auto const views = [&] { auto const layout = TrtllmAttentionWorkspaceManager::buildGenerationLayout( @@ -617,7 +631,7 @@ trtllmGenGenerationPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, decoderInfoParams.fmhaBmm2Scale = nullptr; decoderInfoParams.batchSize = static_cast(batch_beam); decoderInfoParams.maxQSeqLength = static_cast(input_seq_length); - decoderInfoParams.maxEncoderQSeqLength = 0; + decoderInfoParams.maxEncoderQSeqLength = cross_attention ? static_cast(max_past_kv_length) : 0; decoderInfoParams.attentionWindowSize = 0; decoderInfoParams.sinkTokenLength = 0; decoderInfoParams.numTokens = static_cast(num_tokens); @@ -655,7 +669,8 @@ trtllmGenGenerationPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, { return buildPagedKvCacheBuffers(kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping, quantMode, layer_idx, batch_beam, tokens_per_block, num_kv_heads, head_size, - cyclic_attention_window_size, max_attention_window_size, 1, seq_offset, false, qkvElementSize); + effectiveCyclicAttentionWindowSize, effectiveMaxAttentionWindowSize, 1, seq_offset, false, + qkvElementSize); }(); QKVPreprocessingParams qkvParams{}; @@ -676,7 +691,7 @@ trtllmGenGenerationPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, qkvParams.tokens_info = isMultiTokenGen ? views.tokensInfoPtr : nullptr; qkvParams.seq_lens = isMultiTokenGen ? optPtr(spec_decoding_generation_lengths) : nullptr; qkvParams.cache_seq_lens = static_cast(sequence_lengths.data_ptr()); - qkvParams.encoder_seq_lens = nullptr; + qkvParams.encoder_seq_lens = cross_attention ? static_cast(sequence_lengths.data_ptr()) : nullptr; qkvParams.cu_seq_lens = buildDecoderInfoNeeded ? views.cuSeqlensPtr : nullptr; qkvParams.cu_kv_seq_lens = buildDecoderInfoNeeded ? views.cuKvSeqlensPtr : nullptr; qkvParams.sparse_kv_offsets = nullptr; @@ -690,11 +705,11 @@ trtllmGenGenerationPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, qkvParams.batch_size = static_cast(batch_beam); qkvParams.max_input_seq_len = static_cast(input_seq_length); qkvParams.max_kv_seq_len = static_cast(max_past_kv_length); - qkvParams.cyclic_kv_cache_len = static_cast(cyclic_attention_window_size); + qkvParams.cyclic_kv_cache_len = static_cast(effectiveCyclicAttentionWindowSize); qkvParams.token_num = static_cast(num_tokens); qkvParams.remove_padding = true; qkvParams.is_last_chunk = false; - qkvParams.cross_attention = false; + qkvParams.cross_attention = cross_attention; qkvParams.head_num = static_cast(num_heads); qkvParams.kv_head_num = static_cast(num_kv_heads); qkvParams.qheads_per_kv_head = static_cast(num_heads / num_kv_heads); @@ -751,7 +766,9 @@ trtllmGenGenerationPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, auto qProcessed = views.qBuf.view({num_tokens, num_heads, head_size}); - auto const windowLeft = computeWindowLeft(cyclic_attention_window_size, max_past_kv_length, attention_chunk_size); + auto const windowLeft = cross_attention + ? int64_t{-1} + : computeWindowLeft(cyclic_attention_window_size, max_past_kv_length, attention_chunk_size); return {qProcessed, kvPool, blockTables, kvScalePool, views.bmm1Scale, views.bmm2Scale, views.trtllmGenWorkspace, cuSeqlens, input_seq_length, max_past_kv_length, windowLeft, isMultiTokenGen}; } diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index 0f009fc8f64a..4c46fcfbcb7c 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -1750,9 +1750,6 @@ def forward( metadata, TrtllmAttentionMetadata, ) - # Cross-attention uses the THOP path; the trtllm-gen backend API does - # not carry encoder K/V tensors yet. - if forward_args.multi_item_part_lens is not None: raise ValueError( "TRT-LLM Attention does not support multi-item scoring") diff --git a/tensorrt_llm/_torch/attention_backend/trtllm_gen.py b/tensorrt_llm/_torch/attention_backend/trtllm_gen.py index 192dd71d1617..f8596633b514 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm_gen.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm_gen.py @@ -182,6 +182,7 @@ def _trtllm_gen_batch_context_with_kv_cache( enable_pdl: bool, kv_scale_pool: Optional[torch.Tensor], uses_shared_paged_kv_idx: bool, + causal: bool, ) -> None: bmm1_scale_arg = ( _get_bmm1_scale_log2(bmm1_scale) if isinstance(bmm1_scale, torch.Tensor) else bmm1_scale @@ -217,7 +218,7 @@ def _trtllm_gen_batch_context_with_kv_cache( kv_scale_pool, # value_block_scales None, # skip_softmax_threshold_scale_factor uses_shared_paged_kv_idx, - True, # causal + causal, # causal None, # lse 0, # lse_stride_tokens 0, # lse_stride_heads @@ -377,6 +378,7 @@ class FmhaParams: num_requests: int = 0 spec_decoding_generation_lengths: Optional[torch.Tensor] = None spec_decoding_position_offsets: Optional[torch.Tensor] = None + cross_attention: bool = False class FlashInferTrtllmGenAttention: @@ -597,8 +599,13 @@ def is_supported( attn.skip_softmax_threshold_scale_factor_prefill is not None or attn.skip_softmax_threshold_scale_factor_decode is not None ) - if meta.is_cross: - return False, "trtllm-gen does not support cross attention." + if meta.is_cross and is_mla_enable: + return False, "Cross attention with MLA is not supported by trtllm-gen backend." + if meta.is_cross and (meta.is_spec_decoding_enabled or meta.use_spec_decoding): + return ( + False, + "Cross attention with speculative decoding is not supported by trtllm-gen backend.", + ) if ( fwd.sage_attn_num_elts_per_blk_q > 0 or fwd.sage_attn_num_elts_per_blk_k > 0 @@ -617,6 +624,8 @@ def is_supported( return False, "trtllm-gen does not support sparse attention." if has_skip_softmax: return False, "trtllm-gen does not support skip-softmax attention." + if fwd.relative_attention_bias is not None: + return False, "Relative attention bias is not supported by trtllm-gen backend." if meta.use_spec_decoding and meta.is_spec_dec_tree: return ( False, @@ -655,6 +664,11 @@ def is_supported( kv_cache_dtype, _ = self._get_kv_cache_dtype_and_total_blocks(meta, is_mla_enable) if kv_cache_dtype is None: kv_cache_dtype = torch_dtype_to_binding(q_dtype) + if meta.is_cross and kv_cache_dtype == DataType.NVFP4: + return ( + False, + "Cross attention with NVFP4 KV cache is not supported by trtllm-gen backend.", + ) is_fp8_out = output.dtype == torch.float8_e4m3fn is_fp4_out = output.dtype == torch.uint8 @@ -701,10 +715,10 @@ def is_supported( ) if has_generation_phase: - if meta.beam_width != 1: + if meta.effective_beam_width != 1: return ( False, - f"[Generation] Beam search (beam_width={meta.beam_width}) " + f"[Generation] Beam search (beam_width={meta.effective_beam_width}) " "is not supported. Must be 1.", ) sink_token_length = 0 @@ -804,7 +818,7 @@ def forward( max_num_requests = meta.max_num_requests max_context_length = meta.max_context_length attention_window_size = fwd.attention_window_size - beam_width = meta.beam_width + beam_width = meta.effective_beam_width num_tokens = q.size(0) attn_input_type = fwd.attention_input_type @@ -912,6 +926,7 @@ def forward( params.seq_offset = seq_offset params.input_seq_length = max_context_q_len params.batch_size = num_seqs + params.cross_attention = meta.is_cross self.run_context(params) if num_generations > 0 and attn_input_type != AttentionInputType.context_only: @@ -945,6 +960,7 @@ def forward( params.num_requests = num_seqs // beam_width params.spec_decoding_generation_lengths = spec_gen_lengths params.spec_decoding_position_offsets = spec_pos_offsets + params.cross_attention = meta.is_cross if is_mla_enable: self.run_mla_generation(params) else: @@ -994,6 +1010,10 @@ def run_context( rope_params = attn.rope_params bmm1_scale_static = self._get_bmm1_scale(attn) attention_chunk_size = self._get_attention_chunk_size(attn) + if params.cross_attention and fwd.update_kv_cache and fwd.cross_kv is None: + raise RuntimeError( + "trtllm-gen cross attention requires cross_kv when update_kv_cache=True." + ) ( q_processed, @@ -1051,6 +1071,8 @@ def run_context( params.total_num_blocks, # total_num_blocks params.kv_factor, # kv_factor True, # need_build_kv_cache_metadata + fwd.cross_kv if params.cross_attention and fwd.update_kv_cache else None, + params.cross_attention, # cross_attention ) has_fp4_kv = QuantMode(attn.quant_mode).has_fp4_kv_cache() @@ -1087,8 +1109,16 @@ def run_context( self._enable_pdl, # enable_pdl kv_scale_pool, # kv_scale_pool self.USE_SHARED_PAGED_KV_IDX, # uses_shared_paged_kv_idx + ( + False + if params.cross_attention + else AttentionMaskType(fwd.mask_type) == AttentionMaskType.causal + ), # causal ) + if params.cross_attention: + return + thop.trtllm_gen_context_postprocess( params.qkv_input, # qkv_input params.workspace, # workspace @@ -1139,7 +1169,7 @@ def run_generation( rope_params = attn.rope_params bmm1_scale_static = self._get_bmm1_scale(attn) attention_chunk_size = self._get_attention_chunk_size(attn) - batch_beam = params.num_requests * meta.beam_width + batch_beam = params.num_requests * meta.effective_beam_width ( q_processed, kv_pool, @@ -1195,6 +1225,7 @@ def run_generation( params.total_num_blocks, # total_num_blocks params.kv_factor, # kv_factor True, # need_build_kv_cache_metadata + params.cross_attention, # cross_attention ) # FIXME: Flashinfer trtllm-gen API doesn't support a separate @@ -1267,7 +1298,7 @@ def run_mla_generation( if self._get_attention_chunk_size(attn) != 0: raise NotImplementedError("Chunked-attention is not supported by MLA decode path.") - batch_beam = params.num_requests * meta.beam_width + batch_beam = params.num_requests * meta.effective_beam_width if params.attention_input is None: raise RuntimeError("MLA generation requires attention_input.") kv_cache, block_tables = thop.build_trtllm_gen_kv_cache_metadata( diff --git a/tensorrt_llm/_torch/modules/cross_attention.py b/tensorrt_llm/_torch/modules/cross_attention.py index ff0f440e245c..a2ecdc44d0a9 100644 --- a/tensorrt_llm/_torch/modules/cross_attention.py +++ b/tensorrt_llm/_torch/modules/cross_attention.py @@ -42,10 +42,11 @@ class CrossAttention(nn.Module): re-projection. The cross-attention sub-layer honors ``ModelConfig.attn_backend``: when - set to ``"TRTLLM"`` it dispatches through the production C++ attention op - on every supported architecture. Cross-attention currently uses the THOP - attention path because the ``trtllm_gen`` backend API does not yet carry - encoder K/V tensors. + set to ``"TRTLLM"`` it dispatches through the production attention backend + on every supported architecture. If the internal ``trtllm_gen`` fast path + is enabled and supports the current shape, cross-attention writes encoder + K/V to the cross-KV pool and uses the trtllm-gen kernels; otherwise it + falls back to the standard THOP attention path. Encoder and decoder self-attention are unaffected and continue to use whatever backend ``ModelConfig.attn_backend`` selects. diff --git a/tests/integration/defs/llmapi/test_llm_api_pytorch_bart.py b/tests/integration/defs/llmapi/test_llm_api_pytorch_bart.py index 3ae0e3652a3d..8b19894e6143 100644 --- a/tests/integration/defs/llmapi/test_llm_api_pytorch_bart.py +++ b/tests/integration/defs/llmapi/test_llm_api_pytorch_bart.py @@ -27,7 +27,7 @@ SchedulerConfig, ) -from ..conftest import llm_models_root +from ..conftest import llm_models_root, skip_pre_blackwell _SOURCE_TEXT = ( "Summarize: NVIDIA builds fast inference software for large language models. " @@ -63,6 +63,7 @@ def _test_case( num_return_sequences: int, exact_match: bool, feature_id: str, + marks=(), ): expected_output_token_ids = [_EXPECTED_GREEDY_OUTPUT_TOKEN_IDS] if num_beams == 1 else None assert not exact_match or expected_output_token_ids is not None @@ -76,6 +77,7 @@ def _test_case( num_return_sequences, exact_match, id=f"{feature_id}-{_MODEL_NAME}", + marks=marks, ) @@ -118,6 +120,19 @@ def _test_case( ), ] +_TRTLLM_GEN_TEST_CASES = [ + _test_case( + torch_dtype="bfloat16", + use_kv_cache_manager_v2=False, + enable_cuda_graph=False, + num_beams=1, + num_return_sequences=1, + exact_match=True, + feature_id="trtllm-gen-bf16-kv-v1-cuda-graph-off-greedy", + marks=skip_pre_blackwell, + ), +] + def _mixed_batch_test_case( torch_dtype: str, @@ -195,6 +210,14 @@ def _cuda_graph_config( return CudaGraphConfig(batch_sizes=batch_sizes or [1]) if enabled else None +def _enable_trtllm_gen_attention(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("TRTLLM_ENABLE_TRTLLM_GEN_ATTENTION", "1") + + from tensorrt_llm._torch.attention_backend import trtllm + + monkeypatch.setattr(trtllm, "_TRTLLM_ENABLE_TRTLLM_GEN_ATTENTION", True) + + def _assert_bart_response( response: RequestOutput, num_return_sequences: int, @@ -245,12 +268,7 @@ def _assert_expected_generation( assert token_ids_by_output == expected_token_ids_by_output -@pytest.mark.parametrize( - "expected_output_token_ids_by_output,torch_dtype,use_kv_cache_manager_v2," - "enable_cuda_graph,num_beams,num_return_sequences,exact_match", - _TEST_CASES, -) -def test_bart_pytorch_generate_encoder_decoder_end_to_end( +def _run_bart_pytorch_generate_encoder_decoder( monkeypatch: pytest.MonkeyPatch, expected_output_token_ids_by_output: list[list[int]] | None, torch_dtype: str, @@ -312,6 +330,61 @@ def test_bart_pytorch_generate_encoder_decoder_end_to_end( ) +@pytest.mark.parametrize( + "expected_output_token_ids_by_output,torch_dtype,use_kv_cache_manager_v2," + "enable_cuda_graph,num_beams,num_return_sequences,exact_match", + _TEST_CASES, +) +def test_bart_pytorch_generate_encoder_decoder_end_to_end( + monkeypatch: pytest.MonkeyPatch, + expected_output_token_ids_by_output: list[list[int]] | None, + torch_dtype: str, + use_kv_cache_manager_v2: bool, + enable_cuda_graph: bool, + num_beams: int, + num_return_sequences: int, + exact_match: bool, +) -> None: + _run_bart_pytorch_generate_encoder_decoder( + monkeypatch, + expected_output_token_ids_by_output, + torch_dtype, + use_kv_cache_manager_v2, + enable_cuda_graph, + num_beams, + num_return_sequences, + exact_match, + ) + + +@pytest.mark.parametrize( + "expected_output_token_ids_by_output,torch_dtype,use_kv_cache_manager_v2," + "enable_cuda_graph,num_beams,num_return_sequences,exact_match", + _TRTLLM_GEN_TEST_CASES, +) +def test_bart_pytorch_generate_encoder_decoder_trtllm_gen_attention( + monkeypatch: pytest.MonkeyPatch, + expected_output_token_ids_by_output: list[list[int]] | None, + torch_dtype: str, + use_kv_cache_manager_v2: bool, + enable_cuda_graph: bool, + num_beams: int, + num_return_sequences: int, + exact_match: bool, +) -> None: + _enable_trtllm_gen_attention(monkeypatch) + _run_bart_pytorch_generate_encoder_decoder( + monkeypatch, + expected_output_token_ids_by_output, + torch_dtype, + use_kv_cache_manager_v2, + enable_cuda_graph, + num_beams, + num_return_sequences, + exact_match, + ) + + @pytest.mark.parametrize( "torch_dtype,use_kv_cache_manager_v2,num_beams,num_return_sequences", _MIXED_BATCH_TEST_CASES, diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index 43f513bbe6aa..c3f2a2b855c3 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -87,6 +87,7 @@ l0_b200: - llmapi/test_llm_api_pytorch_bart.py::test_bart_pytorch_generate_encoder_decoder_end_to_end[bf16-kv-v2-cuda-graph-off-greedy-bart-large-cnn] - llmapi/test_llm_api_pytorch_bart.py::test_bart_pytorch_generate_encoder_decoder_mixed_encoder_lengths_batch[bf16-kv-v1-cuda-graph-off-greedy-batch2-bart-large-cnn] - llmapi/test_llm_api_pytorch_bart.py::test_bart_pytorch_generate_encoder_decoder_mixed_encoder_lengths_batch[bf16-kv-v2-cuda-graph-off-greedy-batch2-bart-large-cnn] + - llmapi/test_llm_api_pytorch_bart.py::test_bart_pytorch_generate_encoder_decoder_trtllm_gen_attention - llmapi/test_llm_api_pytorch_t5.py::test_t5_pytorch_generate_encoder_decoder_end_to_end[bf16-kv-v1-cuda-graph-off-beam2-t5-small0] - llmapi/test_llm_api_pytorch_t5.py::test_t5_pytorch_generate_encoder_decoder_end_to_end[bf16-kv-v1-cuda-graph-off-beam2-flan-t5-small] - llmapi/test_llm_api_pytorch_t5.py::test_t5_pytorch_generate_encoder_decoder_end_to_end[bf16-kv-v1-cuda-graph-off-beam2-t5-base] diff --git a/tests/unittest/_torch/attention_backend/test_trtllm_gen.py b/tests/unittest/_torch/attention_backend/test_trtllm_gen.py new file mode 100644 index 000000000000..bb7ebd0fd0f5 --- /dev/null +++ b/tests/unittest/_torch/attention_backend/test_trtllm_gen.py @@ -0,0 +1,232 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from types import SimpleNamespace + +import torch + +import tensorrt_llm._torch.attention_backend.trtllm_gen as trtllm_gen +from tensorrt_llm._torch.attention_backend.interface import AttentionForwardArgs, AttentionInputType +from tensorrt_llm.bindings import DataType + + +def _patch_blackwell(monkeypatch): + monkeypatch.setattr(trtllm_gen, "IS_FLASHINFER_AVAILABLE", True) + monkeypatch.setattr(trtllm_gen, "get_sm_version", lambda: 100) + monkeypatch.setattr(trtllm_gen, "is_sm_100f", lambda sm: sm in (100, 103)) + + +def _make_attn(**overrides): + fields = dict( + is_mla_enable=False, + skip_softmax_threshold_scale_factor_prefill=None, + skip_softmax_threshold_scale_factor_decode=None, + position_embedding_type=0, + num_heads=8, + num_kv_heads=8, + head_dim=64, + kv_lora_rank=None, + qk_rope_head_dim=None, + attention_chunk_size=None, + ) + fields.update(overrides) + return SimpleNamespace(**fields) + + +def _make_meta(**overrides): + fields = dict( + is_cross=True, + is_spec_decoding_enabled=False, + use_spec_decoding=False, + is_spec_dec_tree=False, + kv_cache_block_offsets=torch.empty((1, 1, 1), dtype=torch.int32), + kv_cache_manager=SimpleNamespace(dtype=DataType.BF16), + tokens_per_block=64, + beam_width=1, + effective_beam_width=1, + helix_position_offsets=None, + num_sparse_topk=0, + ) + fields.update(overrides) + return SimpleNamespace(**fields) + + +def _make_forward_args(**overrides): + fields = dict( + output=torch.empty((2, 8, 64), dtype=torch.bfloat16), + attention_input_type=AttentionInputType.context_only, + ) + fields.update(overrides) + return AttentionForwardArgs(**fields) + + +def _check_support(monkeypatch, *, attn=None, meta=None, fwd=None, q_dtype=torch.bfloat16): + _patch_blackwell(monkeypatch) + backend = object.__new__(trtllm_gen.FlashInferTrtllmGenAttention) + q = torch.empty((2, 8 * 64), dtype=q_dtype) + return backend.is_supported( + q, + None, + None, + attn=attn or _make_attn(), + meta=meta or _make_meta(), + fwd=fwd or _make_forward_args(), + ) + + +def test_is_supported_allows_cross_attention_on_blackwell(monkeypatch): + supported, reason = _check_support(monkeypatch) + + assert supported, reason + + +def test_is_supported_rejects_cross_attention_mla(monkeypatch): + supported, reason = _check_support( + monkeypatch, + attn=_make_attn(is_mla_enable=True), + fwd=_make_forward_args(attention_input_type=AttentionInputType.generation_only), + ) + + assert not supported + assert "Cross attention with MLA" in reason + + +def test_is_supported_rejects_cross_attention_nvfp4(monkeypatch): + supported, reason = _check_support( + monkeypatch, + meta=_make_meta(kv_cache_manager=SimpleNamespace(dtype=DataType.NVFP4)), + ) + + assert not supported + assert "Cross attention with NVFP4" in reason + + +def test_is_supported_rejects_cross_attention_spec_decoding(monkeypatch): + supported, reason = _check_support( + monkeypatch, + meta=_make_meta(is_spec_decoding_enabled=True, use_spec_decoding=True), + ) + + assert not supported + assert "Cross attention with speculative decoding" in reason + + +def test_is_supported_rejects_relative_attention_bias(monkeypatch): + supported, reason = _check_support( + monkeypatch, + fwd=_make_forward_args(relative_attention_bias=torch.empty((1, 1, 1, 1))), + ) + + assert not supported + assert "Relative attention bias" in reason + + +def test_is_supported_uses_effective_beam_width_for_cross_attention(monkeypatch): + supported, reason = _check_support( + monkeypatch, + meta=_make_meta(beam_width=4, effective_beam_width=1), + fwd=_make_forward_args(attention_input_type=AttentionInputType.generation_only), + ) + + assert supported, reason + + +def test_cross_attention_context_uses_fused_preprocess(monkeypatch): + preprocess_calls = [] + context_calls = [] + postprocess_called = False + + def fake_context_preprocess(*args): + preprocess_calls.append(args) + return ( + torch.empty((2, 2, 4), dtype=torch.bfloat16), + torch.empty((1,), dtype=torch.bfloat16), + torch.empty((1, 1), dtype=torch.int32), + None, + None, + None, + torch.empty((1,), dtype=torch.uint8), + torch.tensor([0, 2], dtype=torch.int32), + torch.tensor([0, 3], dtype=torch.int32), + 2, + 3, + -1, + ) + + def fake_context_attention(*args): + context_calls.append(args) + + def fake_context_postprocess(*args): + nonlocal postprocess_called + postprocess_called = True + + monkeypatch.setattr(trtllm_gen.thop, "trtllm_gen_context_preprocess", fake_context_preprocess) + monkeypatch.setattr(trtllm_gen.thop, "trtllm_gen_context_postprocess", fake_context_postprocess) + monkeypatch.setattr( + trtllm_gen, "_trtllm_gen_batch_context_with_kv_cache", fake_context_attention + ) + + backend = object.__new__(trtllm_gen.FlashInferTrtllmGenAttention) + backend._multi_processor_count = 1 + backend._enable_pdl = False + + cross_kv = torch.empty((3, 16), dtype=torch.bfloat16) + attn = SimpleNamespace( + rope_params=SimpleNamespace(dim=0, theta=10000.0, scale_type=0, scale=1.0, max_positions=0), + rotary_inv_freq=None, + rotary_cos_sin=None, + local_layer_idx=0, + num_heads=2, + num_kv_heads=2, + head_dim=4, + quant_mode=0, + q_scaling=1.0, + position_embedding_type=0, + is_mla_enable=False, + attention_chunk_size=None, + ) + meta = SimpleNamespace( + kv_cache_block_offsets=torch.empty((1, 1, 1), dtype=torch.int32), + host_kv_cache_pool_pointers=torch.empty((1,), dtype=torch.int64), + host_kv_cache_pool_mapping=torch.empty((1,), dtype=torch.int32), + use_paged_context_fmha=False, + ) + fwd = AttentionForwardArgs(cross_kv=cross_kv, update_kv_cache=True) + params = trtllm_gen.FmhaParams( + attn=attn, + meta=meta, + fwd=fwd, + workspace=torch.empty((1,), dtype=torch.uint8), + qkv_input=torch.empty((2, 24), dtype=torch.bfloat16), + context_buf=torch.empty((2, 2, 4), dtype=torch.bfloat16), + sequence_lengths=torch.tensor([3], dtype=torch.int32), + context_lengths=torch.tensor([2], dtype=torch.int32), + input_seq_length=2, + max_past_kv_length=3, + max_attention_window_size=3, + cyclic_attention_window_size=3, + num_tokens=2, + tokens_per_block=64, + kv_factor=2, + total_num_blocks=1, + batch_size=1, + cross_attention=True, + ) + + backend.run_context(params) + + assert preprocess_calls[0][-2] is cross_kv + assert preprocess_calls[0][-1] is True + assert context_calls[0][-1] is False + assert not postprocess_called From 63a2defc9705d36940b9e0c86ebaa765ab22351b Mon Sep 17 00:00:00 2001 From: Guiju Zhang <7135567+cascade812@users.noreply.github.com> Date: Wed, 10 Jun 2026 15:57:51 -0700 Subject: [PATCH 2/6] update Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com> --- tensorrt_llm/_torch/models/modeling_t5.py | 128 ++++++------------ .../defs/llmapi/test_llm_api_pytorch_bart.py | 36 ++++- .../defs/llmapi/test_llm_api_pytorch_t5.py | 105 ++++++++++++-- 3 files changed, 169 insertions(+), 100 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_t5.py b/tensorrt_llm/_torch/models/modeling_t5.py index 5dc5a5c0e389..4cee8fb31102 100644 --- a/tensorrt_llm/_torch/models/modeling_t5.py +++ b/tensorrt_llm/_torch/models/modeling_t5.py @@ -224,17 +224,17 @@ class T5Attention(Attention): When ``position_bias`` is provided (from a ``T5RelativePositionBias`` module living on layer 0), it is added to the QK^T scores before - softmax. Without a KV cache the module computes SDPA directly - (bypassing the VANILLA backend's ``flash_attn_varlen_func`` which - cannot accept an additive bias). With a KV cache it passes the learned - relative-attention table to the TRTLLM backend. + softmax. T5 self-attention requires the TRTLLM backend so the relative + bias can be routed through the attention backend. Without a KV cache, this + module passes a precomputed dense relative bias as explicit attention bias. + With a KV cache, decoder attention passes the learned relative-attention + table to the TRTLLM backend. """ def __init__( self, model_config: ModelConfig[T5Config], layer_idx: Optional[int] = None, - is_decoder: bool = True, ): config = model_config.pretrained_config num_heads = config.num_heads @@ -254,37 +254,11 @@ def __init__( q_scaling=_t5_q_scaling(config), head_dim=_t5_head_dim(config), ) - self._is_decoder = is_decoder - self._head_dim = _t5_head_dim(config) def apply_rope(self, q, k, v, position_ids): """T5 has no RoPE — pass through unchanged.""" return q, k, v - def _split_qkv( - self, - hidden_states: torch.Tensor, - num_tokens: int, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - qkv = self.qkv_proj(hidden_states) - q_size = self.num_heads * self._head_dim - kv_size = self.num_key_value_heads * self._head_dim - q, k, v = qkv[:num_tokens].split([q_size, kv_size, kv_size], dim=-1) - - q = q.view(-1, self.num_heads, self._head_dim) - k = k.view(-1, self.num_key_value_heads, self._head_dim) - v = v.view(-1, self.num_key_value_heads, self._head_dim) - return q, k, v - - @staticmethod - def _slice_position_bias( - position_bias: torch.Tensor, - query_length: int, - key_length: int, - ) -> torch.Tensor: - query_start = key_length - query_length - return position_bias[:, :, query_start:key_length, :key_length].squeeze(0) - def _local_position_bias( self, position_bias: torch.Tensor, @@ -336,70 +310,43 @@ def forward( relative_attention_max_distance: int = 0, **kwargs, ) -> torch.Tensor: - if position_bias is None and relative_attention_bias is None: - return super().forward( - position_ids=position_ids, - hidden_states=hidden_states, - attn_metadata=attn_metadata, - attention_mask=attention_mask, - **kwargs, + if self.attn_backend != "TRTLLM": + raise ValueError( + "T5 self-attention with relative position bias requires " + f"attn_backend='TRTLLM'. Current backend: {self.attn_backend}." ) - assert attn_metadata is not None - assert hidden_states is not None - if attn_metadata.kv_cache_manager is not None: + forward_kwargs = dict(kwargs) + if attn_metadata is not None and attn_metadata.kv_cache_manager is not None: if relative_attention_bias is None: raise ValueError("Cached T5 attention requires a relative attention bias table.") + assert hidden_states is not None relative_attention_bias = self._local_relative_attention_bias( relative_attention_bias, hidden_states, ) - return super().forward( - position_ids=position_ids, - hidden_states=hidden_states, - attn_metadata=attn_metadata, - attention_mask=attention_mask, - relative_attention_bias=relative_attention_bias, - relative_attention_max_distance=relative_attention_max_distance, - **kwargs, + elif position_bias is not None: + assert hidden_states is not None + position_bias = self._local_position_bias(position_bias, hidden_states) + relative_attention_bias = position_bias.squeeze(0).contiguous() + relative_attention_max_distance = 0 + forward_kwargs["attention_window_size"] = relative_attention_bias.shape[-1] + elif relative_attention_bias is not None: + assert hidden_states is not None + relative_attention_bias = self._local_relative_attention_bias( + relative_attention_bias, + hidden_states, ) - # Manual SDPA with additive position bias (no-KV-cache path). - assert position_bias is not None - position_bias = self._local_position_bias(position_bias, hidden_states) - num_tokens = attn_metadata.num_tokens - q, k, v = self._split_qkv(hidden_states, num_tokens) - - # Per-request SDPA with position bias applied to each request's scores. - seq_lens = attn_metadata.seq_lens - offset = 0 - outputs = [] - for seq_len in seq_lens: - sl = int(seq_len) - q_s = q[offset : offset + sl].transpose(0, 1) # (H, S, D) - k_s = k[offset : offset + sl].transpose(0, 1) - v_s = v[offset : offset + sl].transpose(0, 1) - - scores = torch.matmul(q_s, k_s.transpose(-2, -1)) - # position_bias: (1, H, qlen, klen) — slice to this request's lengths - scores = scores + self._slice_position_bias(position_bias, sl, sl) - - if self._is_decoder: - causal_mask = torch.triu( - torch.full((sl, sl), float("-inf"), device=scores.device, dtype=scores.dtype), - diagonal=1, - ) - scores = scores + causal_mask - - attn_weights = F.softmax(scores.float(), dim=-1).to(q.dtype) - out = torch.matmul(attn_weights, v_s) # (H, S, D) - outputs.append(out.transpose(0, 1)) # (S, H, D) - offset += sl - - attn_output = torch.cat(outputs, dim=0) # (T, H, D) - attn_output = attn_output.reshape(num_tokens, -1) - attn_output = self.o_proj(attn_output) - return attn_output + return super().forward( + position_ids=position_ids, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + attention_mask=attention_mask, + relative_attention_bias=relative_attention_bias, + relative_attention_max_distance=relative_attention_max_distance, + **forward_kwargs, + ) class T5CrossAttention(CrossAttention): @@ -450,7 +397,7 @@ def __init__( act_fn = _t5_gated_act_fn(config) if is_gated else _t5_dense_act_fn(config) - self.self_attn = T5Attention(model_config, layer_idx=layer_idx, is_decoder=False) + self.self_attn = T5Attention(model_config, layer_idx=layer_idx) self.input_layernorm = RMSNorm( hidden_size=hidden_size, @@ -536,7 +483,7 @@ def __init__( act_fn = _t5_gated_act_fn(config) if is_gated else _t5_dense_act_fn(config) - self.self_attn = T5Attention(model_config, layer_idx=layer_idx, is_decoder=True) + self.self_attn = T5Attention(model_config, layer_idx=layer_idx) self.cross_attn = T5CrossAttention(model_config, layer_idx=layer_idx) @@ -664,7 +611,12 @@ def forward( attn_metadata: AttentionMetadata, position_ids: Optional[torch.IntTensor] = None, ) -> torch.Tensor: - seq_len = hidden_states.shape[0] + max_context_q_len_override = getattr(attn_metadata, "max_context_q_len_override", None) + if max_context_q_len_override is not None: + seq_len = int(max_context_q_len_override) + else: + seq_lens = attn_metadata.seq_lens + seq_len = hidden_states.shape[0] if seq_lens is None else int(seq_lens.max().item()) position_bias = self.relative_position_bias(seq_len, seq_len, hidden_states.device) for layer in self.layers: diff --git a/tests/integration/defs/llmapi/test_llm_api_pytorch_bart.py b/tests/integration/defs/llmapi/test_llm_api_pytorch_bart.py index 8b19894e6143..3bba330e92db 100644 --- a/tests/integration/defs/llmapi/test_llm_api_pytorch_bart.py +++ b/tests/integration/defs/llmapi/test_llm_api_pytorch_bart.py @@ -64,6 +64,7 @@ def _test_case( exact_match: bool, feature_id: str, marks=(), + kv_cache_dtype: str = "auto", ): expected_output_token_ids = [_EXPECTED_GREEDY_OUTPUT_TOKEN_IDS] if num_beams == 1 else None assert not exact_match or expected_output_token_ids is not None @@ -76,6 +77,7 @@ def _test_case( num_beams, num_return_sequences, exact_match, + kv_cache_dtype, id=f"{feature_id}-{_MODEL_NAME}", marks=marks, ) @@ -131,6 +133,27 @@ def _test_case( feature_id="trtllm-gen-bf16-kv-v1-cuda-graph-off-greedy", marks=skip_pre_blackwell, ), + _test_case( + torch_dtype="bfloat16", + use_kv_cache_manager_v2=False, + enable_cuda_graph=False, + num_beams=2, + num_return_sequences=2, + exact_match=False, + feature_id="trtllm-gen-bf16-kv-v1-cuda-graph-off-beam2", + marks=skip_pre_blackwell, + ), + _test_case( + torch_dtype="bfloat16", + kv_cache_dtype="fp8", + use_kv_cache_manager_v2=False, + enable_cuda_graph=False, + num_beams=1, + num_return_sequences=1, + exact_match=False, + feature_id="trtllm-gen-bf16-fp8kv-kv-v1-cuda-graph-off-greedy", + marks=skip_pre_blackwell, + ), ] @@ -277,6 +300,7 @@ def _run_bart_pytorch_generate_encoder_decoder( num_beams: int, num_return_sequences: int, exact_match: bool, + kv_cache_dtype: str = "auto", ) -> None: monkeypatch.setenv("TLLM_WORKER_USE_SINGLE_PROCESS", "1") monkeypatch.setenv("TRTLLM_SKIP_KV_CACHE_ESTIMATION", "1") @@ -285,7 +309,8 @@ def _run_bart_pytorch_generate_encoder_decoder( tokenizer = AutoTokenizer.from_pretrained(model_path) case_id = ( f"model={_MODEL_NAME}, dtype={torch_dtype}, kv_v2={use_kv_cache_manager_v2}, " - f"cuda_graph={enable_cuda_graph}, beams={num_beams}, returns={num_return_sequences}" + f"cuda_graph={enable_cuda_graph}, beams={num_beams}, returns={num_return_sequences}, " + f"kv_dtype={kv_cache_dtype}" ) sampling_params = _sampling_params(num_beams, num_return_sequences) @@ -303,6 +328,7 @@ def _run_bart_pytorch_generate_encoder_decoder( free_gpu_memory_fraction=_FREE_GPU_MEMORY_FRACTION, cross_kv_cache_fraction=_CROSS_KV_CACHE_FRACTION, use_kv_cache_manager_v2=use_kv_cache_manager_v2, + dtype=kv_cache_dtype, ), max_batch_size=1, max_beam_width=num_beams, @@ -332,7 +358,7 @@ def _run_bart_pytorch_generate_encoder_decoder( @pytest.mark.parametrize( "expected_output_token_ids_by_output,torch_dtype,use_kv_cache_manager_v2," - "enable_cuda_graph,num_beams,num_return_sequences,exact_match", + "enable_cuda_graph,num_beams,num_return_sequences,exact_match,kv_cache_dtype", _TEST_CASES, ) def test_bart_pytorch_generate_encoder_decoder_end_to_end( @@ -344,6 +370,7 @@ def test_bart_pytorch_generate_encoder_decoder_end_to_end( num_beams: int, num_return_sequences: int, exact_match: bool, + kv_cache_dtype: str, ) -> None: _run_bart_pytorch_generate_encoder_decoder( monkeypatch, @@ -354,12 +381,13 @@ def test_bart_pytorch_generate_encoder_decoder_end_to_end( num_beams, num_return_sequences, exact_match, + kv_cache_dtype, ) @pytest.mark.parametrize( "expected_output_token_ids_by_output,torch_dtype,use_kv_cache_manager_v2," - "enable_cuda_graph,num_beams,num_return_sequences,exact_match", + "enable_cuda_graph,num_beams,num_return_sequences,exact_match,kv_cache_dtype", _TRTLLM_GEN_TEST_CASES, ) def test_bart_pytorch_generate_encoder_decoder_trtllm_gen_attention( @@ -371,6 +399,7 @@ def test_bart_pytorch_generate_encoder_decoder_trtllm_gen_attention( num_beams: int, num_return_sequences: int, exact_match: bool, + kv_cache_dtype: str, ) -> None: _enable_trtllm_gen_attention(monkeypatch) _run_bart_pytorch_generate_encoder_decoder( @@ -382,6 +411,7 @@ def test_bart_pytorch_generate_encoder_decoder_trtllm_gen_attention( num_beams, num_return_sequences, exact_match, + kv_cache_dtype, ) diff --git a/tests/integration/defs/llmapi/test_llm_api_pytorch_t5.py b/tests/integration/defs/llmapi/test_llm_api_pytorch_t5.py index c42bb0a40057..bea9c4232c09 100644 --- a/tests/integration/defs/llmapi/test_llm_api_pytorch_t5.py +++ b/tests/integration/defs/llmapi/test_llm_api_pytorch_t5.py @@ -27,7 +27,7 @@ SchedulerConfig, ) -from ..conftest import llm_models_root +from ..conftest import llm_models_root, skip_pre_blackwell _SOURCE_TEXT = "translate English to German: The house is wonderful." _MIXED_ENCODER_SOURCE_TEXTS = [ @@ -47,7 +47,7 @@ "t5-base": [644, 4598, 229, 19250], "t5-large": [644, 4598, 229, 19250], "flan-t5-small": [644, 4598, 229, 9685], - "byt5-small": [258, 35, 119, 114], + "byt5-small": [258, 35, 119, 35], } # Known HF references for returned beam hypotheses. The tests exact-match greedy # outputs and the best beam when a reference is available; lower-ranked BF16 @@ -79,8 +79,8 @@ ("t5-small", 2): [ _HF_BEAM_OUTPUT_TOKEN_IDS_BY_MODEL_AND_BEAMS[("t5-small", 2)], [ - [644, 4675, 229, 219], [644, 4675, 4186, 219], + [644, 4675, 229, 219], ], ], ("flan-t5-small", 2): [ @@ -362,6 +362,31 @@ def _test_case( ), ] +_TRTLLM_GEN_TEST_CASES = [ + _test_case( + model_name="t5-small", + torch_dtype="bfloat16", + use_kv_cache_manager_v2=False, + enable_cuda_graph=False, + num_beams=1, + num_return_sequences=1, + exact_match=True, + feature_id="trtllm-gen-bf16-kv-v1-cuda-graph-off-greedy", + marks=skip_pre_blackwell, + ), + _test_case( + model_name="t5-small", + torch_dtype="bfloat16", + use_kv_cache_manager_v2=False, + enable_cuda_graph=False, + num_beams=2, + num_return_sequences=2, + exact_match=False, + feature_id="trtllm-gen-bf16-kv-v1-cuda-graph-off-beam2", + marks=skip_pre_blackwell, + ), +] + def _mixed_batch_test_case( model_name: str, @@ -475,6 +500,14 @@ def _cuda_graph_config( return CudaGraphConfig(batch_sizes=batch_sizes or [1]) if enabled else None +def _enable_trtllm_gen_attention(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("TRTLLM_ENABLE_TRTLLM_GEN_ATTENTION", "1") + + from tensorrt_llm._torch.attention_backend import trtllm + + monkeypatch.setattr(trtllm, "_TRTLLM_ENABLE_TRTLLM_GEN_ATTENTION", True) + + def _assert_t5_response( response: RequestOutput, num_return_sequences: int, @@ -524,12 +557,7 @@ def _assert_expected_generation( assert token_ids_by_output == expected_token_ids_by_output -@pytest.mark.parametrize( - "model_name,expected_output_token_ids_by_output,torch_dtype,use_kv_cache_manager_v2," - "enable_cuda_graph,num_beams,num_return_sequences,exact_match", - _TEST_CASES, -) -def test_t5_pytorch_generate_encoder_decoder_end_to_end( +def _run_t5_pytorch_generate_encoder_decoder( monkeypatch: pytest.MonkeyPatch, model_name: str, expected_output_token_ids_by_output: list[list[int]] | None, @@ -592,6 +620,65 @@ def test_t5_pytorch_generate_encoder_decoder_end_to_end( ) +@pytest.mark.parametrize( + "model_name,expected_output_token_ids_by_output,torch_dtype,use_kv_cache_manager_v2," + "enable_cuda_graph,num_beams,num_return_sequences,exact_match", + _TEST_CASES, +) +def test_t5_pytorch_generate_encoder_decoder_end_to_end( + monkeypatch: pytest.MonkeyPatch, + model_name: str, + expected_output_token_ids_by_output: list[list[int]] | None, + torch_dtype: str, + use_kv_cache_manager_v2: bool, + enable_cuda_graph: bool, + num_beams: int, + num_return_sequences: int, + exact_match: bool, +) -> None: + _run_t5_pytorch_generate_encoder_decoder( + monkeypatch, + model_name, + expected_output_token_ids_by_output, + torch_dtype, + use_kv_cache_manager_v2, + enable_cuda_graph, + num_beams, + num_return_sequences, + exact_match, + ) + + +@pytest.mark.parametrize( + "model_name,expected_output_token_ids_by_output,torch_dtype,use_kv_cache_manager_v2," + "enable_cuda_graph,num_beams,num_return_sequences,exact_match", + _TRTLLM_GEN_TEST_CASES, +) +def test_t5_pytorch_generate_encoder_decoder_trtllm_gen_attention( + monkeypatch: pytest.MonkeyPatch, + model_name: str, + expected_output_token_ids_by_output: list[list[int]] | None, + torch_dtype: str, + use_kv_cache_manager_v2: bool, + enable_cuda_graph: bool, + num_beams: int, + num_return_sequences: int, + exact_match: bool, +) -> None: + _enable_trtllm_gen_attention(monkeypatch) + _run_t5_pytorch_generate_encoder_decoder( + monkeypatch, + model_name, + expected_output_token_ids_by_output, + torch_dtype, + use_kv_cache_manager_v2, + enable_cuda_graph, + num_beams, + num_return_sequences, + exact_match, + ) + + @pytest.mark.parametrize( "model_name,expected_output_token_ids_by_request,torch_dtype,use_kv_cache_manager_v2," "num_beams,num_return_sequences,exact_match", From 381595702481e9884e0ecaf782bdb8df6b10fb41 Mon Sep 17 00:00:00 2001 From: Guiju Zhang <7135567+cascade812@users.noreply.github.com> Date: Tue, 16 Jun 2026 18:41:04 -0700 Subject: [PATCH 3/6] update test Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com> --- .../defs/llmapi/test_llm_api_pytorch_bart.py | 74 +------------------ .../defs/llmapi/test_llm_api_pytorch_t5.py | 65 +--------------- .../test_lists/test-db/l0_b200.yml | 1 - 3 files changed, 2 insertions(+), 138 deletions(-) diff --git a/tests/integration/defs/llmapi/test_llm_api_pytorch_bart.py b/tests/integration/defs/llmapi/test_llm_api_pytorch_bart.py index 3bba330e92db..36e26c3562f7 100644 --- a/tests/integration/defs/llmapi/test_llm_api_pytorch_bart.py +++ b/tests/integration/defs/llmapi/test_llm_api_pytorch_bart.py @@ -27,7 +27,7 @@ SchedulerConfig, ) -from ..conftest import llm_models_root, skip_pre_blackwell +from ..conftest import llm_models_root _SOURCE_TEXT = ( "Summarize: NVIDIA builds fast inference software for large language models. " @@ -122,40 +122,6 @@ def _test_case( ), ] -_TRTLLM_GEN_TEST_CASES = [ - _test_case( - torch_dtype="bfloat16", - use_kv_cache_manager_v2=False, - enable_cuda_graph=False, - num_beams=1, - num_return_sequences=1, - exact_match=True, - feature_id="trtllm-gen-bf16-kv-v1-cuda-graph-off-greedy", - marks=skip_pre_blackwell, - ), - _test_case( - torch_dtype="bfloat16", - use_kv_cache_manager_v2=False, - enable_cuda_graph=False, - num_beams=2, - num_return_sequences=2, - exact_match=False, - feature_id="trtllm-gen-bf16-kv-v1-cuda-graph-off-beam2", - marks=skip_pre_blackwell, - ), - _test_case( - torch_dtype="bfloat16", - kv_cache_dtype="fp8", - use_kv_cache_manager_v2=False, - enable_cuda_graph=False, - num_beams=1, - num_return_sequences=1, - exact_match=False, - feature_id="trtllm-gen-bf16-fp8kv-kv-v1-cuda-graph-off-greedy", - marks=skip_pre_blackwell, - ), -] - def _mixed_batch_test_case( torch_dtype: str, @@ -233,14 +199,6 @@ def _cuda_graph_config( return CudaGraphConfig(batch_sizes=batch_sizes or [1]) if enabled else None -def _enable_trtllm_gen_attention(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setenv("TRTLLM_ENABLE_TRTLLM_GEN_ATTENTION", "1") - - from tensorrt_llm._torch.attention_backend import trtllm - - monkeypatch.setattr(trtllm, "_TRTLLM_ENABLE_TRTLLM_GEN_ATTENTION", True) - - def _assert_bart_response( response: RequestOutput, num_return_sequences: int, @@ -385,36 +343,6 @@ def test_bart_pytorch_generate_encoder_decoder_end_to_end( ) -@pytest.mark.parametrize( - "expected_output_token_ids_by_output,torch_dtype,use_kv_cache_manager_v2," - "enable_cuda_graph,num_beams,num_return_sequences,exact_match,kv_cache_dtype", - _TRTLLM_GEN_TEST_CASES, -) -def test_bart_pytorch_generate_encoder_decoder_trtllm_gen_attention( - monkeypatch: pytest.MonkeyPatch, - expected_output_token_ids_by_output: list[list[int]] | None, - torch_dtype: str, - use_kv_cache_manager_v2: bool, - enable_cuda_graph: bool, - num_beams: int, - num_return_sequences: int, - exact_match: bool, - kv_cache_dtype: str, -) -> None: - _enable_trtllm_gen_attention(monkeypatch) - _run_bart_pytorch_generate_encoder_decoder( - monkeypatch, - expected_output_token_ids_by_output, - torch_dtype, - use_kv_cache_manager_v2, - enable_cuda_graph, - num_beams, - num_return_sequences, - exact_match, - kv_cache_dtype, - ) - - @pytest.mark.parametrize( "torch_dtype,use_kv_cache_manager_v2,num_beams,num_return_sequences", _MIXED_BATCH_TEST_CASES, diff --git a/tests/integration/defs/llmapi/test_llm_api_pytorch_t5.py b/tests/integration/defs/llmapi/test_llm_api_pytorch_t5.py index bea9c4232c09..1d1452ad3fac 100644 --- a/tests/integration/defs/llmapi/test_llm_api_pytorch_t5.py +++ b/tests/integration/defs/llmapi/test_llm_api_pytorch_t5.py @@ -27,7 +27,7 @@ SchedulerConfig, ) -from ..conftest import llm_models_root, skip_pre_blackwell +from ..conftest import llm_models_root _SOURCE_TEXT = "translate English to German: The house is wonderful." _MIXED_ENCODER_SOURCE_TEXTS = [ @@ -362,31 +362,6 @@ def _test_case( ), ] -_TRTLLM_GEN_TEST_CASES = [ - _test_case( - model_name="t5-small", - torch_dtype="bfloat16", - use_kv_cache_manager_v2=False, - enable_cuda_graph=False, - num_beams=1, - num_return_sequences=1, - exact_match=True, - feature_id="trtllm-gen-bf16-kv-v1-cuda-graph-off-greedy", - marks=skip_pre_blackwell, - ), - _test_case( - model_name="t5-small", - torch_dtype="bfloat16", - use_kv_cache_manager_v2=False, - enable_cuda_graph=False, - num_beams=2, - num_return_sequences=2, - exact_match=False, - feature_id="trtllm-gen-bf16-kv-v1-cuda-graph-off-beam2", - marks=skip_pre_blackwell, - ), -] - def _mixed_batch_test_case( model_name: str, @@ -500,14 +475,6 @@ def _cuda_graph_config( return CudaGraphConfig(batch_sizes=batch_sizes or [1]) if enabled else None -def _enable_trtllm_gen_attention(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setenv("TRTLLM_ENABLE_TRTLLM_GEN_ATTENTION", "1") - - from tensorrt_llm._torch.attention_backend import trtllm - - monkeypatch.setattr(trtllm, "_TRTLLM_ENABLE_TRTLLM_GEN_ATTENTION", True) - - def _assert_t5_response( response: RequestOutput, num_return_sequences: int, @@ -649,36 +616,6 @@ def test_t5_pytorch_generate_encoder_decoder_end_to_end( ) -@pytest.mark.parametrize( - "model_name,expected_output_token_ids_by_output,torch_dtype,use_kv_cache_manager_v2," - "enable_cuda_graph,num_beams,num_return_sequences,exact_match", - _TRTLLM_GEN_TEST_CASES, -) -def test_t5_pytorch_generate_encoder_decoder_trtllm_gen_attention( - monkeypatch: pytest.MonkeyPatch, - model_name: str, - expected_output_token_ids_by_output: list[list[int]] | None, - torch_dtype: str, - use_kv_cache_manager_v2: bool, - enable_cuda_graph: bool, - num_beams: int, - num_return_sequences: int, - exact_match: bool, -) -> None: - _enable_trtllm_gen_attention(monkeypatch) - _run_t5_pytorch_generate_encoder_decoder( - monkeypatch, - model_name, - expected_output_token_ids_by_output, - torch_dtype, - use_kv_cache_manager_v2, - enable_cuda_graph, - num_beams, - num_return_sequences, - exact_match, - ) - - @pytest.mark.parametrize( "model_name,expected_output_token_ids_by_request,torch_dtype,use_kv_cache_manager_v2," "num_beams,num_return_sequences,exact_match", diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index c3f2a2b855c3..43f513bbe6aa 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -87,7 +87,6 @@ l0_b200: - llmapi/test_llm_api_pytorch_bart.py::test_bart_pytorch_generate_encoder_decoder_end_to_end[bf16-kv-v2-cuda-graph-off-greedy-bart-large-cnn] - llmapi/test_llm_api_pytorch_bart.py::test_bart_pytorch_generate_encoder_decoder_mixed_encoder_lengths_batch[bf16-kv-v1-cuda-graph-off-greedy-batch2-bart-large-cnn] - llmapi/test_llm_api_pytorch_bart.py::test_bart_pytorch_generate_encoder_decoder_mixed_encoder_lengths_batch[bf16-kv-v2-cuda-graph-off-greedy-batch2-bart-large-cnn] - - llmapi/test_llm_api_pytorch_bart.py::test_bart_pytorch_generate_encoder_decoder_trtllm_gen_attention - llmapi/test_llm_api_pytorch_t5.py::test_t5_pytorch_generate_encoder_decoder_end_to_end[bf16-kv-v1-cuda-graph-off-beam2-t5-small0] - llmapi/test_llm_api_pytorch_t5.py::test_t5_pytorch_generate_encoder_decoder_end_to_end[bf16-kv-v1-cuda-graph-off-beam2-flan-t5-small] - llmapi/test_llm_api_pytorch_t5.py::test_t5_pytorch_generate_encoder_decoder_end_to_end[bf16-kv-v1-cuda-graph-off-beam2-t5-base] From 96335682ed3688d977219b5ed0b85ccc8e0fef12 Mon Sep 17 00:00:00 2001 From: Guiju Zhang <7135567+cascade812@users.noreply.github.com> Date: Thu, 18 Jun 2026 22:56:31 -0700 Subject: [PATCH 4/6] address comments and fix test Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com> --- .../fmha/flashinfer_trtllm_gen.py | 28 ++- .../_torch/attention_backend/fmha/phased.py | 1 + .../_torch/attention_backend/trtllm.py | 8 +- .../_torch/modules/cross_attention.py | 9 - .../defs/llmapi/test_llm_api_pytorch_bart.py | 2 - .../defs/llmapi/test_llm_api_pytorch_t5.py | 11 +- .../attention_backend/test_trtllm_gen.py | 232 ------------------ 7 files changed, 26 insertions(+), 265 deletions(-) delete mode 100644 tests/unittest/_torch/attention_backend/test_trtllm_gen.py diff --git a/tensorrt_llm/_torch/attention_backend/fmha/flashinfer_trtllm_gen.py b/tensorrt_llm/_torch/attention_backend/fmha/flashinfer_trtllm_gen.py index 543e2504cbdd..278d5bb5e13d 100644 --- a/tensorrt_llm/_torch/attention_backend/fmha/flashinfer_trtllm_gen.py +++ b/tensorrt_llm/_torch/attention_backend/fmha/flashinfer_trtllm_gen.py @@ -596,13 +596,6 @@ def _is_supported_with_reason( sparse_params = attn.sparse_params has_skip_softmax = getattr(sparse_params, "algorithm", None) == "skip_softmax" has_sparse_attention = sparse_params is not None and not has_skip_softmax - if meta.is_cross and is_mla_enable: - return False, "Cross attention with MLA is not supported by trtllm-gen backend." - if meta.is_cross and (meta.is_spec_decoding_enabled or meta.use_spec_decoding): - return ( - False, - "Cross attention with speculative decoding is not supported by trtllm-gen backend.", - ) if ( fwd.sage_attn_num_elts_per_blk_q > 0 or fwd.sage_attn_num_elts_per_blk_k > 0 @@ -656,11 +649,20 @@ def _is_supported_with_reason( kv_cache_dtype = self._get_kv_cache_dtype(meta) if kv_cache_dtype is None: kv_cache_dtype = torch_dtype_to_binding(q_dtype) - if meta.is_cross and kv_cache_dtype == DataType.NVFP4: - return ( - False, - "Cross attention with NVFP4 KV cache is not supported by trtllm-gen backend.", - ) + if meta.is_cross: + if kv_cache_dtype == DataType.NVFP4: + return ( + False, + "Cross attention with NVFP4 KV cache is not supported by trtllm-gen backend.", + ) + if is_mla_enable: + return False, "Cross attention with MLA is not supported by trtllm-gen backend." + if meta.is_spec_decoding_enabled or meta.use_spec_decoding: + return ( + False, + "Cross attention with speculative decoding is not supported by " + "trtllm-gen backend.", + ) is_fp8_out = output.dtype == torch.float8_e4m3fn is_fp4_out = output.dtype == torch.uint8 @@ -698,7 +700,7 @@ def _is_supported_with_reason( ) if has_generation_phase: - if meta.effective_beam_width != 1: + if meta.effective_beam_width != 1 and not meta.is_cross: return ( False, f"[Generation] Beam search (beam_width={meta.effective_beam_width}) " diff --git a/tensorrt_llm/_torch/attention_backend/fmha/phased.py b/tensorrt_llm/_torch/attention_backend/fmha/phased.py index 3654690fc6f6..3752403ecc89 100644 --- a/tensorrt_llm/_torch/attention_backend/fmha/phased.py +++ b/tensorrt_llm/_torch/attention_backend/fmha/phased.py @@ -185,6 +185,7 @@ def forward( fp8_context_fmha=fp8_context_fmha, kv_factor=self.kv_factor, total_num_blocks=self._get_total_num_blocks(metadata), + is_cross=metadata.is_cross, ) sequence_length = metadata.kv_lens_cuda_runtime diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index 93e1a2bfe4c4..57cee026bb1c 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -189,11 +189,11 @@ def max_context_length(self) -> int: def effective_beam_width(self) -> int: """Beam width visible to the kernel. - Cross-attention metadata is already expanded to one row per decoder - beam, and all beams read the same encoder K/V cache. Keep kernel beam - indirection disabled for that path. + Cross-attention K/V storage is request-scoped, but its metadata is + expanded to one row per decoder beam so kernels still need the active + decoder beam width. """ - return 1 if self.is_cross else self.beam_width + return self.beam_width @property def max_seq_len(self) -> int: diff --git a/tensorrt_llm/_torch/modules/cross_attention.py b/tensorrt_llm/_torch/modules/cross_attention.py index a2ecdc44d0a9..dece6d3fb400 100644 --- a/tensorrt_llm/_torch/modules/cross_attention.py +++ b/tensorrt_llm/_torch/modules/cross_attention.py @@ -41,15 +41,6 @@ class CrossAttention(nn.Module): subsequent generation steps, K/V are read from the cache without re-projection. - The cross-attention sub-layer honors ``ModelConfig.attn_backend``: when - set to ``"TRTLLM"`` it dispatches through the production attention backend - on every supported architecture. If the internal ``trtllm_gen`` fast path - is enabled and supports the current shape, cross-attention writes encoder - K/V to the cross-KV pool and uses the trtllm-gen kernels; otherwise it - falls back to the standard THOP attention path. - - Encoder and decoder self-attention are unaffected and continue to use - whatever backend ``ModelConfig.attn_backend`` selects. """ def __init__( diff --git a/tests/integration/defs/llmapi/test_llm_api_pytorch_bart.py b/tests/integration/defs/llmapi/test_llm_api_pytorch_bart.py index 36e26c3562f7..e35486872abc 100644 --- a/tests/integration/defs/llmapi/test_llm_api_pytorch_bart.py +++ b/tests/integration/defs/llmapi/test_llm_api_pytorch_bart.py @@ -63,7 +63,6 @@ def _test_case( num_return_sequences: int, exact_match: bool, feature_id: str, - marks=(), kv_cache_dtype: str = "auto", ): expected_output_token_ids = [_EXPECTED_GREEDY_OUTPUT_TOKEN_IDS] if num_beams == 1 else None @@ -79,7 +78,6 @@ def _test_case( exact_match, kv_cache_dtype, id=f"{feature_id}-{_MODEL_NAME}", - marks=marks, ) diff --git a/tests/integration/defs/llmapi/test_llm_api_pytorch_t5.py b/tests/integration/defs/llmapi/test_llm_api_pytorch_t5.py index 1d1452ad3fac..7e10f56ff6d0 100644 --- a/tests/integration/defs/llmapi/test_llm_api_pytorch_t5.py +++ b/tests/integration/defs/llmapi/test_llm_api_pytorch_t5.py @@ -106,7 +106,7 @@ def _test_case( num_return_sequences: int, exact_match: bool, feature_id: str, - marks=(), + marks=None, ): if num_beams == 1: expected_output_token_ids = ( @@ -127,6 +127,10 @@ def _test_case( assert not exact_match or expected_output_token_ids is not None + param_kwargs = {"id": f"{feature_id}-{model_name}"} + if marks is not None: + param_kwargs["marks"] = marks + return pytest.param( model_name, expected_output_token_ids, @@ -136,8 +140,7 @@ def _test_case( num_beams, num_return_sequences, exact_match, - id=f"{feature_id}-{model_name}", - marks=marks, + **param_kwargs, ) @@ -371,7 +374,6 @@ def _mixed_batch_test_case( num_return_sequences: int, exact_match: bool, feature_id: str, - marks=(), ): expected_output_token_ids_by_request = ( _MIXED_ENCODER_OUTPUT_TOKEN_IDS_BY_MODEL_AND_BEAMS.get((model_name, num_beams)) @@ -389,7 +391,6 @@ def _mixed_batch_test_case( num_return_sequences, exact_match, id=f"{feature_id}-{model_name}", - marks=marks, ) diff --git a/tests/unittest/_torch/attention_backend/test_trtllm_gen.py b/tests/unittest/_torch/attention_backend/test_trtllm_gen.py deleted file mode 100644 index bb7ebd0fd0f5..000000000000 --- a/tests/unittest/_torch/attention_backend/test_trtllm_gen.py +++ /dev/null @@ -1,232 +0,0 @@ -# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from types import SimpleNamespace - -import torch - -import tensorrt_llm._torch.attention_backend.trtllm_gen as trtllm_gen -from tensorrt_llm._torch.attention_backend.interface import AttentionForwardArgs, AttentionInputType -from tensorrt_llm.bindings import DataType - - -def _patch_blackwell(monkeypatch): - monkeypatch.setattr(trtllm_gen, "IS_FLASHINFER_AVAILABLE", True) - monkeypatch.setattr(trtllm_gen, "get_sm_version", lambda: 100) - monkeypatch.setattr(trtllm_gen, "is_sm_100f", lambda sm: sm in (100, 103)) - - -def _make_attn(**overrides): - fields = dict( - is_mla_enable=False, - skip_softmax_threshold_scale_factor_prefill=None, - skip_softmax_threshold_scale_factor_decode=None, - position_embedding_type=0, - num_heads=8, - num_kv_heads=8, - head_dim=64, - kv_lora_rank=None, - qk_rope_head_dim=None, - attention_chunk_size=None, - ) - fields.update(overrides) - return SimpleNamespace(**fields) - - -def _make_meta(**overrides): - fields = dict( - is_cross=True, - is_spec_decoding_enabled=False, - use_spec_decoding=False, - is_spec_dec_tree=False, - kv_cache_block_offsets=torch.empty((1, 1, 1), dtype=torch.int32), - kv_cache_manager=SimpleNamespace(dtype=DataType.BF16), - tokens_per_block=64, - beam_width=1, - effective_beam_width=1, - helix_position_offsets=None, - num_sparse_topk=0, - ) - fields.update(overrides) - return SimpleNamespace(**fields) - - -def _make_forward_args(**overrides): - fields = dict( - output=torch.empty((2, 8, 64), dtype=torch.bfloat16), - attention_input_type=AttentionInputType.context_only, - ) - fields.update(overrides) - return AttentionForwardArgs(**fields) - - -def _check_support(monkeypatch, *, attn=None, meta=None, fwd=None, q_dtype=torch.bfloat16): - _patch_blackwell(monkeypatch) - backend = object.__new__(trtllm_gen.FlashInferTrtllmGenAttention) - q = torch.empty((2, 8 * 64), dtype=q_dtype) - return backend.is_supported( - q, - None, - None, - attn=attn or _make_attn(), - meta=meta or _make_meta(), - fwd=fwd or _make_forward_args(), - ) - - -def test_is_supported_allows_cross_attention_on_blackwell(monkeypatch): - supported, reason = _check_support(monkeypatch) - - assert supported, reason - - -def test_is_supported_rejects_cross_attention_mla(monkeypatch): - supported, reason = _check_support( - monkeypatch, - attn=_make_attn(is_mla_enable=True), - fwd=_make_forward_args(attention_input_type=AttentionInputType.generation_only), - ) - - assert not supported - assert "Cross attention with MLA" in reason - - -def test_is_supported_rejects_cross_attention_nvfp4(monkeypatch): - supported, reason = _check_support( - monkeypatch, - meta=_make_meta(kv_cache_manager=SimpleNamespace(dtype=DataType.NVFP4)), - ) - - assert not supported - assert "Cross attention with NVFP4" in reason - - -def test_is_supported_rejects_cross_attention_spec_decoding(monkeypatch): - supported, reason = _check_support( - monkeypatch, - meta=_make_meta(is_spec_decoding_enabled=True, use_spec_decoding=True), - ) - - assert not supported - assert "Cross attention with speculative decoding" in reason - - -def test_is_supported_rejects_relative_attention_bias(monkeypatch): - supported, reason = _check_support( - monkeypatch, - fwd=_make_forward_args(relative_attention_bias=torch.empty((1, 1, 1, 1))), - ) - - assert not supported - assert "Relative attention bias" in reason - - -def test_is_supported_uses_effective_beam_width_for_cross_attention(monkeypatch): - supported, reason = _check_support( - monkeypatch, - meta=_make_meta(beam_width=4, effective_beam_width=1), - fwd=_make_forward_args(attention_input_type=AttentionInputType.generation_only), - ) - - assert supported, reason - - -def test_cross_attention_context_uses_fused_preprocess(monkeypatch): - preprocess_calls = [] - context_calls = [] - postprocess_called = False - - def fake_context_preprocess(*args): - preprocess_calls.append(args) - return ( - torch.empty((2, 2, 4), dtype=torch.bfloat16), - torch.empty((1,), dtype=torch.bfloat16), - torch.empty((1, 1), dtype=torch.int32), - None, - None, - None, - torch.empty((1,), dtype=torch.uint8), - torch.tensor([0, 2], dtype=torch.int32), - torch.tensor([0, 3], dtype=torch.int32), - 2, - 3, - -1, - ) - - def fake_context_attention(*args): - context_calls.append(args) - - def fake_context_postprocess(*args): - nonlocal postprocess_called - postprocess_called = True - - monkeypatch.setattr(trtllm_gen.thop, "trtllm_gen_context_preprocess", fake_context_preprocess) - monkeypatch.setattr(trtllm_gen.thop, "trtllm_gen_context_postprocess", fake_context_postprocess) - monkeypatch.setattr( - trtllm_gen, "_trtllm_gen_batch_context_with_kv_cache", fake_context_attention - ) - - backend = object.__new__(trtllm_gen.FlashInferTrtllmGenAttention) - backend._multi_processor_count = 1 - backend._enable_pdl = False - - cross_kv = torch.empty((3, 16), dtype=torch.bfloat16) - attn = SimpleNamespace( - rope_params=SimpleNamespace(dim=0, theta=10000.0, scale_type=0, scale=1.0, max_positions=0), - rotary_inv_freq=None, - rotary_cos_sin=None, - local_layer_idx=0, - num_heads=2, - num_kv_heads=2, - head_dim=4, - quant_mode=0, - q_scaling=1.0, - position_embedding_type=0, - is_mla_enable=False, - attention_chunk_size=None, - ) - meta = SimpleNamespace( - kv_cache_block_offsets=torch.empty((1, 1, 1), dtype=torch.int32), - host_kv_cache_pool_pointers=torch.empty((1,), dtype=torch.int64), - host_kv_cache_pool_mapping=torch.empty((1,), dtype=torch.int32), - use_paged_context_fmha=False, - ) - fwd = AttentionForwardArgs(cross_kv=cross_kv, update_kv_cache=True) - params = trtllm_gen.FmhaParams( - attn=attn, - meta=meta, - fwd=fwd, - workspace=torch.empty((1,), dtype=torch.uint8), - qkv_input=torch.empty((2, 24), dtype=torch.bfloat16), - context_buf=torch.empty((2, 2, 4), dtype=torch.bfloat16), - sequence_lengths=torch.tensor([3], dtype=torch.int32), - context_lengths=torch.tensor([2], dtype=torch.int32), - input_seq_length=2, - max_past_kv_length=3, - max_attention_window_size=3, - cyclic_attention_window_size=3, - num_tokens=2, - tokens_per_block=64, - kv_factor=2, - total_num_blocks=1, - batch_size=1, - cross_attention=True, - ) - - backend.run_context(params) - - assert preprocess_calls[0][-2] is cross_kv - assert preprocess_calls[0][-1] is True - assert context_calls[0][-1] is False - assert not postprocess_called From 3c7b75a3d8904f7ebb82051cba2350157305c58d Mon Sep 17 00:00:00 2001 From: Guiju Zhang <7135567+cascade812@users.noreply.github.com> Date: Mon, 22 Jun 2026 00:07:19 -0700 Subject: [PATCH 5/6] fix tests Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com> --- tensorrt_llm/_torch/models/modeling_t5.py | 49 ++++++++++++++++--- .../defs/llmapi/test_llm_api_pytorch_t5.py | 13 ++--- 2 files changed, 49 insertions(+), 13 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_t5.py b/tensorrt_llm/_torch/models/modeling_t5.py index 4cee8fb31102..caddd80c6694 100644 --- a/tensorrt_llm/_torch/models/modeling_t5.py +++ b/tensorrt_llm/_torch/models/modeling_t5.py @@ -38,6 +38,7 @@ from torch import nn from transformers import T5Config +from tensorrt_llm._utils import get_sm_version from tensorrt_llm.functional import PositionEmbeddingType from ..attention_backend import AttentionMetadata @@ -132,6 +133,40 @@ def _clamp_fp16_infs(hidden_states: torch.Tensor) -> torch.Tensor: return torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) +class T5LayerNorm(RMSNorm): + """T5 RMSNorm with HF-compatible architecture-specific precision. + + On Hopper, HF uses Apex FusedRMSNorm when Apex is available. The generic + TRT-LLM RMSNorm path matches that behavior for ByT5 BF16 decoding. On + Blackwell, the generic fused path drifts from the HF reference, so use the + explicit T5 computation. + """ + + def __init__( + self, + *, + hidden_size: int, + eps: float, + dtype: Optional[torch.dtype] = None, + ): + super().__init__(hidden_size=hidden_size, eps=eps, dtype=dtype) + self._use_hopper_rms_norm: Optional[bool] = None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self._use_hopper_rms_norm is None and hidden_states.is_cuda: + sm_version = get_sm_version() + self._use_hopper_rms_norm = 90 <= sm_version < 100 + + if self._use_hopper_rms_norm and hidden_states.dtype in (torch.float16, torch.bfloat16): + return super().forward(hidden_states) + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + if self.weight.dtype in (torch.float16, torch.bfloat16): + hidden_states = hidden_states.to(self.weight.dtype) + return self.weight * hidden_states + + def _t5_encoder_num_layers(config: T5Config) -> int: return config.num_layers @@ -399,12 +434,12 @@ def __init__( self.self_attn = T5Attention(model_config, layer_idx=layer_idx) - self.input_layernorm = RMSNorm( + self.input_layernorm = T5LayerNorm( hidden_size=hidden_size, eps=config.layer_norm_epsilon, dtype=config.torch_dtype, ) - self.post_attention_layernorm = RMSNorm( + self.post_attention_layernorm = T5LayerNorm( hidden_size=hidden_size, eps=config.layer_norm_epsilon, dtype=config.torch_dtype, @@ -487,17 +522,17 @@ def __init__( self.cross_attn = T5CrossAttention(model_config, layer_idx=layer_idx) - self.input_layernorm = RMSNorm( + self.input_layernorm = T5LayerNorm( hidden_size=hidden_size, eps=config.layer_norm_epsilon, dtype=config.torch_dtype, ) - self.post_attention_layernorm = RMSNorm( + self.post_attention_layernorm = T5LayerNorm( hidden_size=hidden_size, eps=config.layer_norm_epsilon, dtype=config.torch_dtype, ) - self.cross_attn_layernorm = RMSNorm( + self.cross_attn_layernorm = T5LayerNorm( hidden_size=hidden_size, eps=config.layer_norm_epsilon, dtype=config.torch_dtype, @@ -599,7 +634,7 @@ def __init__(self, model_config: ModelConfig[T5Config]): self.layers = nn.ModuleList( [T5EncoderLayer(model_config, layer_idx=i) for i in range(num_layers)] ) - self.final_layernorm = RMSNorm( + self.final_layernorm = T5LayerNorm( hidden_size=config.d_model, eps=config.layer_norm_epsilon, dtype=config.torch_dtype, @@ -654,7 +689,7 @@ def __init__(self, model_config: ModelConfig[T5Config]): self.layers = nn.ModuleList( [T5DecoderLayer(model_config, layer_idx=i) for i in range(num_layers)] ) - self.final_layernorm = RMSNorm( + self.final_layernorm = T5LayerNorm( hidden_size=config.d_model, eps=config.layer_norm_epsilon, dtype=config.torch_dtype, diff --git a/tests/integration/defs/llmapi/test_llm_api_pytorch_t5.py b/tests/integration/defs/llmapi/test_llm_api_pytorch_t5.py index 7e10f56ff6d0..b9a68bcac8be 100644 --- a/tests/integration/defs/llmapi/test_llm_api_pytorch_t5.py +++ b/tests/integration/defs/llmapi/test_llm_api_pytorch_t5.py @@ -47,7 +47,7 @@ "t5-base": [644, 4598, 229, 19250], "t5-large": [644, 4598, 229, 19250], "flan-t5-small": [644, 4598, 229, 9685], - "byt5-small": [258, 35, 119, 35], + "byt5-small": [258, 35, 119, 114], } # Known HF references for returned beam hypotheses. The tests exact-match greedy # outputs and the best beam when a reference is available; lower-ranked BF16 @@ -512,17 +512,18 @@ def _assert_expected_generation( assert all(decoded_text_by_output) if expected_token_ids_by_output is None: assert all(expected_text_fragment in text for text in decoded_text_by_output) + elif exact_match: + assert token_ids_by_output == expected_token_ids_by_output + elif len(expected_token_ids_by_output) > 1: + assert {tuple(token_ids) for token_ids in token_ids_by_output} == { + tuple(token_ids) for token_ids in expected_token_ids_by_output + } else: assert token_ids_by_output[0] == expected_token_ids_by_output[0] if len(token_ids_by_output) > 1: assert len({tuple(token_ids) for token_ids in token_ids_by_output}) == len( token_ids_by_output ) - if not exact_match: - return - - assert expected_token_ids_by_output is not None - assert token_ids_by_output == expected_token_ids_by_output def _run_t5_pytorch_generate_encoder_decoder( From 738a8c30e02d0d00a1b926e22aa0e09bb013b6c0 Mon Sep 17 00:00:00 2001 From: Guiju Zhang <7135567+cascade812@users.noreply.github.com> Date: Tue, 23 Jun 2026 13:27:18 -0700 Subject: [PATCH 6/6] address comment Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com> --- .../_torch/attention_backend/fmha/fallback.py | 2 +- .../fmha/flashinfer_trtllm_gen.py | 20 +++++++++---------- .../_torch/attention_backend/trtllm.py | 10 ---------- 3 files changed, 11 insertions(+), 21 deletions(-) diff --git a/tensorrt_llm/_torch/attention_backend/fmha/fallback.py b/tensorrt_llm/_torch/attention_backend/fmha/fallback.py index 73c5778379d8..c13a3298dafa 100644 --- a/tensorrt_llm/_torch/attention_backend/fmha/fallback.py +++ b/tensorrt_llm/_torch/attention_backend/fmha/fallback.py @@ -94,7 +94,7 @@ def forward( block_ids_per_seq=metadata.block_ids_per_seq, tokens_per_block=metadata.tokens_per_block, max_num_requests=metadata.max_num_requests, - beam_width=metadata.effective_beam_width, + beam_width=metadata.beam_width, use_paged_context_fmha=metadata.use_paged_context_fmha, helix_position_offsets=metadata.helix_position_offsets, helix_is_inactive_rank=metadata.helix_is_inactive_rank, diff --git a/tensorrt_llm/_torch/attention_backend/fmha/flashinfer_trtllm_gen.py b/tensorrt_llm/_torch/attention_backend/fmha/flashinfer_trtllm_gen.py index 278d5bb5e13d..803a8d0172a5 100644 --- a/tensorrt_llm/_torch/attention_backend/fmha/flashinfer_trtllm_gen.py +++ b/tensorrt_llm/_torch/attention_backend/fmha/flashinfer_trtllm_gen.py @@ -663,6 +663,11 @@ def _is_supported_with_reason( "Cross attention with speculative decoding is not supported by " "trtllm-gen backend.", ) + if fwd.update_kv_cache and fwd.cross_kv is None: + return ( + False, + "trtllm-gen cross attention requires cross_kv when update_kv_cache=True.", + ) is_fp8_out = output.dtype == torch.float8_e4m3fn is_fp4_out = output.dtype == torch.uint8 @@ -700,10 +705,10 @@ def _is_supported_with_reason( ) if has_generation_phase: - if meta.effective_beam_width != 1 and not meta.is_cross: + if meta.beam_width != 1 and not meta.is_cross: return ( False, - f"[Generation] Beam search (beam_width={meta.effective_beam_width}) " + f"[Generation] Beam search (beam_width={meta.beam_width}) " "is not supported. Must be 1.", ) sink_token_length = 0 @@ -877,11 +882,6 @@ def run_context( rope_params = attn.rope_params bmm1_scale_static = self._get_bmm1_scale(attn) attention_chunk_size = self._get_attention_chunk_size(attn) - if params.is_cross and fwd.update_kv_cache and fwd.cross_kv is None: - raise RuntimeError( - "trtllm-gen cross attention requires cross_kv when update_kv_cache=True." - ) - ( q_processed, kv_pool, @@ -938,7 +938,7 @@ def run_context( params.total_num_blocks, # total_num_blocks params.kv_factor, # kv_factor True, # need_build_kv_cache_metadata - fwd.cross_kv, + fwd.cross_kv, # cross_kv params.is_cross, # is_cross ) @@ -1036,7 +1036,7 @@ def run_generation( rope_params = attn.rope_params bmm1_scale_static = self._get_bmm1_scale(attn) attention_chunk_size = self._get_attention_chunk_size(attn) - batch_beam = params.num_requests * meta.effective_beam_width + batch_beam = params.num_requests * meta.beam_width ( q_processed, kv_pool, @@ -1165,7 +1165,7 @@ def run_mla_generation( if self._get_attention_chunk_size(attn) != 0: raise NotImplementedError("Chunked-attention is not supported by MLA decode path.") - batch_beam = params.num_requests * meta.effective_beam_width + batch_beam = params.num_requests * meta.beam_width if params.attention_input is None: raise RuntimeError("MLA generation requires attention_input.") kv_cache, block_tables = thop.build_trtllm_gen_kv_cache_metadata( diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index 57cee026bb1c..f79f5da4287d 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -185,16 +185,6 @@ def max_context_length(self) -> int: """ return min(self.max_seq_len, self.max_num_tokens) - @property - def effective_beam_width(self) -> int: - """Beam width visible to the kernel. - - Cross-attention K/V storage is request-scoped, but its metadata is - expanded to one row per decoder beam so kernels still need the active - decoder beam width. - """ - return self.beam_width - @property def max_seq_len(self) -> int: """