[TRTLLM-13490][feat] Support cross-attention with FlashInfer TRTLLM-Gen kernels on Blackwell#15429
[TRTLLM-13490][feat] Support cross-attention with FlashInfer TRTLLM-Gen kernels on Blackwell#15429cascade812 wants to merge 5 commits into
Conversation
Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
📝 WalkthroughWalkthroughCross-attention support is added to the trtllm-gen attention backend. Changes span a C++ kernel null-check fix, new ChangesCross-attention support for trtllm-gen backend
Sequence Diagram(s)sequenceDiagram
participant CrossAttention
participant FlashInferTrtllmGenAttention as FlashInferTrtllmGenAttention (trtllm_gen.py)
participant trtllm_gen_context_preprocess as trtllm_gen_context_preprocess (bindings.cpp)
participant trtllmGenContextPreprocess as trtllmGenContextPreprocess (QKVProcessOp.cpp)
participant _trtllm_gen_batch_context_with_kv_cache
CrossAttention->>FlashInferTrtllmGenAttention: run_context(cross_kv, meta.is_cross=True)
FlashInferTrtllmGenAttention->>FlashInferTrtllmGenAttention: guard cross_kv present if update_kv_cache
FlashInferTrtllmGenAttention->>trtllm_gen_context_preprocess: cross_kv, cross_attention=True
trtllm_gen_context_preprocess->>trtllmGenContextPreprocess: cross_kv ptr, cross_attention=true
trtllmGenContextPreprocess->>trtllmGenContextPreprocess: effectiveWindowSize = max_past_kv_length
trtllmGenContextPreprocess->>trtllmGenContextPreprocess: qkvParams.cross_kv_input = cross_kv
trtllmGenContextPreprocess->>trtllmGenContextPreprocess: qkvParams.cross_attention = true
trtllmGenContextPreprocess-->>trtllm_gen_context_preprocess: windowLeft = -1
trtllm_gen_context_preprocess-->>FlashInferTrtllmGenAttention: preprocess result
FlashInferTrtllmGenAttention->>_trtllm_gen_batch_context_with_kv_cache: causal=False
_trtllm_gen_batch_context_with_kv_cache-->>FlashInferTrtllmGenAttention: attn output
FlashInferTrtllmGenAttention-->>CrossAttention: early return (skip non-cross-attn postprocess)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h (2)
2-2:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winUpdate the NVIDIA copyright year for this modified source file.
This file now has a 2026 meaningful modification but the header still ends at 2024. As per coding guidelines, “All TensorRT-LLM source files should contain an NVIDIA copyright header including the year of latest meaningful modification.”
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h` at line 2, Update the copyright year in the file header comment at the top of the file. Change the year range from "2019-2024" to "2019-2026" in the NVIDIA copyright notice to reflect the current meaningful modification year, as required by coding guidelines that mandate copyright headers include the year of latest meaningful modification.Source: Coding guidelines
1415-1423:⚠️ Potential issue | 🟠 Major | ⚡ Quick winAvoid launching encoder-length no-op work when no encoder KV is stored.
With the new
cross_kv_inputguard, generation cross-attention skips KV writes, butmax_seq_lenand the launch grid still scale with the encoder length. For long encoder inputs, every decode step can launch CTAs that do no useful work fortoken_idx >= decoder_seq_len.⚡ Proposed fix
- // The encoder sequence length. - int const encoder_seq_len = params.encoder_seq_lens[batch_idx]; // The encoder sequence offset. // Not needed in Gen phase int const encoder_seq_offset = params.generation_phase ? -1 : params.cu_kv_seq_lens[batch_idx]; + // Only the first chunk needs to store encoder kv input to the kv cache. + bool const store_encoder_kv_cache = params.cross_kv_input != nullptr && (decoder_seq_len == decoder_cache_seq_len); + // The encoder sequence length. + int const encoder_seq_len = store_encoder_kv_cache ? params.encoder_seq_lens[batch_idx] : 0; // THe maximum sequence length of encoder and decoder. int const max_seq_len = max(decoder_seq_len, encoder_seq_len); - - // Only the first chunk needs to store encoder kv input to the kv cache. - bool const store_encoder_kv_cache = params.cross_kv_input != nullptr && (decoder_seq_len == decoder_cache_seq_len);- // The maximum sequence length of encoder and decoder inputs. - int const max_seq_len = std::max(params.max_input_seq_len, params.max_kv_seq_len); + // The maximum sequence length of useful Q/KV work. + int const kv_max_seq_len = (params.generation_phase || params.cross_kv_input == nullptr) ? 0 : params.max_kv_seq_len; + int const max_seq_len = std::max(params.max_input_seq_len, kv_max_seq_len);Also applies to: 1534-1535
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h` around lines 1415 - 1423, The max_seq_len calculation currently always includes encoder_seq_len, which causes the launch grid to scale with encoder length even during generation when store_encoder_kv_cache is false, resulting in CTAs doing no useful work for token indices beyond decoder_seq_len. Modify the max_seq_len computation to only include encoder_seq_len when encoder KV is actually being stored to the cache (when store_encoder_kv_cache is true); otherwise use only decoder_seq_len to avoid launching unnecessary work during generation. This change should also be applied at the other affected locations mentioned in the comment.
🧹 Nitpick comments (1)
tests/unittest/_torch/attention_backend/test_trtllm_gen.py (1)
145-233: ⚡ Quick winAdd generation-path unit coverage for cross-attention wiring.
Coverage is currently insufficient in
tests/unittest/_torch/attention_backend/test_trtllm_gen.pyfor the newly changed generation path. Please add focused tests in this same file for:
run_generation()forwarding ofcross_attentioninto generation preprocess, and- the
update_kv_cache=True+ missingcross_kvfailure path.As per coding guidelines:
tests/**reviews should state whether coverage is sufficient/insufficient and suggest concrete file-level follow-ups.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unittest/_torch/attention_backend/test_trtllm_gen.py` around lines 145 - 233, Add two new unit test functions to the test_trtllm_gen.py file following the pattern of the existing test_cross_attention_context_uses_fused_preprocess test. First, create a test that verifies the run_generation() method properly forwards the cross_attention parameter to the generation preprocess function by mocking the preprocess and generation calls and asserting that cross_attention=True is passed through correctly. Second, create a test that verifies the failure path when update_kv_cache=True is set but cross_kv is missing or None, ensuring the appropriate error is raised. Both tests should follow the same monkeypatching and assertion patterns established in the existing test to ensure consistent coverage of the generation path cross-attention wiring.Source: Coding guidelines
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tests/integration/test_lists/test-db/l0_b200.yml`:
- Line 90: The new T5 TRTLLM-GEN test
(test_t5_pytorch_generate_encoder_decoder_trtllm_gen_attention) is missing from
the l0_b200.yml pre-merge CI selector configuration, while the corresponding
BART test (test_bart_pytorch_generate_encoder_decoder_trtllm_gen_attention) is
present. Add a selector entry for the T5 TRTLLM-GEN test in l0_b200.yml to
ensure the new T5 route is covered by pre-merge CI and regressions cannot slip
through.
---
Outside diff comments:
In
`@cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h`:
- Line 2: Update the copyright year in the file header comment at the top of the
file. Change the year range from "2019-2024" to "2019-2026" in the NVIDIA
copyright notice to reflect the current meaningful modification year, as
required by coding guidelines that mandate copyright headers include the year of
latest meaningful modification.
- Around line 1415-1423: The max_seq_len calculation currently always includes
encoder_seq_len, which causes the launch grid to scale with encoder length even
during generation when store_encoder_kv_cache is false, resulting in CTAs doing
no useful work for token indices beyond decoder_seq_len. Modify the max_seq_len
computation to only include encoder_seq_len when encoder KV is actually being
stored to the cache (when store_encoder_kv_cache is true); otherwise use only
decoder_seq_len to avoid launching unnecessary work during generation. This
change should also be applied at the other affected locations mentioned in the
comment.
---
Nitpick comments:
In `@tests/unittest/_torch/attention_backend/test_trtllm_gen.py`:
- Around line 145-233: Add two new unit test functions to the test_trtllm_gen.py
file following the pattern of the existing
test_cross_attention_context_uses_fused_preprocess test. First, create a test
that verifies the run_generation() method properly forwards the cross_attention
parameter to the generation preprocess function by mocking the preprocess and
generation calls and asserting that cross_attention=True is passed through
correctly. Second, create a test that verifies the failure path when
update_kv_cache=True is set but cross_kv is missing or None, ensuring the
appropriate error is raised. Both tests should follow the same monkeypatching
and assertion patterns established in the existing test to ensure consistent
coverage of the generation path cross-attention wiring.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 7c497782-2541-47dc-900a-9921a2ebfb3c
📒 Files selected for processing (12)
cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.hcpp/tensorrt_llm/nanobind/thop/bindings.cppcpp/tensorrt_llm/thop/trtllmGenFusedOps.hcpp/tensorrt_llm/thop/trtllmGenQKVProcessOp.cpptensorrt_llm/_torch/attention_backend/trtllm.pytensorrt_llm/_torch/attention_backend/trtllm_gen.pytensorrt_llm/_torch/models/modeling_t5.pytensorrt_llm/_torch/modules/cross_attention.pytests/integration/defs/llmapi/test_llm_api_pytorch_bart.pytests/integration/defs/llmapi/test_llm_api_pytorch_t5.pytests/integration/test_lists/test-db/l0_b200.ymltests/unittest/_torch/attention_backend/test_trtllm_gen.py
💤 Files with no reviewable changes (1)
- tensorrt_llm/_torch/attention_backend/trtllm.py
|
/bot run |
|
PR_Github #54692 [ run ] triggered by Bot. Commit: |
|
PR_Github #54692 [ run ] completed with state
|
Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
|
/bot run |
|
PR_Github #54700 [ run ] triggered by Bot. Commit: |
|
PR_Github #54700 [ run ] completed with state
|
Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
|
/bot run |
|
PR_Github #54959 [ run ] triggered by Bot. Commit: |
|
PR_Github #54959 [ run ] completed with state
|
Description
When TRTLLM_ENABLE_TRTLLM_GEN_ATTENTION=1, cross-attention currently falls back to THOP on both sm90 and sm100. This PR enables the
trtllm_gen_backendpath for cross-attention, which uses FlashInfer/TRTLLM-Gen kernels.Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
If PR introduces API changes, an appropriate PR label is added - either
api-compatibleorapi-breaking. Forapi-breaking, includeBREAKINGin the PR title.Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.Summary by CodeRabbit
New Features
Tests