From d5f32bff4a44de3432c7580cadf3d70fc0995985 Mon Sep 17 00:00:00 2001 From: Howard Su Date: Wed, 20 May 2026 14:22:18 +0800 Subject: [PATCH 01/18] gemma4: fix loader + graph for actual GGUF format Loader fixes: - Handle array-typed metadata (head_count_kv is per-layer array) - Fallback n_vocab from token_embd.weight tensor shape - Default missing keys (expert_count, etc.) to 0 - Separate head_dim_full (512) and head_dim_swa (256) - Per-layer n_head_kv_per_layer vector from GGUF array - SWA pattern: read bool/uint8 array or infer from head_kv - Tied embeddings: output = tok_embd when output.weight absent - Tensor name mapping: post_attention_norm, post_ffw_norm, layer_output_scale - Global rope_freqs_global tensor support Graph fixes: - Per-layer head_dim and n_head_kv via helper functions - FA mask padding to 256 (FATTN_KQ_STRIDE) for CUDA compat - Use global rope_freqs for full-attn layers Cache: - Per-layer KV allocation with correct dimensions Validated: load + prefill + decode + snapshot + restore all pass on gemma-4-31B-it-Q4_K_M.gguf (RTX 2080 Ti). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- dflash/src/gemma4/gemma4_backend.cpp | 113 ++++++++++++++++++++++--- dflash/src/gemma4/gemma4_backend.h | 6 +- dflash/src/gemma4/gemma4_graph.cpp | 102 +++++++++++++++++------ dflash/src/gemma4/gemma4_internal.h | 23 +++++- dflash/src/gemma4/gemma4_loader.cpp | 118 ++++++++++++++++++++------- 5 files changed, 288 insertions(+), 74 deletions(-) diff --git a/dflash/src/gemma4/gemma4_backend.cpp b/dflash/src/gemma4/gemma4_backend.cpp index 5b6a0725..922f7051 100644 --- a/dflash/src/gemma4/gemma4_backend.cpp +++ b/dflash/src/gemma4/gemma4_backend.cpp @@ -1,6 +1,6 @@ // Gemma4Backend implementation. // -// Uses gemma4_step() for forward pass (currently stubbed). +// Uses gemma4_step() for forward pass with per-layer embedding support. // Structure mirrors Qwen3Backend: prefill in chunks, autoregressive decode, // KV cache with layer sharing, snapshot/restore. @@ -84,7 +84,7 @@ bool Gemma4Backend::unpark(const std::string & what) { // ── Prefill ──────────────────────────────────────────────────────────── int Gemma4Backend::do_prefill(const std::vector & tokens, - const DaemonIO & io) { + const DaemonIO & io, int kv_offset) { (void)io; const int n = (int)tokens.size(); const int hidden = w_.n_embd; @@ -104,16 +104,18 @@ int Gemma4Backend::do_prefill(const std::vector & tokens, float scale = std::sqrt((float)hidden); for (int i = 0; i < len * hidden; ++i) embed[i] *= scale; - if (!gemma4_step(backend_, w_, cache_, embed.data(), len, pos, logits)) { - std::fprintf(stderr, "[gemma4] prefill step failed at pos=%d\n", pos); + const int kv_pos = kv_offset + pos; + if (!gemma4_step(backend_, w_, cache_, embed.data(), + tokens.data() + pos, len, kv_pos, logits)) { + std::fprintf(stderr, "[gemma4] prefill step failed at pos=%d\n", kv_pos); return -1; } pos += len; - cache_.cur_pos = pos; + cache_.cur_pos = kv_offset + pos; } - return pos; + return kv_offset + pos; } // ── Decode ───────────────────────────────────────────────────────────── @@ -134,7 +136,8 @@ bool Gemma4Backend::do_decode(int committed, int n_gen, float scale = std::sqrt((float)hidden); for (int j = 0; j < hidden; ++j) embed_buf[j] *= scale; - if (!gemma4_step(backend_, w_, cache_, embed_buf.data(), 1, committed, logits)) { + if (!gemma4_step(backend_, w_, cache_, embed_buf.data(), + &tok, 1, committed, logits)) { return false; } @@ -177,7 +180,7 @@ GenerateResult Gemma4Backend::generate(const GenerateRequest & req, cache_.cur_pos = 0; - const int committed = do_prefill(req.prompt, out_io); + const int committed = do_prefill(req.prompt, out_io, /*kv_offset=*/0); if (committed < 0) { result.error = "prefill"; return result; @@ -195,8 +198,8 @@ GenerateResult Gemma4Backend::generate(const GenerateRequest & req, float scale = std::sqrt((float)hidden); for (int j = 0; j < hidden; ++j) embed_buf[j] *= scale; - if (!gemma4_step(backend_, w_, cache_, embed_buf.data(), 1, - committed - 1, logits)) { + if (!gemma4_step(backend_, w_, cache_, embed_buf.data(), + &last_tok, 1, committed - 1, logits)) { result.error = "first logits"; return result; } @@ -246,9 +249,11 @@ GenerateResult Gemma4Backend::restore_and_generate(int slot, const GenerateRequest & req, const DaemonIO & io) { GenerateResult result; + DaemonIO out_io = io.with_token_callback(req.on_token); + if (slot < 0 || slot >= PREFIX_SLOTS || !snapshots_[slot].ctx) { result.error = "bad slot"; - io.emit(-1); + out_io.emit(-1); return result; } @@ -261,9 +266,91 @@ GenerateResult Gemma4Backend::restore_and_generate(int slot, ggml_backend_tensor_set(cache_.v[il], snap.v_snap[il]->data, 0, nbytes); } } - cache_.cur_pos = snap.cur_pos; - return generate(req, io); + const int snap_pos = snap.cur_pos; + cache_.cur_pos = snap_pos; + + // Set up sampler + sampler_ = req.sampler; + if (req.do_sample && sampler_.seed != 0) { + sampler_rng_.seed(sampler_.seed); + } + + // Diff-prefill: only prefill tokens beyond the cached prefix + const int prompt_len = (int)req.prompt.size(); + int committed = snap_pos; + + if (prompt_len > snap_pos) { + // Compute delta (tokens after the snapshot) + std::vector delta(req.prompt.begin() + snap_pos, req.prompt.end()); + committed = do_prefill(delta, out_io, /*kv_offset=*/snap_pos); + if (committed < 0) { + result.error = "prefill"; + return result; + } + } else if (prompt_len > 0 && prompt_len < snap_pos) { + result.error = "snapshot_longer_than_prompt"; + out_io.emit(-1); + return result; + } + // else: prompt_len == snap_pos → no delta, committed stays at snap_pos + + // Generate + if (req.n_gen > 0) { + const int hidden = w_.n_embd; + const int vocab = w_.n_vocab; + std::vector logits; + + // Re-step last token to get logits for first generated token + int32_t last_tok = req.prompt.back(); + std::vector embed_buf(hidden); + w_.embedder.embed(&last_tok, 1, embed_buf.data()); + float scale = std::sqrt((float)hidden); + for (int j = 0; j < hidden; ++j) embed_buf[j] *= scale; + + if (!gemma4_step(backend_, w_, cache_, embed_buf.data(), + &last_tok, 1, committed - 1, logits)) { + result.error = "first logits"; + return result; + } + + // Sample first token + int32_t first; + if (sampler_.temp > 0) { + first = sample_logits(logits.data(), vocab, sampler_, + result.tokens, sampler_rng_); + } else { + first = 0; + float best = logits[0]; + for (int j = 1; j < vocab; ++j) { + if (logits[j] > best) { best = logits[j]; first = j; } + } + } + result.tokens.push_back(first); + out_io.emit(first); + if (out_io.cancelled) { + out_io.emit(-1); + result.ok = true; + return result; + } + + if (first == w_.eos_id || first == w_.eos_chat_id) { + out_io.emit(-1); + result.ok = true; + return result; + } + + if (req.n_gen > 1) { + if (!do_decode(committed, req.n_gen - 1, result.tokens, out_io)) { + result.error = "decode"; + return result; + } + } + } + + out_io.emit(-1); + result.ok = true; + return result; } // ── Snapshots ────────────────────────────────────────────────────────── diff --git a/dflash/src/gemma4/gemma4_backend.h b/dflash/src/gemma4/gemma4_backend.h index 84fa08b3..18c26fda 100644 --- a/dflash/src/gemma4/gemma4_backend.h +++ b/dflash/src/gemma4/gemma4_backend.h @@ -80,8 +80,10 @@ class Gemma4Backend : public ModelBackend { static constexpr int PREFIX_SLOTS = 64; Gemma4Snapshot snapshots_[PREFIX_SLOTS]; - // Prefill prompt tokens in chunks, return committed position. - int do_prefill(const std::vector & tokens, const DaemonIO & io); + // Prefill prompt tokens in chunks, return absolute committed position. + // kv_offset: starting KV cache position (0 for fresh prefill, snap_pos for restore). + int do_prefill(const std::vector & tokens, const DaemonIO & io, + int kv_offset = 0); // Autoregressive decode loop. bool do_decode(int committed, int n_gen, diff --git a/dflash/src/gemma4/gemma4_graph.cpp b/dflash/src/gemma4/gemma4_graph.cpp index c4522edd..e218ac60 100644 --- a/dflash/src/gemma4/gemma4_graph.cpp +++ b/dflash/src/gemma4/gemma4_graph.cpp @@ -150,9 +150,9 @@ static ggml_tensor * build_gemma4_attn_block( int kv_start, int n_tokens) { - const int head_dim = w.head_dim; + const int head_dim = gemma4_head_dim(w, il); const int n_head = w.n_head; - const int n_head_kv = w.n_head_kv; + const int n_head_kv = gemma4_n_head_kv(w, il); const int q_dim = n_head * head_dim; const bool is_swa = gemma4_is_swa_layer(w, il); const bool has_kv = gemma4_has_kv(w, il); @@ -168,7 +168,7 @@ static ggml_tensor * build_gemma4_attn_block( // RoPE for Q const float rope_base = is_swa ? w.rope_freq_base_swa : w.rope_freq_base_full; - ggml_tensor * freq_factors = is_swa ? nullptr : L.rope_freqs; + ggml_tensor * freq_factors = is_swa ? nullptr : (L.rope_freqs ? L.rope_freqs : w.rope_freqs_global); Qcur = ggml_rope_ext(ctx, Qcur, positions, freq_factors, head_dim, GGML_ROPE_TYPE_NEOX, 0, rope_base, 1.0f, @@ -216,7 +216,9 @@ static ggml_tensor * build_gemma4_attn_block( // else: KV-sharing layer — cache already written by source layer // Flash attention - const int kv_len = kv_start + n_tokens; + // Pad kv_len to multiple of 256 for CUDA FA kernel compatibility (FATTN_KQ_STRIDE=256) + const int kv_len_raw = kv_start + n_tokens; + const int kv_len = (kv_len_raw + 255) & ~255; // round up to 256 ggml_tensor * Qfa = ggml_permute(ctx, Qcur, 0, 2, 1, 3); Qfa = ggml_cont(ctx, Qfa); @@ -315,11 +317,19 @@ static ggml_tensor * build_gemma4_layer( return cur; } +// Helper: get a 2D slice from a 3D tensor along ne[2] (same as llama.cpp ggml_view_2d_slice). +static ggml_tensor * gemma4_view_2d_slice(ggml_context * ctx, ggml_tensor * x, int idx) { + return ggml_view_2d(ctx, x, x->ne[0], x->ne[1], + ggml_row_size(x->type, x->ne[0]), + (size_t)idx * x->ne[0] * x->ne[1] * ggml_element_size(x)); +} + bool gemma4_step( ggml_backend_t backend, const Gemma4Weights & w, Gemma4Cache & cache, const float * embed, + const int32_t * token_ids, int n_tokens, int kv_start, std::vector & out_logits) @@ -337,32 +347,65 @@ bool gemma4_step( ggml_tensor * pp = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_tokens); ggml_set_input(pp); + // Token IDs input (for per-layer embedding lookup) + ggml_tensor * tok_ids = nullptr; + if (token_ids && w.per_layer_tok_embd && w.per_layer_model_proj && w.n_embd_per_layer > 0) { + tok_ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_tokens); + ggml_set_input(tok_ids); + } + // Attention masks (full + SWA) - const int kv_len = kv_start + n_tokens; - ggml_tensor * mk_full = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv_len, n_tokens, 1, 1); + // Pad kv_len to 256 for CUDA FA kernel compatibility (FATTN_KQ_STRIDE=256) + const int kv_len_raw = kv_start + n_tokens; + const int kv_len_padded = (kv_len_raw + 255) & ~255; + ggml_tensor * mk_full = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv_len_padded, n_tokens, 1, 1); ggml_set_input(mk_full); ggml_tensor * mk_full_f16 = ggml_cast(ctx, mk_full, GGML_TYPE_F16); - ggml_tensor * mk_swa = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv_len, n_tokens, 1, 1); + ggml_tensor * mk_swa = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv_len_padded, n_tokens, 1, 1); ggml_set_input(mk_swa); ggml_tensor * mk_swa_f16 = ggml_cast(ctx, mk_swa, GGML_TYPE_F16); - // Per-layer embedding input (if model has per-layer embeddings) - // For simplicity, we precompute per-layer inputs for all layers at once - // Shape: [n_embd_per_layer, n_tokens, n_layer] → slice per layer - ggml_tensor * per_layer_all = nullptr; - if (w.per_layer_tok_embd && w.per_layer_model_proj && w.n_embd_per_layer > 0) { - // We need token IDs for per-layer embedding lookup, but we only have - // float embeddings at this point. Per-layer embedding requires a separate - // token ID input. For now, skip per-layer embeddings in the step function - // (they're computed on the embedding path in the backend). - // TODO: Add token ID input to gemma4_step for per-layer embedding support. + // Per-layer embedding computation (reference: gemma4-iswa.cpp build_inp_per_layer + project_per_layer_inputs) + ggml_tensor * per_layer_all = nullptr; // final shape: [n_embd_per_layer, n_tokens, n_layer] + if (tok_ids) { + const int D = w.n_embd_per_layer; + const int L = w.n_layer; + + // 1. Token per-layer embedding lookup + scale + // get_rows(per_layer_tok_embd[D*L, n_vocab], tok_ids) → [D*L, n_tokens] + ggml_tensor * inp_pl = ggml_get_rows(ctx, w.per_layer_tok_embd, tok_ids); + inp_pl = ggml_reshape_3d(ctx, inp_pl, D, L, n_tokens); // [D, L, n_tokens] + inp_pl = ggml_scale(ctx, inp_pl, std::sqrt((float)D)); + + // 2. Project main embedding through per_layer_model_proj + // mul_mat(per_layer_model_proj[n_embd, D*L], ie[n_embd, n_tokens]) → [D*L, n_tokens] + ggml_tensor * proj = ggml_mul_mat(ctx, w.per_layer_model_proj, ie); + proj = ggml_scale(ctx, proj, 1.0f / std::sqrt((float)w.n_embd)); + proj = ggml_reshape_3d(ctx, proj, D, L, n_tokens); // [D, L, n_tokens] + + // 3. RMS norm on projection (normalizes over ne[0]=D for each (layer, token)) + proj = ggml_rms_norm(ctx, proj, w.norm_eps); + // Reshape norm weight from [D*L] to [D, L] for broadcast mul over n_tokens + ggml_tensor * norm_w = ggml_reshape_2d(ctx, w.per_layer_proj_norm, D, L); + proj = ggml_mul(ctx, proj, norm_w); + + // 4. Add token embedding + projection, scale by 1/sqrt(2) + per_layer_all = ggml_add(ctx, proj, inp_pl); + per_layer_all = ggml_scale(ctx, per_layer_all, 1.0f / std::sqrt(2.0f)); + + // 5. Permute to [D, n_tokens, L] for easy per-layer slicing + per_layer_all = ggml_cont(ctx, ggml_permute(ctx, per_layer_all, 0, 2, 1, 3)); } // Build the graph ggml_tensor * cur = ie; // [n_embd, n_tokens] already scaled by sqrt(n_embd) in caller for (int il = 0; il < w.n_layer; ++il) { - ggml_tensor * pl_input = nullptr; // TODO: per-layer embedding per layer + ggml_tensor * pl_input = nullptr; + if (per_layer_all) { + // Slice [n_embd_per_layer, n_tokens] for this layer + pl_input = gemma4_view_2d_slice(ctx, per_layer_all, il); + } cur = build_gemma4_layer(ctx, gf, w, cache, il, cur, pp, mk_full_f16, mk_swa_f16, pl_input, kv_start, n_tokens); @@ -406,24 +449,29 @@ bool gemma4_step( for (int i = 0; i < n_tokens; ++i) pos[i] = kv_start + i; ggml_backend_tensor_set(pp, pos.data(), 0, ggml_nbytes(pp)); - // Causal mask (full attention) - std::vector mfull((size_t)kv_len * n_tokens, -INFINITY); + // Set token IDs for per-layer embedding + if (tok_ids && token_ids) { + ggml_backend_tensor_set(tok_ids, token_ids, 0, (size_t)n_tokens * sizeof(int32_t)); + } + + // Causal mask (full attention) — padded positions are masked with -inf + std::vector mfull((size_t)kv_len_padded * n_tokens, -INFINITY); for (int q = 0; q < n_tokens; ++q) { const int abs_q = kv_start + q; - for (int k = 0; k <= abs_q && k < kv_len; ++k) { - mfull[(size_t)q * kv_len + k] = 0.0f; + for (int k = 0; k <= abs_q && k < kv_len_raw; ++k) { + mfull[(size_t)q * kv_len_padded + k] = 0.0f; } } ggml_backend_tensor_set(mk_full, mfull.data(), 0, ggml_nbytes(mk_full)); - // SWA mask - std::vector mswa((size_t)kv_len * n_tokens, -INFINITY); + // SWA mask — padded positions are masked with -inf + std::vector mswa((size_t)kv_len_padded * n_tokens, -INFINITY); const int W = w.sliding_window; for (int q = 0; q < n_tokens; ++q) { const int abs_q = kv_start + q; const int win_lo = std::max(0, abs_q - W + 1); - for (int k = win_lo; k <= abs_q && k < kv_len; ++k) { - mswa[(size_t)q * kv_len + k] = 0.0f; + for (int k = win_lo; k <= abs_q && k < kv_len_raw; ++k) { + mswa[(size_t)q * kv_len_padded + k] = 0.0f; } } ggml_backend_tensor_set(mk_swa, mswa.data(), 0, ggml_nbytes(mk_swa)); @@ -440,7 +488,7 @@ bool gemma4_step( ggml_backend_tensor_get(cur, out_logits.data(), 0, out_logits.size() * sizeof(float)); - cache.cur_pos = kv_len; + cache.cur_pos = kv_len_raw; ggml_free(ctx); return true; } diff --git a/dflash/src/gemma4/gemma4_internal.h b/dflash/src/gemma4/gemma4_internal.h index d34107b7..dd165a15 100644 --- a/dflash/src/gemma4/gemma4_internal.h +++ b/dflash/src/gemma4/gemma4_internal.h @@ -79,7 +79,8 @@ struct Gemma4Weights { // Global tensors ggml_tensor * tok_embd = nullptr; // [n_embd, n_vocab] ggml_tensor * out_norm = nullptr; // [n_embd] - ggml_tensor * output = nullptr; // [n_embd, n_vocab] (lm_head) + ggml_tensor * output = nullptr; // [n_embd, n_vocab] (lm_head, may be tied to tok_embd) + ggml_tensor * rope_freqs_global = nullptr; // [head_dim/2] global rope freq factors ggml_tensor * per_layer_tok_embd = nullptr; // [n_embd_per_layer * n_layer, n_vocab] ggml_tensor * per_layer_model_proj = nullptr; // [n_embd, n_embd_per_layer * n_layer] ggml_tensor * per_layer_proj_norm = nullptr; // [n_embd_per_layer * n_layer] @@ -91,8 +92,9 @@ struct Gemma4Weights { // Architecture metadata int n_layer = 0; int n_head = 0; - int n_head_kv = 0; - int head_dim = 128; + int n_head_kv = 0; // max n_head_kv (for backward compat) + int head_dim = 128; // head_dim for SWA layers (smaller) + int head_dim_full = 128; // head_dim for full-attention layers int n_embd = 0; int n_ff = 0; // dense FFN intermediate int n_ff_exp = 0; // expert FFN intermediate @@ -103,6 +105,9 @@ struct Gemma4Weights { int n_embd_per_layer = 0; // per-layer embedding dim int n_vocab = 0; + // Per-layer head counts (Gemma4 can have variable n_head_kv per layer) + std::vector n_head_kv_per_layer; + // iSWA int sliding_window = 0; std::vector swa_layers; // true = SWA, false = full attn @@ -132,6 +137,15 @@ inline bool gemma4_has_kv(const Gemma4Weights & w, int il) { return il < (int)w.has_kv.size() && w.has_kv[il]; } +inline int gemma4_head_dim(const Gemma4Weights & w, int il) { + return gemma4_is_swa_layer(w, il) ? w.head_dim : w.head_dim_full; +} + +inline int gemma4_n_head_kv(const Gemma4Weights & w, int il) { + if (il < (int)w.n_head_kv_per_layer.size()) return w.n_head_kv_per_layer[il]; + return w.n_head_kv; +} + // GGUF loader bool load_gemma4_gguf(const std::string & path, ggml_backend_t backend, @@ -172,11 +186,14 @@ void free_gemma4_snapshot(Gemma4Snapshot & s); // Forward: run a single step (prefill chunk or decode token). // Returns logits for last token. +// token_ids: raw token IDs needed for per-layer embedding lookup (may be nullptr +// if the model has no per-layer embeddings). bool gemma4_step( ggml_backend_t backend, const Gemma4Weights & w, Gemma4Cache & cache, const float * embed, + const int32_t * token_ids, int n_tokens, int kv_start, std::vector & out_logits); diff --git a/dflash/src/gemma4/gemma4_loader.cpp b/dflash/src/gemma4/gemma4_loader.cpp index d40db53a..62cab9c8 100644 --- a/dflash/src/gemma4/gemma4_loader.cpp +++ b/dflash/src/gemma4/gemma4_loader.cpp @@ -14,10 +14,12 @@ #include "internal.h" #include "dflash27b.h" +#include #include #include #include #include +#include #if !defined(_WIN32) #include @@ -54,12 +56,32 @@ struct Gemma4Mmap { uint32_t get_u32_or(gguf_context * g, const char * key, uint32_t def) { int64_t id = gguf_find_key(g, key); - return (id >= 0) ? gguf_get_val_u32(g, id) : def; + if (id < 0) return def; + // Handle array type: return first element + if (gguf_get_kv_type(g, id) == GGUF_TYPE_ARRAY) { + if (gguf_get_arr_n(g, id) == 0) return def; + return ((const uint32_t *)gguf_get_arr_data(g, id))[0]; + } + return gguf_get_val_u32(g, id); } float get_f32_or(gguf_context * g, const char * key, float def) { int64_t id = gguf_find_key(g, key); - return (id >= 0) ? gguf_get_val_f32(g, id) : def; + if (id < 0) return def; + if (gguf_get_kv_type(g, id) == GGUF_TYPE_ARRAY) { + if (gguf_get_arr_n(g, id) == 0) return def; + return ((const float *)gguf_get_arr_data(g, id))[0]; + } + return gguf_get_val_f32(g, id); +} + +// Read a u32 array key into a vector (empty if not found or not an array). +std::vector get_u32_arr(gguf_context * g, const char * key) { + int64_t id = gguf_find_key(g, key); + if (id < 0 || gguf_get_kv_type(g, id) != GGUF_TYPE_ARRAY) return {}; + const size_t n = gguf_get_arr_n(g, id); + const uint32_t * data = (const uint32_t *)gguf_get_arr_data(g, id); + return std::vector(data, data + n); } ggml_tensor * find_tensor(ggml_context * ctx, const char * name) { @@ -96,8 +118,8 @@ bool load_gemma4_gguf(const std::string & path, const uint32_t n_ff_exp = get_u32_or(gctx, "gemma4.expert_feed_forward_length", 0); const uint32_t n_head = get_u32_or(gctx, "gemma4.attention.head_count", 0); const uint32_t n_head_kv = get_u32_or(gctx, "gemma4.attention.head_count_kv", 0); - const uint32_t head_dim = get_u32_or(gctx, "gemma4.attention.key_length", 128); - const uint32_t n_vocab = get_u32_or(gctx, "gemma4.vocab_size", 0); + const uint32_t head_dim_full = get_u32_or(gctx, "gemma4.attention.key_length", 128); + const uint32_t head_dim_swa = get_u32_or(gctx, "gemma4.attention.key_length_swa", head_dim_full); const uint32_t n_expert = get_u32_or(gctx, "gemma4.expert_count", 0); const uint32_t n_expert_used = get_u32_or(gctx, "gemma4.expert_used_count", 0); const uint32_t n_dense_lead = get_u32_or(gctx, "gemma4.leading_dense_block_count", 1); @@ -105,6 +127,16 @@ bool load_gemma4_gguf(const std::string & path, const uint32_t shared_kv = get_u32_or(gctx, "gemma4.attention.shared_kv_layers", 0); const uint32_t n_embd_pl = get_u32_or(gctx, "gemma4.embedding_length_per_layer_input", 0); + // Per-layer head_count_kv (may be array or scalar) + std::vector head_kv_arr = get_u32_arr(gctx, "gemma4.attention.head_count_kv"); + + // Get vocab size from token_embd tensor shape (not always in metadata) + uint32_t n_vocab = get_u32_or(gctx, "gemma4.vocab_size", 0); + if (n_vocab == 0) { + ggml_tensor * tok_embd = find_tensor(meta_ctx, "token_embd.weight"); + if (tok_embd) n_vocab = (uint32_t)tok_embd->ne[1]; + } + const float rope_base_full = get_f32_or(gctx, "gemma4.rope.freq_base", 1000000.0f); const float rope_base_swa = get_f32_or(gctx, "gemma4.rope.freq_base_swa", 10000.0f); const float norm_eps = get_f32_or(gctx, "gemma4.attention.layer_norm_rms_epsilon", 1e-6f); @@ -121,7 +153,8 @@ bool load_gemma4_gguf(const std::string & path, out.n_layer = (int)n_layer; out.n_head = (int)n_head; out.n_head_kv = (int)n_head_kv; - out.head_dim = (int)head_dim; + out.head_dim = (int)head_dim_swa; // SWA head_dim (smaller) + out.head_dim_full = (int)head_dim_full; // full-attn head_dim (larger) out.n_embd = (int)n_embd; out.n_ff = (int)n_ff; out.n_ff_exp = (int)n_ff_exp; @@ -137,6 +170,18 @@ bool load_gemma4_gguf(const std::string & path, out.final_logit_softcap = logit_softcap; out.norm_eps = norm_eps; + // Per-layer n_head_kv from array (or fill scalar) + out.n_head_kv_per_layer.resize(n_layer); + if (!head_kv_arr.empty()) { + for (uint32_t il = 0; il < n_layer; ++il) { + out.n_head_kv_per_layer[il] = (int)(il < head_kv_arr.size() ? head_kv_arr[il] : head_kv_arr.back()); + } + } else { + for (uint32_t il = 0; il < n_layer; ++il) { + out.n_head_kv_per_layer[il] = (int)n_head_kv; + } + } + // KV sharing: last shared_kv layers reuse earlier KV out.kv_sharing_start = (int)(n_layer - shared_kv); out.has_kv.resize(n_layer); @@ -144,26 +189,38 @@ bool load_gemma4_gguf(const std::string & path, out.has_kv[il] = (int)il < out.kv_sharing_start; } - // SWA pattern from GGUF (array of bools or compute from pattern) + // SWA pattern from GGUF (array of bools indicating SWA layers) out.swa_layers.resize(n_layer, false); { int64_t pat_id = gguf_find_key(gctx, "gemma4.attention.sliding_window_pattern"); if (pat_id >= 0 && gguf_get_kv_type(gctx, pat_id) == GGUF_TYPE_ARRAY) { const size_t n = gguf_get_arr_n(gctx, pat_id); - // Pattern repeats over layers - for (uint32_t il = 0; il < n_layer; ++il) { - // 0 = full, 1 = SWA in the pattern - // Read from array, cycling if needed - if (n > 0) { - // For gemma4 the pattern is typically stored as a boolean array - // indicating whether each layer is SWA - out.swa_layers[il] = (il % n != 0); // first in group = full + if (n > 0) { + const auto arr_type = gguf_get_arr_type(gctx, pat_id); + const void * data = gguf_get_arr_data(gctx, pat_id); + for (uint32_t il = 0; il < n_layer; ++il) { + size_t idx = il % n; // cycle if pattern shorter than n_layer + if (arr_type == GGUF_TYPE_BOOL || arr_type == GGUF_TYPE_UINT8) { + out.swa_layers[il] = ((const uint8_t *)data)[idx] != 0; + } else if (arr_type == GGUF_TYPE_INT32 || arr_type == GGUF_TYPE_UINT32) { + out.swa_layers[il] = ((const uint32_t *)data)[idx] != 0; + } } } } else { - // Default: every 4th layer is full, rest SWA (like laguna) - for (uint32_t il = 0; il < n_layer; ++il) { - out.swa_layers[il] = (il % 4 != 0); + // Fallback: infer from per-layer head_kv (small kv = full, large kv = swa) + // or use default pattern (alternating 5:1) + if (!head_kv_arr.empty() && head_kv_arr.size() >= n_layer) { + uint32_t max_kv = *std::max_element(head_kv_arr.begin(), head_kv_arr.end()); + for (uint32_t il = 0; il < n_layer; ++il) { + // Layers with max n_head_kv are SWA, smaller are full-attn + out.swa_layers[il] = (head_kv_arr[il] == max_kv); + } + } else { + // Default: every 6th layer is full, rest SWA (5:1 pattern) + for (uint32_t il = 0; il < n_layer; ++il) { + out.swa_layers[il] = ((il % 6) != 5); + } } } } @@ -177,8 +234,8 @@ bool load_gemma4_gguf(const std::string & path, if (out.eos_id == (int32_t)miss) out.eos_id = 1; if (out.eos_chat_id == (int32_t)miss) out.eos_chat_id = -1; - std::printf("[gemma4-loader] n_layer=%u n_embd=%u head_dim=%u n_head=%u n_head_kv=%u\n", - n_layer, n_embd, head_dim, n_head, n_head_kv); + std::printf("[gemma4-loader] n_layer=%u n_embd=%u head_dim_swa=%u head_dim_full=%u n_head=%u n_head_kv=%u\n", + n_layer, n_embd, head_dim_swa, head_dim_full, n_head, n_head_kv); std::printf("[gemma4-loader] n_expert=%u used=%u dense_lead=%u sliding_window=%u\n", n_expert, n_expert_used, n_dense_lead, sliding_win); std::printf("[gemma4-loader] kv_sharing_start=%d per_layer_embd=%u logit_softcap=%g\n", @@ -221,8 +278,8 @@ bool load_gemma4_gguf(const std::string & path, tok_embd_off = offset; tok_embd_sz = sz; tok_embd_type = t->type; - // Set data pointer for metadata but don't copy to GPU - t->data = (void *)src; + // Upload to GPU (needed for tied lm_head / output) + ggml_backend_tensor_set(t, src, 0, sz); } else { ggml_backend_tensor_set(t, src, 0, sz); } @@ -244,7 +301,11 @@ bool load_gemma4_gguf(const std::string & path, // ── Assign tensors to struct ─────────────────────────────────────── out.tok_embd = find_tensor(meta_ctx, "token_embd.weight"); out.out_norm = find_tensor(meta_ctx, "output_norm.weight"); + // Gemma4 uses tied embeddings: lm_head = token_embd out.output = find_tensor(meta_ctx, "output.weight"); + if (!out.output) out.output = out.tok_embd; + // Global rope_freqs (not per-layer in this model variant) + out.rope_freqs_global = find_tensor(meta_ctx, "rope_freqs.weight"); out.per_layer_tok_embd = find_tensor(meta_ctx, "per_layer_tok_embd.weight"); out.per_layer_model_proj = find_tensor(meta_ctx, "per_layer_model_proj.weight"); out.per_layer_proj_norm = find_tensor(meta_ctx, "per_layer_proj_norm.weight"); @@ -266,17 +327,17 @@ bool load_gemma4_gguf(const std::string & path, L.wo = get("attn_output.weight"); L.q_norm = get("attn_q_norm.weight"); L.k_norm = get("attn_k_norm.weight"); - L.attn_post_norm = get("attn_post_norm.weight"); + L.attn_post_norm = get("post_attention_norm.weight"); L.rope_freqs = get("rope_freqs.weight"); L.ffn_norm = get("ffn_norm.weight"); L.ffn_gate = get("ffn_gate.weight"); L.ffn_up = get("ffn_up.weight"); L.ffn_down = get("ffn_down.weight"); - L.ffn_post_norm = get("ffn_post_norm.weight"); + L.ffn_post_norm = get("post_ffw_norm.weight"); - // MoE tensors - L.ffn_norm_moe = get("ffn_norm.weight"); // same tensor for both paths + // MoE tensors (only present for MoE models) + L.ffn_norm_moe = get("ffn_norm_moe.weight"); L.ffn_gate_inp = get("ffn_gate_inp.weight"); L.ffn_gate_inp_s = get("ffn_gate_inp_shexp.weight"); L.ffn_gate_up_exps = get("ffn_gate_up_exps.weight"); @@ -294,7 +355,7 @@ bool load_gemma4_gguf(const std::string & path, L.per_layer_proj = get("per_layer_proj.weight"); L.per_layer_post_norm = get("per_layer_post_norm.weight"); - L.out_scale = get("out_scale.weight"); + L.out_scale = get("layer_output_scale.weight"); } std::printf("[gemma4-loader] loaded %d tensors, vocab=%d\n", n_tensors, (int)n_vocab); @@ -314,9 +375,6 @@ void free_gemma4_weights(Gemma4Weights & w) { bool create_gemma4_cache(ggml_backend_t backend, const Gemma4Weights & w, int max_ctx, Gemma4Cache & out) { - const int D = w.head_dim; - const int Hk = w.n_head_kv; - ggml_init_params ip{}; ip.mem_size = ggml_tensor_overhead() * (size_t)(w.n_layer * 2 + 4) + 4096; ip.no_alloc = true; @@ -331,6 +389,8 @@ bool create_gemma4_cache(ggml_backend_t backend, const Gemma4Weights & w, int last_kv_layer = -1; for (int il = 0; il < w.n_layer; ++il) { if (w.has_kv[il]) { + const int D = gemma4_head_dim(w, il); + const int Hk = gemma4_n_head_kv(w, il); out.k[il] = ggml_new_tensor_3d(out.ctx, GGML_TYPE_F16, D, Hk, max_ctx); out.v[il] = ggml_new_tensor_3d(out.ctx, GGML_TYPE_F16, D, Hk, max_ctx); out.kv_source[il] = il; From 1315311928352aafe052659ed9486b4c7638d69a Mon Sep 17 00:00:00 2001 From: Howard Su Date: Wed, 20 May 2026 14:28:58 +0800 Subject: [PATCH 02/18] gemma4: implement DFlashTarget for speculative decode (G4) Add Gemma4DFlashTarget class implementing the DFlashTarget interface: - verify_batch: full forward with all-token argmax via gemma4_verify_batch - snapshot_kv / restore_kv: full KV cache save/restore for rollback - embed_tokens: CPU embedder with sqrt(n_embd) scaling - project_hidden_to_tokens: lm_head projection via gemma4_project_hidden - capture_layer_ids: evenly-spaced 5 layers (1, 15, 29, 43, 57) - mask_token_id: 0 (padding token) New graph functions: - gemma4_verify_batch(): like gemma4_step but returns all-position argmax - gemma4_project_hidden(): out_norm + lm_head + softcap + argmax Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- dflash/CMakeLists.txt | 1 + dflash/src/gemma4/gemma4_dflash_target.cpp | 145 ++++++++++++++ dflash/src/gemma4/gemma4_dflash_target.h | 63 +++++++ dflash/src/gemma4/gemma4_graph.cpp | 210 +++++++++++++++++++++ dflash/src/gemma4/gemma4_internal.h | 21 +++ 5 files changed, 440 insertions(+) create mode 100644 dflash/src/gemma4/gemma4_dflash_target.cpp create mode 100644 dflash/src/gemma4/gemma4_dflash_target.h diff --git a/dflash/CMakeLists.txt b/dflash/CMakeLists.txt index 2bdbb219..a69d8041 100644 --- a/dflash/CMakeLists.txt +++ b/dflash/CMakeLists.txt @@ -225,6 +225,7 @@ add_library(dflash_common STATIC src/gemma4/gemma4_graph.cpp src/gemma4/gemma4_backend.cpp src/gemma4/gemma4_daemon.cpp + src/gemma4/gemma4_dflash_target.cpp src/flashprefill_q8.cpp src/kv_cache.cpp src/kv_quant.cpp diff --git a/dflash/src/gemma4/gemma4_dflash_target.cpp b/dflash/src/gemma4/gemma4_dflash_target.cpp new file mode 100644 index 00000000..aa66080d --- /dev/null +++ b/dflash/src/gemma4/gemma4_dflash_target.cpp @@ -0,0 +1,145 @@ +// Gemma4DFlashTarget — DFlashTarget adapter for Gemma4 iSWA models. + +#include "gemma4_dflash_target.h" +#include "dflash27b.h" + +#include +#include +#include + +namespace dflash27b { + +Gemma4DFlashTarget::Gemma4DFlashTarget( + Gemma4Weights & w, + Gemma4Cache & cache, + ggml_backend_t backend) + : w_(w), cache_(cache), backend_(backend) { + // Evenly-spaced capture layer IDs (same formula as qwen35 loader). + const int N = DFLASH27B_DRAFT_N_TARGET_LAYERS; + capture_ids_.resize(N); + const int step = std::max(1, (w.n_layer - 2) / (N - 1)); + for (int k = 0; k < N; k++) { + capture_ids_[k] = 1 + k * step; + } +} + +Gemma4DFlashTarget::~Gemma4DFlashTarget() { + free_gemma4_snapshot(verify_snap_); +} + +bool Gemma4DFlashTarget::verify_batch( + const std::vector & tokens, + int base_pos, + int & last_tok, + std::vector * all_argmax) { + const int n_tokens = (int)tokens.size(); + if (n_tokens <= 0) return false; + + const int hidden = w_.n_embd; + + // Embed tokens + std::vector embed((size_t)n_tokens * hidden); + if (!w_.embedder.embed(tokens.data(), n_tokens, embed.data())) { + std::fprintf(stderr, "gemma4_verify_batch: embed failed\n"); + return false; + } + + // Scale by sqrt(n_embd) (Gemma4 convention) + const float scale = std::sqrt((float)hidden); + for (size_t i = 0; i < embed.size(); ++i) embed[i] *= scale; + + // Run verify (all-token argmax) + std::vector argmax_buf; + if (!gemma4_verify_batch(backend_, w_, cache_, embed.data(), + tokens.data(), n_tokens, base_pos, + argmax_buf)) { + return false; + } + + last_tok = argmax_buf[n_tokens - 1]; + if (all_argmax) { + *all_argmax = std::move(argmax_buf); + } + + return true; +} + +bool Gemma4DFlashTarget::snapshot_kv() { + // Save cur_pos and KV cache state + verify_snap_.cur_pos = cache_.cur_pos; + + // Allocate snapshot tensors if needed + if (verify_snap_.k_snap.empty()) { + ggml_init_params ip{}; + ip.mem_size = ggml_tensor_overhead() * (size_t)(w_.n_layer * 2 + 4) + 4096; + ip.no_alloc = true; + verify_snap_.ctx = ggml_init(ip); + if (!verify_snap_.ctx) return false; + + verify_snap_.k_snap.resize(w_.n_layer, nullptr); + verify_snap_.v_snap.resize(w_.n_layer, nullptr); + for (int il = 0; il < w_.n_layer; ++il) { + if (cache_.k[il]) { + verify_snap_.k_snap[il] = ggml_dup_tensor(verify_snap_.ctx, cache_.k[il]); + verify_snap_.v_snap[il] = ggml_dup_tensor(verify_snap_.ctx, cache_.v[il]); + } + } + verify_snap_.buf = ggml_backend_alloc_ctx_tensors(verify_snap_.ctx, backend_); + if (!verify_snap_.buf) return false; + } + + // Copy KV cache to snapshot + for (int il = 0; il < w_.n_layer; ++il) { + if (cache_.k[il] && verify_snap_.k_snap[il]) { + ggml_backend_tensor_copy(cache_.k[il], verify_snap_.k_snap[il]); + ggml_backend_tensor_copy(cache_.v[il], verify_snap_.v_snap[il]); + } + } + return true; +} + +bool Gemma4DFlashTarget::restore_kv() { + if (verify_snap_.k_snap.empty()) return false; + + // Restore KV cache from snapshot + for (int il = 0; il < w_.n_layer; ++il) { + if (cache_.k[il] && verify_snap_.k_snap[il]) { + ggml_backend_tensor_copy(verify_snap_.k_snap[il], cache_.k[il]); + ggml_backend_tensor_copy(verify_snap_.v_snap[il], cache_.v[il]); + } + } + cache_.cur_pos = verify_snap_.cur_pos; + return true; +} + +bool Gemma4DFlashTarget::is_eos(int token) const { + return token == w_.eos_id || token == w_.eos_chat_id; +} + +bool Gemma4DFlashTarget::embed_tokens(const int32_t * tokens, int n, + float * out) const { + if (!w_.embedder.embed(tokens, n, out)) return false; + // Gemma4 scales embeddings by sqrt(n_embd) + const float scale = std::sqrt((float)w_.n_embd); + const size_t total = (size_t)n * w_.n_embd; + for (size_t i = 0; i < total; ++i) out[i] *= scale; + return true; +} + +bool Gemma4DFlashTarget::project_hidden_to_tokens( + const float * hidden, + int n_tokens, + std::vector & tokens_out) { + return gemma4_project_hidden(backend_, w_, hidden, n_tokens, tokens_out); +} + +int Gemma4DFlashTarget::mask_token_id() const { + // Gemma4 uses token ID 0 as padding/mask + return 0; +} + +const std::vector & Gemma4DFlashTarget::capture_layer_ids() const { + return capture_ids_; +} + +} // namespace dflash27b diff --git a/dflash/src/gemma4/gemma4_dflash_target.h b/dflash/src/gemma4/gemma4_dflash_target.h new file mode 100644 index 00000000..86a0d8bf --- /dev/null +++ b/dflash/src/gemma4/gemma4_dflash_target.h @@ -0,0 +1,63 @@ +// Gemma4DFlashTarget — DFlashTarget adapter for Gemma4 iSWA models. +// +// Wraps the Gemma4 target infrastructure (Gemma4Weights, Gemma4Cache, +// gemma4_step) behind the generic DFlashTarget interface so the universal +// DFlash draft model can drive speculative decode verification. + +#pragma once + +#include "common/dflash_target.h" +#include "gemma4_internal.h" + +#include "ggml.h" +#include "ggml-backend.h" + +#include + +namespace dflash27b { + +class Gemma4DFlashTarget : public DFlashTarget { +public: + // Non-owning references — caller must ensure lifetime. + Gemma4DFlashTarget(Gemma4Weights & w, + Gemma4Cache & cache, + ggml_backend_t backend); + + ~Gemma4DFlashTarget() override; + + // ── DFlashTarget interface ────────────────────────────────────── + + bool verify_batch(const std::vector & tokens, + int base_pos, + int & last_tok, + std::vector * all_argmax = nullptr) override; + + bool snapshot_kv() override; + bool restore_kv() override; + + bool is_eos(int token) const override; + + bool embed_tokens(const int32_t * tokens, int n, + float * out) const override; + + bool project_hidden_to_tokens(const float * hidden, + int n_tokens, + std::vector & tokens_out) override; + + int hidden_size() const override { return w_.n_embd; } + int mask_token_id() const override; + const std::vector & capture_layer_ids() const override; + +private: + Gemma4Weights & w_; + Gemma4Cache & cache_; + ggml_backend_t backend_; + + // Capture layer IDs (built once in constructor). + std::vector capture_ids_; + + // Snapshot for speculative verify rollback. + Gemma4Snapshot verify_snap_; +}; + +} // namespace dflash27b diff --git a/dflash/src/gemma4/gemma4_graph.cpp b/dflash/src/gemma4/gemma4_graph.cpp index e218ac60..a83b22ba 100644 --- a/dflash/src/gemma4/gemma4_graph.cpp +++ b/dflash/src/gemma4/gemma4_graph.cpp @@ -493,4 +493,214 @@ bool gemma4_step( return true; } +// ── gemma4_verify_batch ───────────────────────────────────────────────── +// Like gemma4_step but returns argmax for ALL token positions (not just last). + +bool gemma4_verify_batch( + ggml_backend_t backend, + const Gemma4Weights & w, + Gemma4Cache & cache, + const float * embed, + const int32_t * token_ids, + int n_tokens, + int kv_start, + std::vector & out_argmax) +{ + ggml_init_params ip{}; + ip.mem_size = ggml_tensor_overhead() * 16384 + ggml_graph_overhead() + 16 * 1024 * 1024; + ip.no_alloc = true; + ggml_context * ctx = ggml_init(ip); + ggml_cgraph * gf = ggml_new_graph_custom(ctx, 16384, false); + + // Input tensors + ggml_tensor * ie = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, w.n_embd, n_tokens); + ggml_set_input(ie); + ggml_tensor * pp = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_tokens); + ggml_set_input(pp); + + // Token IDs for per-layer embedding + ggml_tensor * tok_ids = nullptr; + if (token_ids && w.per_layer_tok_embd && w.per_layer_model_proj && w.n_embd_per_layer > 0) { + tok_ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_tokens); + ggml_set_input(tok_ids); + } + + // Attention masks (padded) + const int kv_len_raw = kv_start + n_tokens; + const int kv_len_padded = (kv_len_raw + 255) & ~255; + ggml_tensor * mk_full = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv_len_padded, n_tokens, 1, 1); + ggml_set_input(mk_full); + ggml_tensor * mk_full_f16 = ggml_cast(ctx, mk_full, GGML_TYPE_F16); + ggml_tensor * mk_swa = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv_len_padded, n_tokens, 1, 1); + ggml_set_input(mk_swa); + ggml_tensor * mk_swa_f16 = ggml_cast(ctx, mk_swa, GGML_TYPE_F16); + + // Per-layer embedding computation (same as gemma4_step) + ggml_tensor * per_layer_all = nullptr; + if (tok_ids) { + const int D = w.n_embd_per_layer; + const int L = w.n_layer; + ggml_tensor * inp_pl = ggml_get_rows(ctx, w.per_layer_tok_embd, tok_ids); + inp_pl = ggml_reshape_3d(ctx, inp_pl, D, L, n_tokens); + inp_pl = ggml_scale(ctx, inp_pl, std::sqrt((float)D)); + ggml_tensor * proj = ggml_mul_mat(ctx, w.per_layer_model_proj, ie); + proj = ggml_scale(ctx, proj, 1.0f / std::sqrt((float)w.n_embd)); + proj = ggml_reshape_3d(ctx, proj, D, L, n_tokens); + proj = ggml_rms_norm(ctx, proj, w.norm_eps); + ggml_tensor * norm_w = ggml_reshape_2d(ctx, w.per_layer_proj_norm, D, L); + proj = ggml_mul(ctx, proj, norm_w); + per_layer_all = ggml_add(ctx, proj, inp_pl); + per_layer_all = ggml_scale(ctx, per_layer_all, 1.0f / std::sqrt(2.0f)); + per_layer_all = ggml_cont(ctx, ggml_permute(ctx, per_layer_all, 0, 2, 1, 3)); + } + + // Build graph (all layers) + ggml_tensor * cur = ie; + for (int il = 0; il < w.n_layer; ++il) { + ggml_tensor * pl_input = nullptr; + if (per_layer_all) { + pl_input = gemma4_view_2d_slice(ctx, per_layer_all, il); + } + cur = build_gemma4_layer(ctx, gf, w, cache, il, cur, pp, + mk_full_f16, mk_swa_f16, pl_input, + kv_start, n_tokens); + } + + // Final norm + cur = gemma4_rms_norm_mul(ctx, cur, w.out_norm, w.norm_eps); + + // lm_head for ALL tokens (no slicing) + cur = ggml_mul_mat(ctx, w.output, cur); // [n_vocab, n_tokens] + + // Logit softcapping + if (w.final_logit_softcap > 0.0f) { + cur = ggml_scale(ctx, cur, 1.0f / w.final_logit_softcap); + cur = ggml_tanh(ctx, cur); + cur = ggml_scale(ctx, cur, w.final_logit_softcap); + } + + // Argmax per token + cur = ggml_argmax(ctx, cur); // [n_tokens] + ggml_set_output(cur); + ggml_build_forward_expand(gf, cur); + + // Allocate + static ggml_gallocr_t galloc_verify = nullptr; + if (!galloc_verify) galloc_verify = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + if (!ggml_gallocr_alloc_graph(galloc_verify, gf)) { + std::fprintf(stderr, "gemma4_verify_batch: gallocr_alloc_graph failed\n"); + ggml_free(ctx); + return false; + } + + // Set inputs + ggml_backend_tensor_set(ie, embed, 0, ggml_nbytes(ie)); + std::vector pos((size_t)n_tokens); + for (int i = 0; i < n_tokens; ++i) pos[i] = kv_start + i; + ggml_backend_tensor_set(pp, pos.data(), 0, ggml_nbytes(pp)); + + if (tok_ids && token_ids) { + ggml_backend_tensor_set(tok_ids, token_ids, 0, (size_t)n_tokens * sizeof(int32_t)); + } + + // Masks + std::vector mfull((size_t)kv_len_padded * n_tokens, -INFINITY); + for (int q = 0; q < n_tokens; ++q) { + const int abs_q = kv_start + q; + for (int k = 0; k <= abs_q && k < kv_len_raw; ++k) { + mfull[(size_t)q * kv_len_padded + k] = 0.0f; + } + } + ggml_backend_tensor_set(mk_full, mfull.data(), 0, ggml_nbytes(mk_full)); + + std::vector mswa((size_t)kv_len_padded * n_tokens, -INFINITY); + const int W = w.sliding_window; + for (int q = 0; q < n_tokens; ++q) { + const int abs_q = kv_start + q; + const int win_lo = std::max(0, abs_q - W + 1); + for (int k = win_lo; k <= abs_q && k < kv_len_raw; ++k) { + mswa[(size_t)q * kv_len_padded + k] = 0.0f; + } + } + ggml_backend_tensor_set(mk_swa, mswa.data(), 0, ggml_nbytes(mk_swa)); + + // Compute + if (ggml_backend_graph_compute(backend, gf) != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "gemma4_verify_batch: graph_compute failed\n"); + ggml_free(ctx); + return false; + } + + // Read argmax + out_argmax.resize(n_tokens); + ggml_backend_tensor_get(cur, out_argmax.data(), 0, sizeof(int32_t) * n_tokens); + + cache.cur_pos = kv_len_raw; + ggml_free(ctx); + return true; +} + +// ── gemma4_project_hidden ─────────────────────────────────────────────── +// Runs out_norm + lm_head + softcap + argmax on external hidden states. + +bool gemma4_project_hidden( + ggml_backend_t backend, + const Gemma4Weights & w, + const float * hidden, + int n_tokens, + std::vector & out_tokens) +{ + ggml_init_params ip{}; + ip.mem_size = ggml_tensor_overhead() * 64 + ggml_graph_overhead() + 1024 * 1024; + ip.no_alloc = true; + ggml_context * ctx = ggml_init(ip); + ggml_cgraph * gf = ggml_new_graph(ctx); + + // Input: hidden states [n_embd, n_tokens] + ggml_tensor * inp = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, w.n_embd, n_tokens); + ggml_set_input(inp); + + // out_norm + lm_head + ggml_tensor * cur = gemma4_rms_norm_mul(ctx, inp, w.out_norm, w.norm_eps); + cur = ggml_mul_mat(ctx, w.output, cur); // [n_vocab, n_tokens] + + // Logit softcapping + if (w.final_logit_softcap > 0.0f) { + cur = ggml_scale(ctx, cur, 1.0f / w.final_logit_softcap); + cur = ggml_tanh(ctx, cur); + cur = ggml_scale(ctx, cur, w.final_logit_softcap); + } + + // Argmax + cur = ggml_argmax(ctx, cur); // [n_tokens] + ggml_set_output(cur); + ggml_build_forward_expand(gf, cur); + + // Allocate + static ggml_gallocr_t galloc_proj = nullptr; + if (!galloc_proj) galloc_proj = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + if (!ggml_gallocr_alloc_graph(galloc_proj, gf)) { + std::fprintf(stderr, "gemma4_project_hidden: gallocr_alloc_graph failed\n"); + ggml_free(ctx); + return false; + } + + // Set input + ggml_backend_tensor_set(inp, hidden, 0, sizeof(float) * (size_t)n_tokens * w.n_embd); + + // Compute + if (ggml_backend_graph_compute(backend, gf) != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "gemma4_project_hidden: graph_compute failed\n"); + ggml_free(ctx); + return false; + } + + // Read result + out_tokens.resize(n_tokens); + ggml_backend_tensor_get(cur, out_tokens.data(), 0, sizeof(int32_t) * n_tokens); + + ggml_free(ctx); + return true; +} + } // namespace dflash::common diff --git a/dflash/src/gemma4/gemma4_internal.h b/dflash/src/gemma4/gemma4_internal.h index dd165a15..f92f8df0 100644 --- a/dflash/src/gemma4/gemma4_internal.h +++ b/dflash/src/gemma4/gemma4_internal.h @@ -198,4 +198,25 @@ bool gemma4_step( int kv_start, std::vector & out_logits); +// Verify batch: run forward pass returning argmax for ALL positions. +// Used by DFlash speculative decode target. +bool gemma4_verify_batch( + ggml_backend_t backend, + const Gemma4Weights & w, + Gemma4Cache & cache, + const float * embed, + const int32_t * token_ids, + int n_tokens, + int kv_start, + std::vector & out_argmax); + +// Project hidden states through lm_head (out_norm + output + softcap + argmax). +// Used by DFlash draft to convert draft hidden states to token IDs. +bool gemma4_project_hidden( + ggml_backend_t backend, + const Gemma4Weights & w, + const float * hidden, + int n_tokens, + std::vector & out_tokens); + } // namespace dflash::common From 9b26a2b170a1f85af7016e1c008ee849ada3ec30 Mon Sep 17 00:00:00 2001 From: Howard Su Date: Wed, 20 May 2026 15:17:57 +0800 Subject: [PATCH 03/18] gemma4: fix attention scale, tokenizer decode, and server integration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Critical fixes for Gemma4 model inference: - Fix kq_scale: Gemma4 uses self.scaling=1.0 (not 1/sqrt(head_dim)) because Q/K already get per-head RMS norm. This was the root cause of garbage output (repeated token generation). - Add SentencePiece tokenizer support: Gemma4 tokens are raw UTF-8 with U+2581 for space, not GPT-2 byte-level encoding. Detects mode from tokenizer.ggml.model GGUF key. Handles encode (space->▁, UTF-8 char splitting) and decode (▁->space) correctly. - Fix KV cache layout: [D, max_ctx, Hk] matching Qwen35 convention, with per-head strided snapshot save/restore. - Add Gemma4 chat template: <|turn>user\n...\n<|turn>model\n - Map Gemma4 thinking channel (<|channel>...) to existing ... reasoning system for proper content separation. - Add eos_chat_id detection for token (id 106). - Fix special token filtering in both streaming and non-streaming paths. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- dflash/src/gemma4/gemma4_backend.cpp | 54 ++++++++--- dflash/src/gemma4/gemma4_graph.cpp | 4 +- dflash/src/gemma4/gemma4_loader.cpp | 6 +- dflash/src/server/chat_template.cpp | 37 ++++++- dflash/src/server/chat_template.h | 1 + dflash/src/server/http_server.cpp | 39 +++++++- dflash/src/server/http_server.h | 3 + dflash/src/server/server_main.cpp | 3 + dflash/src/server/tokenizer.cpp | 138 +++++++++++++++++++++------ dflash/src/server/tokenizer.h | 5 + 10 files changed, 243 insertions(+), 47 deletions(-) diff --git a/dflash/src/gemma4/gemma4_backend.cpp b/dflash/src/gemma4/gemma4_backend.cpp index 922f7051..c699675e 100644 --- a/dflash/src/gemma4/gemma4_backend.cpp +++ b/dflash/src/gemma4/gemma4_backend.cpp @@ -258,12 +258,27 @@ GenerateResult Gemma4Backend::restore_and_generate(int slot, } const auto & snap = snapshots_[slot]; - // Copy right-sized snapshot into full-size cache (position is outermost dim). + // Restore snapshot into cache per-head (cache: [D, max_ctx, Hk]). for (int il = 0; il < cache_.n_layer; ++il) { if (cache_.k[il] && snap.k_snap[il]) { - const size_t nbytes = ggml_nbytes(snap.k_snap[il]); - ggml_backend_tensor_set(cache_.k[il], snap.k_snap[il]->data, 0, nbytes); - ggml_backend_tensor_set(cache_.v[il], snap.v_snap[il]->data, 0, nbytes); + ggml_tensor * ck = cache_.k[il]; + const int D = (int)ck->ne[0]; + const int Hk = (int)ck->ne[2]; + const int max_ctx = (int)ck->ne[1]; + const int spos = snap.cur_pos; + const size_t elem_sz = ggml_element_size(ck); + const size_t head_bytes_src = (size_t)D * spos * elem_sz; + const size_t head_bytes_dst = (size_t)D * max_ctx * elem_sz; + const size_t copy_bytes = head_bytes_src; + + for (int h = 0; h < Hk; ++h) { + ggml_backend_tensor_set(cache_.k[il], + (const char *)snap.k_snap[il]->data + h * head_bytes_src, + h * head_bytes_dst, copy_bytes); + ggml_backend_tensor_set(cache_.v[il], + (const char *)snap.v_snap[il]->data + h * head_bytes_src, + h * head_bytes_dst, copy_bytes); + } } } @@ -378,12 +393,13 @@ bool Gemma4Backend::snapshot_save(int slot) { snap.v_snap.resize(n_layer, nullptr); for (int il = 0; il < n_layer; ++il) { if (cache_.k[il]) { - // Right-sized: [D, Hk, snap_pos] instead of [D, Hk, max_ctx] + // Cache layout: [D, max_ctx, Hk] + // Snapshot: [D, snap_pos, Hk] — same axis order, truncated positions ggml_tensor * ck = cache_.k[il]; snap.k_snap[il] = ggml_new_tensor_3d(snap.ctx, ck->type, - ck->ne[0], ck->ne[1], snap_pos); + ck->ne[0], snap_pos, ck->ne[2]); snap.v_snap[il] = ggml_new_tensor_3d(snap.ctx, ck->type, - ck->ne[0], ck->ne[1], snap_pos); + ck->ne[0], snap_pos, ck->ne[2]); } } @@ -395,12 +411,28 @@ bool Gemma4Backend::snapshot_save(int slot) { } } - // Copy first snap_pos positions (contiguous — position is outermost dim). + // Copy snap_pos positions per head. + // Cache: [D, max_ctx, Hk], Snap: [D, snap_pos, Hk] + // Per head h: copy D*snap_pos elements from cache offset h*D*max_ctx to snap offset h*D*snap_pos for (int il = 0; il < n_layer; ++il) { if (cache_.k[il] && snap.k_snap[il]) { - const size_t nbytes = ggml_nbytes(snap.k_snap[il]); - ggml_backend_tensor_get(cache_.k[il], snap.k_snap[il]->data, 0, nbytes); - ggml_backend_tensor_get(cache_.v[il], snap.v_snap[il]->data, 0, nbytes); + ggml_tensor * ck = cache_.k[il]; + const int D = (int)ck->ne[0]; + const int Hk = (int)ck->ne[2]; + const int max_ctx = (int)ck->ne[1]; + const size_t elem_sz = ggml_element_size(ck); + const size_t head_bytes_src = (size_t)D * max_ctx * elem_sz; + const size_t head_bytes_dst = (size_t)D * snap_pos * elem_sz; + const size_t copy_bytes = head_bytes_dst; // D * snap_pos * elem_sz per head + + for (int h = 0; h < Hk; ++h) { + ggml_backend_tensor_get(cache_.k[il], + (char *)snap.k_snap[il]->data + h * head_bytes_dst, + h * head_bytes_src, copy_bytes); + ggml_backend_tensor_get(cache_.v[il], + (char *)snap.v_snap[il]->data + h * head_bytes_dst, + h * head_bytes_src, copy_bytes); + } } } snap.cur_pos = snap_pos; diff --git a/dflash/src/gemma4/gemma4_graph.cpp b/dflash/src/gemma4/gemma4_graph.cpp index a83b22ba..33fe591c 100644 --- a/dflash/src/gemma4/gemma4_graph.cpp +++ b/dflash/src/gemma4/gemma4_graph.cpp @@ -230,7 +230,9 @@ static ggml_tensor * build_gemma4_attn_block( head_dim, kv_len, n_head_kv, cache_v->nb[1], cache_v->nb[2], 0); - const float kq_scale = 1.0f / std::sqrt((float)head_dim); + // Gemma4 uses self.scaling = 1.0 (no QK scaling) because Q/K are already + // RMS-normed per-head. Standard 1/sqrt(head_dim) is NOT used here. + const float kq_scale = 1.0f; ggml_tensor * use_mask = is_swa ? attn_mask_swa : attn_mask_full; ggml_tensor * attn = ggml_flash_attn_ext(ctx, Qfa, Kfa, Vfa, use_mask, kq_scale, 0.0f, 0.0f); diff --git a/dflash/src/gemma4/gemma4_loader.cpp b/dflash/src/gemma4/gemma4_loader.cpp index 62cab9c8..dae77815 100644 --- a/dflash/src/gemma4/gemma4_loader.cpp +++ b/dflash/src/gemma4/gemma4_loader.cpp @@ -391,8 +391,10 @@ bool create_gemma4_cache(ggml_backend_t backend, const Gemma4Weights & w, if (w.has_kv[il]) { const int D = gemma4_head_dim(w, il); const int Hk = gemma4_n_head_kv(w, il); - out.k[il] = ggml_new_tensor_3d(out.ctx, GGML_TYPE_F16, D, Hk, max_ctx); - out.v[il] = ggml_new_tensor_3d(out.ctx, GGML_TYPE_F16, D, Hk, max_ctx); + // Layout: [head_dim, max_ctx, n_head_kv] — positions before heads + // (matches the view strides used in build_gemma4_attn_block) + out.k[il] = ggml_new_tensor_3d(out.ctx, GGML_TYPE_F16, D, max_ctx, Hk); + out.v[il] = ggml_new_tensor_3d(out.ctx, GGML_TYPE_F16, D, max_ctx, Hk); out.kv_source[il] = il; last_kv_layer = il; } else { diff --git a/dflash/src/server/chat_template.cpp b/dflash/src/server/chat_template.cpp index 92c46588..d7c026ab 100644 --- a/dflash/src/server/chat_template.cpp +++ b/dflash/src/server/chat_template.cpp @@ -36,7 +36,8 @@ static const char QWEN3_TOOL_SUFFIX[] = ChatFormat chat_format_for_arch(const std::string & arch) { if (arch == "laguna") return ChatFormat::LAGUNA; - // qwen35, qwen3, gemma4 all use the Qwen3/ChatML format + if (arch == "gemma4") return ChatFormat::GEMMA4; + // qwen35, qwen3 use the Qwen3/ChatML format return ChatFormat::QWEN3; } @@ -150,6 +151,40 @@ std::string render_chat_template( } break; } + + case ChatFormat::GEMMA4: { + // Gemma4 format: + // <|turn>user\n{msg}\n<|turn>model\n + // System messages are prepended to the first user message. + result = ""; + std::string system_content; + size_t start_idx = 0; + if (!messages.empty() && messages[0].role == "system") { + system_content = messages[0].content; + start_idx = 1; + } + + for (size_t i = start_idx; i < messages.size(); i++) { + const auto & msg = messages[i]; + std::string role = msg.role; + if (role == "assistant") role = "model"; + + result += "<|turn>"; + result += role; + result += '\n'; + // Inject system content at the start of the first user message. + if (i == start_idx && !system_content.empty() && msg.role == "user") { + result += system_content; + result += "\n\n"; + } + result += msg.content; + result += "\n"; + } + if (add_generation_prompt) { + result += "<|turn>model\n"; + } + break; + } } return result; diff --git a/dflash/src/server/chat_template.h b/dflash/src/server/chat_template.h index 5f35f492..e1ea2134 100644 --- a/dflash/src/server/chat_template.h +++ b/dflash/src/server/chat_template.h @@ -24,6 +24,7 @@ struct ChatMessage { enum class ChatFormat { QWEN3, // <|im_start|>role\n...<|im_end|>\n LAGUNA, // <|begin_of_sentence|><|User|>...<|Assistant|> + GEMMA4, // <|turn>role\n...\n }; // Render chat messages into the model-specific prompt string. diff --git a/dflash/src/server/http_server.cpp b/dflash/src/server/http_server.cpp index 5d61da30..d1114081 100644 --- a/dflash/src/server/http_server.cpp +++ b/dflash/src/server/http_server.cpp @@ -443,7 +443,6 @@ bool HttpServer::route_request(int fd, const HttpRequest & hr) { true, enable_thinking, tools_json); req.prompt_tokens = tokenizer_.encode(rendered); - // Detect if prompt ends with (model will start in reasoning mode). if (enable_thinking) { size_t end = rendered.size(); @@ -691,11 +690,35 @@ void HttpServer::worker_loop() { // Skip EOS/EOT/special tokens — don't forward to SSE. int32_t eos = tokenizer_.eos_id(); - if (token == eos) return true; - // Also skip common Qwen3 special tokens by checking if the raw - // token text starts with '<|' (e.g. <|im_end|>, <|im_start|>). + int32_t eot = tokenizer_.eos_chat_id(); + if (token == eos || token == eot) return true; + const std::string & raw = tokenizer_.raw_token(token); + + // Gemma4 thinking channel: map <|channel> → , \n + if (raw == "<|channel>") { + if (req.stream) { + auto chunks = emitter.emit_token(""); + for (const auto & chunk : chunks) + if (!send_all(fd, chunk.data(), chunk.size())) { client_disconnected = true; return false; } + } + return true; + } + if (raw == "") { + if (req.stream) { + auto chunks = emitter.emit_token("\n"); + for (const auto & chunk : chunks) + if (!send_all(fd, chunk.data(), chunk.size())) { client_disconnected = true; return false; } + } + return true; + } + + // Skip other special tokens (starting with <|, or any <...> except byte-fallback) if (raw.size() >= 2 && raw[0] == '<' && raw[1] == '|') return true; + if (raw.size() >= 2 && raw[0] == '<' && raw.back() == '>') { + if (!(raw.size() == 6 && raw[1] == '0' && raw[2] == 'x')) + return true; + } std::string text = tokenizer_.token_text(token); @@ -782,7 +805,15 @@ void HttpServer::worker_loop() { for (int32_t tok : result.tokens) { const std::string & raw = tokenizer_.raw_token(tok); if (tok == tokenizer_.eos_id()) continue; + if (tok == tokenizer_.eos_chat_id()) continue; + // Gemma4 channel → think mapping + if (raw == "<|channel>") { emitter.emit_token(""); continue; } + if (raw == "") { emitter.emit_token("\n"); continue; } if (raw.size() >= 2 && raw[0] == '<' && raw[1] == '|') continue; + if (raw.size() >= 2 && raw[0] == '<' && raw.back() == '>') { + if (!(raw.size() == 6 && raw[1] == '0' && raw[2] == 'x')) + continue; + } std::string text = tokenizer_.token_text(tok); emitter.emit_token(text); } diff --git a/dflash/src/server/http_server.h b/dflash/src/server/http_server.h index 8c0ec9eb..cc060b98 100644 --- a/dflash/src/server/http_server.h +++ b/dflash/src/server/http_server.h @@ -99,6 +99,9 @@ class HttpServer { // Set the optional pflash drafter tokenizer. void set_drafter_tokenizer(Tokenizer * tok) { drafter_tokenizer_ = tok; } + // Set the chat template format (detected from model arch). + void set_chat_format(ChatFormat fmt) { chat_format_ = fmt; } + // Start listening. Blocks until shutdown() is called. int run(); diff --git a/dflash/src/server/server_main.cpp b/dflash/src/server/server_main.cpp index 319f97de..12f0bc4e 100644 --- a/dflash/src/server/server_main.cpp +++ b/dflash/src/server/server_main.cpp @@ -12,6 +12,7 @@ // [--max-tokens 4096] [--gpu 0] #include "http_server.h" +#include "chat_template.h" #include "common/backend_factory.h" #include "common/gguf_inspect.h" @@ -223,6 +224,7 @@ int main(int argc, char ** argv) { // Create backend. std::fprintf(stderr, "[server] creating backend...\n"); + const std::string arch = detect_arch(bargs.model_path); auto backend = create_backend(bargs); if (!backend) { std::fprintf(stderr, "[server] backend creation failed\n"); @@ -272,6 +274,7 @@ int main(int argc, char ** argv) { std::fprintf(stderr, "[server] ╰─────────────────────────────────────────────────────╯\n\n"); HttpServer server(*backend, tokenizer, sconfig); + server.set_chat_format(chat_format_for_arch(arch)); g_server = &server; std::signal(SIGTERM, signal_handler); std::signal(SIGINT, signal_handler); diff --git a/dflash/src/server/tokenizer.cpp b/dflash/src/server/tokenizer.cpp index 1f538682..5ff4b1a7 100644 --- a/dflash/src/server/tokenizer.cpp +++ b/dflash/src/server/tokenizer.cpp @@ -364,34 +364,76 @@ static std::string encode_gpt2_bpe(const std::string & text) { std::vector Tokenizer::bpe_encode_piece(const std::string & piece) const { if (piece.empty()) return {}; - // Convert raw text to GPT-2 byte encoding for vocab lookup. - // The GGUF vocab stores tokens in GPT-2's byte-to-unicode encoding, - // so " world" becomes "Ġworld" (space 0x20 → Ġ U+0120). - std::string encoded = encode_gpt2_bpe(piece); - - // Start with individual bytes/chars as initial symbols. - // Each symbol is a string that we look up in the vocab. std::vector symbols; - // Try to find the encoded piece as a single token first. - auto it = token_to_id_.find(encoded); - if (it != token_to_id_.end()) { - return { it->second }; - } + if (is_sentencepiece_) { + // SentencePiece: replace leading space with ▁, tokens are raw UTF-8. + std::string sp_piece; + sp_piece.reserve(piece.size() + 2); + size_t start = 0; + if (!piece.empty() && piece[0] == ' ') { + sp_piece += "\xe2\x96\x81"; // ▁ (U+2581) + start = 1; + } + sp_piece += piece.substr(start); + // Replace any remaining spaces with ▁ + std::string encoded; + encoded.reserve(sp_piece.size()); + for (char c : sp_piece) { + if (c == ' ') { + encoded += "\xe2\x96\x81"; + } else { + encoded += c; + } + } - // Split into individual GPT-2-encoded bytes as initial BPE symbols. - // Each raw byte becomes a single GPT-2 Unicode character (possibly multi-byte UTF-8). - for (size_t i = 0; i < piece.size(); i++) { - std::string sym = byte_to_gpt2_unicode((uint8_t)piece[i]); - auto sit = token_to_id_.find(sym); - if (sit != token_to_id_.end()) { - symbols.push_back(sym); - } else { - // Byte-fallback: use <0xNN> tokens - char buf[8]; - std::snprintf(buf, sizeof(buf), "<0x%02X>", - (unsigned)(uint8_t)piece[i]); - symbols.push_back(buf); + // Try whole piece as single token. + auto it = token_to_id_.find(encoded); + if (it != token_to_id_.end()) { + return { it->second }; + } + + // Split into individual UTF-8 characters as initial BPE symbols. + const char * p = encoded.c_str(); + const char * end = p + encoded.size(); + while (p < end) { + int cplen; + utf8_decode(p, (size_t)(end - p), &cplen); + if (cplen <= 0) cplen = 1; + std::string sym(p, cplen); + auto sit = token_to_id_.find(sym); + if (sit != token_to_id_.end()) { + symbols.push_back(sym); + } else { + // Byte-fallback: <0xNN> + char buf[8]; + std::snprintf(buf, sizeof(buf), "<0x%02X>", (unsigned)(uint8_t)*p); + symbols.push_back(buf); + } + p += cplen; + } + } else { + // GPT-2 BPE: convert raw text to GPT-2 byte encoding for vocab lookup. + std::string encoded = encode_gpt2_bpe(piece); + + // Try to find the encoded piece as a single token first. + auto it = token_to_id_.find(encoded); + if (it != token_to_id_.end()) { + return { it->second }; + } + + // Split into individual GPT-2-encoded bytes as initial BPE symbols. + for (size_t i = 0; i < piece.size(); i++) { + std::string sym = byte_to_gpt2_unicode((uint8_t)piece[i]); + auto sit = token_to_id_.find(sym); + if (sit != token_to_id_.end()) { + symbols.push_back(sym); + } else { + char buf[8]; + std::snprintf(buf, sizeof(buf), "<0x%02X>", + (unsigned)(uint8_t)piece[i]); + symbols.push_back(buf); + } } } @@ -538,6 +580,18 @@ bool Tokenizer::load_from_gguf(const char * model_path) { added_tokens_.size()); } + // Detect tokenizer model type (sentencepiece vs bpe). + int model_key = gguf_find_key(gctx, "tokenizer.ggml.model"); + if (model_key >= 0) { + const char * model = gguf_get_val_str(gctx, model_key); + // SentencePiece models store tokens as raw UTF-8 with ▁ for space. + // GPT-2/BPE models use byte-level Unicode encoding. + if (model && (std::strcmp(model, "llama") == 0 || + std::strncmp(model, "gemma", 5) == 0)) { + is_sentencepiece_ = true; + } + } + // Detect pre-tokenizer type. int pre_key = gguf_find_key(gctx, "tokenizer.ggml.pre"); if (pre_key >= 0) { @@ -564,12 +618,18 @@ bool Tokenizer::load_from_gguf(const char * model_path) { auto eot = token_to_id_.find("<|im_end|>"); if (eot != token_to_id_.end()) eos_chat_id_ = eot->second; } + if (eos_chat_id_ < 0) { + // Gemma4 uses as end-of-turn. + auto eot = token_to_id_.find(""); + if (eot != token_to_id_.end()) eos_chat_id_ = eot->second; + } gguf_free(gctx); - std::fprintf(stderr, "[tokenizer] loaded vocab=%d merges=%zu bos=%d eos=%d eot=%d pre=%s\n", + std::fprintf(stderr, "[tokenizer] loaded vocab=%d merges=%zu bos=%d eos=%d eot=%d pre=%s sp=%s\n", n_vocab, merge_rank_.size(), bos_id_, eos_id_, eos_chat_id_, - pre_type_ == PreTokenizer::QWEN35 ? "qwen35" : "qwen2"); + pre_type_ == PreTokenizer::QWEN35 ? "qwen35" : "qwen2", + is_sentencepiece_ ? "yes" : "no"); return true; } @@ -686,11 +746,33 @@ std::string Tokenizer::token_text(int32_t id) const { } } - // Special tokens (e.g. <|im_start|>) — return as-is. + // Special tokens (e.g. <|im_start|>, ) — return as-is. if (!tok.empty() && tok[0] == '<' && tok.back() == '>') { return tok; } + if (is_sentencepiece_) { + // SentencePiece: tokens are raw UTF-8 with ▁ (U+2581) for space. + std::string out; + out.reserve(tok.size()); + const char * p = tok.c_str(); + const char * end = p + tok.size(); + while (p < end) { + // ▁ is 3 bytes: 0xE2 0x96 0x81 + if (end - p >= 3 && + (uint8_t)p[0] == 0xE2 && + (uint8_t)p[1] == 0x96 && + (uint8_t)p[2] == 0x81) { + out.push_back(' '); + p += 3; + } else { + out.push_back(*p); + p++; + } + } + return out; + } + // Decode GPT-2 byte-level BPE encoding → raw bytes. return decode_gpt2_bpe(tok); } diff --git a/dflash/src/server/tokenizer.h b/dflash/src/server/tokenizer.h index f28dfd9f..5484fa47 100644 --- a/dflash/src/server/tokenizer.h +++ b/dflash/src/server/tokenizer.h @@ -46,6 +46,7 @@ class Tokenizer { // ─── Special tokens ────────────────────────────────────────────── int32_t eos_id() const { return eos_id_; } + int32_t eos_chat_id() const { return eos_chat_id_; } int32_t bos_id() const { return bos_id_; } int32_t vocab_size() const { return (int32_t)id_to_token_.size(); } @@ -80,6 +81,10 @@ class Tokenizer { // Pre-tokenizer type enum class PreTokenizer { QWEN2, QWEN35 }; PreTokenizer pre_type_ = PreTokenizer::QWEN35; + + // Decode mode: SentencePiece tokens use UTF-8 with ▁ for space; + // GPT-2/BPE tokens use byte-level Unicode encoding. + bool is_sentencepiece_ = false; }; } // namespace dflash::common From f99ff75becf8df23527523368ddf8fa7bda22817 Mon Sep 17 00:00:00 2001 From: Howard Su Date: Wed, 20 May 2026 15:23:48 +0800 Subject: [PATCH 04/18] gemma4: implement real park/unpark for VRAM management park() now frees snapshots, KV cache, and model weights (releasing GPU memory). unpark() reloads weights from disk and recreates the KV cache. Also adds parked guards to generate(), restore_and_generate(), and snapshot_save() to prevent use while model is parked. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- dflash/src/gemma4/gemma4_backend.cpp | 37 ++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/dflash/src/gemma4/gemma4_backend.cpp b/dflash/src/gemma4/gemma4_backend.cpp index c699675e..7f949ddb 100644 --- a/dflash/src/gemma4/gemma4_backend.cpp +++ b/dflash/src/gemma4/gemma4_backend.cpp @@ -69,15 +69,43 @@ void Gemma4Backend::print_ready_banner() const { bool Gemma4Backend::park(const std::string & what) { (void)what; + if (parked_) return true; + + // Free snapshots first (they reference the snap_backend buffer) + for (int i = 0; i < PREFIX_SLOTS; ++i) { + free_gemma4_snapshot(snapshots_[i]); + } + + // Free KV cache (GPU memory) + free_gemma4_cache(cache_); + + // Free model weights (GPU memory) + free_gemma4_weights(w_); + parked_ = true; - std::printf("[gemma4] parked\n"); std::fflush(stdout); + std::printf("[gemma4] parked (VRAM released)\n"); std::fflush(stdout); return true; } bool Gemma4Backend::unpark(const std::string & what) { (void)what; + if (!parked_) return true; + + // Reload weights from disk + if (!load_gemma4_gguf(cfg_.model_path, backend_, w_)) { + std::fprintf(stderr, "[gemma4] unpark: failed to reload weights\n"); + return false; + } + + // Recreate KV cache + if (!create_gemma4_cache(backend_, w_, cfg_.device.max_ctx, cache_)) { + std::fprintf(stderr, "[gemma4] unpark: failed to recreate cache\n"); + free_gemma4_weights(w_); + return false; + } + parked_ = false; - std::printf("[gemma4] unparked\n"); std::fflush(stdout); + std::printf("[gemma4] unparked (VRAM restored)\n"); std::fflush(stdout); return true; } @@ -172,6 +200,8 @@ bool Gemma4Backend::do_decode(int committed, int n_gen, GenerateResult Gemma4Backend::generate(const GenerateRequest & req, const DaemonIO & io) { GenerateResult result; + if (parked_) { result.error = "model is parked"; return result; } + DaemonIO out_io = io.with_token_callback(req.on_token); sampler_ = req.sampler; if (req.do_sample && sampler_.seed != 0) { @@ -249,6 +279,8 @@ GenerateResult Gemma4Backend::restore_and_generate(int slot, const GenerateRequest & req, const DaemonIO & io) { GenerateResult result; + if (parked_) { result.error = "model is parked"; return result; } + DaemonIO out_io = io.with_token_callback(req.on_token); if (slot < 0 || slot >= PREFIX_SLOTS || !snapshots_[slot].ctx) { @@ -371,6 +403,7 @@ GenerateResult Gemma4Backend::restore_and_generate(int slot, // ── Snapshots ────────────────────────────────────────────────────────── bool Gemma4Backend::snapshot_save(int slot) { + if (parked_) return false; if (slot < 0 || slot >= PREFIX_SLOTS) return false; auto & snap = snapshots_[slot]; From c4a7ba6faa2a686b2cad9cc13c5e7144721d2b37 Mon Sep 17 00:00:00 2001 From: Howard Su Date: Wed, 20 May 2026 16:01:18 +0800 Subject: [PATCH 05/18] gemma4: implement G5 SWA ring-buffer, G6 fa_window, G3 compress G5: SWA layers now allocate min(sliding_window, max_ctx) KV cache instead of full max_ctx. Ring-buffer write (kv_start % swa_size) and ring-aware attention mask enable bounded memory for sliding-window layers. Prefill chunks are capped to avoid ring wrap. G6: Added fa_window config for sparse decode. Full-attention layers limit their FA read to the last fa_window positions during decode, reducing compute at long contexts. G3: Ported PFlash compress pipeline from Qwen35. Parks target, lazy-loads Qwen3-0.6B drafter, runs score_and_compress, emits surviving tokens, unparks. Drafter stays resident (~1.4 GB). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- dflash/src/gemma4/gemma4_backend.cpp | 114 +++++++++++++++++++++------ dflash/src/gemma4/gemma4_backend.h | 6 ++ dflash/src/gemma4/gemma4_graph.cpp | 79 ++++++++++++++----- dflash/src/gemma4/gemma4_internal.h | 2 + dflash/src/gemma4/gemma4_loader.cpp | 13 ++- dflash/src/server/server_main.cpp | 1 + 6 files changed, 170 insertions(+), 45 deletions(-) diff --git a/dflash/src/gemma4/gemma4_backend.cpp b/dflash/src/gemma4/gemma4_backend.cpp index 7f949ddb..114ca158 100644 --- a/dflash/src/gemma4/gemma4_backend.cpp +++ b/dflash/src/gemma4/gemma4_backend.cpp @@ -50,6 +50,7 @@ bool Gemma4Backend::init() { std::fprintf(stderr, "[gemma4] cache alloc failed\n"); return false; } + cache_.fa_window = cfg_.fa_window; std::printf("[gemma4] init ok: %d layers, embd=%d, vocab=%d, max_ctx=%d\n", w_.n_layer, w_.n_embd, w_.n_vocab, cfg_.device.max_ctx); @@ -125,6 +126,12 @@ int Gemma4Backend::do_prefill(const std::vector & tokens, while (pos < n) { int len = std::min(chunk, n - pos); + // Limit chunk to avoid ring-buffer wrap for SWA layers + if (cache_.swa_size > 0 && cache_.swa_size < cache_.max_ctx) { + const int swa_remaining = cache_.swa_size - ((kv_offset + pos) % cache_.swa_size); + len = std::min(len, swa_remaining); + } + // Embed tokens using CPU embedder w_.embedder.embed(tokens.data() + pos, len, embed.data()); @@ -290,17 +297,17 @@ GenerateResult Gemma4Backend::restore_and_generate(int slot, } const auto & snap = snapshots_[slot]; - // Restore snapshot into cache per-head (cache: [D, max_ctx, Hk]). + // Restore snapshot into cache per-head (cache: [D, cache_len, Hk]). for (int il = 0; il < cache_.n_layer; ++il) { if (cache_.k[il] && snap.k_snap[il]) { ggml_tensor * ck = cache_.k[il]; const int D = (int)ck->ne[0]; const int Hk = (int)ck->ne[2]; - const int max_ctx = (int)ck->ne[1]; - const int spos = snap.cur_pos; + const int cache_len = (int)ck->ne[1]; + const int save_pos = (int)snap.k_snap[il]->ne[1]; // min(snap.cur_pos, cache_len) const size_t elem_sz = ggml_element_size(ck); - const size_t head_bytes_src = (size_t)D * spos * elem_sz; - const size_t head_bytes_dst = (size_t)D * max_ctx * elem_sz; + const size_t head_bytes_src = (size_t)D * save_pos * elem_sz; + const size_t head_bytes_dst = (size_t)D * cache_len * elem_sz; const size_t copy_bytes = head_bytes_src; for (int h = 0; h < Hk; ++h) { @@ -426,13 +433,14 @@ bool Gemma4Backend::snapshot_save(int slot) { snap.v_snap.resize(n_layer, nullptr); for (int il = 0; il < n_layer; ++il) { if (cache_.k[il]) { - // Cache layout: [D, max_ctx, Hk] - // Snapshot: [D, snap_pos, Hk] — same axis order, truncated positions ggml_tensor * ck = cache_.k[il]; + const int cache_len = (int)ck->ne[1]; + // Save min(snap_pos, cache_len) positions + const int save_pos = std::min(snap_pos, cache_len); snap.k_snap[il] = ggml_new_tensor_3d(snap.ctx, ck->type, - ck->ne[0], snap_pos, ck->ne[2]); + ck->ne[0], save_pos, ck->ne[2]); snap.v_snap[il] = ggml_new_tensor_3d(snap.ctx, ck->type, - ck->ne[0], snap_pos, ck->ne[2]); + ck->ne[0], save_pos, ck->ne[2]); } } @@ -444,19 +452,19 @@ bool Gemma4Backend::snapshot_save(int slot) { } } - // Copy snap_pos positions per head. - // Cache: [D, max_ctx, Hk], Snap: [D, snap_pos, Hk] - // Per head h: copy D*snap_pos elements from cache offset h*D*max_ctx to snap offset h*D*snap_pos + // Copy valid positions per head. + // Cache: [D, cache_len, Hk], Snap: [D, save_pos, Hk] for (int il = 0; il < n_layer; ++il) { if (cache_.k[il] && snap.k_snap[il]) { ggml_tensor * ck = cache_.k[il]; const int D = (int)ck->ne[0]; const int Hk = (int)ck->ne[2]; - const int max_ctx = (int)ck->ne[1]; + const int cache_len = (int)ck->ne[1]; + const int save_pos = std::min(snap_pos, cache_len); const size_t elem_sz = ggml_element_size(ck); - const size_t head_bytes_src = (size_t)D * max_ctx * elem_sz; - const size_t head_bytes_dst = (size_t)D * snap_pos * elem_sz; - const size_t copy_bytes = head_bytes_dst; // D * snap_pos * elem_sz per head + const size_t head_bytes_src = (size_t)D * cache_len * elem_sz; + const size_t head_bytes_dst = (size_t)D * save_pos * elem_sz; + const size_t copy_bytes = head_bytes_dst; for (int h = 0; h < Hk; ++h) { ggml_backend_tensor_get(cache_.k[il], @@ -493,15 +501,76 @@ int Gemma4Backend::snapshot_cur_pos(int slot) const { bool Gemma4Backend::handle_compress(const std::string & line, const DaemonIO & io) { - (void)line; (void)io; - // Gemma4 doesn't use pflash drafter for compression (yet). - std::printf("[gemma4] compress: not supported\n"); - std::fflush(stdout); - return true; + // Check for "nopark" suffix + bool skip_park = (line.size() >= 16 && + line.compare(line.size() - 7, 7, " nopark") == 0); + + // Parse: "compress [nopark]" + char ppath[1024]; + int keep_x1000 = 0; + char drafter_path[1024] = {0}; + const int n = std::sscanf(line.c_str() + 9, "%1023s %d %1023s", + ppath, &keep_x1000, drafter_path); + if (n < 2) { + std::fprintf(stderr, "[compress] bad args\n"); + io.emit(-1); + return false; + } + + const char * dpath = (n >= 3 && drafter_path[0]) + ? drafter_path + : "/opt/lucebox/models/drafter/Qwen3-0.6B-BF16.gguf"; + + // Park target to free VRAM for the drafter (unless skip_park). + const bool was_parked = parked_; + if (!skip_park && !parked_) { + park("target"); + } + + // Synchronize backend + ggml_backend_synchronize(backend_); + + // Load drafter (lazy — stays resident for subsequent calls) + if (!drafter_loaded_) { + std::fprintf(stderr, "[compress] loading drafter from %s ...\n", dpath); + if (!load_drafter(dpath, /*gpu_layers=*/999, drafter_ctx_)) { + std::fprintf(stderr, "[compress] drafter init failed: %s\n", + dflash27b_last_error()); + io.emit(-1); + if (!skip_park && !was_parked) unpark("target"); + return false; + } + drafter_loaded_ = true; + std::fprintf(stderr, "[compress] drafter ready\n"); + } + + std::vector tokens = read_int32_file(ppath); + bool ok = false; + if (!tokens.empty()) { + const float keep = (float)keep_x1000 / 1000.0f; + auto compressed = drafter_score_and_compress(drafter_ctx_, tokens, keep); + ok = !compressed.empty(); + if (ok) { + std::fprintf(stderr, "[compress] %zu -> %zu tokens\n", + tokens.size(), compressed.size()); + for (int32_t t : compressed) io.emit(t); + } + } + io.emit(-1); + + // Restore park state + if (!skip_park && !was_parked) { + unpark("target"); + } + + return ok; } void Gemma4Backend::free_drafter() { - // No drafter to free. + if (drafter_loaded_) { + dflash27b::free_drafter(drafter_ctx_); + drafter_loaded_ = false; + } } bool Gemma4Backend::try_handle_command(const std::string & line, @@ -514,6 +583,7 @@ bool Gemma4Backend::try_handle_command(const std::string & line, void Gemma4Backend::shutdown() { for (int i = 0; i < PREFIX_SLOTS; ++i) snapshot_free(i); + free_drafter(); free_gemma4_cache(cache_); free_gemma4_weights(w_); free_snapshot_backend(snap_backend_, backend_); diff --git a/dflash/src/gemma4/gemma4_backend.h b/dflash/src/gemma4/gemma4_backend.h index 18c26fda..49468672 100644 --- a/dflash/src/gemma4/gemma4_backend.h +++ b/dflash/src/gemma4/gemma4_backend.h @@ -9,6 +9,7 @@ #include "common/device_placement.h" #include "gemma4_internal.h" #include "common/sampler.h" +#include "../qwen3/qwen3_drafter.h" #include "ggml.h" #include "ggml-backend.h" @@ -24,6 +25,7 @@ struct Gemma4BackendConfig { DevicePlacement device; int stream_fd = -1; int chunk = 512; + int fa_window = 0; // 0 = full attention; >0 = sparse decode window }; class Gemma4Backend : public ModelBackend { @@ -76,6 +78,10 @@ class Gemma4Backend : public ModelBackend { SamplerCfg sampler_; std::mt19937_64 sampler_rng_{std::random_device{}()}; + // PFlash drafter (compress) + DrafterContext drafter_ctx_; + bool drafter_loaded_ = false; + // Snapshots static constexpr int PREFIX_SLOTS = 64; Gemma4Snapshot snapshots_[PREFIX_SLOTS]; diff --git a/dflash/src/gemma4/gemma4_graph.cpp b/dflash/src/gemma4/gemma4_graph.cpp index 33fe591c..6359edc5 100644 --- a/dflash/src/gemma4/gemma4_graph.cpp +++ b/dflash/src/gemma4/gemma4_graph.cpp @@ -178,6 +178,7 @@ static ggml_tensor * build_gemma4_attn_block( int cache_il = cache.kv_source[il]; ggml_tensor * cache_k = cache.k[cache_il]; ggml_tensor * cache_v = cache.v[cache_il]; + const int cache_len = (int)cache_k->ne[1]; // max_ctx for full, swa_size for SWA if (has_kv) { // K/V projection + norm + RoPE + write to cache @@ -198,42 +199,60 @@ static ggml_tensor * build_gemma4_attn_block( 0, rope_base, 1.0f, 0.0f, 1.0f, 32.0f, 1.0f); - // Write K/V to cache + // Write K/V to cache (ring-buffer position for SWA layers) + const int write_pos = is_swa ? (kv_start % cache_len) : kv_start; ggml_tensor * Kcur_T = ggml_permute(ctx, Kcur, 0, 2, 1, 3); ggml_tensor * Vcur_T = ggml_permute(ctx, Vcur, 0, 2, 1, 3); ggml_tensor * k_slot = ggml_view_3d(ctx, cache_k, head_dim, n_tokens, n_head_kv, cache_k->nb[1], cache_k->nb[2], - cache_k->nb[1] * (size_t)kv_start); + cache_k->nb[1] * (size_t)write_pos); ggml_tensor * v_slot = ggml_view_3d(ctx, cache_v, head_dim, n_tokens, n_head_kv, cache_v->nb[1], cache_v->nb[2], - cache_v->nb[1] * (size_t)kv_start); + cache_v->nb[1] * (size_t)write_pos); ggml_build_forward_expand(gf, ggml_cpy(ctx, Kcur_T, k_slot)); ggml_build_forward_expand(gf, ggml_cpy(ctx, Vcur_T, v_slot)); } // else: KV-sharing layer — cache already written by source layer // Flash attention - // Pad kv_len to multiple of 256 for CUDA FA kernel compatibility (FATTN_KQ_STRIDE=256) - const int kv_len_raw = kv_start + n_tokens; - const int kv_len = (kv_len_raw + 255) & ~255; // round up to 256 + // For SWA layers: read entire ring buffer (cache_len positions) + // For full layers: read all positions (or windowed if fa_window > 0) + const int fa_window = cache.fa_window; + const int full_win_start = (!is_swa && fa_window > 0 && kv_start > fa_window) + ? (kv_start - fa_window) : 0; + const int kv_len_raw = is_swa ? std::min(kv_start + n_tokens, cache_len) + : (kv_start + n_tokens - full_win_start); + const int kv_len = (kv_len_raw + 255) & ~255; // pad to 256 for CUDA FA ggml_tensor * Qfa = ggml_permute(ctx, Qcur, 0, 2, 1, 3); Qfa = ggml_cont(ctx, Qfa); + const size_t cache_offset = is_swa ? 0 : (cache_k->nb[1] * (size_t)full_win_start); ggml_tensor * Kfa = ggml_view_3d(ctx, cache_k, head_dim, kv_len, n_head_kv, - cache_k->nb[1], cache_k->nb[2], 0); + cache_k->nb[1], cache_k->nb[2], cache_offset); ggml_tensor * Vfa = ggml_view_3d(ctx, cache_v, head_dim, kv_len, n_head_kv, - cache_v->nb[1], cache_v->nb[2], 0); + cache_v->nb[1], cache_v->nb[2], cache_offset); // Gemma4 uses self.scaling = 1.0 (no QK scaling) because Q/K are already // RMS-normed per-head. Standard 1/sqrt(head_dim) is NOT used here. const float kq_scale = 1.0f; - ggml_tensor * use_mask = is_swa ? attn_mask_swa : attn_mask_full; + ggml_tensor * use_mask; + if (is_swa) { + use_mask = attn_mask_swa; + } else if (full_win_start > 0) { + // View the mask starting at full_win_start column + use_mask = ggml_view_4d(ctx, attn_mask_full, + kv_len, n_tokens, 1, 1, + attn_mask_full->nb[1], attn_mask_full->nb[2], attn_mask_full->nb[3], + (size_t)full_win_start * ggml_element_size(attn_mask_full)); + } else { + use_mask = attn_mask_full; + } ggml_tensor * attn = ggml_flash_attn_ext(ctx, Qfa, Kfa, Vfa, use_mask, kq_scale, 0.0f, 0.0f); @@ -357,13 +376,18 @@ bool gemma4_step( } // Attention masks (full + SWA) - // Pad kv_len to 256 for CUDA FA kernel compatibility (FATTN_KQ_STRIDE=256) + // Full-attention mask: covers all positions [0, kv_start+n_tokens) const int kv_len_raw = kv_start + n_tokens; const int kv_len_padded = (kv_len_raw + 255) & ~255; ggml_tensor * mk_full = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv_len_padded, n_tokens, 1, 1); ggml_set_input(mk_full); ggml_tensor * mk_full_f16 = ggml_cast(ctx, mk_full, GGML_TYPE_F16); - ggml_tensor * mk_swa = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv_len_padded, n_tokens, 1, 1); + + // SWA mask: covers the ring buffer [0, swa_size) with ring-buffer indexing + const int swa_size = cache.swa_size; + const int swa_len_raw = std::min(kv_start + n_tokens, swa_size); + const int swa_len_padded = (swa_len_raw + 255) & ~255; + ggml_tensor * mk_swa = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, swa_len_padded, n_tokens, 1, 1); ggml_set_input(mk_swa); ggml_tensor * mk_swa_f16 = ggml_cast(ctx, mk_swa, GGML_TYPE_F16); @@ -466,15 +490,23 @@ bool gemma4_step( } ggml_backend_tensor_set(mk_full, mfull.data(), 0, ggml_nbytes(mk_full)); - // SWA mask — padded positions are masked with -inf - std::vector mswa((size_t)kv_len_padded * n_tokens, -INFINITY); + // SWA ring-buffer mask — maps cache indices to absolute positions const int W = w.sliding_window; + std::vector mswa((size_t)swa_len_padded * n_tokens, -INFINITY); for (int q = 0; q < n_tokens; ++q) { const int abs_q = kv_start + q; const int win_lo = std::max(0, abs_q - W + 1); - for (int k = win_lo; k <= abs_q && k < kv_len_raw; ++k) { - mswa[(size_t)q * kv_len_padded + k] = 0.0f; + // The ring buffer stores the most recent min(abs_q+1, swa_size) entries. + // Cache slot j holds absolute position: depends on how many tokens written. + const int total_written = abs_q + 1; // positions [0..abs_q] written so far + for (int abs_k = win_lo; abs_k <= abs_q; ++abs_k) { + // Map absolute position to ring-buffer slot + const int slot = abs_k % swa_size; + if (slot < swa_len_raw) { + mswa[(size_t)q * swa_len_padded + slot] = 0.0f; + } } + (void)total_written; } ggml_backend_tensor_set(mk_swa, mswa.data(), 0, ggml_nbytes(mk_swa)); @@ -533,7 +565,12 @@ bool gemma4_verify_batch( ggml_tensor * mk_full = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv_len_padded, n_tokens, 1, 1); ggml_set_input(mk_full); ggml_tensor * mk_full_f16 = ggml_cast(ctx, mk_full, GGML_TYPE_F16); - ggml_tensor * mk_swa = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv_len_padded, n_tokens, 1, 1); + + // SWA mask: ring-buffer sized + const int swa_size = cache.swa_size; + const int swa_len_raw = std::min(kv_start + n_tokens, swa_size); + const int swa_len_padded = (swa_len_raw + 255) & ~255; + ggml_tensor * mk_swa = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, swa_len_padded, n_tokens, 1, 1); ggml_set_input(mk_swa); ggml_tensor * mk_swa_f16 = ggml_cast(ctx, mk_swa, GGML_TYPE_F16); @@ -615,13 +652,17 @@ bool gemma4_verify_batch( } ggml_backend_tensor_set(mk_full, mfull.data(), 0, ggml_nbytes(mk_full)); - std::vector mswa((size_t)kv_len_padded * n_tokens, -INFINITY); + // SWA ring-buffer mask const int W = w.sliding_window; + std::vector mswa((size_t)swa_len_padded * n_tokens, -INFINITY); for (int q = 0; q < n_tokens; ++q) { const int abs_q = kv_start + q; const int win_lo = std::max(0, abs_q - W + 1); - for (int k = win_lo; k <= abs_q && k < kv_len_raw; ++k) { - mswa[(size_t)q * kv_len_padded + k] = 0.0f; + for (int abs_k = win_lo; abs_k <= abs_q; ++abs_k) { + const int slot = abs_k % swa_size; + if (slot < swa_len_raw) { + mswa[(size_t)q * swa_len_padded + slot] = 0.0f; + } } } ggml_backend_tensor_set(mk_swa, mswa.data(), 0, ggml_nbytes(mk_swa)); diff --git a/dflash/src/gemma4/gemma4_internal.h b/dflash/src/gemma4/gemma4_internal.h index f92f8df0..2576d05d 100644 --- a/dflash/src/gemma4/gemma4_internal.h +++ b/dflash/src/gemma4/gemma4_internal.h @@ -158,6 +158,8 @@ struct Gemma4Cache { int cur_pos = 0; int max_ctx = 0; int n_layer = 0; + int swa_size = 0; // ring-buffer size for SWA layers (= sliding_window) + int fa_window = 0; // sparse decode window for full-attn layers (0 = full) // Only layers where has_kv[il] == true have real K/V tensors. // KV-reuse layers reference an earlier layer's cache. diff --git a/dflash/src/gemma4/gemma4_loader.cpp b/dflash/src/gemma4/gemma4_loader.cpp index dae77815..2179ea00 100644 --- a/dflash/src/gemma4/gemma4_loader.cpp +++ b/dflash/src/gemma4/gemma4_loader.cpp @@ -385,16 +385,20 @@ bool create_gemma4_cache(ggml_backend_t backend, const Gemma4Weights & w, out.v.resize(w.n_layer, nullptr); out.kv_source.resize(w.n_layer); + // SWA layers use a ring buffer of size min(sliding_window, max_ctx). + const int swa_size = (w.sliding_window > 0 && w.sliding_window < max_ctx) + ? w.sliding_window : max_ctx; + // Determine KV source for each layer int last_kv_layer = -1; for (int il = 0; il < w.n_layer; ++il) { if (w.has_kv[il]) { const int D = gemma4_head_dim(w, il); const int Hk = gemma4_n_head_kv(w, il); - // Layout: [head_dim, max_ctx, n_head_kv] — positions before heads - // (matches the view strides used in build_gemma4_attn_block) - out.k[il] = ggml_new_tensor_3d(out.ctx, GGML_TYPE_F16, D, max_ctx, Hk); - out.v[il] = ggml_new_tensor_3d(out.ctx, GGML_TYPE_F16, D, max_ctx, Hk); + const bool is_swa = gemma4_is_swa_layer(w, il); + const int cache_len = is_swa ? swa_size : max_ctx; + out.k[il] = ggml_new_tensor_3d(out.ctx, GGML_TYPE_F16, D, cache_len, Hk); + out.v[il] = ggml_new_tensor_3d(out.ctx, GGML_TYPE_F16, D, cache_len, Hk); out.kv_source[il] = il; last_kv_layer = il; } else { @@ -412,6 +416,7 @@ bool create_gemma4_cache(ggml_backend_t backend, const Gemma4Weights & w, out.cur_pos = 0; out.max_ctx = max_ctx; out.n_layer = w.n_layer; + out.swa_size = swa_size; return true; } diff --git a/dflash/src/server/server_main.cpp b/dflash/src/server/server_main.cpp index 12f0bc4e..e11571fe 100644 --- a/dflash/src/server/server_main.cpp +++ b/dflash/src/server/server_main.cpp @@ -16,6 +16,7 @@ #include "common/backend_factory.h" #include "common/gguf_inspect.h" +#include #include #include #include From 1bfb72039eea81e124a16798ab8bdaa7721cee3b Mon Sep 17 00:00:00 2001 From: Howard Su Date: Wed, 20 May 2026 17:54:33 +0800 Subject: [PATCH 06/18] gemma4: wire DFlash speculative decode into Gemma4 backend - Add target_feat ring buffer to Gemma4Cache for feature capture - Add feature capture nodes to build_gemma4_layer() (both step and verify) - Add draft model loading with metadata override (GGUF has wrong dimensions) - Infer n_capture_layers from fc weight shape (6 for Gemma4, not 5 from metadata) - Port do_spec_decode() loop from qwen35 backend - Wire spec-decode into generate() and restore_and_generate() (temp==0 only) - Sync captured features to DraftFeatureMirror after each prefill chunk - Store last_tok during prefill for spec-decode entry - Pass draft_path/draft_gpu/draft_ctx_max through BackendArgs to Gemma4BackendConfig - Clean up draft resources in shutdown() Tested: AR decode produces correct output, spec-decode pipeline runs end-to-end with 9.1 tok/s throughput. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- dflash/src/common/backend_factory.cpp | 11 +- dflash/src/gemma4/gemma4_backend.cpp | 485 +++++++++++++++++---- dflash/src/gemma4/gemma4_backend.h | 18 + dflash/src/gemma4/gemma4_dflash_target.cpp | 17 +- dflash/src/gemma4/gemma4_graph.cpp | 57 ++- dflash/src/gemma4/gemma4_internal.h | 15 + dflash/src/gemma4/gemma4_loader.cpp | 45 ++ 7 files changed, 547 insertions(+), 101 deletions(-) diff --git a/dflash/src/common/backend_factory.cpp b/dflash/src/common/backend_factory.cpp index 01804569..71ce2efc 100644 --- a/dflash/src/common/backend_factory.cpp +++ b/dflash/src/common/backend_factory.cpp @@ -88,10 +88,13 @@ std::unique_ptr create_backend(const BackendArgs & args) { } else if (arch == "gemma4") { Gemma4BackendConfig gcfg; - gcfg.model_path = args.model_path; - gcfg.device = args.device; - gcfg.stream_fd = args.stream_fd; - gcfg.chunk = args.chunk; + gcfg.model_path = args.model_path; + gcfg.draft_path = args.draft_path; + gcfg.draft_gpu = args.draft_gpu; + gcfg.draft_ctx_max = args.draft_ctx_max; + gcfg.device = args.device; + gcfg.stream_fd = args.stream_fd; + gcfg.chunk = args.chunk; auto backend = std::make_unique(gcfg); if (!backend->init()) { diff --git a/dflash/src/gemma4/gemma4_backend.cpp b/dflash/src/gemma4/gemma4_backend.cpp index 114ca158..a02365fd 100644 --- a/dflash/src/gemma4/gemma4_backend.cpp +++ b/dflash/src/gemma4/gemma4_backend.cpp @@ -8,6 +8,9 @@ #include "dflash27b.h" #include "common/sampler.h" #include "common/io_utils.h" +#include "common/dflash_feature_ring.h" +#include "common/dflash_draft_graph.h" +#include "common/step_graph.h" #include "ggml-cuda.h" #include "common/snapshot_backend.h" @@ -52,6 +55,81 @@ bool Gemma4Backend::init() { } cache_.fa_window = cfg_.fa_window; + // Load draft model for speculative decode + if (cfg_.draft_path) { + const int draft_gpu = (cfg_.draft_gpu >= 0) ? cfg_.draft_gpu : cfg_.device.gpu; + draft_backend_ = ggml_backend_cuda_init(draft_gpu); + if (!draft_backend_) { + std::fprintf(stderr, "[gemma4] draft CUDA init failed (gpu=%d)\n", draft_gpu); + } else { + // Load draft GGUF — pass nullptr for target (Gemma4 != TargetWeights) + if (!load_draft_gguf(cfg_.draft_path, draft_backend_, dw_, nullptr)) { + std::fprintf(stderr, "[gemma4] draft load failed: %s\n", dflash27b_last_error()); + ggml_backend_free(draft_backend_); draft_backend_ = nullptr; + } else { + // Override mask_token_id for Gemma4 (use token 0 = pad) + dw_.mask_token_id = 0; + + // Fix draft dimensions from actual tensor shapes (GGUF metadata is wrong) + // fc.weight: [fc_in, draft_hidden] + const int draft_hidden = (int)dw_.fc->ne[1]; + const int fc_in = (int)dw_.fc->ne[0]; + const int n_capture = fc_in / w_.n_embd; + + if (draft_hidden != dw_.n_embd) { + std::printf("[gemma4] draft: overriding n_embd %d -> %d (from fc weight)\n", + dw_.n_embd, draft_hidden); + dw_.n_embd = draft_hidden; + } + // Infer n_head from wq shape: wq.ne[1] = n_head * head_dim + if (dw_.n_layer > 0 && dw_.layers[0].wq) { + const int q_dim = (int)dw_.layers[0].wq->ne[1]; + const int inferred_n_head = q_dim / dw_.head_dim; + if (inferred_n_head != dw_.n_head) { + std::printf("[gemma4] draft: overriding n_head %d -> %d\n", + dw_.n_head, inferred_n_head); + dw_.n_head = inferred_n_head; + } + } + // Infer n_ff from ffn_gate shape + if (dw_.n_layer > 0 && dw_.layers[0].w_gate) { + const int inferred_ff = (int)dw_.layers[0].w_gate->ne[1]; + if (inferred_ff != dw_.n_ff) { + std::printf("[gemma4] draft: overriding n_ff %d -> %d\n", + dw_.n_ff, inferred_ff); + dw_.n_ff = inferred_ff; + } + } + // Override n_target_layers from fc shape + dw_.n_target_layers = n_capture; + + std::printf("[gemma4] draft loaded: fc_in=%d target_hidden=%d " + "draft_hidden=%d n_capture_layers=%d\n", + fc_in, w_.n_embd, draft_hidden, n_capture); + + // Allocate target_feat ring buffer + constexpr int TARGET_FEAT_CAP = 4096; + const int feat_cap = std::min(cfg_.device.max_ctx, TARGET_FEAT_CAP); + if (!create_gemma4_target_feat(backend_, cache_, n_capture, w_.n_embd, feat_cap)) { + std::fprintf(stderr, "[gemma4] target_feat alloc failed\n"); + } else { + // Init feature mirror on draft GPU + const int mirror_cap = std::min(cfg_.draft_ctx_max, feat_cap); + if (!draft_feature_mirror_init(feature_mirror_, draft_backend_, + draft_gpu, cfg_.device.gpu, mirror_cap, + n_capture, w_.n_embd)) { + std::fprintf(stderr, "[gemma4] feature mirror init failed\n"); + } else { + // Create DFlash target adapter + dflash_target_ = new Gemma4DFlashTarget(w_, cache_, backend_); + std::printf("[gemma4] spec-decode ready: capture_layers=%d mirror_cap=%d\n", + n_capture, mirror_cap); + } + } + } + } + } + std::printf("[gemma4] init ok: %d layers, embd=%d, vocab=%d, max_ctx=%d\n", w_.n_layer, w_.n_embd, w_.n_vocab, cfg_.device.max_ctx); std::fflush(stdout); @@ -148,6 +226,22 @@ int Gemma4Backend::do_prefill(const std::vector & tokens, pos += len; cache_.cur_pos = kv_offset + pos; + + // Store last_tok from final chunk's logits (argmax of last position) + if (pos >= n && !logits.empty()) { + int32_t best_tok = 0; + float best_val = logits[0]; + for (int j = 1; j < (int)logits.size(); ++j) { + if (logits[j] > best_val) { best_val = logits[j]; best_tok = j; } + } + cache_.last_tok = best_tok; + } + + // Sync captured features to draft mirror + if (feature_mirror_.target_feat && cache_.target_feat && !draft_parked_) { + draft_feature_mirror_sync_range(cache_.target_feat, cache_.target_feat_cap, + feature_mirror_, kv_pos, len); + } } return kv_offset + pos; @@ -202,6 +296,189 @@ bool Gemma4Backend::do_decode(int committed, int n_gen, return true; } +// ── Speculative Decode ───────────────────────────────────────────────── + +bool Gemma4Backend::do_spec_decode(int committed, int n_gen, + std::vector & out_tokens, + const DaemonIO & io) { + const int hidden = w_.n_embd; + int32_t last_tok = cache_.last_tok; + + DFlashTarget * target = dflash_target_; + const int q_len = dw_.block_size; + + StepGraph draft_sg; + + std::vector noise_embed((size_t)hidden * q_len); + std::vector noise_ids(q_len); + std::vector draft_tok(q_len); + std::vector target_tok(q_len); + std::vector pos_q(q_len); + std::vector pos_k; + std::vector local_hidden; + + int n_generated = 0; + int n_draft_steps = 0; + int n_accept_sum = 0; + + auto t_dec0 = std::chrono::steady_clock::now(); + + while (n_generated < n_gen) { + const int need_commit_budget = n_gen - n_generated; + + // 1. Build noise input: [last_tok, MASK, MASK, ..., MASK] + noise_ids[0] = last_tok; + for (int i = 1; i < q_len; i++) noise_ids[i] = target->mask_token_id(); + if (!target->embed_tokens(noise_ids.data(), q_len, noise_embed.data())) { + std::fprintf(stderr, "[gemma4-spec] noise embed failed\n"); + step_graph_destroy(draft_sg); + return false; + } + + // 2. Draft compute + constexpr int DRAFT_CTX_MAX_DEFAULT = 2048; + const int ring_cap = feature_mirror_.cap; + const int draft_ctx = std::min(committed, + std::min(ring_cap, std::max(DRAFT_CTX_MAX_DEFAULT, cfg_.draft_ctx_max))); + const int draft_start = committed - draft_ctx; + int mirror_slot0 = 0; + const bool use_mirror_view = + draft_feature_mirror_can_view(feature_mirror_, committed, draft_ctx, mirror_slot0); + + if (!build_draft_step(draft_sg, dw_, /*lm_head=*/nullptr, draft_backend_, + draft_ctx, use_mirror_view ? &feature_mirror_ : nullptr, + committed, + std::min(ring_cap, std::max(DRAFT_CTX_MAX_DEFAULT, cfg_.draft_ctx_max)))) { + std::fprintf(stderr, "[gemma4-spec] draft build failed\n"); + step_graph_destroy(draft_sg); + return false; + } + if (!use_mirror_view && + !copy_feature_ring_range_to_tensor(feature_mirror_, draft_sg.target_hidden_cat, + draft_start, draft_ctx)) { + std::fprintf(stderr, "[gemma4-spec] feature copy failed\n"); + step_graph_destroy(draft_sg); + return false; + } + ggml_backend_tensor_set(draft_sg.inp_embed, noise_embed.data(), 0, + sizeof(float) * noise_embed.size()); + pos_k.resize((size_t)draft_ctx + q_len); + for (int i = 0; i < q_len; i++) pos_q[i] = draft_ctx + i; + for (int i = 0; i < draft_ctx + q_len; i++) pos_k[i] = i; + ggml_backend_tensor_set(draft_sg.positions, pos_q.data(), 0, + sizeof(int32_t) * pos_q.size()); + ggml_backend_tensor_set(draft_sg.positions_k, pos_k.data(), 0, + sizeof(int32_t) * pos_k.size()); + + auto st = ggml_backend_graph_compute(draft_backend_, draft_sg.gf); + if (st != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "[gemma4-spec] draft compute failed\n"); + step_graph_destroy(draft_sg); + return false; + } + + // Read draft hidden states + local_hidden.resize((size_t)hidden * q_len); + ggml_backend_tensor_get(draft_sg.hidden_states, local_hidden.data(), 0, + sizeof(float) * local_hidden.size()); + + // 3. Project draft hidden → token IDs via target LM head + if (!target->project_hidden_to_tokens(local_hidden.data(), q_len, draft_tok)) { + std::fprintf(stderr, "[gemma4-spec] projection failed\n"); + step_graph_destroy(draft_sg); + return false; + } + draft_tok[0] = last_tok; + + // 4. Verify: snapshot KV, run target forward over draft tokens + if (!target->snapshot_kv()) { + step_graph_destroy(draft_sg); + return false; + } + + int verify_last_tok = -1; + if (!target->verify_batch(draft_tok, committed, verify_last_tok, &target_tok)) { + std::fprintf(stderr, "[gemma4-spec] verify failed\n"); + target->restore_kv(); + step_graph_destroy(draft_sg); + return false; + } + + // 5. Acceptance: longest matching prefix + int accept_n = 1; + for (int i = 0; i < q_len - 1; i++) { + if (draft_tok[i + 1] == target_tok[i]) accept_n++; + else break; + } + int bonus_tok = (accept_n < q_len) ? target_tok[accept_n - 1] : -1; + int commit_n = accept_n + (bonus_tok >= 0 ? 1 : 0); + if (commit_n > need_commit_budget) { + commit_n = need_commit_budget; + if (commit_n <= accept_n) bonus_tok = -1; + } + + // 6. Replay: roll back KV and re-run accepted tokens + if (!target->restore_kv()) { + step_graph_destroy(draft_sg); + return false; + } + + std::vector replay_tok((size_t)commit_n); + for (int i = 0; i < commit_n; i++) { + replay_tok[i] = (i < accept_n) ? draft_tok[i] : bonus_tok; + } + int replay_last_tok = -1; + if (!target->verify_batch(replay_tok, committed, replay_last_tok, nullptr)) { + std::fprintf(stderr, "[gemma4-spec] replay failed\n"); + step_graph_destroy(draft_sg); + return false; + } + last_tok = replay_last_tok; + + // 7. Sync features for replayed range + if (feature_mirror_.target_feat && cache_.target_feat) { + draft_feature_mirror_sync_range(cache_.target_feat, cache_.target_feat_cap, + feature_mirror_, committed, commit_n); + } + + // 8. Emit committed tokens + bool hit_eos = false; + int emitted = 0; + for (int i = 0; i < commit_n; i++) { + out_tokens.push_back(replay_tok[i]); + io.emit(replay_tok[i]); + emitted++; + if (io.cancelled) break; + if (replay_tok[i] == w_.eos_id || replay_tok[i] == w_.eos_chat_id) { + hit_eos = true; break; + } + } + committed += emitted; + cache_.cur_pos = committed; + n_generated += emitted; + n_accept_sum += std::min(accept_n, emitted); + n_draft_steps++; + if (io.cancelled) break; + if (hit_eos) break; + } + + step_graph_destroy(draft_sg); + + auto t_dec1 = std::chrono::steady_clock::now(); + const double decode_s = std::chrono::duration(t_dec1 - t_dec0).count(); + const int total_draft_pos = std::max(1, n_draft_steps * q_len); + const double accept_pct = 100.0 * (double)n_accept_sum / (double)total_draft_pos; + std::fprintf(stderr, "[gemma4-spec] tokens=%d time=%.3f s speed=%.2f tok/s " + "steps=%d accepted=%d/%d (%.1f%%) avg_commit=%.2f\n", + n_generated, decode_s, + n_generated > 0 ? n_generated / decode_s : 0.0, + n_draft_steps, n_accept_sum, total_draft_pos, accept_pct, + n_draft_steps > 0 ? (double)n_generated / (double)n_draft_steps : 0.0); + + io.emit(-1); + return true; +} + // ── Generate ─────────────────────────────────────────────────────────── GenerateResult Gemma4Backend::generate(const GenerateRequest & req, @@ -224,58 +501,72 @@ GenerateResult Gemma4Backend::generate(const GenerateRequest & req, } if (req.n_gen > 0) { - const int hidden = w_.n_embd; - const int vocab = w_.n_vocab; - std::vector logits; - - // Re-step last token to get logits - int32_t last_tok = req.prompt.back(); - std::vector embed_buf(hidden); - w_.embedder.embed(&last_tok, 1, embed_buf.data()); - float scale = std::sqrt((float)hidden); - for (int j = 0; j < hidden; ++j) embed_buf[j] *= scale; - - if (!gemma4_step(backend_, w_, cache_, embed_buf.data(), - &last_tok, 1, committed - 1, logits)) { - result.error = "first logits"; - return result; - } - - // Sample first token - int32_t first; - if (sampler_.temp > 0) { - first = sample_logits(logits.data(), vocab, sampler_, - result.tokens, sampler_rng_); + // Try speculative decode if draft is available and temp==0 + const bool can_spec = dflash_target_ + && !draft_parked_ + && feature_mirror_.target_feat + && sampler_.temp == 0.0f; + + if (can_spec) { + if (!do_spec_decode(committed, req.n_gen, result.tokens, out_io)) { + result.error = "spec_decode"; + return result; + } } else { - first = 0; - float best = logits[0]; - for (int j = 1; j < vocab; ++j) { - if (logits[j] > best) { best = logits[j]; first = j; } + const int hidden = w_.n_embd; + const int vocab = w_.n_vocab; + std::vector logits; + + // Re-step last token to get logits + int32_t last_tok = req.prompt.back(); + std::vector embed_buf(hidden); + w_.embedder.embed(&last_tok, 1, embed_buf.data()); + float scale = std::sqrt((float)hidden); + for (int j = 0; j < hidden; ++j) embed_buf[j] *= scale; + + if (!gemma4_step(backend_, w_, cache_, embed_buf.data(), + &last_tok, 1, committed - 1, logits)) { + result.error = "first logits"; + return result; } - } - result.tokens.push_back(first); - out_io.emit(first); - if (out_io.cancelled) { - out_io.emit(-1); - result.ok = true; - return result; - } - if (first == w_.eos_id || first == w_.eos_chat_id) { - out_io.emit(-1); - result.ok = true; - return result; - } + // Sample first token + int32_t first; + if (sampler_.temp > 0) { + first = sample_logits(logits.data(), vocab, sampler_, + result.tokens, sampler_rng_); + } else { + first = 0; + float best = logits[0]; + for (int j = 1; j < vocab; ++j) { + if (logits[j] > best) { best = logits[j]; first = j; } + } + } + result.tokens.push_back(first); + out_io.emit(first); + if (out_io.cancelled) { + out_io.emit(-1); + result.ok = true; + return result; + } - if (req.n_gen > 1) { - if (!do_decode(committed, req.n_gen - 1, result.tokens, out_io)) { - result.error = "decode"; + if (first == w_.eos_id || first == w_.eos_chat_id) { + out_io.emit(-1); + result.ok = true; return result; } + + if (req.n_gen > 1) { + if (!do_decode(committed, req.n_gen - 1, result.tokens, out_io)) { + result.error = "decode"; + return result; + } + } + out_io.emit(-1); } + } else { + out_io.emit(-1); } - - out_io.emit(-1); result.ok = true; return result; } @@ -351,58 +642,71 @@ GenerateResult Gemma4Backend::restore_and_generate(int slot, // Generate if (req.n_gen > 0) { - const int hidden = w_.n_embd; - const int vocab = w_.n_vocab; - std::vector logits; - - // Re-step last token to get logits for first generated token - int32_t last_tok = req.prompt.back(); - std::vector embed_buf(hidden); - w_.embedder.embed(&last_tok, 1, embed_buf.data()); - float scale = std::sqrt((float)hidden); - for (int j = 0; j < hidden; ++j) embed_buf[j] *= scale; - - if (!gemma4_step(backend_, w_, cache_, embed_buf.data(), - &last_tok, 1, committed - 1, logits)) { - result.error = "first logits"; - return result; - } - - // Sample first token - int32_t first; - if (sampler_.temp > 0) { - first = sample_logits(logits.data(), vocab, sampler_, - result.tokens, sampler_rng_); + const bool can_spec = dflash_target_ + && !draft_parked_ + && feature_mirror_.target_feat + && sampler_.temp == 0.0f; + + if (can_spec) { + if (!do_spec_decode(committed, req.n_gen, result.tokens, out_io)) { + result.error = "spec_decode"; + return result; + } } else { - first = 0; - float best = logits[0]; - for (int j = 1; j < vocab; ++j) { - if (logits[j] > best) { best = logits[j]; first = j; } + const int hidden = w_.n_embd; + const int vocab = w_.n_vocab; + std::vector logits; + + // Re-step last token to get logits for first generated token + int32_t last_tok = req.prompt.back(); + std::vector embed_buf(hidden); + w_.embedder.embed(&last_tok, 1, embed_buf.data()); + float scale = std::sqrt((float)hidden); + for (int j = 0; j < hidden; ++j) embed_buf[j] *= scale; + + if (!gemma4_step(backend_, w_, cache_, embed_buf.data(), + &last_tok, 1, committed - 1, logits)) { + result.error = "first logits"; + return result; } - } - result.tokens.push_back(first); - out_io.emit(first); - if (out_io.cancelled) { - out_io.emit(-1); - result.ok = true; - return result; - } - if (first == w_.eos_id || first == w_.eos_chat_id) { - out_io.emit(-1); - result.ok = true; - return result; - } + // Sample first token + int32_t first; + if (sampler_.temp > 0) { + first = sample_logits(logits.data(), vocab, sampler_, + result.tokens, sampler_rng_); + } else { + first = 0; + float best = logits[0]; + for (int j = 1; j < vocab; ++j) { + if (logits[j] > best) { best = logits[j]; first = j; } + } + } + result.tokens.push_back(first); + out_io.emit(first); + if (out_io.cancelled) { + out_io.emit(-1); + result.ok = true; + return result; + } - if (req.n_gen > 1) { - if (!do_decode(committed, req.n_gen - 1, result.tokens, out_io)) { - result.error = "decode"; + if (first == w_.eos_id || first == w_.eos_chat_id) { + out_io.emit(-1); + result.ok = true; return result; } + + if (req.n_gen > 1) { + if (!do_decode(committed, req.n_gen - 1, result.tokens, out_io)) { + result.error = "decode"; + return result; + } + } + out_io.emit(-1); } + } else { + out_io.emit(-1); } - - out_io.emit(-1); result.ok = true; return result; } @@ -584,6 +888,11 @@ bool Gemma4Backend::try_handle_command(const std::string & line, void Gemma4Backend::shutdown() { for (int i = 0; i < PREFIX_SLOTS; ++i) snapshot_free(i); free_drafter(); + // Clean up DFlash spec-decode resources + delete dflash_target_; dflash_target_ = nullptr; + draft_feature_mirror_free(feature_mirror_); + if (dw_.ctx) { free_draft_weights(dw_); } + if (draft_backend_) { ggml_backend_free(draft_backend_); draft_backend_ = nullptr; } free_gemma4_cache(cache_); free_gemma4_weights(w_); free_snapshot_backend(snap_backend_, backend_); diff --git a/dflash/src/gemma4/gemma4_backend.h b/dflash/src/gemma4/gemma4_backend.h index 49468672..a5727d18 100644 --- a/dflash/src/gemma4/gemma4_backend.h +++ b/dflash/src/gemma4/gemma4_backend.h @@ -7,7 +7,10 @@ #include "common/model_backend.h" #include "common/device_placement.h" +#include "common/dflash_feature_ring.h" +#include "common/dflash_draft_graph.h" #include "gemma4_internal.h" +#include "gemma4_dflash_target.h" #include "common/sampler.h" #include "../qwen3/qwen3_drafter.h" @@ -22,6 +25,9 @@ namespace dflash::common { struct Gemma4BackendConfig { const char * model_path = nullptr; + const char * draft_path = nullptr; + int draft_gpu = -1; // GPU for draft model (-1 = same as target) + int draft_ctx_max = 2048; // max context for draft feature mirror DevicePlacement device; int stream_fd = -1; int chunk = 512; @@ -78,6 +84,13 @@ class Gemma4Backend : public ModelBackend { SamplerCfg sampler_; std::mt19937_64 sampler_rng_{std::random_device{}()}; + // DFlash speculative decode + ggml_backend_t draft_backend_ = nullptr; + DraftWeights dw_{}; + DraftFeatureMirror feature_mirror_{}; + Gemma4DFlashTarget * dflash_target_ = nullptr; + bool draft_parked_ = false; + // PFlash drafter (compress) DrafterContext drafter_ctx_; bool drafter_loaded_ = false; @@ -95,6 +108,11 @@ class Gemma4Backend : public ModelBackend { bool do_decode(int committed, int n_gen, std::vector & out_tokens, const DaemonIO & io); + + // DFlash speculative decode loop. + bool do_spec_decode(int committed, int n_gen, + std::vector & out_tokens, + const DaemonIO & io); }; } // namespace dflash::common diff --git a/dflash/src/gemma4/gemma4_dflash_target.cpp b/dflash/src/gemma4/gemma4_dflash_target.cpp index aa66080d..09cda2f7 100644 --- a/dflash/src/gemma4/gemma4_dflash_target.cpp +++ b/dflash/src/gemma4/gemma4_dflash_target.cpp @@ -14,12 +14,17 @@ Gemma4DFlashTarget::Gemma4DFlashTarget( Gemma4Cache & cache, ggml_backend_t backend) : w_(w), cache_(cache), backend_(backend) { - // Evenly-spaced capture layer IDs (same formula as qwen35 loader). - const int N = DFLASH27B_DRAFT_N_TARGET_LAYERS; - capture_ids_.resize(N); - const int step = std::max(1, (w.n_layer - 2) / (N - 1)); - for (int k = 0; k < N; k++) { - capture_ids_[k] = 1 + k * step; + // Use capture layer IDs from cache (computed when target_feat is allocated) + if (!cache.capture_layer_ids.empty()) { + capture_ids_ = cache.capture_layer_ids; + } else { + // Fallback: evenly-spaced (legacy path) + const int N = DFLASH27B_DRAFT_N_TARGET_LAYERS; + capture_ids_.resize(N); + const int step = std::max(1, (w.n_layer - 2) / (N - 1)); + for (int k = 0; k < N; k++) { + capture_ids_[k] = 1 + k * step; + } } } diff --git a/dflash/src/gemma4/gemma4_graph.cpp b/dflash/src/gemma4/gemma4_graph.cpp index 6359edc5..3340fc87 100644 --- a/dflash/src/gemma4/gemma4_graph.cpp +++ b/dflash/src/gemma4/gemma4_graph.cpp @@ -274,7 +274,8 @@ static ggml_tensor * build_gemma4_layer( ggml_tensor * attn_mask_swa, ggml_tensor * per_layer_input, // [n_embd_per_layer, n_tokens] or nullptr int kv_start, - int n_tokens) + int n_tokens, + int capture_idx = -1) // >=0: write to target_feat at this capture slot { const Gemma4Layer & L = w.layers[il]; @@ -335,6 +336,43 @@ static ggml_tensor * build_gemma4_layer( cur = ggml_mul(ctx, cur, L.out_scale); } + // Feature capture for DFlash spec-decode + if (capture_idx >= 0 && cache.target_feat) { + const int hidden = w.n_embd; + const size_t elt = ggml_element_size(cache.target_feat); + const size_t col_stride = cache.target_feat->nb[1]; + const int cap = cache.target_feat_cap; + const int slot_start = kv_start % cap; + const int pre_n = std::min(n_tokens, cap - slot_start); + const int post_n = n_tokens - pre_n; + + ggml_tensor * cur_2d = ggml_reshape_2d(ctx, cur, hidden, n_tokens); + + // First slice: [slot_start..slot_start+pre_n) in the ring. + { + const size_t offset = + (size_t)slot_start * col_stride + + (size_t)capture_idx * hidden * elt; + ggml_tensor * slot = ggml_view_2d(ctx, cache.target_feat, + hidden, pre_n, col_stride, offset); + ggml_tensor * src = ggml_view_2d(ctx, cur_2d, + hidden, pre_n, cur_2d->nb[1], 0); + ggml_build_forward_expand(gf, ggml_cpy(ctx, src, slot)); + } + + // Second slice: wrap-around at [0..post_n) if needed. + if (post_n > 0) { + const size_t offset = + (size_t)capture_idx * hidden * elt; + ggml_tensor * slot = ggml_view_2d(ctx, cache.target_feat, + hidden, post_n, col_stride, offset); + ggml_tensor * src = ggml_view_2d(ctx, cur_2d, + hidden, post_n, cur_2d->nb[1], + (size_t)pre_n * cur_2d->nb[1]); + ggml_build_forward_expand(gf, ggml_cpy(ctx, src, slot)); + } + } + return cur; } @@ -432,9 +470,16 @@ bool gemma4_step( // Slice [n_embd_per_layer, n_tokens] for this layer pl_input = gemma4_view_2d_slice(ctx, per_layer_all, il); } + // Determine capture index for this layer (-1 if not a capture layer) + int cap_idx = -1; + if (cache.target_feat) { + for (int k = 0; k < cache.n_capture_layers; k++) { + if (cache.capture_layer_ids[k] == il) { cap_idx = k; break; } + } + } cur = build_gemma4_layer(ctx, gf, w, cache, il, cur, pp, mk_full_f16, mk_swa_f16, pl_input, - kv_start, n_tokens); + kv_start, n_tokens, cap_idx); } // Final norm @@ -600,9 +645,15 @@ bool gemma4_verify_batch( if (per_layer_all) { pl_input = gemma4_view_2d_slice(ctx, per_layer_all, il); } + int cap_idx = -1; + if (cache.target_feat) { + for (int k = 0; k < cache.n_capture_layers; k++) { + if (cache.capture_layer_ids[k] == il) { cap_idx = k; break; } + } + } cur = build_gemma4_layer(ctx, gf, w, cache, il, cur, pp, mk_full_f16, mk_swa_f16, pl_input, - kv_start, n_tokens); + kv_start, n_tokens, cap_idx); } // Final norm diff --git a/dflash/src/gemma4/gemma4_internal.h b/dflash/src/gemma4/gemma4_internal.h index 2576d05d..872bc829 100644 --- a/dflash/src/gemma4/gemma4_internal.h +++ b/dflash/src/gemma4/gemma4_internal.h @@ -160,6 +160,7 @@ struct Gemma4Cache { int n_layer = 0; int swa_size = 0; // ring-buffer size for SWA layers (= sliding_window) int fa_window = 0; // sparse decode window for full-attn layers (0 = full) + int32_t last_tok = -1; // argmax of last prefill token (for spec-decode entry) // Only layers where has_kv[il] == true have real K/V tensors. // KV-reuse layers reference an earlier layer's cache. @@ -167,14 +168,28 @@ struct Gemma4Cache { std::vector v; std::vector kv_source; // for each layer, which layer's KV to use + // DFlash feature capture ring buffer (BF16, allocated when draft is active) + ggml_tensor * target_feat = nullptr; // [fc_in, target_feat_cap] + int target_feat_cap = 0; + int n_capture_layers = 0; + std::vector capture_layer_ids; + ggml_context * ctx = nullptr; ggml_backend_buffer_t buf = nullptr; + + // Separate context/buffer for target_feat (allocated after draft load) + ggml_context * feat_ctx = nullptr; + ggml_backend_buffer_t feat_buf = nullptr; }; bool create_gemma4_cache(ggml_backend_t backend, const Gemma4Weights & w, int max_ctx, Gemma4Cache & out); void free_gemma4_cache(Gemma4Cache & c); +// Allocate target_feat ring buffer (call after draft load determines n_capture_layers). +bool create_gemma4_target_feat(ggml_backend_t backend, Gemma4Cache & cache, + int n_capture_layers, int hidden_size, int cap); + // Snapshot struct Gemma4Snapshot { int cur_pos = 0; diff --git a/dflash/src/gemma4/gemma4_loader.cpp b/dflash/src/gemma4/gemma4_loader.cpp index 2179ea00..8c7c674e 100644 --- a/dflash/src/gemma4/gemma4_loader.cpp +++ b/dflash/src/gemma4/gemma4_loader.cpp @@ -421,12 +421,57 @@ bool create_gemma4_cache(ggml_backend_t backend, const Gemma4Weights & w, } void free_gemma4_cache(Gemma4Cache & c) { + if (c.feat_buf) { ggml_backend_buffer_free(c.feat_buf); c.feat_buf = nullptr; } + if (c.feat_ctx) { ggml_free(c.feat_ctx); c.feat_ctx = nullptr; } + c.target_feat = nullptr; + c.target_feat_cap = 0; + c.n_capture_layers = 0; + c.capture_layer_ids.clear(); if (c.buf) { ggml_backend_buffer_free(c.buf); c.buf = nullptr; } if (c.ctx) { ggml_free(c.ctx); c.ctx = nullptr; } c.k.clear(); c.v.clear(); c.kv_source.clear(); c.cur_pos = 0; } +bool create_gemma4_target_feat(ggml_backend_t backend, Gemma4Cache & cache, + int n_capture_layers, int hidden_size, int cap) { + if (n_capture_layers <= 0 || hidden_size <= 0 || cap <= 0) return false; + + // Free existing feat allocation + if (cache.feat_buf) { ggml_backend_buffer_free(cache.feat_buf); cache.feat_buf = nullptr; } + if (cache.feat_ctx) { ggml_free(cache.feat_ctx); cache.feat_ctx = nullptr; } + + ggml_init_params ip{}; + ip.mem_size = ggml_tensor_overhead() * 4 + 4096; + ip.no_alloc = true; + cache.feat_ctx = ggml_init(ip); + if (!cache.feat_ctx) return false; + + const int fc_in = n_capture_layers * hidden_size; + cache.target_feat = ggml_new_tensor_2d(cache.feat_ctx, GGML_TYPE_BF16, fc_in, cap); + ggml_set_name(cache.target_feat, "gemma4_target_feat"); + + cache.feat_buf = ggml_backend_alloc_ctx_tensors(cache.feat_ctx, backend); + if (!cache.feat_buf) { + ggml_free(cache.feat_ctx); cache.feat_ctx = nullptr; + cache.target_feat = nullptr; + return false; + } + + cache.target_feat_cap = cap; + cache.n_capture_layers = n_capture_layers; + + // Compute capture layer IDs (evenly spaced across layers) + cache.capture_layer_ids.resize(n_capture_layers); + const int n_layer = cache.n_layer; + const int step = std::max(1, (n_layer - 2) / (n_capture_layers - 1)); + for (int k = 0; k < n_capture_layers; k++) { + cache.capture_layer_ids[k] = 1 + k * step; + } + + return true; +} + void free_gemma4_snapshot(Gemma4Snapshot & s) { if (s.buf) { ggml_backend_buffer_free(s.buf); s.buf = nullptr; } if (s.ctx) { ggml_free(s.ctx); s.ctx = nullptr; } From 2065995f9125b176ed1ab22296023502e75c826f Mon Sep 17 00:00:00 2001 From: Howard Su Date: Wed, 20 May 2026 17:56:59 +0800 Subject: [PATCH 07/18] prefix_cache: add Gemma family detection for chat markers Gemma4 uses <|turn> / as single-token turn delimiters. Previously it incorrectly fell through to the Laguna family check because //etc. would encode to non-empty sequences. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- dflash/src/server/prefix_cache.cpp | 11 +++++++++++ dflash/src/server/prefix_cache.h | 4 ++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/dflash/src/server/prefix_cache.cpp b/dflash/src/server/prefix_cache.cpp index 72ceae72..26657b27 100644 --- a/dflash/src/server/prefix_cache.cpp +++ b/dflash/src/server/prefix_cache.cpp @@ -25,6 +25,17 @@ bool resolve_chat_markers(const Tokenizer & tok, ChatMarkers & out) { return true; } + // Try Gemma family: <|turn> (start) and (end) are single tokens. + auto turn_start = tok.encode("<|turn>"); + auto turn_end = tok.encode(""); + if (turn_start.size() == 1 && turn_end.size() == 1) { + out.family = "gemma"; + out.sys_role_prefix = {turn_start[0]}; + out.end_msg_seqs = {{turn_end[0]}}; + out.next_role_starts = {{turn_start[0]}}; + return true; + } + // Try Laguna family: XML-style markers. auto start_sys = tok.encode(""); auto end_sys = tok.encode(""); diff --git a/dflash/src/server/prefix_cache.h b/dflash/src/server/prefix_cache.h index cb0c551b..b9ec001f 100644 --- a/dflash/src/server/prefix_cache.h +++ b/dflash/src/server/prefix_cache.h @@ -26,14 +26,14 @@ namespace dflash::common { // ─── Chat marker detection ────────────────────────────────────────────── struct ChatMarkers { - std::string family; // "qwen" or "laguna" + std::string family; // "qwen", "gemma", or "laguna" // Token sequences for boundary detection std::vector sys_role_prefix; std::vector> end_msg_seqs; std::vector> next_role_starts; }; -// Resolve chat markers from the tokenizer (detects Qwen vs Laguna family). +// Resolve chat markers from the tokenizer (detects Qwen, Gemma, or Laguna family). bool resolve_chat_markers(const Tokenizer & tok, ChatMarkers & out); // Find all turn-boundary cut points in a token stream. From 78aaa06d3dae3c75b2aaa8bc8cdfa479e1f31a48 Mon Sep 17 00:00:00 2001 From: Howard Su Date: Wed, 20 May 2026 18:25:39 +0800 Subject: [PATCH 08/18] gemma4: fix DFlash spec-decode acceptance rate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three root-cause fixes identified from the HuggingFace model card (z-lab/gemma-4-31B-it-DFlash config.json): 1. mask_token_id: use 4 instead of 0 — the draft model was trained with token 4 as the mask/padding token. 2. capture_layer_ids: replace integer-truncation formula with floating-point linspace + rounding. For 60 layers / 6 captures: old: {1,12,23,34,45,56}, correct: {1,12,23,35,46,57}. 3. embed_tokens: remove sqrt(n_embd) scaling — the draft model expects raw unscaled embeddings (same as qwen35 convention). Also removes debug fprintf statements added during investigation. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- dflash/src/gemma4/gemma4_backend.cpp | 9 +++++++-- dflash/src/gemma4/gemma4_dflash_target.cpp | 13 +++++-------- dflash/src/gemma4/gemma4_loader.cpp | 8 +++++--- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/dflash/src/gemma4/gemma4_backend.cpp b/dflash/src/gemma4/gemma4_backend.cpp index a02365fd..5f42236d 100644 --- a/dflash/src/gemma4/gemma4_backend.cpp +++ b/dflash/src/gemma4/gemma4_backend.cpp @@ -67,8 +67,8 @@ bool Gemma4Backend::init() { std::fprintf(stderr, "[gemma4] draft load failed: %s\n", dflash27b_last_error()); ggml_backend_free(draft_backend_); draft_backend_ = nullptr; } else { - // Override mask_token_id for Gemma4 (use token 0 = pad) - dw_.mask_token_id = 0; + // Override mask_token_id for Gemma4 (token 4 per model card) + dw_.mask_token_id = 4; // Fix draft dimensions from actual tensor shapes (GGUF metadata is wrong) // fc.weight: [fc_in, draft_hidden] @@ -124,6 +124,10 @@ bool Gemma4Backend::init() { dflash_target_ = new Gemma4DFlashTarget(w_, cache_, backend_); std::printf("[gemma4] spec-decode ready: capture_layers=%d mirror_cap=%d\n", n_capture, mirror_cap); + std::printf("[gemma4] capture_layer_ids:"); + for (int k = 0; k < (int)cache_.capture_layer_ids.size(); k++) + std::printf(" %d", cache_.capture_layer_ids[k]); + std::printf("\n"); } } } @@ -360,6 +364,7 @@ bool Gemma4Backend::do_spec_decode(int committed, int n_gen, step_graph_destroy(draft_sg); return false; } + ggml_backend_tensor_set(draft_sg.inp_embed, noise_embed.data(), 0, sizeof(float) * noise_embed.size()); pos_k.resize((size_t)draft_ctx + q_len); diff --git a/dflash/src/gemma4/gemma4_dflash_target.cpp b/dflash/src/gemma4/gemma4_dflash_target.cpp index 09cda2f7..e3b83744 100644 --- a/dflash/src/gemma4/gemma4_dflash_target.cpp +++ b/dflash/src/gemma4/gemma4_dflash_target.cpp @@ -123,12 +123,9 @@ bool Gemma4DFlashTarget::is_eos(int token) const { bool Gemma4DFlashTarget::embed_tokens(const int32_t * tokens, int n, float * out) const { - if (!w_.embedder.embed(tokens, n, out)) return false; - // Gemma4 scales embeddings by sqrt(n_embd) - const float scale = std::sqrt((float)w_.n_embd); - const size_t total = (size_t)n * w_.n_embd; - for (size_t i = 0; i < total; ++i) out[i] *= scale; - return true; + // Return raw embeddings (no sqrt(n_embd) scale) — the draft model + // was trained on unscaled embeddings, matching the qwen35 convention. + return w_.embedder.embed(tokens, n, out); } bool Gemma4DFlashTarget::project_hidden_to_tokens( @@ -139,8 +136,8 @@ bool Gemma4DFlashTarget::project_hidden_to_tokens( } int Gemma4DFlashTarget::mask_token_id() const { - // Gemma4 uses token ID 0 as padding/mask - return 0; + // Gemma4 DFlash draft uses token ID 4 as mask (per model card) + return 4; } const std::vector & Gemma4DFlashTarget::capture_layer_ids() const { diff --git a/dflash/src/gemma4/gemma4_loader.cpp b/dflash/src/gemma4/gemma4_loader.cpp index 8c7c674e..2c82c7c6 100644 --- a/dflash/src/gemma4/gemma4_loader.cpp +++ b/dflash/src/gemma4/gemma4_loader.cpp @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -461,12 +462,13 @@ bool create_gemma4_target_feat(ggml_backend_t backend, Gemma4Cache & cache, cache.target_feat_cap = cap; cache.n_capture_layers = n_capture_layers; - // Compute capture layer IDs (evenly spaced across layers) + // Compute capture layer IDs using floating-point linspace with rounding. + // This matches the training config (e.g., gemma4: [1,12,23,35,46,57]). cache.capture_layer_ids.resize(n_capture_layers); const int n_layer = cache.n_layer; - const int step = std::max(1, (n_layer - 2) / (n_capture_layers - 1)); for (int k = 0; k < n_capture_layers; k++) { - cache.capture_layer_ids[k] = 1 + k * step; + cache.capture_layer_ids[k] = (int)std::round( + 1.0 + k * (double)(n_layer - 4) / (n_capture_layers - 1)); } return true; From f102502cda5e0e3c7037c35db4beccf8becde5b2 Mon Sep 17 00:00:00 2001 From: Howard Su Date: Wed, 20 May 2026 18:58:36 +0800 Subject: [PATCH 09/18] gemma4 dflash: fix SWA causal masking and rope_theta - Add causal attention mask for SWA layers in the draft model (layers 0-3 are sliding-window with causal masking, layer 4 is full non-causal). The draft was trained this way; running all-non-causal let future MASK embeddings leak into earlier positions, hurting acceptance rate. - Read rope_theta from draft GGUF metadata instead of hardcoded 10M constant (Gemma4 draft uses 1M, not 10M like Qwen3.5). - Remove double-normalization: gemma4_project_hidden now skips out_norm since the draft already applies its own final norm layer. - Scale embed_tokens by sqrt(n_embd) in DFlashTarget to match Gemma4 convention. - Set swa_window=2048 and mark layers[0..3].is_swa after draft GGUF loading. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- dflash/src/common/dflash_draft_graph.cpp | 64 +++++++++++++++++++++- dflash/src/draft/draft_dflash_graph.cpp | 30 ++++++---- dflash/src/draft/draft_gguf_loader.cpp | 7 +++ dflash/src/draft/draft_graph.h | 3 + dflash/src/gemma4/gemma4_backend.cpp | 10 +++- dflash/src/gemma4/gemma4_dflash_target.cpp | 10 +++- dflash/src/gemma4/gemma4_graph.cpp | 7 ++- dflash/src/internal.h | 1 + 8 files changed, 112 insertions(+), 20 deletions(-) diff --git a/dflash/src/common/dflash_draft_graph.cpp b/dflash/src/common/dflash_draft_graph.cpp index 2e60acb6..1563a9ef 100644 --- a/dflash/src/common/dflash_draft_graph.cpp +++ b/dflash/src/common/dflash_draft_graph.cpp @@ -2,11 +2,26 @@ #include "draft/draft_graph.h" // DraftGraphInputs, DraftGraphOutputs, build_draft_graph #include "ggml-alloc.h" +#include "ggml-backend.h" +#include #include +#include namespace dflash::common { +// Minimum alignment required by ggml flash_attn_ext for mask rows. +static constexpr int MASK_KV_PAD = 32; + +static inline int mask_align_up(int x, int a) { return ((x + a - 1) / a) * a; } + +// Check whether any layer in the draft is SWA. +static bool draft_has_swa_layers(const DraftWeights & dw) { + for (int i = 0; i < dw.n_layer; i++) + if (dw.layers[i].is_swa) return true; + return false; +} + // Build draft graph at a given ctx_len into sg. Does NOT touch sg.alloc. // mirror_view: if true, uses a view into mirror->target_feat at slot0. static bool build_draft_graph_internal( @@ -56,6 +71,21 @@ static bool build_draft_graph_internal( ggml_set_name(sg.positions_k, "positions_k"); ggml_set_input(sg.positions_k); + // Causal mask for SWA layers (if any). + // Shape: [kv_pad, q_len] F32; padded kv dim to MASK_KV_PAD alignment. + sg.attn_mask = nullptr; + const bool has_swa = draft_has_swa_layers(dw); + if (has_swa) { + // SWA layers' effective KV length (windowed or full ctx) + const bool swa_active = dw.swa_window > 0 && ctx_len > dw.swa_window; + const int eff_ctx = swa_active ? dw.swa_window : ctx_len; + const int eff_total_k = eff_ctx + q_len; + const int kv_pad = mask_align_up(eff_total_k, MASK_KV_PAD); + sg.attn_mask = ggml_new_tensor_2d(sg.ctx, GGML_TYPE_F32, kv_pad, q_len); + ggml_set_name(sg.attn_mask, "causal_mask_swa"); + ggml_set_input(sg.attn_mask); + } + sg.gf = ggml_new_graph_custom(sg.ctx, 4096, false); DraftGraphInputs gi{}; @@ -65,6 +95,7 @@ static bool build_draft_graph_internal( gi.positions_q = sg.positions; gi.positions_k = sg.positions_k; gi.lm_head = lm_head; + gi.causal_mask_swa = sg.attn_mask; DraftGraphOutputs go = build_draft_graph(sg.ctx, dw, gi); sg.hidden_states = go.hidden_states; sg.logits = go.logits; @@ -125,7 +156,38 @@ bool build_draft_step( return false; } - return ggml_gallocr_alloc_graph(sg.alloc, sg.gf); + if (!ggml_gallocr_alloc_graph(sg.alloc, sg.gf)) { + return false; + } + + // Fill causal mask data for SWA layers (after allocation gives memory to the tensor). + if (sg.attn_mask) { + const int q_len = dw.block_size; + const bool swa_active = dw.swa_window > 0 && ctx_len > dw.swa_window; + const int eff_ctx = swa_active ? dw.swa_window : ctx_len; + const int eff_total_k = eff_ctx + q_len; + const int kv_pad = mask_align_up(eff_total_k, MASK_KV_PAD); + + // Build causal mask: query at position (eff_ctx + q) can attend to + // key at position k if k <= eff_ctx + q. + // Context keys (k < eff_ctx): always visible. + // Noise keys (k = eff_ctx + j): visible if j <= q. + std::vector mask_data((size_t)kv_pad * q_len, -INFINITY); + for (int q = 0; q < q_len; q++) { + // All context positions are visible + for (int k = 0; k < eff_ctx; k++) { + mask_data[(size_t)q * kv_pad + k] = 0.0f; + } + // Noise positions: causal (only positions 0..q visible) + for (int j = 0; j <= q; j++) { + mask_data[(size_t)q * kv_pad + (eff_ctx + j)] = 0.0f; + } + } + ggml_backend_tensor_set(sg.attn_mask, mask_data.data(), 0, + sizeof(float) * mask_data.size()); + } + + return true; } } // namespace dflash::common diff --git a/dflash/src/draft/draft_dflash_graph.cpp b/dflash/src/draft/draft_dflash_graph.cpp index eddfba9a..23c7c019 100644 --- a/dflash/src/draft/draft_dflash_graph.cpp +++ b/dflash/src/draft/draft_dflash_graph.cpp @@ -1,28 +1,29 @@ // Builds a ggml compute graph for one forward pass of the DFlash draft -// (5-layer non-causal Qwen3-flavored block-diffusion model). +// (5-layer Qwen3-flavored block-diffusion model). // // Stateless: no KV cache. Each call takes: -// - noise_embed [hidden, q_len, 1] bf16 (target.tok_embd on [last_tok, MASK*15]) -// - target_hidden_cat [5*hidden, ctx_len, 1] bf16 (5 target layers concat along features) +// - noise_embed [hidden, q_len, 1] f32 (target.tok_embd on [last_tok, MASK*15]) +// - target_hidden_cat [N*hidden, ctx_len, 1] f32 (N target layers concat along features) // - positions_q [q_len] i32 values [ctx_len..ctx_len+q_len-1] // - positions_k [ctx_len+q_len] i32 values [0..ctx_len+q_len-1] +// - causal_mask_swa [kv_pad, q_len] f32 (optional; causal mask for SWA layers) // and returns: -// - hidden_states [hidden, q_len, 1] bf16 (final RMSNorm; NO lm_head here) +// - hidden_states [hidden, q_len, 1] f32 (final RMSNorm; NO lm_head here) // // The caller projects `hidden_states` through the TARGET's lm_head separately // (the draft has no lm_head of its own, it shares the target's). // -// Semantics match megaqwen3_27b_dflash/reference/dflash_reference.py exactly: +// Semantics: // - fc @ target_hidden_cat -> rms_norm with hidden_norm -> target_feat -// - Per layer (non-causal): +// - Per layer: // h_norm = rms_norm(h) * input_layernorm // Q = wq @ h_norm -> per-head q_norm // K_ctx/V_ctx = wk/wv @ target_feat // K_noi/V_noi = wk/wv @ h_norm // K = concat[K_ctx, K_noi] -> per-head k_norm // V = concat[V_ctx, V_noi] -// RoPE(Q, positions_q); RoPE(K, positions_k) (NEOX style, theta=10M) -// attn = flash_attn_ext(Q, K, V, mask=null, scale=1/sqrt(head_dim)) non-causal +// RoPE(Q, positions_q); RoPE(K, positions_k) (NEOX style) +// attn = flash_attn_ext(Q, K, V, mask, scale) SWA=causal, full=non-causal // h += wo @ attn // h_norm = rms_norm(h) * post_attention_layernorm // h += w_down @ (silu(w_gate @ h_norm) * (w_up @ h_norm)) @@ -46,7 +47,7 @@ DraftGraphOutputs build_draft_graph( const int n_kv = w.n_head_kv; const int head_dim = w.head_dim; const float eps = DFLASH27B_RMS_EPS; - const float rope_base = DFLASH27B_ROPE_THETA; + const float rope_base = w.rope_theta; // ── 1. Feature fusion: target_feat = rms_norm(fc @ target_hidden_cat, hidden_norm) // fc: [5*hidden, hidden] (ggml: ne[0]=5*hidden, ne[1]=hidden) @@ -60,6 +61,12 @@ DraftGraphOutputs build_draft_graph( // ── 2. Decoder layers ggml_tensor * h = in.noise_embed; // [hidden, q_len, 1] + // Pre-cast causal mask to F16 (flash_attn_ext requires F16 mask) + ggml_tensor * mask_f16 = nullptr; + if (in.causal_mask_swa) { + mask_f16 = ggml_cast(ctx, in.causal_mask_swa, GGML_TYPE_F16); + } + for (int il = 0; il < w.n_layer; il++) { const DraftLayer & L = w.layers[il]; @@ -134,9 +141,10 @@ DraftGraphOutputs build_draft_graph( V = ggml_permute(ctx, V, 0, 2, 1, 3); // [head_dim, eff_total_k, n_kv, 1] V = ggml_cont (ctx, V); - // ── 2f. Non-causal flash attention; GQA broadcast handled internally. + // ── 2f. Attention: causal for SWA layers, non-causal for full layers. const float scale = 1.0f / std::sqrt((float)head_dim); - ggml_tensor * attn = ggml_flash_attn_ext(ctx, Q, K, V, /*mask=*/nullptr, + ggml_tensor * mask = (L.is_swa && mask_f16) ? mask_f16 : nullptr; + ggml_tensor * attn = ggml_flash_attn_ext(ctx, Q, K, V, mask, scale, /*max_bias=*/0.0f, /*logit_softcap=*/0.0f); // attn result: [n_embd_v=head_dim, n_head, n_batch=q_len, 1] diff --git a/dflash/src/draft/draft_gguf_loader.cpp b/dflash/src/draft/draft_gguf_loader.cpp index 89f7b17c..aefdccae 100644 --- a/dflash/src/draft/draft_gguf_loader.cpp +++ b/dflash/src/draft/draft_gguf_loader.cpp @@ -159,6 +159,12 @@ bool load_draft_gguf(const std::string & path, std::snprintf(key, sizeof(key), "%s.%s", A, suffix); return get_u32_or(gctx, key, fallback); }; + auto read_f32 = [&](const char * suffix, float fallback) -> float { + std::snprintf(key, sizeof(key), "%s.%s", A, suffix); + int64_t id = gguf_find_key(gctx, key); + if (id < 0) return fallback; + return gguf_get_val_f32(gctx, id); + }; const uint32_t n_embd = read_u32("embedding_length", 0); const uint32_t n_layer = read_u32("block_count", 0); @@ -235,6 +241,7 @@ bool load_draft_gguf(const std::string & path, out.head_dim = (int)head_dim; out.n_embd = (int)n_embd; out.n_ff = (int)n_ff; + out.rope_theta = read_f32("rope.freq_base", DFLASH27B_ROPE_THETA); out.layers.assign((size_t)n_layer, DraftLayer{}); auto g = [&](const char * name) -> ggml_tensor * { diff --git a/dflash/src/draft/draft_graph.h b/dflash/src/draft/draft_graph.h index 28bc0d83..248dcd4a 100644 --- a/dflash/src/draft/draft_graph.h +++ b/dflash/src/draft/draft_graph.h @@ -18,6 +18,9 @@ struct DraftGraphInputs { // hidden states. Used for DFlash integration where the draft shares the // target's lm_head. ggml_tensor * lm_head; + // Optional: causal mask for SWA layers [kv_pad, q_len] F32 (cast to F16 in graph). + // nullptr = all layers non-causal. + ggml_tensor * causal_mask_swa = nullptr; }; struct DraftGraphOutputs { diff --git a/dflash/src/gemma4/gemma4_backend.cpp b/dflash/src/gemma4/gemma4_backend.cpp index 5f42236d..6204bdf7 100644 --- a/dflash/src/gemma4/gemma4_backend.cpp +++ b/dflash/src/gemma4/gemma4_backend.cpp @@ -103,9 +103,15 @@ bool Gemma4Backend::init() { // Override n_target_layers from fc shape dw_.n_target_layers = n_capture; + // Gemma4 DFlash draft: layers 0-3 are SWA (causal), layer 4 is full (non-causal) + // (from model card: layer_types = [sliding*4, full_attention]) + dw_.swa_window = 2048; + for (int i = 0; i < dw_.n_layer - 1 && i < (int)dw_.layers.size(); i++) + dw_.layers[i].is_swa = true; + std::printf("[gemma4] draft loaded: fc_in=%d target_hidden=%d " - "draft_hidden=%d n_capture_layers=%d\n", - fc_in, w_.n_embd, draft_hidden, n_capture); + "draft_hidden=%d n_capture_layers=%d swa=%d\n", + fc_in, w_.n_embd, draft_hidden, n_capture, dw_.swa_window); // Allocate target_feat ring buffer constexpr int TARGET_FEAT_CAP = 4096; diff --git a/dflash/src/gemma4/gemma4_dflash_target.cpp b/dflash/src/gemma4/gemma4_dflash_target.cpp index e3b83744..5d044234 100644 --- a/dflash/src/gemma4/gemma4_dflash_target.cpp +++ b/dflash/src/gemma4/gemma4_dflash_target.cpp @@ -123,9 +123,13 @@ bool Gemma4DFlashTarget::is_eos(int token) const { bool Gemma4DFlashTarget::embed_tokens(const int32_t * tokens, int n, float * out) const { - // Return raw embeddings (no sqrt(n_embd) scale) — the draft model - // was trained on unscaled embeddings, matching the qwen35 convention. - return w_.embedder.embed(tokens, n, out); + if (!w_.embedder.embed(tokens, n, out)) return false; + // Scale by sqrt(n_embd) to match Gemma4's embedding convention. + // The draft was trained with Gemma4's scaled embeddings as noise input. + const float scale = std::sqrt((float)w_.n_embd); + const size_t total = (size_t)n * w_.n_embd; + for (size_t i = 0; i < total; ++i) out[i] *= scale; + return true; } bool Gemma4DFlashTarget::project_hidden_to_tokens( diff --git a/dflash/src/gemma4/gemma4_graph.cpp b/dflash/src/gemma4/gemma4_graph.cpp index 3340fc87..f69cf34a 100644 --- a/dflash/src/gemma4/gemma4_graph.cpp +++ b/dflash/src/gemma4/gemma4_graph.cpp @@ -751,12 +751,13 @@ bool gemma4_project_hidden( ggml_cgraph * gf = ggml_new_graph(ctx); // Input: hidden states [n_embd, n_tokens] + // NOTE: The DFlash draft model already applies its own final RMSNorm, + // so we skip the target's out_norm and go directly to lm_head. ggml_tensor * inp = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, w.n_embd, n_tokens); ggml_set_input(inp); - // out_norm + lm_head - ggml_tensor * cur = gemma4_rms_norm_mul(ctx, inp, w.out_norm, w.norm_eps); - cur = ggml_mul_mat(ctx, w.output, cur); // [n_vocab, n_tokens] + // lm_head (skip out_norm — draft already normalized) + ggml_tensor * cur = ggml_mul_mat(ctx, w.output, inp); // [n_vocab, n_tokens] // Logit softcapping if (w.final_logit_softcap > 0.0f) { diff --git a/dflash/src/internal.h b/dflash/src/internal.h index 6f5666df..083f2f65 100644 --- a/dflash/src/internal.h +++ b/dflash/src/internal.h @@ -230,6 +230,7 @@ struct DraftWeights { int n_embd = DFLASH27B_TARGET_HIDDEN; // 5120 int n_ff = DFLASH27B_TARGET_INTERMEDIATE; // 17408 int swa_window = 0; // sliding window size (0 = disabled) + float rope_theta = DFLASH27B_ROPE_THETA; // RoPE frequency base // DFlash draft-specific config (populated by loader or set by caller). int block_size = DFLASH27B_DRAFT_BLOCK_SIZE; // tokens per draft step (16 or 10) From 106a59eb62755867869ac84bfbb0f9b90aa2f062 Mon Sep 17 00:00:00 2001 From: Howard Su Date: Wed, 20 May 2026 19:33:28 +0800 Subject: [PATCH 10/18] draft: use F16 mask directly, remove unnecessary F32 cast MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Match the pattern from attn_masks.h — create the causal mask tensor as GGML_TYPE_F16 directly and fill with uint16_t values (0x0000 for attend, 0xFC00 for -inf). This eliminates the intermediate ggml_cast op in the draft graph and reduces memory usage. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- dflash/src/common/dflash_draft_graph.cpp | 29 +++++++++++------------- dflash/src/draft/draft_dflash_graph.cpp | 8 +------ 2 files changed, 14 insertions(+), 23 deletions(-) diff --git a/dflash/src/common/dflash_draft_graph.cpp b/dflash/src/common/dflash_draft_graph.cpp index 1563a9ef..c028bb88 100644 --- a/dflash/src/common/dflash_draft_graph.cpp +++ b/dflash/src/common/dflash_draft_graph.cpp @@ -4,7 +4,7 @@ #include "ggml-alloc.h" #include "ggml-backend.h" -#include +#include #include #include @@ -72,7 +72,7 @@ static bool build_draft_graph_internal( ggml_set_input(sg.positions_k); // Causal mask for SWA layers (if any). - // Shape: [kv_pad, q_len] F32; padded kv dim to MASK_KV_PAD alignment. + // Shape: [kv_pad, q_len] F16 (directly, no cast needed — matches attn_masks.h pattern). sg.attn_mask = nullptr; const bool has_swa = draft_has_swa_layers(dw); if (has_swa) { @@ -81,7 +81,7 @@ static bool build_draft_graph_internal( const int eff_ctx = swa_active ? dw.swa_window : ctx_len; const int eff_total_k = eff_ctx + q_len; const int kv_pad = mask_align_up(eff_total_k, MASK_KV_PAD); - sg.attn_mask = ggml_new_tensor_2d(sg.ctx, GGML_TYPE_F32, kv_pad, q_len); + sg.attn_mask = ggml_new_tensor_2d(sg.ctx, GGML_TYPE_F16, kv_pad, q_len); ggml_set_name(sg.attn_mask, "causal_mask_swa"); ggml_set_input(sg.attn_mask); } @@ -168,23 +168,20 @@ bool build_draft_step( const int eff_total_k = eff_ctx + q_len; const int kv_pad = mask_align_up(eff_total_k, MASK_KV_PAD); - // Build causal mask: query at position (eff_ctx + q) can attend to - // key at position k if k <= eff_ctx + q. + // Build causal mask in F16 directly (same pattern as attn_masks.h): // Context keys (k < eff_ctx): always visible. - // Noise keys (k = eff_ctx + j): visible if j <= q. - std::vector mask_data((size_t)kv_pad * q_len, -INFINITY); + // Noise keys (k = eff_ctx + j): visible if j <= q (causal). + static constexpr uint16_t ZERO = 0x0000; + static constexpr uint16_t NEG_INF = 0xFC00; + std::vector mask_data((size_t)kv_pad * q_len, NEG_INF); for (int q = 0; q < q_len; q++) { - // All context positions are visible - for (int k = 0; k < eff_ctx; k++) { - mask_data[(size_t)q * kv_pad + k] = 0.0f; - } - // Noise positions: causal (only positions 0..q visible) - for (int j = 0; j <= q; j++) { - mask_data[(size_t)q * kv_pad + (eff_ctx + j)] = 0.0f; - } + for (int k = 0; k < eff_ctx; k++) + mask_data[(size_t)q * kv_pad + k] = ZERO; + for (int j = 0; j <= q; j++) + mask_data[(size_t)q * kv_pad + (eff_ctx + j)] = ZERO; } ggml_backend_tensor_set(sg.attn_mask, mask_data.data(), 0, - sizeof(float) * mask_data.size()); + sizeof(uint16_t) * mask_data.size()); } return true; diff --git a/dflash/src/draft/draft_dflash_graph.cpp b/dflash/src/draft/draft_dflash_graph.cpp index 23c7c019..11f5b4c5 100644 --- a/dflash/src/draft/draft_dflash_graph.cpp +++ b/dflash/src/draft/draft_dflash_graph.cpp @@ -61,12 +61,6 @@ DraftGraphOutputs build_draft_graph( // ── 2. Decoder layers ggml_tensor * h = in.noise_embed; // [hidden, q_len, 1] - // Pre-cast causal mask to F16 (flash_attn_ext requires F16 mask) - ggml_tensor * mask_f16 = nullptr; - if (in.causal_mask_swa) { - mask_f16 = ggml_cast(ctx, in.causal_mask_swa, GGML_TYPE_F16); - } - for (int il = 0; il < w.n_layer; il++) { const DraftLayer & L = w.layers[il]; @@ -143,7 +137,7 @@ DraftGraphOutputs build_draft_graph( // ── 2f. Attention: causal for SWA layers, non-causal for full layers. const float scale = 1.0f / std::sqrt((float)head_dim); - ggml_tensor * mask = (L.is_swa && mask_f16) ? mask_f16 : nullptr; + ggml_tensor * mask = (L.is_swa && in.causal_mask_swa) ? in.causal_mask_swa : nullptr; ggml_tensor * attn = ggml_flash_attn_ext(ctx, Q, K, V, mask, scale, /*max_bias=*/0.0f, /*logit_softcap=*/0.0f); From 85bc4c3ae5449843ee786d941e6dac83e08a1f1e Mon Sep 17 00:00:00 2001 From: Howard Su Date: Wed, 20 May 2026 19:35:43 +0800 Subject: [PATCH 11/18] =?UTF-8?q?draft:=20rename=20draft=5Fdflash=5Fgraph.?= =?UTF-8?q?cpp=20=E2=86=92=20draft=5Fgraph.cpp=20to=20match=20header?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The implementation file now matches its header (draft_graph.h), eliminating confusion with the similarly-named common/dflash_draft_graph.cpp orchestrator. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- dflash/CMakeLists.txt | 2 +- dflash/docs/SPEC_PREFILL.md | 2 +- dflash/include/dflash27b.h | 2 +- dflash/src/draft/draft_gguf_loader.cpp | 2 +- dflash/src/draft/{draft_dflash_graph.cpp => draft_graph.cpp} | 0 dflash/src/draft/draft_graph.h | 2 +- 6 files changed, 5 insertions(+), 5 deletions(-) rename dflash/src/draft/{draft_dflash_graph.cpp => draft_graph.cpp} (100%) diff --git a/dflash/CMakeLists.txt b/dflash/CMakeLists.txt index a69d8041..4f8869f1 100644 --- a/dflash/CMakeLists.txt +++ b/dflash/CMakeLists.txt @@ -215,7 +215,7 @@ add_library(dflash_common STATIC src/qwen35/qwen35_target_graph.cpp src/draft/draft_gguf_loader.cpp src/draft/draft_safetensors_loader.cpp - src/draft/draft_dflash_graph.cpp + src/draft/draft_graph.cpp src/qwen3/qwen3_drafter.cpp src/qwen3/qwen3_loader.cpp src/qwen3/qwen3_graph.cpp diff --git a/dflash/docs/SPEC_PREFILL.md b/dflash/docs/SPEC_PREFILL.md index f7dc25d5..cb2837dc 100644 --- a/dflash/docs/SPEC_PREFILL.md +++ b/dflash/docs/SPEC_PREFILL.md @@ -89,7 +89,7 @@ src/ qwen35_target_graph.cpp Qwen3.5/3.6 target graph (ggml) gguf_target_loader.cpp Qwen3.5 target GGUF loader draft/ Special DFlash draft model code - draft_dflash_graph.cpp DFlash speculative draft head + draft_graph.cpp DFlash speculative draft head draft_gguf_loader.cpp Draft GGUF loader draft_safetensors_loader.cpp Draft safetensors loader laguna/ Laguna target + daemon model code diff --git a/dflash/include/dflash27b.h b/dflash/include/dflash27b.h index 29c4d1b1..c155021e 100644 --- a/dflash/include/dflash27b.h +++ b/dflash/include/dflash27b.h @@ -24,7 +24,7 @@ extern "C" { // Qwen3.5-27B qwen35 hybrid uses 24 Q heads, 4 KV heads, 256 head_dim, which // live in `src/internal.h` (n_embd_head_k/v, N_HEAD, N_HEAD_KV). Naming is // historical — do not change without updating draft_safetensors_loader.cpp + -// draft_dflash_graph.cpp which consume these as draft-side constants. +// draft_graph.cpp which consume these as draft-side constants. #define DFLASH27B_TARGET_N_HEADS 32 #define DFLASH27B_TARGET_N_KV_HEADS 8 #define DFLASH27B_TARGET_HEAD_DIM 128 diff --git a/dflash/src/draft/draft_gguf_loader.cpp b/dflash/src/draft/draft_gguf_loader.cpp index aefdccae..94505911 100644 --- a/dflash/src/draft/draft_gguf_loader.cpp +++ b/dflash/src/draft/draft_gguf_loader.cpp @@ -2,7 +2,7 @@ // on the CUDA backend. // // This is the Q8_0-quantized counterpart of draft_safetensors_loader.cpp. The -// draft graph builder (draft_dflash_graph.cpp) doesn't care about tensor storage +// draft graph builder (draft_graph.cpp) doesn't care about tensor storage // types — ggml's ggml_mul_mat handles Q8_0 × F32 dequantization transparently. // // GGUF arch: "qwen35-dflash-draft" (from convert_dflash_to_gguf.py / diff --git a/dflash/src/draft/draft_dflash_graph.cpp b/dflash/src/draft/draft_graph.cpp similarity index 100% rename from dflash/src/draft/draft_dflash_graph.cpp rename to dflash/src/draft/draft_graph.cpp diff --git a/dflash/src/draft/draft_graph.h b/dflash/src/draft/draft_graph.h index 248dcd4a..30baa17b 100644 --- a/dflash/src/draft/draft_graph.h +++ b/dflash/src/draft/draft_graph.h @@ -18,7 +18,7 @@ struct DraftGraphInputs { // hidden states. Used for DFlash integration where the draft shares the // target's lm_head. ggml_tensor * lm_head; - // Optional: causal mask for SWA layers [kv_pad, q_len] F32 (cast to F16 in graph). + // Optional: causal mask for SWA layers [kv_pad, q_len] F16. // nullptr = all layers non-causal. ggml_tensor * causal_mask_swa = nullptr; }; From 03aeda5322ad41a489a62761ea44b535ab838e33 Mon Sep 17 00:00:00 2001 From: Howard Su Date: Wed, 20 May 2026 19:58:30 +0800 Subject: [PATCH 12/18] draft: remove DFLASH27B_ROPE_THETA constant, read from GGUF only The 10M default was Qwen3.5-specific and silently wrong for other models (e.g. Gemma4 uses 1M). Now rope_theta must come from the draft GGUF metadata; a warning is printed if the key is missing. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- dflash/include/dflash27b.h | 1 - dflash/src/draft/draft_gguf_loader.cpp | 5 ++++- dflash/src/internal.h | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/dflash/include/dflash27b.h b/dflash/include/dflash27b.h index c155021e..b707b2d8 100644 --- a/dflash/include/dflash27b.h +++ b/dflash/include/dflash27b.h @@ -30,7 +30,6 @@ extern "C" { #define DFLASH27B_TARGET_HEAD_DIM 128 #define DFLASH27B_TARGET_INTERMEDIATE 17408 #define DFLASH27B_TARGET_VOCAB 248320 -#define DFLASH27B_ROPE_THETA 10000000.0f #define DFLASH27B_RMS_EPS 1e-6f #define DFLASH27B_DRAFT_LAYERS 5 diff --git a/dflash/src/draft/draft_gguf_loader.cpp b/dflash/src/draft/draft_gguf_loader.cpp index 94505911..9a6ffdf6 100644 --- a/dflash/src/draft/draft_gguf_loader.cpp +++ b/dflash/src/draft/draft_gguf_loader.cpp @@ -241,7 +241,10 @@ bool load_draft_gguf(const std::string & path, out.head_dim = (int)head_dim; out.n_embd = (int)n_embd; out.n_ff = (int)n_ff; - out.rope_theta = read_f32("rope.freq_base", DFLASH27B_ROPE_THETA); + out.rope_theta = read_f32("rope.freq_base", 0.0f); + if (out.rope_theta == 0.0f) { + fprintf(stderr, "[draft-gguf] WARNING: rope.freq_base not found in GGUF, draft RoPE will be wrong\n"); + } out.layers.assign((size_t)n_layer, DraftLayer{}); auto g = [&](const char * name) -> ggml_tensor * { diff --git a/dflash/src/internal.h b/dflash/src/internal.h index 083f2f65..36f2064f 100644 --- a/dflash/src/internal.h +++ b/dflash/src/internal.h @@ -230,7 +230,7 @@ struct DraftWeights { int n_embd = DFLASH27B_TARGET_HIDDEN; // 5120 int n_ff = DFLASH27B_TARGET_INTERMEDIATE; // 17408 int swa_window = 0; // sliding window size (0 = disabled) - float rope_theta = DFLASH27B_ROPE_THETA; // RoPE frequency base + float rope_theta = 0.0f; // RoPE frequency base (must come from GGUF) // DFlash draft-specific config (populated by loader or set by caller). int block_size = DFLASH27B_DRAFT_BLOCK_SIZE; // tokens per draft step (16 or 10) From 9fe0ce4a76240ddc2577244726b72bf90e5e91b8 Mon Sep 17 00:00:00 2001 From: Howard Su Date: Wed, 20 May 2026 20:10:57 +0800 Subject: [PATCH 13/18] gemma4 spec-decode: replace snapshot/replay with KV truncation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Gemma4 is a pure transformer — after verify, KV entries at accepted positions are already correct (causal masking guarantees independence from rejected tokens). Replace the expensive snapshot → verify → restore → replay pattern with: verify(16 tokens) → truncate KV → bonus(1 token) This eliminates: - 2x full KV cache copies (60 layers × K + V each direction) - The replay forward pass (~9 tokens through 60 layers) Measured ~2.2x speedup on RTX 2080 Ti (9.5 → 21 tok/s). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- dflash/src/gemma4/gemma4_backend.cpp | 54 ++++++++++++++-------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/dflash/src/gemma4/gemma4_backend.cpp b/dflash/src/gemma4/gemma4_backend.cpp index 6204bdf7..6386db7c 100644 --- a/dflash/src/gemma4/gemma4_backend.cpp +++ b/dflash/src/gemma4/gemma4_backend.cpp @@ -401,16 +401,14 @@ bool Gemma4Backend::do_spec_decode(int committed, int n_gen, } draft_tok[0] = last_tok; - // 4. Verify: snapshot KV, run target forward over draft tokens - if (!target->snapshot_kv()) { - step_graph_destroy(draft_sg); - return false; - } - + // 4. Verify: run target forward over all draft tokens. + // Gemma4 is a pure transformer — after verify, KV entries at accepted + // positions are already correct (causal masking guarantees independence + // from rejected tokens at later positions). We use KV truncation instead + // of the expensive snapshot/restore/replay approach. int verify_last_tok = -1; if (!target->verify_batch(draft_tok, committed, verify_last_tok, &target_tok)) { std::fprintf(stderr, "[gemma4-spec] verify failed\n"); - target->restore_kv(); step_graph_destroy(draft_sg); return false; } @@ -428,25 +426,26 @@ bool Gemma4Backend::do_spec_decode(int committed, int n_gen, if (commit_n <= accept_n) bonus_tok = -1; } - // 6. Replay: roll back KV and re-run accepted tokens - if (!target->restore_kv()) { - step_graph_destroy(draft_sg); - return false; - } - - std::vector replay_tok((size_t)commit_n); - for (int i = 0; i < commit_n; i++) { - replay_tok[i] = (i < accept_n) ? draft_tok[i] : bonus_tok; - } - int replay_last_tok = -1; - if (!target->verify_batch(replay_tok, committed, replay_last_tok, nullptr)) { - std::fprintf(stderr, "[gemma4-spec] replay failed\n"); - step_graph_destroy(draft_sg); - return false; + // 6. KV truncation: discard rejected positions, keep accepted. + // Accepted positions 0..accept_n-1 already have correct KV from verify. + cache_.cur_pos = committed + accept_n; + + // If there's a bonus token, run a 1-token forward to get its KV + features. + if (bonus_tok >= 0) { + std::vector bonus_vec = {bonus_tok}; + int bonus_last = -1; + if (!target->verify_batch(bonus_vec, committed + accept_n, bonus_last, nullptr)) { + std::fprintf(stderr, "[gemma4-spec] bonus forward failed\n"); + step_graph_destroy(draft_sg); + return false; + } + last_tok = bonus_last; + } else { + last_tok = verify_last_tok; } - last_tok = replay_last_tok; - // 7. Sync features for replayed range + // 7. Sync features from verify (positions 0..accept_n-1 are correct) + // and from bonus forward (position accept_n, if present). if (feature_mirror_.target_feat && cache_.target_feat) { draft_feature_mirror_sync_range(cache_.target_feat, cache_.target_feat_cap, feature_mirror_, committed, commit_n); @@ -456,11 +455,12 @@ bool Gemma4Backend::do_spec_decode(int committed, int n_gen, bool hit_eos = false; int emitted = 0; for (int i = 0; i < commit_n; i++) { - out_tokens.push_back(replay_tok[i]); - io.emit(replay_tok[i]); + int tok = (i < accept_n) ? draft_tok[i] : bonus_tok; + out_tokens.push_back(tok); + io.emit(tok); emitted++; if (io.cancelled) break; - if (replay_tok[i] == w_.eos_id || replay_tok[i] == w_.eos_chat_id) { + if (tok == w_.eos_id || tok == w_.eos_chat_id) { hit_eos = true; break; } } From f854a11c2b31aaae915933d3efefb628ec90670b Mon Sep 17 00:00:00 2001 From: Howard Su Date: Thu, 21 May 2026 08:29:00 +0800 Subject: [PATCH 14/18] gemma4: add BSA sparse-FA prefill path + unified flash_prefill_forward dispatch - Implement gemma4_prefill_bsa() for per-layer BSA prefill using flash_prefill_forward for SWA layers (head_dim=128) with dense FA fallback for full-attention layers (head_dim=256). - Write KV cache during Graph A (ring-buffer aware for SWA layers). - Add GGML_ASSERT guard for swa_size > 0 before modulo operation. - Add flash_prefill_forward() unified dispatch to flashprefill.h that selects bf16/f16/q8 kernel based on compile flags + buffer type. - Simplify Qwen3 attention dispatch to use the unified function. - Remove duplicated ifdef boilerplate from both model implementations. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- dflash/src/flashprefill.h | 31 ++ dflash/src/gemma4/gemma4_graph.cpp | 703 ++++++++++++++++++++++++++++ dflash/src/gemma4/gemma4_internal.h | 13 + dflash/src/qwen3/qwen3_graph.cpp | 81 +--- 4 files changed, 760 insertions(+), 68 deletions(-) diff --git a/dflash/src/flashprefill.h b/dflash/src/flashprefill.h index 1cb0f66d..fd0c64e0 100644 --- a/dflash/src/flashprefill.h +++ b/dflash/src/flashprefill.h @@ -90,6 +90,37 @@ int flash_prefill_forward_q8( ggml_type qkv_type, const FlashPrefillConfig & cfg); +// ── Unified dispatch ────────────────────────────────────────────────────────── +// Picks the best available kernel at compile time + runtime buffer type: +// BF16 buffers + sm_80 build → flash_prefill_forward_bf16 +// F16 buffers + Volta build → flash_prefill_forward_f16 +// otherwise → flash_prefill_forward_q8 (ggml FA fallback) +// +// Callers no longer need to duplicate the ifdef/dispatch boilerplate. +inline int flash_prefill_forward( + ggml_backend_t backend, + const void * Q, const void * K, const void * V, void * O, + int batch, int seq_len, int n_q_heads, int n_k_heads, int head_dim, + float scale, + ggml_type qkv_type, + const FlashPrefillConfig & cfg) +{ +#if defined(DFLASH27B_HAVE_FLASHPREFILL) || defined(DFLASH27B_HAVE_SM80_FLASHPREFILL) + if (qkv_type == GGML_TYPE_BF16) { + return flash_prefill_forward_bf16(Q, K, V, O, + batch, seq_len, n_q_heads, n_k_heads, head_dim, scale, cfg); + } +#endif +#if defined(DFLASH27B_HAVE_VOLTA_FLASHPREFILL) || defined(DFLASH27B_HAVE_PASCAL_FLASHPREFILL) + if (qkv_type == GGML_TYPE_F16) { + return flash_prefill_forward_f16(Q, K, V, O, + batch, seq_len, n_q_heads, n_k_heads, head_dim, scale, cfg); + } +#endif + return flash_prefill_forward_q8(backend, Q, K, V, O, + batch, seq_len, n_q_heads, n_k_heads, head_dim, scale, qkv_type, cfg); +} + #ifdef DFLASH27B_HAVE_BSA // Free BSA persistent device buffers (blockmask, head_mask_type, softmax_lse). // Safe to call any time; idempotent. Useful before unloading the drafter to diff --git a/dflash/src/gemma4/gemma4_graph.cpp b/dflash/src/gemma4/gemma4_graph.cpp index f69cf34a..41791203 100644 --- a/dflash/src/gemma4/gemma4_graph.cpp +++ b/dflash/src/gemma4/gemma4_graph.cpp @@ -17,6 +17,7 @@ #include "gemma4_internal.h" #include "dflash27b.h" +#include "flashprefill.h" #include #include @@ -544,6 +545,7 @@ bool gemma4_step( // The ring buffer stores the most recent min(abs_q+1, swa_size) entries. // Cache slot j holds absolute position: depends on how many tokens written. const int total_written = abs_q + 1; // positions [0..abs_q] written so far + GGML_ASSERT(swa_size > 0 && "SWA branch entered with uninitialised cache.swa_size"); for (int abs_k = win_lo; abs_k <= abs_q; ++abs_k) { // Map absolute position to ring-buffer slot const int slot = abs_k % swa_size; @@ -798,4 +800,705 @@ bool gemma4_project_hidden( return true; } +// ── gemma4_prefill_bsa ────────────────────────────────────────────────── +// Full-prompt BSA prefill: processes all tokens at once, layer-by-layer. +// SWA layers use flash_prefill_forward_bf16 (block-sparse attention). +// Full-attention layers use ggml_flash_attn_ext (dense, exact). +// After all layers: fills KV cache for subsequent decode. + +// Persistent buffer helper (same pattern as Qwen3). +struct G4PersBuf { + ggml_context * ctx = nullptr; + ggml_backend_buffer_t buf = nullptr; + ggml_tensor * t = nullptr; +}; + +static bool g4_make_pers(ggml_backend_t backend, ggml_type type, int n_dim, + const int64_t * dims, G4PersBuf & out) { + ggml_init_params ip{}; + ip.mem_size = ggml_tensor_overhead() * 4 + 1024; + ip.no_alloc = true; + ip.mem_buffer = nullptr; + out.ctx = ggml_init(ip); + if (!out.ctx) return false; + if (n_dim == 1) out.t = ggml_new_tensor_1d(out.ctx, type, dims[0]); + else if (n_dim == 2) out.t = ggml_new_tensor_2d(out.ctx, type, dims[0], dims[1]); + else if (n_dim == 3) out.t = ggml_new_tensor_3d(out.ctx, type, dims[0], dims[1], dims[2]); + else return false; + out.buf = ggml_backend_alloc_ctx_tensors(out.ctx, backend); + return out.buf != nullptr; +} + +static void g4_free_pers(G4PersBuf & p) { + if (p.buf) { ggml_backend_buffer_free(p.buf); p.buf = nullptr; } + if (p.ctx) { ggml_free(p.ctx); p.ctx = nullptr; } + p.t = nullptr; +} + +static int g4_bsa_chunk_size() { + if (const char * e = std::getenv("DFLASH_G4_BSA_CHUNK")) { + int v = std::atoi(e); + if (v >= 512) return v; + } + return 4096; +} + +bool gemma4_prefill_bsa( + ggml_backend_t backend, + const Gemma4Weights & w, + Gemma4Cache & cache, + const float * embed, + const int32_t * token_ids, + int S, + std::vector & out_logits) +{ + const int hidden = w.n_embd; + const int n_layer = w.n_layer; + const int n_head = w.n_head; + const float eps = w.norm_eps; + + // Determine max dimensions across all layers for buffer allocation. + int max_q_dim = 0, max_kv_dim = 0; + for (int il = 0; il < n_layer; ++il) { + const int D = gemma4_head_dim(w, il); + const int Hk = gemma4_n_head_kv(w, il); + max_q_dim = std::max(max_q_dim, D * n_head); + max_kv_dim = std::max(max_kv_dim, D * Hk); + } + + // Use BF16 only for sm_80+ (native BF16 tensor cores). Volta/Turing + // use F16 with F16 WMMA kernels; other arches use F16 with ggml FA fallback. + const ggml_type half_type = +#ifdef DFLASH27B_HAVE_SM80_FLASHPREFILL + GGML_TYPE_BF16; +#else + GGML_TYPE_F16; +#endif + + // Allocate persistent buffers. + G4PersBuf hidden_buf{}, Q_buf{}, K_buf{}, V_buf{}, attn_out_buf{}; + int64_t d_h[] = {(int64_t)hidden, (int64_t)S}; + int64_t d_q[] = {(int64_t)max_q_dim, (int64_t)S}; + int64_t d_kv[] = {(int64_t)max_kv_dim, (int64_t)S}; + + auto cleanup_all = [&]() { + g4_free_pers(hidden_buf); + g4_free_pers(Q_buf); + g4_free_pers(K_buf); + g4_free_pers(V_buf); + g4_free_pers(attn_out_buf); + }; + + if (!g4_make_pers(backend, GGML_TYPE_F32, 2, d_h, hidden_buf) || + !g4_make_pers(backend, half_type, 2, d_q, Q_buf) || + !g4_make_pers(backend, half_type, 2, d_kv, K_buf) || + !g4_make_pers(backend, half_type, 2, d_kv, V_buf) || + !g4_make_pers(backend, half_type, 2, d_q, attn_out_buf)) { + std::fprintf(stderr, "[gemma4-bsa] persistent buffer alloc failed\n"); + cleanup_all(); + return false; + } + + // Upload embedded+scaled input to hidden_buf. + ggml_backend_tensor_set(hidden_buf.t, embed, 0, (size_t)hidden * S * sizeof(float)); + + // Precompute per-layer embeddings on GPU if the model has them. + // per_layer_all: [n_embd_per_layer, S, n_layer] — computed once, sliced per layer. + G4PersBuf per_layer_buf{}; + if (token_ids && w.per_layer_tok_embd && w.per_layer_model_proj && w.n_embd_per_layer > 0) { + const int D_pl = w.n_embd_per_layer; + const int L_pl = n_layer; + int64_t d_pl[] = {(int64_t)D_pl, (int64_t)S, (int64_t)L_pl}; + if (!g4_make_pers(backend, GGML_TYPE_F32, 3, d_pl, per_layer_buf)) { + std::fprintf(stderr, "[gemma4-bsa] per-layer buf alloc failed\n"); + cleanup_all(); + return false; + } + + // Build a graph to compute per-layer embeddings. + ggml_init_params ip{}; + ip.mem_size = ggml_tensor_overhead() * 32 + ggml_graph_overhead() + 1024 * 1024; + ip.no_alloc = true; + ggml_context * ctx = ggml_init(ip); + ggml_cgraph * gf = ggml_new_graph(ctx); + + ggml_tensor * tok = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, S); + ggml_set_input(tok); + ggml_tensor * h_in = ggml_view_2d(ctx, hidden_buf.t, hidden, S, + hidden * sizeof(float), 0); + + // get_rows(per_layer_tok_embd, tok) → [D_pl*L_pl, S] + ggml_tensor * inp_pl = ggml_get_rows(ctx, w.per_layer_tok_embd, tok); + inp_pl = ggml_reshape_3d(ctx, inp_pl, D_pl, L_pl, S); + inp_pl = ggml_scale(ctx, inp_pl, std::sqrt((float)D_pl)); + + // Project main embedding: mul_mat(per_layer_model_proj, h_in) + ggml_tensor * proj = ggml_mul_mat(ctx, w.per_layer_model_proj, h_in); + proj = ggml_scale(ctx, proj, 1.0f / std::sqrt((float)hidden)); + proj = ggml_reshape_3d(ctx, proj, D_pl, L_pl, S); + + // RMS norm on projection + proj = ggml_rms_norm(ctx, proj, eps); + ggml_tensor * norm_w = ggml_reshape_2d(ctx, w.per_layer_proj_norm, D_pl, L_pl); + proj = ggml_mul(ctx, proj, norm_w); + + // Add + scale + ggml_tensor * pl_all = ggml_add(ctx, proj, inp_pl); + pl_all = ggml_scale(ctx, pl_all, 1.0f / std::sqrt(2.0f)); + + // Permute to [D_pl, S, L_pl] and copy to persistent buffer + pl_all = ggml_cont(ctx, ggml_permute(ctx, pl_all, 0, 2, 1, 3)); + ggml_tensor * cpy = ggml_cpy(ctx, pl_all, per_layer_buf.t); + ggml_set_output(cpy); + ggml_build_forward_expand(gf, cpy); + + ggml_gallocr_t ga = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + if (!ggml_gallocr_alloc_graph(ga, gf)) { + std::fprintf(stderr, "[gemma4-bsa] per-layer graph alloc failed\n"); + ggml_gallocr_free(ga); ggml_free(ctx); + g4_free_pers(per_layer_buf); cleanup_all(); + return false; + } + ggml_backend_tensor_set(tok, token_ids, 0, (size_t)S * sizeof(int32_t)); + ggml_backend_graph_compute(backend, gf); + ggml_gallocr_free(ga); + ggml_free(ctx); + } + + // Gallocr for per-layer graphs (reused). + ggml_gallocr_t galloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + const int CHUNK = g4_bsa_chunk_size(); + + // FlashPrefill config for SWA layers. + const int block_size = 128; + const int swa_window_blocks = (w.sliding_window + block_size - 1) / block_size; + flashprefill::FlashPrefillConfig swa_cfg; + swa_cfg.block_size = block_size; + swa_cfg.attention_sink = 0; + swa_cfg.window = swa_window_blocks; + swa_cfg.last_n_full = 0; + swa_cfg.alpha = 2.0f; // > 1.0 disables dynamic block selection + + // Scale for attention: Gemma4 uses 1.0 (Q/K already RMS-normed per head). + const float kq_scale = 1.0f; + + // ── Per-layer loop ── + for (int il = 0; il < n_layer; ++il) { + const Gemma4Layer & L = w.layers[il]; + const bool is_swa = gemma4_is_swa_layer(w, il); + const bool has_kv = gemma4_has_kv(w, il); + const int D = gemma4_head_dim(w, il); + const int Hk = gemma4_n_head_kv(w, il); + const int q_dim = D * n_head; + const int kv_dim = D * Hk; + + // ── Graph A (chunked): pre_norm + Q/K/V proj + norms + RoPE → persistent bufs ── + const float rope_base = is_swa ? w.rope_freq_base_swa : w.rope_freq_base_full; + ggml_tensor * freq_factors_ref = is_swa ? nullptr : + (L.rope_freqs ? L.rope_freqs : w.rope_freqs_global); + + for (int cs = 0; cs < S; cs += CHUNK) { + const int cl = std::min(CHUNK, S - cs); + + ggml_init_params ipA{}; + ipA.mem_size = ggml_tensor_overhead() * 64 + + ggml_graph_overhead_custom(512, false) + + 128 * 1024; + ipA.no_alloc = true; + ggml_context * gA = ggml_init(ipA); + if (!gA) { std::fprintf(stderr, "[gemma4-bsa] graph A init failed\n"); ggml_gallocr_free(galloc); cleanup_all(); g4_free_pers(per_layer_buf); return false; } + ggml_cgraph * gfA = ggml_new_graph_custom(gA, 512, false); + + // View into hidden_buf for this chunk. + const size_t h_esz = sizeof(float); + ggml_tensor * h_view = ggml_view_2d(gA, hidden_buf.t, + hidden, cl, + hidden * h_esz, + (size_t)cs * hidden * h_esz); + + // Positions for RoPE. + ggml_tensor * pos_t = ggml_new_tensor_1d(gA, GGML_TYPE_I32, cl); + ggml_set_input(pos_t); + + // Pre-attn norm. + ggml_tensor * h_norm = ggml_rms_norm(gA, h_view, eps); + h_norm = ggml_mul(gA, h_norm, L.attn_norm); + + // Q projection + norm + RoPE. + ggml_tensor * Q = ggml_mul_mat(gA, L.wq, h_norm); + Q = ggml_reshape_3d(gA, Q, D, n_head, cl); + if (L.q_norm) { + Q = gemma4_rms_norm_mul(gA, Q, L.q_norm, eps); + } + Q = ggml_rope_ext(gA, Q, pos_t, freq_factors_ref, D, + GGML_ROPE_TYPE_NEOX, 0, rope_base, 1.0f, + 0.0f, 1.0f, 32.0f, 1.0f); + // Reshape Q to [q_dim, cl] and copy to Q_buf. + Q = ggml_reshape_2d(gA, Q, q_dim, cl); + + const size_t q_esz = ggml_type_size(half_type); + ggml_tensor * Q_dst = ggml_view_2d(gA, Q_buf.t, q_dim, cl, + q_esz * max_q_dim, + (size_t)cs * q_esz * max_q_dim); + ggml_build_forward_expand(gfA, ggml_cpy(gA, Q, Q_dst)); + + if (has_kv) { + // K projection + norm + RoPE. + ggml_tensor * K = ggml_mul_mat(gA, L.wk, h_norm); + K = ggml_reshape_3d(gA, K, D, Hk, cl); + if (L.k_norm) { + K = gemma4_rms_norm_mul(gA, K, L.k_norm, eps); + } + K = ggml_rope_ext(gA, K, pos_t, freq_factors_ref, D, + GGML_ROPE_TYPE_NEOX, 0, rope_base, 1.0f, + 0.0f, 1.0f, 32.0f, 1.0f); + K = ggml_reshape_2d(gA, K, kv_dim, cl); + + // V projection + RMSNorm (Gemma4 specific). + ggml_tensor * V = L.wv ? ggml_mul_mat(gA, L.wv, h_norm) + : ggml_mul_mat(gA, L.wk, h_norm); + V = ggml_reshape_3d(gA, V, D, Hk, cl); + V = ggml_rms_norm(gA, V, eps); + V = ggml_reshape_2d(gA, V, kv_dim, cl); + + const size_t kv_esz = ggml_type_size(half_type); + ggml_tensor * K_dst = ggml_view_2d(gA, K_buf.t, kv_dim, cl, + kv_esz * max_kv_dim, + (size_t)cs * kv_esz * max_kv_dim); + ggml_tensor * V_dst = ggml_view_2d(gA, V_buf.t, kv_dim, cl, + kv_esz * max_kv_dim, + (size_t)cs * kv_esz * max_kv_dim); + ggml_build_forward_expand(gfA, ggml_cpy(gA, K, K_dst)); + ggml_build_forward_expand(gfA, ggml_cpy(gA, V, V_dst)); + + // Write to KV cache for subsequent decode. + // K is [kv_dim, cl] = [D*Hk, cl]. Cache is [D, cache_len, Hk] F16. + // Reshape K to [D, Hk, cl] → permute to [D, cl, Hk] → copy into cache slot. + ggml_tensor * cache_k_t = cache.k[il]; + ggml_tensor * cache_v_t = cache.v[il]; + if (cache_k_t) { + const int cache_len_il = (int)cache_k_t->ne[1]; + const int ring_pos = is_swa ? (cs % cache_len_il) : cs; + + // Lambda to copy a sub-range of K/V into cache. + auto write_kv_range = [&](int src_off, int dst_ring, int n) { + if (n <= 0) return; + // K[src_off:src_off+n] → cache_k[dst_ring:dst_ring+n] + ggml_tensor * Ks = (src_off == 0 && n == cl) ? K + : ggml_view_2d(gA, K, kv_dim, n, + K->nb[1], (size_t)src_off * K->nb[1]); + ggml_tensor * K3 = ggml_reshape_3d(gA, Ks, D, Hk, n); + ggml_tensor * Kp = ggml_cont(gA, ggml_permute(gA, K3, 0, 2, 1, 3)); + ggml_tensor * k_slot = ggml_view_3d(gA, cache_k_t, + D, n, Hk, + cache_k_t->nb[1], cache_k_t->nb[2], + cache_k_t->nb[1] * (size_t)dst_ring); + ggml_build_forward_expand(gfA, ggml_cpy(gA, Kp, k_slot)); + + ggml_tensor * Vs = (src_off == 0 && n == cl) ? V + : ggml_view_2d(gA, V, kv_dim, n, + V->nb[1], (size_t)src_off * V->nb[1]); + ggml_tensor * V3 = ggml_reshape_3d(gA, Vs, D, Hk, n); + ggml_tensor * Vp = ggml_cont(gA, ggml_permute(gA, V3, 0, 2, 1, 3)); + ggml_tensor * v_slot = ggml_view_3d(gA, cache_v_t, + D, n, Hk, + cache_v_t->nb[1], cache_v_t->nb[2], + cache_v_t->nb[1] * (size_t)dst_ring); + ggml_build_forward_expand(gfA, ggml_cpy(gA, Vp, v_slot)); + }; + + if (!is_swa && ring_pos + cl > cache_len_il) { + // Full-attention layer: positions exceed cache — truncate. + const int n_fit = cache_len_il - ring_pos; + if (n_fit > 0) write_kv_range(0, ring_pos, n_fit); + } else if (is_swa && ring_pos + cl > cache_len_il) { + // SWA ring wrap — split into two writes. + const int first_n = cache_len_il - ring_pos; + write_kv_range(0, ring_pos, first_n); + write_kv_range(first_n, 0, cl - first_n); + } else { + write_kv_range(0, ring_pos, cl); + } + } + } + + if (!ggml_gallocr_alloc_graph(galloc, gfA)) { + std::fprintf(stderr, "[gemma4-bsa] graph A alloc failed layer=%d cs=%d\n", il, cs); + ggml_free(gA); ggml_gallocr_free(galloc); cleanup_all(); g4_free_pers(per_layer_buf); + return false; + } + + // Set positions. + std::vector pos((size_t)cl); + for (int i = 0; i < cl; ++i) pos[i] = cs + i; + ggml_backend_tensor_set(pos_t, pos.data(), 0, (size_t)cl * sizeof(int32_t)); + + ggml_backend_graph_compute(backend, gfA); + ggml_backend_synchronize(backend); + ggml_free(gA); + } + + // ── Attention ── + // Determine which K/V to use (KV sharing). + const int kv_source_il = cache.kv_source[il]; + // If this layer reuses another layer's KV, the source layer's K/V is + // already in K_buf/V_buf from when that layer was processed. + // For BSA we need the source layer's buffers, but since we process + // layers sequentially and overwrite K_buf/V_buf each layer, we need + // to handle sharing differently. + // + // Simplification: KV-sharing layers have the same head_dim and n_head_kv + // as their source. During BSA prefill, we DON'T overwrite K_buf/V_buf + // for layers without has_kv, so they still hold the source layer's data. + // This works because kv_source[il] < il for sharing layers. + + bool used_bsa = false; + if (is_swa && D == 128) { + // ── BSA sparse-FA for SWA layers (head_dim=128) ── + const bool q_contiguous = (q_dim == max_q_dim); + const bool kv_contiguous = (kv_dim == max_kv_dim); + + int rc; + if (q_contiguous && kv_contiguous) { + rc = flashprefill::flash_prefill_forward( + backend, Q_buf.t->data, K_buf.t->data, + V_buf.t->data, attn_out_buf.t->data, + 1, S, n_head, Hk, D, kq_scale, half_type, swa_cfg); + } else { + // Non-contiguous: allocate temporary packed buffers. + G4PersBuf Q_pack{}, K_pack{}, V_pack{}, O_pack{}; + int64_t dq[] = {(int64_t)q_dim, (int64_t)S}; + int64_t dk[] = {(int64_t)kv_dim, (int64_t)S}; + if (!g4_make_pers(backend, half_type, 2, dq, Q_pack) || + !g4_make_pers(backend, half_type, 2, dk, K_pack) || + !g4_make_pers(backend, half_type, 2, dk, V_pack) || + !g4_make_pers(backend, half_type, 2, dq, O_pack)) { + std::fprintf(stderr, "[gemma4-bsa] pack buf alloc failed\n"); + g4_free_pers(Q_pack); g4_free_pers(K_pack); + g4_free_pers(V_pack); g4_free_pers(O_pack); + ggml_gallocr_free(galloc); cleanup_all(); g4_free_pers(per_layer_buf); + return false; + } + + const size_t esz = ggml_type_size(half_type); + cudaMemcpy2D(Q_pack.t->data, q_dim * esz, + Q_buf.t->data, max_q_dim * esz, + q_dim * esz, S, cudaMemcpyDeviceToDevice); + cudaMemcpy2D(K_pack.t->data, kv_dim * esz, + K_buf.t->data, max_kv_dim * esz, + kv_dim * esz, S, cudaMemcpyDeviceToDevice); + cudaMemcpy2D(V_pack.t->data, kv_dim * esz, + V_buf.t->data, max_kv_dim * esz, + kv_dim * esz, S, cudaMemcpyDeviceToDevice); + + rc = flashprefill::flash_prefill_forward( + backend, Q_pack.t->data, K_pack.t->data, + V_pack.t->data, O_pack.t->data, + 1, S, n_head, Hk, D, kq_scale, half_type, swa_cfg); + + // Copy packed output back to strided attn_out_buf. + cudaMemcpy2D(attn_out_buf.t->data, max_q_dim * esz, + O_pack.t->data, q_dim * esz, + q_dim * esz, S, cudaMemcpyDeviceToDevice); + + g4_free_pers(Q_pack); g4_free_pers(K_pack); + g4_free_pers(V_pack); g4_free_pers(O_pack); + } + + if (rc != 0) { + std::fprintf(stderr, "[gemma4-bsa] flash_prefill failed layer=%d rc=%d\n", il, rc); + ggml_gallocr_free(galloc); cleanup_all(); g4_free_pers(per_layer_buf); + return false; + } + cudaDeviceSynchronize(); + used_bsa = true; + } + + if (!used_bsa) { + // Build a ggml graph for dense causal attention for this layer. + // Process the full sequence in one FA call (or chunked if too large). + for (int cs = 0; cs < S; cs += CHUNK) { + const int cl = std::min(CHUNK, S - cs); + const int kv_len = cs + cl; // attend to all positions up to current + + ggml_init_params ipFA{}; + ipFA.mem_size = ggml_tensor_overhead() * 32 + + ggml_graph_overhead_custom(64, false) + + 128 * 1024; + ipFA.no_alloc = true; + ggml_context * gFA = ggml_init(ipFA); + ggml_cgraph * gfFA = ggml_new_graph_custom(gFA, 64, false); + + const size_t esz = ggml_type_size(half_type); + + // Q view: [D, n_head, cl] from Q_buf + ggml_tensor * Qfa = ggml_view_3d(gFA, Q_buf.t, + D, n_head, cl, + esz * D, esz * max_q_dim, + (size_t)cs * esz * max_q_dim); + + // K view: [D, Hk, kv_len] from K_buf + ggml_tensor * Kfa = ggml_view_3d(gFA, K_buf.t, + D, Hk, kv_len, + esz * D, esz * max_kv_dim, + 0); + + // V view: [D, Hk, kv_len] from V_buf + ggml_tensor * Vfa = ggml_view_3d(gFA, V_buf.t, + D, Hk, kv_len, + esz * D, esz * max_kv_dim, + 0); + + // Causal mask: [kv_len_padded, cl] + const int kv_len_padded = (kv_len + 255) & ~255; + ggml_tensor * mask = ggml_new_tensor_4d(gFA, GGML_TYPE_F32, + kv_len_padded, cl, 1, 1); + ggml_set_input(mask); + ggml_tensor * mask_f16 = ggml_cast(gFA, mask, GGML_TYPE_F16); + + ggml_tensor * attn = ggml_flash_attn_ext(gFA, Qfa, Kfa, Vfa, mask_f16, + kq_scale, 0.0f, 0.0f); + + // Write output to attn_out_buf: [q_dim, cl] at offset cs. + attn = ggml_reshape_2d(gFA, attn, q_dim, cl); + ggml_tensor * O_dst = ggml_view_2d(gFA, attn_out_buf.t, q_dim, cl, + esz * max_q_dim, + (size_t)cs * esz * max_q_dim); + ggml_tensor * cpy_op = ggml_cpy(gFA, attn, O_dst); + ggml_set_output(cpy_op); + ggml_build_forward_expand(gfFA, cpy_op); + + if (!ggml_gallocr_alloc_graph(galloc, gfFA)) { + std::fprintf(stderr, "[gemma4-bsa] dense FA alloc failed layer=%d\n", il); + ggml_free(gFA); ggml_gallocr_free(galloc); cleanup_all(); g4_free_pers(per_layer_buf); + return false; + } + + // Fill causal mask. + std::vector m((size_t)kv_len_padded * cl, -INFINITY); + for (int q = 0; q < cl; ++q) { + const int abs_q = cs + q; + for (int k = 0; k <= abs_q && k < kv_len; ++k) { + m[(size_t)q * kv_len_padded + k] = 0.0f; + } + } + ggml_backend_tensor_set(mask, m.data(), 0, ggml_nbytes(mask)); + + ggml_backend_graph_compute(backend, gfFA); + ggml_backend_synchronize(backend); + ggml_free(gFA); + } + } + + // ── Graph B (chunked): o_proj + post_norm + residual + FFN + per_layer + scale ── + for (int cs = 0; cs < S; cs += CHUNK) { + const int cl = std::min(CHUNK, S - cs); + + ggml_init_params ipB{}; + ipB.mem_size = ggml_tensor_overhead() * 128 + + ggml_graph_overhead_custom(1024, false) + + 2 * 1024 * 1024; + ipB.no_alloc = true; + ggml_context * gB = ggml_init(ipB); + if (!gB) { std::fprintf(stderr, "[gemma4-bsa] graph B init failed\n"); ggml_gallocr_free(galloc); cleanup_all(); g4_free_pers(per_layer_buf); return false; } + ggml_cgraph * gfB = ggml_new_graph_custom(gB, 1024, false); + + const size_t h_esz = sizeof(float); + const size_t a_esz = ggml_type_size(half_type); + + // Hidden state for this chunk (residual input). + ggml_tensor * h_in = ggml_view_2d(gB, hidden_buf.t, hidden, cl, + hidden * h_esz, + (size_t)cs * hidden * h_esz); + + // Attention output for this chunk. + ggml_tensor * a_in = ggml_view_2d(gB, attn_out_buf.t, q_dim, cl, + a_esz * max_q_dim, + (size_t)cs * a_esz * max_q_dim); + + // o_proj: [q_dim, n_embd] × [q_dim, cl] → [n_embd, cl] + ggml_tensor * cur = ggml_mul_mat(gB, L.wo, a_in); + + // Post-attn norm. + if (L.attn_post_norm) { + cur = gemma4_rms_norm_mul(gB, cur, L.attn_post_norm, eps); + } + + // Residual after attention. + ggml_tensor * attn_res = ggml_add(gB, cur, h_in); + + // FFN. + const bool is_moe = (L.ffn_gate_inp != nullptr && il >= w.n_layer_dense_lead); + ggml_tensor * ffn_out; + if (is_moe) { + ggml_tensor * normed = gemma4_rms_norm_mul(gB, attn_res, L.ffn_norm, eps); + ffn_out = build_gemma4_moe_block(gB, attn_res, normed, w, L, cl); + } else { + cur = gemma4_rms_norm_mul(gB, attn_res, L.ffn_norm, eps); + ffn_out = build_gemma4_dense_ffn(gB, cur, L); + } + + // FFN post-norm. + if (L.ffn_post_norm) { + ffn_out = gemma4_rms_norm_mul(gB, ffn_out, L.ffn_post_norm, eps); + } + + // Residual after FFN. + cur = ggml_add(gB, ffn_out, attn_res); + + // Per-layer embedding injection. + if (per_layer_buf.t && L.per_layer_inp_gate && L.per_layer_proj) { + const int D_pl = w.n_embd_per_layer; + // Slice per_layer_buf [D_pl, S, n_layer] → [D_pl, cl] for this layer+chunk + ggml_tensor * pl_slice = ggml_view_2d(gB, per_layer_buf.t, + D_pl, cl, + D_pl * sizeof(float), + ((size_t)il * S + cs) * D_pl * sizeof(float)); + + ggml_tensor * gate = ggml_mul_mat(gB, L.per_layer_inp_gate, cur); + gate = ggml_gelu(gB, gate); + gate = ggml_mul(gB, gate, pl_slice); + ggml_tensor * proj = ggml_mul_mat(gB, L.per_layer_proj, gate); + if (L.per_layer_post_norm) { + proj = gemma4_rms_norm_mul(gB, proj, L.per_layer_post_norm, eps); + } + cur = ggml_add(gB, cur, proj); + } + + // Output scale. + if (L.out_scale) { + cur = ggml_mul(gB, cur, L.out_scale); + } + + // Write back to hidden_buf. + ggml_tensor * h_dst = ggml_view_2d(gB, hidden_buf.t, hidden, cl, + hidden * h_esz, + (size_t)cs * hidden * h_esz); + ggml_tensor * cpy = ggml_cpy(gB, cur, h_dst); + ggml_set_output(cpy); + ggml_build_forward_expand(gfB, cpy); + + if (!ggml_gallocr_alloc_graph(galloc, gfB)) { + std::fprintf(stderr, "[gemma4-bsa] graph B alloc failed layer=%d cs=%d\n", il, cs); + ggml_free(gB); ggml_gallocr_free(galloc); cleanup_all(); g4_free_pers(per_layer_buf); + return false; + } + + ggml_backend_graph_compute(backend, gfB); + ggml_backend_synchronize(backend); + ggml_free(gB); + } + + // Feature capture: write hidden states at capture layers to target_feat ring. + if (cache.target_feat) { + int cap_idx = -1; + for (int k = 0; k < cache.n_capture_layers; k++) { + if (cache.capture_layer_ids[k] == il) { cap_idx = k; break; } + } + if (cap_idx >= 0) { + const int cap = cache.target_feat_cap; + const size_t feat_col_stride = cache.target_feat->nb[1]; + const size_t feat_elt = ggml_element_size(cache.target_feat); + // Write last min(S, cap) positions into the ring buffer. + const int write_start = (S > cap) ? (S - cap) : 0; + const int write_n = std::min(S, cap); + for (int cs = write_start; cs < write_start + write_n; cs += CHUNK) { + const int cl = std::min(CHUNK, write_start + write_n - cs); + const int slot_start = cs % cap; + + ggml_init_params ipC{}; + ipC.mem_size = ggml_tensor_overhead() * 8 + + ggml_graph_overhead() + 64 * 1024; + ipC.no_alloc = true; + ggml_context * gC = ggml_init(ipC); + ggml_cgraph * gfC = ggml_new_graph(gC); + + ggml_tensor * h_src = ggml_view_2d(gC, hidden_buf.t, + hidden, cl, hidden * sizeof(float), + (size_t)cs * hidden * sizeof(float)); + + const size_t offset = (size_t)slot_start * feat_col_stride + + (size_t)cap_idx * hidden * feat_elt; + ggml_tensor * feat_dst = ggml_view_2d(gC, cache.target_feat, + hidden, cl, feat_col_stride, offset); + + ggml_build_forward_expand(gfC, ggml_cpy(gC, h_src, feat_dst)); + + if (ggml_gallocr_alloc_graph(galloc, gfC)) { + ggml_backend_graph_compute(backend, gfC); + } + ggml_free(gC); + } + } + } + } // end layer loop + + // ── Fill KV cache for decode ── + // KV cache was not populated during the BSA layer loop because K_buf/V_buf + // get overwritten each layer. We re-project K/V for each KV-owning layer + // from the hidden states that were stored before each layer's attention. + // + // However, we don't have the pre-norm hidden states anymore (hidden_buf has + // the final output). The correct approach is to write KV cache during the + // layer loop. Since this is a v1 implementation, we use the fallback: + // after BSA prefill returns, the caller (do_prefill) will run a single + // gemma4_step with the last chunk to populate the cache for decode. + // + // TODO: Move KV cache writes into the layer loop for zero-redundancy. + // For now, the caller handles cache population by running a trailing + // gemma4_step over the last swa_size tokens. + + // ── Final norm + logits (last token only) ── + { + ggml_init_params ipF{}; + ipF.mem_size = ggml_tensor_overhead() * 16 + ggml_graph_overhead() + 1024 * 1024; + ipF.no_alloc = true; + ggml_context * gF = ggml_init(ipF); + ggml_cgraph * gfF = ggml_new_graph(gF); + + // View last token of hidden_buf. + ggml_tensor * h_last = ggml_view_2d(gF, hidden_buf.t, hidden, 1, + hidden * sizeof(float), + (size_t)(S - 1) * hidden * sizeof(float)); + + // Final RMSNorm. + ggml_tensor * normed = gemma4_rms_norm_mul(gF, h_last, w.out_norm, eps); + + // lm_head. + ggml_tensor * logits = ggml_mul_mat(gF, w.output, normed); + + // Softcapping. + if (w.final_logit_softcap > 0.0f) { + logits = ggml_scale(gF, logits, 1.0f / w.final_logit_softcap); + logits = ggml_tanh(gF, logits); + logits = ggml_scale(gF, logits, w.final_logit_softcap); + } + + ggml_set_output(logits); + ggml_build_forward_expand(gfF, logits); + + if (!ggml_gallocr_alloc_graph(galloc, gfF)) { + std::fprintf(stderr, "[gemma4-bsa] final graph alloc failed\n"); + ggml_free(gF); ggml_gallocr_free(galloc); cleanup_all(); g4_free_pers(per_layer_buf); + return false; + } + + ggml_backend_graph_compute(backend, gfF); + ggml_backend_synchronize(backend); + + out_logits.resize((size_t)w.n_vocab); + ggml_backend_tensor_get(logits, out_logits.data(), 0, + out_logits.size() * sizeof(float)); + ggml_free(gF); + } + + // Update cache position. + cache.cur_pos = S; + + ggml_gallocr_free(galloc); + cleanup_all(); + g4_free_pers(per_layer_buf); + return true; +} + } // namespace dflash::common diff --git a/dflash/src/gemma4/gemma4_internal.h b/dflash/src/gemma4/gemma4_internal.h index 872bc829..e66f3433 100644 --- a/dflash/src/gemma4/gemma4_internal.h +++ b/dflash/src/gemma4/gemma4_internal.h @@ -236,4 +236,17 @@ bool gemma4_project_hidden( int n_tokens, std::vector & out_tokens); +// BSA sparse-FA prefill: process the full prompt at once using block-sparse +// attention for SWA layers (flash_prefill_forward_bf16). Full-attention layers +// use dense FA. Returns logits for the last token. Populates the KV cache +// for subsequent decode. Returns false on failure. +bool gemma4_prefill_bsa( + ggml_backend_t backend, + const Gemma4Weights & w, + Gemma4Cache & cache, + const float * embed, // [n_embd, S] scaled + const int32_t * token_ids, // [S] (for per-layer embedding) + int S, // total prompt length + std::vector & out_logits); + } // namespace dflash::common diff --git a/dflash/src/qwen3/qwen3_graph.cpp b/dflash/src/qwen3/qwen3_graph.cpp index c907546f..085a2342 100644 --- a/dflash/src/qwen3/qwen3_graph.cpp +++ b/dflash/src/qwen3/qwen3_graph.cpp @@ -509,76 +509,21 @@ bool forward_qwen3_drafter_model( } // ── Attention dispatch ── - // Three paths: - // 1. BF16 WMMA (sm_80+, HIP Phase 2): flash_prefill_forward_bf16 - // 2. F16 WMMA (Volta/Turing): flash_prefill_forward_f16 - // 3. ggml flash_attn_ext: fallback for all other cases auto tF0 = std::chrono::steady_clock::now(); - const bool use_bf16_fp = (Q_buf.t->type == GGML_TYPE_BF16) -#if defined(DFLASH27B_HAVE_FLASHPREFILL) || defined(DFLASH27B_HAVE_SM80_FLASHPREFILL) - && true; -#else - && false; -#endif - const bool use_f16_fp = (Q_buf.t->type == GGML_TYPE_F16) -#if defined(DFLASH27B_HAVE_VOLTA_FLASHPREFILL) || defined(DFLASH27B_HAVE_PASCAL_FLASHPREFILL) - && true; -#else - && false; -#endif - if (use_bf16_fp) { -#if defined(DFLASH27B_HAVE_FLASHPREFILL) || defined(DFLASH27B_HAVE_SM80_FLASHPREFILL) - int rc = flashprefill::flash_prefill_forward_bf16( - Q_buf.t->data, - K_curr_v[il].t->data, - V_curr_v[il].t->data, - attn_out_buf.t->data, - 1, S, H, Hk, D, scale, fp_cfg); - if (rc != 0) { - set_last_error("flash_prefill_forward_bf16 failed at layer " + std::to_string(il)); - ggml_gallocr_free(galloc); cleanup_all(); return false; - } - cudaError_t e = cudaGetLastError(); - if (e != cudaSuccess) { - set_last_error(std::string("flash_prefill cuda error: ") + cudaGetErrorString(e)); - ggml_gallocr_free(galloc); cleanup_all(); return false; - } - cudaDeviceSynchronize(); -#endif - } else if (use_f16_fp) { -#if defined(DFLASH27B_HAVE_VOLTA_FLASHPREFILL) || defined(DFLASH27B_HAVE_PASCAL_FLASHPREFILL) - int rc = flashprefill::flash_prefill_forward_f16( - Q_buf.t->data, - K_curr_v[il].t->data, - V_curr_v[il].t->data, - attn_out_buf.t->data, - 1, S, H, Hk, D, scale, fp_cfg); - if (rc != 0) { - set_last_error("flash_prefill_forward_f16 failed at layer " + std::to_string(il)); - ggml_gallocr_free(galloc); cleanup_all(); return false; - } - cudaError_t e = cudaGetLastError(); - if (e != cudaSuccess) { - set_last_error(std::string("flash_prefill-f16 cuda error: ") + cudaGetErrorString(e)); - ggml_gallocr_free(galloc); cleanup_all(); return false; - } - cudaDeviceSynchronize(); -#endif - } else { - int rc = flashprefill::flash_prefill_forward_q8( - w.backend, - Q_buf.t->data, - K_curr_v[il].t->data, - V_curr_v[il].t->data, - attn_out_buf.t->data, - 1, S, H, Hk, D, scale, - Q_buf.t->type, - fp_cfg); - if (rc != 0) { - set_last_error("flash_prefill_forward_q8 failed at layer " + std::to_string(il)); - ggml_gallocr_free(galloc); cleanup_all(); return false; - } + int rc = flashprefill::flash_prefill_forward( + w.backend, + Q_buf.t->data, + K_curr_v[il].t->data, + V_curr_v[il].t->data, + attn_out_buf.t->data, + 1, S, H, Hk, D, scale, + Q_buf.t->type, + fp_cfg); + if (rc != 0) { + set_last_error("flash_prefill_forward failed at layer " + std::to_string(il)); + ggml_gallocr_free(galloc); cleanup_all(); return false; } + cudaDeviceSynchronize(); auto tF1 = std::chrono::steady_clock::now(); t_fp += std::chrono::duration(tF1 - tF0).count(); if (debug_first_layer) { From 3c433a9bfbaf9cc75a903452e5e283da007ad3cf Mon Sep 17 00:00:00 2001 From: Howard Su Date: Thu, 21 May 2026 08:47:04 +0800 Subject: [PATCH 15/18] gemma4: full feature mirror resync after prefix cache restore MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit After restoring KV from a snapshot, do_prefill only syncs the feature mirror for the delta tokens [snap_pos..committed). The positions [0..snap_pos) in the mirror retain stale data from the previous request's decode phase (which may have diverged from the current prompt context after the ring buffer wraps). Fix: call draft_feature_mirror_sync_tail after restore to resync the entire [0..committed) feature range from cache_.target_feat to the mirror. This ensures the draft model sees consistent features and maintains high acceptance rate (AL) during speculative decoding. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> gemma4: save/restore target_feat in prefix cache snapshot Matching Qwen35's approach: save target_feat (BF16 feature ring buffer) and last_tok as part of the KV snapshot. On restore, target_feat is copied back to GPU before the delta prefill + feature mirror resync. Previously, only K/V tensors were snapshotted. After restore, the feature mirror contained stale data from the previous request's decode phase, causing the draft model to make poor predictions and halving speculative decode acceptance rate (52% → 24%). With this fix, the full feature state is correctly restored, and the subsequent draft_feature_mirror_sync_tail ensures the mirror matches. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- dflash/src/gemma4/gemma4_backend.cpp | 59 ++++++++++++++++++++++++++-- dflash/src/gemma4/gemma4_internal.h | 3 ++ dflash/src/gemma4/gemma4_loader.cpp | 5 ++- 3 files changed, 63 insertions(+), 4 deletions(-) diff --git a/dflash/src/gemma4/gemma4_backend.cpp b/dflash/src/gemma4/gemma4_backend.cpp index 6386db7c..f0c89063 100644 --- a/dflash/src/gemma4/gemma4_backend.cpp +++ b/dflash/src/gemma4/gemma4_backend.cpp @@ -511,6 +511,16 @@ GenerateResult Gemma4Backend::generate(const GenerateRequest & req, return result; } + // Inline snapshot at snap_pos for prefix cache + if (req.snap_slot >= 0 && req.snap_pos > 0 && req.snap_pos <= committed) { + cache_.cur_pos = req.snap_pos; + if (snapshot_save(req.snap_slot)) { + std::fprintf(stderr, "[gemma4] inline-snap slot=%d cur_pos=%d\n", + req.snap_slot, req.snap_pos); + } + cache_.cur_pos = committed; + } + if (req.n_gen > 0) { // Try speculative decode if draft is available and temp==0 const bool can_spec = dflash_target_ @@ -623,8 +633,15 @@ GenerateResult Gemma4Backend::restore_and_generate(int slot, } } + // Restore target_feat from snapshot + if (snap.feat_snap && cache_.target_feat) { + const size_t feat_nbytes = ggml_nbytes(snap.feat_snap); + ggml_backend_tensor_set(cache_.target_feat, snap.feat_snap->data, 0, feat_nbytes); + } + const int snap_pos = snap.cur_pos; cache_.cur_pos = snap_pos; + cache_.last_tok = snap.last_tok; // Set up sampler sampler_ = req.sampler; @@ -651,6 +668,24 @@ GenerateResult Gemma4Backend::restore_and_generate(int slot, } // else: prompt_len == snap_pos → no delta, committed stays at snap_pos + // Inline snapshot at snap_pos for prefix cache (new snap from this request) + if (req.snap_slot >= 0 && req.snap_pos > 0 && req.snap_pos <= committed) { + cache_.cur_pos = req.snap_pos; + if (snapshot_save(req.snap_slot)) { + std::fprintf(stderr, "[gemma4] inline-snap slot=%d cur_pos=%d\n", + req.snap_slot, req.snap_pos); + } + cache_.cur_pos = committed; + } + + // Full feature mirror resync after restore: do_prefill only synced the + // delta [snap_pos..committed). Re-sync the entire [0..committed) range so + // the draft model sees correct features for the full context. + if (feature_mirror_.target_feat && cache_.target_feat && !draft_parked_ && committed > 0) { + draft_feature_mirror_sync_tail(cache_.target_feat, cache_.target_feat_cap, + feature_mirror_, committed); + } + // Generate if (req.n_gen > 0) { const bool can_spec = dflash_target_ @@ -738,8 +773,9 @@ bool Gemma4Backend::snapshot_save(int slot) { if (needs_alloc) { free_gemma4_snapshot(snap); + const int n_feat_tensors = (cache_.target_feat && cache_.target_feat_cap > 0) ? 1 : 0; ggml_init_params ip{}; - ip.mem_size = ggml_tensor_overhead() * (size_t)(n_layer * 2 + 4) + 4096; + ip.mem_size = ggml_tensor_overhead() * (size_t)(n_layer * 2 + n_feat_tensors + 4) + 4096; ip.no_alloc = true; snap.ctx = ggml_init(ip); if (!snap.ctx) return false; @@ -759,10 +795,21 @@ bool Gemma4Backend::snapshot_save(int slot) { } } + // target_feat: save min(snap_pos, target_feat_cap) positions + snap.feat_snap = nullptr; + snap.feat_cap = 0; + if (cache_.target_feat && cache_.target_feat_cap > 0) { + const int feat_len = std::min(snap_pos, cache_.target_feat_cap); + snap.feat_snap = ggml_new_tensor_2d(snap.ctx, cache_.target_feat->type, + cache_.target_feat->ne[0], feat_len); + snap.feat_cap = cache_.target_feat_cap; + } + snap.buf = ggml_backend_alloc_ctx_tensors(snap.ctx, snap_backend_); if (!snap.buf) { ggml_free(snap.ctx); snap.ctx = nullptr; snap.k_snap.clear(); snap.v_snap.clear(); + snap.feat_snap = nullptr; return false; } } @@ -792,9 +839,15 @@ bool Gemma4Backend::snapshot_save(int slot) { } } snap.cur_pos = snap_pos; + snap.last_tok = cache_.last_tok; - std::printf("[gemma4] snapshot saved slot=%d pos=%d\n", slot, snap.cur_pos); - std::fflush(stdout); + // target_feat: copy min(snap_pos, cap) positions from GPU to snapshot + if (snap.feat_snap && cache_.target_feat) { + const size_t feat_nbytes = ggml_nbytes(snap.feat_snap); + ggml_backend_tensor_get(cache_.target_feat, snap.feat_snap->data, 0, feat_nbytes); + } + + std::fprintf(stderr, "[gemma4] snapshot saved slot=%d pos=%d\n", slot, snap.cur_pos); return true; } diff --git a/dflash/src/gemma4/gemma4_internal.h b/dflash/src/gemma4/gemma4_internal.h index e66f3433..4c365ab7 100644 --- a/dflash/src/gemma4/gemma4_internal.h +++ b/dflash/src/gemma4/gemma4_internal.h @@ -193,8 +193,11 @@ bool create_gemma4_target_feat(ggml_backend_t backend, Gemma4Cache & cache, // Snapshot struct Gemma4Snapshot { int cur_pos = 0; + int32_t last_tok = -1; std::vector k_snap; std::vector v_snap; + ggml_tensor * feat_snap = nullptr; // [fc_in, feat_len] + int feat_cap = 0; ggml_context * ctx = nullptr; ggml_backend_buffer_t buf = nullptr; }; diff --git a/dflash/src/gemma4/gemma4_loader.cpp b/dflash/src/gemma4/gemma4_loader.cpp index 2c82c7c6..528113c5 100644 --- a/dflash/src/gemma4/gemma4_loader.cpp +++ b/dflash/src/gemma4/gemma4_loader.cpp @@ -478,7 +478,10 @@ void free_gemma4_snapshot(Gemma4Snapshot & s) { if (s.buf) { ggml_backend_buffer_free(s.buf); s.buf = nullptr; } if (s.ctx) { ggml_free(s.ctx); s.ctx = nullptr; } s.k_snap.clear(); s.v_snap.clear(); - s.cur_pos = 0; + s.feat_snap = nullptr; + s.feat_cap = 0; + s.cur_pos = 0; + s.last_tok = -1; } } // namespace dflash::common From d399796e0a899c0719dd044cb5ed7432791561af Mon Sep 17 00:00:00 2001 From: Howard Su Date: Thu, 21 May 2026 09:21:21 +0800 Subject: [PATCH 16/18] gemma4: check graph_compute return in prefill_bsa per-layer embed Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- dflash/src/gemma4/gemma4_graph.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/dflash/src/gemma4/gemma4_graph.cpp b/dflash/src/gemma4/gemma4_graph.cpp index 41791203..14587d48 100644 --- a/dflash/src/gemma4/gemma4_graph.cpp +++ b/dflash/src/gemma4/gemma4_graph.cpp @@ -960,7 +960,12 @@ bool gemma4_prefill_bsa( return false; } ggml_backend_tensor_set(tok, token_ids, 0, (size_t)S * sizeof(int32_t)); - ggml_backend_graph_compute(backend, gf); + if (ggml_backend_graph_compute(backend, gf) != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "gemma4_prefill_bsa: per-layer embed graph_compute failed\n"); + ggml_gallocr_free(ga); ggml_free(ctx); + g4_free_pers(per_layer_buf); cleanup_all(); + return false; + } ggml_gallocr_free(ga); ggml_free(ctx); } From dfaf9991e7e767f895d06f9e650497e2886cd8b9 Mon Sep 17 00:00:00 2001 From: Howard Su Date: Fri, 22 May 2026 07:29:47 +0800 Subject: [PATCH 17/18] gemma4: fix namespace dflash27b -> dflash::common after rebase Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- dflash/src/gemma4/gemma4_backend.cpp | 2 +- dflash/src/gemma4/gemma4_dflash_target.cpp | 4 ++-- dflash/src/gemma4/gemma4_dflash_target.h | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/dflash/src/gemma4/gemma4_backend.cpp b/dflash/src/gemma4/gemma4_backend.cpp index f0c89063..6395edf1 100644 --- a/dflash/src/gemma4/gemma4_backend.cpp +++ b/dflash/src/gemma4/gemma4_backend.cpp @@ -936,7 +936,7 @@ bool Gemma4Backend::handle_compress(const std::string & line, void Gemma4Backend::free_drafter() { if (drafter_loaded_) { - dflash27b::free_drafter(drafter_ctx_); + ::dflash::common::free_drafter(drafter_ctx_); drafter_loaded_ = false; } } diff --git a/dflash/src/gemma4/gemma4_dflash_target.cpp b/dflash/src/gemma4/gemma4_dflash_target.cpp index 5d044234..aebd0b09 100644 --- a/dflash/src/gemma4/gemma4_dflash_target.cpp +++ b/dflash/src/gemma4/gemma4_dflash_target.cpp @@ -7,7 +7,7 @@ #include #include -namespace dflash27b { +namespace dflash::common { Gemma4DFlashTarget::Gemma4DFlashTarget( Gemma4Weights & w, @@ -148,4 +148,4 @@ const std::vector & Gemma4DFlashTarget::capture_layer_ids() const { return capture_ids_; } -} // namespace dflash27b +} // namespace dflash::common diff --git a/dflash/src/gemma4/gemma4_dflash_target.h b/dflash/src/gemma4/gemma4_dflash_target.h index 86a0d8bf..1d12079b 100644 --- a/dflash/src/gemma4/gemma4_dflash_target.h +++ b/dflash/src/gemma4/gemma4_dflash_target.h @@ -14,7 +14,7 @@ #include -namespace dflash27b { +namespace dflash::common { class Gemma4DFlashTarget : public DFlashTarget { public: @@ -60,4 +60,4 @@ class Gemma4DFlashTarget : public DFlashTarget { Gemma4Snapshot verify_snap_; }; -} // namespace dflash27b +} // namespace dflash::common From d3720b6010ff77894ede99ad93a973f0790df650 Mon Sep 17 00:00:00 2001 From: Davide Cifarelli Date: Fri, 22 May 2026 08:58:12 +0000 Subject: [PATCH 18/18] gemma4: fix MoE GELU contig + loader tensor name mismatches MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three fixes for gemma-4-26B-A4B-it (unsloth UD-Q4_K_M). 1. gemma4_graph.cpp:116 — GGML_ASSERT(ggml_is_contiguous(src0)) crash in ggml_cuda_op_gelu. gate_e and up_e are strided ggml_view_3d halves of fused gate_up_e; CUDA gelu requires contiguous src. Insert ggml_cont before ggml_gelu. 2. gemma4_loader.cpp tensor name mismatches with actual GGUF metadata (silently loaded null → MoE produced gibberish): ffn_gate_inp_shexp.weight → ffn_gate_inp.scale ffn_down_exps_s.weight → ffn_down_exps.scale ffn_pre_norm_2.weight → pre_ffw_norm_2.weight ffn_post_norm_1.weight → post_ffw_norm_1.weight ffn_post_norm_2.weight → post_ffw_norm_2.weight 3. leading_dense_block_count default 1 → 0. Gemma-4-26B-A4B GGUF does not store this key; old default skipped MoE on layer 0, running shared-expert only and corrupting downstream. Verified: 'What is 2+2?' returns '2 + 2 = 4' on lucebox2 RTX 3090. Co-Authored-By: WOZCODE --- dflash/src/gemma4/gemma4_graph.cpp | 2 ++ dflash/src/gemma4/gemma4_loader.cpp | 12 ++++++------ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/dflash/src/gemma4/gemma4_graph.cpp b/dflash/src/gemma4/gemma4_graph.cpp index 14587d48..c042eb15 100644 --- a/dflash/src/gemma4/gemma4_graph.cpp +++ b/dflash/src/gemma4/gemma4_graph.cpp @@ -113,6 +113,8 @@ static ggml_tensor * build_gemma4_moe_block(ggml_context * ctx, ggml_tensor * at n_ff_exp, gate_up_e->ne[1], gate_up_e->ne[2], gate_up_e->nb[1], gate_up_e->nb[2], (size_t)n_ff_exp * ggml_element_size(gate_up_e)); + gate_e = ggml_cont(ctx, gate_e); + up_e = ggml_cont(ctx, up_e); ggml_tensor * gu = ggml_mul(ctx, ggml_gelu(ctx, gate_e), up_e); ggml_tensor * experts = ggml_mul_mat_id(ctx, L.ffn_down_exps, gu, selected); diff --git a/dflash/src/gemma4/gemma4_loader.cpp b/dflash/src/gemma4/gemma4_loader.cpp index 528113c5..077d53f9 100644 --- a/dflash/src/gemma4/gemma4_loader.cpp +++ b/dflash/src/gemma4/gemma4_loader.cpp @@ -123,7 +123,7 @@ bool load_gemma4_gguf(const std::string & path, const uint32_t head_dim_swa = get_u32_or(gctx, "gemma4.attention.key_length_swa", head_dim_full); const uint32_t n_expert = get_u32_or(gctx, "gemma4.expert_count", 0); const uint32_t n_expert_used = get_u32_or(gctx, "gemma4.expert_used_count", 0); - const uint32_t n_dense_lead = get_u32_or(gctx, "gemma4.leading_dense_block_count", 1); + const uint32_t n_dense_lead = get_u32_or(gctx, "gemma4.leading_dense_block_count", 0); const uint32_t sliding_win = get_u32_or(gctx, "gemma4.attention.sliding_window", 0); const uint32_t shared_kv = get_u32_or(gctx, "gemma4.attention.shared_kv_layers", 0); const uint32_t n_embd_pl = get_u32_or(gctx, "gemma4.embedding_length_per_layer_input", 0); @@ -340,16 +340,16 @@ bool load_gemma4_gguf(const std::string & path, // MoE tensors (only present for MoE models) L.ffn_norm_moe = get("ffn_norm_moe.weight"); L.ffn_gate_inp = get("ffn_gate_inp.weight"); - L.ffn_gate_inp_s = get("ffn_gate_inp_shexp.weight"); + L.ffn_gate_inp_s = get("ffn_gate_inp.scale"); L.ffn_gate_up_exps = get("ffn_gate_up_exps.weight"); L.ffn_down_exps = get("ffn_down_exps.weight"); - L.ffn_down_exps_s = get("ffn_down_exps_s.weight"); + L.ffn_down_exps_s = get("ffn_down_exps.scale"); L.ffn_gate_shexp = get("ffn_gate_shexp.weight"); L.ffn_up_shexp = get("ffn_up_shexp.weight"); L.ffn_down_shexp = get("ffn_down_shexp.weight"); - L.ffn_pre_norm_2 = get("ffn_pre_norm_2.weight"); - L.ffn_post_norm_1 = get("ffn_post_norm_1.weight"); - L.ffn_post_norm_2 = get("ffn_post_norm_2.weight"); + L.ffn_pre_norm_2 = get("pre_ffw_norm_2.weight"); + L.ffn_post_norm_1 = get("post_ffw_norm_1.weight"); + L.ffn_post_norm_2 = get("post_ffw_norm_2.weight"); // Per-layer embedding L.per_layer_inp_gate = get("per_layer_inp_gate.weight");