@@ -22,6 +22,7 @@ const std::vector<enum common_speculative_type> common_speculative_types = {
2222 COMMON_SPECULATIVE_TYPE_NONE ,
2323 COMMON_SPECULATIVE_TYPE_DRAFT ,
2424 COMMON_SPECULATIVE_TYPE_EAGLE3 ,
25+ COMMON_SPECULATIVE_TYPE_DFLASH ,
2526 COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE ,
2627 COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K ,
2728 COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V ,
@@ -33,6 +34,7 @@ const std::map<std::string, enum common_speculative_type> common_speculative_typ
3334 {" none" , COMMON_SPECULATIVE_TYPE_NONE },
3435 {" draft" , COMMON_SPECULATIVE_TYPE_DRAFT },
3536 {" eagle3" , COMMON_SPECULATIVE_TYPE_EAGLE3 },
37+ {" dflash" , COMMON_SPECULATIVE_TYPE_DFLASH },
3638 {" ngram_simple" , COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE },
3739 {" ngram_map_k" , COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K },
3840 {" ngram_map_k4v" , COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V },
@@ -708,6 +710,139 @@ struct common_speculative_state_eagle3 : public common_speculative_state {
708710 }
709711};
710712
713+ struct common_speculative_state_dflash : public common_speculative_state {
714+ llama_context * ctx_tgt;
715+
716+ common_sampler * smpl;
717+
718+ llama_batch batch;
719+
720+ struct llama_context * ctx_dft_enc = nullptr ;
721+ struct llama_context * ctx_dft_dec = nullptr ;
722+
723+ int32_t dflash_n_past = 0 ;
724+
725+ // Host-side buffer: accumulated DFlash-encoded target features across all
726+ // committed prompt+drafted tokens. Grows by `n_new * n_embd` floats per draft step
727+ // and is fed to the DFlash decoder via llama_set_dflash_accumulated_target_ctx()
728+ std::vector<float > accumulated_ctx;
729+
730+ common_speculative_state_dflash (
731+ enum common_speculative_type type,
732+ llama_context * ctx_tgt,
733+ llama_context * ctx_dft_enc,
734+ llama_context * ctx_dft_dec)
735+ : common_speculative_state(type)
736+ , ctx_tgt(ctx_tgt)
737+ , ctx_dft_enc(ctx_dft_enc)
738+ , ctx_dft_dec(ctx_dft_dec)
739+ {
740+ batch = llama_batch_init (llama_n_batch (ctx_dft_dec), 0 , 1 );
741+
742+ common_params_sampling params;
743+ params.no_perf = false ;
744+ params.top_k = 1 ;
745+ params.samplers = { COMMON_SAMPLER_TYPE_TOP_K };
746+ smpl = common_sampler_init (llama_get_model (ctx_dft_dec), params);
747+ }
748+
749+ ~common_speculative_state_dflash () override {
750+ llama_perf_context_print (ctx_dft_dec);
751+
752+ if (ctx_dft_dec) {
753+ llama_free (ctx_dft_dec);
754+ }
755+
756+ if (ctx_dft_enc) {
757+ llama_free (ctx_dft_enc);
758+ }
759+
760+ common_sampler_free (smpl);
761+ llama_batch_free (batch);
762+ }
763+
764+ void begin (const llama_tokens & prompt) override {
765+ GGML_UNUSED (prompt);
766+ }
767+
768+ void draft (
769+ const common_params_speculative & params,
770+ const llama_tokens & prompt_tgt,
771+ llama_token id_last,
772+ llama_tokens & result) override {
773+ const int n_embd = llama_model_n_embd (llama_get_model (ctx_dft_dec));
774+ // block_size is bounded by the model's trained block_size (from GGUF metadata).
775+ const int model_block_size = llama_model_dflash_block_size (llama_get_model (ctx_dft_dec));
776+ const int block_size = std::min ((int )params.n_max , model_block_size);
777+ const int n = (int )prompt_tgt.size ();
778+ const int n_new = n - dflash_n_past;
779+
780+ GGML_ASSERT (n >= 1 && " prompt_tgt is empty" );
781+ GGML_ASSERT (n_new >= 1 && " must have at least 1 new token" );
782+
783+ // Step 1: Encode new accepted tokens' features
784+ const float * features = llama_get_dflash_target_features (ctx_tgt);
785+
786+ llama_batch enc_batch = {
787+ /* .n_tokens =*/ n_new,
788+ /* .token =*/ nullptr ,
789+ /* .embd =*/ const_cast <float *>(features),
790+ /* .pos =*/ nullptr ,
791+ /* .n_seq_id =*/ nullptr ,
792+ /* .seq_id =*/ nullptr ,
793+ /* .logits =*/ nullptr ,
794+ };
795+ if (llama_encode (ctx_dft_enc, enc_batch) != 0 ) {
796+ LOG_ERR (" DFlash: encoder failed\n " );
797+ return ;
798+ }
799+
800+ const float * target_ctx_new = llama_get_embeddings (ctx_dft_enc);
801+ GGML_ASSERT (target_ctx_new && " encoder output is null" );
802+
803+ // Step 2: Append to accumulated target_ctx and set on decoder context (writes to cross.v_embd)
804+ const size_t new_size = (size_t )n_embd * n_new;
805+ accumulated_ctx.insert (accumulated_ctx.end (), target_ctx_new, target_ctx_new + new_size);
806+
807+ const int n_ctx_total = (int )(accumulated_ctx.size () / n_embd);
808+ llama_set_dflash_accumulated_target_ctx (ctx_dft_dec, accumulated_ctx.data (), n_embd, n_ctx_total);
809+
810+ // Step 3: Decode noise block
811+ const llama_token mask_token_id = llama_model_dflash_mask_token_id (llama_get_model (ctx_dft_dec));
812+
813+ common_batch_clear (batch);
814+ for (int i = 0 ; i < block_size; i++) {
815+ const llama_token tok = (i == 0 ) ? id_last : mask_token_id;
816+ common_batch_add (batch, tok, i, {0 }, true );
817+ }
818+
819+ if (llama_decode (ctx_dft_dec, batch) != 0 ) {
820+ LOG_ERR (" DFlash: noise decode failed\n " );
821+ return ;
822+ }
823+
824+ dflash_n_past = n;
825+
826+ // Step 4: Sample draft tokens from positions 1..block_size-1
827+ result.clear ();
828+ common_sampler_reset (smpl);
829+
830+ for (int i = 1 ; i < block_size; i++) {
831+ common_sampler_sample (smpl, ctx_dft_dec, i);
832+
833+ const auto * cur_p = common_sampler_get_candidates (smpl, true );
834+ const llama_token id = cur_p->data [0 ].id ;
835+
836+ common_sampler_accept (smpl, id, true );
837+ result.push_back (id);
838+ }
839+ }
840+
841+ void accept (uint16_t n_accepted) override {
842+ GGML_UNUSED (n_accepted);
843+ }
844+ };
845+
711846// state of self-speculation (simple implementation, not ngram-map)
712847struct common_speculative_state_ngram_simple : public common_speculative_state {
713848 common_ngram_simple_config config;
@@ -1057,13 +1192,13 @@ common_speculative * common_speculative_init(
10571192 llama_context * ctx_dft_dec = nullptr ;
10581193
10591194 if (params.model_dft ) {
1060- if (params.eagle3 ) {
1195+ if (params.eagle3 || params. dflash ) {
10611196 llama_context_params params_enc = params.cparams_dft ;
10621197 params_enc.target_model = nullptr ;
10631198 params_enc.embeddings = true ;
10641199 ctx_dft_enc = llama_init_from_model (params.model_dft , params_enc);
10651200 if (!ctx_dft_enc) {
1066- LOG_ERR (" failed to create EAGLE3 encoder context\n " );
1201+ LOG_ERR (" failed to create %s draft model encoder context\n " , params. eagle3 ? " EAGLE3 " : " DFlash " );
10671202 return nullptr ;
10681203 }
10691204
@@ -1072,13 +1207,13 @@ common_speculative * common_speculative_init(
10721207 params_dec.embeddings = true ;
10731208 ctx_dft_dec = llama_init_from_model (params.model_dft , params_dec);
10741209 if (!ctx_dft_dec) {
1075- LOG_ERR (" failed to create EAGLE3 decoder context\n " );
1210+ LOG_ERR (" failed to create %s draft model decoder context\n " , params. eagle3 ? " EAGLE3 " : " DFlash " );
10761211 return nullptr ;
10771212 }
10781213 } else {
10791214 ctx_dft = llama_init_from_model (params.model_dft , params.cparams_dft );
10801215 if (ctx_dft == nullptr ) {
1081- LOG_ERR (" %s " , " failed to create draft context\n " );
1216+ LOG_ERR (" failed to create draft model context\n " );
10821217 return nullptr ;
10831218 }
10841219 }
@@ -1089,6 +1224,7 @@ common_speculative * common_speculative_init(
10891224 {
10901225 bool has_draft = !params.mparams_dft .path .empty ();
10911226 bool has_draft_eagle3 = params.eagle3 ;
1227+ bool has_draft_dflash = params.dflash ;
10921228
10931229 bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE );
10941230 bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE );
@@ -1131,6 +1267,8 @@ common_speculative * common_speculative_init(
11311267 if (has_draft) {
11321268 if (has_draft_eagle3) {
11331269 configs.push_back (common_speculative_config (COMMON_SPECULATIVE_TYPE_EAGLE3 , params));
1270+ } else if (has_draft_dflash) {
1271+ configs.push_back (common_speculative_config (COMMON_SPECULATIVE_TYPE_DFLASH , params));
11341272 } else {
11351273 configs.push_back (common_speculative_config (COMMON_SPECULATIVE_TYPE_DRAFT , params));
11361274 }
@@ -1163,6 +1301,14 @@ common_speculative * common_speculative_init(
11631301 ));
11641302 break ;
11651303 }
1304+ case COMMON_SPECULATIVE_TYPE_DFLASH : {
1305+ impls.push_back (std::make_unique<common_speculative_state_dflash>(config.type ,
1306+ /* .ctx_tgt = */ ctx_tgt,
1307+ /* .ctx_dft_enc = */ ctx_dft_enc,
1308+ /* .ctx_dft_dec = */ ctx_dft_dec
1309+ ));
1310+ break ;
1311+ }
11661312 case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE : {
11671313 common_ngram_map ngram_map = get_common_ngram_map (config);
11681314
0 commit comments