Skip to content

[None][feat] AutoDeploy: Gemma4 multimodal support with custom attention mask#12744

Draft
bmarimuthu-nv wants to merge 21 commits intoNVIDIA:mainfrom
nv-auto-deploy:bala/gemma4-mm
Draft

[None][feat] AutoDeploy: Gemma4 multimodal support with custom attention mask#12744
bmarimuthu-nv wants to merge 21 commits intoNVIDIA:mainfrom
nv-auto-deploy:bala/gemma4-mm

Conversation

@bmarimuthu-nv
Copy link
Copy Markdown
Collaborator

@bmarimuthu-nv bmarimuthu-nv commented Apr 3, 2026

Summary

  • Custom attention mask infrastructure: Adds graph-level attention mask injection for AutoDeploy attention backends. A registry-based provider system lets models supply custom masks (e.g., bidirectional within media blobs, causal for text). Supports torch, triton_paged, and torch_attention backends.
  • Gemma4 VLM export pattern: Switches Gemma4ForConditionalGeneration from FullModelExportInfo to TextModelExportInfo (VLM pattern). The Gemma4ForCausalLM (including lm_head + softcapping) is the export target, while the outer wrapper handles token_type_ids plumbing.
  • Gemma4 custom attention mask: Implements a Gemma4-specific mask provider that builds bidirectional-within-media-blob + causal masks from token_type_ids. During prefill with images, the triton paged backend uses this mask via a per-sequence fallback path.
  • Gemma4ADInputProcessor: Ensures every request (text-only or multimodal) always has token_type_ids in py_multimodal_data, solving mixed-batch scenarios where some requests have images and others don't.

Key design decisions

  • No model-specific code in AD infra (ad_executor, attention_interface). All Gemma4 logic lives in the custom model file.
  • token_type_ids is added as a graph input by the InjectCustomAttentionMask transform post-export (not in the model forward signature).
  • The input processor guarantees token_type_ids is always present so _extra_args covers the full flattened batch on every step.

Test plan

  • pytest tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma4_modeling.py — all 8 tests pass
  • pytest tests/unittest/auto_deploy/singlegpu/transformations/library/test_inject_custom_attention_mask.py — all 3 tests pass
  • pytest tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py
  • E2E: Gemma4 with image prompts via trtllm-serve

🤖 Generated with Claude Code

Summary by CodeRabbit

Release Notes

  • New Features

    • Added support for Gemma 4 and Gemma 3n models with AutoDeploy backend
    • Implemented linear attention and recurrent state caching for optimized memory usage
    • Enabled shared KV cache across attention layers
    • Added custom attention mask support across multiple attention backends
  • Improvements

    • Enhanced KV cache management with placeholder block handling
    • Updated supported models documentation
  • Documentation

    • Added Gemma 4 deployment cookbook with AutoDeploy guidance

bmarimuthu-nv and others added 11 commits April 2, 2026 18:59
…IA#12205)

Adds Gemma3n custom model with shared KV attention, sliding window attention,
and related attention backend changes for AutoDeploy.

Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Adds Gemma4 (MoE) custom model for AutoDeploy with:
- Custom modeling code supporting K=V attention, proportional RoPE,
  parallel dense+MoE, per-layer scalars, and logit softcapping
- Gelu activation support in torch_moe for Gemma4 MoE layers
- Hierarchical equivalence tests
- Model registry config (triton_paged attention backend for head_dim=512)

Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
…nd tests

- Remove incorrect +1.0 scale_shift from Gemma4RMSNorm. HF transformers
  5.5.0 stores effective norm weights directly in the checkpoint; the
  previous implementation incorrectly added 1.0 at load time, causing
  compounding numerical drift across layers and garbled generation.
- Add google/gemma-4-26B-A4B base model registry entry with
  gemma4_moe_base.yaml config.
- Strengthen test_full_model_equivalence with end-to-end logits
  comparison against standalone reference model.
- Add export functional equivalence assertion (pre-export vs post-export).
- Update reference _RefRMSNorm to match corrected norm semantics.
- Update MoE block test to manually unfuse weights (hook now on decoder
  layer, not MoE block).

Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
…ked prefill

Add piecewise CUDA graph compilation, expanded batch sizes, chunked
prefill, and KV cache config to both gemma4_moe.yaml and
gemma4_moe_base.yaml.

Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Squash merge of nv-auto-deploy:bala/custom-attn-mask (PR NVIDIA#12742).

Adds custom attention mask injection infrastructure for AutoDeploy
attention backends and Gemma4 model-specific attention masking based
on token type grouping and media boundaries.

Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
… mask support

Switch Gemma4ForConditionalGeneration from FullModelExportInfo to
TextModelExportInfo (VLM pattern). The CausalLM wrapper including
lm_head is now the export target, while the outer wrapper handles
token_type_ids plumbing for the custom attention mask injector.

- Move lm_head into Gemma4ForCausalLM via Gemma4Model.language_model
- Add lm_head weight remapping hook for HF checkpoint compatibility
- Forward **kwargs through wrapper to exported graph (token_type_ids)
- Switch factory to AutoModelForImageTextToTextFactory
- Add Gemma4ADInputProcessor to ensure token_type_ids is always present
  in py_multimodal_data for all requests (text-only gets zeros)

Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Use FakeTensor from query metadata to create the token_type_ids
placeholder value, preserving symbolic dimensions from torch.export.
The previous approach tried to create a concrete tensor with symbolic
ints which failed at runtime.

Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
…ider

- Use FakeTensor from query metadata for token_type_ids placeholder to
  preserve symbolic dimensions from torch.export
- Derive causal positions from token_type_ids via cumsum to inherit
  correct device during shape propagation (meta) and runtime (cuda),
  replacing bare arange which defaulted to CPU

Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 3, 2026

📝 Walkthrough

Walkthrough

The PR adds linear-attention (recurrent-state) support to the KV cache management system with placeholder block handling, refactors attention backend abstractions to enable shared-KV and custom attention masks across multiple implementations, introduces Gemma3n and Gemma4 model implementations for auto-deploy, and adds comprehensive tests for new features.

Changes

Cohort / File(s) Summary
Eviction Policy Core
cpp/include/tensorrt_llm/batch_manager/evictionPolicy.h, cpp/tensorrt_llm/batch_manager/evictionPolicy.cpp
Added placeholder-level cache management: initialize now takes blocksPerCacheLevel instead of sizes, getFreeBlock accepts wantPlaceholder boolean, and new initializePlaceholders method manages placeholder blocks. Internal storage migrated to BidirectionalVector supporting negative block IDs.
KV Cache Manager Core
cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h, cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Major expansion for linear-attention support: new LinearAttentionMetadata struct, KVCacheBlock now carries windowSize, createPlaceholder requires windowSize, added copyLinearAttentionBlock, placeholder allocation (tryAllocatePlaceholderForLinearAttention), and modified loadOrAllocateBlocks signature with LlmRequest& and shareLastContextBlockAmongBeams. getBlockById return type changed from reference to value. New getters: getTokenCount, getLinearAttentionMetadata, getWindowBlockManager, getRecurrentStatesPool.
KV Cache Index Sentinel
cpp/include/tensorrt_llm/kernels/kvCacheIndex.h
Added null-index sentinel: kInvalidPoolIndex constant, static nullIndex member, isNull() method, and private default constructor for creating null indices.
Transfer & Configuration
cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp, cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp, cpp/tensorrt_llm/executor/kvCacheConfig.cpp
Updated cache transfer to handle layerFirstLayout pools via per-layer slicing instead of single contiguous copy; added linear-attention validation. Modified KV cache sizing to pass explicit attention parameters instead of ModelConfig. Added #include <numeric> for std::reduce.
Python Bindings
cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp
Exposed new C++ APIs to Python: getTokenCount, LinearAttentionMetadata class/enum, updated calculate_max_num_blocks signature, added get_recurrent_states_pool, copy_linear_attention_block methods, and linear_attention_metadata constructor parameter.
Attention Interface Abstraction
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
Introduced backend-agnostic hooks: supports_shared_kv(), get_dynamic_inputs(), get_layer_idx(), get_shared_kv_source_layer_idx() classmethods. Added helper _extract_optional_op_arg() for schema-aware argument extraction. Updated docstring for constant ordering.
FlashInfer Attention Backend
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py
Extended PlanParams with window_left, plumbed sliding-window into FlashInfer planning, added _to_flashinfer_window_left translation helper, extended custom op signature with sliding_window and read_cache_only parameters, added supports_shared_kv() returning True, and refactored constant extraction.
Torch Attention Backends
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_attention.py, tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py
Extended signatures with layer_idx, shared_kv_source_layer_idx, custom_attn_mask, read_cache_only parameters. Added readonly attention paths (_torch_generate_mha_readonly, _torch_context_mha_readonly), inverted-mask logic for custom masks, dynamic input extraction, and supports_shared_kv() returning True. Removed cache mutability registration for readonly path.
Triton & TRTllm Attention Backends
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py, tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py, tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py
Updated to use extract_op_args(...) helper for consistent layout extraction; Triton added triton_paged_context_with_custom_mask fallback for custom masks and extended signature with custom_attn_mask and optional scale; added dynamic input handling.
Attention Mask Provider Framework
tensorrt_llm/_torch/auto_deploy/transform/attention_mask_provider.py, tensorrt_llm/_torch/auto_deploy/transform/attention_mask_providers.py
New framework for model-specific attention mask providers: AttentionMaskProviderContext with caching, AttentionMaskProviderRegistry for (model_type, backend) lookup, and Gemma4-specific token-type-derived mask construction with blob-based media masking.
Attention Mask Injection Transform
tensorrt_llm/_torch/auto_deploy/transform/library/inject_custom_attention_mask.py
New graph-level transform to inject backend-specific attention masks into torch.ops.auto_deploy.torch_attention nodes via registry lookup, with override_existing_mask control.
Shared-KV Cache Transform
tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py, tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py
Refactored to support shared-KV via layer-indexed cache ownership, validation of shared_kv_source_layer_idx, dynamic input plumbing, and metadata caching in node. Updated _insert_cached_attn_node signature and dynamic-input handling.
Gemma3n & Gemma4 Models
tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma3n.py, tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py, tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py
New AutoDeploy-focused implementations: Gemma3n text-only with RoPE, sliding-window attention, optional Gaussian sparsity; Gemma4 multimodal with MoE router, dual-path feedforward, unfused weight loading. Both support conditional generation and register with custom factories. Exposed via __init__.py.
Mistral & Mixtral Sliding-Window Support
tensorrt_llm/_torch/models/modeling_mistral.py, tensorrt_llm/_torch/models/modeling_mixtral.py
Added per-layer sliding-window configuration: read config.layer_types for Mistral, store in self.attention_window_size, and forward via super().forward(...) with attention_window_size parameter.
Auto-Deploy Config & Utils
tensorrt_llm/_torch/auto_deploy/config/default.yaml, tensorrt_llm/_torch/auto_deploy/utils/node_utils.py, tensorrt_llm/_torch/auto_deploy/utils/_graph.py, tensorrt_llm/_torch/auto_deploy/export/export.py
Added inject_custom_attention_mask transform stage; extracted schema introspection into public get_op_schema(op) helper; updated MOE graph expansion to use get_op_schema instead of private attributes.
MOE Activation Support
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py, tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py
Extended Gelu/Geglu support: torch_moe maps Gelu to partial(F.gelu, approximate="tanh"); trtllm_moe validates gated-MLP with Gelu, introduced _normalize_trtllm_act_fn for consistent Silu→Swiglu and Gelu→Geglu mapping across quantized paths.
Transform Module Init
tensorrt_llm/_torch/auto_deploy/transform/__init__.py
Added attention_mask_providers import to ensure provider registration during package initialization.
Gemma Model Registry & Configs
examples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yaml, examples/auto_deploy/model_registry/configs/gemma4_moe.yaml, examples/auto_deploy/model_registry/configs/gemma4_moe_base.yaml, examples/auto_deploy/model_registry/models.yaml
Added Gemma3n IT and Gemma4 MoE (base and IT) configurations with FlashInfer/Triton backends, chunked prefill, KV cache settings, and piecewise compilation. Registered model entries in models.yaml.
KV Cache Tests
cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp, cpp/tests/unit_tests/batch_manager/radixBlockTreeTest.cpp
Extended with linear-attention test suites (testBlockManagerLinearAttention_ContextNoReuse, testBlockManagerLinearAttention_ContextReuse, testKVCacheManagerLinearAttention_DecodingBlockGrowth, testKVCacheManagerLinearAttention_BlockCopying), updated placeholder creation calls with windowSize argument.
Model Tests
tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma3n_modeling.py, tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma4_modeling.py, tests/unittest/_torch/modeling/test_modeling_mistral.py, tests/unittest/_torch/modeling/test_modeling_mixtral.py
Comprehensive equivalence suites against HF references, sliding-window validation, shared-KV metadata checks, export/torch-compile validation, MoE block equivalence, and reference math implementation.
Attention Operation Tests
tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_flashinfer_attention_op.py, tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_torch_attention_op.py, tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py
Added custom attention mask test cases (test_torch_backend_attention_custom_bool_mask_context, test_context_with_custom_bool_mask_matches_torch_attention), updated op calls with new None/parameter placeholders for sliding-window and read-cache-only support.
Shared-KV & Mask Injection Tests
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_shared_kv_attention.py, tests/unittest/auto_deploy/singlegpu/transformations/library/test_inject_custom_attention_mask.py
New comprehensive test suites: shared-KV cache aliasing, readonly attention behavior, backend metadata encoding, piecewise compilation; custom mask injection with caching, Gemma4 token-type mask provider registration, reference output matching.
Compilation & Reference Utils
tests/unittest/auto_deploy/singlegpu/compile/test_captured_graph.py, tests/unittest/auto_deploy/_utils_test/torch_attention_reference.py
Updated BatchInfo serialization to use structured constructor, adjusted encoded payloads; updated cached-attention reference calls with None argument for new parameters.
Documentation & Metadata
docs/source/models/supported-models.md, examples/auto_deploy/cookbooks/gemma_4_trtllm_cookbook.ipynb, security_scanning/metadata.json, tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py, tests/integration/test_lists/waives.txt
Added Gemma4 to supported models list, created Gemma4 deployment cookbook, updated security scanning metadata, clarified docstring in DualModeCapturedGraph, added waived test entries.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • thorjohnsen
  • SimengLiu-nv
  • eopXD
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
⚔️ Resolve merge conflicts
  • Resolve merge conflict in branch bala/gemma4-mm

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 19

🧹 Nitpick comments (14)
cpp/tests/unit_tests/batch_manager/radixBlockTreeTest.cpp (1)

537-540: Optional cleanup: name the new placeholder test constants.

Hoisting the new block/window literals into constexprs will make the updated test easier to scan, and ph can be const as well.

Possible cleanup
-    auto ph = KVCacheBlock::createPlaceholder(42, 100);
+    constexpr auto kBLOCK_ID = 42;
+    constexpr auto kWINDOW_SIZE = 100;
+    auto const ph = KVCacheBlock::createPlaceholder(kBLOCK_ID, kWINDOW_SIZE);
     ASSERT_NE(ph, nullptr);
     EXPECT_TRUE(ph->isPlaceholder());
-    EXPECT_EQ(ph->getBlockId(), 42);
+    EXPECT_EQ(ph->getBlockId(), kBLOCK_ID);
As per coding guidelines, "A variable that is not modified after its initialization should be declared as `const`." and "Except `0`, `nullptr`, `true`, `false`, all other literals in C++ should only be used for variable initialization; extract other literal usages to named constants."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tests/unit_tests/batch_manager/radixBlockTreeTest.cpp` around lines 537 -
540, Extract the magic literals used in the test into named constexprs (e.g.,
constexpr auto kPlaceholderBlockId = 42; constexpr auto kPlaceholderWindow =
100;) and use them when calling KVCacheBlock::createPlaceholder, and declare the
returned pointer as const (const auto ph = ...). Keep the existing assertions
(ph != nullptr, ph->isPlaceholder(), ph->getBlockId()) but compare getBlockId()
against the new kPlaceholderBlockId constant so the test uses named constants
instead of bare literals and ph is immutable.
cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp (1)

665-668: Name the new K/V factor instead of hardcoding 2.

This callsite now carries one more cache-sizing assumption. Giving the factor a named constant, or pulling it from model metadata if that exists, will make the new API path much easier to audit.

Possible cleanup
+    constexpr SizeType32 kKV_FACTOR = 2;
     auto const sizePerHead = mModelConfig.getSizePerHead();
     auto blocksPerWindow = KVCacheManager::calculateMaxNumBlocks(kvCacheConfig, kvDtype, numKvHeadsPerLayer,
         sizePerHead, tokensPerBlock, mWorldConfig, windowSizeToLayers, freePrimaryMemBytes, freeSecondaryMemBytes,
-        extraCostMemory, 2, getMaxBatchSize());
+        extraCostMemory, kKV_FACTOR, getMaxBatchSize());
As per coding guidelines, "Except `0`, `nullptr`, `true`, `false`, all other literals in C++ should only be used for variable initialization; extract other literal usages to named constants."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp` around lines
665 - 668, The call to KVCacheManager::calculateMaxNumBlocks currently passes a
hardcoded K/V factor "2"; replace this literal with a named constant or a value
retrieved from model metadata (e.g., add a const int kvFactor = ... or read from
mModelConfig) and pass kvFactor instead of 2 at the callsite that uses
mModelConfig.getSizePerHead(), kvCacheConfig, numKvHeadsPerLayer,
tokensPerBlock, mWorldConfig, windowSizeToLayers, freePrimaryMemBytes,
freeSecondaryMemBytes, extraCostMemory, and getMaxBatchSize(); ensure the
constant name clearly indicates its meaning (e.g., kKvFactor or
kvFactorFromModel) and update any relevant callers or tests accordingly.
cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp (2)

6817-6819: Don't store seq1 before simulating prefill completion.

Unlike seq0, seq1 never has its synthetic prefill state advanced to the full prompt before storeContextBlocks() on Line 6819. That means this helper is persisting only the already-reused prefix, so the checks below end up depending on partial-prefill side effects instead of just the reuse result under test. Either drop this call or mirror the same completed-prefill setup used for seq0.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp` around lines 6817
- 6819, The test calls blockManager.storeContextBlocks(seq1, *llmRequest1) while
seq1's synthetic prefill was never advanced to the full prompt (unlike seq0),
causing the helper to persist only a partial prefix; either remove the
storeContextBlocks/holdSequence calls for seq1 or advance seq1's synthetic
prefill to the same completed-prefill state used for seq0 before calling
blockManager.storeContextBlocks(seq1, *llmRequest1) so that the test asserts
only the reuse behavior; locate references to seq1, seq0, storeContextBlocks,
holdSequence, and llmRequest1 to mirror the prefill-completion steps applied to
seq0.

6591-6598: Make the linear-layer layout explicit in these fixtures.

These helpers later assume a concrete mapping: Line 7089 treats layer 0 as linear, and Line 7134 assumes exactly half the layers contribute recurrent state. Right now every LinearAttentionMetadata initializer leaves linearLayerIndices implicit, so the assertions depend on a hidden default instead of the fixture setup. Setting it explicitly here will make the tests much less brittle.

Also applies to: 6706-6712, 6935-6942, 7047-7054

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp` around lines 6591
- 6598, The fixture omits an explicit linear layer layout, relying on a hidden
default and making tests brittle; update the LinearAttentionMetadata
initializers (e.g., the linearAttentionMetadata instances) to set
linearLayerIndices explicitly to match the tests' expectations (for example,
include layer 0 as linear and choose indices so exactly half the layers
contribute recurrent state where those assertions expect it), so that functions
referencing linearLayerIndices use the explicit mapping rather than an implicit
default.
cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h (4)

1831-1843: Consider more descriptive variable name.

The variable nkvh is cryptic. Consider a more descriptive name like numKvHeadsForWindow or windowLayerKvHeads for better readability.

📝 Suggested improvement
     {
-        std::vector<SizeType32> nkvh;
-        nkvh.reserve(windowSizeLayers.size());
+        std::vector<SizeType32> numKvHeadsForWindow;
+        numKvHeadsForWindow.reserve(windowSizeLayers.size());
         for (auto const layer : windowSizeLayers)
         {
-            nkvh.push_back(numKvHeadsPerLayer.at(layer));
+            numKvHeadsForWindow.push_back(numKvHeadsPerLayer.at(layer));
         }
-        auto const sumLocalHeads = std::reduce(nkvh.cbegin(), nkvh.cend());
+        auto const sumLocalHeads = std::reduce(numKvHeadsForWindow.cbegin(), numKvHeadsForWindow.cend());
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h` around lines 1831 -
1843, Rename the cryptic local variable nkvh to a descriptive name (e.g.,
numKvHeadsForWindow or windowLayerKvHeads) in the function in kvCacheManager.h:
update its declaration, reserve call, push_back loop, and the std::reduce call
that computes sumLocalHeads so all uses reflect the new name and intent; ensure
comments remain accurate and build after the rename.

198-198: Redundant type alias that may cause confusion.

SizeType32 is already defined at line 74 as tensorrt_llm::runtime::SizeType32. This re-declaration using SizeType32 = WindowSizeType; creates a circular alias that adds no value and could confuse readers. Consider removing this line.

🔧 Suggested fix
-using SizeType32 = WindowSizeType;
-
 struct TempAttentionWindowInputs
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h` at line 198, Remove
the redundant circular alias "using SizeType32 = WindowSizeType;" — it conflicts
with the existing tensorrt_llm::runtime::SizeType32 declared earlier; delete
this line and ensure all code in this header continues to use the original
SizeType32 (or WindowSizeType where intended) so there is no ambiguous
re-aliasing; if any references relied on the alias, update them to the correct
symbol (tensorrt_llm::runtime::SizeType32 or WindowSizeType) to keep meanings
clear.

1817-1820: Parameter naming uses snake_case instead of camelCase.

The parameter layer_idx should be layerIdx per coding guidelines (camelCase for C++ parameters). However, this is consistent with the existing methods on lines 1817-1819 that also use layer_idx. Consider updating all occurrences for consistency with the guideline, but this is a pre-existing pattern.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h` around lines 1817 -
1820, Rename the parameter layer_idx to layerIdx in the virtual method
declarations getPrimaryPool, getIndexerKCachePool (if applicable),
getPoolLayerIdx, and isPoolLayerFirst in kvCacheManager.h so they follow
camelCase parameter naming; update the parameter name in each method signature
(and any implementing overrides in derived classes) and adjust references to
that parameter within those functions to use layerIdx to maintain consistency
with the coding guideline.

117-196: Well-structured metadata struct with minor documentation suggestion.

The LinearAttentionMetadata struct is well-designed for managing linear attention/recurrent state caching. A few observations:

  1. Line 121: Consider documenting why 0x80000001 was chosen for kRecurrentStates (e.g., using high bit to distinguish from valid window sizes).

  2. Line 190-191: If allRecurrentStatesBytes or numLayers is zero, perBlockBytes becomes zero, causing division by zero. Consider adding a guard:

📝 Suggested documentation and safety improvement
     enum LinearCacheType : WindowSizeType
     {
+        // High bit set to distinguish from valid window sizes (which are positive integers).
+        // The value 0x80000001 is intentionally negative when interpreted as signed int32.
         kRecurrentStates = static_cast<WindowSizeType>(0x80000001),
     };
         if (hasRecurrentStatesCache(encodedWindowSize))
         {
             TLLM_CHECK_WITH_INFO(
                 encodedWindowSize == kRecurrentStates, "each pool must only serve one type of linear cache");
             TLLM_CHECK_WITH_INFO(statesSnapshotInterval % tokensPerBlock == 0,
                 "statesSnapshotInterval must be multiple of tokensPerBlock");
             // take a snapshot every `blockAlignment` blocks.
             auto perBlockBytes = allRecurrentStatesBytes * numLayers;
+            TLLM_CHECK_WITH_INFO(perBlockBytes > 0, "perBlockBytes must be positive for recurrent states");
             auto numDynamicBlocks = (memoryBudget / perBlockBytes);
             return static_cast<SizeType32>(numDynamicBlocks);
         }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h` around lines 117 -
196, LinearAttentionMetadata uses a special encoded value
LinearCacheType::kRecurrentStates (0x80000001) and its
calcMaxMemoryBlocks(WindowSizeType, SizeType32, size_t, SizeType32) computes
perBlockBytes = allRecurrentStatesBytes * numLayers which can be zero and cause
division by zero; add a brief comment near LinearCacheType::kRecurrentStates
explaining the high-bit encoding rationale, and in calcMaxMemoryBlocks validate
that allRecurrentStatesBytes > 0 and numLayers > 0 (or that perBlockBytes > 0)
before dividing—return 0 or throw a descriptive error via TLLM_CHECK/TLLM_THROW
if the budget cannot be computed—to avoid division-by-zero and surface clear
diagnostics referencing the member allRecurrentStatesBytes and the numLayers
parameter.
tensorrt_llm/_torch/auto_deploy/transform/attention_mask_provider.py (1)

48-53: Narrow the _get_model_config() fallback.

Returning None on every Exception turns real factory bugs into a silent “no provider” fallback, which is especially risky in mask-provider dispatch. Catch only the expected lookup failures here, and log or re-raise unexpected ones so broken custom-mask injection doesn’t get masked. As per coding guidelines: Avoid broad exception handling — catch specific exceptions, not bare except: (see CODING_GUIDELINES.md).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/transform/attention_mask_provider.py` around
lines 48 - 53, The code currently swallows all exceptions from calling
factory._get_model_config via a bare except, hiding real bugs; change the
try/except around get_model_config() to only catch the expected
lookup/missing-config errors (e.g., KeyError, LookupError or a specific
FactoryLookupError if one exists) and return None for those, but for any other
Exception log the error (or re-raise) so real failures aren’t silenced; update
the block around get_model_config / _get_model_config to catch specific
exceptions and either return None for known lookup exceptions or log and
re-raise unexpected exceptions.
tensorrt_llm/_torch/models/modeling_mistral.py (1)

89-97: Per-layer sliding window logic looks correct, but consider adding a bounds check.

The logic correctly distinguishes between Ministral (per-layer via layer_types) and standard Mistral (uniform SWA). However, if layer_idx >= len(config.layer_types), an IndexError would be raised.

🛡️ Optional defensive check
         layer_types = getattr(config, "layer_types", None)
-        if layer_types is not None and layer_idx is not None:
+        if layer_types is not None and layer_idx is not None and layer_idx < len(layer_types):
             is_sliding = layer_types[layer_idx] == "sliding_attention"
             self.attention_window_size = config.sliding_window if is_sliding else None
         else:
             self.attention_window_size = getattr(config, "sliding_window", None)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/models/modeling_mistral.py` around lines 89 - 97, The
per-layer sliding window logic may index past config.layer_types; add a
defensive bounds check before using layer_types[layer_idx] in the block that
checks layer_types is not None and layer_idx is not None: ensure layer_idx is an
int within 0 <= layer_idx < len(layer_types) (or fall back to uniform
config.sliding_window) and then set self.attention_window_size accordingly;
update the logic around layer_types, layer_idx, and getattr(config,
"sliding_window", None) so an IndexError cannot occur.
tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_flashinfer_attention_op.py (1)

134-138: Consider using keyword arguments for the new optional parameters.

The call signature now has multiple consecutive None values which reduces readability. Since the AI summary indicates these correspond to sliding_window and read_cache_only parameters, using keyword arguments would make the intent clearer.

# Current: positional Nones are ambiguous
None,
None,
1.0,
1.0,

# Clearer with keywords:
sliding_window=None,
read_cache_only=None,
k_scale=1.0,
v_scale=1.0,

This applies to all 8 call sites in this file.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_flashinfer_attention_op.py`
around lines 134 - 138, Replace the ambiguous positional None arguments with
explicit keyword arguments at each of the 8 call sites in this file: pass
sliding_window=None and read_cache_only=None instead of the two consecutive None
values, and pass k_scale=1.0 and v_scale=1.0 for the trailing floats so the call
becomes e.g. sliding_window=None, read_cache_only=None, k_scale=1.0,
v_scale=1.0; update every invocation that currently uses the positional None,
None, 1.0, 1.0 sequence so the intent is clear.
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py (1)

185-186: Rename unused head_dim to _head_dim.

The head_dim variable is unpacked from q.shape but never used in _torch_generate_mha_readonly. Prefix it with an underscore to indicate it's intentionally unused.

♻️ Proposed fix
-    b, s, n_heads, head_dim = q.shape
+    b, s, n_heads, _head_dim = q.shape
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py`
around lines 185 - 186, The unpacking in _torch_generate_mha_readonly currently
binds an unused variable head_dim; rename it to _head_dim to indicate it is
intentionally unused and silence linters. Locate the tuple assignment "b, s,
n_heads, head_dim = q.shape" inside _torch_generate_mha_readonly and change the
identifier to _head_dim while leaving the rest of the logic unchanged.
tensorrt_llm/_torch/auto_deploy/transform/attention_mask_providers.py (1)

18-18: Optional: Remove unnecessary from __future__ import annotations.

Since TensorRT-LLM requires Python 3.10+, the from __future__ import annotations import is unnecessary. This import was primarily needed for Python 3.7-3.9 to enable PEP 585/604 style annotations.

♻️ Proposed fix
-from __future__ import annotations
-
 import torch

Based on learnings: "In TensorRT-LLM (Python requires >=3.10 and <4 as per setup.py), you can use Python 3.10+ features... and you do not need to add from future import annotations."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/transform/attention_mask_providers.py` at
line 18, The file contains an unnecessary compatibility import "from __future__
import annotations" at the top of attention_mask_providers.py; since the project
targets Python >=3.10, remove that import line from the module (i.e., delete the
"from __future__ import annotations" statement) so type annotations rely on
native Python 3.10+ behavior and avoid redundant future imports.
tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py (1)

196-196: Rename unused loop variable idx to _idx.

The loop variable idx is not used within the loop body. Per Python convention, prefix it with an underscore to indicate it's intentionally unused.

♻️ Proposed fix
-        for idx, attn_node in enumerate(source_attn_nodes):
+        for _idx, attn_node in enumerate(source_attn_nodes):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py` at line 196,
The for-loop uses an unused index variable `idx`; update the loop header to
rename `idx` to `_idx` (i.e., change `for idx, attn_node in
enumerate(source_attn_nodes):` to use `_idx`) to follow Python convention for
unused variables and avoid linter warnings—ensure only the loop header is
changed and references to `attn_node` remain unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@cpp/include/tensorrt_llm/batch_manager/evictionPolicy.h`:
- Line 2: Update the copyright header year range in the file by changing the
existing copyright line that currently reads "Copyright (c) 2022-2025, NVIDIA
CORPORATION.  All rights reserved." to include 2026 (i.e., "2022-2026") so the
header reflects modification in this PR; ensure the exact header string is
updated wherever it appears in the file.

In `@cpp/include/tensorrt_llm/kernels/kvCacheIndex.h`:
- Around line 44-47: The constructor KVCacheIndex currently validates the packed
this->value, allowing KVCacheIndex{kInvalidPoolIndex, true} to bypass the null
invariant; fix it by validating the raw parameter before packing: check that the
incoming UnderlyingType value is >= 0 and != kInvalidPoolIndex (use those
symbols explicitly) prior to applying kSecondaryPoolFlag, then assign
this->value = isSecondary ? value | kSecondaryPoolFlag : value so get() and
isNull() remain consistent; update any TLLM_CHECK_DEBUG call in the KVCacheIndex
constructor to validate the raw value instead of the packed field.
- Line 42: The header currently defines KVCacheIndex::nullIndex out-of-class
which causes ODR violations; fix by either making the out-of-class definition
inline (C++17+) or by performing in-class constexpr initialization and removing
the separate definition: update the declaration/definition for
KVCacheIndex::nullIndex accordingly so there is a single definition across TUs
(refer to KVCacheIndex and nullIndex to locate the declarations/definition to
change).

In `@cpp/tensorrt_llm/batch_manager/evictionPolicy.cpp`:
- Around line 108-124: The validators in levelValidators must be side-effect
free: stop calling block->isPrimary() / isPlaceholder() (which trigger
TLLM_CHECK) and instead compare the block's cached level via a non-asserting
accessor (e.g. block->getCacheLevel() or block->getLevel()) inside the three
lambdas, and update the warning in verifyQueueIntegrity() so it prints the
actual level and expected level (use levelToStr[actualLevel] and
levelToStr[queueLevel] rather than levelToStr twice) referencing
levelValidators, levelToStr and the TLLM_LOG_WARNING call.

In `@cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp`:
- Around line 2901-2905: The wrapper KVCacheManager::copyLinearAttentionBlock
should avoid calling getSequence() unguarded because getSequence throws if the
request was removed; before retrieving the sequence call the existing presence
check (e.g., a hasSequence/findSequence/tryGetSequence method) or otherwise test
whether the requestId exists and return early if not, then call
mBlockManager.copyLinearAttentionBlock(sequence, llmRequest); this mirrors
WindowBlockManager::copyLinearAttentionBlock's graceful handling and prevents
the wrapper from throwing when the request is removed between scheduling and
copy.
- Around line 1817-1858: The code reuses a single placeholder returned to
allocateBlock and then mutates it inside the beam loop
(placeholder->setBlockKey, placeholder->setHash, addBlockToBeam), causing
placeholder state to be overwritten across beams when lastBlockIds differ; fix
by allocating a distinct placeholder per beam before mutating it (call
getFreeBlock(..., /*wantPlaceholder=*/true) inside the beam loop or clone the
original placeholder into a new KVCacheBlock for each beam), then perform
setBlockKey, setHash, addBlockToBeam, and ref-count updates on that per-beam
placeholder so each beam has an independent placeholder instance (refer to
functions/getters: allocateBlock logic, getFreeBlock, addBlockToBeam,
KVCacheBlock::setBlockKey, setHash).
- Around line 1984-1995: The loop that finds a successor using
beamBlockIds/nextBlockIndex/nextBlockId/nextBlock may terminate without finding
a real (non-placeholder) successor, yet the code unconditionally asserts
TLLM_CHECK(nextBlockId != -1) and proceeds to use nextBlockId; change this so
after the loop you validate that nextBlock is non-null and not a placeholder
(e.g., check getBlockById(nextBlockId) and nextBlock->isPlaceholder()) and if no
real successor was found set nextBlockId to -1 and skip the onboarding path that
uses onboardedBlocks and prevBlockId (or return/continue) instead of performing
the copy into a placeholder; update the TLLM_CHECK or replace it with a guarded
branch that only runs the onboarding logic when a real successor exists.
- Around line 470-490: The loop in
KVCacheBlock::detachPreviousPlaceholdersFromLookupTree climbs past the immediate
parent using only a check for non-placeholder siblings against the original
this, which can wrongly detach ancestor placeholders when sibling placeholder
branches exist; fix by tracking the child path as you climb: initialize a
BlockPtr child = this, and in each iteration verify that the current node's
siblings contain exactly one entry and that that single sibling's block.get()
equals child (otherwise stop climbing and return); then perform
current->detachFromLookupNode()/setPrevBlockInSeq(nullptr), set child = current,
and continue using getPrevBlock() to move up. Use the existing methods
getPrevBlock, getNextBlocks, isPlaceholder, detachFromLookupNode, and
setPrevBlockInSeq in your change.

In `@cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp`:
- Around line 184-188: The check for layerFirstLayout currently happens inside
the per-pool loop which can allow earlier pools to be written to POSIX before a
later pool fails; pre-scan the pools vector before performing any file I/O and
abort early if any pool has layerFirstLayout to ensure all-or-nothing behavior.
Move or duplicate the TLLM_CHECK_WITH_INFO(!pools[poolIdx].layerFirstLayout,
...) out of the per-pool transfer loop into a single pre-flight loop over pools
(before calls that may perform POSIX writes such as code paths under
POSIX_DEBUG_FALLBACK), so computeBlockPointer and any subsequent I/O only run
after the pre-scan passes. Ensure the same change is applied to the other
location mentioned (around lines 200-210) where the per-pool guard appears.

In `@cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp`:
- Around line 114-117: The trampoline slot count is too small now that
PyKvCacheManager adds the new override getTokenCount; update the NB_TRAMPOLINE
declaration for tbk::BaseKVCacheManager from 36 to 37 so the nanobind trampoline
has space for all overrides (locate the NB_TRAMPOLINE(tbk::BaseKVCacheManager,
36) macro and change the literal to 37 to match the 37 overrides including
getTokenCount).

In `@cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp`:
- Around line 7133-7134: The cudaMemset for the recurrent-state pool is using an
incorrect size: it multiplies strideBlockId again by numLayers/2 which overflows
the pool. Change the memset length to use the pool stride already represented by
strideBlockId times the pool block count (blocksInPrimaryPool) instead of
multiplying by numLayers/2; update the call that uses cudaMemset(poolBaseAddr,
0, ...) to use strideBlockId * blocksInPrimaryPool (keeping poolBaseAddr,
strideBlockId, numLayers, and blocksInPrimaryPool in mind when locating the
change).

In `@docs/source/models/supported-models.md`:
- Line 16: The footnote [^7] in supported-models.md references a missing
AutoDeploy config (gemma4_moe.yaml) for Gemma4ForConditionalGeneration; fix by
either adding the correct AutoDeploy config file named gemma4_moe.yaml to the
model registry configs (matching the Gemma 4 entry) or update the footnote [^7]
in supported-models.md to point to the correct existing config filename (or
remove the footnote if no config is intended), and ensure the
Gemma4ForConditionalGeneration documentation entry consistently references that
config.

In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py`:
- Around line 362-438: Add an Optional[torch.Tensor] custom_attn_mask parameter
to _torch_context_mha_readonly (e.g., def _torch_context_mha_readonly(...,
logit_cap: Optional[float] = None, custom_attn_mask: Optional[torch.Tensor] =
None, sliding_window_size: Optional[int] = None, sinks: Optional[torch.Tensor] =
None)) so the call-site argument no longer shifts parameters; then apply the
mask to attn_scores before softcapping (mirror the logic in _torch_context_mha):
build the relevant slice of custom_attn_mask for the current batch/sequence/key
lengths and call attn_scores.masked_fill_(mask.unsqueeze(0).unsqueeze(0),
float("-inf")) (do this before _apply_logit_softcapping and sliding-window
masking as in the other function).

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py`:
- Around line 471-487: The validator _validate_mlp_style_and_act_fn and
normalizer _normalize_trtllm_act_fn currently enable gated GELU (mapped to
GEGLU) but two backends (fp8_block_scale_moe_runner C++ kernel and the
TRTLLM-Gen NVFP4 path that raises ValueError) don't support GEGLU; either
restrict the validator/normalizer to disallow gated GELU (remove
ActivationType.Gelu from the gated branch in _validate_mlp_style_and_act_fn and
stop mapping Gelu→Geglu in _normalize_trtllm_act_fn) OR implement GEGLU support
in the backend code (add act_type support in fp8_block_scale_moe_runner and
handle GEGLU in the NVFP4 path instead of raising ValueError); pick one approach
and make matching changes to the named symbols so validation and backend
behavior remain consistent.

In `@tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma3n.py`:
- Around line 603-606: Replace the public input validation assertions with
explicit exceptions: in the model's forward method(s) where you currently have
"assert position_ids is not None, 'position_ids must be provided'" (both
occurrences in modeling_gemma3n.py), change them to raise a ValueError with the
same message; also keep the existing input_ids/inputs_embeds mutual-exclusivity
check but ensure it raises ValueError instead of relying on assert so
caller-controlled inputs are validated reliably.

In `@tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py`:
- Around line 983-990: The factory
Gemma4ForConditionalGenerationFactory.init_processor currently returns
ADGemma4Tokenizer (a text-only tokenizer) which breaks multimodal handling
expected by AutoModelForImageTextToTextFactory and causes
Gemma4ADInputProcessor.__call__ to fall back to zeroed token_type_ids; change
init_processor to return a true multimodal processor (e.g., an ADGemma4Processor
that combines image processor + tokenizer, similar to
ADMistralSmall4Processor/PixtralProcessor in
Mistral3ForConditionalGenerationFactory) or explicitly route image preprocessing
so that apply_chat_template() isn’t the sole path for multimodal inputs and the
processor provides per-blob IDs. Ensure the returned object implements the image
handling APIs used by AutoModelForImageTextToTextFactory.
- Around line 934-938: The code currently builds additional from
config.get("extra_special_tokens") using list(extra.keys()) for dicts, which
adds symbolic names instead of actual token strings; change the branch in the
extra handling so that when extra is a dict you collect the token strings via
list(extra.values()) (and coerce/filter to strings) so additional contains the
actual special token strings to pass to additional_special_tokens during
tokenizer initialization.

In `@tensorrt_llm/_torch/auto_deploy/utils/node_utils.py`:
- Around line 1052-1058: The get_op_schema function currently picks an arbitrary
overload by iterating op._schemas.values(); instead, when op is an
OpOverloadPacket (has attribute _schemas), use the packet's default overload
schema via op.default._schema to reliably get the intended schema, otherwise
fall back to using op._schema and raise the same RuntimeError if neither exists;
update get_op_schema to check for _schemas and return op.default._schema.

In
`@tests/unittest/auto_deploy/singlegpu/transformations/library/test_inject_custom_attention_mask.py`:
- Around line 1-13: Update the file header to the repo-standard NVIDIA copyright
header for 2026 by replacing the existing 2025 header block at the top of
test_inject_custom_attention_mask.py with the canonical header used across the
repo (change the year to 2026 and ensure the exact wording/format matches other
files), so the license block format and year match the repository standard.

---

Nitpick comments:
In `@cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h`:
- Around line 1831-1843: Rename the cryptic local variable nkvh to a descriptive
name (e.g., numKvHeadsForWindow or windowLayerKvHeads) in the function in
kvCacheManager.h: update its declaration, reserve call, push_back loop, and the
std::reduce call that computes sumLocalHeads so all uses reflect the new name
and intent; ensure comments remain accurate and build after the rename.
- Line 198: Remove the redundant circular alias "using SizeType32 =
WindowSizeType;" — it conflicts with the existing
tensorrt_llm::runtime::SizeType32 declared earlier; delete this line and ensure
all code in this header continues to use the original SizeType32 (or
WindowSizeType where intended) so there is no ambiguous re-aliasing; if any
references relied on the alias, update them to the correct symbol
(tensorrt_llm::runtime::SizeType32 or WindowSizeType) to keep meanings clear.
- Around line 1817-1820: Rename the parameter layer_idx to layerIdx in the
virtual method declarations getPrimaryPool, getIndexerKCachePool (if
applicable), getPoolLayerIdx, and isPoolLayerFirst in kvCacheManager.h so they
follow camelCase parameter naming; update the parameter name in each method
signature (and any implementing overrides in derived classes) and adjust
references to that parameter within those functions to use layerIdx to maintain
consistency with the coding guideline.
- Around line 117-196: LinearAttentionMetadata uses a special encoded value
LinearCacheType::kRecurrentStates (0x80000001) and its
calcMaxMemoryBlocks(WindowSizeType, SizeType32, size_t, SizeType32) computes
perBlockBytes = allRecurrentStatesBytes * numLayers which can be zero and cause
division by zero; add a brief comment near LinearCacheType::kRecurrentStates
explaining the high-bit encoding rationale, and in calcMaxMemoryBlocks validate
that allRecurrentStatesBytes > 0 and numLayers > 0 (or that perBlockBytes > 0)
before dividing—return 0 or throw a descriptive error via TLLM_CHECK/TLLM_THROW
if the budget cannot be computed—to avoid division-by-zero and surface clear
diagnostics referencing the member allRecurrentStatesBytes and the numLayers
parameter.

In `@cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp`:
- Around line 665-668: The call to KVCacheManager::calculateMaxNumBlocks
currently passes a hardcoded K/V factor "2"; replace this literal with a named
constant or a value retrieved from model metadata (e.g., add a const int
kvFactor = ... or read from mModelConfig) and pass kvFactor instead of 2 at the
callsite that uses mModelConfig.getSizePerHead(), kvCacheConfig,
numKvHeadsPerLayer, tokensPerBlock, mWorldConfig, windowSizeToLayers,
freePrimaryMemBytes, freeSecondaryMemBytes, extraCostMemory, and
getMaxBatchSize(); ensure the constant name clearly indicates its meaning (e.g.,
kKvFactor or kvFactorFromModel) and update any relevant callers or tests
accordingly.

In `@cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp`:
- Around line 6817-6819: The test calls blockManager.storeContextBlocks(seq1,
*llmRequest1) while seq1's synthetic prefill was never advanced to the full
prompt (unlike seq0), causing the helper to persist only a partial prefix;
either remove the storeContextBlocks/holdSequence calls for seq1 or advance
seq1's synthetic prefill to the same completed-prefill state used for seq0
before calling blockManager.storeContextBlocks(seq1, *llmRequest1) so that the
test asserts only the reuse behavior; locate references to seq1, seq0,
storeContextBlocks, holdSequence, and llmRequest1 to mirror the
prefill-completion steps applied to seq0.
- Around line 6591-6598: The fixture omits an explicit linear layer layout,
relying on a hidden default and making tests brittle; update the
LinearAttentionMetadata initializers (e.g., the linearAttentionMetadata
instances) to set linearLayerIndices explicitly to match the tests' expectations
(for example, include layer 0 as linear and choose indices so exactly half the
layers contribute recurrent state where those assertions expect it), so that
functions referencing linearLayerIndices use the explicit mapping rather than an
implicit default.

In `@cpp/tests/unit_tests/batch_manager/radixBlockTreeTest.cpp`:
- Around line 537-540: Extract the magic literals used in the test into named
constexprs (e.g., constexpr auto kPlaceholderBlockId = 42; constexpr auto
kPlaceholderWindow = 100;) and use them when calling
KVCacheBlock::createPlaceholder, and declare the returned pointer as const
(const auto ph = ...). Keep the existing assertions (ph != nullptr,
ph->isPlaceholder(), ph->getBlockId()) but compare getBlockId() against the new
kPlaceholderBlockId constant so the test uses named constants instead of bare
literals and ph is immutable.

In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py`:
- Around line 185-186: The unpacking in _torch_generate_mha_readonly currently
binds an unused variable head_dim; rename it to _head_dim to indicate it is
intentionally unused and silence linters. Locate the tuple assignment "b, s,
n_heads, head_dim = q.shape" inside _torch_generate_mha_readonly and change the
identifier to _head_dim while leaving the rest of the logic unchanged.

In `@tensorrt_llm/_torch/auto_deploy/transform/attention_mask_provider.py`:
- Around line 48-53: The code currently swallows all exceptions from calling
factory._get_model_config via a bare except, hiding real bugs; change the
try/except around get_model_config() to only catch the expected
lookup/missing-config errors (e.g., KeyError, LookupError or a specific
FactoryLookupError if one exists) and return None for those, but for any other
Exception log the error (or re-raise) so real failures aren’t silenced; update
the block around get_model_config / _get_model_config to catch specific
exceptions and either return None for known lookup exceptions or log and
re-raise unexpected exceptions.

In `@tensorrt_llm/_torch/auto_deploy/transform/attention_mask_providers.py`:
- Line 18: The file contains an unnecessary compatibility import "from
__future__ import annotations" at the top of attention_mask_providers.py; since
the project targets Python >=3.10, remove that import line from the module
(i.e., delete the "from __future__ import annotations" statement) so type
annotations rely on native Python 3.10+ behavior and avoid redundant future
imports.

In `@tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py`:
- Line 196: The for-loop uses an unused index variable `idx`; update the loop
header to rename `idx` to `_idx` (i.e., change `for idx, attn_node in
enumerate(source_attn_nodes):` to use `_idx`) to follow Python convention for
unused variables and avoid linter warnings—ensure only the loop header is
changed and references to `attn_node` remain unchanged.

In `@tensorrt_llm/_torch/models/modeling_mistral.py`:
- Around line 89-97: The per-layer sliding window logic may index past
config.layer_types; add a defensive bounds check before using
layer_types[layer_idx] in the block that checks layer_types is not None and
layer_idx is not None: ensure layer_idx is an int within 0 <= layer_idx <
len(layer_types) (or fall back to uniform config.sliding_window) and then set
self.attention_window_size accordingly; update the logic around layer_types,
layer_idx, and getattr(config, "sliding_window", None) so an IndexError cannot
occur.

In
`@tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_flashinfer_attention_op.py`:
- Around line 134-138: Replace the ambiguous positional None arguments with
explicit keyword arguments at each of the 8 call sites in this file: pass
sliding_window=None and read_cache_only=None instead of the two consecutive None
values, and pass k_scale=1.0 and v_scale=1.0 for the trailing floats so the call
becomes e.g. sliding_window=None, read_cache_only=None, k_scale=1.0,
v_scale=1.0; update every invocation that currently uses the positional None,
None, 1.0, 1.0 sequence so the intent is clear.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: ec1d7992-27f0-46c7-8dce-06cb76d6dc23

📥 Commits

Reviewing files that changed from the base of the PR and between 3ac2704 and eede8ee.

⛔ Files ignored due to path filters (1)
  • security_scanning/poetry.lock is excluded by !**/*.lock
📒 Files selected for processing (55)
  • cpp/include/tensorrt_llm/batch_manager/evictionPolicy.h
  • cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
  • cpp/include/tensorrt_llm/kernels/kvCacheIndex.h
  • cpp/tensorrt_llm/batch_manager/evictionPolicy.cpp
  • cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
  • cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp
  • cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp
  • cpp/tensorrt_llm/executor/kvCacheConfig.cpp
  • cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp
  • cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp
  • cpp/tests/unit_tests/batch_manager/radixBlockTreeTest.cpp
  • docs/source/models/supported-models.md
  • examples/auto_deploy/cookbooks/gemma_4_trtllm_cookbook.ipynb
  • examples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yaml
  • examples/auto_deploy/model_registry/configs/gemma4_moe.yaml
  • examples/auto_deploy/model_registry/configs/gemma4_moe_base.yaml
  • examples/auto_deploy/model_registry/models.yaml
  • security_scanning/metadata.json
  • tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_attention.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py
  • tensorrt_llm/_torch/auto_deploy/export/export.py
  • tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py
  • tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma3n.py
  • tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py
  • tensorrt_llm/_torch/auto_deploy/transform/__init__.py
  • tensorrt_llm/_torch/auto_deploy/transform/attention_mask_provider.py
  • tensorrt_llm/_torch/auto_deploy/transform/attention_mask_providers.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/inject_custom_attention_mask.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py
  • tensorrt_llm/_torch/auto_deploy/utils/_graph.py
  • tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
  • tensorrt_llm/_torch/models/modeling_mistral.py
  • tensorrt_llm/_torch/models/modeling_mixtral.py
  • tests/integration/test_lists/waives.txt
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma3n_modeling.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma4_modeling.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_shared_kv_attention.py
  • tests/unittest/_torch/modeling/test_modeling_mistral.py
  • tests/unittest/_torch/modeling/test_modeling_mixtral.py
  • tests/unittest/auto_deploy/_utils_test/torch_attention_reference.py
  • tests/unittest/auto_deploy/singlegpu/compile/test_captured_graph.py
  • tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_flashinfer_attention_op.py
  • tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_torch_attention_op.py
  • tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py
  • tests/unittest/auto_deploy/singlegpu/transformations/library/test_inject_custom_attention_mask.py

@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Update the header year to 2026.

This header is modified in this PR, so the copyright range should include 2026.

As per coding guidelines, "Add NVIDIA copyright header to ALL new files; update year on modified files".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/include/tensorrt_llm/batch_manager/evictionPolicy.h` at line 2, Update
the copyright header year range in the file by changing the existing copyright
line that currently reads "Copyright (c) 2022-2025, NVIDIA CORPORATION.  All
rights reserved." to include 2026 (i.e., "2022-2026") so the header reflects
modification in this PR; ensure the exact header string is updated wherever it
appears in the file.

// The illegal value (INT32_MAX) ensures accidental use triggers an obvious OOB failure.
static constexpr UnderlyingType kInvalidPoolIndex = std::numeric_limits<UnderlyingType>::max();

static const KVCacheIndex nullIndex;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🌐 Web query:

In C++17/C++20, if a class declares a static data member and a header later defines it out of class as constexpr, does that header definition need inline(or a single.cpp definition) to avoid multiple definitions across translation units?

💡 Result:

No, in C++17 and C++20, if a class declares a static data member with constexpr in-class initialization (e.g., static constexpr int x = 42;), that in-class declaration is both the definition and implicitly inline, so no out-of-class definition is needed at all, even if ODR-used. This avoids multiple definition issues across translation units. However, the question specifies "a header later defines it out of class as constexpr" (e.g., constexpr int Cls::x;). This out-of-class redeclaration: - Is permitted for compatibility (deprecated since C++17). - Does NOT make the member inline; it remains a non-inline static data member. - Therefore, if included in multiple translation units (TUs), it creates multiple definitions, violating the ODR ([class.static.data]/4 notes exactly one definition allowed if ODR-used; no diagnostic required). - To avoid this, either: - Use inline explicitly: inline constexpr int Cls::x; (C++17+; allows multiple identical definitions). - Or provide the definition in exactly one .cpp file (not header). - Best practice: rely on the in-class constexpr initialization (implicitly inline); omit the out-of-class definition entirely.

Citations:


🏁 Script executed:

# First, let me find the exact file location
find . -type f -name "kvCacheIndex.h" -o -name "*kvCacheIndex*"

Repository: NVIDIA/TensorRT-LLM

Length of output: 112


🏁 Script executed:

# View the header file to understand the current structure around lines 42 and 74
head -100 cpp/include/tensorrt_llm/kernels/kvCacheIndex.h | cat -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 2718


Add inline to the out-of-class definition, or move it to a single .cpp file.

Line 74 defines KVCacheIndex::nullIndex in the header without inline linkage. This will instantiate multiple definitions across translation units when the header is included, violating the ODR. Either mark the definition as inline (C++17+):

inline constexpr KVCacheIndex KVCacheIndex::nullIndex{};

Or move it to a .cpp file. The best practice alternative is to eliminate the out-of-class definition entirely: change line 42 to static constexpr KVCacheIndex nullIndex{}; with in-class initialization and remove line 74.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/include/tensorrt_llm/kernels/kvCacheIndex.h` at line 42, The header
currently defines KVCacheIndex::nullIndex out-of-class which causes ODR
violations; fix by either making the out-of-class definition inline (C++17+) or
by performing in-class constexpr initialization and removing the separate
definition: update the declaration/definition for KVCacheIndex::nullIndex
accordingly so there is a single definition across TUs (refer to KVCacheIndex
and nullIndex to locate the declarations/definition to change).

Comment on lines 44 to +47
explicit KVCacheIndex(UnderlyingType value, bool isSecondary = false)
: value{isSecondary ? value | kSecondaryPoolFlag : value}
{
TLLM_CHECK_DEBUG(value >= 0);
TLLM_CHECK_DEBUG(value >= 0 && this->value != kInvalidPoolIndex);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Check the raw pool index before packing the secondary flag.

TLLM_CHECK_DEBUG is validating the packed field, so KVCacheIndex{kInvalidPoolIndex, true} slips through: get() then returns the reserved sentinel while isNull() is still false. The new null invariant only holds for primary indices right now.

Possible fix
-        TLLM_CHECK_DEBUG(value >= 0 && this->value != kInvalidPoolIndex);
+        TLLM_CHECK_DEBUG(value >= 0 && value != kInvalidPoolIndex);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
explicit KVCacheIndex(UnderlyingType value, bool isSecondary = false)
: value{isSecondary ? value | kSecondaryPoolFlag : value}
{
TLLM_CHECK_DEBUG(value >= 0);
TLLM_CHECK_DEBUG(value >= 0 && this->value != kInvalidPoolIndex);
explicit KVCacheIndex(UnderlyingType value, bool isSecondary = false)
: value{isSecondary ? value | kSecondaryPoolFlag : value}
{
TLLM_CHECK_DEBUG(value >= 0 && value != kInvalidPoolIndex);
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/include/tensorrt_llm/kernels/kvCacheIndex.h` around lines 44 - 47, The
constructor KVCacheIndex currently validates the packed this->value, allowing
KVCacheIndex{kInvalidPoolIndex, true} to bypass the null invariant; fix it by
validating the raw parameter before packing: check that the incoming
UnderlyingType value is >= 0 and != kInvalidPoolIndex (use those symbols
explicitly) prior to applying kSecondaryPoolFlag, then assign this->value =
isSecondary ? value | kSecondaryPoolFlag : value so get() and isNull() remain
consistent; update any TLLM_CHECK_DEBUG call in the KVCacheIndex constructor to
validate the raw value instead of the packed field.

Comment on lines +108 to +124
static char const* const levelToStr[] = {"primary", "secondary", "placeholder"};
static const std::function<bool(BlockPtr const&)> levelValidators[]
= {[](BlockPtr const& block) { return block->isPrimary(); },
[](BlockPtr const& block) { return !block->isPrimary(); },
[](BlockPtr const& block) { return block->isPlaceholder(); }};
bool queueCompromised = false;
for (SizeType32 cacheLevel = 0; cacheLevel < 2; cacheLevel++)
for (SizeType32 queueLevel = 0; queueLevel < kNumCacheLevels + 1; queueLevel++)
{
for (SizeType32 level = 0; level < kMaxPriority - kMinPriority + 1; level++)
for (SizeType32 pri = 0; pri < kNumPriorities; pri++)
{
for (auto const& block : mFreeQueues[cacheLevel][level])
for (auto const& block : mFreeQueues[queueLevel][pri])
{
if ((cacheLevel == 0 && !block->isPrimary()) || (cacheLevel == 1 && block->isPrimary()))
bool const valid = levelValidators[queueLevel](block);
if (!valid)
{
TLLM_LOG_WARNING("Found %s block (id %d) at cacheLevel %d",
block->isPrimary() ? "primary" : "secondary", block->getBlockId(), cacheLevel);
TLLM_LOG_WARNING("Block (id %d) has level=%s, but misplaced at queueLevel %s", block->getBlockId(),
levelToStr[queueLevel], levelToStr[queueLevel]);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Make placeholder validation side-effect free.

The new validators call block->isPrimary(), but placeholder blocks now TLLM_CHECK there. If a placeholder is ever queued at the wrong level, verifyQueueIntegrity() aborts instead of reporting corruption, and the warning currently prints queueLevel for both the actual and expected levels.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/batch_manager/evictionPolicy.cpp` around lines 108 - 124,
The validators in levelValidators must be side-effect free: stop calling
block->isPrimary() / isPlaceholder() (which trigger TLLM_CHECK) and instead
compare the block's cached level via a non-asserting accessor (e.g.
block->getCacheLevel() or block->getLevel()) inside the three lambdas, and
update the warning in verifyQueueIntegrity() so it prints the actual level and
expected level (use levelToStr[actualLevel] and levelToStr[queueLevel] rather
than levelToStr twice) referencing levelValidators, levelToStr and the
TLLM_LOG_WARNING call.

Comment on lines +470 to +490
void KVCacheBlock::detachPreviousPlaceholdersFromLookupTree() const
{
BlockPtr current = getPrevBlock();
while (current != nullptr && current->getBlockId() != KVCacheBlock::kCachedBlocksRootId)
{
if (!current->isPlaceholder())
{
return;
}
auto siblings = current->getNextBlocks();
for (auto const& [key, block] : siblings)
{
if (!block->isPlaceholder() && block.get() != this)
{
return;
}
}
BlockPtr prev = current->getPrevBlock();
current->detachFromLookupNode();
current->setPrevBlockInSeq(nullptr);
current = prev;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Stop climbing when any sibling branch exists.

Once this loop moves above the immediate parent, comparing every child against the original this is no longer sufficient. A placeholder ancestor with another placeholder child still passes this check and gets detached, which can orphan that sibling subtree from the reuse tree.

Possible fix
 void KVCacheBlock::detachPreviousPlaceholdersFromLookupTree() const
 {
+    auto const* childOnPath = this;
     BlockPtr current = getPrevBlock();
     while (current != nullptr && current->getBlockId() != KVCacheBlock::kCachedBlocksRootId)
     {
         if (!current->isPlaceholder())
         {
             return;
         }
         auto siblings = current->getNextBlocks();
         for (auto const& [key, block] : siblings)
         {
-            if (!block->isPlaceholder() && block.get() != this)
+            if (block.get() != childOnPath)
             {
                 return;
             }
         }
         BlockPtr prev = current->getPrevBlock();
         current->detachFromLookupNode();
         current->setPrevBlockInSeq(nullptr);
+        childOnPath = current.get();
         current = prev;
     }
 }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp` around lines 470 - 490,
The loop in KVCacheBlock::detachPreviousPlaceholdersFromLookupTree climbs past
the immediate parent using only a check for non-placeholder siblings against the
original this, which can wrongly detach ancestor placeholders when sibling
placeholder branches exist; fix by tracking the child path as you climb:
initialize a BlockPtr child = this, and in each iteration verify that the
current node's siblings contain exactly one entry and that that single sibling's
block.get() equals child (otherwise stop climbing and return); then perform
current->detachFromLookupNode()/setPrevBlockInSeq(nullptr), set child = current,
and continue using getPrevBlock() to move up. Use the existing methods
getPrevBlock, getNextBlocks, isPlaceholder, detachFromLookupNode, and
setPrevBlockInSeq in your change.

Comment on lines +603 to +606
del kwargs
assert position_ids is not None, "position_ids must be provided"
if (input_ids is None) == (inputs_embeds is None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# First, check if the file exists and get basic info
ls -la tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma3n.py 2>&1 | head -5
wc -l tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma3n.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 248


🏁 Script executed:

# Extract lines 600-610 to see the context around line 603-606
sed -n '600,610p' tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma3n.py | cat -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 644


🏁 Script executed:

# Extract lines 761-770 to see the context around line 764-766
sed -n '761,770p' tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma3n.py | cat -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 631


🏁 Script executed:

# Check what methods contain these lines to verify they are public APIs
sed -n '580,610p' tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma3n.py | cat -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 1565


🏁 Script executed:

# Check what methods contain lines 764-766
sed -n '750,770p' tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma3n.py | cat -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 1076


Replace these public assert guards with explicit exceptions.

position_ids is caller-controlled input in public forward() methods. Using assert for input validation violates the coding guideline to "Raise ValueError instead of using assertions." Assertions are removed under python -O optimization, causing validation failures to appear later in execution with less clear error messages.

Suggested fix
-        assert position_ids is not None, "position_ids must be provided"
+        if position_ids is None:
+            raise ValueError("position_ids must be provided")

This applies to both occurrences at lines 605 and 766.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma3n.py` around
lines 603 - 606, Replace the public input validation assertions with explicit
exceptions: in the model's forward method(s) where you currently have "assert
position_ids is not None, 'position_ids must be provided'" (both occurrences in
modeling_gemma3n.py), change them to raise a ValueError with the same message;
also keep the existing input_ids/inputs_embeds mutual-exclusivity check but
ensure it raises ValueError instead of relying on assert so caller-controlled
inputs are validated reliably.

Comment on lines +934 to +938
extra = config.get("extra_special_tokens", [])
if isinstance(extra, list):
additional = extra
else:
additional = list(extra.keys()) if isinstance(extra, dict) else []
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# First, let's look at the file and surrounding context
head -n 950 tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py | tail -n 30

Repository: NVIDIA/TensorRT-LLM

Length of output: 1479


🏁 Script executed:

# Also check how the 'additional' variable is used after line 938
sed -n '934,960p' tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1200


🏁 Script executed:

# Search for extra_special_tokens usage in the file to understand the pattern
rg "extra_special_tokens" tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py -A 3 -B 3

Repository: NVIDIA/TensorRT-LLM

Length of output: 831


🏁 Script executed:

# Check similar patterns in other model files
rg "extra\.keys\(\)" tensorrt_llm/ --type py -A 2 -B 2

Repository: NVIDIA/TensorRT-LLM

Length of output: 537


🏁 Script executed:

# Search for how tokenizer.add_special_tokens is typically used in codebase
rg "add_special_tokens" tensorrt_llm/ --type py -A 2 -B 2 | head -50

Repository: NVIDIA/TensorRT-LLM

Length of output: 3313


🏁 Script executed:

# Search for Gemma 4 tokenizer config examples or documentation
find tensorrt_llm -name "*gemma*" -type f | grep -E "(config|README)"

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Check if there are any test cases or examples for this tokenizer
rg "ADGemma4Tokenizer\|extra_special_tokens" tensorrt_llm/ --type py -B 2 -A 5 | head -80

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Check HF tokenizer config patterns - look for how other models handle special tokens
rg "additional_special_tokens" tensorrt_llm/ --type py -B 3 -A 3 | head -100

Repository: NVIDIA/TensorRT-LLM

Length of output: 3758


🏁 Script executed:

# Look at the actual tokenizer_config.json loading to understand the data structure
rg "tokenizer_config\|tokenizer.json" tensorrt_llm/ --type py -B 2 -A 2 | head -60

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Look for any actual Gemma 4 config files or test data
find tensorrt_llm -name "*gemma*" -type f

Repository: NVIDIA/TensorRT-LLM

Length of output: 353


🏁 Script executed:

# Check HF documentation or codebase info about extra_special_tokens structure
rg "extra_special_tokens.*dict\|dict.*extra_special_tokens" tensorrt_llm/ --type py -B 3 -A 3

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Search for when extra_special_tokens might be a dict vs list in HF configs
rg "extra_special_tokens" tensorrt_llm/ -B 5 -A 5 --type py | grep -E "(dict|keys|values|items|list)" -B 3 -A 3

Repository: NVIDIA/TensorRT-LLM

Length of output: 2657


🏁 Script executed:

# Check the HF tokenizer documentation or patterns about special tokens structure
web_search HuggingFace tokenizer extra_special_tokens dict vs list format

Repository: NVIDIA/TensorRT-LLM

Length of output: 113


🏁 Script executed:

# Check the similar Gemma 3 and Gemma 3N implementations
sed -n '1,50p' tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma3n.py | grep -A 30 "extra_special_tokens"

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Look for any tests or examples that show the Gemma 4 tokenizer config structure
find tensorrt_llm -name "*test*" -type f | xargs grep -l "gemma4\|Gemma4" 2>/dev/null

Repository: NVIDIA/TensorRT-LLM

Length of output: 47


🏁 Script executed:

# Check if there are any comments or patterns in the codebase about dict-form special tokens
rg "additional_special_tokens.*dict\|dict.*additional" tensorrt_llm/ --type py -B 2 -A 2

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Check how PreTrainedTokenizerFast handles additional_special_tokens parameter
rg "additional_special_tokens" tensorrt_llm/_torch/auto_deploy/models/custom/ -B 3 -A 3 --type py

Repository: NVIDIA/TensorRT-LLM

Length of output: 2589


🏁 Script executed:

# Search for any actual tokenizer_config.json examples or references
find tensorrt_llm -name "*.json" -type f | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 384


🏁 Script executed:

# Look at how the Gemma3/Gemma3N handle special tokens for comparison
sed -n '900,1000p' tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma3n.py | grep -A 20 "extra_special_tokens"

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Check Hugging Face PreTrainedTokenizerFast documentation by looking at imports and usages
rg "PreTrainedTokenizerFast\|from_pretrained" tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py -A 5 -B 2

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Look at the class definition and how additional_special_tokens is used
sed -n '1,100p' tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py | grep -A 10 "class ADGemma4Tokenizer"

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Check what PreTrainedTokenizerFast signature expects for additional_special_tokens
rg "class ADGemma4Tokenizer" tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py -A 20

Repository: NVIDIA/TensorRT-LLM

Length of output: 837


🏁 Script executed:

# Search for how add_special_tokens or similar methods work with the additional_special_tokens list
rg "add_special_tokens\|additional_special_tokens" tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py -B 5 -A 5

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Check the base class or inheritance to understand what additional_special_tokens parameter expects
sed -n '850,920p' tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 2583


🏁 Script executed:

# Search for any actual Gemma4 model configs or references
rg "google/gemma-4\|gemma-4" tensorrt_llm/ --type py

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Look for documentation or examples about extra_special_tokens format in HF configs
rg "extra_special_tokens.*=\|extra_special_tokens.*:" tensorrt_llm/ -B 2 -A 2 --type py | head -80

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Check if there are any comments or documentation about the expected format
sed -n '900,950p' tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py | cat -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 2501


🏁 Script executed:

# Look at the glm_moe_dsa implementation which also handles extra_special_tokens
sed -n '1,100p' tensorrt_llm/tokenizer/glm_moe_dsa/tokenizer.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 2971


🏁 Script executed:

# Look at how HF PreTrainedTokenizerFast uses additional_special_tokens
rg "additional_special_tokens" tensorrt_llm/ --type py -B 3 -A 3 | grep -E "(list|dict|keys|values)" -B 2 -A 2

Repository: NVIDIA/TensorRT-LLM

Length of output: 824


🏁 Script executed:

# Check if there's any documentation or test showing dict-form extra_special_tokens
sed -n '30,60p' tensorrt_llm/tokenizer/glm_moe_dsa/tokenizer.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1071


🏁 Script executed:

# Look at the actual full context around line 48 where additional_special_tokens is used
sed -n '934,960p' tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1200


🏁 Script executed:

# Check if there are actual test cases or any usage that triggers the dict path
rg "extra_special_tokens.*dict\|extra_special_tokens.*{" tensorrt_llm/ --type py

Repository: NVIDIA/TensorRT-LLM

Length of output: 228


🏁 Script executed:

# Look at HF transformers documentation pattern for additional_special_tokens
rg "additional_special_tokens.*list\|additional_special_tokens.*dict" tensorrt_llm/ --type py -A 3 -B 3

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Check if the .keys() vs .values() distinction matters for PreTrainedTokenizerFast
rg "PreTrainedTokenizerFast.*additional_special_tokens\|additional_special_tokens.*PreTrainedTokenizerFast" tensorrt_llm/ --type py -A 5 -B 5

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Check if there are actual test cases or any usage that triggers the dict path
rg "extra_special_tokens" tensorrt_llm/ --type py | grep -i "dict\|keys\|values"

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Look at the actual semantics - what does additional_special_tokens expect?
# Let's check PreTrainedTokenizerFast usage patterns in the codebase
rg "additional_special_tokens=" tensorrt_llm/ --type py -B 2 -A 2

Repository: NVIDIA/TensorRT-LLM

Length of output: 1347


🏁 Script executed:

# Check if there's any code that would create a dict-form extra_special_tokens
rg "extra_special_tokens.*=" tensorrt_llm/ --type py -B 1 -A 1 | head -60

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Let's understand what PreTrainedTokenizerFast expects for additional_special_tokens
# by checking if there are examples showing the actual values
sed -n '920,1000p' tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 3372


🏁 Script executed:

# Check if there's documentation or comments about what extra_special_tokens should contain
rg "extra_special_tokens\|additional_special_tokens" tensorrt_llm/tokenizer/glm_moe_dsa/ --type py -B 5 -A 5

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Look for the Gemma 4 wrapper class definition to understand what data types it expects
sed -n '895,945p' tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 2076


🏁 Script executed:

# Let's verify the semantic expectation by checking how PreTrainedTokenizerFast uses additional_special_tokens
# in the initialization. Look for any examples of what values are expected
python3 << 'EOF'
# Check the standard HF pattern for additional_special_tokens
# It should be a list of token strings to add to the tokenizer vocabulary
# For example: ["<image>", "<video>"] not the symbolic names like ["image_token", "video_token"]
print("According to HF transformers documentation:")
print("additional_special_tokens should be a list of token STRINGS")
print("For dict form {name: token_string}, we need the token_strings (values), not names (keys)")
print("")
print("Example:")
print("  If extra_special_tokens = {'image_token': '<image>', 'video_token': '<video>'}")
print("  Then additional_special_tokens should be: ['<image>', '<video>'] (values)")
print("  NOT: ['image_token', 'video_token'] (keys)")
EOF

Repository: NVIDIA/TensorRT-LLM

Length of output: 467


🏁 Script executed:

# Double-check the actual Mistral3 implementation to see if it has the same potential issue
sed -n '900,1000p' tensorrt_llm/_torch/auto_deploy/models/custom/modeling_mistral3.py | grep -B 5 -A 10 "extra_special_tokens"

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Check if there are any actual token examples in the codebase that would clarify the format
rg "<.*>" tensorrt_llm/_torch/auto_deploy/models/custom/ --type py | grep -i "token\|gemma\|mistral" | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 389


Use the token strings from extra_special_tokens, not the symbolic names.

If this branch sees the standard {name: token} form, list(extra.keys()) will register entries like "image_token" instead of the actual special tokens, which breaks tokenizer initialization. The additional_special_tokens parameter expects a list of actual token strings to add to the vocabulary.

Suggested fix
-        else:
-            additional = list(extra.keys()) if isinstance(extra, dict) else []
+        else:
+            additional = list(extra.values()) if isinstance(extra, dict) else []
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py` around
lines 934 - 938, The code currently builds additional from
config.get("extra_special_tokens") using list(extra.keys()) for dicts, which
adds symbolic names instead of actual token strings; change the branch in the
extra handling so that when extra is a dict you collect the token strings via
list(extra.values()) (and coerce/filter to strings) so additional contains the
actual special token strings to pass to additional_special_tokens during
tokenizer initialization.

Comment on lines +983 to +990
def __call__(self, inputs, sampling_params):
token_ids, extra = self.base(inputs, sampling_params)
if extra is None:
extra = {}
mm_data = extra.setdefault("multimodal_data", {})
if "token_type_ids" not in mm_data:
mm_data["token_type_ids"] = torch.zeros(1, len(token_ids), dtype=torch.int64)
return token_ids, extra
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Inspect the base image-text factory contract.
rg -n -C5 'class\s+AutoModelForImageTextToTextFactory\b|def\s+init_(tokenizer|processor|input_processor)\(' tensorrt_llm/_torch/auto_deploy --type py

# Compare sibling image-text factories, if any.
rg -n -C4 'class\s+\w+\(AutoModelForImageTextToTextFactory\)' tensorrt_llm/_torch/auto_deploy --type py

# Trace where multimodal_data / token_type_ids are produced and consumed.
rg -n -C3 'multimodal_data|token_type_ids' tensorrt_llm/_torch/auto_deploy --type py

Repository: NVIDIA/TensorRT-LLM

Length of output: 27916


🏁 Script executed:

# Find ADGemma4Tokenizer class definition
rg -n 'class ADGemma4Tokenizer' tensorrt_llm/_torch/auto_deploy --type py -A 20

# Find how base factory uses the processor
rg -n 'init_processor' tensorrt_llm/_torch/auto_deploy/models/hf.py -B 5 -A 10

# Check if ADGemma4Tokenizer has processor-like methods
rg -n 'ADGemma4Tokenizer.*\(' tensorrt_llm/_torch/auto_deploy --type py -B 3 -A 3

# Find where multimodal_data is set/used in the llm preprocessing pipeline
rg -n '__call__.*sampling_params|init_input_processor' tensorrt_llm/_torch/auto_deploy/llm.py -B 5 -A 15

Repository: NVIDIA/TensorRT-LLM

Length of output: 6903


🏁 Script executed:

# Check if there's a Gemma4Processor in HF ecosystem
rg -n 'class.*Gemma4.*Processor|from transformers' tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py --type py | head -30

# Check what ADInputProcessor does with the processor argument
rg -n 'class ADInputProcessor' tensorrt_llm/_torch/auto_deploy --type py -A 30

# Find if there's a real Gemma4 multimodal processor in HF
rg -n 'Gemma4Processor|gemma.*processor' tensorrt_llm/_torch/auto_deploy --type py -i

# Check the imports to see what processor classes are available
head -100 tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py | grep -E 'import|from'

Repository: NVIDIA/TensorRT-LLM

Length of output: 4234


🏁 Script executed:

# Check apply_chat_template in ADInputProcessor flow
rg -n 'apply_chat_template' tensorrt_llm/_torch/auto_deploy --type py -B 5 -A 10

# Check if ADGemma4Tokenizer overrides apply_chat_template
rg -n 'apply_chat_template|def.*method' tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py -A 5

# See full ADGemma4Tokenizer class definition
sed -n '907,962p' tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py

# Compare with Mistral's processor
rg -n 'class ADMistralSmall4Processor' tensorrt_llm/_torch/auto_deploy/models/custom/modeling_mistral3.py -A 40

Repository: NVIDIA/TensorRT-LLM

Length of output: 17142


init_processor() returns a tokenizer instead of a multimodal processor, breaking image support.

Gemma4ForConditionalGenerationFactory.init_processor() returns ADGemma4Tokenizer (a text-only tokenizer), while the base AutoModelForImageTextToTextFactory expects a processor with image handling (e.g., PixtralProcessor for Mistral). The tokenizer's apply_chat_template() method cannot generate image embeddings or multimodal tensors, so multimodal requests will only receive the all-zeros fallback token_type_ids from Gemma4ADInputProcessor.__call__(). The docstring claiming "the HF processor already provides [per-blob IDs]" is misleading since the factory returns a tokenizer, not a processor.

Compare with Mistral3ForConditionalGenerationFactory.init_processor(), which correctly returns ADMistralSmall4Processor (combining image processor + tokenizer). Gemma4 should either return a real multimodal processor or document that multimodal image preprocessing happens elsewhere before reaching ADInputProcessor.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.py` around
lines 983 - 990, The factory
Gemma4ForConditionalGenerationFactory.init_processor currently returns
ADGemma4Tokenizer (a text-only tokenizer) which breaks multimodal handling
expected by AutoModelForImageTextToTextFactory and causes
Gemma4ADInputProcessor.__call__ to fall back to zeroed token_type_ids; change
init_processor to return a true multimodal processor (e.g., an ADGemma4Processor
that combines image processor + tokenizer, similar to
ADMistralSmall4Processor/PixtralProcessor in
Mistral3ForConditionalGenerationFactory) or explicitly route image preprocessing
so that apply_chat_template() isn’t the sole path for multimodal inputs and the
processor provides per-blob IDs. Ensure the returned object implements the image
handling APIs used by AutoModelForImageTextToTextFactory.

Comment on lines +1052 to 1058
def get_op_schema(op) -> torch.FunctionSchema:
"""Return the schema for an op or op overload packet."""
if hasattr(op, "_schemas"):
return next(iter(op._schemas.values()))
elif hasattr(op, "_schema"):
if hasattr(op, "_schema"):
return op._schema
raise RuntimeError(f"No schema found on op {op}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

In PyTorch, when an OpOverloadPackethas multiple overloads, is iterating overop._schemas.values()guaranteed to return the default overload first? What is the supported way to retrieve the default overload's schema from anOpOverloadPacket?

💡 Result:

No, iterating over op._schemas.values is not guaranteed to return the default overload first. There is no evidence in the source code or documentation that dict iteration order in Python (insertion order since 3.7) places the default overload ('default' or '') first in OpOverloadPacket._schemas. Overload order depends on registration order, which can vary and is not specified to prioritize default. The supported way to retrieve the default overload's schema is getattr(op, 'default')._schema or op.default._schema if 'default' exists, or check op.overloads for the default name (often first or named 'default') and use getattr(op, name)._schema. This is shown in PyTorch source code like torch/fx/operator_schemas.py where schemas are fetched via getattr(op, overload)._schema for overload in op.overloads (results 1,3,9,10). Direct _schemas access is internal/undocumented.

Citations:


🏁 Script executed:

# Read the implementation of get_op_schema function
cat -n tensorrt_llm/_torch/auto_deploy/utils/node_utils.py | sed -n '1050,1070p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 874


🏁 Script executed:

# Search for usages of get_op_schema in the codebase
rg "get_op_schema" -A 2 -B 2

Repository: NVIDIA/TensorRT-LLM

Length of output: 5929


🏁 Script executed:

# Check if op.default exists in typical PyTorch usage - look at related test or example files
fd "node_utils\|_graph\|export" -e py | head -10 | xargs grep -l "get_op_schema\|_schemas\|OpOverload" 2>/dev/null | head -5

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


Use op.default._schema for OpOverloadPackets instead of next(iter(op._schemas.values())).

Iterating over op._schemas.values() does not guarantee returning the default overload first—PyTorch does not specify iteration order. The codebase already demonstrates this awareness in _graph.py (checking base_op.default before passing to get_op_schema()), but other callers like export.py and attention_interface.py pass multi-overload packets directly. Use op.default._schema when the op is an OpOverloadPacket, matching PyTorch's supported pattern in torch/fx/operator_schemas.py.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/utils/node_utils.py` around lines 1052 -
1058, The get_op_schema function currently picks an arbitrary overload by
iterating op._schemas.values(); instead, when op is an OpOverloadPacket (has
attribute _schemas), use the packet's default overload schema via
op.default._schema to reliably get the intended schema, otherwise fall back to
using op._schema and raise the same RuntimeError if neither exists; update
get_op_schema to check for _schemas and return op.default._schema.

Comment on lines +1 to +13
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use the repo-standard 2026 header on this new file.

This file is new here, but the header still says 2025 and does not match the standard header form used in the rest of the repo.

As per coding guidelines, "Add NVIDIA copyright header to ALL new files; update year on modified files".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@tests/unittest/auto_deploy/singlegpu/transformations/library/test_inject_custom_attention_mask.py`
around lines 1 - 13, Update the file header to the repo-standard NVIDIA
copyright header for 2026 by replacing the existing 2025 header block at the top
of test_inject_custom_attention_mask.py with the canonical header used across
the repo (change the year to 2026 and ensure the exact wording/format matches
other files), so the license block format and year match the repository
standard.

During warmup, CUDA graph capture, and other executor-internal calls
that bypass the input processor, token_type_ids is not in named_args.
Supply all-zeros (standard causal mask) when missing.

Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
- Use aten.sym_size inside the graph to derive seq_len instead of
  the Python-level symbolic variable which becomes undefined after
  graph recompile during KV cache transforms
- Add register_default_extra_arg infra so token_type_ids defaults to
  zeros during warmup/init forward passes
- Wire default factory through AttentionMaskProviderContext into
  CachedSequenceInterface and SequenceInfo

Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
During resize_kv_cache warmup with empty batches, the custom attention
mask may have batch dimension 0. Treat empty masks the same as None
to fall back to the standard causal attention path.

Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
During warmup forward passes, sequences may have zero q_len or kv_len.
Skip these in the per-sequence custom mask attention loop to avoid
index-out-of-bounds errors.

Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Move mask computation from FX graph nodes to the wrapper forward().
The graph receives the finished custom_attn_mask tensor (or None) as
an optional input — no token_type_ids graph input, no computation
nodes, no symbolic shape/device issues.

During warmup, text-only, and decode: wrapper passes None so the
attention backend uses its fast causal kernel. During prefill with
images: wrapper computes the bidirectional-within-blob + causal mask
from token_type_ids and passes it to the graph.

Reverts band-aid fixes (numel check, empty seq skip) that were needed
with the in-graph approach.

Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants