Gemma4: Full DFlash Integration (Speculative Decode + BSA Prefill + Prefix Cache)#232
Conversation
There was a problem hiding this comment.
3 issues found across 25 files
Reply with feedback, questions, or to request a fix.
Re-trigger cubic
|
Thanks for the PR @howard0su . Gemma4 never actually enters the BSA sparse-FA path right now. Every layer falls through the guard and uses dense WMMA instead. Is that expected for this PR? In case it is, no worries. It can be added later. Only if needed here something to consider:
Cheap fix: Both findings reproduced against PR head |
|
good catch. I will check it. I only have a 2080ti which doesn't support BSA. Will try to find a env to debug. |
There was a problem hiding this comment.
1 issue found across 4 files (changes from recent commits).
Reply with feedback, questions, or to request a fix.
Re-trigger cubic
Loader fixes: - Handle array-typed metadata (head_count_kv is per-layer array) - Fallback n_vocab from token_embd.weight tensor shape - Default missing keys (expert_count, etc.) to 0 - Separate head_dim_full (512) and head_dim_swa (256) - Per-layer n_head_kv_per_layer vector from GGUF array - SWA pattern: read bool/uint8 array or infer from head_kv - Tied embeddings: output = tok_embd when output.weight absent - Tensor name mapping: post_attention_norm, post_ffw_norm, layer_output_scale - Global rope_freqs_global tensor support Graph fixes: - Per-layer head_dim and n_head_kv via helper functions - FA mask padding to 256 (FATTN_KQ_STRIDE) for CUDA compat - Use global rope_freqs for full-attn layers Cache: - Per-layer KV allocation with correct dimensions Validated: load + prefill + decode + snapshot + restore all pass on gemma-4-31B-it-Q4_K_M.gguf (RTX 2080 Ti). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Add Gemma4DFlashTarget class implementing the DFlashTarget interface: - verify_batch: full forward with all-token argmax via gemma4_verify_batch - snapshot_kv / restore_kv: full KV cache save/restore for rollback - embed_tokens: CPU embedder with sqrt(n_embd) scaling - project_hidden_to_tokens: lm_head projection via gemma4_project_hidden - capture_layer_ids: evenly-spaced 5 layers (1, 15, 29, 43, 57) - mask_token_id: 0 (padding token) New graph functions: - gemma4_verify_batch(): like gemma4_step but returns all-position argmax - gemma4_project_hidden(): out_norm + lm_head + softcap + argmax Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Critical fixes for Gemma4 model inference: - Fix kq_scale: Gemma4 uses self.scaling=1.0 (not 1/sqrt(head_dim)) because Q/K already get per-head RMS norm. This was the root cause of garbage output (repeated token generation). - Add SentencePiece tokenizer support: Gemma4 tokens are raw UTF-8 with U+2581 for space, not GPT-2 byte-level encoding. Detects mode from tokenizer.ggml.model GGUF key. Handles encode (space->▁, UTF-8 char splitting) and decode (▁->space) correctly. - Fix KV cache layout: [D, max_ctx, Hk] matching Qwen35 convention, with per-head strided snapshot save/restore. - Add Gemma4 chat template: <bos><|turn>user\n...<turn|>\n<|turn>model\n - Map Gemma4 thinking channel (<|channel>...<channel|>) to existing <think>...</think> reasoning system for proper content separation. - Add eos_chat_id detection for <turn|> token (id 106). - Fix special token filtering in both streaming and non-streaming paths. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
park() now frees snapshots, KV cache, and model weights (releasing GPU memory). unpark() reloads weights from disk and recreates the KV cache. Also adds parked guards to generate(), restore_and_generate(), and snapshot_save() to prevent use while model is parked. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
G5: SWA layers now allocate min(sliding_window, max_ctx) KV cache instead of full max_ctx. Ring-buffer write (kv_start % swa_size) and ring-aware attention mask enable bounded memory for sliding-window layers. Prefill chunks are capped to avoid ring wrap. G6: Added fa_window config for sparse decode. Full-attention layers limit their FA read to the last fa_window positions during decode, reducing compute at long contexts. G3: Ported PFlash compress pipeline from Qwen35. Parks target, lazy-loads Qwen3-0.6B drafter, runs score_and_compress, emits surviving tokens, unparks. Drafter stays resident (~1.4 GB). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- Add target_feat ring buffer to Gemma4Cache for feature capture - Add feature capture nodes to build_gemma4_layer() (both step and verify) - Add draft model loading with metadata override (GGUF has wrong dimensions) - Infer n_capture_layers from fc weight shape (6 for Gemma4, not 5 from metadata) - Port do_spec_decode() loop from qwen35 backend - Wire spec-decode into generate() and restore_and_generate() (temp==0 only) - Sync captured features to DraftFeatureMirror after each prefill chunk - Store last_tok during prefill for spec-decode entry - Pass draft_path/draft_gpu/draft_ctx_max through BackendArgs to Gemma4BackendConfig - Clean up draft resources in shutdown() Tested: AR decode produces correct output, spec-decode pipeline runs end-to-end with 9.1 tok/s throughput. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Gemma4 uses <|turn> / <turn|> as single-token turn delimiters. Previously it incorrectly fell through to the Laguna family check because <system>/<user>/etc. would encode to non-empty sequences. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Three root-cause fixes identified from the HuggingFace model card
(z-lab/gemma-4-31B-it-DFlash config.json):
1. mask_token_id: use 4 instead of 0 — the draft model was trained
with token 4 as the mask/padding token.
2. capture_layer_ids: replace integer-truncation formula with
floating-point linspace + rounding. For 60 layers / 6 captures:
old: {1,12,23,34,45,56}, correct: {1,12,23,35,46,57}.
3. embed_tokens: remove sqrt(n_embd) scaling — the draft model
expects raw unscaled embeddings (same as qwen35 convention).
Also removes debug fprintf statements added during investigation.
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- Add causal attention mask for SWA layers in the draft model (layers 0-3 are sliding-window with causal masking, layer 4 is full non-causal). The draft was trained this way; running all-non-causal let future MASK embeddings leak into earlier positions, hurting acceptance rate. - Read rope_theta from draft GGUF metadata instead of hardcoded 10M constant (Gemma4 draft uses 1M, not 10M like Qwen3.5). - Remove double-normalization: gemma4_project_hidden now skips out_norm since the draft already applies its own final norm layer. - Scale embed_tokens by sqrt(n_embd) in DFlashTarget to match Gemma4 convention. - Set swa_window=2048 and mark layers[0..3].is_swa after draft GGUF loading. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Match the pattern from attn_masks.h — create the causal mask tensor as GGML_TYPE_F16 directly and fill with uint16_t values (0x0000 for attend, 0xFC00 for -inf). This eliminates the intermediate ggml_cast op in the draft graph and reduces memory usage. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
The implementation file now matches its header (draft_graph.h), eliminating confusion with the similarly-named common/dflash_draft_graph.cpp orchestrator. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
The 10M default was Qwen3.5-specific and silently wrong for other models (e.g. Gemma4 uses 1M). Now rope_theta must come from the draft GGUF metadata; a warning is printed if the key is missing. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Gemma4 is a pure transformer — after verify, KV entries at accepted positions are already correct (causal masking guarantees independence from rejected tokens). Replace the expensive snapshot → verify → restore → replay pattern with: verify(16 tokens) → truncate KV → bonus(1 token) This eliminates: - 2x full KV cache copies (60 layers × K + V each direction) - The replay forward pass (~9 tokens through 60 layers) Measured ~2.2x speedup on RTX 2080 Ti (9.5 → 21 tok/s). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…d dispatch - Implement gemma4_prefill_bsa() for per-layer BSA prefill using flash_prefill_forward for SWA layers (head_dim=128) with dense FA fallback for full-attention layers (head_dim=256). - Write KV cache during Graph A (ring-buffer aware for SWA layers). - Add GGML_ASSERT guard for swa_size > 0 before modulo operation. - Add flash_prefill_forward() unified dispatch to flashprefill.h that selects bf16/f16/q8 kernel based on compile flags + buffer type. - Simplify Qwen3 attention dispatch to use the unified function. - Remove duplicated ifdef boilerplate from both model implementations. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
After restoring KV from a snapshot, do_prefill only syncs the feature mirror for the delta tokens [snap_pos..committed). The positions [0..snap_pos) in the mirror retain stale data from the previous request's decode phase (which may have diverged from the current prompt context after the ring buffer wraps). Fix: call draft_feature_mirror_sync_tail after restore to resync the entire [0..committed) feature range from cache_.target_feat to the mirror. This ensures the draft model sees consistent features and maintains high acceptance rate (AL) during speculative decoding. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> gemma4: save/restore target_feat in prefix cache snapshot Matching Qwen35's approach: save target_feat (BF16 feature ring buffer) and last_tok as part of the KV snapshot. On restore, target_feat is copied back to GPU before the delta prefill + feature mirror resync. Previously, only K/V tensors were snapshotted. After restore, the feature mirror contained stale data from the previous request's decode phase, causing the draft model to make poor predictions and halving speculative decode acceptance rate (52% → 24%). With this fix, the full feature state is correctly restored, and the subsequent draft_feature_mirror_sync_tail ensures the mirror matches. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Three fixes for gemma-4-26B-A4B-it (unsloth UD-Q4_K_M).
1. gemma4_graph.cpp:116 — GGML_ASSERT(ggml_is_contiguous(src0))
crash in ggml_cuda_op_gelu. gate_e and up_e are strided
ggml_view_3d halves of fused gate_up_e; CUDA gelu requires
contiguous src. Insert ggml_cont before ggml_gelu.
2. gemma4_loader.cpp tensor name mismatches with actual GGUF
metadata (silently loaded null → MoE produced gibberish):
ffn_gate_inp_shexp.weight → ffn_gate_inp.scale
ffn_down_exps_s.weight → ffn_down_exps.scale
ffn_pre_norm_2.weight → pre_ffw_norm_2.weight
ffn_post_norm_1.weight → post_ffw_norm_1.weight
ffn_post_norm_2.weight → post_ffw_norm_2.weight
3. leading_dense_block_count default 1 → 0. Gemma-4-26B-A4B GGUF
does not store this key; old default skipped MoE on layer 0,
running shared-expert only and corrupting downstream.
Verified: 'What is 2+2?' returns '2 + 2 = 4' on lucebox2 RTX 3090.
Co-Authored-By: WOZCODE <contact@withwoz.com>
# Conflicts: # dflash/src/gemma4/gemma4_backend.h
PR Luce-Org#236 (placement refactor) replaced BackendArgs::draft_gpu with BackendArgs::draft_device but only updated the qwen35 caller. PR Luce-Org#232 (gemma4 DFlash spec decode) merged after Luce-Org#236 and re-introduced args.draft_gpu in the gemma4 branch, breaking compilation of dflash_common on main. Caught by PR Luce-Org#252 CI build.
Adds complete Gemma4 support to the DFlash inference pipeline — speculative decode, BSA sparse-FA prefill, SWA ring-buffer KV cache, and prefix cache with correct feature snapshot/restore.
Key Changes
Speculative Decode (DFlash target)
BSA Sparse-FA Prefill
Prefix Cache
SWA Ring-Buffer & Architecture
Misc