Skip to content

Commit 0724d66

Browse files
committed
dflash: first working POC
1 parent 91b03e4 commit 0724d66

23 files changed

Lines changed: 816 additions & 13 deletions

common/arg.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3474,6 +3474,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
34743474
params.speculative.eagle3 = true;
34753475
}
34763476
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_CLI}));
3477+
add_opt(common_arg(
3478+
{"--dflash"},
3479+
"use DFlash speculative decoding with the draft model",
3480+
[](common_params & params) {
3481+
params.speculative.dflash = true;
3482+
}
3483+
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_CLI}));
34773484
add_opt(common_arg(
34783485
{"-cd", "--ctx-size-draft"}, "N",
34793486
string_format("size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.n_ctx),

common/common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ enum common_speculative_type {
159159
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
160160
COMMON_SPECULATIVE_TYPE_DRAFT, // draft model
161161
COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model
162+
COMMON_SPECULATIVE_TYPE_DFLASH, // dflash draft model
162163
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
163164
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
164165
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
@@ -328,6 +329,7 @@ struct common_params_speculative {
328329
llama_context_params cparams_dft; // these are the parameters for the draft llama_context
329330

330331
bool eagle3 = false; // use EAGLE3 speculative decoding
332+
bool dflash = false; // use DFlash speculative decoding
331333

332334
int32_t n_ctx = 0; // draft context size
333335
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)

common/speculative.cpp

Lines changed: 150 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ const std::vector<enum common_speculative_type> common_speculative_types = {
2222
COMMON_SPECULATIVE_TYPE_NONE,
2323
COMMON_SPECULATIVE_TYPE_DRAFT,
2424
COMMON_SPECULATIVE_TYPE_EAGLE3,
25+
COMMON_SPECULATIVE_TYPE_DFLASH,
2526
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,
2627
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,
2728
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V,
@@ -33,6 +34,7 @@ const std::map<std::string, enum common_speculative_type> common_speculative_typ
3334
{"none", COMMON_SPECULATIVE_TYPE_NONE},
3435
{"draft", COMMON_SPECULATIVE_TYPE_DRAFT},
3536
{"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3},
37+
{"dflash", COMMON_SPECULATIVE_TYPE_DFLASH},
3638
{"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
3739
{"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
3840
{"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
@@ -708,6 +710,139 @@ struct common_speculative_state_eagle3 : public common_speculative_state {
708710
}
709711
};
710712

713+
struct common_speculative_state_dflash : public common_speculative_state {
714+
llama_context * ctx_tgt;
715+
716+
common_sampler * smpl;
717+
718+
llama_batch batch;
719+
720+
struct llama_context * ctx_dft_enc = nullptr;
721+
struct llama_context * ctx_dft_dec = nullptr;
722+
723+
int32_t dflash_n_past = 0;
724+
725+
// Host-side buffer: accumulated DFlash-encoded target features across all
726+
// committed prompt+drafted tokens. Grows by `n_new * n_embd` floats per draft step
727+
// and is fed to the DFlash decoder via llama_set_dflash_accumulated_target_ctx()
728+
std::vector<float> accumulated_ctx;
729+
730+
common_speculative_state_dflash(
731+
enum common_speculative_type type,
732+
llama_context * ctx_tgt,
733+
llama_context * ctx_dft_enc,
734+
llama_context * ctx_dft_dec)
735+
: common_speculative_state(type)
736+
, ctx_tgt(ctx_tgt)
737+
, ctx_dft_enc(ctx_dft_enc)
738+
, ctx_dft_dec(ctx_dft_dec)
739+
{
740+
batch = llama_batch_init(llama_n_batch(ctx_dft_dec), 0, 1);
741+
742+
common_params_sampling params;
743+
params.no_perf = false;
744+
params.top_k = 1;
745+
params.samplers = { COMMON_SAMPLER_TYPE_TOP_K };
746+
smpl = common_sampler_init(llama_get_model(ctx_dft_dec), params);
747+
}
748+
749+
~common_speculative_state_dflash() override {
750+
llama_perf_context_print(ctx_dft_dec);
751+
752+
if (ctx_dft_dec) {
753+
llama_free(ctx_dft_dec);
754+
}
755+
756+
if (ctx_dft_enc) {
757+
llama_free(ctx_dft_enc);
758+
}
759+
760+
common_sampler_free(smpl);
761+
llama_batch_free(batch);
762+
}
763+
764+
void begin(const llama_tokens & prompt) override {
765+
GGML_UNUSED(prompt);
766+
}
767+
768+
void draft(
769+
const common_params_speculative & params,
770+
const llama_tokens & prompt_tgt,
771+
llama_token id_last,
772+
llama_tokens & result) override {
773+
const int n_embd = llama_model_n_embd(llama_get_model(ctx_dft_dec));
774+
// block_size is bounded by the model's trained block_size (from GGUF metadata).
775+
const int model_block_size = llama_model_dflash_block_size(llama_get_model(ctx_dft_dec));
776+
const int block_size = std::min((int)params.n_max, model_block_size);
777+
const int n = (int)prompt_tgt.size();
778+
const int n_new = n - dflash_n_past;
779+
780+
GGML_ASSERT(n >= 1 && "prompt_tgt is empty");
781+
GGML_ASSERT(n_new >= 1 && "must have at least 1 new token");
782+
783+
// Step 1: Encode new accepted tokens' features
784+
const float * features = llama_get_dflash_target_features(ctx_tgt);
785+
786+
llama_batch enc_batch = {
787+
/*.n_tokens =*/ n_new,
788+
/*.token =*/ nullptr,
789+
/*.embd =*/ const_cast<float*>(features),
790+
/*.pos =*/ nullptr,
791+
/*.n_seq_id =*/ nullptr,
792+
/*.seq_id =*/ nullptr,
793+
/*.logits =*/ nullptr,
794+
};
795+
if (llama_encode(ctx_dft_enc, enc_batch) != 0) {
796+
LOG_ERR("DFlash: encoder failed\n");
797+
return;
798+
}
799+
800+
const float * target_ctx_new = llama_get_embeddings(ctx_dft_enc);
801+
GGML_ASSERT(target_ctx_new && "encoder output is null");
802+
803+
// Step 2: Append to accumulated target_ctx and set on decoder context (writes to cross.v_embd)
804+
const size_t new_size = (size_t)n_embd * n_new;
805+
accumulated_ctx.insert(accumulated_ctx.end(), target_ctx_new, target_ctx_new + new_size);
806+
807+
const int n_ctx_total = (int)(accumulated_ctx.size() / n_embd);
808+
llama_set_dflash_accumulated_target_ctx(ctx_dft_dec, accumulated_ctx.data(), n_embd, n_ctx_total);
809+
810+
// Step 3: Decode noise block
811+
const llama_token mask_token_id = llama_model_dflash_mask_token_id(llama_get_model(ctx_dft_dec));
812+
813+
common_batch_clear(batch);
814+
for (int i = 0; i < block_size; i++) {
815+
const llama_token tok = (i == 0) ? id_last : mask_token_id;
816+
common_batch_add(batch, tok, i, {0}, true);
817+
}
818+
819+
if (llama_decode(ctx_dft_dec, batch) != 0) {
820+
LOG_ERR("DFlash: noise decode failed\n");
821+
return;
822+
}
823+
824+
dflash_n_past = n;
825+
826+
// Step 4: Sample draft tokens from positions 1..block_size-1
827+
result.clear();
828+
common_sampler_reset(smpl);
829+
830+
for (int i = 1; i < block_size; i++) {
831+
common_sampler_sample(smpl, ctx_dft_dec, i);
832+
833+
const auto * cur_p = common_sampler_get_candidates(smpl, true);
834+
const llama_token id = cur_p->data[0].id;
835+
836+
common_sampler_accept(smpl, id, true);
837+
result.push_back(id);
838+
}
839+
}
840+
841+
void accept(uint16_t n_accepted) override {
842+
GGML_UNUSED(n_accepted);
843+
}
844+
};
845+
711846
// state of self-speculation (simple implementation, not ngram-map)
712847
struct common_speculative_state_ngram_simple : public common_speculative_state {
713848
common_ngram_simple_config config;
@@ -1057,13 +1192,13 @@ common_speculative * common_speculative_init(
10571192
llama_context * ctx_dft_dec = nullptr;
10581193

10591194
if (params.model_dft) {
1060-
if (params.eagle3) {
1195+
if (params.eagle3 || params.dflash) {
10611196
llama_context_params params_enc = params.cparams_dft;
10621197
params_enc.target_model = nullptr;
10631198
params_enc.embeddings = true;
10641199
ctx_dft_enc = llama_init_from_model(params.model_dft, params_enc);
10651200
if (!ctx_dft_enc) {
1066-
LOG_ERR("failed to create EAGLE3 encoder context\n");
1201+
LOG_ERR("failed to create %s draft model encoder context\n", params.eagle3 ? "EAGLE3" : "DFlash");
10671202
return nullptr;
10681203
}
10691204

@@ -1072,13 +1207,13 @@ common_speculative * common_speculative_init(
10721207
params_dec.embeddings = true;
10731208
ctx_dft_dec = llama_init_from_model(params.model_dft, params_dec);
10741209
if (!ctx_dft_dec) {
1075-
LOG_ERR("failed to create EAGLE3 decoder context\n");
1210+
LOG_ERR("failed to create %s draft model decoder context\n", params.eagle3 ? "EAGLE3" : "DFlash");
10761211
return nullptr;
10771212
}
10781213
} else {
10791214
ctx_dft = llama_init_from_model(params.model_dft, params.cparams_dft);
10801215
if (ctx_dft == nullptr) {
1081-
LOG_ERR("%s", "failed to create draft context\n");
1216+
LOG_ERR("failed to create draft model context\n");
10821217
return nullptr;
10831218
}
10841219
}
@@ -1089,6 +1224,7 @@ common_speculative * common_speculative_init(
10891224
{
10901225
bool has_draft = !params.mparams_dft.path.empty();
10911226
bool has_draft_eagle3 = params.eagle3;
1227+
bool has_draft_dflash = params.dflash;
10921228

10931229
bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE);
10941230
bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE);
@@ -1131,6 +1267,8 @@ common_speculative * common_speculative_init(
11311267
if (has_draft) {
11321268
if (has_draft_eagle3) {
11331269
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_EAGLE3, params));
1270+
} else if (has_draft_dflash) {
1271+
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DFLASH, params));
11341272
} else {
11351273
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params));
11361274
}
@@ -1163,6 +1301,14 @@ common_speculative * common_speculative_init(
11631301
));
11641302
break;
11651303
}
1304+
case COMMON_SPECULATIVE_TYPE_DFLASH: {
1305+
impls.push_back(std::make_unique<common_speculative_state_dflash>(config.type,
1306+
/* .ctx_tgt = */ ctx_tgt,
1307+
/* .ctx_dft_enc = */ ctx_dft_enc,
1308+
/* .ctx_dft_dec = */ ctx_dft_dec
1309+
));
1310+
break;
1311+
}
11661312
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
11671313
common_ngram_map ngram_map = get_common_ngram_map(config);
11681314

convert_hf_to_gguf.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4887,6 +4887,47 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
48874887
yield from super().modify_tensors(data_torch, name, bid)
48884888

48894889

4890+
@ModelBase.register("DFlashDraftModel")
4891+
class DFlashModel(Qwen3Model):
4892+
model_arch = gguf.MODEL_ARCH.DFLASH
4893+
4894+
def set_vocab(self):
4895+
if self.target_model_dir is None:
4896+
raise ValueError(
4897+
"DFlash draft model requires --target-model-dir to be specified. "
4898+
"Please provide the path to the target model directory containing the tokenizer."
4899+
)
4900+
logger.info(f"DFLASH: Using tokenizer from target model: {self.target_model_dir}")
4901+
original_dir = self.dir_model
4902+
self.dir_model = self.target_model_dir
4903+
super().set_vocab()
4904+
self.dir_model = original_dir
4905+
4906+
def set_gguf_parameters(self):
4907+
super().set_gguf_parameters()
4908+
block_size = self.hparams.get("block_size", 16)
4909+
self.gguf_writer.add_uint32(f"{self.gguf_writer.arch}.block_size", block_size)
4910+
dflash_config = self.hparams.get("dflash_config", {})
4911+
target_layer_ids = dflash_config.get("target_layer_ids", [])
4912+
if target_layer_ids:
4913+
extract_layer_ids = [i + 1 for i in target_layer_ids]
4914+
self.gguf_writer.add_array(f"{self.gguf_writer.arch}.target_layer_ids", extract_layer_ids)
4915+
mask_token_id = dflash_config.get("mask_token_id", None)
4916+
if mask_token_id is not None:
4917+
self.gguf_writer.add_uint32(f"{self.gguf_writer.arch}.mask_token_id", mask_token_id)
4918+
4919+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
4920+
if name == "fc.weight":
4921+
yield (name, data_torch)
4922+
return
4923+
if name == "hidden_norm.weight":
4924+
yield ("hidden_norm.weight", data_torch)
4925+
return
4926+
if not name.startswith("model."):
4927+
name = "model." + name
4928+
yield from super().modify_tensors(data_torch, name, bid)
4929+
4930+
48904931
@ModelBase.register("Qwen3MoeForCausalLM")
48914932
class Qwen3MoeModel(Qwen2MoeModel):
48924933
model_arch = gguf.MODEL_ARCH.QWEN3MOE

0 commit comments

Comments
 (0)