[#12808][feat] AutoDeploy: Custom attn mask support for attention backends#12742
[#12808][feat] AutoDeploy: Custom attn mask support for attention backends#12742bmarimuthu-nv wants to merge 13 commits intoNVIDIA:mainfrom
Conversation
|
@coderabbitai summary |
✅ Actions performedSummary regeneration triggered. |
📝 WalkthroughWalkthroughAdds backend-native custom attention-mask support: provider registry and context, graph transform to inject optional Changes
Sequence DiagramsequenceDiagram
participant FX as FX GraphModule
participant Transform as InjectCustomAttentionMask
participant Registry as AttentionMaskProviderRegistry
participant Provider as Mask Provider
participant Context as AttentionMaskProviderContext
participant FXNode as Modified FX Graph Node
FX->>Transform: scan graph for torch_attention nodes
Transform->>Registry: lookup(model_type, backend)
Registry-->>Transform: provider(fn) or None
Transform->>Provider: invoke(provider, context, attn_node)
Provider->>Context: add_or_retrieve_input("custom_attn_mask")
Context-->>Provider: placeholder node (or create)
Provider->>Context: get_or_create_cached_node(mask_key)
Context-->>Provider: memoized mask node
Provider->>Transform: return mask node
Transform->>FXNode: inject mask into attn_node args/kwargs
FXNode-->>FX: updated attention node consumes custom_attn_mask at runtime
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
Squash merge of nv-auto-deploy:bala/custom-attn-mask (PR NVIDIA#12742). Adds custom attention mask injection infrastructure for AutoDeploy attention backends and Gemma4 model-specific attention masking based on token type grouping and media boundaries. Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
2a095f4 to
e383ad1
Compare
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
|
/bot run --disable-fail-fast --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" |
e383ad1 to
937ccf6
Compare
|
/bot run --disable-fail-fast --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" |
|
PR_Github #41973 [ run ] triggered by Bot. Commit: |
|
PR_Github #41975 [ run ] triggered by Bot. Commit: |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (1)
tensorrt_llm/_torch/auto_deploy/transform/attention_mask_provider.py (1)
50-53: Consider catching more specific exceptions.The bare
except Exceptioncatch is flagged by static analysis. While this is "best-effort" inference and suppressing failures is intentional, catching a narrower set of exceptions (e.g.,AttributeError,TypeError,ValueError) would avoid masking unexpected errors likeKeyboardInterruptor system-level issues.♻️ Suggested narrower exception handling
if callable(get_model_config): try: model_config, _unused_kwargs = get_model_config() - except Exception: + except (AttributeError, TypeError, ValueError, KeyError): return None🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/transform/attention_mask_provider.py` around lines 50 - 53, Replace the bare "except Exception" around the get_model_config() call with a narrower exception clause that only catches expected failures (e.g., AttributeError, TypeError, ValueError) so unexpected system-level exceptions aren't swallowed; specifically update the try/except that assigns model_config, _unused_kwargs = get_model_config() in attention_mask_provider.py to catch those specific exceptions and return None in that narrow except block.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py`:
- Around line 362-437: The context-path helper _torch_context_mha_readonly
currently has the wrong parameter order (logit_cap is in the slot where
torch_backend_mha_with_cache now forwards custom_attn_mask), causing TypeError
and never applying the custom mask; update _torch_context_mha_readonly's
signature to accept custom_attn_mask (e.g., add custom_attn_mask:
Optional[torch.Tensor] before logit_cap or match the caller's order), propagate
that argument to where masks are built, and apply custom_attn_mask (combined
with causal/sliding-window masks) to attn_scores before softmax; make the
identical signature and mask-application fix for the other context helper
referenced around the 525-539 block so both helpers accept and use
custom_attn_mask consistently.
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py`:
- Around line 849-877: The code incorrectly hard-codes mask_head_offset = 0 so
all heads use head-0's mask; update the logic in triton_paged_attention (around
mask_head_offset, mask_offsets computation and loading custom_mask via
custom_mask_ptr) to incorporate the current head index (e.g., head_id or a
passed-in head offset) when computing mask_head_offset, or alternatively add a
guard in the wrapper that validates the incoming mask is broadcastable across
heads and fails fast if not; ensure mask_offsets uses the computed
mask_head_offset so per-head masks [B,H,Q,K] are indexed correctly when loading
custom_mask.
In
`@tests/unittest/auto_deploy/singlegpu/transformations/library/test_inject_custom_attention_mask.py`:
- Around line 1-13: Replace the legacy header block that begins with "#
Copyright (c) 2025, NVIDIA CORPORATION." with the repo-standard SPDX NVIDIA
header and update the year to the latest modification year used elsewhere in the
PR; ensure the new header matches the SPDX format used across the project
(including the SPDX-License-Identifier line and the correct copyright year) so
the top-of-file header in test_inject_custom_attention_mask.py conforms to
repository conventions.
---
Nitpick comments:
In `@tensorrt_llm/_torch/auto_deploy/transform/attention_mask_provider.py`:
- Around line 50-53: Replace the bare "except Exception" around the
get_model_config() call with a narrower exception clause that only catches
expected failures (e.g., AttributeError, TypeError, ValueError) so unexpected
system-level exceptions aren't swallowed; specifically update the try/except
that assigns model_config, _unused_kwargs = get_model_config() in
attention_mask_provider.py to catch those specific exceptions and return None in
that narrow except block.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 104e9f2b-e54d-4d9a-b81e-992c80a6c74b
📒 Files selected for processing (16)
tensorrt_llm/_torch/auto_deploy/config/default.yamltensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.pytensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.pytensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.pytensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.pytensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.pytensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.pytensorrt_llm/_torch/auto_deploy/transform/__init__.pytensorrt_llm/_torch/auto_deploy/transform/attention_mask_provider.pytensorrt_llm/_torch/auto_deploy/transform/attention_mask_providers.pytensorrt_llm/_torch/auto_deploy/transform/library/inject_custom_attention_mask.pytensorrt_llm/_torch/auto_deploy/transform/library/kvcache.pytests/unittest/auto_deploy/_utils_test/torch_attention_reference.pytests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_torch_attention_op.pytests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.pytests/unittest/auto_deploy/singlegpu/transformations/library/test_inject_custom_attention_mask.py
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py
Show resolved
Hide resolved
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py
Outdated
Show resolved
Hide resolved
.../unittest/auto_deploy/singlegpu/transformations/library/test_inject_custom_attention_mask.py
Show resolved
Hide resolved
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
|
/bot run --disable-fail-fast --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" |
|
PR_Github #41984 [ run ] triggered by Bot. Commit: |
| triton.Config({}, num_warps=8, num_stages=3), | ||
| ], | ||
| key=["HEAD_DIM", "PAGE_SIZE", "HEAD_RATIO_PADDED"], | ||
| key=["HEAD_DIM", "PAGE_SIZE", "HEAD_RATIO_PADDED", "SLIDING_WINDOW"], |
There was a problem hiding this comment.
Sliding window changes pulled from 141e0d4#diff-e5e901759fe5eb0bfbd8a6fbbdbe5ed498a49b93bff19830e9c95c10745d6b14 and updated to support user custom_attn_mask + sliding_window on top
Remove unused ("gemma4", "torch") provider registration, fix test to
use production backend string "torch_attention", and replace hardcoded
arg position lookups with extract_op_args / _get_op_schema utilities.
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Change get_dynamic_inputs() to return Dict[str, Optional[Node]] and pass them as kwargs to the cached attention op via call_function, eliminating fragile positional ordering between caches and constants. Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
|
/bot run --disable-fail-fast --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" |
|
PR_Github #41984 [ run ] completed with state
|
|
/bot run --disable-fail-fast --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" |
|
PR_Github #42053 [ run ] triggered by Bot. Commit: |
|
PR_Github #42053 [ run ] completed with state
|
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Add missing dynamic_kwargs parameter to match base class signature. Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
|
/bot run --disable-fail-fast --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" |
|
PR_Github #42151 [ run ] triggered by Bot. Commit: |
lucaslie
left a comment
There was a problem hiding this comment.
Design Review: Custom Attention Mask Infrastructure
Thanks for the work here — several components are valuable and should be kept regardless of the overall architecture. Specifically:
- Sliding window in triton_paged (both context and decode kernels) is a solid improvement
extract_op_argsrefactoring replacing ad-hoc positional arg extraction is clean_paged_context_masked_kernelTriton kernel for masked prefill will be needed regardless of design_torch_context_mha_readonlyfor read-only context attention is useful
However, the overall design has several structural issues that will make it fragile at scale. Below is a detailed analysis and an alternative proposal.
Critical Issues
1. Type Variance for Graph Inputs
The PR passes custom_attn_mask as a graph input that alternates between None (decode/warmup/text-only) and a tensor (VLM prefill). This is problematic because:
torch.exportrequires tensor-typed graph inputs;Noneis not a valid dynamic shape source.- CUDA graph capture requires fixed tensor addresses and shapes. The monolithic decode path uses
CapturedGraphwhich hashes static inputs; switching betweenNoneand a tensor changes the hash, causing fallback to eager execution.
2. _default_extra_arg_factories Is Over-Engineered
The new factory system on SequenceInfo is unnecessary because:
- The existing
capture_forward_kwargsmechanism inexport_to_gm.pyalready captures whatever kwargs the submodule receives during export, including additional inputs. - Monolithic decode warmup uses
set_generate_only_batch()which doesn't pass multimodal extras. - Piecewise prefill warmup in
_setup_piecewise_mixed_batchalso uses synthetic sequences without extras.
3. Provider Registry Creates Tight Coupling
The AttentionMaskProviderRegistry keyed by (model_type, backend) means:
- Every new model needing masks must register providers for every backend.
- Model-specific mask semantics are coupled to backend code.
- Adding a new backend requires updating all model providers.
4. mutates_args Regression
torch_cached_attention_with_cache changed from mutates_args=("k_cache", "v_cache") to mutates_args=(). The k/v caches are mutated in-place during context attention (_torch_context_mha writes to them). Removing this annotation means the graph compiler may not track these mutations correctly.
5. Mask Computation Outside Graph Is Fragile
Computing the mask in the VLM wrapper (outside the graph) means:
- The wrapper must know backend-specific mask formats (e.g., triton requires
[B, 1, S_q, S_k]uint8). - The mask must be constructed before the graph call, requiring access to runtime metadata.
- The graph cannot be tested in isolation for mask correctness.
Alternative Design Proposal
Principle: Two Modes for attn_mask, Analogous to Attention's Own Cached/Uncached Swap Pattern
The canonical IR op torch_attention already has an attn_mask parameter. Instead of building a separate injection transform, make better use of this parameter with two modes. The existing is_causal boolean and attn_mask: Optional[torch.Tensor] parameters are preserved as-is.
Mode A: String Mask Types (Design Now, Implement Later)
In addition to Optional[torch.Tensor], allow attn_mask to accept well-known string literals (e.g. "causal", "bidirectional"). Each attention backend maps recognized strings to its optimized implementation or raises an error if unsupported. The existing is_causal boolean is kept; when attn_mask is a string, it takes precedence.
This mode is not required for the immediate VLM iteration — it should be part of the design but can be implemented as a follow-up.
Mode B: Mask Custom Op (for VLM/Gemma4)
For complex, model-specific masks that cannot be described by a string:
Key layout constraint: During export, the model runs in [B, S] layout (e.g., set_example_sequence creates [2, 4] inputs). Only after the kvcache transform does the graph switch to ragged [1, total_tokens] layout. The mask custom op's reference implementation must therefore operate in [B, S] layout.
- Define a mask custom op in model code (not backend code):
@torch.library.custom_op("auto_deploy::gemma4_build_vlm_mask", mutates_args=())
def gemma4_build_vlm_mask(
is_vision_token: torch.Tensor, # [B, S] bool - same layout as input_ids at export time
seq_len: int,
) -> torch.Tensor:
# Reference: build VLM attention mask from vision token indicator.
# Operates in [B, S] layout (pre-kvcache-transform).
# Returns mask in [B, 1, S, S] compatible with torch_attention attn_mask.
...-
During export: The model's forward calls this custom op to build the mask from
is_vision_token, passes the result totorch_attention(attn_mask=...). The existingcapture_forward_kwargsmechanism naturally capturesis_vision_tokenas a graph input with dynamic shapes. No new infrastructure needed. -
During kvcache transform: Just like
torch_attentionis swapped forcached_attention, the mask custom op is identified as a single node in the graph and swapped for a backend-specific version. The backend version:- Receives the original args (e.g.,
is_vision_token— now in ragged layout) - Receives metadata args (e.g.,
batch_info_host, sequence boundaries) - Produces a mask in the format the backend's cached attention kernel expects
- Receives the original args (e.g.,
-
Backend-specific mask functions registered from model code:
# In model code (e.g., _torch/auto_deploy/models/gemma4.py), not backend code
MaskOpRegistry.register(
source_op="auto_deploy::gemma4_build_vlm_mask",
backend="triton_paged",
)(gemma4_triton_paged_mask_fn)The registration is model-driven: each model declares which backends it supports masks for. A backend with no registered mask function for a given mask op raises a clear error. This avoids the (model_type, backend) cartesian product problem.
VLM Prefill/Decode Handling
Following the existing prefill/decode specialization pattern:
- Prefill with images: Backend mask function reads
is_vision_token(ragged layout), uses metadata (batch_info_host,cu_seqlen) to build the VLM mask. - Prefill without images:
is_vision_tokenis all-False; backend mask function produces standard causal mask or returns a sentinel for the fast path. - Decode: Backend mask function checks
batch_info_host, sees decode-only, returns a no-op. This is CUDA-graph compatible because:- Decode runs monolithic CUDA graphs.
- The mask function still executes (same graph structure) but produces a dummy output.
- The attention kernel's decode path does not consume the mask output.
is_vision_tokenis padded to the CUDA graph bucket size (all-False), so tensor shapes are stable.
The mask computation itself is not expected to be CUDA-graph compatible (it may involve data-dependent control flow). This is fine because mask computation only matters for prefill, and prefill is not CUDA-graphed. For decode CUDA graphs, the mask function runs but is a no-op.
Key Advantages
- No type variance:
is_vision_tokenis always a tensor. NoNone/tensor switching for graph inputs. - No
_default_extra_arg_factories: Uses existingcapture_forward_kwargsandextra_argsmechanisms. - Separation of concerns: Model code defines mask semantics and registers backend implementations. Backend attention code only needs to accept the mask output.
- CUDA graph compatible by design: Decode ignores mask output; mask function is a no-op.
- Testable: Mask computation is in the graph — graph-level tests verify correctness end-to-end.
- Preserves existing API:
is_causalkept as-is;attn_maskextended naturally.
Recommendation
I'd suggest splitting the valuable parts (sliding window support, extract_op_args cleanup, the triton masked kernel) into separate PRs, and redesigning the mask injection infrastructure along the lines above.
| for ``name``. It is called at the start of every ``nest_sequences`` | ||
| so that initialization-time forward passes (e.g. ``resize_kv_cache``) | ||
| always receive a valid tensor for ``name`` even when no per-request data | ||
| is available. |
There was a problem hiding this comment.
The (model_type, backend) key creates a tight coupling: every new model needing custom masks must register a provider for every backend, and adding a new backend requires updating all model providers.
An alternative is a model-driven registry keyed by (source_mask_op, backend). The model defines a mask custom op (e.g. auto_deploy::gemma4_build_vlm_mask) and registers backend-specific replacements from its own code. The kvcache transform then identifies the mask op node in the graph and swaps it — analogous to how torch_attention is swapped for cached_attention. This way model-specific mask semantics stay in model code, not in the attention backend.
| @staticmethod | ||
| def _get_attn_mask_arg(node: Node): | ||
| return extract_op_args(node, "attn_mask")[0] | ||
|
|
There was a problem hiding this comment.
Rather than a separate injection transform that rewires attn_mask on existing torch_attention nodes, consider having the model code produce the mask in-graph via a mask custom op during export. The mask op output feeds into torch_attention(attn_mask=...) naturally.
During the kvcache transform, the mask custom op is swapped for a backend-specific version (just like torch_attention is swapped for cached_attention). This eliminates the need for this transform entirely and keeps mask computation inside the graph where it can be tested and verified.
|
|
||
| # EXTRA TENSOR FIELDS ###################################################################### | ||
| self._extra_args: Dict[str, Optional[torch.Tensor]] = {} | ||
| # Default factories for extra args: callables that accept this SequenceInfo instance |
There was a problem hiding this comment.
This _default_extra_arg_factories machinery is unnecessary. The existing capture_forward_kwargs mechanism in export_to_gm.py already captures whatever kwargs the submodule receives during export — including additional inputs like is_vision_token. For warmup/CUDA-graph paths:
- Monolithic decode warmup uses
set_generate_only_batch()which doesn't pass multimodal extras. - Piecewise prefill warmup in
_setup_piecewise_mixed_batchuses synthetic sequences without extras.
Both paths already produce valid inputs without needing factories. Adding this callback system to SequenceInfo increases the surface area of a critical class without a clear necessity.
| logit_cap: Optional[float] = None, | ||
| read_cache_only: bool = False, | ||
| # DYNAMIC INPUTS | ||
| custom_attn_mask: Optional[torch.Tensor] = None, |
There was a problem hiding this comment.
Two concerns with the signature changes in this op:
-
mutates_argsregression: The decorator was changed frommutates_args=("k_cache", "v_cache")tomutates_args=(). The k/v caches are still mutated in-place during context attention (_torch_context_mhawrites tok_cacheandv_cache). Removing this annotation means the graph compiler may not track these mutations correctly. -
Dynamic inputs as kwargs: Rather than threading
custom_attn_maskthrough as a dynamic kwarg on the cached attention op, consider having the mask computation live inside the graph as a separate custom op node whose output feeds into the attention op. During the kvcache transform, that mask node gets swapped for a backend-specific version — keeping the cached attention op signature stable.
| meta_nodes_std: List[Node], | ||
| meta_nodes_extra: List[Node], | ||
| cache_nodes: List[Node], | ||
| dynamic_kwargs: Dict[str, Optional[Node]], |
There was a problem hiding this comment.
The dynamic_kwargs / get_dynamic_inputs() mechanism adds a generic escape hatch for forwarding arbitrary tensor kwargs through the cached attention op. This is powerful but also fragile — it changes the calling convention of every cached attention op and requires all backends to accept these kwargs.
With the mask-custom-op approach (where the mask is computed by a separate node in the graph that gets swapped during this same transform), this generic mechanism wouldn't be needed. The mask node swap would happen alongside the attention node swap, and the mask output would be consumed by the attention op as a regular positional arg — no kwargs plumbing required.
|
PR_Github #42151 [ run ] completed with state
|
Summary by CodeRabbit
Description
This branch adds custom attention mask injection infrastructure for the AutoDeploy pipeline, specifically targeting Gemma4's VLM attention patterns. The
work spans:
This is infra only. A follow up Gemma4 multimodal PR PR will implement the end 2 end for the model.
Custom Attention Mask Flow
Key Design Points
custom_attn_maskplaceholder (defaults toNone)(model_type, backend)→ mask wiring functionget_dynamic_inputs()onAttentionDescriptorensures the mask tensor flows through the cached attention op (between caches and constants)[B, 1, S_q, S_k]; torch backend supports[B, N, S_q, S_k]via slicingBlock Diagram
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.Summary by CodeRabbit
Release Notes
New Features
Tests