Skip to content

[#12808][feat] AutoDeploy: Custom attn mask support for attention backends#12742

Open
bmarimuthu-nv wants to merge 13 commits intoNVIDIA:mainfrom
nv-auto-deploy:bala/custom-attn-mask
Open

[#12808][feat] AutoDeploy: Custom attn mask support for attention backends#12742
bmarimuthu-nv wants to merge 13 commits intoNVIDIA:mainfrom
nv-auto-deploy:bala/custom-attn-mask

Conversation

@bmarimuthu-nv
Copy link
Copy Markdown
Collaborator

@bmarimuthu-nv bmarimuthu-nv commented Apr 3, 2026

Summary by CodeRabbit

Description

This branch adds custom attention mask injection infrastructure for the AutoDeploy pipeline, specifically targeting Gemma4's VLM attention patterns. The
work spans:

  1. A new transform (inject_custom_attention_mask) with a provider registry pattern
  2. A new get_dynamic_inputs() hook in the AttentionDescriptor ABC to forward runtime tensors through the cached attention op
  3. A default extra-arg factory system on SequenceInfo for warmup/CUDA-graph passes
  4. A new Triton kernel (_paged_context_masked_kernel) for masked prefill
  5. Mask support plumbed into torch and triton_paged backends
  6. Refactoring of layout extraction to use extract_op_args() consistently

This is infra only. A follow up Gemma4 multimodal PR PR will implement the end 2 end for the model.

Custom Attention Mask Flow

Key Design Points

  • Mask is computed outside the graph (no mask-building ops in the exported graph). In-graph mask computation hits symbolic shape and device errors after graph recompilation during KV cache transforms
  • The graph only receives a custom_attn_mask placeholder (defaults to None)
  • Provider registry maps (model_type, backend) → mask wiring function
  • get_dynamic_inputs() on AttentionDescriptor ensures the mask tensor flows through the cached attention op (between caches and constants)
  • Triton kernel requires broadcast mask [B, 1, S_q, S_k]; torch backend supports [B, N, S_q, S_k] via slicing
  • custom_attn_mask should not be expected to carry sliding-window logic. The backend must apply sliding window on top of whatever user mask is provided.

Block Diagram

┌─────────────────────────────────────────────────────────────────────────┐
│                        TRANSFORM TIME (graph rewrite)                   │
│                                                                         │
│  InferenceOptimizer                                                     │
│    │                                                                    │
│    ▼                                                                    │
│  InjectCustomAttentionMask transform                                    │
│    │                                                                    │
│    ├─ infer_model_type(factory) ──► "gemma4"                           │
│    │                                                                    │
│    ├─ AttentionMaskProviderRegistry.get("gemma4", backend)             │
│    │     │                                                              │
│    │     ▼                                                              │
│    │   _gemma4_*_mask_provider(ctx, attn_node)                         │
│    │     │                                                              │
│    │     ▼                                                              │
│    │   ctx.add_or_retrieve_input("custom_attn_mask", val=None)         │
│    │     │                  ▲                                           │
│    │     │                  │ (cached — built once, shared across       │
│    │     │                  │  all attn nodes in the graph)             │
│    │     ▼                                                              │
│    │   placeholder node: custom_attn_mask                              │
│    │                                                                    │
│    ├─ For each torch_attention node:                                    │
│    │     set attn_mask arg ──► placeholder node                        │
│    │                                                                    │
│    ▼                                                                    │
│  _InsertCachedOperator (kvcache.py)                                    │
│    │                                                                    │
│    ├─ get_dynamic_inputs(attn_node) ──► [attn_mask node]               │
│    │                                                                    │
│    ▼                                                                    │
│  cached_attention_op(q, k, v, *meta, *caches, *dynamic, *constants)    │
│                                            ▲                            │
│                                            │                            │
│                              custom_attn_mask inserted here             │
└─────────────────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────────────────┐
│                        RUNTIME (forward pass)                           │
│                                                                         │
│  VLM Wrapper (outside the graph)                                        │
│    │                                                                    │
│    ├─ Prefill with vision tokens?                                       │
│    │     YES ──► compute mask [B, 1, S_q, S_k] from token_type_ids     │
│    │     NO  ──► pass None (decode / text-only)                        │
│    │                                                                    │
│    ▼                                                                    │
│  graph(input_ids, position_ids, custom_attn_mask=mask_or_none)         │
│    │                                                                    │
│    ▼                                                                    │
│  ┌───────────────────────────────────────────────────────┐              │
│  │  Backend dispatch (per attention layer)                │              │
│  │                                                       │              │
│  │  if custom_attn_mask is None:                         │              │
│  │    ├─ torch_backend ──► standard causal kernel         │              │
│  │    └─ triton_paged  ──► triton_paged_context()        │              │
│  │                                                       │              │
│  │  if custom_attn_mask is not None:                     │              │
│  │    ├─ torch_backend ──► _torch_context_mha()          │              │
│  │    │    mask[idx, :, :S_q, :S_kv] ──► masked_fill_   │              │
│  │    │                                                  │              │
│  │    └─ triton_paged  ──► _paged_context_masked_kernel  │              │
│  │         mask[B, 0, q_pos, kv_pos] ──► per-page load  │              │
│  │         (broadcast over heads, head dim must be 1)    │              │
│  └───────────────────────────────────────────────────────┘              │
│                                                                         │
│  Warmup / CUDA-graph capture passes:                                    │
│    SequenceInfo.nest_sequences()                                        │
│      │                                                                  │
│      ├─ _default_extra_arg_factories["custom_attn_mask"](info)         │
│      │    ──► returns None (no mask during warmup)                      │
│      │                                                                  │
│      └─ per-request extra_args override if present                      │
└─────────────────────────────────────────────────────────────────────────┘

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

Summary by CodeRabbit

Release Notes

  • New Features

    • Added custom attention mask support across multiple attention backends (torch, triton_paged, torch_attention). Introduced an attention mask provider registry enabling model-specific mask handling. Enhanced graph transformation pipeline with mask injection capabilities. Extended attention descriptors to support dynamic inputs and shared KV caching.
  • Tests

    • Added comprehensive test coverage for custom attention mask functionality and provider registry integration.

@bmarimuthu-nv
Copy link
Copy Markdown
Collaborator Author

@coderabbitai summary

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 3, 2026

✅ Actions performed

Summary regeneration triggered.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 3, 2026

📝 Walkthrough

Walkthrough

Adds backend-native custom attention-mask support: provider registry and context, graph transform to inject optional custom_attn_mask into attention ops, attention-backend plumbing (torch/triton) to accept/use/read-only masks and shared-KV, and tests validating behavior for Gemma4-style token-type masks.

Changes

Cohort / File(s) Summary
Pipeline config
tensorrt_llm/_torch/auto_deploy/config/default.yaml
Registered inject_custom_attention_mask transform to run at pattern_matcher with torch_attention backend.
Transform registry & init
tensorrt_llm/_torch/auto_deploy/transform/__init__.py, .../transform/attention_mask_provider.py, .../transform/attention_mask_providers.py
Added attention-mask provider infrastructure: context, registry, model-type inference, and Gemma4-specific provider registrations that expose a custom_attn_mask graph input.
Inject transform
tensorrt_llm/_torch/auto_deploy/transform/library/inject_custom_attention_mask.py
New InjectCustomAttentionMask transform (registered inject_custom_attention_mask) that looks up providers and injects/overwrites backend-native attn_mask arguments on torch_attention nodes (configurable override/backend/model_type).
Attention descriptor / kvcache transform
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py, .../transform/library/kvcache.py
Extended attention descriptor contract to accept dynamic inputs and shared-KV helpers; kvcache transform threads get_dynamic_inputs() into cached-attention op insertion.
Torch backend attention
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py
Added optional custom_attn_mask and read_cache_only flag; introduced read-only context MHA routines, extracted generate-phase cache-write helper, support for shared KV and dynamic attn_mask input, and constant-list update to include read_cache_only.
Triton paged attention
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py
Added masked paged-context kernel and triton_paged_context_with_custom_mask wrapper; custom-op signatures accept optional custom_attn_mask, scale defaulted, and dispatch prefill to masked/unmasked kernels; added dynamic input handling and relaxed constants filtering.
Other attention backends
.../custom_ops/attention/flashinfer_attention.py, triton_attention.py, trtllm_attention.py
Standardized constant extraction by using extract_op_args(..., "layout")[0] for layout checks.
Module transform tests and references
tests/.../torch_attention_reference.py, tests/.../test_torch_attention_op.py, tests/.../test_triton_paged_attention.py, tests/.../test_inject_custom_attention_mask.py
Updated op-call sites to include extra dynamic arg slot; added tests validating custom boolean masks for torch and triton paged kernels and comprehensive transform/unit tests for inject_custom_attention_mask.

Sequence Diagram

sequenceDiagram
    participant FX as FX GraphModule
    participant Transform as InjectCustomAttentionMask
    participant Registry as AttentionMaskProviderRegistry
    participant Provider as Mask Provider
    participant Context as AttentionMaskProviderContext
    participant FXNode as Modified FX Graph Node
    FX->>Transform: scan graph for torch_attention nodes
    Transform->>Registry: lookup(model_type, backend)
    Registry-->>Transform: provider(fn) or None
    Transform->>Provider: invoke(provider, context, attn_node)
    Provider->>Context: add_or_retrieve_input("custom_attn_mask")
    Context-->>Provider: placeholder node (or create)
    Provider->>Context: get_or_create_cached_node(mask_key)
    Context-->>Provider: memoized mask node
    Provider->>Transform: return mask node
    Transform->>FXNode: inject mask into attn_node args/kwargs
    FXNode-->>FX: updated attention node consumes custom_attn_mask at runtime
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 52.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ⚠️ Warning The pull request has no description provided by the author, violating the template requirements for explaining changes, test coverage, and checklist items. Add a comprehensive PR description following the provided template, including what the changes do, why they are needed, test coverage details, and confirmation of the PR checklist items.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly and concisely summarizes the main feature: adding custom attention mask support to the AutoDeploy attention backends. It directly reflects the primary objective described in the PR.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

bmarimuthu-nv added a commit to nv-auto-deploy/TensorRT-LLM that referenced this pull request Apr 3, 2026
Squash merge of nv-auto-deploy:bala/custom-attn-mask (PR NVIDIA#12742).

Adds custom attention mask injection infrastructure for AutoDeploy
attention backends and Gemma4 model-specific attention masking based
on token type grouping and media boundaries.

Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
@bmarimuthu-nv bmarimuthu-nv force-pushed the bala/custom-attn-mask branch from 2a095f4 to e383ad1 Compare April 6, 2026 17:09
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
@bmarimuthu-nv bmarimuthu-nv changed the title [None][feat] AutoDeploy: Custom attn mask support for attention backends [None][feat] AutoDeploy: Custom attn mask support for attention backend Apr 6, 2026
@bmarimuthu-nv bmarimuthu-nv marked this pull request as ready for review April 6, 2026 18:40
@bmarimuthu-nv bmarimuthu-nv requested a review from a team as a code owner April 6, 2026 18:41
@bmarimuthu-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1"

@bmarimuthu-nv bmarimuthu-nv force-pushed the bala/custom-attn-mask branch from e383ad1 to 937ccf6 Compare April 6, 2026 18:42
@bmarimuthu-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1"

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41973 [ run ] triggered by Bot. Commit: 937ccf6 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41975 [ run ] triggered by Bot. Commit: 937ccf6 Link to invocation

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
tensorrt_llm/_torch/auto_deploy/transform/attention_mask_provider.py (1)

50-53: Consider catching more specific exceptions.

The bare except Exception catch is flagged by static analysis. While this is "best-effort" inference and suppressing failures is intentional, catching a narrower set of exceptions (e.g., AttributeError, TypeError, ValueError) would avoid masking unexpected errors like KeyboardInterrupt or system-level issues.

♻️ Suggested narrower exception handling
     if callable(get_model_config):
         try:
             model_config, _unused_kwargs = get_model_config()
-        except Exception:
+        except (AttributeError, TypeError, ValueError, KeyError):
             return None
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/transform/attention_mask_provider.py` around
lines 50 - 53, Replace the bare "except Exception" around the get_model_config()
call with a narrower exception clause that only catches expected failures (e.g.,
AttributeError, TypeError, ValueError) so unexpected system-level exceptions
aren't swallowed; specifically update the try/except that assigns model_config,
_unused_kwargs = get_model_config() in attention_mask_provider.py to catch those
specific exceptions and return None in that narrow except block.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py`:
- Around line 362-437: The context-path helper _torch_context_mha_readonly
currently has the wrong parameter order (logit_cap is in the slot where
torch_backend_mha_with_cache now forwards custom_attn_mask), causing TypeError
and never applying the custom mask; update _torch_context_mha_readonly's
signature to accept custom_attn_mask (e.g., add custom_attn_mask:
Optional[torch.Tensor] before logit_cap or match the caller's order), propagate
that argument to where masks are built, and apply custom_attn_mask (combined
with causal/sliding-window masks) to attn_scores before softmax; make the
identical signature and mask-application fix for the other context helper
referenced around the 525-539 block so both helpers accept and use
custom_attn_mask consistently.

In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py`:
- Around line 849-877: The code incorrectly hard-codes mask_head_offset = 0 so
all heads use head-0's mask; update the logic in triton_paged_attention (around
mask_head_offset, mask_offsets computation and loading custom_mask via
custom_mask_ptr) to incorporate the current head index (e.g., head_id or a
passed-in head offset) when computing mask_head_offset, or alternatively add a
guard in the wrapper that validates the incoming mask is broadcastable across
heads and fails fast if not; ensure mask_offsets uses the computed
mask_head_offset so per-head masks [B,H,Q,K] are indexed correctly when loading
custom_mask.

In
`@tests/unittest/auto_deploy/singlegpu/transformations/library/test_inject_custom_attention_mask.py`:
- Around line 1-13: Replace the legacy header block that begins with "#
Copyright (c) 2025, NVIDIA CORPORATION." with the repo-standard SPDX NVIDIA
header and update the year to the latest modification year used elsewhere in the
PR; ensure the new header matches the SPDX format used across the project
(including the SPDX-License-Identifier line and the correct copyright year) so
the top-of-file header in test_inject_custom_attention_mask.py conforms to
repository conventions.

---

Nitpick comments:
In `@tensorrt_llm/_torch/auto_deploy/transform/attention_mask_provider.py`:
- Around line 50-53: Replace the bare "except Exception" around the
get_model_config() call with a narrower exception clause that only catches
expected failures (e.g., AttributeError, TypeError, ValueError) so unexpected
system-level exceptions aren't swallowed; specifically update the try/except
that assigns model_config, _unused_kwargs = get_model_config() in
attention_mask_provider.py to catch those specific exceptions and return None in
that narrow except block.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 104e9f2b-e54d-4d9a-b81e-992c80a6c74b

📥 Commits

Reviewing files that changed from the base of the PR and between 662e45f and e383ad1.

📒 Files selected for processing (16)
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
  • tensorrt_llm/_torch/auto_deploy/transform/__init__.py
  • tensorrt_llm/_torch/auto_deploy/transform/attention_mask_provider.py
  • tensorrt_llm/_torch/auto_deploy/transform/attention_mask_providers.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/inject_custom_attention_mask.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py
  • tests/unittest/auto_deploy/_utils_test/torch_attention_reference.py
  • tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_torch_attention_op.py
  • tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py
  • tests/unittest/auto_deploy/singlegpu/transformations/library/test_inject_custom_attention_mask.py

Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
@bmarimuthu-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1"

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41984 [ run ] triggered by Bot. Commit: f903596 Link to invocation

@bmarimuthu-nv bmarimuthu-nv changed the title [None][feat] AutoDeploy: Custom attn mask support for attention backend [None][feat] AutoDeploy: Custom attn mask support for attention backends Apr 6, 2026
@suyoggupta suyoggupta requested review from lucaslie and removed request for marinayanov April 6, 2026 20:49
triton.Config({}, num_warps=8, num_stages=3),
],
key=["HEAD_DIM", "PAGE_SIZE", "HEAD_RATIO_PADDED"],
key=["HEAD_DIM", "PAGE_SIZE", "HEAD_RATIO_PADDED", "SLIDING_WINDOW"],
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sliding window changes pulled from 141e0d4#diff-e5e901759fe5eb0bfbd8a6fbbdbe5ed498a49b93bff19830e9c95c10745d6b14 and updated to support user custom_attn_mask + sliding_window on top

Remove unused ("gemma4", "torch") provider registration, fix test to
use production backend string "torch_attention", and replace hardcoded
arg position lookups with extract_op_args / _get_op_schema utilities.

Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Change get_dynamic_inputs() to return Dict[str, Optional[Node]] and
pass them as kwargs to the cached attention op via call_function,
eliminating fragile positional ordering between caches and constants.

Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
@bmarimuthu-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1"

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41984 [ run ] completed with state SUCCESS. Commit: f903596
/LLM/main/L0_MergeRequest_PR pipeline #32836 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@bmarimuthu-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1"

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42053 [ run ] triggered by Bot. Commit: f0fe102 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42053 [ run ] completed with state SUCCESS. Commit: f0fe102
/LLM/main/L0_MergeRequest_PR pipeline #32894 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Add missing dynamic_kwargs parameter to match base class signature.

Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
@bmarimuthu-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1"

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42151 [ run ] triggered by Bot. Commit: a92711c Link to invocation

@bmarimuthu-nv bmarimuthu-nv changed the title [None][feat] AutoDeploy: Custom attn mask support for attention backends [#12808][feat] AutoDeploy: Custom attn mask support for attention backends Apr 7, 2026
Copy link
Copy Markdown
Member

@lucaslie lucaslie left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Design Review: Custom Attention Mask Infrastructure

Thanks for the work here — several components are valuable and should be kept regardless of the overall architecture. Specifically:

  • Sliding window in triton_paged (both context and decode kernels) is a solid improvement
  • extract_op_args refactoring replacing ad-hoc positional arg extraction is clean
  • _paged_context_masked_kernel Triton kernel for masked prefill will be needed regardless of design
  • _torch_context_mha_readonly for read-only context attention is useful

However, the overall design has several structural issues that will make it fragile at scale. Below is a detailed analysis and an alternative proposal.


Critical Issues

1. Type Variance for Graph Inputs

The PR passes custom_attn_mask as a graph input that alternates between None (decode/warmup/text-only) and a tensor (VLM prefill). This is problematic because:

  • torch.export requires tensor-typed graph inputs; None is not a valid dynamic shape source.
  • CUDA graph capture requires fixed tensor addresses and shapes. The monolithic decode path uses CapturedGraph which hashes static inputs; switching between None and a tensor changes the hash, causing fallback to eager execution.

2. _default_extra_arg_factories Is Over-Engineered

The new factory system on SequenceInfo is unnecessary because:

  • The existing capture_forward_kwargs mechanism in export_to_gm.py already captures whatever kwargs the submodule receives during export, including additional inputs.
  • Monolithic decode warmup uses set_generate_only_batch() which doesn't pass multimodal extras.
  • Piecewise prefill warmup in _setup_piecewise_mixed_batch also uses synthetic sequences without extras.

3. Provider Registry Creates Tight Coupling

The AttentionMaskProviderRegistry keyed by (model_type, backend) means:

  • Every new model needing masks must register providers for every backend.
  • Model-specific mask semantics are coupled to backend code.
  • Adding a new backend requires updating all model providers.

4. mutates_args Regression

torch_cached_attention_with_cache changed from mutates_args=("k_cache", "v_cache") to mutates_args=(). The k/v caches are mutated in-place during context attention (_torch_context_mha writes to them). Removing this annotation means the graph compiler may not track these mutations correctly.

5. Mask Computation Outside Graph Is Fragile

Computing the mask in the VLM wrapper (outside the graph) means:

  • The wrapper must know backend-specific mask formats (e.g., triton requires [B, 1, S_q, S_k] uint8).
  • The mask must be constructed before the graph call, requiring access to runtime metadata.
  • The graph cannot be tested in isolation for mask correctness.

Alternative Design Proposal

Principle: Two Modes for attn_mask, Analogous to Attention's Own Cached/Uncached Swap Pattern

The canonical IR op torch_attention already has an attn_mask parameter. Instead of building a separate injection transform, make better use of this parameter with two modes. The existing is_causal boolean and attn_mask: Optional[torch.Tensor] parameters are preserved as-is.

Mode A: String Mask Types (Design Now, Implement Later)

In addition to Optional[torch.Tensor], allow attn_mask to accept well-known string literals (e.g. "causal", "bidirectional"). Each attention backend maps recognized strings to its optimized implementation or raises an error if unsupported. The existing is_causal boolean is kept; when attn_mask is a string, it takes precedence.

This mode is not required for the immediate VLM iteration — it should be part of the design but can be implemented as a follow-up.

Mode B: Mask Custom Op (for VLM/Gemma4)

For complex, model-specific masks that cannot be described by a string:

Key layout constraint: During export, the model runs in [B, S] layout (e.g., set_example_sequence creates [2, 4] inputs). Only after the kvcache transform does the graph switch to ragged [1, total_tokens] layout. The mask custom op's reference implementation must therefore operate in [B, S] layout.

  1. Define a mask custom op in model code (not backend code):
@torch.library.custom_op("auto_deploy::gemma4_build_vlm_mask", mutates_args=())
def gemma4_build_vlm_mask(
    is_vision_token: torch.Tensor,  # [B, S] bool - same layout as input_ids at export time
    seq_len: int,
) -> torch.Tensor:
    # Reference: build VLM attention mask from vision token indicator.
    # Operates in [B, S] layout (pre-kvcache-transform).
    # Returns mask in [B, 1, S, S] compatible with torch_attention attn_mask.
    ...
  1. During export: The model's forward calls this custom op to build the mask from is_vision_token, passes the result to torch_attention(attn_mask=...). The existing capture_forward_kwargs mechanism naturally captures is_vision_token as a graph input with dynamic shapes. No new infrastructure needed.

  2. During kvcache transform: Just like torch_attention is swapped for cached_attention, the mask custom op is identified as a single node in the graph and swapped for a backend-specific version. The backend version:

    • Receives the original args (e.g., is_vision_token — now in ragged layout)
    • Receives metadata args (e.g., batch_info_host, sequence boundaries)
    • Produces a mask in the format the backend's cached attention kernel expects
  3. Backend-specific mask functions registered from model code:

# In model code (e.g., _torch/auto_deploy/models/gemma4.py), not backend code
MaskOpRegistry.register(
    source_op="auto_deploy::gemma4_build_vlm_mask",
    backend="triton_paged",
)(gemma4_triton_paged_mask_fn)

The registration is model-driven: each model declares which backends it supports masks for. A backend with no registered mask function for a given mask op raises a clear error. This avoids the (model_type, backend) cartesian product problem.

VLM Prefill/Decode Handling

Following the existing prefill/decode specialization pattern:

  • Prefill with images: Backend mask function reads is_vision_token (ragged layout), uses metadata (batch_info_host, cu_seqlen) to build the VLM mask.
  • Prefill without images: is_vision_token is all-False; backend mask function produces standard causal mask or returns a sentinel for the fast path.
  • Decode: Backend mask function checks batch_info_host, sees decode-only, returns a no-op. This is CUDA-graph compatible because:
    • Decode runs monolithic CUDA graphs.
    • The mask function still executes (same graph structure) but produces a dummy output.
    • The attention kernel's decode path does not consume the mask output.
    • is_vision_token is padded to the CUDA graph bucket size (all-False), so tensor shapes are stable.

The mask computation itself is not expected to be CUDA-graph compatible (it may involve data-dependent control flow). This is fine because mask computation only matters for prefill, and prefill is not CUDA-graphed. For decode CUDA graphs, the mask function runs but is a no-op.

Key Advantages

  1. No type variance: is_vision_token is always a tensor. No None/tensor switching for graph inputs.
  2. No _default_extra_arg_factories: Uses existing capture_forward_kwargs and extra_args mechanisms.
  3. Separation of concerns: Model code defines mask semantics and registers backend implementations. Backend attention code only needs to accept the mask output.
  4. CUDA graph compatible by design: Decode ignores mask output; mask function is a no-op.
  5. Testable: Mask computation is in the graph — graph-level tests verify correctness end-to-end.
  6. Preserves existing API: is_causal kept as-is; attn_mask extended naturally.

Recommendation

I'd suggest splitting the valuable parts (sliding window support, extract_op_args cleanup, the triton masked kernel) into separate PRs, and redesigning the mask injection infrastructure along the lines above.

for ``name``. It is called at the start of every ``nest_sequences``
so that initialization-time forward passes (e.g. ``resize_kv_cache``)
always receive a valid tensor for ``name`` even when no per-request data
is available.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The (model_type, backend) key creates a tight coupling: every new model needing custom masks must register a provider for every backend, and adding a new backend requires updating all model providers.

An alternative is a model-driven registry keyed by (source_mask_op, backend). The model defines a mask custom op (e.g. auto_deploy::gemma4_build_vlm_mask) and registers backend-specific replacements from its own code. The kvcache transform then identifies the mask op node in the graph and swaps it — analogous to how torch_attention is swapped for cached_attention. This way model-specific mask semantics stay in model code, not in the attention backend.

@staticmethod
def _get_attn_mask_arg(node: Node):
return extract_op_args(node, "attn_mask")[0]

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than a separate injection transform that rewires attn_mask on existing torch_attention nodes, consider having the model code produce the mask in-graph via a mask custom op during export. The mask op output feeds into torch_attention(attn_mask=...) naturally.

During the kvcache transform, the mask custom op is swapped for a backend-specific version (just like torch_attention is swapped for cached_attention). This eliminates the need for this transform entirely and keeps mask computation inside the graph where it can be tested and verified.


# EXTRA TENSOR FIELDS ######################################################################
self._extra_args: Dict[str, Optional[torch.Tensor]] = {}
# Default factories for extra args: callables that accept this SequenceInfo instance
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This _default_extra_arg_factories machinery is unnecessary. The existing capture_forward_kwargs mechanism in export_to_gm.py already captures whatever kwargs the submodule receives during export — including additional inputs like is_vision_token. For warmup/CUDA-graph paths:

  • Monolithic decode warmup uses set_generate_only_batch() which doesn't pass multimodal extras.
  • Piecewise prefill warmup in _setup_piecewise_mixed_batch uses synthetic sequences without extras.

Both paths already produce valid inputs without needing factories. Adding this callback system to SequenceInfo increases the surface area of a critical class without a clear necessity.

logit_cap: Optional[float] = None,
read_cache_only: bool = False,
# DYNAMIC INPUTS
custom_attn_mask: Optional[torch.Tensor] = None,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two concerns with the signature changes in this op:

  1. mutates_args regression: The decorator was changed from mutates_args=("k_cache", "v_cache") to mutates_args=(). The k/v caches are still mutated in-place during context attention (_torch_context_mha writes to k_cache and v_cache). Removing this annotation means the graph compiler may not track these mutations correctly.

  2. Dynamic inputs as kwargs: Rather than threading custom_attn_mask through as a dynamic kwarg on the cached attention op, consider having the mask computation live inside the graph as a separate custom op node whose output feeds into the attention op. During the kvcache transform, that mask node gets swapped for a backend-specific version — keeping the cached attention op signature stable.

meta_nodes_std: List[Node],
meta_nodes_extra: List[Node],
cache_nodes: List[Node],
dynamic_kwargs: Dict[str, Optional[Node]],
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dynamic_kwargs / get_dynamic_inputs() mechanism adds a generic escape hatch for forwarding arbitrary tensor kwargs through the cached attention op. This is powerful but also fragile — it changes the calling convention of every cached attention op and requires all backends to accept these kwargs.

With the mask-custom-op approach (where the mask is computed by a separate node in the graph that gets swapped during this same transform), this generic mechanism wouldn't be needed. The mask node swap would happen alongside the attention node swap, and the mask output would be consumed by the attention op as a regular positional arg — no kwargs plumbing required.

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42151 [ run ] completed with state SUCCESS. Commit: a92711c
/LLM/main/L0_MergeRequest_PR pipeline #32983 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants