Skip to content

Commit 6467da5

Browse files
authored
Merge pull request #232 from howard0su/gemma4
Gemma4: Full DFlash Integration (Speculative Decode + BSA Prefill + Prefix Cache)
2 parents 839f912 + a013349 commit 6467da5

27 files changed

Lines changed: 2592 additions & 280 deletions

dflash/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ add_library(dflash_common STATIC
215215
src/qwen35/qwen35_target_graph.cpp
216216
src/draft/draft_gguf_loader.cpp
217217
src/draft/draft_safetensors_loader.cpp
218-
src/draft/draft_dflash_graph.cpp
218+
src/draft/draft_graph.cpp
219219
src/qwen3/qwen3_drafter.cpp
220220
src/qwen3/qwen3_loader.cpp
221221
src/qwen3/qwen3_graph.cpp
@@ -225,6 +225,7 @@ add_library(dflash_common STATIC
225225
src/gemma4/gemma4_graph.cpp
226226
src/gemma4/gemma4_backend.cpp
227227
src/gemma4/gemma4_daemon.cpp
228+
src/gemma4/gemma4_dflash_target.cpp
228229
src/flashprefill_q8.cpp
229230
src/kv_cache.cpp
230231
src/kv_quant.cpp

dflash/docs/SPEC_PREFILL.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ src/
8989
qwen35_target_graph.cpp Qwen3.5/3.6 target graph (ggml)
9090
gguf_target_loader.cpp Qwen3.5 target GGUF loader
9191
draft/ Special DFlash draft model code
92-
draft_dflash_graph.cpp DFlash speculative draft head
92+
draft_graph.cpp DFlash speculative draft head
9393
draft_gguf_loader.cpp Draft GGUF loader
9494
draft_safetensors_loader.cpp Draft safetensors loader
9595
laguna/ Laguna target + daemon model code

dflash/include/dflash27b.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,12 @@ extern "C" {
2424
// Qwen3.5-27B qwen35 hybrid uses 24 Q heads, 4 KV heads, 256 head_dim, which
2525
// live in `src/internal.h` (n_embd_head_k/v, N_HEAD, N_HEAD_KV). Naming is
2626
// historical — do not change without updating draft_safetensors_loader.cpp +
27-
// draft_dflash_graph.cpp which consume these as draft-side constants.
27+
// draft_graph.cpp which consume these as draft-side constants.
2828
#define DFLASH27B_TARGET_N_HEADS 32
2929
#define DFLASH27B_TARGET_N_KV_HEADS 8
3030
#define DFLASH27B_TARGET_HEAD_DIM 128
3131
#define DFLASH27B_TARGET_INTERMEDIATE 17408
3232
#define DFLASH27B_TARGET_VOCAB 248320
33-
#define DFLASH27B_ROPE_THETA 10000000.0f
3433
#define DFLASH27B_RMS_EPS 1e-6f
3534

3635
#define DFLASH27B_DRAFT_LAYERS 5

dflash/src/common/backend_factory.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,13 @@ std::unique_ptr<ModelBackend> create_backend(const BackendArgs & args) {
8888

8989
} else if (arch == "gemma4") {
9090
Gemma4BackendConfig gcfg;
91-
gcfg.model_path = args.model_path;
92-
gcfg.device = args.device;
93-
gcfg.stream_fd = args.stream_fd;
94-
gcfg.chunk = args.chunk;
91+
gcfg.model_path = args.model_path;
92+
gcfg.draft_path = args.draft_path;
93+
gcfg.draft_gpu = args.draft_gpu;
94+
gcfg.draft_ctx_max = args.draft_ctx_max;
95+
gcfg.device = args.device;
96+
gcfg.stream_fd = args.stream_fd;
97+
gcfg.chunk = args.chunk;
9598

9699
auto backend = std::make_unique<Gemma4Backend>(gcfg);
97100
if (!backend->init()) {

dflash/src/common/dflash_draft_graph.cpp

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,26 @@
22
#include "draft/draft_graph.h" // DraftGraphInputs, DraftGraphOutputs, build_draft_graph
33

44
#include "ggml-alloc.h"
5+
#include "ggml-backend.h"
56

7+
#include <cstdint>
68
#include <cstdio>
9+
#include <vector>
710

811
namespace dflash::common {
912

13+
// Minimum alignment required by ggml flash_attn_ext for mask rows.
14+
static constexpr int MASK_KV_PAD = 32;
15+
16+
static inline int mask_align_up(int x, int a) { return ((x + a - 1) / a) * a; }
17+
18+
// Check whether any layer in the draft is SWA.
19+
static bool draft_has_swa_layers(const DraftWeights & dw) {
20+
for (int i = 0; i < dw.n_layer; i++)
21+
if (dw.layers[i].is_swa) return true;
22+
return false;
23+
}
24+
1025
// Build draft graph at a given ctx_len into sg. Does NOT touch sg.alloc.
1126
// mirror_view: if true, uses a view into mirror->target_feat at slot0.
1227
static bool build_draft_graph_internal(
@@ -56,6 +71,21 @@ static bool build_draft_graph_internal(
5671
ggml_set_name(sg.positions_k, "positions_k");
5772
ggml_set_input(sg.positions_k);
5873

74+
// Causal mask for SWA layers (if any).
75+
// Shape: [kv_pad, q_len] F16 (directly, no cast needed — matches attn_masks.h pattern).
76+
sg.attn_mask = nullptr;
77+
const bool has_swa = draft_has_swa_layers(dw);
78+
if (has_swa) {
79+
// SWA layers' effective KV length (windowed or full ctx)
80+
const bool swa_active = dw.swa_window > 0 && ctx_len > dw.swa_window;
81+
const int eff_ctx = swa_active ? dw.swa_window : ctx_len;
82+
const int eff_total_k = eff_ctx + q_len;
83+
const int kv_pad = mask_align_up(eff_total_k, MASK_KV_PAD);
84+
sg.attn_mask = ggml_new_tensor_2d(sg.ctx, GGML_TYPE_F16, kv_pad, q_len);
85+
ggml_set_name(sg.attn_mask, "causal_mask_swa");
86+
ggml_set_input(sg.attn_mask);
87+
}
88+
5989
sg.gf = ggml_new_graph_custom(sg.ctx, 4096, false);
6090

6191
DraftGraphInputs gi{};
@@ -65,6 +95,7 @@ static bool build_draft_graph_internal(
6595
gi.positions_q = sg.positions;
6696
gi.positions_k = sg.positions_k;
6797
gi.lm_head = lm_head;
98+
gi.causal_mask_swa = sg.attn_mask;
6899
DraftGraphOutputs go = build_draft_graph(sg.ctx, dw, gi);
69100
sg.hidden_states = go.hidden_states;
70101
sg.logits = go.logits;
@@ -125,7 +156,35 @@ bool build_draft_step(
125156
return false;
126157
}
127158

128-
return ggml_gallocr_alloc_graph(sg.alloc, sg.gf);
159+
if (!ggml_gallocr_alloc_graph(sg.alloc, sg.gf)) {
160+
return false;
161+
}
162+
163+
// Fill causal mask data for SWA layers (after allocation gives memory to the tensor).
164+
if (sg.attn_mask) {
165+
const int q_len = dw.block_size;
166+
const bool swa_active = dw.swa_window > 0 && ctx_len > dw.swa_window;
167+
const int eff_ctx = swa_active ? dw.swa_window : ctx_len;
168+
const int eff_total_k = eff_ctx + q_len;
169+
const int kv_pad = mask_align_up(eff_total_k, MASK_KV_PAD);
170+
171+
// Build causal mask in F16 directly (same pattern as attn_masks.h):
172+
// Context keys (k < eff_ctx): always visible.
173+
// Noise keys (k = eff_ctx + j): visible if j <= q (causal).
174+
static constexpr uint16_t ZERO = 0x0000;
175+
static constexpr uint16_t NEG_INF = 0xFC00;
176+
std::vector<uint16_t> mask_data((size_t)kv_pad * q_len, NEG_INF);
177+
for (int q = 0; q < q_len; q++) {
178+
for (int k = 0; k < eff_ctx; k++)
179+
mask_data[(size_t)q * kv_pad + k] = ZERO;
180+
for (int j = 0; j <= q; j++)
181+
mask_data[(size_t)q * kv_pad + (eff_ctx + j)] = ZERO;
182+
}
183+
ggml_backend_tensor_set(sg.attn_mask, mask_data.data(), 0,
184+
sizeof(uint16_t) * mask_data.size());
185+
}
186+
187+
return true;
129188
}
130189

131190
} // namespace dflash::common

dflash/src/draft/draft_gguf_loader.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// on the CUDA backend.
33
//
44
// This is the Q8_0-quantized counterpart of draft_safetensors_loader.cpp. The
5-
// draft graph builder (draft_dflash_graph.cpp) doesn't care about tensor storage
5+
// draft graph builder (draft_graph.cpp) doesn't care about tensor storage
66
// types — ggml's ggml_mul_mat handles Q8_0 × F32 dequantization transparently.
77
//
88
// GGUF arch: "qwen35-dflash-draft" (from convert_dflash_to_gguf.py /
@@ -159,6 +159,12 @@ bool load_draft_gguf(const std::string & path,
159159
std::snprintf(key, sizeof(key), "%s.%s", A, suffix);
160160
return get_u32_or(gctx, key, fallback);
161161
};
162+
auto read_f32 = [&](const char * suffix, float fallback) -> float {
163+
std::snprintf(key, sizeof(key), "%s.%s", A, suffix);
164+
int64_t id = gguf_find_key(gctx, key);
165+
if (id < 0) return fallback;
166+
return gguf_get_val_f32(gctx, id);
167+
};
162168

163169
const uint32_t n_embd = read_u32("embedding_length", 0);
164170
const uint32_t n_layer = read_u32("block_count", 0);
@@ -235,6 +241,10 @@ bool load_draft_gguf(const std::string & path,
235241
out.head_dim = (int)head_dim;
236242
out.n_embd = (int)n_embd;
237243
out.n_ff = (int)n_ff;
244+
out.rope_theta = read_f32("rope.freq_base", 0.0f);
245+
if (out.rope_theta == 0.0f) {
246+
fprintf(stderr, "[draft-gguf] WARNING: rope.freq_base not found in GGUF, draft RoPE will be wrong\n");
247+
}
238248
out.layers.assign((size_t)n_layer, DraftLayer{});
239249

240250
auto g = [&](const char * name) -> ggml_tensor * {
Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,29 @@
11
// Builds a ggml compute graph for one forward pass of the DFlash draft
2-
// (5-layer non-causal Qwen3-flavored block-diffusion model).
2+
// (5-layer Qwen3-flavored block-diffusion model).
33
//
44
// Stateless: no KV cache. Each call takes:
5-
// - noise_embed [hidden, q_len, 1] bf16 (target.tok_embd on [last_tok, MASK*15])
6-
// - target_hidden_cat [5*hidden, ctx_len, 1] bf16 (5 target layers concat along features)
5+
// - noise_embed [hidden, q_len, 1] f32 (target.tok_embd on [last_tok, MASK*15])
6+
// - target_hidden_cat [N*hidden, ctx_len, 1] f32 (N target layers concat along features)
77
// - positions_q [q_len] i32 values [ctx_len..ctx_len+q_len-1]
88
// - positions_k [ctx_len+q_len] i32 values [0..ctx_len+q_len-1]
9+
// - causal_mask_swa [kv_pad, q_len] f32 (optional; causal mask for SWA layers)
910
// and returns:
10-
// - hidden_states [hidden, q_len, 1] bf16 (final RMSNorm; NO lm_head here)
11+
// - hidden_states [hidden, q_len, 1] f32 (final RMSNorm; NO lm_head here)
1112
//
1213
// The caller projects `hidden_states` through the TARGET's lm_head separately
1314
// (the draft has no lm_head of its own, it shares the target's).
1415
//
15-
// Semantics match megaqwen3_27b_dflash/reference/dflash_reference.py exactly:
16+
// Semantics:
1617
// - fc @ target_hidden_cat -> rms_norm with hidden_norm -> target_feat
17-
// - Per layer (non-causal):
18+
// - Per layer:
1819
// h_norm = rms_norm(h) * input_layernorm
1920
// Q = wq @ h_norm -> per-head q_norm
2021
// K_ctx/V_ctx = wk/wv @ target_feat
2122
// K_noi/V_noi = wk/wv @ h_norm
2223
// K = concat[K_ctx, K_noi] -> per-head k_norm
2324
// V = concat[V_ctx, V_noi]
24-
// RoPE(Q, positions_q); RoPE(K, positions_k) (NEOX style, theta=10M)
25-
// attn = flash_attn_ext(Q, K, V, mask=null, scale=1/sqrt(head_dim)) non-causal
25+
// RoPE(Q, positions_q); RoPE(K, positions_k) (NEOX style)
26+
// attn = flash_attn_ext(Q, K, V, mask, scale) SWA=causal, full=non-causal
2627
// h += wo @ attn
2728
// h_norm = rms_norm(h) * post_attention_layernorm
2829
// h += w_down @ (silu(w_gate @ h_norm) * (w_up @ h_norm))
@@ -46,7 +47,7 @@ DraftGraphOutputs build_draft_graph(
4647
const int n_kv = w.n_head_kv;
4748
const int head_dim = w.head_dim;
4849
const float eps = DFLASH27B_RMS_EPS;
49-
const float rope_base = DFLASH27B_ROPE_THETA;
50+
const float rope_base = w.rope_theta;
5051

5152
// ── 1. Feature fusion: target_feat = rms_norm(fc @ target_hidden_cat, hidden_norm)
5253
// fc: [5*hidden, hidden] (ggml: ne[0]=5*hidden, ne[1]=hidden)
@@ -134,9 +135,10 @@ DraftGraphOutputs build_draft_graph(
134135
V = ggml_permute(ctx, V, 0, 2, 1, 3); // [head_dim, eff_total_k, n_kv, 1]
135136
V = ggml_cont (ctx, V);
136137

137-
// ── 2f. Non-causal flash attention; GQA broadcast handled internally.
138+
// ── 2f. Attention: causal for SWA layers, non-causal for full layers.
138139
const float scale = 1.0f / std::sqrt((float)head_dim);
139-
ggml_tensor * attn = ggml_flash_attn_ext(ctx, Q, K, V, /*mask=*/nullptr,
140+
ggml_tensor * mask = (L.is_swa && in.causal_mask_swa) ? in.causal_mask_swa : nullptr;
141+
ggml_tensor * attn = ggml_flash_attn_ext(ctx, Q, K, V, mask,
140142
scale, /*max_bias=*/0.0f,
141143
/*logit_softcap=*/0.0f);
142144
// attn result: [n_embd_v=head_dim, n_head, n_batch=q_len, 1]

dflash/src/draft/draft_graph.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ struct DraftGraphInputs {
1818
// hidden states. Used for DFlash integration where the draft shares the
1919
// target's lm_head.
2020
ggml_tensor * lm_head;
21+
// Optional: causal mask for SWA layers [kv_pad, q_len] F16.
22+
// nullptr = all layers non-causal.
23+
ggml_tensor * causal_mask_swa = nullptr;
2124
};
2225

2326
struct DraftGraphOutputs {

dflash/src/flashprefill.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,37 @@ int flash_prefill_forward_q8(
9090
ggml_type qkv_type,
9191
const FlashPrefillConfig & cfg);
9292

93+
// ── Unified dispatch ──────────────────────────────────────────────────────────
94+
// Picks the best available kernel at compile time + runtime buffer type:
95+
// BF16 buffers + sm_80 build → flash_prefill_forward_bf16
96+
// F16 buffers + Volta build → flash_prefill_forward_f16
97+
// otherwise → flash_prefill_forward_q8 (ggml FA fallback)
98+
//
99+
// Callers no longer need to duplicate the ifdef/dispatch boilerplate.
100+
inline int flash_prefill_forward(
101+
ggml_backend_t backend,
102+
const void * Q, const void * K, const void * V, void * O,
103+
int batch, int seq_len, int n_q_heads, int n_k_heads, int head_dim,
104+
float scale,
105+
ggml_type qkv_type,
106+
const FlashPrefillConfig & cfg)
107+
{
108+
#if defined(DFLASH27B_HAVE_FLASHPREFILL) || defined(DFLASH27B_HAVE_SM80_FLASHPREFILL)
109+
if (qkv_type == GGML_TYPE_BF16) {
110+
return flash_prefill_forward_bf16(Q, K, V, O,
111+
batch, seq_len, n_q_heads, n_k_heads, head_dim, scale, cfg);
112+
}
113+
#endif
114+
#if defined(DFLASH27B_HAVE_VOLTA_FLASHPREFILL) || defined(DFLASH27B_HAVE_PASCAL_FLASHPREFILL)
115+
if (qkv_type == GGML_TYPE_F16) {
116+
return flash_prefill_forward_f16(Q, K, V, O,
117+
batch, seq_len, n_q_heads, n_k_heads, head_dim, scale, cfg);
118+
}
119+
#endif
120+
return flash_prefill_forward_q8(backend, Q, K, V, O,
121+
batch, seq_len, n_q_heads, n_k_heads, head_dim, scale, qkv_type, cfg);
122+
}
123+
93124
#ifdef DFLASH27B_HAVE_BSA
94125
// Free BSA persistent device buffers (blockmask, head_mask_type, softmax_lse).
95126
// Safe to call any time; idempotent. Useful before unloading the drafter to

0 commit comments

Comments
 (0)