[None][feat] add Gemma4 MTP speculative decoding support#15473
[None][feat] add Gemma4 MTP speculative decoding support#15473Tushar-ml wants to merge 1 commit into
Conversation
Implements Multi-Token Prediction (MTP) speculative decoding for Gemma4 models using the one-engine path where the assistant's Q-only attention layers share the backbone's KV cache via cache_layer_idx. Key changes: modeling_gemma4.py: - Promote Gemma4ForCausalLM to extend SpecDecOneEngineForCausalLM - Add Gemma4MTPHead, Gemma4MTPDecoderLayer, Gemma4MTP classes - Gemma4MTPDecoderLayer uses is_kv_shared=True (Q-only, no KV projections) - cache_layer_idx computed from backbone layer_types per attention type - Gemma4TextModel gains aux_stream_dict for MTPForCausalLM interface compat modeling_speculative.py: - Add gemma4_text case in MTPForCausalLM to create Gemma4MTP layers - Gemma4 uses max_draft_len (not num_nextn_predict_layers) - load_weights maps HF assistant checkpoint layout to TRT-LLM paths - load_weights_from_target_model shares lm_head/embed_tokens with backbone flashinfer.py: - Mirror spec-dec fields from TrtllmAttentionMetadata onto FlashInferAttentionMetadata so the MTP draft loop can use FlashInfer for Gemma4 (needed for head_dim 256/512 and VSWA on non-Blackwell) py_executor_creator.py: - Detect uses_shared_backbone_kv_for_mtp class attribute to route Gemma4 MTP through the correct KV-cache and backend path model_loader.py: - Return model_cls from load_config_and_apply_defaults for class-level capability checks in py_executor_creator _util.py: - _target_only_kv_layer_mask: restrict main KV cache to backbone layers when draft layers share backbone KV or use a separate draft KV manager speculative/utils.py: - Disable separate draft KV cache for Gemma4 MTP one-engine mode (assistant reads backbone KV via Q-only attention, no separate pool needed) modeling_gemma4mm.py: - Propagate flashinfer_supports_one_engine_spec_decode and uses_shared_backbone_kv_for_mtp from Gemma4ForCausalLM examples/auto_deploy/model_registry/configs/gemma4_dense.yaml: - Update world_size, batch sizes, and token limits for 4-GPU serving Signed-off-by: varadrane1707 <varad@simplismart.tech>
|
Caution Review failedThe pull request is closed. ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: 📒 Files selected for processing (9)
📝 WalkthroughWalkthroughAdds Gemma4 one-engine MTP speculative decoding with shared-backbone Q-only attention. Introduces ChangesGemma4 MTP One-Engine Speculative Decoding
Sequence Diagram(s)sequenceDiagram
rect rgba(30, 100, 200, 0.5)
Note over create_py_executor: Executor setup
create_py_executor->>_load_config_and_create_checkpoint_loader: create loader
_load_config_and_create_checkpoint_loader->>ModelLoader: load_config_and_apply_defaults()
ModelLoader-->>_load_config_and_create_checkpoint_loader: (llm_args, model_cls)
_load_config_and_create_checkpoint_loader-->>create_py_executor: (llm_args, checkpoint_loader, model_cls)
create_py_executor->>create_py_executor: model_cls.uses_shared_backbone_kv_for_mtp → disable separate draft KV
create_py_executor->>create_py_executor: SM check → enforce FLASHINFER / TRTLLM compat
create_py_executor->>KvCacheCreator: create KV cache manager
KvCacheCreator->>KvCacheCreator: _target_only_kv_layer_mask → backbone-only mask
end
rect rgba(30, 180, 100, 0.5)
Note over Gemma4ForCausalLM: Speculative decode forward
Gemma4ForCausalLM->>Gemma4TextModel: forward(hidden_states)
Gemma4ForCausalLM->>Gemma4MTP: forward(input_ids, hidden_states, spec_metadata)
Gemma4MTP->>Gemma4MTPDecoderLayer: forward per layer (Q-only KV-shared attn)
Gemma4MTP->>Gemma4MTPHead: get_last_token_states → shared lm_head logits
Gemma4MTPHead-->>Gemma4ForCausalLM: draft logits
Gemma4ForCausalLM->>FlashInferAttention: forward_impl → _forward_triton_paged
end
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested labels
Suggested reviewers
✨ Finishing Touches🧪 Generate unit tests (beta)
Warning Tools execution failed with the following error: Failed to run tools: Ping-pong health check failed Comment |
Summary
Implements Multi-Token Prediction (MTP) speculative decoding for Gemma4 models using the one-engine path. The Gemma4 assistant model uses Q-only attention layers that share the backbone KV cache via
cache_layer_idx, making it incompatible with the two-engine (separate draft KV pool) approach.Gemma4MTP/Gemma4MTPDecoderLayer/Gemma4MTPHead: New MTP layer classes. EachGemma4MTPDecoderLayerusesis_kv_shared=True(no K/V projections) and reads backbone KV viacache_layer_idx, which is computed as the last backbone layer with its own KV cache of matching attention type (sliding/full).Gemma4ForCausalLMpromoted fromDecoderModelForCausalLM→SpecDecOneEngineForCausalLM.MTPForCausalLMextended with agemma4_textcase: usesmax_draft_len(notnum_nextn_predict_layers), loads assistant checkpoint weights with HF→TRT-LLM key remapping.FlashInferAttentionMetadatagains spec-dec fields mirrored fromTrtllmAttentionMetadataso MTP draft loop buffers can be allocated when FlashInfer is used (required for Gemma4 VSWA + head_dim 256/512 on non-Blackwell GPUs).py_executor_creator/_util/speculative/utils: detectuses_shared_backbone_kv_for_mtpclass attribute to disable separate draft KV pool and route Gemma4 MTP through the correct KV-cache path.model_loader:load_config_and_apply_defaultsnow returnsmodel_clsfor class-level capability checks.Test plan
MTPDecodingConfig(max_draft_len=1, speculative_model="google/gemma-4-31b-it-assistant", mtp_eagle_one_model=True)on H100 (4× TP)free_gpu_memory_fraction=0.6attn_backend=FLASHINFER) on Blackwell; TRTLLM backend on Hopperpytest tests/unittest/_torch/speculative/ -k mtp🤖 Generated with Claude Code
Summary by CodeRabbit
New Features
Performance Improvements