Skip to content

spec: support MTP#6

Open
am17an wants to merge 4 commits into
gg/spec-parallelfrom
gg-mtp-rebase
Open

spec: support MTP#6
am17an wants to merge 4 commits into
gg/spec-parallelfrom
gg-mtp-rebase

Conversation

@am17an
Copy link
Copy Markdown
Owner

@am17an am17an commented May 11, 2026

I have removed the partial rollback changes and isolated changes for just qwen models. Things to work out

  • generic MTP loading (support both separate GGUF + grafted onto same GGUF?)
  • vision inputs
  • n_seq > 1
  • partial rollback
  • unaccounted memory
  • prefill speeds

note that partial rollback is extremely important for the speed-up here, for the MoE model there is actually a slowdown with MTP on this branch

@ggerganov
Copy link
Copy Markdown

After the refactoring, all the state management of the draft context is perform outside of common/speculative - i.e. in the server_context. So all the logic for llama_memory_seq_rm can be removed - it is already taken into account in the server:

diff --git a/common/speculative.cpp b/common/speculative.cpp
index ef13edd34..95329b8a6 100644
--- a/common/speculative.cpp
+++ b/common/speculative.cpp
@@ -592,19 +592,6 @@ struct common_speculative_state_mtp : public common_speculative_impl {
         auto & draft_tokens = *dp.result;
         draft_tokens.clear();
 
-        if (last_n_drafted[seq_id] > 0) {
-            const int32_t n_to_drop = (int32_t) last_n_drafted[seq_id] - 1;
-            if (n_to_drop > 0) {
-                const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
-                if (pos_max >= 0) {
-                    const llama_pos drop_from = pos_max - n_to_drop + 1;
-                    llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, drop_from, -1);
-                }
-            }
-            last_n_drafted[seq_id]  = 0;
-            last_n_accepted[seq_id] = 0;
-        }
-
         // Effective draft length: min(global cap, per-sequence override).
         int32_t n_max = std::max(1, params.n_max);
         if (dp.n_max > 0) {
@@ -673,32 +660,9 @@ struct common_speculative_state_mtp : public common_speculative_impl {
             cond_tok = best;
             ++pos;
         }
-
-        last_n_drafted[seq_id] = (uint16_t) draft_tokens.size();
     }
 
     void accept(llama_seq_id seq_id, uint16_t n_accepted) override {
-        GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < last_n_drafted.size());
-
-        auto * ctx_dft = this->params.ctx_dft;
-
-        const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
-        const int32_t   n_drafted_last = (int32_t) last_n_drafted[seq_id];
-
-        const int32_t n_to_drop = std::max(0, n_drafted_last - (int32_t) n_accepted - 1);
-
-        if (pos_max < 0) {
-            last_n_accepted[seq_id] = (int32_t) n_accepted;
-            return;
-        }
-
-        if (n_to_drop > 0) {
-            const llama_pos drop_from = pos_max - n_to_drop + 1;
-            llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, drop_from, -1);
-        }
-
-        last_n_drafted [seq_id] = 0;
-        last_n_accepted[seq_id] = (int32_t) n_accepted;
     }
 };
 

@ggerganov
Copy link
Copy Markdown

Give me ~1 hour an I'll open a PR here to simplify (wip: https://github.com/ggml-org/llama.cpp/tree/gg/spec-mtp-experiments)

Comment thread src/models/qwen35-mtp.cpp
@ggerganov
Copy link
Copy Markdown

note that partial rollback is extremely important for the speed-up here

In the partial rollback implementation, the accepted batch is not re-evaluated with the draft context, correct? I think this will narrow the difference a bit, though not very sure by how much.

Here are the mtp-bench.py results on M2 Ultra:

Qwen3.6-27B:

  • No MTP: "wall_s_total": 71.63
  • MTP (ggml-org:gg/spec-mtp-experiments): "wall_s_total": 43.11
  • MTP (am17an:mtp-clean): wall_s_total": 41.67

Qwen3.6-35B-A3B:

  • No MTP: "wall_s_total": 20.3
  • MTP (ggml-org:gg/spec-mtp-experiments): "wall_s_total": 16.15
  • MTP (am17an:mtp-clean): "wall_s_total": 16.22

@am17an
Copy link
Copy Markdown
Owner Author

am17an commented May 11, 2026

on my DGX spark (patched with adding a draft acceptance loop)

Qwen3.5-35B-A3B (using spec-draft-n-max 2):
No-MTP: "wall_s_total": 27.68
gg-mtp-rebase: "wall_s_total": 26.05
mtp-clean: "wall_s_total": 22.19

Qwen3.5-27B (using spec-draft-n-max 3)
No-MTP: "wall_s_total": 201.10
gg-mtp-rebase: 97.83
mtp-clean: 81.23

@am17an
Copy link
Copy Markdown
Owner Author

am17an commented May 11, 2026

Another thing is mtp-clean doesn't use the pinned memory as this PR, so wall time might be slightly inflated.

@am17an
Copy link
Copy Markdown
Owner Author

am17an commented May 11, 2026

Basically at low acceptance rates < 0.5, the speed difference is going to much larger. From anecdotal usage, using this PR I seem to even hit 9 toks/sec when doing real coding work, vs with partial rollback I never hit below 14 toks/sec even when acceptance is low. You can perhaps try and use it, I felt the difference is quite real.

@ggerganov
Copy link
Copy Markdown

Did you use this branch or #7 ?

@am17an
Copy link
Copy Markdown
Owner Author

am17an commented May 11, 2026

I used this branch, just saw #7

@am17an
Copy link
Copy Markdown
Owner Author

am17an commented May 11, 2026

Just tried #7 as well,

Qwen3.6 27B - "wall_s_total": 100.33
Qwen3.6 35BA3B - "wall_s_total": 26.82

Somehow acceptance rates are suspiciously high, maybe some accounting error

  code_python        pred= 192 draft= 139 acc= 138 rate=0.993 tok/s=19.5
  code_cpp           pred= 192 draft= 129 acc= 127 rate=0.985 tok/s=16.7
  explain_concept    pred= 192 draft= 118 acc= 117 rate=0.992 tok/s=13.7
  summarize          pred=  55 draft=  35 acc=  35 rate=1.000 tok/s=16.0
  qa_factual         pred= 178 draft= 109 acc= 107 rate=0.982 tok/s=13.8
  translation        pred=  23 draft=  13 acc=  12 rate=0.923 tok/s=12.4
  creative_short     pred= 192 draft= 109 acc= 105 rate=0.963 tok/s=12.9
  stepwise_math      pred= 192 draft= 130 acc= 130 rate=1.000 tok/s=16.6
  long_code_review   pred= 192 draft= 119 acc= 115 rate=0.966 tok/s=13.1

For reference in mtp-clean, they are

  code_python        pred= 192 draft= 153 acc= 140 rate=0.915 tok/s=21.0
  code_cpp           pred= 192 draft= 171 acc= 134 rate=0.784 tok/s=17.8
  explain_concept    pred= 192 draft= 180 acc= 130 rate=0.722 tok/s=17.3
  summarize          pred=  55 draft=  54 acc=  36 rate=0.667 tok/s=15.9
  qa_factual         pred= 177 draft= 180 acc= 116 rate=0.644 tok/s=16.5
  translation        pred=  22 draft=  24 acc=  13 rate=0.542 tok/s=15.3
  creative_short     pred= 192 draft= 195 acc= 126 rate=0.646 tok/s=16.5
  stepwise_math      pred= 192 draft= 162 acc= 137 rate=0.846 tok/s=19.9
  long_code_review   pred= 192 draft= 186 acc= 129 rate=0.694 tok/s=15.8

@ggerganov
Copy link
Copy Markdown

Somehow acceptance rates are suspiciously high, maybe some accounting error

With the p_min logic that I added, we don't draft low-prob tokens, so the acceptance is very high.

@ggerganov
Copy link
Copy Markdown

You can observe the accepted drafts with LLAMA_TRACE=1 env variable.

@am17an
Copy link
Copy Markdown
Owner Author

am17an commented May 11, 2026

I think p_min logic will also sample at every step, causing a logit transfer D2H - so it may increase overall time (since draft model is extremely lightweight) not sure if this is right, but p_min does add some time

@ggerganov
Copy link
Copy Markdown

not sure if this is right, but p_min does add some time

Yes, I'm also not sure. On Mac it is always useful for some reason. On CUDA sometimes it helps sometimes not. In any case, it can be adjusted with the --spec-draft-p-min argument.

Regarding the partial rollback - it does bring a noticeable benefit on CUDA. But I still don't see a good way to support it cleanly. Among other drawbacks, the compute graph is also no longer static. The logic is not compatible with ngram speculative decoding because it uses long drafts of ~64 tokens which still need to be checkpointed. And for some reason that I still don't understand, it does not seem to help much on Mac.

Comment thread common/speculative.cpp
Comment on lines +480 to +484
// TODO: how to make it work with vision tokens?
if (batch_in.token == nullptr || batch_in.embd != nullptr) {
pending_pos[seq_id] = -1;
return true;
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not really sure what is the correct way to process the image embeddings with the MTP context. In any case, vision MTP seems to already work to good extent:

Here I ask it to OCR 100 random integers without speculative decoding and with MTP:

  • Without spec decoding
Image
  • With MTP
Image

With MTP it is ~2x faster which means the MTP context "knows" about the integers in some way. But at the same time, I'm pretty sure that the current way of processing is not 100% correct because inp->tokens tensor in the mtp graph is being used with stale data when the input batch has image embeddings and no tokens.

I think we will figure this out later - not super important atm.

@ggerganov
Copy link
Copy Markdown

@am17an I think the changes are good overall.

On my end, I will continue on top of ggml-org#22838 to support specifying multiple speculative decoding types like this:

--spec-type ngram-mod,mtp

Should be simple change and when ready, will proceed with merging ggml-org#22838.

@am17an
Copy link
Copy Markdown
Owner Author

am17an commented May 11, 2026

Regarding the partial rollback - it does bring a noticeable benefit on CUDA. But I still don't see a good way to support it cleanly. Among other drawbacks, the compute graph is also no longer static. The logic is not compatible with ngram speculative decoding because it uses long drafts of ~64 tokens which still need to be checkpointed. And for some reason that I still don't understand, it does not seem to help much on Mac.

We can perhaps just enable this option when MTP is enabled as a spec mode for hybrid models, I think we can also make the compute graph static by only doing rollback when drafts.size() == n_rs_seq. I'm pushing for this change because it's how all other frameworks have implemented MTP for Qwen models, so don't want llama.cpp to be slower in this regard when we already have the code for it.

@ggerganov
Copy link
Copy Markdown

We can iterate on it, but I don't think we can merge MTP directly with the partial rollback changes. These changes have to be in a follow-up PR because they affect a lot of logic: ggml, llama.cpp recurrent state, server logic, backend code. We have to merge something that is solid and works across all hardware, so we can in parallel continue to add other speculative decoding approaches. The partial rollback will be a potential optimization if we figure out how to do it cleanly.

@am17an
Copy link
Copy Markdown
Owner Author

am17an commented May 11, 2026

Yes agreed, for other models it is not even required so first we should get MTP in master and make it stable. As such there are issues with GGUF loading/unloading and general memory-issues that are needed to be fixed. I will keep the partial rollback branch up-to-date so people are free to use it.

So the plan is that you merge ggml-org#22838, and then I rebase github.com/ggml-org/pull/22673 on top on that with the changes here. And then we can probably have another round of review regarding the other parts of the code?

@ggerganov
Copy link
Copy Markdown

I will keep the partial rollback branch up-to-date so people are free to use it.

Ok sounds good.

So the plan is that you merge ggml-org#22838, and then I rebase github.com/ggml-org#22673 on top on that with the changes here. And then we can probably have another round of review regarding the other parts of the code?

Yes. I haven't looked at all at the prompt prefill yet so not sure what is the status there. I think this branch here should perform a bit better thanks to pinned mem. The GGUF loading is probably the most important to figure out how to make it user friendly.

@Zorgonatis
Copy link
Copy Markdown

commenting just on the gguf mtp approach, as a user I believe it would be best to align with the same packaging principles as mmproj and other spec decoding implementations (eagle, dflash etc) - keep optional model features in their own external gguf for maximum flexibility at runtime.

@syzhizhu
Copy link
Copy Markdown

--split-mode tensors become invalid and affect MTP speed.Removing --split-mode tensors restores normal MTP speed.

However, previously --split-mode tensors was available.

@ggerganov
Copy link
Copy Markdown

The argument is --split-mode tensor, no --split-mode tensors

@syzhizhu
Copy link
Copy Markdown

论点是--split-mode tensor,不--split-mode tensors
Typo, but the problem still exists.

@homemdesgraca
Copy link
Copy Markdown

homemdesgraca commented May 11, 2026

Sorry if someone already talked about this, but:

  • The model, along with it's MTP draft model, loads mostly fine with some -fitt headroom (even though it wasn't needed on the old MTP branch).
  • llama-server crashes between interactions if it tries to use checkpoints, but works fine without checkpoints (--ctx-checkpoints 0).
  • The speed is better in CUDA than the older MTP branch (Qwen3.6-35B-A3B; about 18% improvement).

Flags used to reproduce llama-server checkpoint crash (fixed with --ctx-checkpoints 0):

./llama.cpp/build/bin/llama-server --host 0.0.0.0 --port 5000 \
                  -m ./Models/Qwen3.6-35B-A3B-MTP-UD-Q3_K_XL.gguf --fit on \
                  -a "Qwen3.6-35B-A3B" -c 150000 \
                  --top-k 20 --top-p 0.95 --min-p 0 --repeat-penalty 1.0 --presence-penalty 0.0 \
                  -fa on --temp 0.6 -ctk q4_0 -ctv q4_0 --batch-size 4096 \
                  --reasoning off --no-mmap --jinja -t 6 -fitt 1024 \
                  --spec-draft-n-max 2 --spec-type mtp --ctx-checkpoints 32
Logs of llama-server crashing with checkpoints
↪ fish mtp-qwen-35b.fish 
ggml_cuda_init: found 1 CUDA devices (Total VRAM: 11902 MiB):
  Device 0: NVIDIA GeForce RTX 3060, compute capability 8.6, VMM: yes, VRAM: 11902 MiB
main: n_parallel is set to auto, using n_parallel = 4 and kv_unified = true
build_info: b9103-bc2db938f
system_info: n_threads = 6 (n_threads_batch = 6) / 12 | CUDA : ARCHS = 860 | USE_GRAPHS = 1 | PEER_MAX_BATCH_SIZE = 128 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | BMI2 = 1 | LLAMAFILE = 1 | OPENMP = 1 | REPACK = 1 | 
Running without SSL
init: using 11 threads for HTTP server
start: binding port with default address family
main: loading model
srv    load_model: loading model './Models/Qwen3.6-35B-A3B-MTP-UD-Q3_K_XL.gguf'
common_init_result: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on
common_params_fit_impl: getting device memory data for initial parameters:
common_memory_breakdown_print: | memory breakdown [MiB] | total    free     self   model   context   compute    unaccounted |
common_memory_breakdown_print: |   - CUDA0 (RTX 3060)   | 11902 = 10598 + (17963 = 16395 +    1075 +     493) +      -16659 |
common_memory_breakdown_print: |   - Host               |                    816 =   515 +       0 +     301                |
common_params_fit_impl: projected to use 17963 MiB of device memory vs. 10598 MiB of free device memory
common_params_fit_impl: cannot meet free memory target of 1024 MiB, need to reduce device memory by 8388 MiB
common_params_fit_impl: context size set by user to 150000 -> no change
common_params_fit_impl: getting device memory data with all MoE tensors moved to system memory:
common_memory_breakdown_print: | memory breakdown [MiB] | total    free     self   model   context   compute    unaccounted |
common_memory_breakdown_print: |   - CUDA0 (RTX 3060)   | 11902 = 10598 + ( 3629 =  1961 +    1075 +     592) +       -2324 |
common_memory_breakdown_print: |   - Host               |                  15250 = 14949 +       0 +     301                |
common_params_fit_impl: with only dense weights in device memory there is a total surplus of 5945 MiB
common_params_fit_impl: id=0, target=9574 MiB
common_memory_breakdown_print: | memory breakdown [MiB] | total    free     self   model   context   compute    unaccounted |
common_memory_breakdown_print: |   - CUDA0 (RTX 3060)   | 11902 = 10852 + (  894 =     0 +       0 +     894) +         155 |
common_memory_breakdown_print: |   - Host               |                  18299 = 16910 +    1075 +     313                |
common_params_fit_impl: memory for test allocation by device:
common_params_fit_impl: id=0, n_layer= 0, n_part= 0, overflow_type=4, mem=   894 MiB
common_params_fit_impl: filling dense-only layers back-to-front:
common_memory_breakdown_print: | memory breakdown [MiB] | total    free     self   model   context   compute    unaccounted |
common_memory_breakdown_print: |   - CUDA0 (RTX 3060)   | 11902 = 10600 + ( 3961 =  2293 +    1075 +     592) +       -2658 |
common_memory_breakdown_print: |   - Host               |                  14918 = 14617 +       0 +     301                |
common_params_fit_impl: memory for test allocation by device:
common_params_fit_impl: id=0, n_layer=42, n_part=41, overflow_type=4, mem=  3961 MiB
common_params_fit_impl: set ngl_per_device[0].n_layer=42
common_params_fit_impl:   - CUDA0 (NVIDIA GeForce RTX 3060): 42 layers,   3961 MiB used,   6637 MiB free
common_params_fit_impl: converting dense-only layers to full layers and filling them front-to-back with overflow to next device/system memory:
common_memory_breakdown_print: | memory breakdown [MiB] | total    free     self   model   context   compute    unaccounted |
common_memory_breakdown_print: |   - CUDA0 (RTX 3060)   | 11902 = 10603 + (17963 = 16395 +    1075 +     493) +      -16664 |
common_memory_breakdown_print: |   - Host               |                    816 =   515 +       0 +     301                |
common_params_fit_impl: memory for test allocation by device:
common_params_fit_impl: id=0, n_layer=42, n_part= 0, overflow_type=4, mem= 17963 MiB
common_memory_breakdown_print: | memory breakdown [MiB] | total    free    self   model   context   compute    unaccounted |
common_memory_breakdown_print: |   - CUDA0 (RTX 3060)   | 11902 = 10603 + (9293 =  7721 +    1075 +     497) +       -7994 |
common_memory_breakdown_print: |   - Host               |                  9490 =  9189 +       0 +     301                |
common_params_fit_impl: memory for test allocation by device:
common_params_fit_impl: id=0, n_layer=42, n_part=25, overflow_type=4, mem=  9293 MiB
common_params_fit_impl: set ngl_per_device[0].(n_layer, n_part)=(42, 25), id_dense_start=0
common_memory_breakdown_print: | memory breakdown [MiB] | total    free    self   model   context   compute    unaccounted |
common_memory_breakdown_print: |   - CUDA0 (RTX 3060)   | 11902 = 10603 + (9625 =  8053 +    1075 +     497) +       -8326 |
common_memory_breakdown_print: |   - Host               |                  9158 =  8857 +       0 +     301                |
common_params_fit_impl: memory for test allocation by device:
common_params_fit_impl: id=0, n_layer=42, n_part=24, overflow_type=4, mem=  9625 MiB
common_params_fit_impl: set ngl_per_device_high[0].(n_layer, n_part)=(42, 24), id_dense_start_high=0
common_params_fit_impl: trying to fit one extra layer with overflow_type=LAYER_FRACTION_UP
common_memory_breakdown_print: | memory breakdown [MiB] | total    free    self   model   context   compute    unaccounted |
common_memory_breakdown_print: |   - CUDA0 (RTX 3060)   | 11902 = 10603 + (9387 =  7815 +    1075 +     497) +       -8088 |
common_memory_breakdown_print: |   - Host               |                  9396 =  9095 +       0 +     301                |
common_params_fit_impl: memory for test allocation by device:
common_params_fit_impl: id=0, n_layer=42, n_part=25, overflow_type=2, mem=  9387 MiB
common_params_fit_impl: set ngl_per_device[0].(n_layer, n_part, overflow_type)=(42, 25, UP), id_dense_start=0
common_params_fit_impl: trying to fit one extra layer with overflow_type=LAYER_FRACTION_GATE
common_memory_breakdown_print: | memory breakdown [MiB] | total    free    self   model   context   compute    unaccounted |
common_memory_breakdown_print: |   - CUDA0 (RTX 3060)   | 11902 = 10603 + (9488 =  7916 +    1075 +     497) +       -8189 |
common_memory_breakdown_print: |   - Host               |                  9295 =  8994 +       0 +     301                |
common_params_fit_impl: memory for test allocation by device:
common_params_fit_impl: id=0, n_layer=42, n_part=25, overflow_type=3, mem=  9488 MiB
common_params_fit_impl: set ngl_per_device[0].(n_layer, n_part, overflow_type)=(42, 25, GATE), id_dense_start=0
common_params_fit_impl:   - CUDA0 (NVIDIA GeForce RTX 3060): 42 layers (25 overflowing),   9488 MiB used,   1110 MiB free
common_fit_params: successfully fit params to free device memory
common_fit_params: fitting params to free memory took 5.10 seconds
llama_model_loader: loaded meta data with 55 key-value pairs and 753 tensors from ./Models/Qwen3.6-35B-A3B-MTP-UD-Q3_K_XL.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = qwen35moe
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                     general.sampling.top_k i32              = 20
llama_model_loader: - kv   3:                     general.sampling.top_p f32              = 0.950000
llama_model_loader: - kv   4:                      general.sampling.temp f32              = 1.000000
llama_model_loader: - kv   5:                               general.name str              = Qwen3.6-35B-A3B
llama_model_loader: - kv   6:                           general.basename str              = Qwen3.6-35B-A3B
llama_model_loader: - kv   7:                       general.quantized_by str              = Unsloth
llama_model_loader: - kv   8:                         general.size_label str              = 35B-A3B
llama_model_loader: - kv   9:                            general.license str              = apache-2.0
llama_model_loader: - kv  10:                       general.license.link str              = https://huggingface.co/Qwen/Qwen3.6-3...
llama_model_loader: - kv  11:                           general.repo_url str              = https://huggingface.co/unsloth
llama_model_loader: - kv  12:                   general.base_model.count u32              = 1
llama_model_loader: - kv  13:                  general.base_model.0.name str              = Qwen3.6 35B A3B
llama_model_loader: - kv  14:          general.base_model.0.organization str              = Qwen
llama_model_loader: - kv  15:              general.base_model.0.repo_url str              = https://huggingface.co/Qwen/Qwen3.6-3...
llama_model_loader: - kv  16:                               general.tags arr[str,3]       = ["qwen3_5_moe", "qwen", "image-text-t...
llama_model_loader: - kv  17:                   qwen35moe.context_length u32              = 262144
llama_model_loader: - kv  18:                 qwen35moe.embedding_length u32              = 2048
llama_model_loader: - kv  19:             qwen35moe.attention.head_count u32              = 16
llama_model_loader: - kv  20:          qwen35moe.attention.head_count_kv u32              = 2
llama_model_loader: - kv  21:          qwen35moe.rope.dimension_sections arr[i32,4]       = [11, 11, 10, 0]
llama_model_loader: - kv  22:                   qwen35moe.rope.freq_base f32              = 10000000.000000
llama_model_loader: - kv  23: qwen35moe.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv  24:                     qwen35moe.expert_count u32              = 256
llama_model_loader: - kv  25:                qwen35moe.expert_used_count u32              = 8
llama_model_loader: - kv  26:             qwen35moe.attention.key_length u32              = 256
llama_model_loader: - kv  27:           qwen35moe.attention.value_length u32              = 256
llama_model_loader: - kv  28:       qwen35moe.expert_feed_forward_length u32              = 512
llama_model_loader: - kv  29: qwen35moe.expert_shared_feed_forward_length u32              = 512
llama_model_loader: - kv  30:                  qwen35moe.ssm.conv_kernel u32              = 4
llama_model_loader: - kv  31:                   qwen35moe.ssm.state_size u32              = 128
llama_model_loader: - kv  32:                  qwen35moe.ssm.group_count u32              = 16
llama_model_loader: - kv  33:               qwen35moe.ssm.time_step_rank u32              = 32
llama_model_loader: - kv  34:                   qwen35moe.ssm.inner_size u32              = 4096
llama_model_loader: - kv  35:          qwen35moe.full_attention_interval u32              = 4
llama_model_loader: - kv  36:             qwen35moe.rope.dimension_count u32              = 64
llama_model_loader: - kv  37:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  38:                         tokenizer.ggml.pre str              = qwen35
llama_model_loader: - kv  39:                      tokenizer.ggml.tokens arr[str,248320]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  40:                  tokenizer.ggml.token_type arr[i32,248320]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  41:                      tokenizer.ggml.merges arr[str,247587]  = ["Ġ Ġ", "ĠĠ ĠĠ", "i n", "Ġ t",...
llama_model_loader: - kv  42:                tokenizer.ggml.eos_token_id u32              = 248046
llama_model_loader: - kv  43:            tokenizer.ggml.padding_token_id u32              = 248055
llama_model_loader: - kv  44:                tokenizer.ggml.bos_token_id u32              = 248044
llama_model_loader: - kv  45:               tokenizer.ggml.add_bos_token bool             = false
llama_model_loader: - kv  46:                    tokenizer.chat_template str              = {%- set image_count = namespace(value...
llama_model_loader: - kv  47:               general.quantization_version u32              = 2
llama_model_loader: - kv  48:                          general.file_type u32              = 12
llama_model_loader: - kv  49:                      quantize.imatrix.file str              = Qwen3.6-35B-A3B-GGUF/imatrix_unsloth....
llama_model_loader: - kv  50:                   quantize.imatrix.dataset str              = unsloth_calibration_Qwen3.6-35B-A3B.txt
llama_model_loader: - kv  51:             quantize.imatrix.entries_count u32              = 510
llama_model_loader: - kv  52:              quantize.imatrix.chunks_count u32              = 76
llama_model_loader: - kv  53:                      qwen35moe.block_count u32              = 41
llama_model_loader: - kv  54:             qwen35moe.nextn_predict_layers u32              = 1
llama_model_loader: - type  f32:  368 tensors
llama_model_loader: - type q8_0:  264 tensors
llama_model_loader: - type q5_K:    1 tensors
llama_model_loader: - type q6_K:    4 tensors
llama_model_loader: - type iq3_xxs:   78 tensors
llama_model_loader: - type iq4_xs:   38 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type   = Q3_K - Medium
print_info: file size   = 16.51 GiB (4.00 BPW) 
llama_prepare_model_devices: using device CUDA0 (NVIDIA GeForce RTX 3060) (0000:10:00.0) - 10855 MiB free
load: 0 unused tokens
load: printing all EOG tokens:
load:   - 248044 ('<|endoftext|>')
load:   - 248046 ('<|im_end|>')
load:   - 248063 ('<|fim_pad|>')
load:   - 248064 ('<|repo_name|>')
load:   - 248065 ('<|file_sep|>')
load: special tokens cache size = 33
load: token to piece cache size = 1.7581 MB
print_info: arch                  = qwen35moe
print_info: vocab_only            = 0
print_info: no_alloc              = 0
print_info: n_ctx_train           = 262144
print_info: n_embd                = 2048
print_info: n_embd_inp            = 2048
print_info: n_layer               = 41
print_info: n_head                = 16
print_info: n_head_kv             = 2
print_info: n_rot                 = 64
print_info: n_swa                 = 0
print_info: is_swa_any            = 0
print_info: n_embd_head_k         = 256
print_info: n_embd_head_v         = 256
print_info: n_gqa                 = 8
print_info: n_embd_k_gqa          = 512
print_info: n_embd_v_gqa          = 512
print_info: f_norm_eps            = 0.0e+00
print_info: f_norm_rms_eps        = 1.0e-06
print_info: f_clamp_kqv           = 0.0e+00
print_info: f_max_alibi_bias      = 0.0e+00
print_info: f_logit_scale         = 0.0e+00
print_info: f_attn_scale          = 0.0e+00
print_info: f_attn_value_scale    = 0.0000
print_info: n_ff                  = 0
print_info: n_expert              = 256
print_info: n_expert_used         = 8
print_info: n_expert_groups       = 0
print_info: n_group_used          = 0
print_info: causal attn           = 1
print_info: pooling type          = -1
print_info: rope type             = 40
print_info: rope scaling          = linear
print_info: freq_base_train       = 10000000.0
print_info: freq_scale_train      = 1
print_info: n_ctx_orig_yarn       = 262144
print_info: rope_yarn_log_mul     = 0.0000
print_info: rope_finetuned        = unknown
print_info: mrope sections        = [11, 11, 10, 0]
print_info: ssm_d_conv            = 4
print_info: ssm_d_inner           = 4096
print_info: ssm_d_state           = 128
print_info: ssm_dt_rank           = 32
print_info: ssm_n_group           = 16
print_info: ssm_dt_b_c_rms        = 0
print_info: model type            = 35B.A3B
print_info: model params          = 35.51 B
print_info: general.name          = Qwen3.6-35B-A3B
print_info: vocab type            = BPE
print_info: n_vocab               = 248320
print_info: n_merges              = 247587
print_info: BOS token             = 248044 '<|endoftext|>'
print_info: EOS token             = 248046 '<|im_end|>'
print_info: EOT token             = 248046 '<|im_end|>'
print_info: PAD token             = 248055 '<|vision_pad|>'
print_info: LF token              = 198 'Ċ'
print_info: FIM PRE token         = 248060 '<|fim_prefix|>'
print_info: FIM SUF token         = 248062 '<|fim_suffix|>'
print_info: FIM MID token         = 248061 '<|fim_middle|>'
print_info: FIM PAD token         = 248063 '<|fim_pad|>'
print_info: FIM REP token         = 248064 '<|repo_name|>'
print_info: FIM SEP token         = 248065 '<|file_sep|>'
print_info: EOG token             = 248044 '<|endoftext|>'
print_info: EOG token             = 248046 '<|im_end|>'
print_info: EOG token             = 248063 '<|fim_pad|>'
print_info: EOG token             = 248064 '<|repo_name|>'
print_info: EOG token             = 248065 '<|file_sep|>'
print_info: max token length      = 256
load_tensors: loading model tensors, this can take a while... (mmap = false, direct_io = false)
load_tensors: offloading output layer to GPU
load_tensors: offloading 40 repeating layers to GPU
load_tensors: offloaded 42/42 layers to GPU
load_tensors:        CUDA0 model buffer size =  7916.16 MiB
load_tensors:    CUDA_Host model buffer size =  8994.38 MiB
..............................................................................................
common_init_result: added <|endoftext|> logit bias = -inf
common_init_result: added <|im_end|> logit bias = -inf
common_init_result: added <|fim_pad|> logit bias = -inf
common_init_result: added <|repo_name|> logit bias = -inf
common_init_result: added <|file_sep|> logit bias = -inf
llama_context: constructing llama_context
llama_context: n_seq_max     = 4
llama_context: n_ctx         = 150016
llama_context: n_ctx_seq     = 150016
llama_context: n_batch       = 4096
llama_context: n_ubatch      = 512
llama_context: causal_attn   = 1
llama_context: flash_attn    = enabled
llama_context: kv_unified    = true
llama_context: freq_base     = 10000000.0
llama_context: freq_scale    = 1
llama_context: n_ctx_seq (150016) < n_ctx_train (262144) -- the full capacity of the model will not be utilized
llama_context:  CUDA_Host  output buffer size =     3.79 MiB
llama_kv_cache:      CUDA0 KV buffer size =   824.06 MiB
llama_kv_cache: size =  824.06 MiB (150016 cells,  10 layers,  4/1 seqs), K (q4_0):  412.03 MiB, V (q4_0):  412.03 MiB
llama_kv_cache: attn_rot_k = 1, n_embd_head_k_all = 256
llama_kv_cache: attn_rot_v = 1, n_embd_head_k_all = 256
llama_memory_recurrent:      CUDA0 RS buffer size =   251.25 MiB
llama_memory_recurrent: size =  251.25 MiB (     4 cells,  41 layers,  4 seqs), R (f32):   11.25 MiB, S (f32):  240.00 MiB
sched_reserve: reserving ...
sched_reserve: resolving fused Gated Delta Net support:
sched_reserve: fused Gated Delta Net (autoregressive) enabled
sched_reserve: fused Gated Delta Net (chunked) enabled
sched_reserve:      CUDA0 compute buffer size =   497.00 MiB
sched_reserve:  CUDA_Host compute buffer size =   301.29 MiB
sched_reserve: graph nodes  = 3849
sched_reserve: graph splits = 70 (with bs=512), 50 (with bs=1)
sched_reserve: reserve took 135.98 ms, sched copies = 1
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
srv    load_model: loading MTP head from './Models/Qwen3.6-35B-A3B-MTP-UD-Q3_K_XL.gguf' (override_arch=qwen35moe_mtp)
llama_model_loader: loaded meta data with 55 key-value pairs and 753 tensors from ./Models/Qwen3.6-35B-A3B-MTP-UD-Q3_K_XL.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = qwen35moe
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                     general.sampling.top_k i32              = 20
llama_model_loader: - kv   3:                     general.sampling.top_p f32              = 0.950000
llama_model_loader: - kv   4:                      general.sampling.temp f32              = 1.000000
llama_model_loader: - kv   5:                               general.name str              = Qwen3.6-35B-A3B
llama_model_loader: - kv   6:                           general.basename str              = Qwen3.6-35B-A3B
llama_model_loader: - kv   7:                       general.quantized_by str              = Unsloth
llama_model_loader: - kv   8:                         general.size_label str              = 35B-A3B
llama_model_loader: - kv   9:                            general.license str              = apache-2.0
llama_model_loader: - kv  10:                       general.license.link str              = https://huggingface.co/Qwen/Qwen3.6-3...
llama_model_loader: - kv  11:                           general.repo_url str              = https://huggingface.co/unsloth
llama_model_loader: - kv  12:                   general.base_model.count u32              = 1
llama_model_loader: - kv  13:                  general.base_model.0.name str              = Qwen3.6 35B A3B
llama_model_loader: - kv  14:          general.base_model.0.organization str              = Qwen
llama_model_loader: - kv  15:              general.base_model.0.repo_url str              = https://huggingface.co/Qwen/Qwen3.6-3...
llama_model_loader: - kv  16:                               general.tags arr[str,3]       = ["qwen3_5_moe", "qwen", "image-text-t...
llama_model_loader: - kv  17:                   qwen35moe.context_length u32              = 262144
llama_model_loader: - kv  18:                 qwen35moe.embedding_length u32              = 2048
llama_model_loader: - kv  19:             qwen35moe.attention.head_count u32              = 16
llama_model_loader: - kv  20:          qwen35moe.attention.head_count_kv u32              = 2
llama_model_loader: - kv  21:          qwen35moe.rope.dimension_sections arr[i32,4]       = [11, 11, 10, 0]
llama_model_loader: - kv  22:                   qwen35moe.rope.freq_base f32              = 10000000.000000
llama_model_loader: - kv  23: qwen35moe.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv  24:                     qwen35moe.expert_count u32              = 256
llama_model_loader: - kv  25:                qwen35moe.expert_used_count u32              = 8
llama_model_loader: - kv  26:             qwen35moe.attention.key_length u32              = 256
llama_model_loader: - kv  27:           qwen35moe.attention.value_length u32              = 256
llama_model_loader: - kv  28:       qwen35moe.expert_feed_forward_length u32              = 512
llama_model_loader: - kv  29: qwen35moe.expert_shared_feed_forward_length u32              = 512
llama_model_loader: - kv  30:                  qwen35moe.ssm.conv_kernel u32              = 4
llama_model_loader: - kv  31:                   qwen35moe.ssm.state_size u32              = 128
llama_model_loader: - kv  32:                  qwen35moe.ssm.group_count u32              = 16
llama_model_loader: - kv  33:               qwen35moe.ssm.time_step_rank u32              = 32
llama_model_loader: - kv  34:                   qwen35moe.ssm.inner_size u32              = 4096
llama_model_loader: - kv  35:          qwen35moe.full_attention_interval u32              = 4
llama_model_loader: - kv  36:             qwen35moe.rope.dimension_count u32              = 64
llama_model_loader: - kv  37:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  38:                         tokenizer.ggml.pre str              = qwen35
llama_model_loader: - kv  39:                      tokenizer.ggml.tokens arr[str,248320]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  40:                  tokenizer.ggml.token_type arr[i32,248320]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  41:                      tokenizer.ggml.merges arr[str,247587]  = ["Ġ Ġ", "ĠĠ ĠĠ", "i n", "Ġ t",...
llama_model_loader: - kv  42:                tokenizer.ggml.eos_token_id u32              = 248046
llama_model_loader: - kv  43:            tokenizer.ggml.padding_token_id u32              = 248055
llama_model_loader: - kv  44:                tokenizer.ggml.bos_token_id u32              = 248044
llama_model_loader: - kv  45:               tokenizer.ggml.add_bos_token bool             = false
llama_model_loader: - kv  46:                    tokenizer.chat_template str              = {%- set image_count = namespace(value...
llama_model_loader: - kv  47:               general.quantization_version u32              = 2
llama_model_loader: - kv  48:                          general.file_type u32              = 12
llama_model_loader: - kv  49:                      quantize.imatrix.file str              = Qwen3.6-35B-A3B-GGUF/imatrix_unsloth....
llama_model_loader: - kv  50:                   quantize.imatrix.dataset str              = unsloth_calibration_Qwen3.6-35B-A3B.txt
llama_model_loader: - kv  51:             quantize.imatrix.entries_count u32              = 510
llama_model_loader: - kv  52:              quantize.imatrix.chunks_count u32              = 76
llama_model_loader: - kv  53:                      qwen35moe.block_count u32              = 41
llama_model_loader: - kv  54:             qwen35moe.nextn_predict_layers u32              = 1
llama_model_loader: - type  f32:  368 tensors
llama_model_loader: - type q8_0:  264 tensors
llama_model_loader: - type q5_K:    1 tensors
llama_model_loader: - type q6_K:    4 tensors
llama_model_loader: - type iq3_xxs:   78 tensors
llama_model_loader: - type iq4_xs:   38 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type   = Q3_K - Medium
print_info: file size   = 16.51 GiB (4.00 BPW) 
llama_model_create: overriding architecture qwen35moe -> qwen35moe_mtp
llama_prepare_model_devices: using device CUDA0 (NVIDIA GeForce RTX 3060) (0000:10:00.0) - 1309 MiB free
load: 0 unused tokens
load: printing all EOG tokens:
load:   - 248044 ('<|endoftext|>')
load:   - 248046 ('<|im_end|>')
load:   - 248063 ('<|fim_pad|>')
load:   - 248064 ('<|repo_name|>')
load:   - 248065 ('<|file_sep|>')
load: special tokens cache size = 33
load: token to piece cache size = 1.7581 MB
print_info: arch                  = qwen35moe_mtp
print_info: vocab_only            = 0
print_info: no_alloc              = 0
print_info: n_ctx_train           = 262144
print_info: n_embd                = 2048
print_info: n_embd_inp            = 2048
print_info: n_layer               = 41
print_info: n_head                = 16
print_info: n_head_kv             = 2
print_info: n_rot                 = 64
print_info: n_swa                 = 0
print_info: is_swa_any            = 0
print_info: n_embd_head_k         = 256
print_info: n_embd_head_v         = 256
print_info: n_gqa                 = 8
print_info: n_embd_k_gqa          = 512
print_info: n_embd_v_gqa          = 512
print_info: f_norm_eps            = 0.0e+00
print_info: f_norm_rms_eps        = 1.0e-06
print_info: f_clamp_kqv           = 0.0e+00
print_info: f_max_alibi_bias      = 0.0e+00
print_info: f_logit_scale         = 0.0e+00
print_info: f_attn_scale          = 0.0e+00
print_info: f_attn_value_scale    = 0.0000
print_info: n_ff                  = 0
print_info: n_expert              = 256
print_info: n_expert_used         = 8
print_info: n_expert_groups       = 0
print_info: n_group_used          = 0
print_info: causal attn           = 1
print_info: pooling type          = -1
print_info: rope type             = 40
print_info: rope scaling          = linear
print_info: freq_base_train       = 10000000.0
print_info: freq_scale_train      = 1
print_info: n_ctx_orig_yarn       = 262144
print_info: rope_yarn_log_mul     = 0.0000
print_info: rope_finetuned        = unknown
print_info: mrope sections        = [11, 11, 10, 0]
print_info: model type            = ?B
print_info: model params          = 35.51 B
print_info: general.name          = Qwen3.6-35B-A3B
print_info: vocab type            = BPE
print_info: n_vocab               = 248320
print_info: n_merges              = 247587
print_info: BOS token             = 248044 '<|endoftext|>'
print_info: EOS token             = 248046 '<|im_end|>'
print_info: EOT token             = 248046 '<|im_end|>'
print_info: PAD token             = 248055 '<|vision_pad|>'
print_info: LF token              = 198 'Ċ'
print_info: FIM PRE token         = 248060 '<|fim_prefix|>'
print_info: FIM SUF token         = 248062 '<|fim_suffix|>'
print_info: FIM MID token         = 248061 '<|fim_middle|>'
print_info: FIM PAD token         = 248063 '<|fim_pad|>'
print_info: FIM REP token         = 248064 '<|repo_name|>'
print_info: FIM SEP token         = 248065 '<|file_sep|>'
print_info: EOG token             = 248044 '<|endoftext|>'
print_info: EOG token             = 248046 '<|im_end|>'
print_info: EOG token             = 248063 '<|fim_pad|>'
print_info: EOG token             = 248064 '<|repo_name|>'
print_info: EOG token             = 248065 '<|file_sep|>'
print_info: max token length      = 256
load_tensors: loading model tensors, this can take a while... (mmap = false, direct_io = false)
done_getting_tensors: partial load — used 23 of 753 tensors in the file (rest belong to a sibling model on the same .gguf)
load_tensors: offloading output layer to GPU
load_tensors: offloading 40 repeating layers to GPU
load_tensors: offloaded 42/42 layers to GPU
load_tensors:        CUDA0 model buffer size =   437.75 MiB
load_tensors:    CUDA_Host model buffer size =  1331.31 MiB
....llama_context: constructing llama_context
llama_context: n_seq_max     = 1
llama_context: n_ctx         = 150016
llama_context: n_ctx_seq     = 150016
llama_context: n_batch       = 4096
llama_context: n_ubatch      = 512
llama_context: causal_attn   = 1
llama_context: flash_attn    = enabled
llama_context: kv_unified    = true
llama_context: freq_base     = 10000000.0
llama_context: freq_scale    = 1
llama_context: n_ctx_seq (150016) < n_ctx_train (262144) -- the full capacity of the model will not be utilized
llama_context:  CUDA_Host  output buffer size =     0.95 MiB
llama_kv_cache:      CUDA0 KV buffer size =    82.41 MiB
llama_kv_cache: size =   82.41 MiB (150016 cells,   1 layers,  1/1 seqs), K (q4_0):   41.20 MiB, V (q4_0):   41.20 MiB
llama_kv_cache: attn_rot_k = 1, n_embd_head_k_all = 256
llama_kv_cache: attn_rot_v = 1, n_embd_head_k_all = 256
sched_reserve: reserving ...
sched_reserve: resolving fused Gated Delta Net support:
sched_reserve: fused Gated Delta Net (autoregressive) enabled
sched_reserve: fused Gated Delta Net (chunked) enabled
sched_reserve:      CUDA0 compute buffer size =   497.00 MiB
sched_reserve:  CUDA_Host compute buffer size =   301.28 MiB
sched_reserve: graph nodes  = 98
sched_reserve: graph splits = 5 (with bs=512), 4 (with bs=1)
sched_reserve: reserve took 131.37 ms, sched copies = 1
srv    load_model: initializing slots, n_slots = 4
common_context_can_seq_rm: the context does not support partial sequence removal
srv    load_model: speculative decoding will use checkpoints
srv    load_model: speculative decoding context initialized
slot   load_model: id  0 | task -1 | new slot, n_ctx = 150016
slot   load_model: id  1 | task -1 | new slot, n_ctx = 150016
slot   load_model: id  2 | task -1 | new slot, n_ctx = 150016
slot   load_model: id  3 | task -1 | new slot, n_ctx = 150016
srv    load_model: prompt cache is enabled, size limit: 8192 MiB
srv    load_model: use `--cache-ram 0` to disable the prompt cache
srv    load_model: for more info see https://github.com/ggml-org/llama.cpp/pull/16391
srv          init: init: idle slots will be saved to prompt cache and cleared upon starting a new task
init: chat template, example_format: '<|im_start|>system
You are a helpful assistant<|im_end|>
<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there<|im_end|>
<|im_start|>user
How are you?<|im_end|>
<|im_start|>assistant
<think>

</think>

'
srv          init: init: chat template, thinking = 0
main: model loaded
main: server is listening on http://0.0.0.0:5000
main: starting the main loop...
srv  update_slots: all slots are idle
srv  params_from_: Chat format: peg-native
slot get_availabl: id  3 | task -1 | selected slot by LRU, t_last = -1
srv  get_availabl: updating prompt cache
srv          load:  - looking for better prompt, base f_keep = -1.000, sim = 0.000
srv        update:  - cache state: 0 prompts, 0.000 MiB (limits: 8192.000 MiB, 150016 tokens, 8589934592 est)
srv  get_availabl: prompt cache update took 0.01 ms
reasoning-budget: activated, budget=2147483647 tokens
reasoning-budget: deactivated (natural end)
slot launch_slot_: id  3 | task -1 | sampler chain: logits -> ?penalties -> ?dry -> ?top-n-sigma -> top-k -> ?typical -> top-p -> ?min-p -> ?xtc -> temp-ext -> dist 
slot launch_slot_: id  3 | task 0 | processing task, is_child = 0
slot update_slots: id  3 | task 0 | new prompt, n_ctx_slot = 150016, n_keep = 0, task.n_tokens = 3258
slot update_slots: id  3 | task 0 | n_tokens = 0, memory_seq_rm [0, end)
slot update_slots: id  3 | task 0 | prompt processing progress, n_tokens = 2742, batch.n_tokens = 2742, progress = 0.841621
slot update_slots: id  3 | task 0 | n_tokens = 2742, memory_seq_rm [2742, end)
slot update_slots: id  3 | task 0 | prompt processing progress, n_tokens = 3254, batch.n_tokens = 512, progress = 0.998772
slot create_check: id  3 | task 0 | created context checkpoint 1 of 32 (pos_min = 2741, pos_max = 2741, n_tokens = 2742, size = 64.361 MiB)
slot update_slots: id  3 | task 0 | n_tokens = 3254, memory_seq_rm [3254, end)
slot init_sampler: id  3 | task 0 | init sampler, took 0.41 ms, tokens: text = 3258, total = 3258
slot update_slots: id  3 | task 0 | prompt processing done, n_tokens = 3258, batch.n_tokens = 4
slot create_check: id  3 | task 0 | created context checkpoint 2 of 32 (pos_min = 3253, pos_max = 3253, n_tokens = 3254, size = 64.650 MiB)
srv  log_server_r: done request: POST /v1/chat/completions 127.0.0.1 200
~llama_io_write_device: allocated 'CUDA0' buffer 62.812 MiB
slot print_timing: id  3 | task 0 | 
prompt eval time =    8584.47 ms /  3258 tokens (    2.63 ms per token,   379.52 tokens per second)
       eval time =    1287.69 ms /    57 tokens (   22.59 ms per token,    44.27 tokens per second)
      total time =    9872.17 ms /  3315 tokens
draft acceptance rate = 1.00000 (   36 accepted /    36 generated)
statistics mtp: #calls(b,g,a) = 1 20 18, #gen drafts = 18, #acc drafts = 18, #gen tokens = 36, #acc tokens = 36, dur(b,g,a) = 0.004, 143.316, 0.009 ms
slot      release: id  3 | task 0 | stop processing: n_tokens = 3314, truncated = 0
srv  update_slots: all slots are idle
srv  params_from_: Chat format: peg-native
slot get_availabl: id  3 | task -1 | selected slot by LCP similarity, sim_best = 0.959 (> 0.100 thold), f_keep = 1.000
reasoning-budget: activated, budget=2147483647 tokens
reasoning-budget: deactivated (natural end)
slot launch_slot_: id  3 | task -1 | sampler chain: logits -> ?penalties -> ?dry -> ?top-n-sigma -> top-k -> ?typical -> top-p -> ?min-p -> ?xtc -> temp-ext -> dist 
slot launch_slot_: id  3 | task 25 | processing task, is_child = 0
slot update_slots: id  3 | task 25 | new prompt, n_ctx_slot = 150016, n_keep = 0, task.n_tokens = 3454
slot update_slots: id  3 | task 25 | n_tokens = 3314, memory_seq_rm [3314, end)
slot update_slots: id  3 | task 25 | prompt processing progress, n_tokens = 3450, batch.n_tokens = 136, progress = 0.998842
slot update_slots: id  3 | task 25 | n_tokens = 3450, memory_seq_rm [3450, end)
slot init_sampler: id  3 | task 25 | init sampler, took 0.43 ms, tokens: text = 3454, total = 3454
slot update_slots: id  3 | task 25 | prompt processing done, n_tokens = 3454, batch.n_tokens = 4
slot create_check: id  3 | task 25 | created context checkpoint 3 of 32 (pos_min = 3449, pos_max = 3449, n_tokens = 3450, size = 64.761 MiB)
srv  log_server_r: done request: POST /v1/chat/completions 127.0.0.1 200
slot print_timing: id  3 | task 25 | 
prompt eval time =     909.06 ms /   140 tokens (    6.49 ms per token,   154.01 tokens per second)
       eval time =    1745.13 ms /    75 tokens (   23.27 ms per token,    42.98 tokens per second)
      total time =    2654.19 ms /   215 tokens
draft acceptance rate = 1.00000 (   43 accepted /    43 generated)
statistics mtp: #calls(b,g,a) = 2 51 40, #gen drafts = 40, #acc drafts = 40, #gen tokens = 79, #acc tokens = 79, dur(b,g,a) = 0.006, 337.486, 0.017 ms
slot      release: id  3 | task 25 | stop processing: n_tokens = 3528, truncated = 0
srv  update_slots: all slots are idle
srv  params_from_: Chat format: peg-native
slot get_availabl: id  3 | task -1 | selected slot by LCP similarity, sim_best = 0.966 (> 0.100 thold), f_keep = 1.000
reasoning-budget: activated, budget=2147483647 tokens
reasoning-budget: deactivated (natural end)
slot launch_slot_: id  3 | task -1 | sampler chain: logits -> ?penalties -> ?dry -> ?top-n-sigma -> top-k -> ?typical -> top-p -> ?min-p -> ?xtc -> temp-ext -> dist 
slot launch_slot_: id  3 | task 60 | processing task, is_child = 0
slot update_slots: id  3 | task 60 | new prompt, n_ctx_slot = 150016, n_keep = 0, task.n_tokens = 3652
slot update_slots: id  3 | task 60 | n_tokens = 3528, memory_seq_rm [3528, end)
slot update_slots: id  3 | task 60 | prompt processing progress, n_tokens = 3648, batch.n_tokens = 120, progress = 0.998905
slot create_check: id  3 | task 60 | created context checkpoint 4 of 32 (pos_min = 3527, pos_max = 3527, n_tokens = 3528, size = 64.805 MiB)
slot update_slots: id  3 | task 60 | n_tokens = 3648, memory_seq_rm [3648, end)
slot init_sampler: id  3 | task 60 | init sampler, took 0.45 ms, tokens: text = 3652, total = 3652
slot update_slots: id  3 | task 60 | prompt processing done, n_tokens = 3652, batch.n_tokens = 4
slot create_check: id  3 | task 60 | created context checkpoint 5 of 32 (pos_min = 3647, pos_max = 3647, n_tokens = 3648, size = 64.873 MiB)
srv  log_server_r: done request: POST /v1/chat/completions 127.0.0.1 200
slot print_timing: id  3 | task 60 | 
prompt eval time =     896.19 ms /   124 tokens (    7.23 ms per token,   138.36 tokens per second)
       eval time =    1836.62 ms /    81 tokens (   22.67 ms per token,    44.10 tokens per second)
      total time =    2732.81 ms /   205 tokens
draft acceptance rate = 1.00000 (   50 accepted /    50 generated)
statistics mtp: #calls(b,g,a) = 3 81 68, #gen drafts = 68, #acc drafts = 68, #gen tokens = 129, #acc tokens = 129, dur(b,g,a) = 0.008, 546.364, 0.026 ms
slot      release: id  3 | task 60 | stop processing: n_tokens = 3732, truncated = 0
srv  update_slots: all slots are idle
srv  params_from_: Chat format: peg-native
slot get_availabl: id  3 | task -1 | selected slot by LCP similarity, sim_best = 0.720 (> 0.100 thold), f_keep = 1.000
reasoning-budget: activated, budget=2147483647 tokens
reasoning-budget: deactivated (natural end)
slot launch_slot_: id  3 | task -1 | sampler chain: logits -> ?penalties -> ?dry -> ?top-n-sigma -> top-k -> ?typical -> top-p -> ?min-p -> ?xtc -> temp-ext -> dist 
slot launch_slot_: id  3 | task 94 | processing task, is_child = 0
slot update_slots: id  3 | task 94 | new prompt, n_ctx_slot = 150016, n_keep = 0, task.n_tokens = 5182
slot update_slots: id  3 | task 94 | n_tokens = 3732, memory_seq_rm [3732, end)
slot update_slots: id  3 | task 94 | prompt processing progress, n_tokens = 4666, batch.n_tokens = 934, progress = 0.900425
slot update_slots: id  3 | task 94 | n_tokens = 4666, memory_seq_rm [4666, end)
slot update_slots: id  3 | task 94 | prompt processing progress, n_tokens = 5178, batch.n_tokens = 512, progress = 0.999228
slot create_check: id  3 | task 94 | created context checkpoint 6 of 32 (pos_min = 4665, pos_max = 4665, n_tokens = 4666, size = 65.448 MiB)
slot update_slots: id  3 | task 94 | n_tokens = 5178, memory_seq_rm [5178, end)
slot init_sampler: id  3 | task 94 | init sampler, took 0.63 ms, tokens: text = 5182, total = 5182
slot update_slots: id  3 | task 94 | prompt processing done, n_tokens = 5182, batch.n_tokens = 4
slot create_check: id  3 | task 94 | created context checkpoint 7 of 32 (pos_min = 5177, pos_max = 5177, n_tokens = 5178, size = 65.737 MiB)
srv  log_server_r: done request: POST /v1/chat/completions 127.0.0.1 200
slot print_timing: id  3 | task 94 | 
prompt eval time =    3223.94 ms /  1450 tokens (    2.22 ms per token,   449.76 tokens per second)
       eval time =    7558.78 ms /   281 tokens (   26.90 ms per token,    37.18 tokens per second)
      total time =   10782.72 ms /  1731 tokens
draft acceptance rate = 0.98684 (  150 accepted /   152 generated)
statistics mtp: #calls(b,g,a) = 4 211 159, #gen drafts = 159, #acc drafts = 159, #gen tokens = 281, #acc tokens = 279, dur(b,g,a) = 0.010, 1368.242, 0.058 ms
slot      release: id  3 | task 94 | stop processing: n_tokens = 5462, truncated = 0
srv  update_slots: all slots are idle
srv  params_from_: Chat format: peg-native
slot get_availabl: id  3 | task -1 | selected slot by LCP similarity, sim_best = 0.996 (> 0.100 thold), f_keep = 0.591
reasoning-budget: activated, budget=2147483647 tokens
reasoning-budget: deactivated (natural end)
slot launch_slot_: id  3 | task -1 | sampler chain: logits -> ?penalties -> ?dry -> ?top-n-sigma -> top-k -> ?typical -> top-p -> ?min-p -> ?xtc -> temp-ext -> dist 
slot launch_slot_: id  3 | task 236 | processing task, is_child = 0
slot update_slots: id  3 | task 236 | new prompt, n_ctx_slot = 150016, n_keep = 0, task.n_tokens = 3243
slot update_slots: id  3 | task 236 | n_past = 3229, slot.prompt.tokens.size() = 5462, seq_id = 3, pos_min = 5461, n_swa = 0
slot update_slots: id  3 | task 236 | Checking checkpoint with [5177, 5177] against 3229...
slot update_slots: id  3 | task 236 | Checking checkpoint with [4665, 4665] against 3229...
slot update_slots: id  3 | task 236 | Checking checkpoint with [3647, 3647] against 3229...
slot update_slots: id  3 | task 236 | Checking checkpoint with [3527, 3527] against 3229...
slot update_slots: id  3 | task 236 | Checking checkpoint with [3449, 3449] against 3229...
slot update_slots: id  3 | task 236 | Checking checkpoint with [3253, 3253] against 3229...
slot update_slots: id  3 | task 236 | Checking checkpoint with [2741, 2741] against 3229...
/home/homemdesgraca/Misc/llama.cpp/common/common.cpp:2059: checkpoint size mismatch: expected 1623312, got 0

state_read_meta: invalid seq_id-agnostic kv cell
state_seq_set_data: error loading state: failed to restore kv cache
[New LWP 50384]
[New LWP 50383]
[New LWP 50382]
[New LWP 50381]
[New LWP 50380]
[New LWP 50076]
[New LWP 50075]
[New LWP 50074]
[New LWP 50073]
[New LWP 50072]
[New LWP 50071]
[New LWP 50070]
[New LWP 50069]
[New LWP 50068]
[New LWP 50067]
[New LWP 50066]
[New LWP 50065]
[New LWP 50064]
[New LWP 50063]
[New LWP 50062]
[New LWP 50060]

This GDB supports auto-downloading debuginfo from the following URLs:
  <https://debuginfod.archlinux.org>
  <https://debuginfod.cachyos.org>
Enable debuginfod for this session? (y or [n]) [answered N; input not from terminal]
Debuginfod has been disabled.
To make this setting permanent, add 'set debuginfod enabled off' to .gdbinit.
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/usr/lib/libthread_db.so.1".
0x00007fb4584b4e22 in ?? () from /usr/lib/libc.so.6
#0  0x00007fb4584b4e22 in ?? () from /usr/lib/libc.so.6
#1  0x00007fb4584a8178 in ?? () from /usr/lib/libc.so.6
#2  0x00007fb45852fa6b in wait4 () from /usr/lib/libc.so.6
#3  0x00007fb462efd37b in ggml_print_backtrace () from /home/homemdesgraca/Misc/llama.cpp/build/bin/libggml-base.so.0
#4  0x00007fb462efd50e in ggml_abort () from /home/homemdesgraca/Misc/llama.cpp/build/bin/libggml-base.so.0
#5  0x00007fb462c1a6ee in common_prompt_checkpoint::load_dft(llama_context*, int, unsigned int) const () from /home/homemdesgraca/Misc/llama.cpp/build/bin/libllama-common.so.0
#6  0x00005644b54e29d9 in server_context_impl::update_slots() ()
#7  0x00005644b5582ab1 in server_queue::start_loop(long) ()
#8  0x00005644b5438c52 in main ()
[Inferior 1 (process 50059) detached]

Btw, thanks ggerganov and am17an for the insane work being done here :]

@ggerganov
Copy link
Copy Markdown

If you increase the -fitt to 2048 does it still crash?

@homemdesgraca
Copy link
Copy Markdown

homemdesgraca commented May 11, 2026

If you increase the -fitt to 2048 does it still crash?

-fitt = 512 -> Loads main model then fails loading MTP
-fitt = 1024 -> Loads both but crashes on first inference
-fitt = 1500 -> Loads and inferences correctly (if ctx-checkpoints is set to 0)

@ggerganov
Copy link
Copy Markdown

This branch starts with -np 4 by default, which requires slightly more memory. In your case, you should stick to -np 1 and I think it should work OK.

@homemdesgraca
Copy link
Copy Markdown

homemdesgraca commented May 11, 2026

This branch starts with -np 4 by default, which requires slightly more memory. In your case, you should stick to -np 1 and I think it should work OK.

Still fails to load with -fitt 512 (which worked fine for the older am17an's MTP branch)

llama_context:  CUDA_Host  output buffer size =     0.95 MiB
llama_kv_cache:      CUDA0 KV buffer size =    82.41 MiB
llama_kv_cache: size =   82.41 MiB (150016 cells,   1 layers,  1/1 seqs), K (q4_0):   41.20 MiB, V (q4_0):   41.20 MiB
llama_kv_cache: attn_rot_k = 1, n_embd_head_k_all = 256
llama_kv_cache: attn_rot_v = 1, n_embd_head_k_all = 256
sched_reserve: reserving ...
sched_reserve: resolving fused Gated Delta Net support:
sched_reserve: fused Gated Delta Net (autoregressive) enabled
sched_reserve: fused Gated Delta Net (chunked) enabled
ggml_backend_cuda_buffer_type_alloc_buffer: allocating 497.00 MiB on device 0: cudaMalloc failed: out of memory
ggml_gallocr_reserve_n_impl: failed to allocate CUDA0 buffer of size 521142272
graph_reserve: failed to allocate compute buffers
llama_init_from_model: failed to initialize the context: failed to allocate compute pp buffers
srv    load_model: failed to create MTP context
srv    operator(): operator(): cleaning up before exit...
main: exiting due to model loading error

--fit prob just isn't accounting for the MTP weights and context, which, on the older branch, it did.


Also, I think the main thing is that it crashes with checkpoints, mainly this line (with --ctx-checkpoints > 0):
.../llama.cpp/common/common.cpp:2059: checkpoint size mismatch: expected 1623312, got 0
Maybe the main model context and the draft model contexts don't match and it gets confused? Not sure.

@ggerganov
Copy link
Copy Markdown

--fit prob just isn't accounting for the MTP weights and context, which, on the older branch, it did.

Ah, I missed that there was such logic. In that case, it is better to wait for the updated branch that will likely include this logic again. Here we are mainly prototyping the speculative architecture.

@AbdulrahmanHashem
Copy link
Copy Markdown

for testing i load it like so

./build/bin/llama-server
-m /models/Qwen/Qwen3.6-35B-A3B-MTP-UD-Q4_K_XL.gguf
--chat_template_kwargs '{"preserve_thinking": "True"}'
--jinja
-fa on
--main-gpu 0
--temperature 0.7 --top-p 0.95 --top-k 20 --min-p 0.00 --presence-penalty 0.0 --repeat-penalty 1.0
-t 10 -tb 10
--host 0.0.0.0
--port 8080
-c 0 -fitc 202144
--cache-type-k q8_0 --cache-type-v q4_0
--spec-type mtp --spec-draft-n-max 3
--fit-target 768,1280 \ <--- you'll need to change this depending on how many gpus you got.
--mlock
-ub 300

this works

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants