Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
d5f32bf
gemma4: fix loader + graph for actual GGUF format
howard0su May 20, 2026
1315311
gemma4: implement DFlashTarget for speculative decode (G4)
howard0su May 20, 2026
9b26a2b
gemma4: fix attention scale, tokenizer decode, and server integration
howard0su May 20, 2026
f99ff75
gemma4: implement real park/unpark for VRAM management
howard0su May 20, 2026
c4a7ba6
gemma4: implement G5 SWA ring-buffer, G6 fa_window, G3 compress
howard0su May 20, 2026
1bfb720
gemma4: wire DFlash speculative decode into Gemma4 backend
howard0su May 20, 2026
2065995
prefix_cache: add Gemma family detection for chat markers
howard0su May 20, 2026
78aaa06
gemma4: fix DFlash spec-decode acceptance rate
howard0su May 20, 2026
f102502
gemma4 dflash: fix SWA causal masking and rope_theta
howard0su May 20, 2026
106a59e
draft: use F16 mask directly, remove unnecessary F32 cast
howard0su May 20, 2026
85bc4c3
draft: rename draft_dflash_graph.cpp → draft_graph.cpp to match header
howard0su May 20, 2026
03aeda5
draft: remove DFLASH27B_ROPE_THETA constant, read from GGUF only
howard0su May 20, 2026
9fe0ce4
gemma4 spec-decode: replace snapshot/replay with KV truncation
howard0su May 20, 2026
f854a11
gemma4: add BSA sparse-FA prefill path + unified flash_prefill_forwar…
howard0su May 21, 2026
3c433a9
gemma4: full feature mirror resync after prefix cache restore
howard0su May 21, 2026
d399796
gemma4: check graph_compute return in prefill_bsa per-layer embed
howard0su May 21, 2026
dfaf999
gemma4: fix namespace dflash27b -> dflash::common after rebase
howard0su May 21, 2026
d3720b6
gemma4: fix MoE GELU contig + loader tensor name mismatches
May 22, 2026
a013349
Merge branch 'main' into gemma4
davide221 May 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dflash/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ add_library(dflash_common STATIC
src/qwen35/qwen35_target_graph.cpp
src/draft/draft_gguf_loader.cpp
src/draft/draft_safetensors_loader.cpp
src/draft/draft_dflash_graph.cpp
src/draft/draft_graph.cpp
src/qwen3/qwen3_drafter.cpp
src/qwen3/qwen3_loader.cpp
src/qwen3/qwen3_graph.cpp
Expand All @@ -225,6 +225,7 @@ add_library(dflash_common STATIC
src/gemma4/gemma4_graph.cpp
src/gemma4/gemma4_backend.cpp
src/gemma4/gemma4_daemon.cpp
src/gemma4/gemma4_dflash_target.cpp
src/flashprefill_q8.cpp
src/kv_cache.cpp
src/kv_quant.cpp
Expand Down
2 changes: 1 addition & 1 deletion dflash/docs/SPEC_PREFILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ src/
qwen35_target_graph.cpp Qwen3.5/3.6 target graph (ggml)
gguf_target_loader.cpp Qwen3.5 target GGUF loader
draft/ Special DFlash draft model code
draft_dflash_graph.cpp DFlash speculative draft head
draft_graph.cpp DFlash speculative draft head
draft_gguf_loader.cpp Draft GGUF loader
draft_safetensors_loader.cpp Draft safetensors loader
laguna/ Laguna target + daemon model code
Expand Down
3 changes: 1 addition & 2 deletions dflash/include/dflash27b.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,12 @@ extern "C" {
// Qwen3.5-27B qwen35 hybrid uses 24 Q heads, 4 KV heads, 256 head_dim, which
// live in `src/internal.h` (n_embd_head_k/v, N_HEAD, N_HEAD_KV). Naming is
// historical — do not change without updating draft_safetensors_loader.cpp +
// draft_dflash_graph.cpp which consume these as draft-side constants.
// draft_graph.cpp which consume these as draft-side constants.
#define DFLASH27B_TARGET_N_HEADS 32
#define DFLASH27B_TARGET_N_KV_HEADS 8
#define DFLASH27B_TARGET_HEAD_DIM 128
#define DFLASH27B_TARGET_INTERMEDIATE 17408
#define DFLASH27B_TARGET_VOCAB 248320
#define DFLASH27B_ROPE_THETA 10000000.0f
#define DFLASH27B_RMS_EPS 1e-6f

#define DFLASH27B_DRAFT_LAYERS 5
Expand Down
11 changes: 7 additions & 4 deletions dflash/src/common/backend_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,13 @@ std::unique_ptr<ModelBackend> create_backend(const BackendArgs & args) {

} else if (arch == "gemma4") {
Gemma4BackendConfig gcfg;
gcfg.model_path = args.model_path;
gcfg.device = args.device;
gcfg.stream_fd = args.stream_fd;
gcfg.chunk = args.chunk;
gcfg.model_path = args.model_path;
gcfg.draft_path = args.draft_path;
gcfg.draft_gpu = args.draft_gpu;
gcfg.draft_ctx_max = args.draft_ctx_max;
gcfg.device = args.device;
gcfg.stream_fd = args.stream_fd;
gcfg.chunk = args.chunk;

auto backend = std::make_unique<Gemma4Backend>(gcfg);
if (!backend->init()) {
Expand Down
61 changes: 60 additions & 1 deletion dflash/src/common/dflash_draft_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,26 @@
#include "draft/draft_graph.h" // DraftGraphInputs, DraftGraphOutputs, build_draft_graph

#include "ggml-alloc.h"
#include "ggml-backend.h"

#include <cstdint>
#include <cstdio>
#include <vector>

namespace dflash::common {

// Minimum alignment required by ggml flash_attn_ext for mask rows.
static constexpr int MASK_KV_PAD = 32;

static inline int mask_align_up(int x, int a) { return ((x + a - 1) / a) * a; }

// Check whether any layer in the draft is SWA.
static bool draft_has_swa_layers(const DraftWeights & dw) {
for (int i = 0; i < dw.n_layer; i++)
if (dw.layers[i].is_swa) return true;
return false;
}

// Build draft graph at a given ctx_len into sg. Does NOT touch sg.alloc.
// mirror_view: if true, uses a view into mirror->target_feat at slot0.
static bool build_draft_graph_internal(
Expand Down Expand Up @@ -56,6 +71,21 @@ static bool build_draft_graph_internal(
ggml_set_name(sg.positions_k, "positions_k");
ggml_set_input(sg.positions_k);

// Causal mask for SWA layers (if any).
// Shape: [kv_pad, q_len] F16 (directly, no cast needed — matches attn_masks.h pattern).
sg.attn_mask = nullptr;
const bool has_swa = draft_has_swa_layers(dw);
if (has_swa) {
// SWA layers' effective KV length (windowed or full ctx)
const bool swa_active = dw.swa_window > 0 && ctx_len > dw.swa_window;
const int eff_ctx = swa_active ? dw.swa_window : ctx_len;
const int eff_total_k = eff_ctx + q_len;
const int kv_pad = mask_align_up(eff_total_k, MASK_KV_PAD);
sg.attn_mask = ggml_new_tensor_2d(sg.ctx, GGML_TYPE_F16, kv_pad, q_len);
ggml_set_name(sg.attn_mask, "causal_mask_swa");
ggml_set_input(sg.attn_mask);
}

sg.gf = ggml_new_graph_custom(sg.ctx, 4096, false);

DraftGraphInputs gi{};
Expand All @@ -65,6 +95,7 @@ static bool build_draft_graph_internal(
gi.positions_q = sg.positions;
gi.positions_k = sg.positions_k;
gi.lm_head = lm_head;
gi.causal_mask_swa = sg.attn_mask;
DraftGraphOutputs go = build_draft_graph(sg.ctx, dw, gi);
sg.hidden_states = go.hidden_states;
sg.logits = go.logits;
Expand Down Expand Up @@ -125,7 +156,35 @@ bool build_draft_step(
return false;
}

return ggml_gallocr_alloc_graph(sg.alloc, sg.gf);
if (!ggml_gallocr_alloc_graph(sg.alloc, sg.gf)) {
return false;
}

// Fill causal mask data for SWA layers (after allocation gives memory to the tensor).
if (sg.attn_mask) {
const int q_len = dw.block_size;
const bool swa_active = dw.swa_window > 0 && ctx_len > dw.swa_window;
const int eff_ctx = swa_active ? dw.swa_window : ctx_len;
const int eff_total_k = eff_ctx + q_len;
const int kv_pad = mask_align_up(eff_total_k, MASK_KV_PAD);

// Build causal mask in F16 directly (same pattern as attn_masks.h):
// Context keys (k < eff_ctx): always visible.
// Noise keys (k = eff_ctx + j): visible if j <= q (causal).
static constexpr uint16_t ZERO = 0x0000;
static constexpr uint16_t NEG_INF = 0xFC00;
std::vector<uint16_t> mask_data((size_t)kv_pad * q_len, NEG_INF);
for (int q = 0; q < q_len; q++) {
for (int k = 0; k < eff_ctx; k++)
mask_data[(size_t)q * kv_pad + k] = ZERO;
for (int j = 0; j <= q; j++)
mask_data[(size_t)q * kv_pad + (eff_ctx + j)] = ZERO;
}
ggml_backend_tensor_set(sg.attn_mask, mask_data.data(), 0,
sizeof(uint16_t) * mask_data.size());
}

return true;
}

} // namespace dflash::common
12 changes: 11 additions & 1 deletion dflash/src/draft/draft_gguf_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// on the CUDA backend.
//
// This is the Q8_0-quantized counterpart of draft_safetensors_loader.cpp. The
// draft graph builder (draft_dflash_graph.cpp) doesn't care about tensor storage
// draft graph builder (draft_graph.cpp) doesn't care about tensor storage
// types — ggml's ggml_mul_mat handles Q8_0 × F32 dequantization transparently.
//
// GGUF arch: "qwen35-dflash-draft" (from convert_dflash_to_gguf.py /
Expand Down Expand Up @@ -159,6 +159,12 @@ bool load_draft_gguf(const std::string & path,
std::snprintf(key, sizeof(key), "%s.%s", A, suffix);
return get_u32_or(gctx, key, fallback);
};
auto read_f32 = [&](const char * suffix, float fallback) -> float {
std::snprintf(key, sizeof(key), "%s.%s", A, suffix);
int64_t id = gguf_find_key(gctx, key);
if (id < 0) return fallback;
return gguf_get_val_f32(gctx, id);
};

const uint32_t n_embd = read_u32("embedding_length", 0);
const uint32_t n_layer = read_u32("block_count", 0);
Expand Down Expand Up @@ -235,6 +241,10 @@ bool load_draft_gguf(const std::string & path,
out.head_dim = (int)head_dim;
out.n_embd = (int)n_embd;
out.n_ff = (int)n_ff;
out.rope_theta = read_f32("rope.freq_base", 0.0f);
if (out.rope_theta == 0.0f) {
fprintf(stderr, "[draft-gguf] WARNING: rope.freq_base not found in GGUF, draft RoPE will be wrong\n");
}
Comment thread
howard0su marked this conversation as resolved.
out.layers.assign((size_t)n_layer, DraftLayer{});

auto g = [&](const char * name) -> ggml_tensor * {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,28 +1,29 @@
// Builds a ggml compute graph for one forward pass of the DFlash draft
// (5-layer non-causal Qwen3-flavored block-diffusion model).
// (5-layer Qwen3-flavored block-diffusion model).
//
// Stateless: no KV cache. Each call takes:
// - noise_embed [hidden, q_len, 1] bf16 (target.tok_embd on [last_tok, MASK*15])
// - target_hidden_cat [5*hidden, ctx_len, 1] bf16 (5 target layers concat along features)
// - noise_embed [hidden, q_len, 1] f32 (target.tok_embd on [last_tok, MASK*15])
// - target_hidden_cat [N*hidden, ctx_len, 1] f32 (N target layers concat along features)
// - positions_q [q_len] i32 values [ctx_len..ctx_len+q_len-1]
// - positions_k [ctx_len+q_len] i32 values [0..ctx_len+q_len-1]
// - causal_mask_swa [kv_pad, q_len] f32 (optional; causal mask for SWA layers)
// and returns:
// - hidden_states [hidden, q_len, 1] bf16 (final RMSNorm; NO lm_head here)
// - hidden_states [hidden, q_len, 1] f32 (final RMSNorm; NO lm_head here)
//
// The caller projects `hidden_states` through the TARGET's lm_head separately
// (the draft has no lm_head of its own, it shares the target's).
//
// Semantics match megaqwen3_27b_dflash/reference/dflash_reference.py exactly:
// Semantics:
// - fc @ target_hidden_cat -> rms_norm with hidden_norm -> target_feat
// - Per layer (non-causal):
// - Per layer:
// h_norm = rms_norm(h) * input_layernorm
// Q = wq @ h_norm -> per-head q_norm
// K_ctx/V_ctx = wk/wv @ target_feat
// K_noi/V_noi = wk/wv @ h_norm
// K = concat[K_ctx, K_noi] -> per-head k_norm
// V = concat[V_ctx, V_noi]
// RoPE(Q, positions_q); RoPE(K, positions_k) (NEOX style, theta=10M)
// attn = flash_attn_ext(Q, K, V, mask=null, scale=1/sqrt(head_dim)) non-causal
// RoPE(Q, positions_q); RoPE(K, positions_k) (NEOX style)
// attn = flash_attn_ext(Q, K, V, mask, scale) SWA=causal, full=non-causal
// h += wo @ attn
// h_norm = rms_norm(h) * post_attention_layernorm
// h += w_down @ (silu(w_gate @ h_norm) * (w_up @ h_norm))
Expand All @@ -46,7 +47,7 @@ DraftGraphOutputs build_draft_graph(
const int n_kv = w.n_head_kv;
const int head_dim = w.head_dim;
const float eps = DFLASH27B_RMS_EPS;
const float rope_base = DFLASH27B_ROPE_THETA;
const float rope_base = w.rope_theta;

// ── 1. Feature fusion: target_feat = rms_norm(fc @ target_hidden_cat, hidden_norm)
// fc: [5*hidden, hidden] (ggml: ne[0]=5*hidden, ne[1]=hidden)
Expand Down Expand Up @@ -134,9 +135,10 @@ DraftGraphOutputs build_draft_graph(
V = ggml_permute(ctx, V, 0, 2, 1, 3); // [head_dim, eff_total_k, n_kv, 1]
V = ggml_cont (ctx, V);

// ── 2f. Non-causal flash attention; GQA broadcast handled internally.
// ── 2f. Attention: causal for SWA layers, non-causal for full layers.
const float scale = 1.0f / std::sqrt((float)head_dim);
ggml_tensor * attn = ggml_flash_attn_ext(ctx, Q, K, V, /*mask=*/nullptr,
ggml_tensor * mask = (L.is_swa && in.causal_mask_swa) ? in.causal_mask_swa : nullptr;
ggml_tensor * attn = ggml_flash_attn_ext(ctx, Q, K, V, mask,
scale, /*max_bias=*/0.0f,
/*logit_softcap=*/0.0f);
// attn result: [n_embd_v=head_dim, n_head, n_batch=q_len, 1]
Expand Down
3 changes: 3 additions & 0 deletions dflash/src/draft/draft_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ struct DraftGraphInputs {
// hidden states. Used for DFlash integration where the draft shares the
// target's lm_head.
ggml_tensor * lm_head;
// Optional: causal mask for SWA layers [kv_pad, q_len] F16.
// nullptr = all layers non-causal.
ggml_tensor * causal_mask_swa = nullptr;
};

struct DraftGraphOutputs {
Expand Down
31 changes: 31 additions & 0 deletions dflash/src/flashprefill.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,37 @@ int flash_prefill_forward_q8(
ggml_type qkv_type,
const FlashPrefillConfig & cfg);

// ── Unified dispatch ──────────────────────────────────────────────────────────
// Picks the best available kernel at compile time + runtime buffer type:
// BF16 buffers + sm_80 build → flash_prefill_forward_bf16
// F16 buffers + Volta build → flash_prefill_forward_f16
// otherwise → flash_prefill_forward_q8 (ggml FA fallback)
//
// Callers no longer need to duplicate the ifdef/dispatch boilerplate.
inline int flash_prefill_forward(
ggml_backend_t backend,
const void * Q, const void * K, const void * V, void * O,
int batch, int seq_len, int n_q_heads, int n_k_heads, int head_dim,
float scale,
ggml_type qkv_type,
const FlashPrefillConfig & cfg)
{
#if defined(DFLASH27B_HAVE_FLASHPREFILL) || defined(DFLASH27B_HAVE_SM80_FLASHPREFILL)
if (qkv_type == GGML_TYPE_BF16) {
return flash_prefill_forward_bf16(Q, K, V, O,
batch, seq_len, n_q_heads, n_k_heads, head_dim, scale, cfg);
}
#endif
#if defined(DFLASH27B_HAVE_VOLTA_FLASHPREFILL) || defined(DFLASH27B_HAVE_PASCAL_FLASHPREFILL)
if (qkv_type == GGML_TYPE_F16) {
return flash_prefill_forward_f16(Q, K, V, O,
batch, seq_len, n_q_heads, n_k_heads, head_dim, scale, cfg);
}
#endif
return flash_prefill_forward_q8(backend, Q, K, V, O,
batch, seq_len, n_q_heads, n_k_heads, head_dim, scale, qkv_type, cfg);
}

#ifdef DFLASH27B_HAVE_BSA
// Free BSA persistent device buffers (blockmask, head_mask_type, softmax_lse).
// Safe to call any time; idempotent. Useful before unloading the drafter to
Expand Down
Loading
Loading