Skip to content

[TRTLLM-13490][feat] Support cross-attention with FlashInfer TRTLLM-Gen kernels on Blackwell#15429

Open
cascade812 wants to merge 5 commits into
NVIDIA:mainfrom
cascade812:guiju/trtllm
Open

[TRTLLM-13490][feat] Support cross-attention with FlashInfer TRTLLM-Gen kernels on Blackwell#15429
cascade812 wants to merge 5 commits into
NVIDIA:mainfrom
cascade812:guiju/trtllm

Conversation

@cascade812

@cascade812 cascade812 commented Jun 16, 2026

Copy link
Copy Markdown
Collaborator

Description

When TRTLLM_ENABLE_TRTLLM_GEN_ATTENTION=1, cross-attention currently falls back to THOP on both sm90 and sm100. This PR enables the trtllm_gen_backend path for cross-attention, which uses FlashInfer/TRTLLM-Gen kernels.

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • If PR introduces API changes, an appropriate PR label is added - either api-compatible or api-breaking. For api-breaking, include BREAKING in the PR title.

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

Summary by CodeRabbit

  • New Features

    • Added cross-attention support with configurable KV-cache data types and enhanced cache management.
    • Extended attention backend with improved configuration validation and hardware compatibility checks across architectures.
  • Tests

    • Added comprehensive cross-attention validation tests including hardware support verification and configuration constraint enforcement.

Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
@cascade812 cascade812 requested review from a team as code owners June 16, 2026 22:26
@cascade812 cascade812 requested review from 2ez4bz and QiJune June 16, 2026 22:26
@cascade812 cascade812 requested a review from yuxianq June 16, 2026 22:30
@coderabbitai

coderabbitai Bot commented Jun 16, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

📝 Walkthrough

Walkthrough

Cross-attention support is added to the trtllm-gen attention backend. Changes span a C++ kernel null-check fix, new cross_kv/cross_attention parameters in C++ preprocess ops and their nanobind/Python bindings, Python backend support-gating and execution-path wiring, a refactor of T5Attention to route relative position bias through the TRTLLM backend, and new BART/T5 integration tests and unit tests.

Changes

Cross-attention support for trtllm-gen backend

Layer / File(s) Summary
KV-cache store guard fix and C++ API contracts
cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h, cpp/tensorrt_llm/thop/trtllmGenFusedOps.h
store_encoder_kv_cache adds a params.cross_kv_input != nullptr guard; trtllmGenContextPreprocess and trtllmGenGenerationPreprocess declarations gain cross_kv and cross_attention parameters.
C++ preprocess implementation
cpp/tensorrt_llm/thop/trtllmGenQKVProcessOp.cpp
Implements cross-attention in context and generation preprocess: MLA/multi-token validation, effective window sizes switch to max_past_kv_length, cross_kv_input wiring, conditional encoder_seq_lens/cache_seq_lens swap, qkvParams.cross_attention assignment, and windowLeft = -1 in cross-attention mode.
nanobind/Python binding wiring
cpp/tensorrt_llm/nanobind/thop/bindings.cpp
Extends trtllmGenContextPreprocessBinding and trtllmGenGenerationPreprocessBinding C++ wrappers and their nanobind Python signatures to accept and forward cross_kv and cross_attention.
Python trtllm-gen backend
tensorrt_llm/_torch/attention_backend/trtllm_gen.py
Adds cross_attention field to FmhaParams; tightens is_supported to reject MLA/NVFP4/speculative-decoding/relative-bias for cross-attention; switches beam-width checks to effective_beam_width; wires params.cross_attention in context/generation phases; guards cross_kv for KV-cache updates; passes causal=False for cross-attention FMHA with early return.
THOP backend comment cleanup
tensorrt_llm/_torch/attention_backend/trtllm.py, tensorrt_llm/_torch/modules/cross_attention.py
Removes stale comment that cross-attention was forced through THOP; updates CrossAttention docstring to describe current trtllm-gen fast path dispatch.
T5Attention refactor
tensorrt_llm/_torch/models/modeling_t5.py
Removes is_decoder from T5Attention.__init__, adds no-op apply_rope, rewrites forward to route relative position bias through TRTLLM backend for both KV-cache and non-cache paths with TP-sliced bias, updates T5EncoderLayer/T5DecoderLayer construction, and adds max_context_q_len_override support in T5Encoder.forward.
Unit tests for cross-attention backend
tests/unittest/_torch/attention_backend/test_trtllm_gen.py
New test module with helpers asserting is_supported outcomes for all cross-attention constraint combinations and a behavioral test verifying run_context argument forwarding and early-return in cross-attention mode.
BART and T5 integration tests
tests/integration/defs/llmapi/test_llm_api_pytorch_bart.py, tests/integration/defs/llmapi/test_llm_api_pytorch_t5.py, tests/integration/test_lists/test-db/l0_b200.yml
Adds trtllm-gen attention test cases (including fp8 KV cache), _enable_trtllm_gen_attention helper, new end-to-end tests for BART and T5 on Blackwell+, refactors test bodies into shared helpers, updates expected token fixtures, and registers the BART test in the B200 pre-merge list.

Sequence Diagram(s)

sequenceDiagram
  participant CrossAttention
  participant FlashInferTrtllmGenAttention as FlashInferTrtllmGenAttention (trtllm_gen.py)
  participant trtllm_gen_context_preprocess as trtllm_gen_context_preprocess (bindings.cpp)
  participant trtllmGenContextPreprocess as trtllmGenContextPreprocess (QKVProcessOp.cpp)
  participant _trtllm_gen_batch_context_with_kv_cache

  CrossAttention->>FlashInferTrtllmGenAttention: run_context(cross_kv, meta.is_cross=True)
  FlashInferTrtllmGenAttention->>FlashInferTrtllmGenAttention: guard cross_kv present if update_kv_cache
  FlashInferTrtllmGenAttention->>trtllm_gen_context_preprocess: cross_kv, cross_attention=True
  trtllm_gen_context_preprocess->>trtllmGenContextPreprocess: cross_kv ptr, cross_attention=true
  trtllmGenContextPreprocess->>trtllmGenContextPreprocess: effectiveWindowSize = max_past_kv_length
  trtllmGenContextPreprocess->>trtllmGenContextPreprocess: qkvParams.cross_kv_input = cross_kv
  trtllmGenContextPreprocess->>trtllmGenContextPreprocess: qkvParams.cross_attention = true
  trtllmGenContextPreprocess-->>trtllm_gen_context_preprocess: windowLeft = -1
  trtllm_gen_context_preprocess-->>FlashInferTrtllmGenAttention: preprocess result
  FlashInferTrtllmGenAttention->>_trtllm_gen_batch_context_with_kv_cache: causal=False
  _trtllm_gen_batch_context_with_kv_cache-->>FlashInferTrtllmGenAttention: attn output
  FlashInferTrtllmGenAttention-->>CrossAttention: early return (skip non-cross-attn postprocess)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • NVIDIA/TensorRT-LLM#15345: Directly modifies trtllm_gen.py is_supported/routing logic and cross_kv/cross_attention wiring in the same backend this PR extends.

Suggested labels

api-compatible

Suggested reviewers

  • QiJune
  • brb-nv
  • chang-l
  • laikhtewari
  • bo-nv
🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 7.55% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ⚠️ Warning The PR description is incomplete and missing critical sections from the template. Add a detailed description explaining the changes, include specific test cases that validate the implementation, and ensure all checklist items are properly addressed.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly and specifically describes the main change: enabling cross-attention support with FlashInfer TRTLLM-Gen kernels on Blackwell GPUs.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h (2)

2-2: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Update the NVIDIA copyright year for this modified source file.

This file now has a 2026 meaningful modification but the header still ends at 2024. As per coding guidelines, “All TensorRT-LLM source files should contain an NVIDIA copyright header including the year of latest meaningful modification.”

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h`
at line 2, Update the copyright year in the file header comment at the top of
the file. Change the year range from "2019-2024" to "2019-2026" in the NVIDIA
copyright notice to reflect the current meaningful modification year, as
required by coding guidelines that mandate copyright headers include the year of
latest meaningful modification.

Source: Coding guidelines


1415-1423: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Avoid launching encoder-length no-op work when no encoder KV is stored.

With the new cross_kv_input guard, generation cross-attention skips KV writes, but max_seq_len and the launch grid still scale with the encoder length. For long encoder inputs, every decode step can launch CTAs that do no useful work for token_idx >= decoder_seq_len.

⚡ Proposed fix
-    // The encoder sequence length.
-    int const encoder_seq_len = params.encoder_seq_lens[batch_idx];
     // The encoder sequence offset.
     // Not needed in Gen phase
     int const encoder_seq_offset = params.generation_phase ? -1 : params.cu_kv_seq_lens[batch_idx];
+    // Only the first chunk needs to store encoder kv input to the kv cache.
+    bool const store_encoder_kv_cache = params.cross_kv_input != nullptr && (decoder_seq_len == decoder_cache_seq_len);
+    // The encoder sequence length.
+    int const encoder_seq_len = store_encoder_kv_cache ? params.encoder_seq_lens[batch_idx] : 0;
     // THe maximum sequence length of encoder and decoder.
     int const max_seq_len = max(decoder_seq_len, encoder_seq_len);
-
-    // Only the first chunk needs to store encoder kv input to the kv cache.
-    bool const store_encoder_kv_cache = params.cross_kv_input != nullptr && (decoder_seq_len == decoder_cache_seq_len);
-    // The maximum sequence length of encoder and decoder inputs.
-    int const max_seq_len = std::max(params.max_input_seq_len, params.max_kv_seq_len);
+    // The maximum sequence length of useful Q/KV work.
+    int const kv_max_seq_len = (params.generation_phase || params.cross_kv_input == nullptr) ? 0 : params.max_kv_seq_len;
+    int const max_seq_len = std::max(params.max_input_seq_len, kv_max_seq_len);

Also applies to: 1534-1535

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h`
around lines 1415 - 1423, The max_seq_len calculation currently always includes
encoder_seq_len, which causes the launch grid to scale with encoder length even
during generation when store_encoder_kv_cache is false, resulting in CTAs doing
no useful work for token indices beyond decoder_seq_len. Modify the max_seq_len
computation to only include encoder_seq_len when encoder KV is actually being
stored to the cache (when store_encoder_kv_cache is true); otherwise use only
decoder_seq_len to avoid launching unnecessary work during generation. This
change should also be applied at the other affected locations mentioned in the
comment.
🧹 Nitpick comments (1)
tests/unittest/_torch/attention_backend/test_trtllm_gen.py (1)

145-233: ⚡ Quick win

Add generation-path unit coverage for cross-attention wiring.

Coverage is currently insufficient in tests/unittest/_torch/attention_backend/test_trtllm_gen.py for the newly changed generation path. Please add focused tests in this same file for:

  1. run_generation() forwarding of cross_attention into generation preprocess, and
  2. the update_kv_cache=True + missing cross_kv failure path.

As per coding guidelines: tests/** reviews should state whether coverage is sufficient/insufficient and suggest concrete file-level follow-ups.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/unittest/_torch/attention_backend/test_trtllm_gen.py` around lines 145
- 233, Add two new unit test functions to the test_trtllm_gen.py file following
the pattern of the existing test_cross_attention_context_uses_fused_preprocess
test. First, create a test that verifies the run_generation() method properly
forwards the cross_attention parameter to the generation preprocess function by
mocking the preprocess and generation calls and asserting that
cross_attention=True is passed through correctly. Second, create a test that
verifies the failure path when update_kv_cache=True is set but cross_kv is
missing or None, ensuring the appropriate error is raised. Both tests should
follow the same monkeypatching and assertion patterns established in the
existing test to ensure consistent coverage of the generation path
cross-attention wiring.

Source: Coding guidelines

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tests/integration/test_lists/test-db/l0_b200.yml`:
- Line 90: The new T5 TRTLLM-GEN test
(test_t5_pytorch_generate_encoder_decoder_trtllm_gen_attention) is missing from
the l0_b200.yml pre-merge CI selector configuration, while the corresponding
BART test (test_bart_pytorch_generate_encoder_decoder_trtllm_gen_attention) is
present. Add a selector entry for the T5 TRTLLM-GEN test in l0_b200.yml to
ensure the new T5 route is covered by pre-merge CI and regressions cannot slip
through.

---

Outside diff comments:
In
`@cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h`:
- Line 2: Update the copyright year in the file header comment at the top of the
file. Change the year range from "2019-2024" to "2019-2026" in the NVIDIA
copyright notice to reflect the current meaningful modification year, as
required by coding guidelines that mandate copyright headers include the year of
latest meaningful modification.
- Around line 1415-1423: The max_seq_len calculation currently always includes
encoder_seq_len, which causes the launch grid to scale with encoder length even
during generation when store_encoder_kv_cache is false, resulting in CTAs doing
no useful work for token indices beyond decoder_seq_len. Modify the max_seq_len
computation to only include encoder_seq_len when encoder KV is actually being
stored to the cache (when store_encoder_kv_cache is true); otherwise use only
decoder_seq_len to avoid launching unnecessary work during generation. This
change should also be applied at the other affected locations mentioned in the
comment.

---

Nitpick comments:
In `@tests/unittest/_torch/attention_backend/test_trtllm_gen.py`:
- Around line 145-233: Add two new unit test functions to the test_trtllm_gen.py
file following the pattern of the existing
test_cross_attention_context_uses_fused_preprocess test. First, create a test
that verifies the run_generation() method properly forwards the cross_attention
parameter to the generation preprocess function by mocking the preprocess and
generation calls and asserting that cross_attention=True is passed through
correctly. Second, create a test that verifies the failure path when
update_kv_cache=True is set but cross_kv is missing or None, ensuring the
appropriate error is raised. Both tests should follow the same monkeypatching
and assertion patterns established in the existing test to ensure consistent
coverage of the generation path cross-attention wiring.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 7c497782-2541-47dc-900a-9921a2ebfb3c

📥 Commits

Reviewing files that changed from the base of the PR and between dfad249 and 63a2def.

📒 Files selected for processing (12)
  • cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h
  • cpp/tensorrt_llm/nanobind/thop/bindings.cpp
  • cpp/tensorrt_llm/thop/trtllmGenFusedOps.h
  • cpp/tensorrt_llm/thop/trtllmGenQKVProcessOp.cpp
  • tensorrt_llm/_torch/attention_backend/trtllm.py
  • tensorrt_llm/_torch/attention_backend/trtllm_gen.py
  • tensorrt_llm/_torch/models/modeling_t5.py
  • tensorrt_llm/_torch/modules/cross_attention.py
  • tests/integration/defs/llmapi/test_llm_api_pytorch_bart.py
  • tests/integration/defs/llmapi/test_llm_api_pytorch_t5.py
  • tests/integration/test_lists/test-db/l0_b200.yml
  • tests/unittest/_torch/attention_backend/test_trtllm_gen.py
💤 Files with no reviewable changes (1)
  • tensorrt_llm/_torch/attention_backend/trtllm.py

Comment thread tests/integration/test_lists/test-db/l0_b200.yml Outdated
@cascade812 cascade812 marked this pull request as draft June 16, 2026 22:38
@cascade812

Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54692 [ run ] triggered by Bot. Commit: 63a2def Link to invocation

@cascade812 cascade812 marked this pull request as ready for review June 17, 2026 01:06
@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54692 [ run ] completed with state FAILURE. Commit: 63a2def
/LLM/main/L0_MergeRequest_PR pipeline #43721 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
@cascade812

Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54700 [ run ] triggered by Bot. Commit: 3815957 Link to invocation

Comment thread tests/integration/defs/llmapi/test_llm_api_pytorch_t5.py
Comment thread tensorrt_llm/_torch/attention_backend/fmha/flashinfer_trtllm_gen.py Outdated
@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54700 [ run ] completed with state FAILURE. Commit: 3815957
/LLM/main/L0_MergeRequest_PR pipeline #43731 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

Comment thread tensorrt_llm/_torch/attention_backend/trtllm_gen.py Outdated
Comment thread tensorrt_llm/_torch/attention_backend/trtllm_gen.py Outdated
Comment thread tensorrt_llm/_torch/attention_backend/trtllm_gen.py Outdated
Comment thread tensorrt_llm/_torch/modules/cross_attention.py Outdated
Comment thread tests/integration/defs/llmapi/test_llm_api_pytorch_bart.py Outdated
Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
@cascade812

Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54959 [ run ] triggered by Bot. Commit: 9633568 Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54959 [ run ] completed with state SUCCESS. Commit: 9633568
/LLM/main/L0_MergeRequest_PR pipeline #43959 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants