Skip to content

[None][feat] add Gemma4 MTP speculative decoding support#15473

Closed
Tushar-ml wants to merge 1 commit into
NVIDIA:mainfrom
Tushar-ml:feat/gemma4-mtp-speculative-decoding
Closed

[None][feat] add Gemma4 MTP speculative decoding support#15473
Tushar-ml wants to merge 1 commit into
NVIDIA:mainfrom
Tushar-ml:feat/gemma4-mtp-speculative-decoding

Conversation

@Tushar-ml

@Tushar-ml Tushar-ml commented Jun 18, 2026

Copy link
Copy Markdown
Contributor

Summary

Implements Multi-Token Prediction (MTP) speculative decoding for Gemma4 models using the one-engine path. The Gemma4 assistant model uses Q-only attention layers that share the backbone KV cache via cache_layer_idx, making it incompatible with the two-engine (separate draft KV pool) approach.

  • Gemma4MTP / Gemma4MTPDecoderLayer / Gemma4MTPHead: New MTP layer classes. Each Gemma4MTPDecoderLayer uses is_kv_shared=True (no K/V projections) and reads backbone KV via cache_layer_idx, which is computed as the last backbone layer with its own KV cache of matching attention type (sliding/full).
  • Gemma4ForCausalLM promoted from DecoderModelForCausalLMSpecDecOneEngineForCausalLM.
  • MTPForCausalLM extended with a gemma4_text case: uses max_draft_len (not num_nextn_predict_layers), loads assistant checkpoint weights with HF→TRT-LLM key remapping.
  • FlashInferAttentionMetadata gains spec-dec fields mirrored from TrtllmAttentionMetadata so MTP draft loop buffers can be allocated when FlashInfer is used (required for Gemma4 VSWA + head_dim 256/512 on non-Blackwell GPUs).
  • py_executor_creator / _util / speculative/utils: detect uses_shared_backbone_kv_for_mtp class attribute to disable separate draft KV pool and route Gemma4 MTP through the correct KV-cache path.
  • model_loader: load_config_and_apply_defaults now returns model_cls for class-level capability checks.

Test plan

  • Text-only Gemma4-27B/31B with MTPDecodingConfig(max_draft_len=1, speculative_model="google/gemma-4-31b-it-assistant", mtp_eagle_one_model=True) on H100 (4× TP)
  • Verify acceptance rate > baseline greedy throughput
  • Verify no KV cache OOM with free_gpu_memory_fraction=0.6
  • Smoke test with FlashInfer backend (attn_backend=FLASHINFER) on Blackwell; TRTLLM backend on Hopper
  • Run existing MTP unit tests: pytest tests/unittest/_torch/speculative/ -k mtp

🤖 Generated with Claude Code

Summary by CodeRabbit

  • New Features

    • Gemma4 now supports one-engine speculative decoding with improved inference efficiency.
    • Added Multi-Token Prediction (MTP) capability for Gemma4 assistant models.
    • Gemma4 multimodal model enhanced with speculative decoding support.
  • Performance Improvements

    • Gemma4 configuration parameters optimized for better batch processing and memory utilization.
    • Improved KV cache management for speculative decoding scenarios.
    • Dynamic backend selection for attention operations based on hardware capabilities.

Implements Multi-Token Prediction (MTP) speculative decoding for Gemma4
models using the one-engine path where the assistant's Q-only attention
layers share the backbone's KV cache via cache_layer_idx.

Key changes:

modeling_gemma4.py:
- Promote Gemma4ForCausalLM to extend SpecDecOneEngineForCausalLM
- Add Gemma4MTPHead, Gemma4MTPDecoderLayer, Gemma4MTP classes
- Gemma4MTPDecoderLayer uses is_kv_shared=True (Q-only, no KV projections)
- cache_layer_idx computed from backbone layer_types per attention type
- Gemma4TextModel gains aux_stream_dict for MTPForCausalLM interface compat

modeling_speculative.py:
- Add gemma4_text case in MTPForCausalLM to create Gemma4MTP layers
- Gemma4 uses max_draft_len (not num_nextn_predict_layers)
- load_weights maps HF assistant checkpoint layout to TRT-LLM paths
- load_weights_from_target_model shares lm_head/embed_tokens with backbone

flashinfer.py:
- Mirror spec-dec fields from TrtllmAttentionMetadata onto
  FlashInferAttentionMetadata so the MTP draft loop can use FlashInfer
  for Gemma4 (needed for head_dim 256/512 and VSWA on non-Blackwell)

py_executor_creator.py:
- Detect uses_shared_backbone_kv_for_mtp class attribute to route
  Gemma4 MTP through the correct KV-cache and backend path

model_loader.py:
- Return model_cls from load_config_and_apply_defaults for class-level
  capability checks in py_executor_creator

_util.py:
- _target_only_kv_layer_mask: restrict main KV cache to backbone layers
  when draft layers share backbone KV or use a separate draft KV manager

speculative/utils.py:
- Disable separate draft KV cache for Gemma4 MTP one-engine mode
  (assistant reads backbone KV via Q-only attention, no separate pool needed)

modeling_gemma4mm.py:
- Propagate flashinfer_supports_one_engine_spec_decode and
  uses_shared_backbone_kv_for_mtp from Gemma4ForCausalLM

examples/auto_deploy/model_registry/configs/gemma4_dense.yaml:
- Update world_size, batch sizes, and token limits for 4-GPU serving

Signed-off-by: varadrane1707 <varad@simplismart.tech>
@Tushar-ml Tushar-ml requested review from a team as code owners June 18, 2026 06:36
@Tushar-ml Tushar-ml closed this Jun 18, 2026
@coderabbitai

coderabbitai Bot commented Jun 18, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

Caution

Review failed

The pull request is closed.

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: c4a89d02-d4d4-4640-a6bc-a96494ca7e7b

📥 Commits

Reviewing files that changed from the base of the PR and between c390d2f and 7c75198.

📒 Files selected for processing (9)
  • examples/auto_deploy/model_registry/configs/gemma4_dense.yaml
  • tensorrt_llm/_torch/attention_backend/flashinfer.py
  • tensorrt_llm/_torch/models/modeling_gemma4.py
  • tensorrt_llm/_torch/models/modeling_gemma4mm.py
  • tensorrt_llm/_torch/models/modeling_speculative.py
  • tensorrt_llm/_torch/pyexecutor/_util.py
  • tensorrt_llm/_torch/pyexecutor/model_loader.py
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
  • tensorrt_llm/_torch/speculative/utils.py

📝 Walkthrough

Walkthrough

Adds Gemma4 one-engine MTP speculative decoding with shared-backbone Q-only attention. Introduces Gemma4MTP, Gemma4MTPDecoderLayer, and Gemma4MTPHead classes, a Triton paged attention backend in FlashInferAttentionMetadata, Gemma4ForCausalLM inheriting from SpecDecOneEngineForCausalLM, weight remapping in MTPForCausalLM, executor model_cls propagation, KV cache layer masking, and an updated Gemma4 dense deployment config.

Changes

Gemma4 MTP One-Engine Speculative Decoding

Layer / File(s) Summary
FlashInfer metadata and Triton paged attention for spec-dec
tensorrt_llm/_torch/attention_backend/flashinfer.py
Extends FlashInferAttentionMetadata with spec-decoding flags, draft/mask/offset tensors, kv_lens_cuda, host_request_types, and Triton helpers. Adds _forward_triton_paged() for context and spec-dec generation phases, early-returns from plan construction for the "triton" backend, and routes forward_impl() to the Triton path when spec-dec or multi-token generation is active.
Gemma4 MTP model classes and FlashInfer backend selection
tensorrt_llm/_torch/models/modeling_gemma4.py
Adds _gemma4_flashinfer_backend(head_dim) for SM/head-dim-based sub-backend selection; Gemma4Attention now calls it. Gemma4TextModel gains aux_stream_dict. Introduces Gemma4MTPHead, Gemma4MTPDecoderLayer, and Gemma4MTP implementing the Q-only KV-shared MTP assistant forward path with pre_projection, TP-aware chunking, and KV-cache sharing indices.
Gemma4ForCausalLM one-engine integration and MM delegation
tensorrt_llm/_torch/models/modeling_gemma4.py, tensorrt_llm/_torch/models/modeling_gemma4mm.py
Gemma4ForCausalLM now inherits from SpecDecOneEngineForCausalLM, gains flashinfer_supports_one_engine_spec_decode(), uses_shared_backbone_kv_for_mtp = True, conditional cuda_graph_config, _build_attention_masks(), and updated forward()/load_weights() signatures. Gemma4ForConditionalGeneration delegates spec-dec capability and model defaults to the causal LM.
MTPForCausalLM Gemma4 branch and weight remapping
tensorrt_llm/_torch/models/modeling_speculative.py
MTPForCausalLM.__init__ adds a gemma4_text branch selecting Gemma4MTP and optionally deriving mtp_layer_types from Gemma4AssistantConfig. load_weights() remaps HF assistant checkpoint keys to mtp_layers.* paths including Q-only q_projqkv_proj remapping, updates matching state dict entries, and loads with strict=False.
Executor model_cls propagation, KV cache masking, and spec config
tensorrt_llm/_torch/pyexecutor/model_loader.py, tensorrt_llm/_torch/pyexecutor/py_executor_creator.py, tensorrt_llm/_torch/pyexecutor/_util.py, tensorrt_llm/_torch/speculative/utils.py
load_config_and_apply_defaults returns (llm_args, model_cls). create_py_executor uses model_cls.uses_shared_backbone_kv_for_mtp to disable separate draft KV cache and enforce SM-conditional TRTLLM/FLASHINFER compatibility. KvCacheCreator derives spec_dec_layer_mask via new _target_only_kv_layer_mask. Spec config disables separate draft KV cache for Gemma4 hybrid mtp_eagle_one_model.
Gemma4 dense auto-deploy config
examples/auto_deploy/model_registry/configs/gemma4_dense.yaml
Sets world_size: 4, revises cuda_graph_config token/batch limits, lowers free_gpu_memory_fraction to 0.6, and expands the piecewise_num_tokens schedule.

Sequence Diagram(s)

sequenceDiagram
    rect rgba(30, 100, 200, 0.5)
        Note over create_py_executor: Executor setup
        create_py_executor->>_load_config_and_create_checkpoint_loader: create loader
        _load_config_and_create_checkpoint_loader->>ModelLoader: load_config_and_apply_defaults()
        ModelLoader-->>_load_config_and_create_checkpoint_loader: (llm_args, model_cls)
        _load_config_and_create_checkpoint_loader-->>create_py_executor: (llm_args, checkpoint_loader, model_cls)
        create_py_executor->>create_py_executor: model_cls.uses_shared_backbone_kv_for_mtp → disable separate draft KV
        create_py_executor->>create_py_executor: SM check → enforce FLASHINFER / TRTLLM compat
        create_py_executor->>KvCacheCreator: create KV cache manager
        KvCacheCreator->>KvCacheCreator: _target_only_kv_layer_mask → backbone-only mask
    end
    rect rgba(30, 180, 100, 0.5)
        Note over Gemma4ForCausalLM: Speculative decode forward
        Gemma4ForCausalLM->>Gemma4TextModel: forward(hidden_states)
        Gemma4ForCausalLM->>Gemma4MTP: forward(input_ids, hidden_states, spec_metadata)
        Gemma4MTP->>Gemma4MTPDecoderLayer: forward per layer (Q-only KV-shared attn)
        Gemma4MTP->>Gemma4MTPHead: get_last_token_states → shared lm_head logits
        Gemma4MTPHead-->>Gemma4ForCausalLM: draft logits
        Gemma4ForCausalLM->>FlashInferAttention: forward_impl → _forward_triton_paged
    end
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

  • NVIDIA/TensorRT-LLM#10502: Both PRs modify one-model speculative decoding KV-cache handling — this PR adds _target_only_kv_layer_mask and shared-backbone KV masking while the referenced PR adds draft_kv_cache_manager support in speculative workers and attention metadata.
  • NVIDIA/TensorRT-LLM#14745: Both PRs modify the speculative-decoding execution path in py_executor_creator.py and spec_metadata handling — the referenced PR refactors SpecMetadata/CUDA-graph dispatch while this PR wires Gemma4/FlashInfer into the metadata-driven one-engine flow.

Suggested labels

LLM API, SW Architecture

Suggested reviewers

  • yuxianq
  • byshiue
  • Superjomn
  • QiJune
  • mikeiovine
  • Wanli-Jiang
  • nv-guomingz
  • zhenhuaw-me
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Warning

Tools execution failed with the following error:

Failed to run tools: Ping-pong health check failed


Comment @coderabbitai help to get the list of available commands and usage tips.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants