diff --git a/dflash/CMakeLists.txt b/dflash/CMakeLists.txt index de0b5f2d..50857ff7 100644 --- a/dflash/CMakeLists.txt +++ b/dflash/CMakeLists.txt @@ -215,7 +215,7 @@ add_library(dflash_common STATIC src/qwen35/qwen35_target_graph.cpp src/draft/draft_gguf_loader.cpp src/draft/draft_safetensors_loader.cpp - src/draft/draft_dflash_graph.cpp + src/draft/draft_graph.cpp src/qwen3/qwen3_drafter.cpp src/qwen3/qwen3_loader.cpp src/qwen3/qwen3_graph.cpp @@ -225,6 +225,7 @@ add_library(dflash_common STATIC src/gemma4/gemma4_graph.cpp src/gemma4/gemma4_backend.cpp src/gemma4/gemma4_daemon.cpp + src/gemma4/gemma4_dflash_target.cpp src/flashprefill_q8.cpp src/kv_cache.cpp src/kv_quant.cpp diff --git a/dflash/docs/SPEC_PREFILL.md b/dflash/docs/SPEC_PREFILL.md index f7dc25d5..cb2837dc 100644 --- a/dflash/docs/SPEC_PREFILL.md +++ b/dflash/docs/SPEC_PREFILL.md @@ -89,7 +89,7 @@ src/ qwen35_target_graph.cpp Qwen3.5/3.6 target graph (ggml) gguf_target_loader.cpp Qwen3.5 target GGUF loader draft/ Special DFlash draft model code - draft_dflash_graph.cpp DFlash speculative draft head + draft_graph.cpp DFlash speculative draft head draft_gguf_loader.cpp Draft GGUF loader draft_safetensors_loader.cpp Draft safetensors loader laguna/ Laguna target + daemon model code diff --git a/dflash/include/dflash27b.h b/dflash/include/dflash27b.h index 29c4d1b1..b707b2d8 100644 --- a/dflash/include/dflash27b.h +++ b/dflash/include/dflash27b.h @@ -24,13 +24,12 @@ extern "C" { // Qwen3.5-27B qwen35 hybrid uses 24 Q heads, 4 KV heads, 256 head_dim, which // live in `src/internal.h` (n_embd_head_k/v, N_HEAD, N_HEAD_KV). Naming is // historical — do not change without updating draft_safetensors_loader.cpp + -// draft_dflash_graph.cpp which consume these as draft-side constants. +// draft_graph.cpp which consume these as draft-side constants. #define DFLASH27B_TARGET_N_HEADS 32 #define DFLASH27B_TARGET_N_KV_HEADS 8 #define DFLASH27B_TARGET_HEAD_DIM 128 #define DFLASH27B_TARGET_INTERMEDIATE 17408 #define DFLASH27B_TARGET_VOCAB 248320 -#define DFLASH27B_ROPE_THETA 10000000.0f #define DFLASH27B_RMS_EPS 1e-6f #define DFLASH27B_DRAFT_LAYERS 5 diff --git a/dflash/src/common/backend_factory.cpp b/dflash/src/common/backend_factory.cpp index f8323595..5e4bb0da 100644 --- a/dflash/src/common/backend_factory.cpp +++ b/dflash/src/common/backend_factory.cpp @@ -88,10 +88,13 @@ std::unique_ptr create_backend(const BackendArgs & args) { } else if (arch == "gemma4") { Gemma4BackendConfig gcfg; - gcfg.model_path = args.model_path; - gcfg.device = args.device; - gcfg.stream_fd = args.stream_fd; - gcfg.chunk = args.chunk; + gcfg.model_path = args.model_path; + gcfg.draft_path = args.draft_path; + gcfg.draft_gpu = args.draft_gpu; + gcfg.draft_ctx_max = args.draft_ctx_max; + gcfg.device = args.device; + gcfg.stream_fd = args.stream_fd; + gcfg.chunk = args.chunk; auto backend = std::make_unique(gcfg); if (!backend->init()) { diff --git a/dflash/src/common/dflash_draft_graph.cpp b/dflash/src/common/dflash_draft_graph.cpp index 2e60acb6..c028bb88 100644 --- a/dflash/src/common/dflash_draft_graph.cpp +++ b/dflash/src/common/dflash_draft_graph.cpp @@ -2,11 +2,26 @@ #include "draft/draft_graph.h" // DraftGraphInputs, DraftGraphOutputs, build_draft_graph #include "ggml-alloc.h" +#include "ggml-backend.h" +#include #include +#include namespace dflash::common { +// Minimum alignment required by ggml flash_attn_ext for mask rows. +static constexpr int MASK_KV_PAD = 32; + +static inline int mask_align_up(int x, int a) { return ((x + a - 1) / a) * a; } + +// Check whether any layer in the draft is SWA. +static bool draft_has_swa_layers(const DraftWeights & dw) { + for (int i = 0; i < dw.n_layer; i++) + if (dw.layers[i].is_swa) return true; + return false; +} + // Build draft graph at a given ctx_len into sg. Does NOT touch sg.alloc. // mirror_view: if true, uses a view into mirror->target_feat at slot0. static bool build_draft_graph_internal( @@ -56,6 +71,21 @@ static bool build_draft_graph_internal( ggml_set_name(sg.positions_k, "positions_k"); ggml_set_input(sg.positions_k); + // Causal mask for SWA layers (if any). + // Shape: [kv_pad, q_len] F16 (directly, no cast needed — matches attn_masks.h pattern). + sg.attn_mask = nullptr; + const bool has_swa = draft_has_swa_layers(dw); + if (has_swa) { + // SWA layers' effective KV length (windowed or full ctx) + const bool swa_active = dw.swa_window > 0 && ctx_len > dw.swa_window; + const int eff_ctx = swa_active ? dw.swa_window : ctx_len; + const int eff_total_k = eff_ctx + q_len; + const int kv_pad = mask_align_up(eff_total_k, MASK_KV_PAD); + sg.attn_mask = ggml_new_tensor_2d(sg.ctx, GGML_TYPE_F16, kv_pad, q_len); + ggml_set_name(sg.attn_mask, "causal_mask_swa"); + ggml_set_input(sg.attn_mask); + } + sg.gf = ggml_new_graph_custom(sg.ctx, 4096, false); DraftGraphInputs gi{}; @@ -65,6 +95,7 @@ static bool build_draft_graph_internal( gi.positions_q = sg.positions; gi.positions_k = sg.positions_k; gi.lm_head = lm_head; + gi.causal_mask_swa = sg.attn_mask; DraftGraphOutputs go = build_draft_graph(sg.ctx, dw, gi); sg.hidden_states = go.hidden_states; sg.logits = go.logits; @@ -125,7 +156,35 @@ bool build_draft_step( return false; } - return ggml_gallocr_alloc_graph(sg.alloc, sg.gf); + if (!ggml_gallocr_alloc_graph(sg.alloc, sg.gf)) { + return false; + } + + // Fill causal mask data for SWA layers (after allocation gives memory to the tensor). + if (sg.attn_mask) { + const int q_len = dw.block_size; + const bool swa_active = dw.swa_window > 0 && ctx_len > dw.swa_window; + const int eff_ctx = swa_active ? dw.swa_window : ctx_len; + const int eff_total_k = eff_ctx + q_len; + const int kv_pad = mask_align_up(eff_total_k, MASK_KV_PAD); + + // Build causal mask in F16 directly (same pattern as attn_masks.h): + // Context keys (k < eff_ctx): always visible. + // Noise keys (k = eff_ctx + j): visible if j <= q (causal). + static constexpr uint16_t ZERO = 0x0000; + static constexpr uint16_t NEG_INF = 0xFC00; + std::vector mask_data((size_t)kv_pad * q_len, NEG_INF); + for (int q = 0; q < q_len; q++) { + for (int k = 0; k < eff_ctx; k++) + mask_data[(size_t)q * kv_pad + k] = ZERO; + for (int j = 0; j <= q; j++) + mask_data[(size_t)q * kv_pad + (eff_ctx + j)] = ZERO; + } + ggml_backend_tensor_set(sg.attn_mask, mask_data.data(), 0, + sizeof(uint16_t) * mask_data.size()); + } + + return true; } } // namespace dflash::common diff --git a/dflash/src/draft/draft_gguf_loader.cpp b/dflash/src/draft/draft_gguf_loader.cpp index 89f7b17c..9a6ffdf6 100644 --- a/dflash/src/draft/draft_gguf_loader.cpp +++ b/dflash/src/draft/draft_gguf_loader.cpp @@ -2,7 +2,7 @@ // on the CUDA backend. // // This is the Q8_0-quantized counterpart of draft_safetensors_loader.cpp. The -// draft graph builder (draft_dflash_graph.cpp) doesn't care about tensor storage +// draft graph builder (draft_graph.cpp) doesn't care about tensor storage // types — ggml's ggml_mul_mat handles Q8_0 × F32 dequantization transparently. // // GGUF arch: "qwen35-dflash-draft" (from convert_dflash_to_gguf.py / @@ -159,6 +159,12 @@ bool load_draft_gguf(const std::string & path, std::snprintf(key, sizeof(key), "%s.%s", A, suffix); return get_u32_or(gctx, key, fallback); }; + auto read_f32 = [&](const char * suffix, float fallback) -> float { + std::snprintf(key, sizeof(key), "%s.%s", A, suffix); + int64_t id = gguf_find_key(gctx, key); + if (id < 0) return fallback; + return gguf_get_val_f32(gctx, id); + }; const uint32_t n_embd = read_u32("embedding_length", 0); const uint32_t n_layer = read_u32("block_count", 0); @@ -235,6 +241,10 @@ bool load_draft_gguf(const std::string & path, out.head_dim = (int)head_dim; out.n_embd = (int)n_embd; out.n_ff = (int)n_ff; + out.rope_theta = read_f32("rope.freq_base", 0.0f); + if (out.rope_theta == 0.0f) { + fprintf(stderr, "[draft-gguf] WARNING: rope.freq_base not found in GGUF, draft RoPE will be wrong\n"); + } out.layers.assign((size_t)n_layer, DraftLayer{}); auto g = [&](const char * name) -> ggml_tensor * { diff --git a/dflash/src/draft/draft_dflash_graph.cpp b/dflash/src/draft/draft_graph.cpp similarity index 91% rename from dflash/src/draft/draft_dflash_graph.cpp rename to dflash/src/draft/draft_graph.cpp index eddfba9a..11f5b4c5 100644 --- a/dflash/src/draft/draft_dflash_graph.cpp +++ b/dflash/src/draft/draft_graph.cpp @@ -1,28 +1,29 @@ // Builds a ggml compute graph for one forward pass of the DFlash draft -// (5-layer non-causal Qwen3-flavored block-diffusion model). +// (5-layer Qwen3-flavored block-diffusion model). // // Stateless: no KV cache. Each call takes: -// - noise_embed [hidden, q_len, 1] bf16 (target.tok_embd on [last_tok, MASK*15]) -// - target_hidden_cat [5*hidden, ctx_len, 1] bf16 (5 target layers concat along features) +// - noise_embed [hidden, q_len, 1] f32 (target.tok_embd on [last_tok, MASK*15]) +// - target_hidden_cat [N*hidden, ctx_len, 1] f32 (N target layers concat along features) // - positions_q [q_len] i32 values [ctx_len..ctx_len+q_len-1] // - positions_k [ctx_len+q_len] i32 values [0..ctx_len+q_len-1] +// - causal_mask_swa [kv_pad, q_len] f32 (optional; causal mask for SWA layers) // and returns: -// - hidden_states [hidden, q_len, 1] bf16 (final RMSNorm; NO lm_head here) +// - hidden_states [hidden, q_len, 1] f32 (final RMSNorm; NO lm_head here) // // The caller projects `hidden_states` through the TARGET's lm_head separately // (the draft has no lm_head of its own, it shares the target's). // -// Semantics match megaqwen3_27b_dflash/reference/dflash_reference.py exactly: +// Semantics: // - fc @ target_hidden_cat -> rms_norm with hidden_norm -> target_feat -// - Per layer (non-causal): +// - Per layer: // h_norm = rms_norm(h) * input_layernorm // Q = wq @ h_norm -> per-head q_norm // K_ctx/V_ctx = wk/wv @ target_feat // K_noi/V_noi = wk/wv @ h_norm // K = concat[K_ctx, K_noi] -> per-head k_norm // V = concat[V_ctx, V_noi] -// RoPE(Q, positions_q); RoPE(K, positions_k) (NEOX style, theta=10M) -// attn = flash_attn_ext(Q, K, V, mask=null, scale=1/sqrt(head_dim)) non-causal +// RoPE(Q, positions_q); RoPE(K, positions_k) (NEOX style) +// attn = flash_attn_ext(Q, K, V, mask, scale) SWA=causal, full=non-causal // h += wo @ attn // h_norm = rms_norm(h) * post_attention_layernorm // h += w_down @ (silu(w_gate @ h_norm) * (w_up @ h_norm)) @@ -46,7 +47,7 @@ DraftGraphOutputs build_draft_graph( const int n_kv = w.n_head_kv; const int head_dim = w.head_dim; const float eps = DFLASH27B_RMS_EPS; - const float rope_base = DFLASH27B_ROPE_THETA; + const float rope_base = w.rope_theta; // ── 1. Feature fusion: target_feat = rms_norm(fc @ target_hidden_cat, hidden_norm) // fc: [5*hidden, hidden] (ggml: ne[0]=5*hidden, ne[1]=hidden) @@ -134,9 +135,10 @@ DraftGraphOutputs build_draft_graph( V = ggml_permute(ctx, V, 0, 2, 1, 3); // [head_dim, eff_total_k, n_kv, 1] V = ggml_cont (ctx, V); - // ── 2f. Non-causal flash attention; GQA broadcast handled internally. + // ── 2f. Attention: causal for SWA layers, non-causal for full layers. const float scale = 1.0f / std::sqrt((float)head_dim); - ggml_tensor * attn = ggml_flash_attn_ext(ctx, Q, K, V, /*mask=*/nullptr, + ggml_tensor * mask = (L.is_swa && in.causal_mask_swa) ? in.causal_mask_swa : nullptr; + ggml_tensor * attn = ggml_flash_attn_ext(ctx, Q, K, V, mask, scale, /*max_bias=*/0.0f, /*logit_softcap=*/0.0f); // attn result: [n_embd_v=head_dim, n_head, n_batch=q_len, 1] diff --git a/dflash/src/draft/draft_graph.h b/dflash/src/draft/draft_graph.h index 28bc0d83..30baa17b 100644 --- a/dflash/src/draft/draft_graph.h +++ b/dflash/src/draft/draft_graph.h @@ -18,6 +18,9 @@ struct DraftGraphInputs { // hidden states. Used for DFlash integration where the draft shares the // target's lm_head. ggml_tensor * lm_head; + // Optional: causal mask for SWA layers [kv_pad, q_len] F16. + // nullptr = all layers non-causal. + ggml_tensor * causal_mask_swa = nullptr; }; struct DraftGraphOutputs { diff --git a/dflash/src/flashprefill.h b/dflash/src/flashprefill.h index 1cb0f66d..fd0c64e0 100644 --- a/dflash/src/flashprefill.h +++ b/dflash/src/flashprefill.h @@ -90,6 +90,37 @@ int flash_prefill_forward_q8( ggml_type qkv_type, const FlashPrefillConfig & cfg); +// ── Unified dispatch ────────────────────────────────────────────────────────── +// Picks the best available kernel at compile time + runtime buffer type: +// BF16 buffers + sm_80 build → flash_prefill_forward_bf16 +// F16 buffers + Volta build → flash_prefill_forward_f16 +// otherwise → flash_prefill_forward_q8 (ggml FA fallback) +// +// Callers no longer need to duplicate the ifdef/dispatch boilerplate. +inline int flash_prefill_forward( + ggml_backend_t backend, + const void * Q, const void * K, const void * V, void * O, + int batch, int seq_len, int n_q_heads, int n_k_heads, int head_dim, + float scale, + ggml_type qkv_type, + const FlashPrefillConfig & cfg) +{ +#if defined(DFLASH27B_HAVE_FLASHPREFILL) || defined(DFLASH27B_HAVE_SM80_FLASHPREFILL) + if (qkv_type == GGML_TYPE_BF16) { + return flash_prefill_forward_bf16(Q, K, V, O, + batch, seq_len, n_q_heads, n_k_heads, head_dim, scale, cfg); + } +#endif +#if defined(DFLASH27B_HAVE_VOLTA_FLASHPREFILL) || defined(DFLASH27B_HAVE_PASCAL_FLASHPREFILL) + if (qkv_type == GGML_TYPE_F16) { + return flash_prefill_forward_f16(Q, K, V, O, + batch, seq_len, n_q_heads, n_k_heads, head_dim, scale, cfg); + } +#endif + return flash_prefill_forward_q8(backend, Q, K, V, O, + batch, seq_len, n_q_heads, n_k_heads, head_dim, scale, qkv_type, cfg); +} + #ifdef DFLASH27B_HAVE_BSA // Free BSA persistent device buffers (blockmask, head_mask_type, softmax_lse). // Safe to call any time; idempotent. Useful before unloading the drafter to diff --git a/dflash/src/gemma4/gemma4_backend.cpp b/dflash/src/gemma4/gemma4_backend.cpp index 5b6a0725..6395edf1 100644 --- a/dflash/src/gemma4/gemma4_backend.cpp +++ b/dflash/src/gemma4/gemma4_backend.cpp @@ -1,6 +1,6 @@ // Gemma4Backend implementation. // -// Uses gemma4_step() for forward pass (currently stubbed). +// Uses gemma4_step() for forward pass with per-layer embedding support. // Structure mirrors Qwen3Backend: prefill in chunks, autoregressive decode, // KV cache with layer sharing, snapshot/restore. @@ -8,6 +8,9 @@ #include "dflash27b.h" #include "common/sampler.h" #include "common/io_utils.h" +#include "common/dflash_feature_ring.h" +#include "common/dflash_draft_graph.h" +#include "common/step_graph.h" #include "ggml-cuda.h" #include "common/snapshot_backend.h" @@ -50,6 +53,92 @@ bool Gemma4Backend::init() { std::fprintf(stderr, "[gemma4] cache alloc failed\n"); return false; } + cache_.fa_window = cfg_.fa_window; + + // Load draft model for speculative decode + if (cfg_.draft_path) { + const int draft_gpu = (cfg_.draft_gpu >= 0) ? cfg_.draft_gpu : cfg_.device.gpu; + draft_backend_ = ggml_backend_cuda_init(draft_gpu); + if (!draft_backend_) { + std::fprintf(stderr, "[gemma4] draft CUDA init failed (gpu=%d)\n", draft_gpu); + } else { + // Load draft GGUF — pass nullptr for target (Gemma4 != TargetWeights) + if (!load_draft_gguf(cfg_.draft_path, draft_backend_, dw_, nullptr)) { + std::fprintf(stderr, "[gemma4] draft load failed: %s\n", dflash27b_last_error()); + ggml_backend_free(draft_backend_); draft_backend_ = nullptr; + } else { + // Override mask_token_id for Gemma4 (token 4 per model card) + dw_.mask_token_id = 4; + + // Fix draft dimensions from actual tensor shapes (GGUF metadata is wrong) + // fc.weight: [fc_in, draft_hidden] + const int draft_hidden = (int)dw_.fc->ne[1]; + const int fc_in = (int)dw_.fc->ne[0]; + const int n_capture = fc_in / w_.n_embd; + + if (draft_hidden != dw_.n_embd) { + std::printf("[gemma4] draft: overriding n_embd %d -> %d (from fc weight)\n", + dw_.n_embd, draft_hidden); + dw_.n_embd = draft_hidden; + } + // Infer n_head from wq shape: wq.ne[1] = n_head * head_dim + if (dw_.n_layer > 0 && dw_.layers[0].wq) { + const int q_dim = (int)dw_.layers[0].wq->ne[1]; + const int inferred_n_head = q_dim / dw_.head_dim; + if (inferred_n_head != dw_.n_head) { + std::printf("[gemma4] draft: overriding n_head %d -> %d\n", + dw_.n_head, inferred_n_head); + dw_.n_head = inferred_n_head; + } + } + // Infer n_ff from ffn_gate shape + if (dw_.n_layer > 0 && dw_.layers[0].w_gate) { + const int inferred_ff = (int)dw_.layers[0].w_gate->ne[1]; + if (inferred_ff != dw_.n_ff) { + std::printf("[gemma4] draft: overriding n_ff %d -> %d\n", + dw_.n_ff, inferred_ff); + dw_.n_ff = inferred_ff; + } + } + // Override n_target_layers from fc shape + dw_.n_target_layers = n_capture; + + // Gemma4 DFlash draft: layers 0-3 are SWA (causal), layer 4 is full (non-causal) + // (from model card: layer_types = [sliding*4, full_attention]) + dw_.swa_window = 2048; + for (int i = 0; i < dw_.n_layer - 1 && i < (int)dw_.layers.size(); i++) + dw_.layers[i].is_swa = true; + + std::printf("[gemma4] draft loaded: fc_in=%d target_hidden=%d " + "draft_hidden=%d n_capture_layers=%d swa=%d\n", + fc_in, w_.n_embd, draft_hidden, n_capture, dw_.swa_window); + + // Allocate target_feat ring buffer + constexpr int TARGET_FEAT_CAP = 4096; + const int feat_cap = std::min(cfg_.device.max_ctx, TARGET_FEAT_CAP); + if (!create_gemma4_target_feat(backend_, cache_, n_capture, w_.n_embd, feat_cap)) { + std::fprintf(stderr, "[gemma4] target_feat alloc failed\n"); + } else { + // Init feature mirror on draft GPU + const int mirror_cap = std::min(cfg_.draft_ctx_max, feat_cap); + if (!draft_feature_mirror_init(feature_mirror_, draft_backend_, + draft_gpu, cfg_.device.gpu, mirror_cap, + n_capture, w_.n_embd)) { + std::fprintf(stderr, "[gemma4] feature mirror init failed\n"); + } else { + // Create DFlash target adapter + dflash_target_ = new Gemma4DFlashTarget(w_, cache_, backend_); + std::printf("[gemma4] spec-decode ready: capture_layers=%d mirror_cap=%d\n", + n_capture, mirror_cap); + std::printf("[gemma4] capture_layer_ids:"); + for (int k = 0; k < (int)cache_.capture_layer_ids.size(); k++) + std::printf(" %d", cache_.capture_layer_ids[k]); + std::printf("\n"); + } + } + } + } + } std::printf("[gemma4] init ok: %d layers, embd=%d, vocab=%d, max_ctx=%d\n", w_.n_layer, w_.n_embd, w_.n_vocab, cfg_.device.max_ctx); @@ -69,22 +158,50 @@ void Gemma4Backend::print_ready_banner() const { bool Gemma4Backend::park(const std::string & what) { (void)what; + if (parked_) return true; + + // Free snapshots first (they reference the snap_backend buffer) + for (int i = 0; i < PREFIX_SLOTS; ++i) { + free_gemma4_snapshot(snapshots_[i]); + } + + // Free KV cache (GPU memory) + free_gemma4_cache(cache_); + + // Free model weights (GPU memory) + free_gemma4_weights(w_); + parked_ = true; - std::printf("[gemma4] parked\n"); std::fflush(stdout); + std::printf("[gemma4] parked (VRAM released)\n"); std::fflush(stdout); return true; } bool Gemma4Backend::unpark(const std::string & what) { (void)what; + if (!parked_) return true; + + // Reload weights from disk + if (!load_gemma4_gguf(cfg_.model_path, backend_, w_)) { + std::fprintf(stderr, "[gemma4] unpark: failed to reload weights\n"); + return false; + } + + // Recreate KV cache + if (!create_gemma4_cache(backend_, w_, cfg_.device.max_ctx, cache_)) { + std::fprintf(stderr, "[gemma4] unpark: failed to recreate cache\n"); + free_gemma4_weights(w_); + return false; + } + parked_ = false; - std::printf("[gemma4] unparked\n"); std::fflush(stdout); + std::printf("[gemma4] unparked (VRAM restored)\n"); std::fflush(stdout); return true; } // ── Prefill ──────────────────────────────────────────────────────────── int Gemma4Backend::do_prefill(const std::vector & tokens, - const DaemonIO & io) { + const DaemonIO & io, int kv_offset) { (void)io; const int n = (int)tokens.size(); const int hidden = w_.n_embd; @@ -97,6 +214,12 @@ int Gemma4Backend::do_prefill(const std::vector & tokens, while (pos < n) { int len = std::min(chunk, n - pos); + // Limit chunk to avoid ring-buffer wrap for SWA layers + if (cache_.swa_size > 0 && cache_.swa_size < cache_.max_ctx) { + const int swa_remaining = cache_.swa_size - ((kv_offset + pos) % cache_.swa_size); + len = std::min(len, swa_remaining); + } + // Embed tokens using CPU embedder w_.embedder.embed(tokens.data() + pos, len, embed.data()); @@ -104,16 +227,34 @@ int Gemma4Backend::do_prefill(const std::vector & tokens, float scale = std::sqrt((float)hidden); for (int i = 0; i < len * hidden; ++i) embed[i] *= scale; - if (!gemma4_step(backend_, w_, cache_, embed.data(), len, pos, logits)) { - std::fprintf(stderr, "[gemma4] prefill step failed at pos=%d\n", pos); + const int kv_pos = kv_offset + pos; + if (!gemma4_step(backend_, w_, cache_, embed.data(), + tokens.data() + pos, len, kv_pos, logits)) { + std::fprintf(stderr, "[gemma4] prefill step failed at pos=%d\n", kv_pos); return -1; } pos += len; - cache_.cur_pos = pos; + cache_.cur_pos = kv_offset + pos; + + // Store last_tok from final chunk's logits (argmax of last position) + if (pos >= n && !logits.empty()) { + int32_t best_tok = 0; + float best_val = logits[0]; + for (int j = 1; j < (int)logits.size(); ++j) { + if (logits[j] > best_val) { best_val = logits[j]; best_tok = j; } + } + cache_.last_tok = best_tok; + } + + // Sync captured features to draft mirror + if (feature_mirror_.target_feat && cache_.target_feat && !draft_parked_) { + draft_feature_mirror_sync_range(cache_.target_feat, cache_.target_feat_cap, + feature_mirror_, kv_pos, len); + } } - return pos; + return kv_offset + pos; } // ── Decode ───────────────────────────────────────────────────────────── @@ -134,7 +275,8 @@ bool Gemma4Backend::do_decode(int committed, int n_gen, float scale = std::sqrt((float)hidden); for (int j = 0; j < hidden; ++j) embed_buf[j] *= scale; - if (!gemma4_step(backend_, w_, cache_, embed_buf.data(), 1, committed, logits)) { + if (!gemma4_step(backend_, w_, cache_, embed_buf.data(), + &tok, 1, committed, logits)) { return false; } @@ -164,11 +306,197 @@ bool Gemma4Backend::do_decode(int committed, int n_gen, return true; } +// ── Speculative Decode ───────────────────────────────────────────────── + +bool Gemma4Backend::do_spec_decode(int committed, int n_gen, + std::vector & out_tokens, + const DaemonIO & io) { + const int hidden = w_.n_embd; + int32_t last_tok = cache_.last_tok; + + DFlashTarget * target = dflash_target_; + const int q_len = dw_.block_size; + + StepGraph draft_sg; + + std::vector noise_embed((size_t)hidden * q_len); + std::vector noise_ids(q_len); + std::vector draft_tok(q_len); + std::vector target_tok(q_len); + std::vector pos_q(q_len); + std::vector pos_k; + std::vector local_hidden; + + int n_generated = 0; + int n_draft_steps = 0; + int n_accept_sum = 0; + + auto t_dec0 = std::chrono::steady_clock::now(); + + while (n_generated < n_gen) { + const int need_commit_budget = n_gen - n_generated; + + // 1. Build noise input: [last_tok, MASK, MASK, ..., MASK] + noise_ids[0] = last_tok; + for (int i = 1; i < q_len; i++) noise_ids[i] = target->mask_token_id(); + if (!target->embed_tokens(noise_ids.data(), q_len, noise_embed.data())) { + std::fprintf(stderr, "[gemma4-spec] noise embed failed\n"); + step_graph_destroy(draft_sg); + return false; + } + + // 2. Draft compute + constexpr int DRAFT_CTX_MAX_DEFAULT = 2048; + const int ring_cap = feature_mirror_.cap; + const int draft_ctx = std::min(committed, + std::min(ring_cap, std::max(DRAFT_CTX_MAX_DEFAULT, cfg_.draft_ctx_max))); + const int draft_start = committed - draft_ctx; + int mirror_slot0 = 0; + const bool use_mirror_view = + draft_feature_mirror_can_view(feature_mirror_, committed, draft_ctx, mirror_slot0); + + if (!build_draft_step(draft_sg, dw_, /*lm_head=*/nullptr, draft_backend_, + draft_ctx, use_mirror_view ? &feature_mirror_ : nullptr, + committed, + std::min(ring_cap, std::max(DRAFT_CTX_MAX_DEFAULT, cfg_.draft_ctx_max)))) { + std::fprintf(stderr, "[gemma4-spec] draft build failed\n"); + step_graph_destroy(draft_sg); + return false; + } + if (!use_mirror_view && + !copy_feature_ring_range_to_tensor(feature_mirror_, draft_sg.target_hidden_cat, + draft_start, draft_ctx)) { + std::fprintf(stderr, "[gemma4-spec] feature copy failed\n"); + step_graph_destroy(draft_sg); + return false; + } + + ggml_backend_tensor_set(draft_sg.inp_embed, noise_embed.data(), 0, + sizeof(float) * noise_embed.size()); + pos_k.resize((size_t)draft_ctx + q_len); + for (int i = 0; i < q_len; i++) pos_q[i] = draft_ctx + i; + for (int i = 0; i < draft_ctx + q_len; i++) pos_k[i] = i; + ggml_backend_tensor_set(draft_sg.positions, pos_q.data(), 0, + sizeof(int32_t) * pos_q.size()); + ggml_backend_tensor_set(draft_sg.positions_k, pos_k.data(), 0, + sizeof(int32_t) * pos_k.size()); + + auto st = ggml_backend_graph_compute(draft_backend_, draft_sg.gf); + if (st != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "[gemma4-spec] draft compute failed\n"); + step_graph_destroy(draft_sg); + return false; + } + + // Read draft hidden states + local_hidden.resize((size_t)hidden * q_len); + ggml_backend_tensor_get(draft_sg.hidden_states, local_hidden.data(), 0, + sizeof(float) * local_hidden.size()); + + // 3. Project draft hidden → token IDs via target LM head + if (!target->project_hidden_to_tokens(local_hidden.data(), q_len, draft_tok)) { + std::fprintf(stderr, "[gemma4-spec] projection failed\n"); + step_graph_destroy(draft_sg); + return false; + } + draft_tok[0] = last_tok; + + // 4. Verify: run target forward over all draft tokens. + // Gemma4 is a pure transformer — after verify, KV entries at accepted + // positions are already correct (causal masking guarantees independence + // from rejected tokens at later positions). We use KV truncation instead + // of the expensive snapshot/restore/replay approach. + int verify_last_tok = -1; + if (!target->verify_batch(draft_tok, committed, verify_last_tok, &target_tok)) { + std::fprintf(stderr, "[gemma4-spec] verify failed\n"); + step_graph_destroy(draft_sg); + return false; + } + + // 5. Acceptance: longest matching prefix + int accept_n = 1; + for (int i = 0; i < q_len - 1; i++) { + if (draft_tok[i + 1] == target_tok[i]) accept_n++; + else break; + } + int bonus_tok = (accept_n < q_len) ? target_tok[accept_n - 1] : -1; + int commit_n = accept_n + (bonus_tok >= 0 ? 1 : 0); + if (commit_n > need_commit_budget) { + commit_n = need_commit_budget; + if (commit_n <= accept_n) bonus_tok = -1; + } + + // 6. KV truncation: discard rejected positions, keep accepted. + // Accepted positions 0..accept_n-1 already have correct KV from verify. + cache_.cur_pos = committed + accept_n; + + // If there's a bonus token, run a 1-token forward to get its KV + features. + if (bonus_tok >= 0) { + std::vector bonus_vec = {bonus_tok}; + int bonus_last = -1; + if (!target->verify_batch(bonus_vec, committed + accept_n, bonus_last, nullptr)) { + std::fprintf(stderr, "[gemma4-spec] bonus forward failed\n"); + step_graph_destroy(draft_sg); + return false; + } + last_tok = bonus_last; + } else { + last_tok = verify_last_tok; + } + + // 7. Sync features from verify (positions 0..accept_n-1 are correct) + // and from bonus forward (position accept_n, if present). + if (feature_mirror_.target_feat && cache_.target_feat) { + draft_feature_mirror_sync_range(cache_.target_feat, cache_.target_feat_cap, + feature_mirror_, committed, commit_n); + } + + // 8. Emit committed tokens + bool hit_eos = false; + int emitted = 0; + for (int i = 0; i < commit_n; i++) { + int tok = (i < accept_n) ? draft_tok[i] : bonus_tok; + out_tokens.push_back(tok); + io.emit(tok); + emitted++; + if (io.cancelled) break; + if (tok == w_.eos_id || tok == w_.eos_chat_id) { + hit_eos = true; break; + } + } + committed += emitted; + cache_.cur_pos = committed; + n_generated += emitted; + n_accept_sum += std::min(accept_n, emitted); + n_draft_steps++; + if (io.cancelled) break; + if (hit_eos) break; + } + + step_graph_destroy(draft_sg); + + auto t_dec1 = std::chrono::steady_clock::now(); + const double decode_s = std::chrono::duration(t_dec1 - t_dec0).count(); + const int total_draft_pos = std::max(1, n_draft_steps * q_len); + const double accept_pct = 100.0 * (double)n_accept_sum / (double)total_draft_pos; + std::fprintf(stderr, "[gemma4-spec] tokens=%d time=%.3f s speed=%.2f tok/s " + "steps=%d accepted=%d/%d (%.1f%%) avg_commit=%.2f\n", + n_generated, decode_s, + n_generated > 0 ? n_generated / decode_s : 0.0, + n_draft_steps, n_accept_sum, total_draft_pos, accept_pct, + n_draft_steps > 0 ? (double)n_generated / (double)n_draft_steps : 0.0); + + io.emit(-1); + return true; +} + // ── Generate ─────────────────────────────────────────────────────────── GenerateResult Gemma4Backend::generate(const GenerateRequest & req, const DaemonIO & io) { GenerateResult result; + if (parked_) { result.error = "model is parked"; return result; } + DaemonIO out_io = io.with_token_callback(req.on_token); sampler_ = req.sampler; if (req.do_sample && sampler_.seed != 0) { @@ -177,65 +505,89 @@ GenerateResult Gemma4Backend::generate(const GenerateRequest & req, cache_.cur_pos = 0; - const int committed = do_prefill(req.prompt, out_io); + const int committed = do_prefill(req.prompt, out_io, /*kv_offset=*/0); if (committed < 0) { result.error = "prefill"; return result; } - if (req.n_gen > 0) { - const int hidden = w_.n_embd; - const int vocab = w_.n_vocab; - std::vector logits; - - // Re-step last token to get logits - int32_t last_tok = req.prompt.back(); - std::vector embed_buf(hidden); - w_.embedder.embed(&last_tok, 1, embed_buf.data()); - float scale = std::sqrt((float)hidden); - for (int j = 0; j < hidden; ++j) embed_buf[j] *= scale; - - if (!gemma4_step(backend_, w_, cache_, embed_buf.data(), 1, - committed - 1, logits)) { - result.error = "first logits"; - return result; + // Inline snapshot at snap_pos for prefix cache + if (req.snap_slot >= 0 && req.snap_pos > 0 && req.snap_pos <= committed) { + cache_.cur_pos = req.snap_pos; + if (snapshot_save(req.snap_slot)) { + std::fprintf(stderr, "[gemma4] inline-snap slot=%d cur_pos=%d\n", + req.snap_slot, req.snap_pos); } + cache_.cur_pos = committed; + } - // Sample first token - int32_t first; - if (sampler_.temp > 0) { - first = sample_logits(logits.data(), vocab, sampler_, - result.tokens, sampler_rng_); + if (req.n_gen > 0) { + // Try speculative decode if draft is available and temp==0 + const bool can_spec = dflash_target_ + && !draft_parked_ + && feature_mirror_.target_feat + && sampler_.temp == 0.0f; + + if (can_spec) { + if (!do_spec_decode(committed, req.n_gen, result.tokens, out_io)) { + result.error = "spec_decode"; + return result; + } } else { - first = 0; - float best = logits[0]; - for (int j = 1; j < vocab; ++j) { - if (logits[j] > best) { best = logits[j]; first = j; } + const int hidden = w_.n_embd; + const int vocab = w_.n_vocab; + std::vector logits; + + // Re-step last token to get logits + int32_t last_tok = req.prompt.back(); + std::vector embed_buf(hidden); + w_.embedder.embed(&last_tok, 1, embed_buf.data()); + float scale = std::sqrt((float)hidden); + for (int j = 0; j < hidden; ++j) embed_buf[j] *= scale; + + if (!gemma4_step(backend_, w_, cache_, embed_buf.data(), + &last_tok, 1, committed - 1, logits)) { + result.error = "first logits"; + return result; } - } - result.tokens.push_back(first); - out_io.emit(first); - if (out_io.cancelled) { - out_io.emit(-1); - result.ok = true; - return result; - } - if (first == w_.eos_id || first == w_.eos_chat_id) { - out_io.emit(-1); - result.ok = true; - return result; - } + // Sample first token + int32_t first; + if (sampler_.temp > 0) { + first = sample_logits(logits.data(), vocab, sampler_, + result.tokens, sampler_rng_); + } else { + first = 0; + float best = logits[0]; + for (int j = 1; j < vocab; ++j) { + if (logits[j] > best) { best = logits[j]; first = j; } + } + } + result.tokens.push_back(first); + out_io.emit(first); + if (out_io.cancelled) { + out_io.emit(-1); + result.ok = true; + return result; + } - if (req.n_gen > 1) { - if (!do_decode(committed, req.n_gen - 1, result.tokens, out_io)) { - result.error = "decode"; + if (first == w_.eos_id || first == w_.eos_chat_id) { + out_io.emit(-1); + result.ok = true; return result; } + + if (req.n_gen > 1) { + if (!do_decode(committed, req.n_gen - 1, result.tokens, out_io)) { + result.error = "decode"; + return result; + } + } + out_io.emit(-1); } + } else { + out_io.emit(-1); } - - out_io.emit(-1); result.ok = true; return result; } @@ -246,29 +598,169 @@ GenerateResult Gemma4Backend::restore_and_generate(int slot, const GenerateRequest & req, const DaemonIO & io) { GenerateResult result; + if (parked_) { result.error = "model is parked"; return result; } + + DaemonIO out_io = io.with_token_callback(req.on_token); + if (slot < 0 || slot >= PREFIX_SLOTS || !snapshots_[slot].ctx) { result.error = "bad slot"; - io.emit(-1); + out_io.emit(-1); return result; } const auto & snap = snapshots_[slot]; - // Copy right-sized snapshot into full-size cache (position is outermost dim). + // Restore snapshot into cache per-head (cache: [D, cache_len, Hk]). for (int il = 0; il < cache_.n_layer; ++il) { if (cache_.k[il] && snap.k_snap[il]) { - const size_t nbytes = ggml_nbytes(snap.k_snap[il]); - ggml_backend_tensor_set(cache_.k[il], snap.k_snap[il]->data, 0, nbytes); - ggml_backend_tensor_set(cache_.v[il], snap.v_snap[il]->data, 0, nbytes); + ggml_tensor * ck = cache_.k[il]; + const int D = (int)ck->ne[0]; + const int Hk = (int)ck->ne[2]; + const int cache_len = (int)ck->ne[1]; + const int save_pos = (int)snap.k_snap[il]->ne[1]; // min(snap.cur_pos, cache_len) + const size_t elem_sz = ggml_element_size(ck); + const size_t head_bytes_src = (size_t)D * save_pos * elem_sz; + const size_t head_bytes_dst = (size_t)D * cache_len * elem_sz; + const size_t copy_bytes = head_bytes_src; + + for (int h = 0; h < Hk; ++h) { + ggml_backend_tensor_set(cache_.k[il], + (const char *)snap.k_snap[il]->data + h * head_bytes_src, + h * head_bytes_dst, copy_bytes); + ggml_backend_tensor_set(cache_.v[il], + (const char *)snap.v_snap[il]->data + h * head_bytes_src, + h * head_bytes_dst, copy_bytes); + } } } - cache_.cur_pos = snap.cur_pos; - return generate(req, io); + // Restore target_feat from snapshot + if (snap.feat_snap && cache_.target_feat) { + const size_t feat_nbytes = ggml_nbytes(snap.feat_snap); + ggml_backend_tensor_set(cache_.target_feat, snap.feat_snap->data, 0, feat_nbytes); + } + + const int snap_pos = snap.cur_pos; + cache_.cur_pos = snap_pos; + cache_.last_tok = snap.last_tok; + + // Set up sampler + sampler_ = req.sampler; + if (req.do_sample && sampler_.seed != 0) { + sampler_rng_.seed(sampler_.seed); + } + + // Diff-prefill: only prefill tokens beyond the cached prefix + const int prompt_len = (int)req.prompt.size(); + int committed = snap_pos; + + if (prompt_len > snap_pos) { + // Compute delta (tokens after the snapshot) + std::vector delta(req.prompt.begin() + snap_pos, req.prompt.end()); + committed = do_prefill(delta, out_io, /*kv_offset=*/snap_pos); + if (committed < 0) { + result.error = "prefill"; + return result; + } + } else if (prompt_len > 0 && prompt_len < snap_pos) { + result.error = "snapshot_longer_than_prompt"; + out_io.emit(-1); + return result; + } + // else: prompt_len == snap_pos → no delta, committed stays at snap_pos + + // Inline snapshot at snap_pos for prefix cache (new snap from this request) + if (req.snap_slot >= 0 && req.snap_pos > 0 && req.snap_pos <= committed) { + cache_.cur_pos = req.snap_pos; + if (snapshot_save(req.snap_slot)) { + std::fprintf(stderr, "[gemma4] inline-snap slot=%d cur_pos=%d\n", + req.snap_slot, req.snap_pos); + } + cache_.cur_pos = committed; + } + + // Full feature mirror resync after restore: do_prefill only synced the + // delta [snap_pos..committed). Re-sync the entire [0..committed) range so + // the draft model sees correct features for the full context. + if (feature_mirror_.target_feat && cache_.target_feat && !draft_parked_ && committed > 0) { + draft_feature_mirror_sync_tail(cache_.target_feat, cache_.target_feat_cap, + feature_mirror_, committed); + } + + // Generate + if (req.n_gen > 0) { + const bool can_spec = dflash_target_ + && !draft_parked_ + && feature_mirror_.target_feat + && sampler_.temp == 0.0f; + + if (can_spec) { + if (!do_spec_decode(committed, req.n_gen, result.tokens, out_io)) { + result.error = "spec_decode"; + return result; + } + } else { + const int hidden = w_.n_embd; + const int vocab = w_.n_vocab; + std::vector logits; + + // Re-step last token to get logits for first generated token + int32_t last_tok = req.prompt.back(); + std::vector embed_buf(hidden); + w_.embedder.embed(&last_tok, 1, embed_buf.data()); + float scale = std::sqrt((float)hidden); + for (int j = 0; j < hidden; ++j) embed_buf[j] *= scale; + + if (!gemma4_step(backend_, w_, cache_, embed_buf.data(), + &last_tok, 1, committed - 1, logits)) { + result.error = "first logits"; + return result; + } + + // Sample first token + int32_t first; + if (sampler_.temp > 0) { + first = sample_logits(logits.data(), vocab, sampler_, + result.tokens, sampler_rng_); + } else { + first = 0; + float best = logits[0]; + for (int j = 1; j < vocab; ++j) { + if (logits[j] > best) { best = logits[j]; first = j; } + } + } + result.tokens.push_back(first); + out_io.emit(first); + if (out_io.cancelled) { + out_io.emit(-1); + result.ok = true; + return result; + } + + if (first == w_.eos_id || first == w_.eos_chat_id) { + out_io.emit(-1); + result.ok = true; + return result; + } + + if (req.n_gen > 1) { + if (!do_decode(committed, req.n_gen - 1, result.tokens, out_io)) { + result.error = "decode"; + return result; + } + } + out_io.emit(-1); + } + } else { + out_io.emit(-1); + } + result.ok = true; + return result; } // ── Snapshots ────────────────────────────────────────────────────────── bool Gemma4Backend::snapshot_save(int slot) { + if (parked_) return false; if (slot < 0 || slot >= PREFIX_SLOTS) return false; auto & snap = snapshots_[slot]; @@ -281,8 +773,9 @@ bool Gemma4Backend::snapshot_save(int slot) { if (needs_alloc) { free_gemma4_snapshot(snap); + const int n_feat_tensors = (cache_.target_feat && cache_.target_feat_cap > 0) ? 1 : 0; ggml_init_params ip{}; - ip.mem_size = ggml_tensor_overhead() * (size_t)(n_layer * 2 + 4) + 4096; + ip.mem_size = ggml_tensor_overhead() * (size_t)(n_layer * 2 + n_feat_tensors + 4) + 4096; ip.no_alloc = true; snap.ctx = ggml_init(ip); if (!snap.ctx) return false; @@ -291,35 +784,70 @@ bool Gemma4Backend::snapshot_save(int slot) { snap.v_snap.resize(n_layer, nullptr); for (int il = 0; il < n_layer; ++il) { if (cache_.k[il]) { - // Right-sized: [D, Hk, snap_pos] instead of [D, Hk, max_ctx] ggml_tensor * ck = cache_.k[il]; + const int cache_len = (int)ck->ne[1]; + // Save min(snap_pos, cache_len) positions + const int save_pos = std::min(snap_pos, cache_len); snap.k_snap[il] = ggml_new_tensor_3d(snap.ctx, ck->type, - ck->ne[0], ck->ne[1], snap_pos); + ck->ne[0], save_pos, ck->ne[2]); snap.v_snap[il] = ggml_new_tensor_3d(snap.ctx, ck->type, - ck->ne[0], ck->ne[1], snap_pos); + ck->ne[0], save_pos, ck->ne[2]); } } + // target_feat: save min(snap_pos, target_feat_cap) positions + snap.feat_snap = nullptr; + snap.feat_cap = 0; + if (cache_.target_feat && cache_.target_feat_cap > 0) { + const int feat_len = std::min(snap_pos, cache_.target_feat_cap); + snap.feat_snap = ggml_new_tensor_2d(snap.ctx, cache_.target_feat->type, + cache_.target_feat->ne[0], feat_len); + snap.feat_cap = cache_.target_feat_cap; + } + snap.buf = ggml_backend_alloc_ctx_tensors(snap.ctx, snap_backend_); if (!snap.buf) { ggml_free(snap.ctx); snap.ctx = nullptr; snap.k_snap.clear(); snap.v_snap.clear(); + snap.feat_snap = nullptr; return false; } } - // Copy first snap_pos positions (contiguous — position is outermost dim). + // Copy valid positions per head. + // Cache: [D, cache_len, Hk], Snap: [D, save_pos, Hk] for (int il = 0; il < n_layer; ++il) { if (cache_.k[il] && snap.k_snap[il]) { - const size_t nbytes = ggml_nbytes(snap.k_snap[il]); - ggml_backend_tensor_get(cache_.k[il], snap.k_snap[il]->data, 0, nbytes); - ggml_backend_tensor_get(cache_.v[il], snap.v_snap[il]->data, 0, nbytes); + ggml_tensor * ck = cache_.k[il]; + const int D = (int)ck->ne[0]; + const int Hk = (int)ck->ne[2]; + const int cache_len = (int)ck->ne[1]; + const int save_pos = std::min(snap_pos, cache_len); + const size_t elem_sz = ggml_element_size(ck); + const size_t head_bytes_src = (size_t)D * cache_len * elem_sz; + const size_t head_bytes_dst = (size_t)D * save_pos * elem_sz; + const size_t copy_bytes = head_bytes_dst; + + for (int h = 0; h < Hk; ++h) { + ggml_backend_tensor_get(cache_.k[il], + (char *)snap.k_snap[il]->data + h * head_bytes_dst, + h * head_bytes_src, copy_bytes); + ggml_backend_tensor_get(cache_.v[il], + (char *)snap.v_snap[il]->data + h * head_bytes_dst, + h * head_bytes_src, copy_bytes); + } } } snap.cur_pos = snap_pos; + snap.last_tok = cache_.last_tok; - std::printf("[gemma4] snapshot saved slot=%d pos=%d\n", slot, snap.cur_pos); - std::fflush(stdout); + // target_feat: copy min(snap_pos, cap) positions from GPU to snapshot + if (snap.feat_snap && cache_.target_feat) { + const size_t feat_nbytes = ggml_nbytes(snap.feat_snap); + ggml_backend_tensor_get(cache_.target_feat, snap.feat_snap->data, 0, feat_nbytes); + } + + std::fprintf(stderr, "[gemma4] snapshot saved slot=%d pos=%d\n", slot, snap.cur_pos); return true; } @@ -341,15 +869,76 @@ int Gemma4Backend::snapshot_cur_pos(int slot) const { bool Gemma4Backend::handle_compress(const std::string & line, const DaemonIO & io) { - (void)line; (void)io; - // Gemma4 doesn't use pflash drafter for compression (yet). - std::printf("[gemma4] compress: not supported\n"); - std::fflush(stdout); - return true; + // Check for "nopark" suffix + bool skip_park = (line.size() >= 16 && + line.compare(line.size() - 7, 7, " nopark") == 0); + + // Parse: "compress [nopark]" + char ppath[1024]; + int keep_x1000 = 0; + char drafter_path[1024] = {0}; + const int n = std::sscanf(line.c_str() + 9, "%1023s %d %1023s", + ppath, &keep_x1000, drafter_path); + if (n < 2) { + std::fprintf(stderr, "[compress] bad args\n"); + io.emit(-1); + return false; + } + + const char * dpath = (n >= 3 && drafter_path[0]) + ? drafter_path + : "/opt/lucebox/models/drafter/Qwen3-0.6B-BF16.gguf"; + + // Park target to free VRAM for the drafter (unless skip_park). + const bool was_parked = parked_; + if (!skip_park && !parked_) { + park("target"); + } + + // Synchronize backend + ggml_backend_synchronize(backend_); + + // Load drafter (lazy — stays resident for subsequent calls) + if (!drafter_loaded_) { + std::fprintf(stderr, "[compress] loading drafter from %s ...\n", dpath); + if (!load_drafter(dpath, /*gpu_layers=*/999, drafter_ctx_)) { + std::fprintf(stderr, "[compress] drafter init failed: %s\n", + dflash27b_last_error()); + io.emit(-1); + if (!skip_park && !was_parked) unpark("target"); + return false; + } + drafter_loaded_ = true; + std::fprintf(stderr, "[compress] drafter ready\n"); + } + + std::vector tokens = read_int32_file(ppath); + bool ok = false; + if (!tokens.empty()) { + const float keep = (float)keep_x1000 / 1000.0f; + auto compressed = drafter_score_and_compress(drafter_ctx_, tokens, keep); + ok = !compressed.empty(); + if (ok) { + std::fprintf(stderr, "[compress] %zu -> %zu tokens\n", + tokens.size(), compressed.size()); + for (int32_t t : compressed) io.emit(t); + } + } + io.emit(-1); + + // Restore park state + if (!skip_park && !was_parked) { + unpark("target"); + } + + return ok; } void Gemma4Backend::free_drafter() { - // No drafter to free. + if (drafter_loaded_) { + ::dflash::common::free_drafter(drafter_ctx_); + drafter_loaded_ = false; + } } bool Gemma4Backend::try_handle_command(const std::string & line, @@ -362,6 +951,12 @@ bool Gemma4Backend::try_handle_command(const std::string & line, void Gemma4Backend::shutdown() { for (int i = 0; i < PREFIX_SLOTS; ++i) snapshot_free(i); + free_drafter(); + // Clean up DFlash spec-decode resources + delete dflash_target_; dflash_target_ = nullptr; + draft_feature_mirror_free(feature_mirror_); + if (dw_.ctx) { free_draft_weights(dw_); } + if (draft_backend_) { ggml_backend_free(draft_backend_); draft_backend_ = nullptr; } free_gemma4_cache(cache_); free_gemma4_weights(w_); free_snapshot_backend(snap_backend_, backend_); diff --git a/dflash/src/gemma4/gemma4_backend.h b/dflash/src/gemma4/gemma4_backend.h index 11c97200..35da51de 100644 --- a/dflash/src/gemma4/gemma4_backend.h +++ b/dflash/src/gemma4/gemma4_backend.h @@ -7,8 +7,12 @@ #include "common/model_backend.h" #include "placement/placement_config.h" +#include "common/dflash_feature_ring.h" +#include "common/dflash_draft_graph.h" #include "gemma4_internal.h" +#include "gemma4_dflash_target.h" #include "common/sampler.h" +#include "../qwen3/qwen3_drafter.h" #include "ggml.h" #include "ggml-backend.h" @@ -21,9 +25,13 @@ namespace dflash::common { struct Gemma4BackendConfig { const char * model_path = nullptr; + const char * draft_path = nullptr; + int draft_gpu = -1; // GPU for draft model (-1 = same as target) + int draft_ctx_max = 2048; // max context for draft feature mirror DevicePlacement device; int stream_fd = -1; int chunk = 512; + int fa_window = 0; // 0 = full attention; >0 = sparse decode window }; class Gemma4Backend : public ModelBackend { @@ -76,17 +84,35 @@ class Gemma4Backend : public ModelBackend { SamplerCfg sampler_; std::mt19937_64 sampler_rng_{std::random_device{}()}; + // DFlash speculative decode + ggml_backend_t draft_backend_ = nullptr; + DraftWeights dw_{}; + DraftFeatureMirror feature_mirror_{}; + Gemma4DFlashTarget * dflash_target_ = nullptr; + bool draft_parked_ = false; + + // PFlash drafter (compress) + DrafterContext drafter_ctx_; + bool drafter_loaded_ = false; + // Snapshots static constexpr int PREFIX_SLOTS = 64; Gemma4Snapshot snapshots_[PREFIX_SLOTS]; - // Prefill prompt tokens in chunks, return committed position. - int do_prefill(const std::vector & tokens, const DaemonIO & io); + // Prefill prompt tokens in chunks, return absolute committed position. + // kv_offset: starting KV cache position (0 for fresh prefill, snap_pos for restore). + int do_prefill(const std::vector & tokens, const DaemonIO & io, + int kv_offset = 0); // Autoregressive decode loop. bool do_decode(int committed, int n_gen, std::vector & out_tokens, const DaemonIO & io); + + // DFlash speculative decode loop. + bool do_spec_decode(int committed, int n_gen, + std::vector & out_tokens, + const DaemonIO & io); }; } // namespace dflash::common diff --git a/dflash/src/gemma4/gemma4_dflash_target.cpp b/dflash/src/gemma4/gemma4_dflash_target.cpp new file mode 100644 index 00000000..aebd0b09 --- /dev/null +++ b/dflash/src/gemma4/gemma4_dflash_target.cpp @@ -0,0 +1,151 @@ +// Gemma4DFlashTarget — DFlashTarget adapter for Gemma4 iSWA models. + +#include "gemma4_dflash_target.h" +#include "dflash27b.h" + +#include +#include +#include + +namespace dflash::common { + +Gemma4DFlashTarget::Gemma4DFlashTarget( + Gemma4Weights & w, + Gemma4Cache & cache, + ggml_backend_t backend) + : w_(w), cache_(cache), backend_(backend) { + // Use capture layer IDs from cache (computed when target_feat is allocated) + if (!cache.capture_layer_ids.empty()) { + capture_ids_ = cache.capture_layer_ids; + } else { + // Fallback: evenly-spaced (legacy path) + const int N = DFLASH27B_DRAFT_N_TARGET_LAYERS; + capture_ids_.resize(N); + const int step = std::max(1, (w.n_layer - 2) / (N - 1)); + for (int k = 0; k < N; k++) { + capture_ids_[k] = 1 + k * step; + } + } +} + +Gemma4DFlashTarget::~Gemma4DFlashTarget() { + free_gemma4_snapshot(verify_snap_); +} + +bool Gemma4DFlashTarget::verify_batch( + const std::vector & tokens, + int base_pos, + int & last_tok, + std::vector * all_argmax) { + const int n_tokens = (int)tokens.size(); + if (n_tokens <= 0) return false; + + const int hidden = w_.n_embd; + + // Embed tokens + std::vector embed((size_t)n_tokens * hidden); + if (!w_.embedder.embed(tokens.data(), n_tokens, embed.data())) { + std::fprintf(stderr, "gemma4_verify_batch: embed failed\n"); + return false; + } + + // Scale by sqrt(n_embd) (Gemma4 convention) + const float scale = std::sqrt((float)hidden); + for (size_t i = 0; i < embed.size(); ++i) embed[i] *= scale; + + // Run verify (all-token argmax) + std::vector argmax_buf; + if (!gemma4_verify_batch(backend_, w_, cache_, embed.data(), + tokens.data(), n_tokens, base_pos, + argmax_buf)) { + return false; + } + + last_tok = argmax_buf[n_tokens - 1]; + if (all_argmax) { + *all_argmax = std::move(argmax_buf); + } + + return true; +} + +bool Gemma4DFlashTarget::snapshot_kv() { + // Save cur_pos and KV cache state + verify_snap_.cur_pos = cache_.cur_pos; + + // Allocate snapshot tensors if needed + if (verify_snap_.k_snap.empty()) { + ggml_init_params ip{}; + ip.mem_size = ggml_tensor_overhead() * (size_t)(w_.n_layer * 2 + 4) + 4096; + ip.no_alloc = true; + verify_snap_.ctx = ggml_init(ip); + if (!verify_snap_.ctx) return false; + + verify_snap_.k_snap.resize(w_.n_layer, nullptr); + verify_snap_.v_snap.resize(w_.n_layer, nullptr); + for (int il = 0; il < w_.n_layer; ++il) { + if (cache_.k[il]) { + verify_snap_.k_snap[il] = ggml_dup_tensor(verify_snap_.ctx, cache_.k[il]); + verify_snap_.v_snap[il] = ggml_dup_tensor(verify_snap_.ctx, cache_.v[il]); + } + } + verify_snap_.buf = ggml_backend_alloc_ctx_tensors(verify_snap_.ctx, backend_); + if (!verify_snap_.buf) return false; + } + + // Copy KV cache to snapshot + for (int il = 0; il < w_.n_layer; ++il) { + if (cache_.k[il] && verify_snap_.k_snap[il]) { + ggml_backend_tensor_copy(cache_.k[il], verify_snap_.k_snap[il]); + ggml_backend_tensor_copy(cache_.v[il], verify_snap_.v_snap[il]); + } + } + return true; +} + +bool Gemma4DFlashTarget::restore_kv() { + if (verify_snap_.k_snap.empty()) return false; + + // Restore KV cache from snapshot + for (int il = 0; il < w_.n_layer; ++il) { + if (cache_.k[il] && verify_snap_.k_snap[il]) { + ggml_backend_tensor_copy(verify_snap_.k_snap[il], cache_.k[il]); + ggml_backend_tensor_copy(verify_snap_.v_snap[il], cache_.v[il]); + } + } + cache_.cur_pos = verify_snap_.cur_pos; + return true; +} + +bool Gemma4DFlashTarget::is_eos(int token) const { + return token == w_.eos_id || token == w_.eos_chat_id; +} + +bool Gemma4DFlashTarget::embed_tokens(const int32_t * tokens, int n, + float * out) const { + if (!w_.embedder.embed(tokens, n, out)) return false; + // Scale by sqrt(n_embd) to match Gemma4's embedding convention. + // The draft was trained with Gemma4's scaled embeddings as noise input. + const float scale = std::sqrt((float)w_.n_embd); + const size_t total = (size_t)n * w_.n_embd; + for (size_t i = 0; i < total; ++i) out[i] *= scale; + return true; +} + +bool Gemma4DFlashTarget::project_hidden_to_tokens( + const float * hidden, + int n_tokens, + std::vector & tokens_out) { + return gemma4_project_hidden(backend_, w_, hidden, n_tokens, tokens_out); +} + +int Gemma4DFlashTarget::mask_token_id() const { + // Gemma4 DFlash draft uses token ID 4 as mask (per model card) + return 4; +} + +const std::vector & Gemma4DFlashTarget::capture_layer_ids() const { + return capture_ids_; +} + +} // namespace dflash::common diff --git a/dflash/src/gemma4/gemma4_dflash_target.h b/dflash/src/gemma4/gemma4_dflash_target.h new file mode 100644 index 00000000..1d12079b --- /dev/null +++ b/dflash/src/gemma4/gemma4_dflash_target.h @@ -0,0 +1,63 @@ +// Gemma4DFlashTarget — DFlashTarget adapter for Gemma4 iSWA models. +// +// Wraps the Gemma4 target infrastructure (Gemma4Weights, Gemma4Cache, +// gemma4_step) behind the generic DFlashTarget interface so the universal +// DFlash draft model can drive speculative decode verification. + +#pragma once + +#include "common/dflash_target.h" +#include "gemma4_internal.h" + +#include "ggml.h" +#include "ggml-backend.h" + +#include + +namespace dflash::common { + +class Gemma4DFlashTarget : public DFlashTarget { +public: + // Non-owning references — caller must ensure lifetime. + Gemma4DFlashTarget(Gemma4Weights & w, + Gemma4Cache & cache, + ggml_backend_t backend); + + ~Gemma4DFlashTarget() override; + + // ── DFlashTarget interface ────────────────────────────────────── + + bool verify_batch(const std::vector & tokens, + int base_pos, + int & last_tok, + std::vector * all_argmax = nullptr) override; + + bool snapshot_kv() override; + bool restore_kv() override; + + bool is_eos(int token) const override; + + bool embed_tokens(const int32_t * tokens, int n, + float * out) const override; + + bool project_hidden_to_tokens(const float * hidden, + int n_tokens, + std::vector & tokens_out) override; + + int hidden_size() const override { return w_.n_embd; } + int mask_token_id() const override; + const std::vector & capture_layer_ids() const override; + +private: + Gemma4Weights & w_; + Gemma4Cache & cache_; + ggml_backend_t backend_; + + // Capture layer IDs (built once in constructor). + std::vector capture_ids_; + + // Snapshot for speculative verify rollback. + Gemma4Snapshot verify_snap_; +}; + +} // namespace dflash::common diff --git a/dflash/src/gemma4/gemma4_graph.cpp b/dflash/src/gemma4/gemma4_graph.cpp index c4522edd..c042eb15 100644 --- a/dflash/src/gemma4/gemma4_graph.cpp +++ b/dflash/src/gemma4/gemma4_graph.cpp @@ -17,6 +17,7 @@ #include "gemma4_internal.h" #include "dflash27b.h" +#include "flashprefill.h" #include #include @@ -112,6 +113,8 @@ static ggml_tensor * build_gemma4_moe_block(ggml_context * ctx, ggml_tensor * at n_ff_exp, gate_up_e->ne[1], gate_up_e->ne[2], gate_up_e->nb[1], gate_up_e->nb[2], (size_t)n_ff_exp * ggml_element_size(gate_up_e)); + gate_e = ggml_cont(ctx, gate_e); + up_e = ggml_cont(ctx, up_e); ggml_tensor * gu = ggml_mul(ctx, ggml_gelu(ctx, gate_e), up_e); ggml_tensor * experts = ggml_mul_mat_id(ctx, L.ffn_down_exps, gu, selected); @@ -150,9 +153,9 @@ static ggml_tensor * build_gemma4_attn_block( int kv_start, int n_tokens) { - const int head_dim = w.head_dim; + const int head_dim = gemma4_head_dim(w, il); const int n_head = w.n_head; - const int n_head_kv = w.n_head_kv; + const int n_head_kv = gemma4_n_head_kv(w, il); const int q_dim = n_head * head_dim; const bool is_swa = gemma4_is_swa_layer(w, il); const bool has_kv = gemma4_has_kv(w, il); @@ -168,7 +171,7 @@ static ggml_tensor * build_gemma4_attn_block( // RoPE for Q const float rope_base = is_swa ? w.rope_freq_base_swa : w.rope_freq_base_full; - ggml_tensor * freq_factors = is_swa ? nullptr : L.rope_freqs; + ggml_tensor * freq_factors = is_swa ? nullptr : (L.rope_freqs ? L.rope_freqs : w.rope_freqs_global); Qcur = ggml_rope_ext(ctx, Qcur, positions, freq_factors, head_dim, GGML_ROPE_TYPE_NEOX, 0, rope_base, 1.0f, @@ -178,6 +181,7 @@ static ggml_tensor * build_gemma4_attn_block( int cache_il = cache.kv_source[il]; ggml_tensor * cache_k = cache.k[cache_il]; ggml_tensor * cache_v = cache.v[cache_il]; + const int cache_len = (int)cache_k->ne[1]; // max_ctx for full, swa_size for SWA if (has_kv) { // K/V projection + norm + RoPE + write to cache @@ -198,38 +202,60 @@ static ggml_tensor * build_gemma4_attn_block( 0, rope_base, 1.0f, 0.0f, 1.0f, 32.0f, 1.0f); - // Write K/V to cache + // Write K/V to cache (ring-buffer position for SWA layers) + const int write_pos = is_swa ? (kv_start % cache_len) : kv_start; ggml_tensor * Kcur_T = ggml_permute(ctx, Kcur, 0, 2, 1, 3); ggml_tensor * Vcur_T = ggml_permute(ctx, Vcur, 0, 2, 1, 3); ggml_tensor * k_slot = ggml_view_3d(ctx, cache_k, head_dim, n_tokens, n_head_kv, cache_k->nb[1], cache_k->nb[2], - cache_k->nb[1] * (size_t)kv_start); + cache_k->nb[1] * (size_t)write_pos); ggml_tensor * v_slot = ggml_view_3d(ctx, cache_v, head_dim, n_tokens, n_head_kv, cache_v->nb[1], cache_v->nb[2], - cache_v->nb[1] * (size_t)kv_start); + cache_v->nb[1] * (size_t)write_pos); ggml_build_forward_expand(gf, ggml_cpy(ctx, Kcur_T, k_slot)); ggml_build_forward_expand(gf, ggml_cpy(ctx, Vcur_T, v_slot)); } // else: KV-sharing layer — cache already written by source layer // Flash attention - const int kv_len = kv_start + n_tokens; + // For SWA layers: read entire ring buffer (cache_len positions) + // For full layers: read all positions (or windowed if fa_window > 0) + const int fa_window = cache.fa_window; + const int full_win_start = (!is_swa && fa_window > 0 && kv_start > fa_window) + ? (kv_start - fa_window) : 0; + const int kv_len_raw = is_swa ? std::min(kv_start + n_tokens, cache_len) + : (kv_start + n_tokens - full_win_start); + const int kv_len = (kv_len_raw + 255) & ~255; // pad to 256 for CUDA FA ggml_tensor * Qfa = ggml_permute(ctx, Qcur, 0, 2, 1, 3); Qfa = ggml_cont(ctx, Qfa); + const size_t cache_offset = is_swa ? 0 : (cache_k->nb[1] * (size_t)full_win_start); ggml_tensor * Kfa = ggml_view_3d(ctx, cache_k, head_dim, kv_len, n_head_kv, - cache_k->nb[1], cache_k->nb[2], 0); + cache_k->nb[1], cache_k->nb[2], cache_offset); ggml_tensor * Vfa = ggml_view_3d(ctx, cache_v, head_dim, kv_len, n_head_kv, - cache_v->nb[1], cache_v->nb[2], 0); - - const float kq_scale = 1.0f / std::sqrt((float)head_dim); - ggml_tensor * use_mask = is_swa ? attn_mask_swa : attn_mask_full; + cache_v->nb[1], cache_v->nb[2], cache_offset); + + // Gemma4 uses self.scaling = 1.0 (no QK scaling) because Q/K are already + // RMS-normed per-head. Standard 1/sqrt(head_dim) is NOT used here. + const float kq_scale = 1.0f; + ggml_tensor * use_mask; + if (is_swa) { + use_mask = attn_mask_swa; + } else if (full_win_start > 0) { + // View the mask starting at full_win_start column + use_mask = ggml_view_4d(ctx, attn_mask_full, + kv_len, n_tokens, 1, 1, + attn_mask_full->nb[1], attn_mask_full->nb[2], attn_mask_full->nb[3], + (size_t)full_win_start * ggml_element_size(attn_mask_full)); + } else { + use_mask = attn_mask_full; + } ggml_tensor * attn = ggml_flash_attn_ext(ctx, Qfa, Kfa, Vfa, use_mask, kq_scale, 0.0f, 0.0f); @@ -251,7 +277,8 @@ static ggml_tensor * build_gemma4_layer( ggml_tensor * attn_mask_swa, ggml_tensor * per_layer_input, // [n_embd_per_layer, n_tokens] or nullptr int kv_start, - int n_tokens) + int n_tokens, + int capture_idx = -1) // >=0: write to target_feat at this capture slot { const Gemma4Layer & L = w.layers[il]; @@ -312,14 +339,59 @@ static ggml_tensor * build_gemma4_layer( cur = ggml_mul(ctx, cur, L.out_scale); } + // Feature capture for DFlash spec-decode + if (capture_idx >= 0 && cache.target_feat) { + const int hidden = w.n_embd; + const size_t elt = ggml_element_size(cache.target_feat); + const size_t col_stride = cache.target_feat->nb[1]; + const int cap = cache.target_feat_cap; + const int slot_start = kv_start % cap; + const int pre_n = std::min(n_tokens, cap - slot_start); + const int post_n = n_tokens - pre_n; + + ggml_tensor * cur_2d = ggml_reshape_2d(ctx, cur, hidden, n_tokens); + + // First slice: [slot_start..slot_start+pre_n) in the ring. + { + const size_t offset = + (size_t)slot_start * col_stride + + (size_t)capture_idx * hidden * elt; + ggml_tensor * slot = ggml_view_2d(ctx, cache.target_feat, + hidden, pre_n, col_stride, offset); + ggml_tensor * src = ggml_view_2d(ctx, cur_2d, + hidden, pre_n, cur_2d->nb[1], 0); + ggml_build_forward_expand(gf, ggml_cpy(ctx, src, slot)); + } + + // Second slice: wrap-around at [0..post_n) if needed. + if (post_n > 0) { + const size_t offset = + (size_t)capture_idx * hidden * elt; + ggml_tensor * slot = ggml_view_2d(ctx, cache.target_feat, + hidden, post_n, col_stride, offset); + ggml_tensor * src = ggml_view_2d(ctx, cur_2d, + hidden, post_n, cur_2d->nb[1], + (size_t)pre_n * cur_2d->nb[1]); + ggml_build_forward_expand(gf, ggml_cpy(ctx, src, slot)); + } + } + return cur; } +// Helper: get a 2D slice from a 3D tensor along ne[2] (same as llama.cpp ggml_view_2d_slice). +static ggml_tensor * gemma4_view_2d_slice(ggml_context * ctx, ggml_tensor * x, int idx) { + return ggml_view_2d(ctx, x, x->ne[0], x->ne[1], + ggml_row_size(x->type, x->ne[0]), + (size_t)idx * x->ne[0] * x->ne[1] * ggml_element_size(x)); +} + bool gemma4_step( ggml_backend_t backend, const Gemma4Weights & w, Gemma4Cache & cache, const float * embed, + const int32_t * token_ids, int n_tokens, int kv_start, std::vector & out_logits) @@ -337,35 +409,80 @@ bool gemma4_step( ggml_tensor * pp = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_tokens); ggml_set_input(pp); + // Token IDs input (for per-layer embedding lookup) + ggml_tensor * tok_ids = nullptr; + if (token_ids && w.per_layer_tok_embd && w.per_layer_model_proj && w.n_embd_per_layer > 0) { + tok_ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_tokens); + ggml_set_input(tok_ids); + } + // Attention masks (full + SWA) - const int kv_len = kv_start + n_tokens; - ggml_tensor * mk_full = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv_len, n_tokens, 1, 1); + // Full-attention mask: covers all positions [0, kv_start+n_tokens) + const int kv_len_raw = kv_start + n_tokens; + const int kv_len_padded = (kv_len_raw + 255) & ~255; + ggml_tensor * mk_full = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv_len_padded, n_tokens, 1, 1); ggml_set_input(mk_full); ggml_tensor * mk_full_f16 = ggml_cast(ctx, mk_full, GGML_TYPE_F16); - ggml_tensor * mk_swa = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv_len, n_tokens, 1, 1); + + // SWA mask: covers the ring buffer [0, swa_size) with ring-buffer indexing + const int swa_size = cache.swa_size; + const int swa_len_raw = std::min(kv_start + n_tokens, swa_size); + const int swa_len_padded = (swa_len_raw + 255) & ~255; + ggml_tensor * mk_swa = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, swa_len_padded, n_tokens, 1, 1); ggml_set_input(mk_swa); ggml_tensor * mk_swa_f16 = ggml_cast(ctx, mk_swa, GGML_TYPE_F16); - // Per-layer embedding input (if model has per-layer embeddings) - // For simplicity, we precompute per-layer inputs for all layers at once - // Shape: [n_embd_per_layer, n_tokens, n_layer] → slice per layer - ggml_tensor * per_layer_all = nullptr; - if (w.per_layer_tok_embd && w.per_layer_model_proj && w.n_embd_per_layer > 0) { - // We need token IDs for per-layer embedding lookup, but we only have - // float embeddings at this point. Per-layer embedding requires a separate - // token ID input. For now, skip per-layer embeddings in the step function - // (they're computed on the embedding path in the backend). - // TODO: Add token ID input to gemma4_step for per-layer embedding support. + // Per-layer embedding computation (reference: gemma4-iswa.cpp build_inp_per_layer + project_per_layer_inputs) + ggml_tensor * per_layer_all = nullptr; // final shape: [n_embd_per_layer, n_tokens, n_layer] + if (tok_ids) { + const int D = w.n_embd_per_layer; + const int L = w.n_layer; + + // 1. Token per-layer embedding lookup + scale + // get_rows(per_layer_tok_embd[D*L, n_vocab], tok_ids) → [D*L, n_tokens] + ggml_tensor * inp_pl = ggml_get_rows(ctx, w.per_layer_tok_embd, tok_ids); + inp_pl = ggml_reshape_3d(ctx, inp_pl, D, L, n_tokens); // [D, L, n_tokens] + inp_pl = ggml_scale(ctx, inp_pl, std::sqrt((float)D)); + + // 2. Project main embedding through per_layer_model_proj + // mul_mat(per_layer_model_proj[n_embd, D*L], ie[n_embd, n_tokens]) → [D*L, n_tokens] + ggml_tensor * proj = ggml_mul_mat(ctx, w.per_layer_model_proj, ie); + proj = ggml_scale(ctx, proj, 1.0f / std::sqrt((float)w.n_embd)); + proj = ggml_reshape_3d(ctx, proj, D, L, n_tokens); // [D, L, n_tokens] + + // 3. RMS norm on projection (normalizes over ne[0]=D for each (layer, token)) + proj = ggml_rms_norm(ctx, proj, w.norm_eps); + // Reshape norm weight from [D*L] to [D, L] for broadcast mul over n_tokens + ggml_tensor * norm_w = ggml_reshape_2d(ctx, w.per_layer_proj_norm, D, L); + proj = ggml_mul(ctx, proj, norm_w); + + // 4. Add token embedding + projection, scale by 1/sqrt(2) + per_layer_all = ggml_add(ctx, proj, inp_pl); + per_layer_all = ggml_scale(ctx, per_layer_all, 1.0f / std::sqrt(2.0f)); + + // 5. Permute to [D, n_tokens, L] for easy per-layer slicing + per_layer_all = ggml_cont(ctx, ggml_permute(ctx, per_layer_all, 0, 2, 1, 3)); } // Build the graph ggml_tensor * cur = ie; // [n_embd, n_tokens] already scaled by sqrt(n_embd) in caller for (int il = 0; il < w.n_layer; ++il) { - ggml_tensor * pl_input = nullptr; // TODO: per-layer embedding per layer + ggml_tensor * pl_input = nullptr; + if (per_layer_all) { + // Slice [n_embd_per_layer, n_tokens] for this layer + pl_input = gemma4_view_2d_slice(ctx, per_layer_all, il); + } + // Determine capture index for this layer (-1 if not a capture layer) + int cap_idx = -1; + if (cache.target_feat) { + for (int k = 0; k < cache.n_capture_layers; k++) { + if (cache.capture_layer_ids[k] == il) { cap_idx = k; break; } + } + } cur = build_gemma4_layer(ctx, gf, w, cache, il, cur, pp, mk_full_f16, mk_swa_f16, pl_input, - kv_start, n_tokens); + kv_start, n_tokens, cap_idx); } // Final norm @@ -406,25 +523,39 @@ bool gemma4_step( for (int i = 0; i < n_tokens; ++i) pos[i] = kv_start + i; ggml_backend_tensor_set(pp, pos.data(), 0, ggml_nbytes(pp)); - // Causal mask (full attention) - std::vector mfull((size_t)kv_len * n_tokens, -INFINITY); + // Set token IDs for per-layer embedding + if (tok_ids && token_ids) { + ggml_backend_tensor_set(tok_ids, token_ids, 0, (size_t)n_tokens * sizeof(int32_t)); + } + + // Causal mask (full attention) — padded positions are masked with -inf + std::vector mfull((size_t)kv_len_padded * n_tokens, -INFINITY); for (int q = 0; q < n_tokens; ++q) { const int abs_q = kv_start + q; - for (int k = 0; k <= abs_q && k < kv_len; ++k) { - mfull[(size_t)q * kv_len + k] = 0.0f; + for (int k = 0; k <= abs_q && k < kv_len_raw; ++k) { + mfull[(size_t)q * kv_len_padded + k] = 0.0f; } } ggml_backend_tensor_set(mk_full, mfull.data(), 0, ggml_nbytes(mk_full)); - // SWA mask - std::vector mswa((size_t)kv_len * n_tokens, -INFINITY); + // SWA ring-buffer mask — maps cache indices to absolute positions const int W = w.sliding_window; + std::vector mswa((size_t)swa_len_padded * n_tokens, -INFINITY); for (int q = 0; q < n_tokens; ++q) { const int abs_q = kv_start + q; const int win_lo = std::max(0, abs_q - W + 1); - for (int k = win_lo; k <= abs_q && k < kv_len; ++k) { - mswa[(size_t)q * kv_len + k] = 0.0f; + // The ring buffer stores the most recent min(abs_q+1, swa_size) entries. + // Cache slot j holds absolute position: depends on how many tokens written. + const int total_written = abs_q + 1; // positions [0..abs_q] written so far + GGML_ASSERT(swa_size > 0 && "SWA branch entered with uninitialised cache.swa_size"); + for (int abs_k = win_lo; abs_k <= abs_q; ++abs_k) { + // Map absolute position to ring-buffer slot + const int slot = abs_k % swa_size; + if (slot < swa_len_raw) { + mswa[(size_t)q * swa_len_padded + slot] = 0.0f; + } } + (void)total_written; } ggml_backend_tensor_set(mk_swa, mswa.data(), 0, ggml_nbytes(mk_swa)); @@ -440,9 +571,941 @@ bool gemma4_step( ggml_backend_tensor_get(cur, out_logits.data(), 0, out_logits.size() * sizeof(float)); - cache.cur_pos = kv_len; + cache.cur_pos = kv_len_raw; + ggml_free(ctx); + return true; +} + +// ── gemma4_verify_batch ───────────────────────────────────────────────── +// Like gemma4_step but returns argmax for ALL token positions (not just last). + +bool gemma4_verify_batch( + ggml_backend_t backend, + const Gemma4Weights & w, + Gemma4Cache & cache, + const float * embed, + const int32_t * token_ids, + int n_tokens, + int kv_start, + std::vector & out_argmax) +{ + ggml_init_params ip{}; + ip.mem_size = ggml_tensor_overhead() * 16384 + ggml_graph_overhead() + 16 * 1024 * 1024; + ip.no_alloc = true; + ggml_context * ctx = ggml_init(ip); + ggml_cgraph * gf = ggml_new_graph_custom(ctx, 16384, false); + + // Input tensors + ggml_tensor * ie = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, w.n_embd, n_tokens); + ggml_set_input(ie); + ggml_tensor * pp = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_tokens); + ggml_set_input(pp); + + // Token IDs for per-layer embedding + ggml_tensor * tok_ids = nullptr; + if (token_ids && w.per_layer_tok_embd && w.per_layer_model_proj && w.n_embd_per_layer > 0) { + tok_ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_tokens); + ggml_set_input(tok_ids); + } + + // Attention masks (padded) + const int kv_len_raw = kv_start + n_tokens; + const int kv_len_padded = (kv_len_raw + 255) & ~255; + ggml_tensor * mk_full = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv_len_padded, n_tokens, 1, 1); + ggml_set_input(mk_full); + ggml_tensor * mk_full_f16 = ggml_cast(ctx, mk_full, GGML_TYPE_F16); + + // SWA mask: ring-buffer sized + const int swa_size = cache.swa_size; + const int swa_len_raw = std::min(kv_start + n_tokens, swa_size); + const int swa_len_padded = (swa_len_raw + 255) & ~255; + ggml_tensor * mk_swa = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, swa_len_padded, n_tokens, 1, 1); + ggml_set_input(mk_swa); + ggml_tensor * mk_swa_f16 = ggml_cast(ctx, mk_swa, GGML_TYPE_F16); + + // Per-layer embedding computation (same as gemma4_step) + ggml_tensor * per_layer_all = nullptr; + if (tok_ids) { + const int D = w.n_embd_per_layer; + const int L = w.n_layer; + ggml_tensor * inp_pl = ggml_get_rows(ctx, w.per_layer_tok_embd, tok_ids); + inp_pl = ggml_reshape_3d(ctx, inp_pl, D, L, n_tokens); + inp_pl = ggml_scale(ctx, inp_pl, std::sqrt((float)D)); + ggml_tensor * proj = ggml_mul_mat(ctx, w.per_layer_model_proj, ie); + proj = ggml_scale(ctx, proj, 1.0f / std::sqrt((float)w.n_embd)); + proj = ggml_reshape_3d(ctx, proj, D, L, n_tokens); + proj = ggml_rms_norm(ctx, proj, w.norm_eps); + ggml_tensor * norm_w = ggml_reshape_2d(ctx, w.per_layer_proj_norm, D, L); + proj = ggml_mul(ctx, proj, norm_w); + per_layer_all = ggml_add(ctx, proj, inp_pl); + per_layer_all = ggml_scale(ctx, per_layer_all, 1.0f / std::sqrt(2.0f)); + per_layer_all = ggml_cont(ctx, ggml_permute(ctx, per_layer_all, 0, 2, 1, 3)); + } + + // Build graph (all layers) + ggml_tensor * cur = ie; + for (int il = 0; il < w.n_layer; ++il) { + ggml_tensor * pl_input = nullptr; + if (per_layer_all) { + pl_input = gemma4_view_2d_slice(ctx, per_layer_all, il); + } + int cap_idx = -1; + if (cache.target_feat) { + for (int k = 0; k < cache.n_capture_layers; k++) { + if (cache.capture_layer_ids[k] == il) { cap_idx = k; break; } + } + } + cur = build_gemma4_layer(ctx, gf, w, cache, il, cur, pp, + mk_full_f16, mk_swa_f16, pl_input, + kv_start, n_tokens, cap_idx); + } + + // Final norm + cur = gemma4_rms_norm_mul(ctx, cur, w.out_norm, w.norm_eps); + + // lm_head for ALL tokens (no slicing) + cur = ggml_mul_mat(ctx, w.output, cur); // [n_vocab, n_tokens] + + // Logit softcapping + if (w.final_logit_softcap > 0.0f) { + cur = ggml_scale(ctx, cur, 1.0f / w.final_logit_softcap); + cur = ggml_tanh(ctx, cur); + cur = ggml_scale(ctx, cur, w.final_logit_softcap); + } + + // Argmax per token + cur = ggml_argmax(ctx, cur); // [n_tokens] + ggml_set_output(cur); + ggml_build_forward_expand(gf, cur); + + // Allocate + static ggml_gallocr_t galloc_verify = nullptr; + if (!galloc_verify) galloc_verify = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + if (!ggml_gallocr_alloc_graph(galloc_verify, gf)) { + std::fprintf(stderr, "gemma4_verify_batch: gallocr_alloc_graph failed\n"); + ggml_free(ctx); + return false; + } + + // Set inputs + ggml_backend_tensor_set(ie, embed, 0, ggml_nbytes(ie)); + std::vector pos((size_t)n_tokens); + for (int i = 0; i < n_tokens; ++i) pos[i] = kv_start + i; + ggml_backend_tensor_set(pp, pos.data(), 0, ggml_nbytes(pp)); + + if (tok_ids && token_ids) { + ggml_backend_tensor_set(tok_ids, token_ids, 0, (size_t)n_tokens * sizeof(int32_t)); + } + + // Masks + std::vector mfull((size_t)kv_len_padded * n_tokens, -INFINITY); + for (int q = 0; q < n_tokens; ++q) { + const int abs_q = kv_start + q; + for (int k = 0; k <= abs_q && k < kv_len_raw; ++k) { + mfull[(size_t)q * kv_len_padded + k] = 0.0f; + } + } + ggml_backend_tensor_set(mk_full, mfull.data(), 0, ggml_nbytes(mk_full)); + + // SWA ring-buffer mask + const int W = w.sliding_window; + std::vector mswa((size_t)swa_len_padded * n_tokens, -INFINITY); + for (int q = 0; q < n_tokens; ++q) { + const int abs_q = kv_start + q; + const int win_lo = std::max(0, abs_q - W + 1); + for (int abs_k = win_lo; abs_k <= abs_q; ++abs_k) { + const int slot = abs_k % swa_size; + if (slot < swa_len_raw) { + mswa[(size_t)q * swa_len_padded + slot] = 0.0f; + } + } + } + ggml_backend_tensor_set(mk_swa, mswa.data(), 0, ggml_nbytes(mk_swa)); + + // Compute + if (ggml_backend_graph_compute(backend, gf) != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "gemma4_verify_batch: graph_compute failed\n"); + ggml_free(ctx); + return false; + } + + // Read argmax + out_argmax.resize(n_tokens); + ggml_backend_tensor_get(cur, out_argmax.data(), 0, sizeof(int32_t) * n_tokens); + + cache.cur_pos = kv_len_raw; + ggml_free(ctx); + return true; +} + +// ── gemma4_project_hidden ─────────────────────────────────────────────── +// Runs out_norm + lm_head + softcap + argmax on external hidden states. + +bool gemma4_project_hidden( + ggml_backend_t backend, + const Gemma4Weights & w, + const float * hidden, + int n_tokens, + std::vector & out_tokens) +{ + ggml_init_params ip{}; + ip.mem_size = ggml_tensor_overhead() * 64 + ggml_graph_overhead() + 1024 * 1024; + ip.no_alloc = true; + ggml_context * ctx = ggml_init(ip); + ggml_cgraph * gf = ggml_new_graph(ctx); + + // Input: hidden states [n_embd, n_tokens] + // NOTE: The DFlash draft model already applies its own final RMSNorm, + // so we skip the target's out_norm and go directly to lm_head. + ggml_tensor * inp = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, w.n_embd, n_tokens); + ggml_set_input(inp); + + // lm_head (skip out_norm — draft already normalized) + ggml_tensor * cur = ggml_mul_mat(ctx, w.output, inp); // [n_vocab, n_tokens] + + // Logit softcapping + if (w.final_logit_softcap > 0.0f) { + cur = ggml_scale(ctx, cur, 1.0f / w.final_logit_softcap); + cur = ggml_tanh(ctx, cur); + cur = ggml_scale(ctx, cur, w.final_logit_softcap); + } + + // Argmax + cur = ggml_argmax(ctx, cur); // [n_tokens] + ggml_set_output(cur); + ggml_build_forward_expand(gf, cur); + + // Allocate + static ggml_gallocr_t galloc_proj = nullptr; + if (!galloc_proj) galloc_proj = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + if (!ggml_gallocr_alloc_graph(galloc_proj, gf)) { + std::fprintf(stderr, "gemma4_project_hidden: gallocr_alloc_graph failed\n"); + ggml_free(ctx); + return false; + } + + // Set input + ggml_backend_tensor_set(inp, hidden, 0, sizeof(float) * (size_t)n_tokens * w.n_embd); + + // Compute + if (ggml_backend_graph_compute(backend, gf) != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "gemma4_project_hidden: graph_compute failed\n"); + ggml_free(ctx); + return false; + } + + // Read result + out_tokens.resize(n_tokens); + ggml_backend_tensor_get(cur, out_tokens.data(), 0, sizeof(int32_t) * n_tokens); + ggml_free(ctx); return true; } +// ── gemma4_prefill_bsa ────────────────────────────────────────────────── +// Full-prompt BSA prefill: processes all tokens at once, layer-by-layer. +// SWA layers use flash_prefill_forward_bf16 (block-sparse attention). +// Full-attention layers use ggml_flash_attn_ext (dense, exact). +// After all layers: fills KV cache for subsequent decode. + +// Persistent buffer helper (same pattern as Qwen3). +struct G4PersBuf { + ggml_context * ctx = nullptr; + ggml_backend_buffer_t buf = nullptr; + ggml_tensor * t = nullptr; +}; + +static bool g4_make_pers(ggml_backend_t backend, ggml_type type, int n_dim, + const int64_t * dims, G4PersBuf & out) { + ggml_init_params ip{}; + ip.mem_size = ggml_tensor_overhead() * 4 + 1024; + ip.no_alloc = true; + ip.mem_buffer = nullptr; + out.ctx = ggml_init(ip); + if (!out.ctx) return false; + if (n_dim == 1) out.t = ggml_new_tensor_1d(out.ctx, type, dims[0]); + else if (n_dim == 2) out.t = ggml_new_tensor_2d(out.ctx, type, dims[0], dims[1]); + else if (n_dim == 3) out.t = ggml_new_tensor_3d(out.ctx, type, dims[0], dims[1], dims[2]); + else return false; + out.buf = ggml_backend_alloc_ctx_tensors(out.ctx, backend); + return out.buf != nullptr; +} + +static void g4_free_pers(G4PersBuf & p) { + if (p.buf) { ggml_backend_buffer_free(p.buf); p.buf = nullptr; } + if (p.ctx) { ggml_free(p.ctx); p.ctx = nullptr; } + p.t = nullptr; +} + +static int g4_bsa_chunk_size() { + if (const char * e = std::getenv("DFLASH_G4_BSA_CHUNK")) { + int v = std::atoi(e); + if (v >= 512) return v; + } + return 4096; +} + +bool gemma4_prefill_bsa( + ggml_backend_t backend, + const Gemma4Weights & w, + Gemma4Cache & cache, + const float * embed, + const int32_t * token_ids, + int S, + std::vector & out_logits) +{ + const int hidden = w.n_embd; + const int n_layer = w.n_layer; + const int n_head = w.n_head; + const float eps = w.norm_eps; + + // Determine max dimensions across all layers for buffer allocation. + int max_q_dim = 0, max_kv_dim = 0; + for (int il = 0; il < n_layer; ++il) { + const int D = gemma4_head_dim(w, il); + const int Hk = gemma4_n_head_kv(w, il); + max_q_dim = std::max(max_q_dim, D * n_head); + max_kv_dim = std::max(max_kv_dim, D * Hk); + } + + // Use BF16 only for sm_80+ (native BF16 tensor cores). Volta/Turing + // use F16 with F16 WMMA kernels; other arches use F16 with ggml FA fallback. + const ggml_type half_type = +#ifdef DFLASH27B_HAVE_SM80_FLASHPREFILL + GGML_TYPE_BF16; +#else + GGML_TYPE_F16; +#endif + + // Allocate persistent buffers. + G4PersBuf hidden_buf{}, Q_buf{}, K_buf{}, V_buf{}, attn_out_buf{}; + int64_t d_h[] = {(int64_t)hidden, (int64_t)S}; + int64_t d_q[] = {(int64_t)max_q_dim, (int64_t)S}; + int64_t d_kv[] = {(int64_t)max_kv_dim, (int64_t)S}; + + auto cleanup_all = [&]() { + g4_free_pers(hidden_buf); + g4_free_pers(Q_buf); + g4_free_pers(K_buf); + g4_free_pers(V_buf); + g4_free_pers(attn_out_buf); + }; + + if (!g4_make_pers(backend, GGML_TYPE_F32, 2, d_h, hidden_buf) || + !g4_make_pers(backend, half_type, 2, d_q, Q_buf) || + !g4_make_pers(backend, half_type, 2, d_kv, K_buf) || + !g4_make_pers(backend, half_type, 2, d_kv, V_buf) || + !g4_make_pers(backend, half_type, 2, d_q, attn_out_buf)) { + std::fprintf(stderr, "[gemma4-bsa] persistent buffer alloc failed\n"); + cleanup_all(); + return false; + } + + // Upload embedded+scaled input to hidden_buf. + ggml_backend_tensor_set(hidden_buf.t, embed, 0, (size_t)hidden * S * sizeof(float)); + + // Precompute per-layer embeddings on GPU if the model has them. + // per_layer_all: [n_embd_per_layer, S, n_layer] — computed once, sliced per layer. + G4PersBuf per_layer_buf{}; + if (token_ids && w.per_layer_tok_embd && w.per_layer_model_proj && w.n_embd_per_layer > 0) { + const int D_pl = w.n_embd_per_layer; + const int L_pl = n_layer; + int64_t d_pl[] = {(int64_t)D_pl, (int64_t)S, (int64_t)L_pl}; + if (!g4_make_pers(backend, GGML_TYPE_F32, 3, d_pl, per_layer_buf)) { + std::fprintf(stderr, "[gemma4-bsa] per-layer buf alloc failed\n"); + cleanup_all(); + return false; + } + + // Build a graph to compute per-layer embeddings. + ggml_init_params ip{}; + ip.mem_size = ggml_tensor_overhead() * 32 + ggml_graph_overhead() + 1024 * 1024; + ip.no_alloc = true; + ggml_context * ctx = ggml_init(ip); + ggml_cgraph * gf = ggml_new_graph(ctx); + + ggml_tensor * tok = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, S); + ggml_set_input(tok); + ggml_tensor * h_in = ggml_view_2d(ctx, hidden_buf.t, hidden, S, + hidden * sizeof(float), 0); + + // get_rows(per_layer_tok_embd, tok) → [D_pl*L_pl, S] + ggml_tensor * inp_pl = ggml_get_rows(ctx, w.per_layer_tok_embd, tok); + inp_pl = ggml_reshape_3d(ctx, inp_pl, D_pl, L_pl, S); + inp_pl = ggml_scale(ctx, inp_pl, std::sqrt((float)D_pl)); + + // Project main embedding: mul_mat(per_layer_model_proj, h_in) + ggml_tensor * proj = ggml_mul_mat(ctx, w.per_layer_model_proj, h_in); + proj = ggml_scale(ctx, proj, 1.0f / std::sqrt((float)hidden)); + proj = ggml_reshape_3d(ctx, proj, D_pl, L_pl, S); + + // RMS norm on projection + proj = ggml_rms_norm(ctx, proj, eps); + ggml_tensor * norm_w = ggml_reshape_2d(ctx, w.per_layer_proj_norm, D_pl, L_pl); + proj = ggml_mul(ctx, proj, norm_w); + + // Add + scale + ggml_tensor * pl_all = ggml_add(ctx, proj, inp_pl); + pl_all = ggml_scale(ctx, pl_all, 1.0f / std::sqrt(2.0f)); + + // Permute to [D_pl, S, L_pl] and copy to persistent buffer + pl_all = ggml_cont(ctx, ggml_permute(ctx, pl_all, 0, 2, 1, 3)); + ggml_tensor * cpy = ggml_cpy(ctx, pl_all, per_layer_buf.t); + ggml_set_output(cpy); + ggml_build_forward_expand(gf, cpy); + + ggml_gallocr_t ga = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + if (!ggml_gallocr_alloc_graph(ga, gf)) { + std::fprintf(stderr, "[gemma4-bsa] per-layer graph alloc failed\n"); + ggml_gallocr_free(ga); ggml_free(ctx); + g4_free_pers(per_layer_buf); cleanup_all(); + return false; + } + ggml_backend_tensor_set(tok, token_ids, 0, (size_t)S * sizeof(int32_t)); + if (ggml_backend_graph_compute(backend, gf) != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "gemma4_prefill_bsa: per-layer embed graph_compute failed\n"); + ggml_gallocr_free(ga); ggml_free(ctx); + g4_free_pers(per_layer_buf); cleanup_all(); + return false; + } + ggml_gallocr_free(ga); + ggml_free(ctx); + } + + // Gallocr for per-layer graphs (reused). + ggml_gallocr_t galloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + const int CHUNK = g4_bsa_chunk_size(); + + // FlashPrefill config for SWA layers. + const int block_size = 128; + const int swa_window_blocks = (w.sliding_window + block_size - 1) / block_size; + flashprefill::FlashPrefillConfig swa_cfg; + swa_cfg.block_size = block_size; + swa_cfg.attention_sink = 0; + swa_cfg.window = swa_window_blocks; + swa_cfg.last_n_full = 0; + swa_cfg.alpha = 2.0f; // > 1.0 disables dynamic block selection + + // Scale for attention: Gemma4 uses 1.0 (Q/K already RMS-normed per head). + const float kq_scale = 1.0f; + + // ── Per-layer loop ── + for (int il = 0; il < n_layer; ++il) { + const Gemma4Layer & L = w.layers[il]; + const bool is_swa = gemma4_is_swa_layer(w, il); + const bool has_kv = gemma4_has_kv(w, il); + const int D = gemma4_head_dim(w, il); + const int Hk = gemma4_n_head_kv(w, il); + const int q_dim = D * n_head; + const int kv_dim = D * Hk; + + // ── Graph A (chunked): pre_norm + Q/K/V proj + norms + RoPE → persistent bufs ── + const float rope_base = is_swa ? w.rope_freq_base_swa : w.rope_freq_base_full; + ggml_tensor * freq_factors_ref = is_swa ? nullptr : + (L.rope_freqs ? L.rope_freqs : w.rope_freqs_global); + + for (int cs = 0; cs < S; cs += CHUNK) { + const int cl = std::min(CHUNK, S - cs); + + ggml_init_params ipA{}; + ipA.mem_size = ggml_tensor_overhead() * 64 + + ggml_graph_overhead_custom(512, false) + + 128 * 1024; + ipA.no_alloc = true; + ggml_context * gA = ggml_init(ipA); + if (!gA) { std::fprintf(stderr, "[gemma4-bsa] graph A init failed\n"); ggml_gallocr_free(galloc); cleanup_all(); g4_free_pers(per_layer_buf); return false; } + ggml_cgraph * gfA = ggml_new_graph_custom(gA, 512, false); + + // View into hidden_buf for this chunk. + const size_t h_esz = sizeof(float); + ggml_tensor * h_view = ggml_view_2d(gA, hidden_buf.t, + hidden, cl, + hidden * h_esz, + (size_t)cs * hidden * h_esz); + + // Positions for RoPE. + ggml_tensor * pos_t = ggml_new_tensor_1d(gA, GGML_TYPE_I32, cl); + ggml_set_input(pos_t); + + // Pre-attn norm. + ggml_tensor * h_norm = ggml_rms_norm(gA, h_view, eps); + h_norm = ggml_mul(gA, h_norm, L.attn_norm); + + // Q projection + norm + RoPE. + ggml_tensor * Q = ggml_mul_mat(gA, L.wq, h_norm); + Q = ggml_reshape_3d(gA, Q, D, n_head, cl); + if (L.q_norm) { + Q = gemma4_rms_norm_mul(gA, Q, L.q_norm, eps); + } + Q = ggml_rope_ext(gA, Q, pos_t, freq_factors_ref, D, + GGML_ROPE_TYPE_NEOX, 0, rope_base, 1.0f, + 0.0f, 1.0f, 32.0f, 1.0f); + // Reshape Q to [q_dim, cl] and copy to Q_buf. + Q = ggml_reshape_2d(gA, Q, q_dim, cl); + + const size_t q_esz = ggml_type_size(half_type); + ggml_tensor * Q_dst = ggml_view_2d(gA, Q_buf.t, q_dim, cl, + q_esz * max_q_dim, + (size_t)cs * q_esz * max_q_dim); + ggml_build_forward_expand(gfA, ggml_cpy(gA, Q, Q_dst)); + + if (has_kv) { + // K projection + norm + RoPE. + ggml_tensor * K = ggml_mul_mat(gA, L.wk, h_norm); + K = ggml_reshape_3d(gA, K, D, Hk, cl); + if (L.k_norm) { + K = gemma4_rms_norm_mul(gA, K, L.k_norm, eps); + } + K = ggml_rope_ext(gA, K, pos_t, freq_factors_ref, D, + GGML_ROPE_TYPE_NEOX, 0, rope_base, 1.0f, + 0.0f, 1.0f, 32.0f, 1.0f); + K = ggml_reshape_2d(gA, K, kv_dim, cl); + + // V projection + RMSNorm (Gemma4 specific). + ggml_tensor * V = L.wv ? ggml_mul_mat(gA, L.wv, h_norm) + : ggml_mul_mat(gA, L.wk, h_norm); + V = ggml_reshape_3d(gA, V, D, Hk, cl); + V = ggml_rms_norm(gA, V, eps); + V = ggml_reshape_2d(gA, V, kv_dim, cl); + + const size_t kv_esz = ggml_type_size(half_type); + ggml_tensor * K_dst = ggml_view_2d(gA, K_buf.t, kv_dim, cl, + kv_esz * max_kv_dim, + (size_t)cs * kv_esz * max_kv_dim); + ggml_tensor * V_dst = ggml_view_2d(gA, V_buf.t, kv_dim, cl, + kv_esz * max_kv_dim, + (size_t)cs * kv_esz * max_kv_dim); + ggml_build_forward_expand(gfA, ggml_cpy(gA, K, K_dst)); + ggml_build_forward_expand(gfA, ggml_cpy(gA, V, V_dst)); + + // Write to KV cache for subsequent decode. + // K is [kv_dim, cl] = [D*Hk, cl]. Cache is [D, cache_len, Hk] F16. + // Reshape K to [D, Hk, cl] → permute to [D, cl, Hk] → copy into cache slot. + ggml_tensor * cache_k_t = cache.k[il]; + ggml_tensor * cache_v_t = cache.v[il]; + if (cache_k_t) { + const int cache_len_il = (int)cache_k_t->ne[1]; + const int ring_pos = is_swa ? (cs % cache_len_il) : cs; + + // Lambda to copy a sub-range of K/V into cache. + auto write_kv_range = [&](int src_off, int dst_ring, int n) { + if (n <= 0) return; + // K[src_off:src_off+n] → cache_k[dst_ring:dst_ring+n] + ggml_tensor * Ks = (src_off == 0 && n == cl) ? K + : ggml_view_2d(gA, K, kv_dim, n, + K->nb[1], (size_t)src_off * K->nb[1]); + ggml_tensor * K3 = ggml_reshape_3d(gA, Ks, D, Hk, n); + ggml_tensor * Kp = ggml_cont(gA, ggml_permute(gA, K3, 0, 2, 1, 3)); + ggml_tensor * k_slot = ggml_view_3d(gA, cache_k_t, + D, n, Hk, + cache_k_t->nb[1], cache_k_t->nb[2], + cache_k_t->nb[1] * (size_t)dst_ring); + ggml_build_forward_expand(gfA, ggml_cpy(gA, Kp, k_slot)); + + ggml_tensor * Vs = (src_off == 0 && n == cl) ? V + : ggml_view_2d(gA, V, kv_dim, n, + V->nb[1], (size_t)src_off * V->nb[1]); + ggml_tensor * V3 = ggml_reshape_3d(gA, Vs, D, Hk, n); + ggml_tensor * Vp = ggml_cont(gA, ggml_permute(gA, V3, 0, 2, 1, 3)); + ggml_tensor * v_slot = ggml_view_3d(gA, cache_v_t, + D, n, Hk, + cache_v_t->nb[1], cache_v_t->nb[2], + cache_v_t->nb[1] * (size_t)dst_ring); + ggml_build_forward_expand(gfA, ggml_cpy(gA, Vp, v_slot)); + }; + + if (!is_swa && ring_pos + cl > cache_len_il) { + // Full-attention layer: positions exceed cache — truncate. + const int n_fit = cache_len_il - ring_pos; + if (n_fit > 0) write_kv_range(0, ring_pos, n_fit); + } else if (is_swa && ring_pos + cl > cache_len_il) { + // SWA ring wrap — split into two writes. + const int first_n = cache_len_il - ring_pos; + write_kv_range(0, ring_pos, first_n); + write_kv_range(first_n, 0, cl - first_n); + } else { + write_kv_range(0, ring_pos, cl); + } + } + } + + if (!ggml_gallocr_alloc_graph(galloc, gfA)) { + std::fprintf(stderr, "[gemma4-bsa] graph A alloc failed layer=%d cs=%d\n", il, cs); + ggml_free(gA); ggml_gallocr_free(galloc); cleanup_all(); g4_free_pers(per_layer_buf); + return false; + } + + // Set positions. + std::vector pos((size_t)cl); + for (int i = 0; i < cl; ++i) pos[i] = cs + i; + ggml_backend_tensor_set(pos_t, pos.data(), 0, (size_t)cl * sizeof(int32_t)); + + ggml_backend_graph_compute(backend, gfA); + ggml_backend_synchronize(backend); + ggml_free(gA); + } + + // ── Attention ── + // Determine which K/V to use (KV sharing). + const int kv_source_il = cache.kv_source[il]; + // If this layer reuses another layer's KV, the source layer's K/V is + // already in K_buf/V_buf from when that layer was processed. + // For BSA we need the source layer's buffers, but since we process + // layers sequentially and overwrite K_buf/V_buf each layer, we need + // to handle sharing differently. + // + // Simplification: KV-sharing layers have the same head_dim and n_head_kv + // as their source. During BSA prefill, we DON'T overwrite K_buf/V_buf + // for layers without has_kv, so they still hold the source layer's data. + // This works because kv_source[il] < il for sharing layers. + + bool used_bsa = false; + if (is_swa && D == 128) { + // ── BSA sparse-FA for SWA layers (head_dim=128) ── + const bool q_contiguous = (q_dim == max_q_dim); + const bool kv_contiguous = (kv_dim == max_kv_dim); + + int rc; + if (q_contiguous && kv_contiguous) { + rc = flashprefill::flash_prefill_forward( + backend, Q_buf.t->data, K_buf.t->data, + V_buf.t->data, attn_out_buf.t->data, + 1, S, n_head, Hk, D, kq_scale, half_type, swa_cfg); + } else { + // Non-contiguous: allocate temporary packed buffers. + G4PersBuf Q_pack{}, K_pack{}, V_pack{}, O_pack{}; + int64_t dq[] = {(int64_t)q_dim, (int64_t)S}; + int64_t dk[] = {(int64_t)kv_dim, (int64_t)S}; + if (!g4_make_pers(backend, half_type, 2, dq, Q_pack) || + !g4_make_pers(backend, half_type, 2, dk, K_pack) || + !g4_make_pers(backend, half_type, 2, dk, V_pack) || + !g4_make_pers(backend, half_type, 2, dq, O_pack)) { + std::fprintf(stderr, "[gemma4-bsa] pack buf alloc failed\n"); + g4_free_pers(Q_pack); g4_free_pers(K_pack); + g4_free_pers(V_pack); g4_free_pers(O_pack); + ggml_gallocr_free(galloc); cleanup_all(); g4_free_pers(per_layer_buf); + return false; + } + + const size_t esz = ggml_type_size(half_type); + cudaMemcpy2D(Q_pack.t->data, q_dim * esz, + Q_buf.t->data, max_q_dim * esz, + q_dim * esz, S, cudaMemcpyDeviceToDevice); + cudaMemcpy2D(K_pack.t->data, kv_dim * esz, + K_buf.t->data, max_kv_dim * esz, + kv_dim * esz, S, cudaMemcpyDeviceToDevice); + cudaMemcpy2D(V_pack.t->data, kv_dim * esz, + V_buf.t->data, max_kv_dim * esz, + kv_dim * esz, S, cudaMemcpyDeviceToDevice); + + rc = flashprefill::flash_prefill_forward( + backend, Q_pack.t->data, K_pack.t->data, + V_pack.t->data, O_pack.t->data, + 1, S, n_head, Hk, D, kq_scale, half_type, swa_cfg); + + // Copy packed output back to strided attn_out_buf. + cudaMemcpy2D(attn_out_buf.t->data, max_q_dim * esz, + O_pack.t->data, q_dim * esz, + q_dim * esz, S, cudaMemcpyDeviceToDevice); + + g4_free_pers(Q_pack); g4_free_pers(K_pack); + g4_free_pers(V_pack); g4_free_pers(O_pack); + } + + if (rc != 0) { + std::fprintf(stderr, "[gemma4-bsa] flash_prefill failed layer=%d rc=%d\n", il, rc); + ggml_gallocr_free(galloc); cleanup_all(); g4_free_pers(per_layer_buf); + return false; + } + cudaDeviceSynchronize(); + used_bsa = true; + } + + if (!used_bsa) { + // Build a ggml graph for dense causal attention for this layer. + // Process the full sequence in one FA call (or chunked if too large). + for (int cs = 0; cs < S; cs += CHUNK) { + const int cl = std::min(CHUNK, S - cs); + const int kv_len = cs + cl; // attend to all positions up to current + + ggml_init_params ipFA{}; + ipFA.mem_size = ggml_tensor_overhead() * 32 + + ggml_graph_overhead_custom(64, false) + + 128 * 1024; + ipFA.no_alloc = true; + ggml_context * gFA = ggml_init(ipFA); + ggml_cgraph * gfFA = ggml_new_graph_custom(gFA, 64, false); + + const size_t esz = ggml_type_size(half_type); + + // Q view: [D, n_head, cl] from Q_buf + ggml_tensor * Qfa = ggml_view_3d(gFA, Q_buf.t, + D, n_head, cl, + esz * D, esz * max_q_dim, + (size_t)cs * esz * max_q_dim); + + // K view: [D, Hk, kv_len] from K_buf + ggml_tensor * Kfa = ggml_view_3d(gFA, K_buf.t, + D, Hk, kv_len, + esz * D, esz * max_kv_dim, + 0); + + // V view: [D, Hk, kv_len] from V_buf + ggml_tensor * Vfa = ggml_view_3d(gFA, V_buf.t, + D, Hk, kv_len, + esz * D, esz * max_kv_dim, + 0); + + // Causal mask: [kv_len_padded, cl] + const int kv_len_padded = (kv_len + 255) & ~255; + ggml_tensor * mask = ggml_new_tensor_4d(gFA, GGML_TYPE_F32, + kv_len_padded, cl, 1, 1); + ggml_set_input(mask); + ggml_tensor * mask_f16 = ggml_cast(gFA, mask, GGML_TYPE_F16); + + ggml_tensor * attn = ggml_flash_attn_ext(gFA, Qfa, Kfa, Vfa, mask_f16, + kq_scale, 0.0f, 0.0f); + + // Write output to attn_out_buf: [q_dim, cl] at offset cs. + attn = ggml_reshape_2d(gFA, attn, q_dim, cl); + ggml_tensor * O_dst = ggml_view_2d(gFA, attn_out_buf.t, q_dim, cl, + esz * max_q_dim, + (size_t)cs * esz * max_q_dim); + ggml_tensor * cpy_op = ggml_cpy(gFA, attn, O_dst); + ggml_set_output(cpy_op); + ggml_build_forward_expand(gfFA, cpy_op); + + if (!ggml_gallocr_alloc_graph(galloc, gfFA)) { + std::fprintf(stderr, "[gemma4-bsa] dense FA alloc failed layer=%d\n", il); + ggml_free(gFA); ggml_gallocr_free(galloc); cleanup_all(); g4_free_pers(per_layer_buf); + return false; + } + + // Fill causal mask. + std::vector m((size_t)kv_len_padded * cl, -INFINITY); + for (int q = 0; q < cl; ++q) { + const int abs_q = cs + q; + for (int k = 0; k <= abs_q && k < kv_len; ++k) { + m[(size_t)q * kv_len_padded + k] = 0.0f; + } + } + ggml_backend_tensor_set(mask, m.data(), 0, ggml_nbytes(mask)); + + ggml_backend_graph_compute(backend, gfFA); + ggml_backend_synchronize(backend); + ggml_free(gFA); + } + } + + // ── Graph B (chunked): o_proj + post_norm + residual + FFN + per_layer + scale ── + for (int cs = 0; cs < S; cs += CHUNK) { + const int cl = std::min(CHUNK, S - cs); + + ggml_init_params ipB{}; + ipB.mem_size = ggml_tensor_overhead() * 128 + + ggml_graph_overhead_custom(1024, false) + + 2 * 1024 * 1024; + ipB.no_alloc = true; + ggml_context * gB = ggml_init(ipB); + if (!gB) { std::fprintf(stderr, "[gemma4-bsa] graph B init failed\n"); ggml_gallocr_free(galloc); cleanup_all(); g4_free_pers(per_layer_buf); return false; } + ggml_cgraph * gfB = ggml_new_graph_custom(gB, 1024, false); + + const size_t h_esz = sizeof(float); + const size_t a_esz = ggml_type_size(half_type); + + // Hidden state for this chunk (residual input). + ggml_tensor * h_in = ggml_view_2d(gB, hidden_buf.t, hidden, cl, + hidden * h_esz, + (size_t)cs * hidden * h_esz); + + // Attention output for this chunk. + ggml_tensor * a_in = ggml_view_2d(gB, attn_out_buf.t, q_dim, cl, + a_esz * max_q_dim, + (size_t)cs * a_esz * max_q_dim); + + // o_proj: [q_dim, n_embd] × [q_dim, cl] → [n_embd, cl] + ggml_tensor * cur = ggml_mul_mat(gB, L.wo, a_in); + + // Post-attn norm. + if (L.attn_post_norm) { + cur = gemma4_rms_norm_mul(gB, cur, L.attn_post_norm, eps); + } + + // Residual after attention. + ggml_tensor * attn_res = ggml_add(gB, cur, h_in); + + // FFN. + const bool is_moe = (L.ffn_gate_inp != nullptr && il >= w.n_layer_dense_lead); + ggml_tensor * ffn_out; + if (is_moe) { + ggml_tensor * normed = gemma4_rms_norm_mul(gB, attn_res, L.ffn_norm, eps); + ffn_out = build_gemma4_moe_block(gB, attn_res, normed, w, L, cl); + } else { + cur = gemma4_rms_norm_mul(gB, attn_res, L.ffn_norm, eps); + ffn_out = build_gemma4_dense_ffn(gB, cur, L); + } + + // FFN post-norm. + if (L.ffn_post_norm) { + ffn_out = gemma4_rms_norm_mul(gB, ffn_out, L.ffn_post_norm, eps); + } + + // Residual after FFN. + cur = ggml_add(gB, ffn_out, attn_res); + + // Per-layer embedding injection. + if (per_layer_buf.t && L.per_layer_inp_gate && L.per_layer_proj) { + const int D_pl = w.n_embd_per_layer; + // Slice per_layer_buf [D_pl, S, n_layer] → [D_pl, cl] for this layer+chunk + ggml_tensor * pl_slice = ggml_view_2d(gB, per_layer_buf.t, + D_pl, cl, + D_pl * sizeof(float), + ((size_t)il * S + cs) * D_pl * sizeof(float)); + + ggml_tensor * gate = ggml_mul_mat(gB, L.per_layer_inp_gate, cur); + gate = ggml_gelu(gB, gate); + gate = ggml_mul(gB, gate, pl_slice); + ggml_tensor * proj = ggml_mul_mat(gB, L.per_layer_proj, gate); + if (L.per_layer_post_norm) { + proj = gemma4_rms_norm_mul(gB, proj, L.per_layer_post_norm, eps); + } + cur = ggml_add(gB, cur, proj); + } + + // Output scale. + if (L.out_scale) { + cur = ggml_mul(gB, cur, L.out_scale); + } + + // Write back to hidden_buf. + ggml_tensor * h_dst = ggml_view_2d(gB, hidden_buf.t, hidden, cl, + hidden * h_esz, + (size_t)cs * hidden * h_esz); + ggml_tensor * cpy = ggml_cpy(gB, cur, h_dst); + ggml_set_output(cpy); + ggml_build_forward_expand(gfB, cpy); + + if (!ggml_gallocr_alloc_graph(galloc, gfB)) { + std::fprintf(stderr, "[gemma4-bsa] graph B alloc failed layer=%d cs=%d\n", il, cs); + ggml_free(gB); ggml_gallocr_free(galloc); cleanup_all(); g4_free_pers(per_layer_buf); + return false; + } + + ggml_backend_graph_compute(backend, gfB); + ggml_backend_synchronize(backend); + ggml_free(gB); + } + + // Feature capture: write hidden states at capture layers to target_feat ring. + if (cache.target_feat) { + int cap_idx = -1; + for (int k = 0; k < cache.n_capture_layers; k++) { + if (cache.capture_layer_ids[k] == il) { cap_idx = k; break; } + } + if (cap_idx >= 0) { + const int cap = cache.target_feat_cap; + const size_t feat_col_stride = cache.target_feat->nb[1]; + const size_t feat_elt = ggml_element_size(cache.target_feat); + // Write last min(S, cap) positions into the ring buffer. + const int write_start = (S > cap) ? (S - cap) : 0; + const int write_n = std::min(S, cap); + for (int cs = write_start; cs < write_start + write_n; cs += CHUNK) { + const int cl = std::min(CHUNK, write_start + write_n - cs); + const int slot_start = cs % cap; + + ggml_init_params ipC{}; + ipC.mem_size = ggml_tensor_overhead() * 8 + + ggml_graph_overhead() + 64 * 1024; + ipC.no_alloc = true; + ggml_context * gC = ggml_init(ipC); + ggml_cgraph * gfC = ggml_new_graph(gC); + + ggml_tensor * h_src = ggml_view_2d(gC, hidden_buf.t, + hidden, cl, hidden * sizeof(float), + (size_t)cs * hidden * sizeof(float)); + + const size_t offset = (size_t)slot_start * feat_col_stride + + (size_t)cap_idx * hidden * feat_elt; + ggml_tensor * feat_dst = ggml_view_2d(gC, cache.target_feat, + hidden, cl, feat_col_stride, offset); + + ggml_build_forward_expand(gfC, ggml_cpy(gC, h_src, feat_dst)); + + if (ggml_gallocr_alloc_graph(galloc, gfC)) { + ggml_backend_graph_compute(backend, gfC); + } + ggml_free(gC); + } + } + } + } // end layer loop + + // ── Fill KV cache for decode ── + // KV cache was not populated during the BSA layer loop because K_buf/V_buf + // get overwritten each layer. We re-project K/V for each KV-owning layer + // from the hidden states that were stored before each layer's attention. + // + // However, we don't have the pre-norm hidden states anymore (hidden_buf has + // the final output). The correct approach is to write KV cache during the + // layer loop. Since this is a v1 implementation, we use the fallback: + // after BSA prefill returns, the caller (do_prefill) will run a single + // gemma4_step with the last chunk to populate the cache for decode. + // + // TODO: Move KV cache writes into the layer loop for zero-redundancy. + // For now, the caller handles cache population by running a trailing + // gemma4_step over the last swa_size tokens. + + // ── Final norm + logits (last token only) ── + { + ggml_init_params ipF{}; + ipF.mem_size = ggml_tensor_overhead() * 16 + ggml_graph_overhead() + 1024 * 1024; + ipF.no_alloc = true; + ggml_context * gF = ggml_init(ipF); + ggml_cgraph * gfF = ggml_new_graph(gF); + + // View last token of hidden_buf. + ggml_tensor * h_last = ggml_view_2d(gF, hidden_buf.t, hidden, 1, + hidden * sizeof(float), + (size_t)(S - 1) * hidden * sizeof(float)); + + // Final RMSNorm. + ggml_tensor * normed = gemma4_rms_norm_mul(gF, h_last, w.out_norm, eps); + + // lm_head. + ggml_tensor * logits = ggml_mul_mat(gF, w.output, normed); + + // Softcapping. + if (w.final_logit_softcap > 0.0f) { + logits = ggml_scale(gF, logits, 1.0f / w.final_logit_softcap); + logits = ggml_tanh(gF, logits); + logits = ggml_scale(gF, logits, w.final_logit_softcap); + } + + ggml_set_output(logits); + ggml_build_forward_expand(gfF, logits); + + if (!ggml_gallocr_alloc_graph(galloc, gfF)) { + std::fprintf(stderr, "[gemma4-bsa] final graph alloc failed\n"); + ggml_free(gF); ggml_gallocr_free(galloc); cleanup_all(); g4_free_pers(per_layer_buf); + return false; + } + + ggml_backend_graph_compute(backend, gfF); + ggml_backend_synchronize(backend); + + out_logits.resize((size_t)w.n_vocab); + ggml_backend_tensor_get(logits, out_logits.data(), 0, + out_logits.size() * sizeof(float)); + ggml_free(gF); + } + + // Update cache position. + cache.cur_pos = S; + + ggml_gallocr_free(galloc); + cleanup_all(); + g4_free_pers(per_layer_buf); + return true; +} + } // namespace dflash::common diff --git a/dflash/src/gemma4/gemma4_internal.h b/dflash/src/gemma4/gemma4_internal.h index d34107b7..4c365ab7 100644 --- a/dflash/src/gemma4/gemma4_internal.h +++ b/dflash/src/gemma4/gemma4_internal.h @@ -79,7 +79,8 @@ struct Gemma4Weights { // Global tensors ggml_tensor * tok_embd = nullptr; // [n_embd, n_vocab] ggml_tensor * out_norm = nullptr; // [n_embd] - ggml_tensor * output = nullptr; // [n_embd, n_vocab] (lm_head) + ggml_tensor * output = nullptr; // [n_embd, n_vocab] (lm_head, may be tied to tok_embd) + ggml_tensor * rope_freqs_global = nullptr; // [head_dim/2] global rope freq factors ggml_tensor * per_layer_tok_embd = nullptr; // [n_embd_per_layer * n_layer, n_vocab] ggml_tensor * per_layer_model_proj = nullptr; // [n_embd, n_embd_per_layer * n_layer] ggml_tensor * per_layer_proj_norm = nullptr; // [n_embd_per_layer * n_layer] @@ -91,8 +92,9 @@ struct Gemma4Weights { // Architecture metadata int n_layer = 0; int n_head = 0; - int n_head_kv = 0; - int head_dim = 128; + int n_head_kv = 0; // max n_head_kv (for backward compat) + int head_dim = 128; // head_dim for SWA layers (smaller) + int head_dim_full = 128; // head_dim for full-attention layers int n_embd = 0; int n_ff = 0; // dense FFN intermediate int n_ff_exp = 0; // expert FFN intermediate @@ -103,6 +105,9 @@ struct Gemma4Weights { int n_embd_per_layer = 0; // per-layer embedding dim int n_vocab = 0; + // Per-layer head counts (Gemma4 can have variable n_head_kv per layer) + std::vector n_head_kv_per_layer; + // iSWA int sliding_window = 0; std::vector swa_layers; // true = SWA, false = full attn @@ -132,6 +137,15 @@ inline bool gemma4_has_kv(const Gemma4Weights & w, int il) { return il < (int)w.has_kv.size() && w.has_kv[il]; } +inline int gemma4_head_dim(const Gemma4Weights & w, int il) { + return gemma4_is_swa_layer(w, il) ? w.head_dim : w.head_dim_full; +} + +inline int gemma4_n_head_kv(const Gemma4Weights & w, int il) { + if (il < (int)w.n_head_kv_per_layer.size()) return w.n_head_kv_per_layer[il]; + return w.n_head_kv; +} + // GGUF loader bool load_gemma4_gguf(const std::string & path, ggml_backend_t backend, @@ -144,6 +158,9 @@ struct Gemma4Cache { int cur_pos = 0; int max_ctx = 0; int n_layer = 0; + int swa_size = 0; // ring-buffer size for SWA layers (= sliding_window) + int fa_window = 0; // sparse decode window for full-attn layers (0 = full) + int32_t last_tok = -1; // argmax of last prefill token (for spec-decode entry) // Only layers where has_kv[il] == true have real K/V tensors. // KV-reuse layers reference an earlier layer's cache. @@ -151,19 +168,36 @@ struct Gemma4Cache { std::vector v; std::vector kv_source; // for each layer, which layer's KV to use + // DFlash feature capture ring buffer (BF16, allocated when draft is active) + ggml_tensor * target_feat = nullptr; // [fc_in, target_feat_cap] + int target_feat_cap = 0; + int n_capture_layers = 0; + std::vector capture_layer_ids; + ggml_context * ctx = nullptr; ggml_backend_buffer_t buf = nullptr; + + // Separate context/buffer for target_feat (allocated after draft load) + ggml_context * feat_ctx = nullptr; + ggml_backend_buffer_t feat_buf = nullptr; }; bool create_gemma4_cache(ggml_backend_t backend, const Gemma4Weights & w, int max_ctx, Gemma4Cache & out); void free_gemma4_cache(Gemma4Cache & c); +// Allocate target_feat ring buffer (call after draft load determines n_capture_layers). +bool create_gemma4_target_feat(ggml_backend_t backend, Gemma4Cache & cache, + int n_capture_layers, int hidden_size, int cap); + // Snapshot struct Gemma4Snapshot { int cur_pos = 0; + int32_t last_tok = -1; std::vector k_snap; std::vector v_snap; + ggml_tensor * feat_snap = nullptr; // [fc_in, feat_len] + int feat_cap = 0; ggml_context * ctx = nullptr; ggml_backend_buffer_t buf = nullptr; }; @@ -172,13 +206,50 @@ void free_gemma4_snapshot(Gemma4Snapshot & s); // Forward: run a single step (prefill chunk or decode token). // Returns logits for last token. +// token_ids: raw token IDs needed for per-layer embedding lookup (may be nullptr +// if the model has no per-layer embeddings). bool gemma4_step( ggml_backend_t backend, const Gemma4Weights & w, Gemma4Cache & cache, const float * embed, + const int32_t * token_ids, + int n_tokens, + int kv_start, + std::vector & out_logits); + +// Verify batch: run forward pass returning argmax for ALL positions. +// Used by DFlash speculative decode target. +bool gemma4_verify_batch( + ggml_backend_t backend, + const Gemma4Weights & w, + Gemma4Cache & cache, + const float * embed, + const int32_t * token_ids, int n_tokens, int kv_start, + std::vector & out_argmax); + +// Project hidden states through lm_head (out_norm + output + softcap + argmax). +// Used by DFlash draft to convert draft hidden states to token IDs. +bool gemma4_project_hidden( + ggml_backend_t backend, + const Gemma4Weights & w, + const float * hidden, + int n_tokens, + std::vector & out_tokens); + +// BSA sparse-FA prefill: process the full prompt at once using block-sparse +// attention for SWA layers (flash_prefill_forward_bf16). Full-attention layers +// use dense FA. Returns logits for the last token. Populates the KV cache +// for subsequent decode. Returns false on failure. +bool gemma4_prefill_bsa( + ggml_backend_t backend, + const Gemma4Weights & w, + Gemma4Cache & cache, + const float * embed, // [n_embd, S] scaled + const int32_t * token_ids, // [S] (for per-layer embedding) + int S, // total prompt length std::vector & out_logits); } // namespace dflash::common diff --git a/dflash/src/gemma4/gemma4_loader.cpp b/dflash/src/gemma4/gemma4_loader.cpp index d40db53a..077d53f9 100644 --- a/dflash/src/gemma4/gemma4_loader.cpp +++ b/dflash/src/gemma4/gemma4_loader.cpp @@ -14,10 +14,13 @@ #include "internal.h" #include "dflash27b.h" +#include #include +#include #include #include #include +#include #if !defined(_WIN32) #include @@ -54,12 +57,32 @@ struct Gemma4Mmap { uint32_t get_u32_or(gguf_context * g, const char * key, uint32_t def) { int64_t id = gguf_find_key(g, key); - return (id >= 0) ? gguf_get_val_u32(g, id) : def; + if (id < 0) return def; + // Handle array type: return first element + if (gguf_get_kv_type(g, id) == GGUF_TYPE_ARRAY) { + if (gguf_get_arr_n(g, id) == 0) return def; + return ((const uint32_t *)gguf_get_arr_data(g, id))[0]; + } + return gguf_get_val_u32(g, id); } float get_f32_or(gguf_context * g, const char * key, float def) { int64_t id = gguf_find_key(g, key); - return (id >= 0) ? gguf_get_val_f32(g, id) : def; + if (id < 0) return def; + if (gguf_get_kv_type(g, id) == GGUF_TYPE_ARRAY) { + if (gguf_get_arr_n(g, id) == 0) return def; + return ((const float *)gguf_get_arr_data(g, id))[0]; + } + return gguf_get_val_f32(g, id); +} + +// Read a u32 array key into a vector (empty if not found or not an array). +std::vector get_u32_arr(gguf_context * g, const char * key) { + int64_t id = gguf_find_key(g, key); + if (id < 0 || gguf_get_kv_type(g, id) != GGUF_TYPE_ARRAY) return {}; + const size_t n = gguf_get_arr_n(g, id); + const uint32_t * data = (const uint32_t *)gguf_get_arr_data(g, id); + return std::vector(data, data + n); } ggml_tensor * find_tensor(ggml_context * ctx, const char * name) { @@ -96,15 +119,25 @@ bool load_gemma4_gguf(const std::string & path, const uint32_t n_ff_exp = get_u32_or(gctx, "gemma4.expert_feed_forward_length", 0); const uint32_t n_head = get_u32_or(gctx, "gemma4.attention.head_count", 0); const uint32_t n_head_kv = get_u32_or(gctx, "gemma4.attention.head_count_kv", 0); - const uint32_t head_dim = get_u32_or(gctx, "gemma4.attention.key_length", 128); - const uint32_t n_vocab = get_u32_or(gctx, "gemma4.vocab_size", 0); + const uint32_t head_dim_full = get_u32_or(gctx, "gemma4.attention.key_length", 128); + const uint32_t head_dim_swa = get_u32_or(gctx, "gemma4.attention.key_length_swa", head_dim_full); const uint32_t n_expert = get_u32_or(gctx, "gemma4.expert_count", 0); const uint32_t n_expert_used = get_u32_or(gctx, "gemma4.expert_used_count", 0); - const uint32_t n_dense_lead = get_u32_or(gctx, "gemma4.leading_dense_block_count", 1); + const uint32_t n_dense_lead = get_u32_or(gctx, "gemma4.leading_dense_block_count", 0); const uint32_t sliding_win = get_u32_or(gctx, "gemma4.attention.sliding_window", 0); const uint32_t shared_kv = get_u32_or(gctx, "gemma4.attention.shared_kv_layers", 0); const uint32_t n_embd_pl = get_u32_or(gctx, "gemma4.embedding_length_per_layer_input", 0); + // Per-layer head_count_kv (may be array or scalar) + std::vector head_kv_arr = get_u32_arr(gctx, "gemma4.attention.head_count_kv"); + + // Get vocab size from token_embd tensor shape (not always in metadata) + uint32_t n_vocab = get_u32_or(gctx, "gemma4.vocab_size", 0); + if (n_vocab == 0) { + ggml_tensor * tok_embd = find_tensor(meta_ctx, "token_embd.weight"); + if (tok_embd) n_vocab = (uint32_t)tok_embd->ne[1]; + } + const float rope_base_full = get_f32_or(gctx, "gemma4.rope.freq_base", 1000000.0f); const float rope_base_swa = get_f32_or(gctx, "gemma4.rope.freq_base_swa", 10000.0f); const float norm_eps = get_f32_or(gctx, "gemma4.attention.layer_norm_rms_epsilon", 1e-6f); @@ -121,7 +154,8 @@ bool load_gemma4_gguf(const std::string & path, out.n_layer = (int)n_layer; out.n_head = (int)n_head; out.n_head_kv = (int)n_head_kv; - out.head_dim = (int)head_dim; + out.head_dim = (int)head_dim_swa; // SWA head_dim (smaller) + out.head_dim_full = (int)head_dim_full; // full-attn head_dim (larger) out.n_embd = (int)n_embd; out.n_ff = (int)n_ff; out.n_ff_exp = (int)n_ff_exp; @@ -137,6 +171,18 @@ bool load_gemma4_gguf(const std::string & path, out.final_logit_softcap = logit_softcap; out.norm_eps = norm_eps; + // Per-layer n_head_kv from array (or fill scalar) + out.n_head_kv_per_layer.resize(n_layer); + if (!head_kv_arr.empty()) { + for (uint32_t il = 0; il < n_layer; ++il) { + out.n_head_kv_per_layer[il] = (int)(il < head_kv_arr.size() ? head_kv_arr[il] : head_kv_arr.back()); + } + } else { + for (uint32_t il = 0; il < n_layer; ++il) { + out.n_head_kv_per_layer[il] = (int)n_head_kv; + } + } + // KV sharing: last shared_kv layers reuse earlier KV out.kv_sharing_start = (int)(n_layer - shared_kv); out.has_kv.resize(n_layer); @@ -144,26 +190,38 @@ bool load_gemma4_gguf(const std::string & path, out.has_kv[il] = (int)il < out.kv_sharing_start; } - // SWA pattern from GGUF (array of bools or compute from pattern) + // SWA pattern from GGUF (array of bools indicating SWA layers) out.swa_layers.resize(n_layer, false); { int64_t pat_id = gguf_find_key(gctx, "gemma4.attention.sliding_window_pattern"); if (pat_id >= 0 && gguf_get_kv_type(gctx, pat_id) == GGUF_TYPE_ARRAY) { const size_t n = gguf_get_arr_n(gctx, pat_id); - // Pattern repeats over layers - for (uint32_t il = 0; il < n_layer; ++il) { - // 0 = full, 1 = SWA in the pattern - // Read from array, cycling if needed - if (n > 0) { - // For gemma4 the pattern is typically stored as a boolean array - // indicating whether each layer is SWA - out.swa_layers[il] = (il % n != 0); // first in group = full + if (n > 0) { + const auto arr_type = gguf_get_arr_type(gctx, pat_id); + const void * data = gguf_get_arr_data(gctx, pat_id); + for (uint32_t il = 0; il < n_layer; ++il) { + size_t idx = il % n; // cycle if pattern shorter than n_layer + if (arr_type == GGUF_TYPE_BOOL || arr_type == GGUF_TYPE_UINT8) { + out.swa_layers[il] = ((const uint8_t *)data)[idx] != 0; + } else if (arr_type == GGUF_TYPE_INT32 || arr_type == GGUF_TYPE_UINT32) { + out.swa_layers[il] = ((const uint32_t *)data)[idx] != 0; + } } } } else { - // Default: every 4th layer is full, rest SWA (like laguna) - for (uint32_t il = 0; il < n_layer; ++il) { - out.swa_layers[il] = (il % 4 != 0); + // Fallback: infer from per-layer head_kv (small kv = full, large kv = swa) + // or use default pattern (alternating 5:1) + if (!head_kv_arr.empty() && head_kv_arr.size() >= n_layer) { + uint32_t max_kv = *std::max_element(head_kv_arr.begin(), head_kv_arr.end()); + for (uint32_t il = 0; il < n_layer; ++il) { + // Layers with max n_head_kv are SWA, smaller are full-attn + out.swa_layers[il] = (head_kv_arr[il] == max_kv); + } + } else { + // Default: every 6th layer is full, rest SWA (5:1 pattern) + for (uint32_t il = 0; il < n_layer; ++il) { + out.swa_layers[il] = ((il % 6) != 5); + } } } } @@ -177,8 +235,8 @@ bool load_gemma4_gguf(const std::string & path, if (out.eos_id == (int32_t)miss) out.eos_id = 1; if (out.eos_chat_id == (int32_t)miss) out.eos_chat_id = -1; - std::printf("[gemma4-loader] n_layer=%u n_embd=%u head_dim=%u n_head=%u n_head_kv=%u\n", - n_layer, n_embd, head_dim, n_head, n_head_kv); + std::printf("[gemma4-loader] n_layer=%u n_embd=%u head_dim_swa=%u head_dim_full=%u n_head=%u n_head_kv=%u\n", + n_layer, n_embd, head_dim_swa, head_dim_full, n_head, n_head_kv); std::printf("[gemma4-loader] n_expert=%u used=%u dense_lead=%u sliding_window=%u\n", n_expert, n_expert_used, n_dense_lead, sliding_win); std::printf("[gemma4-loader] kv_sharing_start=%d per_layer_embd=%u logit_softcap=%g\n", @@ -221,8 +279,8 @@ bool load_gemma4_gguf(const std::string & path, tok_embd_off = offset; tok_embd_sz = sz; tok_embd_type = t->type; - // Set data pointer for metadata but don't copy to GPU - t->data = (void *)src; + // Upload to GPU (needed for tied lm_head / output) + ggml_backend_tensor_set(t, src, 0, sz); } else { ggml_backend_tensor_set(t, src, 0, sz); } @@ -244,7 +302,11 @@ bool load_gemma4_gguf(const std::string & path, // ── Assign tensors to struct ─────────────────────────────────────── out.tok_embd = find_tensor(meta_ctx, "token_embd.weight"); out.out_norm = find_tensor(meta_ctx, "output_norm.weight"); + // Gemma4 uses tied embeddings: lm_head = token_embd out.output = find_tensor(meta_ctx, "output.weight"); + if (!out.output) out.output = out.tok_embd; + // Global rope_freqs (not per-layer in this model variant) + out.rope_freqs_global = find_tensor(meta_ctx, "rope_freqs.weight"); out.per_layer_tok_embd = find_tensor(meta_ctx, "per_layer_tok_embd.weight"); out.per_layer_model_proj = find_tensor(meta_ctx, "per_layer_model_proj.weight"); out.per_layer_proj_norm = find_tensor(meta_ctx, "per_layer_proj_norm.weight"); @@ -266,35 +328,35 @@ bool load_gemma4_gguf(const std::string & path, L.wo = get("attn_output.weight"); L.q_norm = get("attn_q_norm.weight"); L.k_norm = get("attn_k_norm.weight"); - L.attn_post_norm = get("attn_post_norm.weight"); + L.attn_post_norm = get("post_attention_norm.weight"); L.rope_freqs = get("rope_freqs.weight"); L.ffn_norm = get("ffn_norm.weight"); L.ffn_gate = get("ffn_gate.weight"); L.ffn_up = get("ffn_up.weight"); L.ffn_down = get("ffn_down.weight"); - L.ffn_post_norm = get("ffn_post_norm.weight"); + L.ffn_post_norm = get("post_ffw_norm.weight"); - // MoE tensors - L.ffn_norm_moe = get("ffn_norm.weight"); // same tensor for both paths + // MoE tensors (only present for MoE models) + L.ffn_norm_moe = get("ffn_norm_moe.weight"); L.ffn_gate_inp = get("ffn_gate_inp.weight"); - L.ffn_gate_inp_s = get("ffn_gate_inp_shexp.weight"); + L.ffn_gate_inp_s = get("ffn_gate_inp.scale"); L.ffn_gate_up_exps = get("ffn_gate_up_exps.weight"); L.ffn_down_exps = get("ffn_down_exps.weight"); - L.ffn_down_exps_s = get("ffn_down_exps_s.weight"); + L.ffn_down_exps_s = get("ffn_down_exps.scale"); L.ffn_gate_shexp = get("ffn_gate_shexp.weight"); L.ffn_up_shexp = get("ffn_up_shexp.weight"); L.ffn_down_shexp = get("ffn_down_shexp.weight"); - L.ffn_pre_norm_2 = get("ffn_pre_norm_2.weight"); - L.ffn_post_norm_1 = get("ffn_post_norm_1.weight"); - L.ffn_post_norm_2 = get("ffn_post_norm_2.weight"); + L.ffn_pre_norm_2 = get("pre_ffw_norm_2.weight"); + L.ffn_post_norm_1 = get("post_ffw_norm_1.weight"); + L.ffn_post_norm_2 = get("post_ffw_norm_2.weight"); // Per-layer embedding L.per_layer_inp_gate = get("per_layer_inp_gate.weight"); L.per_layer_proj = get("per_layer_proj.weight"); L.per_layer_post_norm = get("per_layer_post_norm.weight"); - L.out_scale = get("out_scale.weight"); + L.out_scale = get("layer_output_scale.weight"); } std::printf("[gemma4-loader] loaded %d tensors, vocab=%d\n", n_tensors, (int)n_vocab); @@ -314,9 +376,6 @@ void free_gemma4_weights(Gemma4Weights & w) { bool create_gemma4_cache(ggml_backend_t backend, const Gemma4Weights & w, int max_ctx, Gemma4Cache & out) { - const int D = w.head_dim; - const int Hk = w.n_head_kv; - ggml_init_params ip{}; ip.mem_size = ggml_tensor_overhead() * (size_t)(w.n_layer * 2 + 4) + 4096; ip.no_alloc = true; @@ -327,12 +386,20 @@ bool create_gemma4_cache(ggml_backend_t backend, const Gemma4Weights & w, out.v.resize(w.n_layer, nullptr); out.kv_source.resize(w.n_layer); + // SWA layers use a ring buffer of size min(sliding_window, max_ctx). + const int swa_size = (w.sliding_window > 0 && w.sliding_window < max_ctx) + ? w.sliding_window : max_ctx; + // Determine KV source for each layer int last_kv_layer = -1; for (int il = 0; il < w.n_layer; ++il) { if (w.has_kv[il]) { - out.k[il] = ggml_new_tensor_3d(out.ctx, GGML_TYPE_F16, D, Hk, max_ctx); - out.v[il] = ggml_new_tensor_3d(out.ctx, GGML_TYPE_F16, D, Hk, max_ctx); + const int D = gemma4_head_dim(w, il); + const int Hk = gemma4_n_head_kv(w, il); + const bool is_swa = gemma4_is_swa_layer(w, il); + const int cache_len = is_swa ? swa_size : max_ctx; + out.k[il] = ggml_new_tensor_3d(out.ctx, GGML_TYPE_F16, D, cache_len, Hk); + out.v[il] = ggml_new_tensor_3d(out.ctx, GGML_TYPE_F16, D, cache_len, Hk); out.kv_source[il] = il; last_kv_layer = il; } else { @@ -350,21 +417,71 @@ bool create_gemma4_cache(ggml_backend_t backend, const Gemma4Weights & w, out.cur_pos = 0; out.max_ctx = max_ctx; out.n_layer = w.n_layer; + out.swa_size = swa_size; return true; } void free_gemma4_cache(Gemma4Cache & c) { + if (c.feat_buf) { ggml_backend_buffer_free(c.feat_buf); c.feat_buf = nullptr; } + if (c.feat_ctx) { ggml_free(c.feat_ctx); c.feat_ctx = nullptr; } + c.target_feat = nullptr; + c.target_feat_cap = 0; + c.n_capture_layers = 0; + c.capture_layer_ids.clear(); if (c.buf) { ggml_backend_buffer_free(c.buf); c.buf = nullptr; } if (c.ctx) { ggml_free(c.ctx); c.ctx = nullptr; } c.k.clear(); c.v.clear(); c.kv_source.clear(); c.cur_pos = 0; } +bool create_gemma4_target_feat(ggml_backend_t backend, Gemma4Cache & cache, + int n_capture_layers, int hidden_size, int cap) { + if (n_capture_layers <= 0 || hidden_size <= 0 || cap <= 0) return false; + + // Free existing feat allocation + if (cache.feat_buf) { ggml_backend_buffer_free(cache.feat_buf); cache.feat_buf = nullptr; } + if (cache.feat_ctx) { ggml_free(cache.feat_ctx); cache.feat_ctx = nullptr; } + + ggml_init_params ip{}; + ip.mem_size = ggml_tensor_overhead() * 4 + 4096; + ip.no_alloc = true; + cache.feat_ctx = ggml_init(ip); + if (!cache.feat_ctx) return false; + + const int fc_in = n_capture_layers * hidden_size; + cache.target_feat = ggml_new_tensor_2d(cache.feat_ctx, GGML_TYPE_BF16, fc_in, cap); + ggml_set_name(cache.target_feat, "gemma4_target_feat"); + + cache.feat_buf = ggml_backend_alloc_ctx_tensors(cache.feat_ctx, backend); + if (!cache.feat_buf) { + ggml_free(cache.feat_ctx); cache.feat_ctx = nullptr; + cache.target_feat = nullptr; + return false; + } + + cache.target_feat_cap = cap; + cache.n_capture_layers = n_capture_layers; + + // Compute capture layer IDs using floating-point linspace with rounding. + // This matches the training config (e.g., gemma4: [1,12,23,35,46,57]). + cache.capture_layer_ids.resize(n_capture_layers); + const int n_layer = cache.n_layer; + for (int k = 0; k < n_capture_layers; k++) { + cache.capture_layer_ids[k] = (int)std::round( + 1.0 + k * (double)(n_layer - 4) / (n_capture_layers - 1)); + } + + return true; +} + void free_gemma4_snapshot(Gemma4Snapshot & s) { if (s.buf) { ggml_backend_buffer_free(s.buf); s.buf = nullptr; } if (s.ctx) { ggml_free(s.ctx); s.ctx = nullptr; } s.k_snap.clear(); s.v_snap.clear(); - s.cur_pos = 0; + s.feat_snap = nullptr; + s.feat_cap = 0; + s.cur_pos = 0; + s.last_tok = -1; } } // namespace dflash::common diff --git a/dflash/src/internal.h b/dflash/src/internal.h index 6f5666df..36f2064f 100644 --- a/dflash/src/internal.h +++ b/dflash/src/internal.h @@ -230,6 +230,7 @@ struct DraftWeights { int n_embd = DFLASH27B_TARGET_HIDDEN; // 5120 int n_ff = DFLASH27B_TARGET_INTERMEDIATE; // 17408 int swa_window = 0; // sliding window size (0 = disabled) + float rope_theta = 0.0f; // RoPE frequency base (must come from GGUF) // DFlash draft-specific config (populated by loader or set by caller). int block_size = DFLASH27B_DRAFT_BLOCK_SIZE; // tokens per draft step (16 or 10) diff --git a/dflash/src/qwen3/qwen3_graph.cpp b/dflash/src/qwen3/qwen3_graph.cpp index c907546f..085a2342 100644 --- a/dflash/src/qwen3/qwen3_graph.cpp +++ b/dflash/src/qwen3/qwen3_graph.cpp @@ -509,76 +509,21 @@ bool forward_qwen3_drafter_model( } // ── Attention dispatch ── - // Three paths: - // 1. BF16 WMMA (sm_80+, HIP Phase 2): flash_prefill_forward_bf16 - // 2. F16 WMMA (Volta/Turing): flash_prefill_forward_f16 - // 3. ggml flash_attn_ext: fallback for all other cases auto tF0 = std::chrono::steady_clock::now(); - const bool use_bf16_fp = (Q_buf.t->type == GGML_TYPE_BF16) -#if defined(DFLASH27B_HAVE_FLASHPREFILL) || defined(DFLASH27B_HAVE_SM80_FLASHPREFILL) - && true; -#else - && false; -#endif - const bool use_f16_fp = (Q_buf.t->type == GGML_TYPE_F16) -#if defined(DFLASH27B_HAVE_VOLTA_FLASHPREFILL) || defined(DFLASH27B_HAVE_PASCAL_FLASHPREFILL) - && true; -#else - && false; -#endif - if (use_bf16_fp) { -#if defined(DFLASH27B_HAVE_FLASHPREFILL) || defined(DFLASH27B_HAVE_SM80_FLASHPREFILL) - int rc = flashprefill::flash_prefill_forward_bf16( - Q_buf.t->data, - K_curr_v[il].t->data, - V_curr_v[il].t->data, - attn_out_buf.t->data, - 1, S, H, Hk, D, scale, fp_cfg); - if (rc != 0) { - set_last_error("flash_prefill_forward_bf16 failed at layer " + std::to_string(il)); - ggml_gallocr_free(galloc); cleanup_all(); return false; - } - cudaError_t e = cudaGetLastError(); - if (e != cudaSuccess) { - set_last_error(std::string("flash_prefill cuda error: ") + cudaGetErrorString(e)); - ggml_gallocr_free(galloc); cleanup_all(); return false; - } - cudaDeviceSynchronize(); -#endif - } else if (use_f16_fp) { -#if defined(DFLASH27B_HAVE_VOLTA_FLASHPREFILL) || defined(DFLASH27B_HAVE_PASCAL_FLASHPREFILL) - int rc = flashprefill::flash_prefill_forward_f16( - Q_buf.t->data, - K_curr_v[il].t->data, - V_curr_v[il].t->data, - attn_out_buf.t->data, - 1, S, H, Hk, D, scale, fp_cfg); - if (rc != 0) { - set_last_error("flash_prefill_forward_f16 failed at layer " + std::to_string(il)); - ggml_gallocr_free(galloc); cleanup_all(); return false; - } - cudaError_t e = cudaGetLastError(); - if (e != cudaSuccess) { - set_last_error(std::string("flash_prefill-f16 cuda error: ") + cudaGetErrorString(e)); - ggml_gallocr_free(galloc); cleanup_all(); return false; - } - cudaDeviceSynchronize(); -#endif - } else { - int rc = flashprefill::flash_prefill_forward_q8( - w.backend, - Q_buf.t->data, - K_curr_v[il].t->data, - V_curr_v[il].t->data, - attn_out_buf.t->data, - 1, S, H, Hk, D, scale, - Q_buf.t->type, - fp_cfg); - if (rc != 0) { - set_last_error("flash_prefill_forward_q8 failed at layer " + std::to_string(il)); - ggml_gallocr_free(galloc); cleanup_all(); return false; - } + int rc = flashprefill::flash_prefill_forward( + w.backend, + Q_buf.t->data, + K_curr_v[il].t->data, + V_curr_v[il].t->data, + attn_out_buf.t->data, + 1, S, H, Hk, D, scale, + Q_buf.t->type, + fp_cfg); + if (rc != 0) { + set_last_error("flash_prefill_forward failed at layer " + std::to_string(il)); + ggml_gallocr_free(galloc); cleanup_all(); return false; } + cudaDeviceSynchronize(); auto tF1 = std::chrono::steady_clock::now(); t_fp += std::chrono::duration(tF1 - tF0).count(); if (debug_first_layer) { diff --git a/dflash/src/server/chat_template.cpp b/dflash/src/server/chat_template.cpp index 92c46588..d7c026ab 100644 --- a/dflash/src/server/chat_template.cpp +++ b/dflash/src/server/chat_template.cpp @@ -36,7 +36,8 @@ static const char QWEN3_TOOL_SUFFIX[] = ChatFormat chat_format_for_arch(const std::string & arch) { if (arch == "laguna") return ChatFormat::LAGUNA; - // qwen35, qwen3, gemma4 all use the Qwen3/ChatML format + if (arch == "gemma4") return ChatFormat::GEMMA4; + // qwen35, qwen3 use the Qwen3/ChatML format return ChatFormat::QWEN3; } @@ -150,6 +151,40 @@ std::string render_chat_template( } break; } + + case ChatFormat::GEMMA4: { + // Gemma4 format: + // <|turn>user\n{msg}\n<|turn>model\n + // System messages are prepended to the first user message. + result = ""; + std::string system_content; + size_t start_idx = 0; + if (!messages.empty() && messages[0].role == "system") { + system_content = messages[0].content; + start_idx = 1; + } + + for (size_t i = start_idx; i < messages.size(); i++) { + const auto & msg = messages[i]; + std::string role = msg.role; + if (role == "assistant") role = "model"; + + result += "<|turn>"; + result += role; + result += '\n'; + // Inject system content at the start of the first user message. + if (i == start_idx && !system_content.empty() && msg.role == "user") { + result += system_content; + result += "\n\n"; + } + result += msg.content; + result += "\n"; + } + if (add_generation_prompt) { + result += "<|turn>model\n"; + } + break; + } } return result; diff --git a/dflash/src/server/chat_template.h b/dflash/src/server/chat_template.h index 5f35f492..e1ea2134 100644 --- a/dflash/src/server/chat_template.h +++ b/dflash/src/server/chat_template.h @@ -24,6 +24,7 @@ struct ChatMessage { enum class ChatFormat { QWEN3, // <|im_start|>role\n...<|im_end|>\n LAGUNA, // <|begin_of_sentence|><|User|>...<|Assistant|> + GEMMA4, // <|turn>role\n...\n }; // Render chat messages into the model-specific prompt string. diff --git a/dflash/src/server/http_server.cpp b/dflash/src/server/http_server.cpp index 8188fd07..8be6ec7c 100644 --- a/dflash/src/server/http_server.cpp +++ b/dflash/src/server/http_server.cpp @@ -443,7 +443,6 @@ bool HttpServer::route_request(int fd, const HttpRequest & hr) { true, enable_thinking, tools_json); req.prompt_tokens = tokenizer_.encode(rendered); - // Detect if prompt ends with (model will start in reasoning mode). if (enable_thinking) { size_t end = rendered.size(); @@ -691,11 +690,35 @@ void HttpServer::worker_loop() { // Skip EOS/EOT/special tokens — don't forward to SSE. int32_t eos = tokenizer_.eos_id(); - if (token == eos) return true; - // Also skip common Qwen3 special tokens by checking if the raw - // token text starts with '<|' (e.g. <|im_end|>, <|im_start|>). + int32_t eot = tokenizer_.eos_chat_id(); + if (token == eos || token == eot) return true; + const std::string & raw = tokenizer_.raw_token(token); + + // Gemma4 thinking channel: map <|channel> → , \n + if (raw == "<|channel>") { + if (req.stream) { + auto chunks = emitter.emit_token(""); + for (const auto & chunk : chunks) + if (!send_all(fd, chunk.data(), chunk.size())) { client_disconnected = true; return false; } + } + return true; + } + if (raw == "") { + if (req.stream) { + auto chunks = emitter.emit_token("\n"); + for (const auto & chunk : chunks) + if (!send_all(fd, chunk.data(), chunk.size())) { client_disconnected = true; return false; } + } + return true; + } + + // Skip other special tokens (starting with <|, or any <...> except byte-fallback) if (raw.size() >= 2 && raw[0] == '<' && raw[1] == '|') return true; + if (raw.size() >= 2 && raw[0] == '<' && raw.back() == '>') { + if (!(raw.size() == 6 && raw[1] == '0' && raw[2] == 'x')) + return true; + } std::string text = tokenizer_.token_text(token); @@ -797,7 +820,15 @@ void HttpServer::worker_loop() { for (int32_t tok : result.tokens) { const std::string & raw = tokenizer_.raw_token(tok); if (tok == tokenizer_.eos_id()) continue; + if (tok == tokenizer_.eos_chat_id()) continue; + // Gemma4 channel → think mapping + if (raw == "<|channel>") { emitter.emit_token(""); continue; } + if (raw == "") { emitter.emit_token("\n"); continue; } if (raw.size() >= 2 && raw[0] == '<' && raw[1] == '|') continue; + if (raw.size() >= 2 && raw[0] == '<' && raw.back() == '>') { + if (!(raw.size() == 6 && raw[1] == '0' && raw[2] == 'x')) + continue; + } std::string text = tokenizer_.token_text(tok); emitter.emit_token(text); } diff --git a/dflash/src/server/http_server.h b/dflash/src/server/http_server.h index 24d075d8..832e7602 100644 --- a/dflash/src/server/http_server.h +++ b/dflash/src/server/http_server.h @@ -100,6 +100,9 @@ class HttpServer { // Set the optional pflash drafter tokenizer. void set_drafter_tokenizer(Tokenizer * tok) { drafter_tokenizer_ = tok; } + // Set the chat template format (detected from model arch). + void set_chat_format(ChatFormat fmt) { chat_format_ = fmt; } + // Start listening. Blocks until shutdown() is called. int run(); diff --git a/dflash/src/server/prefix_cache.cpp b/dflash/src/server/prefix_cache.cpp index 72ceae72..26657b27 100644 --- a/dflash/src/server/prefix_cache.cpp +++ b/dflash/src/server/prefix_cache.cpp @@ -25,6 +25,17 @@ bool resolve_chat_markers(const Tokenizer & tok, ChatMarkers & out) { return true; } + // Try Gemma family: <|turn> (start) and (end) are single tokens. + auto turn_start = tok.encode("<|turn>"); + auto turn_end = tok.encode(""); + if (turn_start.size() == 1 && turn_end.size() == 1) { + out.family = "gemma"; + out.sys_role_prefix = {turn_start[0]}; + out.end_msg_seqs = {{turn_end[0]}}; + out.next_role_starts = {{turn_start[0]}}; + return true; + } + // Try Laguna family: XML-style markers. auto start_sys = tok.encode(""); auto end_sys = tok.encode(""); diff --git a/dflash/src/server/prefix_cache.h b/dflash/src/server/prefix_cache.h index cb0c551b..b9ec001f 100644 --- a/dflash/src/server/prefix_cache.h +++ b/dflash/src/server/prefix_cache.h @@ -26,14 +26,14 @@ namespace dflash::common { // ─── Chat marker detection ────────────────────────────────────────────── struct ChatMarkers { - std::string family; // "qwen" or "laguna" + std::string family; // "qwen", "gemma", or "laguna" // Token sequences for boundary detection std::vector sys_role_prefix; std::vector> end_msg_seqs; std::vector> next_role_starts; }; -// Resolve chat markers from the tokenizer (detects Qwen vs Laguna family). +// Resolve chat markers from the tokenizer (detects Qwen, Gemma, or Laguna family). bool resolve_chat_markers(const Tokenizer & tok, ChatMarkers & out); // Find all turn-boundary cut points in a token stream. diff --git a/dflash/src/server/server_main.cpp b/dflash/src/server/server_main.cpp index 357c32a0..69cfb752 100644 --- a/dflash/src/server/server_main.cpp +++ b/dflash/src/server/server_main.cpp @@ -12,10 +12,12 @@ // [--max-tokens 4096] [--target-device auto:0] #include "http_server.h" +#include "chat_template.h" #include "common/backend_factory.h" #include "common/gguf_inspect.h" #include "common/peer_access.h" +#include #include #include #include @@ -331,6 +333,7 @@ int main(int argc, char ** argv) { // Create backend. g_peer_access_opt_in = bargs.device.peer_access; std::fprintf(stderr, "[server] creating backend...\n"); + const std::string arch = detect_arch(bargs.model_path); auto backend = create_backend(bargs); if (!backend) { std::fprintf(stderr, "[server] backend creation failed\n"); @@ -395,6 +398,7 @@ int main(int argc, char ** argv) { std::fprintf(stderr, "[server] ╰─────────────────────────────────────────────────────╯\n\n"); HttpServer server(*backend, tokenizer, sconfig); + server.set_chat_format(chat_format_for_arch(arch)); g_server = &server; std::signal(SIGTERM, signal_handler); std::signal(SIGINT, signal_handler); diff --git a/dflash/src/server/tokenizer.cpp b/dflash/src/server/tokenizer.cpp index 1f538682..5ff4b1a7 100644 --- a/dflash/src/server/tokenizer.cpp +++ b/dflash/src/server/tokenizer.cpp @@ -364,34 +364,76 @@ static std::string encode_gpt2_bpe(const std::string & text) { std::vector Tokenizer::bpe_encode_piece(const std::string & piece) const { if (piece.empty()) return {}; - // Convert raw text to GPT-2 byte encoding for vocab lookup. - // The GGUF vocab stores tokens in GPT-2's byte-to-unicode encoding, - // so " world" becomes "Ġworld" (space 0x20 → Ġ U+0120). - std::string encoded = encode_gpt2_bpe(piece); - - // Start with individual bytes/chars as initial symbols. - // Each symbol is a string that we look up in the vocab. std::vector symbols; - // Try to find the encoded piece as a single token first. - auto it = token_to_id_.find(encoded); - if (it != token_to_id_.end()) { - return { it->second }; - } + if (is_sentencepiece_) { + // SentencePiece: replace leading space with ▁, tokens are raw UTF-8. + std::string sp_piece; + sp_piece.reserve(piece.size() + 2); + size_t start = 0; + if (!piece.empty() && piece[0] == ' ') { + sp_piece += "\xe2\x96\x81"; // ▁ (U+2581) + start = 1; + } + sp_piece += piece.substr(start); + // Replace any remaining spaces with ▁ + std::string encoded; + encoded.reserve(sp_piece.size()); + for (char c : sp_piece) { + if (c == ' ') { + encoded += "\xe2\x96\x81"; + } else { + encoded += c; + } + } - // Split into individual GPT-2-encoded bytes as initial BPE symbols. - // Each raw byte becomes a single GPT-2 Unicode character (possibly multi-byte UTF-8). - for (size_t i = 0; i < piece.size(); i++) { - std::string sym = byte_to_gpt2_unicode((uint8_t)piece[i]); - auto sit = token_to_id_.find(sym); - if (sit != token_to_id_.end()) { - symbols.push_back(sym); - } else { - // Byte-fallback: use <0xNN> tokens - char buf[8]; - std::snprintf(buf, sizeof(buf), "<0x%02X>", - (unsigned)(uint8_t)piece[i]); - symbols.push_back(buf); + // Try whole piece as single token. + auto it = token_to_id_.find(encoded); + if (it != token_to_id_.end()) { + return { it->second }; + } + + // Split into individual UTF-8 characters as initial BPE symbols. + const char * p = encoded.c_str(); + const char * end = p + encoded.size(); + while (p < end) { + int cplen; + utf8_decode(p, (size_t)(end - p), &cplen); + if (cplen <= 0) cplen = 1; + std::string sym(p, cplen); + auto sit = token_to_id_.find(sym); + if (sit != token_to_id_.end()) { + symbols.push_back(sym); + } else { + // Byte-fallback: <0xNN> + char buf[8]; + std::snprintf(buf, sizeof(buf), "<0x%02X>", (unsigned)(uint8_t)*p); + symbols.push_back(buf); + } + p += cplen; + } + } else { + // GPT-2 BPE: convert raw text to GPT-2 byte encoding for vocab lookup. + std::string encoded = encode_gpt2_bpe(piece); + + // Try to find the encoded piece as a single token first. + auto it = token_to_id_.find(encoded); + if (it != token_to_id_.end()) { + return { it->second }; + } + + // Split into individual GPT-2-encoded bytes as initial BPE symbols. + for (size_t i = 0; i < piece.size(); i++) { + std::string sym = byte_to_gpt2_unicode((uint8_t)piece[i]); + auto sit = token_to_id_.find(sym); + if (sit != token_to_id_.end()) { + symbols.push_back(sym); + } else { + char buf[8]; + std::snprintf(buf, sizeof(buf), "<0x%02X>", + (unsigned)(uint8_t)piece[i]); + symbols.push_back(buf); + } } } @@ -538,6 +580,18 @@ bool Tokenizer::load_from_gguf(const char * model_path) { added_tokens_.size()); } + // Detect tokenizer model type (sentencepiece vs bpe). + int model_key = gguf_find_key(gctx, "tokenizer.ggml.model"); + if (model_key >= 0) { + const char * model = gguf_get_val_str(gctx, model_key); + // SentencePiece models store tokens as raw UTF-8 with ▁ for space. + // GPT-2/BPE models use byte-level Unicode encoding. + if (model && (std::strcmp(model, "llama") == 0 || + std::strncmp(model, "gemma", 5) == 0)) { + is_sentencepiece_ = true; + } + } + // Detect pre-tokenizer type. int pre_key = gguf_find_key(gctx, "tokenizer.ggml.pre"); if (pre_key >= 0) { @@ -564,12 +618,18 @@ bool Tokenizer::load_from_gguf(const char * model_path) { auto eot = token_to_id_.find("<|im_end|>"); if (eot != token_to_id_.end()) eos_chat_id_ = eot->second; } + if (eos_chat_id_ < 0) { + // Gemma4 uses as end-of-turn. + auto eot = token_to_id_.find(""); + if (eot != token_to_id_.end()) eos_chat_id_ = eot->second; + } gguf_free(gctx); - std::fprintf(stderr, "[tokenizer] loaded vocab=%d merges=%zu bos=%d eos=%d eot=%d pre=%s\n", + std::fprintf(stderr, "[tokenizer] loaded vocab=%d merges=%zu bos=%d eos=%d eot=%d pre=%s sp=%s\n", n_vocab, merge_rank_.size(), bos_id_, eos_id_, eos_chat_id_, - pre_type_ == PreTokenizer::QWEN35 ? "qwen35" : "qwen2"); + pre_type_ == PreTokenizer::QWEN35 ? "qwen35" : "qwen2", + is_sentencepiece_ ? "yes" : "no"); return true; } @@ -686,11 +746,33 @@ std::string Tokenizer::token_text(int32_t id) const { } } - // Special tokens (e.g. <|im_start|>) — return as-is. + // Special tokens (e.g. <|im_start|>, ) — return as-is. if (!tok.empty() && tok[0] == '<' && tok.back() == '>') { return tok; } + if (is_sentencepiece_) { + // SentencePiece: tokens are raw UTF-8 with ▁ (U+2581) for space. + std::string out; + out.reserve(tok.size()); + const char * p = tok.c_str(); + const char * end = p + tok.size(); + while (p < end) { + // ▁ is 3 bytes: 0xE2 0x96 0x81 + if (end - p >= 3 && + (uint8_t)p[0] == 0xE2 && + (uint8_t)p[1] == 0x96 && + (uint8_t)p[2] == 0x81) { + out.push_back(' '); + p += 3; + } else { + out.push_back(*p); + p++; + } + } + return out; + } + // Decode GPT-2 byte-level BPE encoding → raw bytes. return decode_gpt2_bpe(tok); } diff --git a/dflash/src/server/tokenizer.h b/dflash/src/server/tokenizer.h index f28dfd9f..5484fa47 100644 --- a/dflash/src/server/tokenizer.h +++ b/dflash/src/server/tokenizer.h @@ -46,6 +46,7 @@ class Tokenizer { // ─── Special tokens ────────────────────────────────────────────── int32_t eos_id() const { return eos_id_; } + int32_t eos_chat_id() const { return eos_chat_id_; } int32_t bos_id() const { return bos_id_; } int32_t vocab_size() const { return (int32_t)id_to_token_.size(); } @@ -80,6 +81,10 @@ class Tokenizer { // Pre-tokenizer type enum class PreTokenizer { QWEN2, QWEN35 }; PreTokenizer pre_type_ = PreTokenizer::QWEN35; + + // Decode mode: SentencePiece tokens use UTF-8 with ▁ for space; + // GPT-2/BPE tokens use byte-level Unicode encoding. + bool is_sentencepiece_ = false; }; } // namespace dflash::common