3434from aphrodite .multimodal .inputs import NestedTensors
3535from aphrodite .transformers_utils .config import set_default_rope_theta
3636from aphrodite .v1 .attention .backend import AttentionType
37+ from aphrodite .v1 .attention .selector import get_attn_backend
38+ from aphrodite .v1 .kv_cache_interface import (
39+ FullAttentionSpec ,
40+ KVCacheSpec ,
41+ SlidingWindowSpec ,
42+ )
3743
3844from .qwen2 import Qwen2MLP as Qwen3MLP
3945from .qwen3 import Qwen3ForCausalLM
4753logger = init_logger (__name__ )
4854
4955
56+ _DFLASH_VALID_LAYER_TYPES = frozenset ({"full_attention" , "sliding_attention" })
57+
58+
59+ def _get_dflash_layer_types (config : Qwen3Config ) -> tuple [str , ...]:
60+ layer_types = getattr (config , "layer_types" , None )
61+ if layer_types is None :
62+ return ("full_attention" ,) * config .num_hidden_layers
63+ if len (layer_types ) != config .num_hidden_layers :
64+ raise ValueError (
65+ f"DFlash layer_types length { len (layer_types )} does not match "
66+ f"num_hidden_layers { config .num_hidden_layers } ."
67+ )
68+ invalid = set (layer_types ) - _DFLASH_VALID_LAYER_TYPES
69+ if invalid :
70+ raise ValueError (f"Invalid DFlash layer_type(s): { sorted (invalid )} ." )
71+ if "sliding_attention" in layer_types and not getattr (
72+ config , "sliding_window" , None
73+ ):
74+ raise ValueError (
75+ "DFlash sliding_attention layers require `sliding_window` in config."
76+ )
77+ return tuple (layer_types )
78+
79+
80+ class DFlashAttention (Attention ):
81+ """Attention with DFlash-specific KV allocation semantics.
82+
83+ The compute path keeps the layer's configured sliding window. The KV cache
84+ spec is widened to full attention because DFlash writes every context KV
85+ before drafting and cannot evict old context blocks from draft layers.
86+ """
87+
88+ def get_kv_cache_spec (self , aphrodite_config : AphroditeConfig ) -> KVCacheSpec | None :
89+ spec = super ().get_kv_cache_spec (aphrodite_config )
90+ if isinstance (spec , SlidingWindowSpec ):
91+ return FullAttentionSpec (
92+ block_size = spec .block_size ,
93+ num_kv_heads = spec .num_kv_heads ,
94+ head_size = spec .head_size ,
95+ head_size_v = getattr (spec , "head_size_v" , spec .head_size ),
96+ dtype = spec .dtype ,
97+ kv_quant_mode = spec .kv_quant_mode ,
98+ page_size_padded = spec .page_size_padded ,
99+ )
100+ return spec
101+
102+
50103class DFlashQwen3Attention (nn .Module ):
51104 """Attention for DFlash speculative decoding.
52105
@@ -66,6 +119,7 @@ def __init__(
66119 attention_bias : bool = False ,
67120 cache_config : CacheConfig | None = None ,
68121 quant_config : QuantizationConfig | None = None ,
122+ sliding_window : int | None = None ,
69123 prefix : str = "" ,
70124 attn_type : str = AttentionType .DECODER ,
71125 ) -> None :
@@ -109,15 +163,24 @@ def __init__(
109163 max_position = max_position ,
110164 rope_parameters = rope_parameters ,
111165 )
112- self .attn = Attention (
166+ draft_attn_backend = get_attn_backend (
167+ self .head_dim ,
168+ torch .get_default_dtype (),
169+ cache_config .cache_dtype if cache_config is not None else "auto" ,
170+ use_mm_prefix = False ,
171+ attn_type = attn_type ,
172+ )
173+ self .attn = DFlashAttention (
113174 self .num_heads ,
114175 self .head_dim ,
115176 self .scaling ,
116177 num_kv_heads = self .num_kv_heads ,
117178 cache_config = cache_config ,
118179 quant_config = quant_config ,
180+ per_layer_sliding_window = sliding_window ,
119181 prefix = f"{ prefix } .attn" ,
120182 attn_type = attn_type ,
183+ attn_backend = draft_attn_backend ,
121184 )
122185 self .q_norm = RMSNorm (self .head_dim , eps = rms_norm_eps )
123186 self .k_norm = RMSNorm (self .head_dim , eps = rms_norm_eps )
@@ -154,12 +217,17 @@ def __init__(
154217 config : Qwen3Config ,
155218 cache_config : CacheConfig | None = None ,
156219 quant_config : QuantizationConfig | None = None ,
220+ layer_type : str = "full_attention" ,
157221 prefix : str = "" ,
158222 ) -> None :
159223 super ().__init__ ()
160224 self .hidden_size = config .hidden_size
225+ self .layer_type = layer_type
161226 set_default_rope_theta (config , default_theta = 1000000 )
162227 attn_type = AttentionType .DECODER
228+ sliding_window = (
229+ config .sliding_window if layer_type == "sliding_attention" else None
230+ )
163231
164232 self .self_attn = DFlashQwen3Attention (
165233 hidden_size = self .hidden_size ,
@@ -171,6 +239,7 @@ def __init__(
171239 head_dim = getattr (config , "head_dim" , None ),
172240 cache_config = cache_config ,
173241 quant_config = quant_config ,
242+ sliding_window = sliding_window ,
174243 rope_parameters = config .rope_parameters ,
175244 prefix = f"{ prefix } .self_attn" ,
176245 attn_type = attn_type ,
@@ -236,17 +305,30 @@ def __init__(
236305 self .config .hidden_size ,
237306 prefix = maybe_prefix (prefix , "embed_tokens" ),
238307 )
239-
308+ target_config = aphrodite_config .model_config .hf_text_config
309+ self .embed_normalizer : float | None = None
310+ if str (getattr (target_config , "model_type" , "" )).startswith ("gemma4" ):
311+ # Gemma4 scales token embeddings by sqrt(hidden_size). DFlash
312+ # shares the target embeddings, so the draft path must match.
313+ self .embed_normalizer = target_config .hidden_size ** 0.5
314+
315+ self .layer_types = _get_dflash_layer_types (self .config )
240316 self .layers = nn .ModuleList (
241317 [
242318 DFlashQwen3DecoderLayer (
243319 current_aphrodite_config ,
244320 prefix = maybe_prefix (prefix , f"layers.{ layer_idx + start_layer_id } " ),
245321 config = self .config ,
322+ layer_type = self .layer_types [layer_idx ],
246323 )
247324 for layer_idx in range (self .config .num_hidden_layers )
248325 ]
249326 )
327+ self .sliding_attention_layer_names = {
328+ layer .self_attn .attn .layer_name
329+ for layer in self .layers
330+ if layer .layer_type == "sliding_attention"
331+ }
250332 if self .use_aux_hidden_state :
251333 num_features_to_use = self .config .num_hidden_layers
252334 if "target_layer_ids" in drafter_config :
@@ -276,7 +358,8 @@ def __init__(
276358 )
277359
278360 def embed_input_ids (self , input_ids : torch .Tensor ) -> torch .Tensor :
279- return self .embed_tokens (input_ids )
361+ embeds = self .embed_tokens (input_ids )
362+ return embeds * self .embed_normalizer if self .embed_normalizer else embeds
280363
281364 def _build_fused_kv_buffers (self ) -> None :
282365 """Build fused weight buffers for precompute_and_store_context_kv.
@@ -504,7 +587,11 @@ def __init__(self, *, aphrodite_config: AphroditeConfig, prefix: str = ""):
504587 self .config .hidden_size ,
505588 prefix = maybe_prefix (prefix , "lm_head" ),
506589 )
507- self .logits_processor = LogitsProcessor (self .config .draft_vocab_size , scale = logit_scale )
590+ self .logits_processor = LogitsProcessor (
591+ self .config .draft_vocab_size ,
592+ scale = logit_scale ,
593+ soft_cap = getattr (self .config , "final_logit_softcapping" , None ),
594+ )
508595 target_vocab_size = aphrodite_config .model_config .get_vocab_size ()
509596 if self .config .draft_vocab_size != target_vocab_size :
510597 self .draft_id_to_target_id = nn .Parameter (
@@ -556,6 +643,10 @@ def precompute_and_store_context_kv(
556643 """Precompute projected + RoPE'd K/V and write to cache."""
557644 self .model .precompute_and_store_context_kv (context_states , context_positions , context_slot_mapping )
558645
646+ @property
647+ def sliding_attention_layer_names (self ) -> set [str ]:
648+ return self .model .sliding_attention_layer_names
649+
559650 def combine_hidden_states (
560651 self ,
561652 hidden_states : torch .Tensor ,
0 commit comments