Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 52 additions & 24 deletions dflash/src/draft/draft_gguf_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
// types — ggml's ggml_mul_mat handles Q8_0 × F32 dequantization transparently.
//
// GGUF arch: "qwen35-dflash-draft" (from convert_dflash_to_gguf.py /
// quantize_draft_q8.py) or the newer "dflash-draft" export. Tensor naming
// convention:
// quantize_draft_q8.py). Tensor naming convention:
//
// dflash.fc.weight / dflash_fc.weight [5*hidden, hidden] Q8_0 / F16
// dflash.hidden_norm.weight /
// dflash_hidden_norm.weight [hidden] F32
// dflash.fc.weight [5*hidden, hidden] Q8_0 / F16
// dflash.hidden_norm.weight [hidden] F32
// output_norm.weight [hidden] F32
// blk.<i>.attn_norm.weight [hidden] F32
// blk.<i>.ffn_norm.weight [hidden] F32
Expand Down Expand Up @@ -108,6 +106,14 @@ uint32_t get_u32_or(const gguf_context * g, const char * key, uint32_t fallback)
return gguf_get_val_u32(g, id);
}

int count_swa_layers(const DraftWeights & w) {
int n_swa = 0;
for (const DraftLayer & layer : w.layers) {
if (layer.is_swa) n_swa++;
}
return n_swa;
}

} // namespace

bool load_draft_gguf(const std::string & path,
Expand All @@ -126,6 +132,7 @@ bool load_draft_gguf(const std::string & path,
}

// Validate arch
std::string arch_s;
{
int64_t arch_id = gguf_find_key(gctx, "general.architecture");
if (arch_id < 0) {
Expand All @@ -134,8 +141,8 @@ bool load_draft_gguf(const std::string & path,
return false;
}
const char * arch = gguf_get_val_str(gctx, arch_id);
if (std::string(arch) != "qwen35-dflash-draft" &&
std::string(arch) != "dflash-draft") {
arch_s = arch;
if (arch_s != "qwen35-dflash-draft" && arch_s != "dflash-draft") {
set_last_error(std::string("unexpected draft arch: ") + arch +
" (expected qwen35-dflash-draft or dflash-draft)");
gguf_free(gctx);
Expand All @@ -144,8 +151,7 @@ bool load_draft_gguf(const std::string & path,
}

// Read dimensions from GGUF metadata
int64_t arch_id = gguf_find_key(gctx, "general.architecture");
const char * A = gguf_get_val_str(gctx, arch_id);
const char * A = arch_s.c_str();
char key[128];

auto read_u32 = [&](const char * suffix, uint32_t fallback) -> uint32_t {
Expand All @@ -162,16 +168,17 @@ bool load_draft_gguf(const std::string & path,
const uint32_t block_sz = read_u32("dflash.block_size", 0);
uint32_t n_tgt_lay = read_u32("dflash.n_target_layers", 0);
if (n_tgt_lay == 0) {
const uint32_t n_tgt_feat = read_u32("dflash.n_target_features", 0);
if (n_tgt_feat != 0 && n_embd != 0 && (n_tgt_feat % n_embd) == 0) {
n_tgt_lay = n_tgt_feat / n_embd;
std::snprintf(key, sizeof(key), "%s.%s", A, "dflash.target_layer_ids");
const int64_t target_ids_id = gguf_find_key(gctx, key);
if (target_ids_id >= 0 &&
gguf_get_kv_type(gctx, target_ids_id) == GGUF_TYPE_ARRAY) {
n_tgt_lay = (uint32_t)gguf_get_arr_n(gctx, target_ids_id);
}
}
if (n_tgt_lay == 0) {
std::snprintf(key, sizeof(key), "%s.%s", A, "dflash.target_layer_ids");
int64_t id = gguf_find_key(gctx, key);
if (id >= 0) {
n_tgt_lay = (uint32_t)gguf_get_arr_n(gctx, id);
if (n_tgt_lay == 0 && n_embd != 0) {
const uint32_t n_target_features = read_u32("dflash.n_target_features", 0);
if (n_target_features != 0 && (n_target_features % n_embd) == 0) {
n_tgt_lay = n_target_features / n_embd;
}
}

Expand Down Expand Up @@ -240,17 +247,17 @@ bool load_draft_gguf(const std::string & path,
auto g = [&](const char * name) -> ggml_tensor * {
return ggml_get_tensor(meta_ctx, name);
};

auto first = [](ggml_tensor * a, ggml_tensor * b) -> ggml_tensor * {
return a ? a : b;
auto g_any = [&](const char * a, const char * b) -> ggml_tensor * {
if (ggml_tensor * t = g(a)) return t;
return g(b);
};

out.fc = first(g("dflash.fc.weight"), g("dflash_fc.weight"));
out.hidden_norm = first(g("dflash.hidden_norm.weight"), g("dflash_hidden_norm.weight"));
out.fc = g_any("dflash.fc.weight", "dflash_fc.weight");
out.hidden_norm = g_any("dflash.hidden_norm.weight", "dflash_hidden_norm.weight");
out.out_norm = g("output_norm.weight");
if (!out.fc || !out.hidden_norm || !out.out_norm) {
set_last_error("draft GGUF: missing top-level tensors "
"(dflash fc / dflash hidden norm / output_norm)");
"(dflash.fc|dflash_fc / dflash.hidden_norm|dflash_hidden_norm / output_norm)");
gguf_free(gctx);
return false;
}
Expand All @@ -263,7 +270,8 @@ bool load_draft_gguf(const std::string & path,
};
DraftLayer & L = out.layers[il];
L.attn_norm = fnd("attn_norm.weight");
L.ffn_norm = first(fnd("ffn_norm.weight"), fnd("post_attention_norm.weight"));
L.ffn_norm = fnd("ffn_norm.weight");
if (!L.ffn_norm) L.ffn_norm = fnd("post_attention_norm.weight");
L.wq = fnd("attn_q.weight");
L.wk = fnd("attn_k.weight");
L.wv = fnd("attn_v.weight");
Expand All @@ -283,6 +291,26 @@ bool load_draft_gguf(const std::string & path,
}
}

// GGUF Qwen3.6 drafters carry SWA metadata emitted by the converter:
// dflash-draft.attention.sliding_window = 2048
// dflash-draft.attention.sliding_window_pattern = [true,true,true,true,false]
out.swa_window = (int)read_u32("attention.sliding_window", 0);
std::snprintf(key, sizeof(key), "%s.%s", A, "attention.sliding_window_pattern");
int64_t swp_id = gguf_find_key(gctx, key);
if (swp_id >= 0 && gguf_get_kv_type(gctx, swp_id) == GGUF_TYPE_ARRAY &&
gguf_get_arr_type(gctx, swp_id) == GGUF_TYPE_BOOL) {
const size_t n = gguf_get_arr_n(gctx, swp_id);
const bool * pattern = static_cast<const bool *>(gguf_get_arr_data(gctx, swp_id));
for (size_t il = 0; il < n && il < out.layers.size(); il++) {
out.layers[il].is_swa = pattern[il];
}
}
const int n_swa = count_swa_layers(out);
if (n_swa > 0) {
std::fprintf(stderr, "[draft GGUF] SWA layers: %d/%d (window=%d)\n",
n_swa, out.n_layer, out.swa_window);
}

// ── 3. Allocate CUDA buffer for all tensors ──────────────────────────
out.buf = ggml_backend_alloc_ctx_tensors(meta_ctx, backend);
if (!out.buf) {
Expand Down