Skip to content

feat: add dflash support for gemma-4#1673

Merged
AlpinDale merged 1 commit into
mainfrom
feat/gemma4-dflash
May 7, 2026
Merged

feat: add dflash support for gemma-4#1673
AlpinDale merged 1 commit into
mainfrom
feat/gemma4-dflash

Conversation

@AlpinDale
Copy link
Copy Markdown
Collaborator

@AlpinDale AlpinDale merged commit 14e8de1 into main May 7, 2026
1 check failed
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several enhancements and fixes for DFlash speculative decoding. Key updates include support for sliding window attention (SWA) within DFlash layers, embedding normalization for Gemma4-style models, and improved KV cache isolation to prevent draft KVs from overwriting target KVs. Additionally, the Triton kernel for input expansion was refined to handle rejected tokens more robustly, and the scheduler's cache pruning logic was updated. Feedback highlights a need for consistency in auxiliary layer ID mapping within the configuration update logic to match the model runner's implementation.

Comment on lines +72 to 74
# TODO: does this need to be shifted by 1 like in gpu_model_runner?
aux_layer_ids = config_dict["aux_hidden_state_layer_ids"]
pre_trained_config["eagle_aux_hidden_state_layer_ids"] = aux_layer_ids
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The TODO here should be addressed to ensure consistency with the changes in gpu_model_runner.py. Since gpu_model_runner.py now explicitly shifts DFlash target_layer_ids by 1 when converting them to Eagle-style auxiliary layer IDs, this function should perform the same transformation. Failing to do so will result in incorrect layer indices being used for hidden states when the config is updated via this path.

Suggested change
# TODO: does this need to be shifted by 1 like in gpu_model_runner?
aux_layer_ids = config_dict["aux_hidden_state_layer_ids"]
pre_trained_config["eagle_aux_hidden_state_layer_ids"] = aux_layer_ids
# Shift by 1 to convert DFlash's aux layer id semantics to match Eagle
aux_layer_ids = [i + 1 for i in config_dict["aux_hidden_state_layer_ids"]]
pre_trained_config["eagle_aux_hidden_state_layer_ids"] = aux_layer_ids

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant