feat: add dflash support for gemma-4#1673
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces several enhancements and fixes for DFlash speculative decoding. Key updates include support for sliding window attention (SWA) within DFlash layers, embedding normalization for Gemma4-style models, and improved KV cache isolation to prevent draft KVs from overwriting target KVs. Additionally, the Triton kernel for input expansion was refined to handle rejected tokens more robustly, and the scheduler's cache pruning logic was updated. Feedback highlights a need for consistency in auxiliary layer ID mapping within the configuration update logic to match the model runner's implementation.
| # TODO: does this need to be shifted by 1 like in gpu_model_runner? | ||
| aux_layer_ids = config_dict["aux_hidden_state_layer_ids"] | ||
| pre_trained_config["eagle_aux_hidden_state_layer_ids"] = aux_layer_ids |
There was a problem hiding this comment.
The TODO here should be addressed to ensure consistency with the changes in gpu_model_runner.py. Since gpu_model_runner.py now explicitly shifts DFlash target_layer_ids by 1 when converting them to Eagle-style auxiliary layer IDs, this function should perform the same transformation. Failing to do so will result in incorrect layer indices being used for hidden states when the config is updated via this path.
| # TODO: does this need to be shifted by 1 like in gpu_model_runner? | |
| aux_layer_ids = config_dict["aux_hidden_state_layer_ids"] | |
| pre_trained_config["eagle_aux_hidden_state_layer_ids"] = aux_layer_ids | |
| # Shift by 1 to convert DFlash's aux layer id semantics to match Eagle | |
| aux_layer_ids = [i + 1 for i in config_dict["aux_hidden_state_layer_ids"]] | |
| pre_trained_config["eagle_aux_hidden_state_layer_ids"] = aux_layer_ids |
Based on vllm-project/vllm#41703