Skip to content

[None][feat] Add Gemma3n model + Shared-kv attention support in AutoDeploy#244

Open
bmarimuthu-nv wants to merge 8 commits into
feat/paperclip_maximizerfrom
bala/gemma3n
Open

[None][feat] Add Gemma3n model + Shared-kv attention support in AutoDeploy#244
bmarimuthu-nv wants to merge 8 commits into
feat/paperclip_maximizerfrom
bala/gemma3n

Conversation

@bmarimuthu-nv

@bmarimuthu-nv bmarimuthu-nv commented Mar 13, 2026

Copy link
Copy Markdown

Summary

  • add FlashInfer cached-attention support for Gemma 3n shared-KV decode layers
  • thread per-layer sliding-window metadata through the FlashInfer AutoDeploy backend so alternating sliding/full Gemma 3n layers preserve their attention semantics
  • register the shared-KV FlashInfer cached op as a dynamic cached-attention op for piecewise execution
  • extend shared-KV tests to cover FlashInfer descriptor routing, cache aliasing, and a CUDA runtime check against a manual attention reference
  • update the Gemma 3n shared-KV registry config to prefer attn_backend: flashinfer on the validated world_size_1 path

Backend Decision

  • trtllm was evaluated first, but its cached-attention kernel currently asserts when KV cache writes are disabled
  • evidence from the GPU-shell feasibility run: KV cache update cannot be disabled now
  • because Gemma 3n shared-KV tail layers require a read-only KV-cache consumer, the supported backend path in this PR is FlashInfer, not TRT-LLM

Validation

  • pytest -q tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_shared_kv_attention.py tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma3n_modeling.py
    • passed: 22 passed
    • log: /tmp/codex-gpu-shells/7a1e5e171e/gpu/logs/job_34.log
  • full AutoDeploy e2e run on a local google/gemma-3n-E2B-it snapshot with:
    • world_size_1.yaml
    • gemma3n_shared_kv.yaml
    • --args.attn-backend=flashinfer
    • --args.compile-backend=torch-simple
    • TMPDIR=/tmp to keep executor IPC socket paths under the Unix limit
    • log: /tmp/codex-gpu-shells/7a1e5e171e/gpu/logs/job_33.log

E2E Output Check

The full 1-GPU FlashInfer run produced coherent outputs. Examples from the validation log:

  • gravity: Gravity is the force that pulls everything with mass towards each other.
  • Iceland: The capital of Iceland is Reykjavik.
  • Romeo and Juliet: coherent two-sentence plot summary
  • prime function: generated a correct Python is_prime implementation
  • northern lights: coherent explanation of solar wind / aurora cause
  • universe / golf / compiler prompts: long-form coherent answers rather than the earlier broken fragments / markdown repetition

Notes

  • the validated e2e path used the local checkpoint snapshot directly because the direct Hub-name flow still depends on gated-repo auth at runtime
  • the earlier bad-generation issue is now understood as backend-path selection: Gemma 3n decode needs a shared-KV-aware cached attention backend, and FlashInfer now provides that path in AutoDeploy

Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
@bmarimuthu-nv bmarimuthu-nv changed the title [None][fix] route Gemma3n shared-kv decode through torch backend [None][feat] add flashinfer shared-kv attention for Gemma3n Mar 13, 2026
@bmarimuthu-nv

Copy link
Copy Markdown
Author

FlashInfer shared-KV backend update is pushed in 1006914e07.

What changed in this follow-up:

  • added a FlashInfer read-only shared-KV cached attention op for Gemma 3n tail layers
  • threaded alternating sliding-window metadata through the FlashInfer backend
  • registered the shared FlashInfer cached op in piecewise dynamic-op handling
  • switched the Gemma 3n shared-KV config from attn_backend: torch to attn_backend: flashinfer

Why FlashInfer and not TRT-LLM:

  • I evaluated TRT-LLM first, but the cached-attention kernel currently asserts when KV writes are disabled
  • evidence from the GPU-shell feasibility run: KV cache update cannot be disabled now
  • Gemma 3n shared-KV layers need a read-only KV consumer, so TRT-LLM is not a viable backend path today without lower-level kernel support

Validation:

  • pytest -q tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_shared_kv_attention.py tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma3n_modeling.py
    • passed: 22 passed
    • log: /tmp/codex-gpu-shells/7a1e5e171e/gpu/logs/job_34.log
  • full 1-GPU Gemma 3n AutoDeploy run with local snapshot, attn_backend=flashinfer, compile_backend=torch-simple
    • log: /tmp/codex-gpu-shells/7a1e5e171e/gpu/logs/job_33.log

Generated text in the full FlashInfer run is coherent. Examples from the log:

  • gravity: Gravity is the force that pulls everything with mass towards each other.
  • Iceland: The capital of Iceland is Reykjavik.
  • Romeo and Juliet: coherent two-sentence summary
  • prime check: generated a correct Python is_prime function
  • northern lights: coherent aurora / solar-wind explanation
  • broader prompts like universe, golf, and compiler/interpreter also produced long-form coherent answers instead of the earlier broken fragments / markdown repetition

Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Comment thread tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
Comment thread tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py Outdated
Comment thread tensorrt_llm/_torch/pyexecutor/py_executor.py Outdated
Comment thread tensorrt_llm/_torch/pyexecutor/resource_manager.py Outdated
Comment thread tensorrt_llm/_torch/pyexecutor/resource_manager.py Outdated
Comment thread tensorrt_llm/_torch/pyexecutor/sampler.py Outdated
Comment thread tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_attention.py Outdated
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Comment thread tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py Outdated
# limitations under the License.

import enum
import os

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clean up

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed in de7ca6d. os was another stale import left behind after the earlier debug-block cleanup in sampler.py. I reran the Gemma shared-KV unit shard after this import-only cleanup: 22 passed (job_4.log).

Comment thread tensorrt_llm/_torch/pyexecutor/sampler.py Outdated
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
@bmarimuthu-nv bmarimuthu-nv changed the title [None][feat] add flashinfer shared-kv attention for Gemma3n [None][feat] Add Gemma3n model + Shared-kv attention support in AutoDeploy Mar 13, 2026
@suyoggupta

Copy link
Copy Markdown

@bmarimuthu-nv : can we prioritize landing this PR to main?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants