Skip to content

Commit d1b3425

Browse files
ruixiang63kashif
andauthored
spec : add DFlash support (#22105)
* spec: add DFlash v2 support * dflash: support sliding window attention per layer_types * docs: add dflash section --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
1 parent c1a1c8e commit d1b3425

14 files changed

Lines changed: 712 additions & 9 deletions

File tree

common/common.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ enum common_speculative_type {
169169
COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, // standalone draft model speculative decoding
170170
COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, // Eagle3 speculative decoding
171171
COMMON_SPECULATIVE_TYPE_DRAFT_MTP, // Multi-token prediction
172+
COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH, // DFlash speculative decoding
172173
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding based on n-grams
173174
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
174175
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
@@ -384,7 +385,7 @@ struct common_params_speculative {
384385

385386
uint32_t need_n_rs_seq() const {
386387
bool needs_rs_seq = std::any_of(types.begin(), types.end(), [&](auto t) {
387-
return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP || t == COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3;
388+
return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP || t == COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3 || t == COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH;
388389
});
389390

390391
return needs_rs_seq ? draft.n_max : 0u;

common/speculative.cpp

Lines changed: 302 additions & 1 deletion
Large diffs are not rendered by default.

conversion/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
"DeepseekV2ForCausalLM": "deepseek",
5151
"DeepseekV3ForCausalLM": "deepseek",
5252
"DeepseekV32ForCausalLM": "deepseek",
53+
"DFlashDraftModel": "qwen",
5354
"DistilBertForMaskedLM": "bert",
5455
"DistilBertForSequenceClassification": "bert",
5556
"DistilBertModel": "bert",

conversion/qwen.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,3 +625,55 @@ class Qwen3_5TextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReor
625625
@ModelBase.register("Qwen3_5MoeForConditionalGeneration", "Qwen3_5MoeForCausalLM")
626626
class Qwen3_5MoeTextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReorderBase):
627627
model_arch = gguf.MODEL_ARCH.QWEN35MOE
628+
629+
630+
@ModelBase.register("DFlashDraftModel")
631+
class DFlashModel(Qwen3Model):
632+
model_arch = gguf.MODEL_ARCH.DFLASH
633+
634+
def set_vocab(self):
635+
if self.target_model_dir is None:
636+
raise ValueError(
637+
"DFlash draft model requires --target-model-dir to be specified. "
638+
"Please provide the path to the target model directory containing the tokenizer."
639+
)
640+
logger.info(f"DFlash: Using tokenizer from target model: {self.target_model_dir}")
641+
original_dir = self.dir_model
642+
self.dir_model = self.target_model_dir
643+
super().set_vocab()
644+
self.dir_model = original_dir
645+
646+
def set_gguf_parameters(self):
647+
super().set_gguf_parameters()
648+
649+
block_size = self.hparams.get("block_size", 16)
650+
self.gguf_writer.add_uint32(f"{self.gguf_writer.arch}.block_size", block_size)
651+
dflash_config = self.hparams.get("dflash_config", {})
652+
653+
target_layer_ids = dflash_config.get("target_layer_ids", [])
654+
if target_layer_ids:
655+
extract_layer_ids = [i + 1 for i in target_layer_ids]
656+
self.gguf_writer.add_array(f"{self.gguf_writer.arch}.target_layers", extract_layer_ids)
657+
658+
mask_token_id = dflash_config.get("mask_token_id", None)
659+
if mask_token_id is not None:
660+
self.gguf_writer.add_mask_token_id(mask_token_id)
661+
662+
use_sliding_window = self.hparams.get("use_sliding_window", False)
663+
sliding_window = self.hparams.get("sliding_window")
664+
layer_types = self.hparams.get("layer_types")
665+
if use_sliding_window and sliding_window and layer_types:
666+
is_swa = [lt == "sliding_attention" for lt in layer_types]
667+
self.gguf_writer.add_sliding_window(sliding_window)
668+
self.gguf_writer.add_sliding_window_pattern(is_swa)
669+
670+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
671+
if name == "fc.weight":
672+
yield (name, data_torch)
673+
return
674+
if name == "hidden_norm.weight":
675+
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ENC_OUTPUT_NORM), data_torch)
676+
return
677+
if not name.startswith("model."):
678+
name = "model." + name
679+
yield from super().modify_tensors(data_torch, name, bid)

docs/speculative.md

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,32 @@ Supported EAGLE-3 draft models include:
5252

5353
For the full and up-to-date list of supported models, see #18039.
5454

55+
### DFlash (`draft-dflash`)
56+
57+
DFlash produces an entire block of draft tokens in a single forward pass (block diffusion) and
58+
injects the target model's hidden states into the draft model's attention, instead of drafting one
59+
token at a time. This keeps the draft model small while making drafting GPU-friendly. Unlike EAGLE-3
60+
(a single-layer autoregressive draft), the DFlash draft uses several transformer layers but emits a
61+
whole block per draft step.
62+
63+
The draft is a small block-diffusion model trained for a specific target (for example
64+
`z-lab/Qwen3-4B-DFlash` for `Qwen/Qwen3-4B`). Convert it with `--target-model-dir` so it inherits the
65+
target's tokenizer and token embeddings:
66+
67+
```bash
68+
python convert_hf_to_gguf.py z-lab/Qwen3-4B-DFlash \
69+
--target-model-dir Qwen/Qwen3-4B --outtype bf16 --outfile Qwen3-4B-DFlash.gguf
70+
71+
llama-server -m Qwen3-4B.gguf -md Qwen3-4B-DFlash.gguf \
72+
--spec-type draft-dflash --spec-draft-n-max 15 -fa on --jinja
73+
```
74+
75+
`--spec-draft-n-max` is clamped to the draft model's trained block size.
76+
77+
See:
78+
79+
- #22105
80+
5581
### n-gram Cache (`ngram-cache`)
5682

5783
An n-gram is a sequence of n tokens. The n-gram cache implementation maintains statistics about short n-gram sequences.
@@ -147,7 +173,7 @@ If a draft model is combined with a draftless decoding the draftless decoding ha
147173
### General Speculative Parameters
148174

149175
```
150-
--spec-type [none|draft-simple|draft-eagle3|draft-mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]
176+
--spec-type [none|draft-simple|draft-eagle3|draft-dflash|draft-mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]
151177
comma-separated list of types of speculative decoding to use
152178
(default: none)
153179
(env: LLAMA_ARG_SPEC_TYPE)
@@ -287,6 +313,7 @@ Specifies a comma-separated list of speculative decoding types to use.
287313
| `none` | No speculative decoding (default) |
288314
| `draft-simple` | Use a simple draft model for speculation |
289315
| `draft-eagle3` | Use an EAGLE-3 draft model that reads the target's hidden states |
316+
| `draft-dflash` | Use a DFlash block-diffusion draft model that emits a block per step |
290317
| `draft-mtp` | Use Multi Token Prediction (MTP) heads from the main model |
291318
| `ngram-cache` | Use n-gram cache lookup |
292319
| `ngram-simple` | Use simple n-gram pattern matching |

gguf-py/gguf/constants.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,7 @@ class MODEL_ARCH(IntEnum):
517517
PANGU_EMBED = auto()
518518
MISTRAL3 = auto()
519519
EAGLE3 = auto()
520+
DFLASH = auto()
520521
MISTRAL4 = auto()
521522
PADDLEOCR = auto()
522523
MIMO2 = auto()
@@ -1074,6 +1075,7 @@ class MODEL_TENSOR(IntEnum):
10741075
MODEL_ARCH.PANGU_EMBED: "pangu-embedded",
10751076
MODEL_ARCH.MISTRAL3: "mistral3",
10761077
MODEL_ARCH.EAGLE3: "eagle3",
1078+
MODEL_ARCH.DFLASH: "dflash",
10771079
MODEL_ARCH.MISTRAL4: "mistral4",
10781080
MODEL_ARCH.PADDLEOCR: "paddleocr",
10791081
MODEL_ARCH.MIMO2: "mimo2",
@@ -4086,6 +4088,22 @@ class MODEL_TENSOR(IntEnum):
40864088
MODEL_TENSOR.FC,
40874089
MODEL_TENSOR.D2T,
40884090
],
4091+
MODEL_ARCH.DFLASH: [
4092+
MODEL_TENSOR.OUTPUT_NORM,
4093+
MODEL_TENSOR.ATTN_NORM,
4094+
MODEL_TENSOR.ATTN_Q,
4095+
MODEL_TENSOR.ATTN_K,
4096+
MODEL_TENSOR.ATTN_V,
4097+
MODEL_TENSOR.ATTN_OUT,
4098+
MODEL_TENSOR.ATTN_Q_NORM,
4099+
MODEL_TENSOR.ATTN_K_NORM,
4100+
MODEL_TENSOR.FFN_NORM,
4101+
MODEL_TENSOR.FFN_GATE,
4102+
MODEL_TENSOR.FFN_DOWN,
4103+
MODEL_TENSOR.FFN_UP,
4104+
MODEL_TENSOR.FC,
4105+
MODEL_TENSOR.ENC_OUTPUT_NORM,
4106+
],
40894107
MODEL_ARCH.MISTRAL4: [
40904108
MODEL_TENSOR.TOKEN_EMBD,
40914109
MODEL_TENSOR.OUTPUT_NORM,

src/llama-arch.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
129129
{ LLM_ARCH_PANGU_EMBED, "pangu-embedded" },
130130
{ LLM_ARCH_MISTRAL3, "mistral3" },
131131
{ LLM_ARCH_EAGLE3, "eagle3" },
132+
{ LLM_ARCH_DFLASH, "dflash" },
132133
{ LLM_ARCH_MISTRAL4, "mistral4" },
133134
{ LLM_ARCH_PADDLEOCR, "paddleocr" },
134135
{ LLM_ARCH_MIMO2, "mimo2" },

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ enum llm_arch {
143143
LLM_ARCH_TALKIE,
144144
LLM_ARCH_MELLUM,
145145
LLM_ARCH_EAGLE3,
146+
LLM_ARCH_DFLASH,
146147
LLM_ARCH_UNKNOWN,
147148
};
148149

src/llama-context.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,10 @@ llama_context::llama_context(
100100
cparams.ctx_other = params.ctx_other;
101101
}
102102

103-
if (model.arch == LLM_ARCH_EAGLE3) {
103+
if (model.arch == LLM_ARCH_EAGLE3 || model.arch == LLM_ARCH_DFLASH) {
104104
if (model.tok_embd == nullptr || model.output == nullptr) {
105105
if (params.ctx_other == nullptr) {
106-
throw std::runtime_error("EAGLE3 requires ctx_other to be set (this warning is normal during memory fitting)");
106+
throw std::runtime_error(model.arch_name() + " requires ctx_other to be set (this warning is normal during memory fitting)");
107107
}
108108
cparams.ctx_other = params.ctx_other;
109109
}

src/llama-graph.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,11 @@ void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
486486
mctx->set_input_k_idxs(self_k_idxs, ubatch);
487487
mctx->set_input_v_idxs(self_v_idxs, ubatch);
488488

489-
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
489+
// the mask is left unallocated when the graph only stores K/V without attending
490+
// (e.g. DFlash's KV-injection pass)
491+
if (self_kq_mask && self_kq_mask->buffer) {
492+
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
493+
}
490494

491495
if (self_k_rot) {
492496
mctx->set_input_k_rot(self_k_rot);
@@ -904,6 +908,7 @@ void llm_graph_result::reset() {
904908
t_logits = nullptr;
905909
t_embd = nullptr;
906910
t_embd_pooled = nullptr;
911+
t_h_nextn = nullptr;
907912

908913
t_layer_inp.resize(LLAMA_MAX_LAYERS);
909914
std::fill(t_layer_inp.begin(), t_layer_inp.end(), nullptr);

0 commit comments

Comments
 (0)