Skip to content

spec: support eagle3 for qwen3.5 & 3.6#24593

Merged
ggerganov merged 6 commits into
ggml-org:masterfrom
ruixiang63:eagle3_qwen3.6
Jun 19, 2026
Merged

spec: support eagle3 for qwen3.5 & 3.6#24593
ggerganov merged 6 commits into
ggml-org:masterfrom
ruixiang63:eagle3_qwen3.6

Conversation

@ruixiang63

@ruixiang63 ruixiang63 commented Jun 13, 2026

Copy link
Copy Markdown
Contributor

Overview

Support third-party eagle3 draft models for Qwen3.5 & 3.6

Fix issue: #24541

Running steps:

  • Compile based on this PR
cmake -B build -DGGML_CUDA=ON
cmake --build build --config Release
  • Convert HF models to GGUF
TARGET_MODEL_HF="Qwen/Qwen3.6-27B"
TARGET_MODEL_GGUF="Qwen3.6-27B.gguf"
EAGLE3_MODEL_HF="Qwen3.6-27B-PRISM-EAGLE3/full"
EAGLE3_MODEL_GGUF="Qwen3.6-27B-PRISM-EAGLE3.gguf"


python convert_hf_to_gguf.py \
   "${TARGET_MODEL_HF}" \
   --outtype bf16 \
   --outfile "${TARGET_MODEL_GGUF}"

python convert_hf_to_gguf.py \
   "${EAGLE3_MODEL_HF}" \
   --outtype bf16 \
   --target-model-dir "${TARGET_MODEL_HF}" \
   --outfile "${EAGLE3_MODEL_GGUF}"
  • Start llama-server, e.g.
./build/bin/llama-server \
   -m Qwen3.6-27B.gguf_Q4_K_M.gguf \
   -md Qwen3.6-27B-PRISM-EAGLE3.gguf_Q4_K_M.gguf \
   --spec-type draft-eagle3 \
   --spec-draft-n-max 3 \
   --spec-draft-p-min 0.0 \
   -np 1 \
   -c 92288 --port 8080 -ngl 99 -fa on \
   --jinja

Performance on DGX Spark with SpeedBench:

python tools/server/bench/speed-bench/speed_bench.py \
  --url localhost:8080 \
  --bench qualitative \
  --category all \
  --osl 256 \
  --concurrency 1 \
  --limit 2 \
  • Qwen3.6-27B (Q4_K_M) Baseline
Summary (elapsed=559.27s)
category       samples  avg_prompt_t/s  avg_pred_t/s  avg_latency  accept_rate
-------------  -------  --------------  ------------  -----------  -----------
coding         2        423.46          12.59         20.760s      n/a        
humanities     2        147.33          12.56         31.123s      n/a        
math           2        44.41           12.56         20.521s      n/a        
qa             2        92.88           12.56         19.320s      n/a        
rag            2        676.74          12.54         31.599s      n/a        
reasoning      2        74.70           12.56         20.644s      n/a        
stem           2        44.25           12.56         20.490s      n/a        
writing        2        561.37          12.52         21.757s      n/a        
multilingual   2        180.25          12.56         20.751s      n/a        
summarization  2        144.48          12.56         20.753s      n/a        
roleplay       2        322.04          12.56         51.916s      n/a        
overall        22       246.54          12.56         25.421s      n/a     
  • with Qwen3.6-27B-PRISM-EAGLE3 (Q4_K_M)
Summary (elapsed=327.68s)
category       samples  avg_prompt_t/s  avg_pred_t/s  avg_latency  accept_rate
-------------  -------  --------------  ------------  -----------  -----------
coding         2        172.88          23.30         12.086s      0.5350     
humanities     2        131.91          21.75         18.888s      0.4550     
math           2        41.62           23.63         10.991s      0.5579     
qa             2        80.33           20.40         12.965s      0.4301     
rag            2        420.24          24.32         17.154s      0.6101     
reasoning      2        64.81           23.64         11.117s      0.5579     
stem           2        41.84           23.61         10.951s      0.5579     
writing        2        487.46          21.59         13.348s      0.4895     
multilingual   2        158.60          17.08         15.405s      0.3192     
summarization  2        128.43          23.52         11.276s      0.5475     
roleplay       2        286.61          22.55         29.652s      0.5090     
overall        22       183.16          22.31         14.894s      0.5015   
  • Comparison
python tools/server/bench/speed-bench/speed_bench_compare.py \
  --baseline Qwen3.6-27B-baseline.json --speculative Qwen3.6-27B-eagle3.json

Comparison: baseline=Qwen3.6-27B-baseline.json speculative=Qwen3.6-27B-eagle3.json

category       base_avg_pred_t/s  spec_avg_pred_t/s  decode_speedup  base_avg_latency  spec_avg_latency  latency_speedup  accept_rate
-------------  -----------------  -----------------  --------------  ----------------  ----------------  ---------------  -----------
coding         12.59              23.30              1.85x           20.760s           12.086s           1.72x            0.5350     
humanities     12.56              21.75              1.73x           31.123s           18.888s           1.65x            0.4550     
math           12.56              23.63              1.88x           20.521s           10.991s           1.87x            0.5579     
qa             12.56              20.40              1.62x           19.320s           12.965s           1.49x            0.4301     
rag            12.54              24.32              1.94x           31.599s           17.154s           1.84x            0.6101     
reasoning      12.56              23.64              1.88x           20.644s           11.117s           1.86x            0.5579     
stem           12.56              23.61              1.88x           20.490s           10.951s           1.87x            0.5579     
writing        12.52              21.59              1.72x           21.757s           13.348s           1.63x            0.4895     
multilingual   12.56              17.08              1.36x           20.751s           15.405s           1.35x            0.3192     
summarization  12.56              23.52              1.87x           20.753s           11.276s           1.84x            0.5475     
roleplay       12.56              22.55              1.80x           51.916s           29.652s           1.75x            0.5090     
overall        12.56              22.31              1.78x           25.421s           14.894s           1.71x            0.5015    

Deferred boundary in eagle3 across context checkpoints for hybrid models

Eagle's draft trails the target by one position (input at P is (token[P+1], g_embd[P])), and g_embd from eagle3 encoder is a transient target activation not stored in any KV cache. On recurrent/hybrid targets a checkpoint is single-position, so on restore the draft is at pos_max-1, the target resumes at pos_max+1, and g_embd[pos_max] is lost → the draft can't fill pos_maxllama_decode(ctx_dft) fails (rc=-1).

Solution: stash that one g_embd[pos_max] row with the checkpoint and restore it on load, so the existing bridge fills the boundary and the draft keeps full context.

It's the cheapest fix: recomputing g_embd needs an extra target decode and eagle3 encoder (and is impossible on a restored recurrent state), full re-processing re-runs the prefix, and re-seeding the draft loses context — whereas stashing one row (~20 KB/checkpoint, recurrent/hybrid only) adds no decode and no quality loss.

Example (pos_max = 13)

At checkpoint creation:

  • Target KV covers positions 0..13.
  • Draft KV covers 0..12 — one behind, because decoding draft pos 13 needs token[14], which isn't available yet (the deferred boundary). g_embd[13] exists only as a transient activation at this moment.

On restore:

  • Draft comes back at 12; target resumes at 14 (pos_max+1, since the checkpoint already holds 13). Reprocessing 14, 15, … produces g_embd[14], g_embd[15], … — but not g_embd[13] (13 isn't reprocessed, and it was never saved).
  • Draft at 12 tries to decode 13: it has token[14] but not g_embd[13] → can't. So the draft batch jumps 12 → 14, leaving a hole at 13 → llama_decode(ctx_dft) returns rc=-1.
  • With the fix: we stashed g_embd[13] in the checkpoint, restore it, and the bridge decodes draft pos 13 from (token[14], g_embd[13]) → draft goes 12 → 13 → 14 …, contiguous, full context preserved.

Requirements

  • I have read and agree with the contributing guidelines
  • AI usage disclosure: YES, mainly for code review and debugging.

@AbdulrahmanHashem

Copy link
Copy Markdown

GGUFs are in demand XD

@ruixiang63

Copy link
Copy Markdown
Contributor Author

GGUFs are in demand XD

You can convert it using the convert_hf_to_gguf.py script, similar to other eagle3 models.

Comment thread src/llama-context.cpp Outdated
@MikeLP

MikeLP commented Jun 13, 2026

Copy link
Copy Markdown

How much faster is this method than MTP?

@AbdulrahmanHashem

AbdulrahmanHashem commented Jun 13, 2026

Copy link
Copy Markdown

i converted the Ex0bit/Qwen3.6-27B-PRISM-EAGLE3 to GGUF and ran it with Qwen3.6-27B-Q4_K_M_MTP.gguf
and

./build_9626/bin/llama-server \
--verbosity 4 \
-m /home/abdulrahman/Personal/Programs/llama/models/Qwen/Qwen3.6-27B-Q4_K_M_MTP.gguf \
-md /home/abdulrahman/Personal/Programs/llama/models/Qwen/Qwen3.6-27B-EAGLE3-PRISM.gguf \
--temperature 0.6 --top-p 0.95 --top-k 20 --min-p 0.00 --presence-penalty 0.0 --repeat-penalty 1.0 \
--jinja -fa on \
--spec-type draft-eagle3 \
--spec-draft-n-max 8 \
--spec-draft-p-min 0.5 \
-c 100000 -np 1 -ngl 999

and i'm getting about 16 t/s knowing that base is 22.5 t/s
my guess is that model and draft aren't vibing together
but i thought to say something maybe it's a bug.

i'll be trying rjmalagon/specdrift-qwen3.6-27b-eagle3 shortly.

Update: couldn't convert rjmalagon/specdrift-qwen3.6-27b-eagle3
as it stands there is no compatible eagle3 for unsloth/Qwen3.6-27B-Q4_K_M_MTP.gguf

@ruixiang63

Copy link
Copy Markdown
Contributor Author

How much faster is this method than MTP?

It’s hard to define “how much faster,” since speculative decoding speedups depend heavily on the use cases and prompts. Different models and methods perform better in different scenarios.

We have introduced SpeedBench(#23869) in llama.cpp to systematically measure speedups and compare different approaches. Feel free to try it and benchmark it :)

@ruixiang63

Copy link
Copy Markdown
Contributor Author

"specdrift-qwen3.6-27b-eagle3" is a different eagle3 model AFAIK, see https://vllm.ai/blog/2026-05-26-eagle-3-1

@AbdulrahmanHashem

Copy link
Copy Markdown

"specdrift-qwen3.6-27b-eagle3" is a different eagle3 model AFAIK, see https://vllm.ai/blog/2026-05-26-eagle-3-1

yes i just discovered that it won't convert and as of now i cannot find a compatible eagle3 for unsloth/Qwen3.6-27B-Q4_K_M_MTP.gguf
if you find one please let me know, and thanks for the amazing work ^_^.

@ggerganov ggerganov self-assigned this Jun 14, 2026
@gelim

gelim commented Jun 14, 2026

Copy link
Copy Markdown
Contributor

thanks, gave a try to this branch with Ex0bit GGUF generated as per EAGLE3 PR doc. Text generation crashes instantly like reported in #24541

After patching qwen35.cpp as advised, only single turn generation working but with very bad perf like mentioned before. Sorry no nice table from SpeedBench, but on my Tesla V100, Qwen3.6 27B performs (without MTP) at tg 27 tok/sec usually, but here I get from 15 to 17 tok/sec (promtp "generate pi estimator code in python").

if writing something as second turn, this crash occurs:

3.21.236.241 I slot launch_slot_: id  0 | task 47 | processing task, is_child = 0
3.21.236.272 I slot update_slots: id  0 | task 47 | Checking checkpoint with [10, 10] against 21...
3.21.289.034 W slot update_slots: id  0 | task 47 | restored context checkpoint (pos_min = 10, pos_max = 10, n_tokens = 11, n_past = 11, size = 149.665 MiB)
3.21.407.379 E init: the tokens of sequence 0 in the input batch have inconsistent sequence positions:
 - the last position stored in the memory module of the context (i.e. the KV cache) for sequence 0 is X = 9
 - the tokens for sequence 0 in the input batch have a starting position of Y = 11
 it is required that the sequence positions remain consecutive: Y = X + 1
3.21.407.388 E decode: failed to initialize batch
3.21.407.390 E llama_decode: failed to decode, ret = -1
3.21.407.392 E process: llama_decode(ctx_dft) failed rc=-1 (n_tokens=20, ubatch_pos[0]=11)
3.21.407.396 E srv  update_slots: failed to process speculative batch
3.21.504.498 E init: the tokens of sequence 0 in the input batch have inconsistent sequence positions:
 - the last position stored in the memory module of the context (i.e. the KV cache) for sequence 0 is X = 9
 - the tokens for sequence 0 in the input batch have a starting position of Y = 31
 it is required that the sequence positions remain consecutive: Y = X + 1
3.21.504.506 E decode: failed to initialize batch
3.21.504.507 E llama_decode: failed to decode, ret = -1
3.21.504.509 E process: llama_decode(ctx_dft) failed rc=-1 (n_tokens=21, ubatch_pos[0]=32)
3.21.504.512 E srv  update_slots: failed to process speculative batch
/home/user/llama.cpp-eagl3_qwen3.6/tools/server/server-context.cpp:3275: fatal error - please provide logs and repro in https://github.com/ggml-org/llama.cpp/pull/20277

3.21.565.303 E init: the tokens of sequence 0 in the input batch have inconsistent sequence positions:

@ruixiang63

Copy link
Copy Markdown
Contributor Author

After patching qwen35.cpp as advised, only single turn generation working but with very bad perf like mentioned before. Sorry no nice table from SpeedBench, but on my Tesla V100, Qwen3.6 27B performs (without MTP) at tg 27 tok/sec usually, but here I get from 15 to 17 tok/sec (promtp "generate pi estimator code in python").

Sorry, I can't reproduce this issue. Did you build on top of this PR directly? If you only patch qwen35.cpp, this is not enough. This PR introduced fixes for this.
image

@gelim

gelim commented Jun 14, 2026

Copy link
Copy Markdown
Contributor

On top of your branch yes

Comment thread src/llama-context.cpp Outdated
Comment thread tools/server/server-context.cpp Outdated
@ruixiang63 ruixiang63 marked this pull request as draft June 15, 2026 15:44
@ruixiang63 ruixiang63 force-pushed the eagle3_qwen3.6 branch 3 times, most recently from ebf46af to fd50e23 Compare June 17, 2026 20:10
@ruixiang63 ruixiang63 marked this pull request as ready for review June 18, 2026 14:19
@ruixiang63 ruixiang63 requested a review from a team as a code owner June 18, 2026 14:19
Comment thread common/speculative.h Outdated
ruixiang63 and others added 2 commits June 18, 2026 17:34
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
@ruixiang63

Copy link
Copy Markdown
Contributor Author

@ggerganov Fixed, let me know if any other changes are needed :)

Comment thread common/common.h Outdated
@ggerganov ggerganov added the merge ready A maintainer can use this label to indicate that they consider the changes final and ready to merge. label Jun 18, 2026
@ruixiang63

Copy link
Copy Markdown
Contributor Author

Looks all checks passed. @ggerganov @CISC

Comment thread common/speculative.cpp
@ggerganov ggerganov merged commit b14e3fb into ggml-org:master Jun 19, 2026
1 check passed
@cb88

cb88 commented Jun 19, 2026

Copy link
Copy Markdown

Is this expected to work with -sm tensor like MTP now does?

I get an assert on 2xMI50

0.22.383.898 W set_sampler: backend sampling not supported with SPLIT_MODE_TENSOR; using CPU
0.22.383.900 W common_speculative_impl_draft_eagle3: backend offload failed for seq_id=0; using CPU sampler
0.22.383.904 W common_speculative_impl_draft_eagle3: backend offload failed for seq_id=1; using CPU sampler
0.22.383.904 W common_speculative_impl_draft_eagle3: backend offload failed for seq_id=2; using CPU sampler
0.22.383.905 W common_speculative_impl_draft_eagle3: backend offload failed for seq_id=3; using CPU sampler
/home/cb88/llama.cpp/ggml/src/ggml-backend-meta.cpp:728: GGML_ASSERT(src_ss[0].axis != GGML_BACKEND_SPLIT_AXIS_1) failed

Note that with -sm layer I do get a speed up fro about 16t/s to 24t/s or so with 2xMI50 and Q8_K_M model + Q8 draft

[cb88@MZ31AR0 ~]$ cat llama-eagle.sh
HIP_VISIBLE_DEVICES=0,1  ./llama.cpp/build/bin/llama-server --host 0.0.0.0 \
-fa on -dio --api-key llama \
-sm tensor \
--spec-type draft-eagle3 \
 --spec-draft-n-max 3 \
--spec-draft-p-min 0.0 \
-c 131072 \
-lv 4 \
--jinja \
-m $@

@ruixiang63

Copy link
Copy Markdown
Contributor Author

Is this expected to work with -sm tensor like MTP now does?

I get an assert on 2xMI50

0.22.383.898 W set_sampler: backend sampling not supported with SPLIT_MODE_TENSOR; using CPU 0.22.383.900 W common_speculative_impl_draft_eagle3: backend offload failed for seq_id=0; using CPU sampler 0.22.383.904 W common_speculative_impl_draft_eagle3: backend offload failed for seq_id=1; using CPU sampler 0.22.383.904 W common_speculative_impl_draft_eagle3: backend offload failed for seq_id=2; using CPU sampler 0.22.383.905 W common_speculative_impl_draft_eagle3: backend offload failed for seq_id=3; using CPU sampler /home/cb88/llama.cpp/ggml/src/ggml-backend-meta.cpp:728: GGML_ASSERT(src_ss[0].axis != GGML_BACKEND_SPLIT_AXIS_1) failed

Note that with -sm layer I do get a speed up fro about 16t/s to 24t/s or so with 2xMI50 and Q8_K_M model + Q8 draft

Check this comment #18039 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples merge ready A maintainer can use this label to indicate that they consider the changes final and ready to merge. model Model specific server

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants