Skip to content

Commit 14e8de1

Browse files
authored
feat: add dflash support for gemma-4 (#1673)
1 parent c16f370 commit 14e8de1

9 files changed

Lines changed: 269 additions & 34 deletions

File tree

aphrodite/config/speculative.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -936,6 +936,10 @@ def max_num_new_slots_for_drafting(self) -> int:
936936
def use_eagle(self) -> bool:
937937
return self.method in ("eagle", "eagle3", "mtp", "dflash")
938938

939+
def requires_eagle_cache_drop(self) -> bool:
940+
"""Whether prefix cache hits must drop one block for hidden states."""
941+
return self.use_eagle() and not self.use_dflash()
942+
939943
def use_dflash(self) -> bool:
940944
return self.method == "dflash"
941945

aphrodite/model_executor/models/qwen3_dflash.py

Lines changed: 95 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@
3434
from aphrodite.multimodal.inputs import NestedTensors
3535
from aphrodite.transformers_utils.config import set_default_rope_theta
3636
from aphrodite.v1.attention.backend import AttentionType
37+
from aphrodite.v1.attention.selector import get_attn_backend
38+
from aphrodite.v1.kv_cache_interface import (
39+
FullAttentionSpec,
40+
KVCacheSpec,
41+
SlidingWindowSpec,
42+
)
3743

3844
from .qwen2 import Qwen2MLP as Qwen3MLP
3945
from .qwen3 import Qwen3ForCausalLM
@@ -47,6 +53,53 @@
4753
logger = init_logger(__name__)
4854

4955

56+
_DFLASH_VALID_LAYER_TYPES = frozenset({"full_attention", "sliding_attention"})
57+
58+
59+
def _get_dflash_layer_types(config: Qwen3Config) -> tuple[str, ...]:
60+
layer_types = getattr(config, "layer_types", None)
61+
if layer_types is None:
62+
return ("full_attention",) * config.num_hidden_layers
63+
if len(layer_types) != config.num_hidden_layers:
64+
raise ValueError(
65+
f"DFlash layer_types length {len(layer_types)} does not match "
66+
f"num_hidden_layers {config.num_hidden_layers}."
67+
)
68+
invalid = set(layer_types) - _DFLASH_VALID_LAYER_TYPES
69+
if invalid:
70+
raise ValueError(f"Invalid DFlash layer_type(s): {sorted(invalid)}.")
71+
if "sliding_attention" in layer_types and not getattr(
72+
config, "sliding_window", None
73+
):
74+
raise ValueError(
75+
"DFlash sliding_attention layers require `sliding_window` in config."
76+
)
77+
return tuple(layer_types)
78+
79+
80+
class DFlashAttention(Attention):
81+
"""Attention with DFlash-specific KV allocation semantics.
82+
83+
The compute path keeps the layer's configured sliding window. The KV cache
84+
spec is widened to full attention because DFlash writes every context KV
85+
before drafting and cannot evict old context blocks from draft layers.
86+
"""
87+
88+
def get_kv_cache_spec(self, aphrodite_config: AphroditeConfig) -> KVCacheSpec | None:
89+
spec = super().get_kv_cache_spec(aphrodite_config)
90+
if isinstance(spec, SlidingWindowSpec):
91+
return FullAttentionSpec(
92+
block_size=spec.block_size,
93+
num_kv_heads=spec.num_kv_heads,
94+
head_size=spec.head_size,
95+
head_size_v=getattr(spec, "head_size_v", spec.head_size),
96+
dtype=spec.dtype,
97+
kv_quant_mode=spec.kv_quant_mode,
98+
page_size_padded=spec.page_size_padded,
99+
)
100+
return spec
101+
102+
50103
class DFlashQwen3Attention(nn.Module):
51104
"""Attention for DFlash speculative decoding.
52105
@@ -66,6 +119,7 @@ def __init__(
66119
attention_bias: bool = False,
67120
cache_config: CacheConfig | None = None,
68121
quant_config: QuantizationConfig | None = None,
122+
sliding_window: int | None = None,
69123
prefix: str = "",
70124
attn_type: str = AttentionType.DECODER,
71125
) -> None:
@@ -109,15 +163,24 @@ def __init__(
109163
max_position=max_position,
110164
rope_parameters=rope_parameters,
111165
)
112-
self.attn = Attention(
166+
draft_attn_backend = get_attn_backend(
167+
self.head_dim,
168+
torch.get_default_dtype(),
169+
cache_config.cache_dtype if cache_config is not None else "auto",
170+
use_mm_prefix=False,
171+
attn_type=attn_type,
172+
)
173+
self.attn = DFlashAttention(
113174
self.num_heads,
114175
self.head_dim,
115176
self.scaling,
116177
num_kv_heads=self.num_kv_heads,
117178
cache_config=cache_config,
118179
quant_config=quant_config,
180+
per_layer_sliding_window=sliding_window,
119181
prefix=f"{prefix}.attn",
120182
attn_type=attn_type,
183+
attn_backend=draft_attn_backend,
121184
)
122185
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
123186
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
@@ -154,12 +217,17 @@ def __init__(
154217
config: Qwen3Config,
155218
cache_config: CacheConfig | None = None,
156219
quant_config: QuantizationConfig | None = None,
220+
layer_type: str = "full_attention",
157221
prefix: str = "",
158222
) -> None:
159223
super().__init__()
160224
self.hidden_size = config.hidden_size
225+
self.layer_type = layer_type
161226
set_default_rope_theta(config, default_theta=1000000)
162227
attn_type = AttentionType.DECODER
228+
sliding_window = (
229+
config.sliding_window if layer_type == "sliding_attention" else None
230+
)
163231

164232
self.self_attn = DFlashQwen3Attention(
165233
hidden_size=self.hidden_size,
@@ -171,6 +239,7 @@ def __init__(
171239
head_dim=getattr(config, "head_dim", None),
172240
cache_config=cache_config,
173241
quant_config=quant_config,
242+
sliding_window=sliding_window,
174243
rope_parameters=config.rope_parameters,
175244
prefix=f"{prefix}.self_attn",
176245
attn_type=attn_type,
@@ -236,17 +305,30 @@ def __init__(
236305
self.config.hidden_size,
237306
prefix=maybe_prefix(prefix, "embed_tokens"),
238307
)
239-
308+
target_config = aphrodite_config.model_config.hf_text_config
309+
self.embed_normalizer: float | None = None
310+
if str(getattr(target_config, "model_type", "")).startswith("gemma4"):
311+
# Gemma4 scales token embeddings by sqrt(hidden_size). DFlash
312+
# shares the target embeddings, so the draft path must match.
313+
self.embed_normalizer = target_config.hidden_size**0.5
314+
315+
self.layer_types = _get_dflash_layer_types(self.config)
240316
self.layers = nn.ModuleList(
241317
[
242318
DFlashQwen3DecoderLayer(
243319
current_aphrodite_config,
244320
prefix=maybe_prefix(prefix, f"layers.{layer_idx + start_layer_id}"),
245321
config=self.config,
322+
layer_type=self.layer_types[layer_idx],
246323
)
247324
for layer_idx in range(self.config.num_hidden_layers)
248325
]
249326
)
327+
self.sliding_attention_layer_names = {
328+
layer.self_attn.attn.layer_name
329+
for layer in self.layers
330+
if layer.layer_type == "sliding_attention"
331+
}
250332
if self.use_aux_hidden_state:
251333
num_features_to_use = self.config.num_hidden_layers
252334
if "target_layer_ids" in drafter_config:
@@ -276,7 +358,8 @@ def __init__(
276358
)
277359

278360
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
279-
return self.embed_tokens(input_ids)
361+
embeds = self.embed_tokens(input_ids)
362+
return embeds * self.embed_normalizer if self.embed_normalizer else embeds
280363

281364
def _build_fused_kv_buffers(self) -> None:
282365
"""Build fused weight buffers for precompute_and_store_context_kv.
@@ -504,7 +587,11 @@ def __init__(self, *, aphrodite_config: AphroditeConfig, prefix: str = ""):
504587
self.config.hidden_size,
505588
prefix=maybe_prefix(prefix, "lm_head"),
506589
)
507-
self.logits_processor = LogitsProcessor(self.config.draft_vocab_size, scale=logit_scale)
590+
self.logits_processor = LogitsProcessor(
591+
self.config.draft_vocab_size,
592+
scale=logit_scale,
593+
soft_cap=getattr(self.config, "final_logit_softcapping", None),
594+
)
508595
target_vocab_size = aphrodite_config.model_config.get_vocab_size()
509596
if self.config.draft_vocab_size != target_vocab_size:
510597
self.draft_id_to_target_id = nn.Parameter(
@@ -556,6 +643,10 @@ def precompute_and_store_context_kv(
556643
"""Precompute projected + RoPE'd K/V and write to cache."""
557644
self.model.precompute_and_store_context_kv(context_states, context_positions, context_slot_mapping)
558645

646+
@property
647+
def sliding_attention_layer_names(self) -> set[str]:
648+
return self.model.sliding_attention_layer_names
649+
559650
def combine_hidden_states(
560651
self,
561652
hidden_states: torch.Tensor,

aphrodite/transformers_utils/configs/speculators/algos.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,16 @@ def update_dflash(config_dict: dict, pre_trained_config: dict) -> None:
6060
pre_trained_config["draft_vocab_size"] = config_dict.get("draft_vocab_size")
6161
if config_dict.get("target_hidden_size") is not None:
6262
pre_trained_config["target_hidden_size"] = config_dict["target_hidden_size"]
63+
for key in (
64+
"layer_types",
65+
"use_sliding_window",
66+
"sliding_window",
67+
"max_window_layers",
68+
):
69+
if key in config_dict:
70+
pre_trained_config[key] = config_dict[key]
6371

72+
# TODO: does this need to be shifted by 1 like in gpu_model_runner?
6473
aux_layer_ids = config_dict["aux_hidden_state_layer_ids"]
6574
pre_trained_config["eagle_aux_hidden_state_layer_ids"] = aux_layer_ids
6675

aphrodite/v1/attention/backends/triton_attn.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,11 @@ def __init__(
135135

136136
model_config = aphrodite_config.model_config
137137
self.num_heads_q = model_config.get_num_attention_heads(aphrodite_config.parallel_config)
138-
self.num_heads_kv = model_config.get_num_kv_heads(aphrodite_config.parallel_config)
139-
self.headdim = model_config.get_head_size()
138+
# Some models (e.g. Gemma4) use different KV/head geometry for
139+
# different attention layer groups, so size decode metadata from the
140+
# actual KV cache spec instead of the model-wide defaults.
141+
self.num_heads_kv = kv_cache_spec.num_kv_heads
142+
self.headdim = kv_cache_spec.head_size
140143

141144
# Check if CUDA Graphs are enabled for decode
142145
self.decode_cudagraph_enabled = self.aphrodite_config.compilation_config.cudagraph_mode in (

aphrodite/v1/core/kv_cache_utils.py

Lines changed: 81 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import hashlib
77
import math
88
import os
9+
import re
910
from collections import defaultdict
1011
from collections.abc import Callable, Iterable, Iterator, Sequence
1112
from dataclasses import dataclass, replace
@@ -78,6 +79,8 @@ def maybe_convert_block_hash(hash_bytes: BlockHash) -> ExternalBlockHash:
7879

7980
logger = init_logger(__name__)
8081

82+
_LAYER_INDEX_RE = re.compile(r"(?:^|\.)layers\.(\d+)(?:\.|$)")
83+
8184
# The hash seed for the first block of any prefix block sequence.
8285
#
8386
# We use a random value to avoid hash collisions or PYTHONHASHSEED environment
@@ -846,7 +849,10 @@ def may_override_num_blocks(aphrodite_config: AphroditeConfig, num_blocks: int)
846849
return num_blocks
847850

848851

849-
def _pool_bytes_per_block(kv_cache_groups: list[KVCacheGroupSpec]) -> int:
852+
def _pool_bytes_per_block(
853+
kv_cache_groups: list[KVCacheGroupSpec],
854+
aphrodite_config: AphroditeConfig | None = None,
855+
) -> int:
850856
"""
851857
Bytes consumed by one block in the worker's shared KV cache pool, mirroring
852858
the divisor used by `get_kv_cache_config_from_groups` to convert
@@ -863,7 +869,22 @@ def _pool_bytes_per_block(kv_cache_groups: list[KVCacheGroupSpec]) -> int:
863869
cast(UniformTypeKVCacheSpecs, g.kv_cache_spec).get_num_layer_tuples() for g in kv_cache_groups
864870
)
865871
return layer_tuple_page_bytes * num_layer_tuples
866-
group_size = max(len(g.layer_names) for g in kv_cache_groups)
872+
if aphrodite_config is not None:
873+
isolated_group_ids = _get_dflash_isolated_group_ids(
874+
aphrodite_config, kv_cache_groups
875+
)
876+
shared_group_size = max(
877+
(
878+
len(group.layer_names)
879+
for group_id, group in enumerate(kv_cache_groups)
880+
if group_id not in isolated_group_ids
881+
),
882+
default=0,
883+
)
884+
isolated_layers = sum(len(kv_cache_groups[group_id].layer_names) for group_id in isolated_group_ids)
885+
group_size = shared_group_size + isolated_layers
886+
else:
887+
group_size = max(len(g.layer_names) for g in kv_cache_groups)
867888
page_size = get_uniform_page_size([g.kv_cache_spec for g in kv_cache_groups])
868889
return page_size * group_size
869890

@@ -897,6 +918,35 @@ def get_uniform_page_size(kv_cache_specs: Iterable[KVCacheSpec]) -> int:
897918
return page_sizes.pop()
898919

899920

921+
def _get_dflash_isolated_group_ids(
922+
aphrodite_config: AphroditeConfig,
923+
kv_cache_groups: list[KVCacheGroupSpec],
924+
) -> set[int]:
925+
spec_config = aphrodite_config.speculative_config
926+
if spec_config is None or spec_config.method != "dflash":
927+
return set()
928+
929+
try:
930+
target_num_layers = aphrodite_config.model_config.get_num_layers(
931+
aphrodite_config.parallel_config
932+
)
933+
except Exception:
934+
return set()
935+
936+
group_ids: set[int] = set()
937+
for group_id, group in enumerate(kv_cache_groups):
938+
layer_indices: list[int] = []
939+
for layer_name in group.layer_names:
940+
match = _LAYER_INDEX_RE.search(layer_name)
941+
if match is None:
942+
layer_indices = []
943+
break
944+
layer_indices.append(int(match.group(1)))
945+
if layer_indices and all(idx >= target_num_layers for idx in layer_indices):
946+
group_ids.add(group_id)
947+
return group_ids
948+
949+
900950
def _get_kv_cache_groups_uniform_spec(
901951
kv_cache_specs: dict[str, KVCacheSpec],
902952
) -> list[KVCacheGroupSpec]:
@@ -1222,18 +1272,41 @@ def get_kv_cache_config_from_groups(
12221272
# (sw.1, padding) will be: (group_size = 2)
12231273
# full.0, sw.0, sw.1: share a Tensor with size=available_memory//2
12241274
# full.1, sw.2: share another Tensor with size=available_memory//2
1225-
group_size = max(len(group.layer_names) for group in kv_cache_groups)
1275+
# DFlash writes draft context KVs directly into cache using the draft
1276+
# block table. Do not row-share those tensors with target KV groups, or
1277+
# overlapping physical block ids can overwrite target KVs under batching.
1278+
isolated_group_ids = _get_dflash_isolated_group_ids(
1279+
aphrodite_config, kv_cache_groups
1280+
)
1281+
shared_groups = [
1282+
group
1283+
for group_id, group in enumerate(kv_cache_groups)
1284+
if group_id not in isolated_group_ids
1285+
]
1286+
isolated_layer_names = [
1287+
layer_name
1288+
for group_id in sorted(isolated_group_ids)
1289+
for layer_name in kv_cache_groups[group_id].layer_names
1290+
]
1291+
shared_group_size = (
1292+
max(len(group.layer_names) for group in shared_groups)
1293+
if shared_groups
1294+
else 0
1295+
)
1296+
group_size = shared_group_size + len(isolated_layer_names)
12261297

12271298
page_size = get_uniform_page_size([group.kv_cache_spec for group in kv_cache_groups])
12281299
assert group_size > 0, "group_size must be greater than 0"
12291300
num_blocks = get_num_blocks(aphrodite_config, group_size, available_memory, page_size)
12301301
kv_cache_tensors = []
1231-
for i in range(group_size):
1302+
for i in range(shared_group_size):
12321303
shared_by = []
1233-
for j in range(len(kv_cache_groups)):
1234-
if i < len(kv_cache_groups[j].layer_names):
1235-
shared_by.append(kv_cache_groups[j].layer_names[i])
1304+
for group in shared_groups:
1305+
if i < len(group.layer_names):
1306+
shared_by.append(group.layer_names[i])
12361307
kv_cache_tensors.append(KVCacheTensor(size=page_size * num_blocks, shared_by=shared_by))
1308+
for layer_name in isolated_layer_names:
1309+
kv_cache_tensors.append(KVCacheTensor(size=page_size * num_blocks, shared_by=[layer_name]))
12371310

12381311
return KVCacheConfig(
12391312
num_blocks=num_blocks,
@@ -1839,7 +1912,7 @@ def get_kv_cache_configs(
18391912
if not groups:
18401913
adjusted_memory.append(avail_mem)
18411914
continue
1842-
bytes_per_block = _pool_bytes_per_block(groups)
1915+
bytes_per_block = _pool_bytes_per_block(groups, aphrodite_config)
18431916
logger.info(
18441917
"Overriding num_gpu_blocks=%d with num_gpu_blocks_override=%d",
18451918
avail_mem // bytes_per_block,

0 commit comments

Comments
 (0)