From 819cc8b8c01b879b3668208c372db15eefb4b03e Mon Sep 17 00:00:00 2001 From: Leechael Yim Date: Wed, 6 May 2026 01:56:13 +0800 Subject: [PATCH] llama: castle DFlash drafter + DDTree spec-decode (full feature stack) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit All llama-level changes for DFlash + DDTree, on top of luce-org ggml stack. This is the entire castle implementation — DFlash drafter, DDTree builder & verifier, persist-rollback recurrent cache, server slot integration, tests, plus the LLAMA_DDTREE_* fast-path env knobs. Major areas: - batch: parent_id field + llama_batch_init_tree + tree-mode ubatch propagation - graph: build_inp_tree + ancestor mask in kq_mask + read_only_tree reuse - kv-cache: llama_kv_cache_seq_compact_tree - memory-recurrent: snapshot/restore/release API (note: replaceable by upstream ggml-org#19493 checkpoint mechanism) - context: dflash persist buffers + rollback + recurrent_tail_pos + capture_hidden + dflash_draft_top_k + target_feat injection - model: LLM_ARCH_DFLASH_DRAFT load path + capture_layers hparams - src/models/dflash-draft.cpp: 5-layer draft graph (3 modes: full / fuse_only / kv_update_only), shared lm_head with target - src/models/qwen35.cpp: hidden-state capture for 5 layers + tree-mode dispatch - common/speculative-tree.{h,cpp}: ddtree builder + verifier walk + visibility mask - common/speculative-tree-driver.{h,cpp}: spec-decode coordinator with chain validate, fast batched, fast rollback, AR fallback, snapshot replay paths - common/speculative-draft-backend.{h,cpp}: draft backend abstraction - common/sampling: common_sampler_grammar_token_valid for grammar-aware verify - tools/server/server-context.cpp: DDTree slot lifecycle + grammar-aware verify_cbs - examples/speculative-tree/main.cpp: standalone CLI - tests/test-speculative-tree*.cpp + test-qwen35-*.cpp + test-dflash-draft.cpp: unit + acceptance suites Track A: production castle stack. Stacked on track-a/ggml. Tracking: #3 (Phase 1, Track A, llama layer) --- README.md | 6 + common/CMakeLists.txt | 6 + common/arg.cpp | 54 + common/common.h | 7 + common/sampling.cpp | 47 +- common/sampling.h | 6 + common/speculative-draft-backend.cpp | 327 +++ common/speculative-draft-backend.h | 69 + common/speculative-tree-driver.cpp | 1194 +++++++++ common/speculative-tree-driver.h | 124 + common/speculative-tree.cpp | 296 +++ common/speculative-tree.h | 116 + docs/ddtree-dataset-eval-plan.md | 229 ++ examples/CMakeLists.txt | 3 + examples/speculative-tree/CMakeLists.txt | 7 + examples/speculative-tree/main.cpp | 318 +++ include/llama.h | 163 +- src/CMakeLists.txt | 1 + src/llama-arch.cpp | 9 + src/llama-arch.h | 5 + src/llama-batch.cpp | 56 +- src/llama-batch.h | 2 + src/llama-context.cpp | 2255 ++++++++++++++++- src/llama-context.h | 197 +- src/llama-graph.cpp | 211 +- src/llama-graph.h | 163 +- src/llama-hparams.h | 7 + src/llama-kv-cache.cpp | 169 +- src/llama-kv-cache.h | 12 +- src/llama-memory-recurrent.cpp | 144 +- src/llama-memory-recurrent.h | 26 + src/llama-model-loader.cpp | 3 + src/llama-model.cpp | 124 + src/llama-model.h | 4 + src/models/delta-net-base.cpp | 51 + src/models/dflash-draft.cpp | 317 +++ src/models/models.h | 19 + src/models/qwen35.cpp | 136 +- tests/.gitignore | 2 + tests/CMakeLists.txt | 58 + tests/fixtures/ddtree/README.md | 332 +++ .../ddtree/dflash_draft_metadata_smoke.json | 11 + tests/fixtures/ddtree/make_short_prompt.py | 77 + tests/fixtures/ddtree/short_prompt.bin | Bin 0 -> 64 bytes tests/fixtures/ddtree/tree_5node.json | 11 + tests/test-dflash-draft.cpp | 350 +++ tests/test-qwen35-chain-capture.cpp | 354 +++ tests/test-qwen35-root-vs-chain.cpp | 635 +++++ tests/test-qwen35-tree-rollback.cpp | 299 +++ tests/test-qwen35-tree.cpp | 311 +++ tests/test-speculative-draft-backend.cpp | 99 + tests/test-speculative-tree-e2e.cpp | 765 ++++++ tests/test-speculative-tree.cpp | 336 +++ tools/server/README.md | 36 + tools/server/server-context.cpp | 411 ++- 55 files changed, 10857 insertions(+), 113 deletions(-) create mode 100644 common/speculative-draft-backend.cpp create mode 100644 common/speculative-draft-backend.h create mode 100644 common/speculative-tree-driver.cpp create mode 100644 common/speculative-tree-driver.h create mode 100644 common/speculative-tree.cpp create mode 100644 common/speculative-tree.h create mode 100644 docs/ddtree-dataset-eval-plan.md create mode 100644 examples/speculative-tree/CMakeLists.txt create mode 100644 examples/speculative-tree/main.cpp create mode 100644 src/models/dflash-draft.cpp create mode 100644 tests/fixtures/ddtree/README.md create mode 100644 tests/fixtures/ddtree/dflash_draft_metadata_smoke.json create mode 100755 tests/fixtures/ddtree/make_short_prompt.py create mode 100644 tests/fixtures/ddtree/short_prompt.bin create mode 100644 tests/fixtures/ddtree/tree_5node.json create mode 100644 tests/test-dflash-draft.cpp create mode 100644 tests/test-qwen35-chain-capture.cpp create mode 100644 tests/test-qwen35-root-vs-chain.cpp create mode 100644 tests/test-qwen35-tree-rollback.cpp create mode 100644 tests/test-qwen35-tree.cpp create mode 100644 tests/test-speculative-draft-backend.cpp create mode 100644 tests/test-speculative-tree-e2e.cpp create mode 100644 tests/test-speculative-tree.cpp diff --git a/README.md b/README.md index be23abcea67..1b261fea0c6 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,12 @@ LLM inference in C/C++ +## DFlash / DDTree local notes + +This fork carries experimental Qwen3.5 DFlash / DDTree work. Dataset benchmark +setup, Castle commands, correctness caveats, and current llama.cpp-vs-Python +numbers are tracked in [docs/ddtree-dataset-eval-plan.md](docs/ddtree-dataset-eval-plan.md). + ## Recent API changes - [Changelog for `libllama` API](https://github.com/ggml-org/llama.cpp/issues/9289) diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index b313a7320e5..b9878b34182 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -90,6 +90,12 @@ add_library(${TARGET} STATIC sampling.h speculative.cpp speculative.h + speculative-draft-backend.cpp + speculative-draft-backend.h + speculative-tree.cpp + speculative-tree.h + speculative-tree-driver.cpp + speculative-tree-driver.h unicode.cpp unicode.h jinja/lexer.cpp diff --git a/common/arg.cpp b/common/arg.cpp index bf8a5304501..61fcc398a76 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -601,6 +601,11 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context throw std::invalid_argument("error: --model is required\n"); } + // DDTree mode requires a draft model + if (params.speculative.ddtree_mode && !params.speculative.has_dft()) { + throw std::invalid_argument("error: --speculative-mode ddtree requires -md/--model-draft\n"); + } + if (params.escape) { string_process_escapes(params.prompt); string_process_escapes(params.input_prefix); @@ -3554,6 +3559,55 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.speculative.ngram_min_hits = value; } ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--speculative-mode"}, "[chain|ddtree]", + "speculative decoding mode: 'chain' = standard draft-model chain (default), " + "'ddtree' = DDTree dflash-draft speculative decoding (requires -md)", + [](common_params & params, const std::string & value) { + if (value == "chain") { + params.speculative.ddtree_mode = false; + } else if (value == "ddtree") { + params.speculative.ddtree_mode = true; + } else { + throw std::invalid_argument("--speculative-mode must be 'chain' or 'ddtree'"); + } + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SPECULATIVE_MODE")); + add_opt(common_arg( + {"--ddtree-budget"}, "N", + string_format("DDTree: total tree node budget per spec step (default: %d)", params.speculative.ddtree_budget), + [](common_params & params, int value) { + if (value < 1) { + throw std::invalid_argument("--ddtree-budget must be >= 1"); + } + params.speculative.ddtree_budget = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_DDTREE_BUDGET")); + add_opt(common_arg( + {"--ddtree-temp"}, "F", + string_format("DDTree: temperature for draft log-prob extraction (default: %.1f)", (double)params.speculative.ddtree_temp), + [](common_params & params, const std::string & value) { + params.speculative.ddtree_temp = std::stof(value); + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_DDTREE_TEMP")); + add_opt(common_arg( + {"--ddtree-no-chain-seed"}, + "DDTree: disable chain-seed greedy initialization (enabled by default)", + [](common_params & params) { + params.speculative.ddtree_chain_seed = false; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--ddtree-top-k"}, "N", + string_format("DDTree: per-position draft top-K width (default: %d, 0 = standalone-compatible auto)", + params.speculative.ddtree_top_k), + [](common_params & params, int value) { + if (value < 0) { + throw std::invalid_argument("--ddtree-top-k must be >= 0"); + } + params.speculative.ddtree_top_k = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_DDTREE_TOP_K")); add_opt(common_arg( {"-ctkd", "--cache-type-k-draft"}, "TYPE", string_format( diff --git a/common/common.h b/common/common.h index 020b6a721ff..5d2fdb48806 100644 --- a/common/common.h +++ b/common/common.h @@ -355,6 +355,13 @@ struct common_params_speculative { bool has_dft() const { return !mparams_dft.path.empty() || !mparams_dft.hf_repo.empty(); } + + // DDTree speculative decoding parameters (Phase 5) + bool ddtree_mode = false; // true when --speculative-mode ddtree is set + int32_t ddtree_budget = 22; // non-root tree node budget (matches dflash default) + float ddtree_temp = 1.0f; // temperature for draft log-prob extraction + bool ddtree_chain_seed = true; // seed the tree heap with greedy chain (recommended) + int32_t ddtree_top_k = 0; // per-position draft top-K width; 0 = standalone-compatible auto }; struct common_params_vocoder { diff --git a/common/sampling.cpp b/common/sampling.cpp index 526f036ff98..fa3561328d4 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -429,6 +429,22 @@ static bool grammar_should_apply(struct common_sampler * gsmpl) { return true; } +bool common_sampler_grammar_token_valid(struct common_sampler * gsmpl, llama_token token) { + if (!gsmpl) { + return false; + } + if (!grammar_should_apply(gsmpl)) { + return true; + } + // Apply the grammar sampler to a single-token candidate array. The grammar + // sampler masks invalid tokens to -INFINITY; we only inspect the result and + // do not advance any sampler state. + llama_token_data single = { token, 1.0f, 0.0f }; + llama_token_data_array single_array = { &single, 1, -1, false }; + llama_sampler_apply(gsmpl->grmr, &single_array); + return single_array.data[0].logit != -INFINITY; +} + void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) { if (!gsmpl) { return; @@ -524,7 +540,36 @@ struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) { } llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) { - llama_synchronize(ctx); + // Optional sub-timing instrumentation (LLAMA_DDTREE_PROFILE_CB=1). + // Splits the cb cost into the GPU sync wait vs the host-side sampler + // work so we can tell whether prefetching logits would help. + static const bool s_profile_cb = []{ + const char * e = std::getenv("LLAMA_DDTREE_PROFILE_CB"); + return e && e[0] == '1'; + }(); + int64_t cb_sync_us = 0; + int64_t cb_work_start_us = 0; + if (s_profile_cb) { + const int64_t t0 = ggml_time_us(); + llama_synchronize(ctx); + cb_sync_us = ggml_time_us() - t0; + cb_work_start_us = ggml_time_us(); + } else { + llama_synchronize(ctx); + } + struct cb_timing_guard { + bool active; + int64_t sync_us; + int64_t work_start_us; + ~cb_timing_guard() { + if (active) { + const double work_ms = (ggml_time_us() - work_start_us) * 1e-3; + fprintf(stderr, "cb_timing: sync=%.3f work=%.3f total=%.3f ms\n", + sync_us * 1e-3, work_ms, sync_us * 1e-3 + work_ms); + } + } + } cb_guard{ s_profile_cb, cb_sync_us, cb_work_start_us }; + (void) cb_guard; // start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations const auto tm = gsmpl->tm(); diff --git a/common/sampling.h b/common/sampling.h index 5b57ad65811..c42de51d90b 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -64,6 +64,12 @@ struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl); // llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false); +// Cheap single-token grammar check used by speculative decoding short-circuits. +// Returns true when the token would be accepted by the grammar (or when no +// grammar is active / not currently applicable). Does NOT advance any sampler +// state; safe to call from inside a verify-walk callback. +bool common_sampler_grammar_token_valid(struct common_sampler * gsmpl, llama_token token); + // generalized version of common_sampler_sample // // will cross-reference the sampled tokens with a batch of draft tokens and accept those that match diff --git a/common/speculative-draft-backend.cpp b/common/speculative-draft-backend.cpp new file mode 100644 index 00000000000..aa9850e53b1 --- /dev/null +++ b/common/speculative-draft-backend.cpp @@ -0,0 +1,327 @@ +#include "speculative-draft-backend.h" + +#include "log.h" + +#include +#include +#include +#include +#include +#include + +using ddtree_draft_clock = std::chrono::steady_clock; + +static double draft_elapsed_ms(ddtree_draft_clock::time_point t0) { + return std::chrono::duration(ddtree_draft_clock::now() - t0).count(); +} + +int llama_speculative_draft_top_k_width(int block_size, const llama_ddtree_params & params) { + if (params.top_k > 0) { + return params.top_k; + } + // Standalone DFlash only asks the draft for top-K branches when the DDTree + // budget can grow beyond the greedy chain; it uses K=8 in that case. + return (params.budget > std::max(0, block_size - 1)) ? 8 : 1; +} + +bool llama_speculative_draft_pack_target_feat(const llama_speculative_draft_target_feat_view & view, + std::vector & out, + int64_t & ctx_len) { + ctx_len = 0; + if (view.ring == nullptr || view.n_committed <= 0 || view.cap <= 0 || view.n_embd_fc <= 0) { + return false; + } + + ctx_len = std::min(view.n_committed, view.cap); + const int64_t ring_start = view.n_committed - ctx_len; + + out.resize((size_t) view.n_embd_fc * ctx_len); + for (int64_t t = 0; t < ctx_len; ++t) { + const int64_t ring_col = (ring_start + t) % view.cap; + const float * src = view.ring + ring_col * view.n_embd_fc; + float * dst = out.data() + t * view.n_embd_fc; + memcpy(dst, src, (size_t) view.n_embd_fc * sizeof(float)); + } + return true; +} + +class llama_speculative_llama_draft_backend final : public llama_speculative_draft_backend { + public: + llama_speculative_llama_draft_backend(llama_context * draft_ctx, + const llama_model * target_model, + int64_t n_embd, + int64_t n_vocab, + int64_t block_size, + llama_token mask_token_id, + const llama_ddtree_params & params) : + draft_ctx(draft_ctx), + target_model(target_model), + n_embd(n_embd), + n_vocab(n_vocab), + block_size(block_size), + mask_token_id(mask_token_id), + params(params) { + mask_embd.resize((size_t) n_embd); + noise_embd.resize((size_t) block_size * n_embd); + pos.resize((size_t) block_size); + n_seq_id.assign((size_t) block_size, 1); + seq_id_values.assign((size_t) block_size, 0); + seq_id_ptrs.resize((size_t) block_size); + logits.assign((size_t) block_size, 1); + for (int64_t i = 0; i < block_size; ++i) { + seq_id_ptrs[(size_t) i] = &seq_id_values[(size_t) i]; + } + } + + bool init() { + if (draft_ctx == nullptr || target_model == nullptr || n_embd <= 0 || n_vocab <= 0 || block_size <= 1) { + return false; + } + if (llama_model_token_embd_lookup(target_model, mask_token_id, mask_embd.data(), n_embd) != 0) { + LOG_ERR("%s: token_embd_lookup failed for mask_token=%d\n", __func__, (int) mask_token_id); + return false; + } + for (int64_t i = 1; i < block_size; ++i) { + memcpy(noise_embd.data() + i * n_embd, mask_embd.data(), (size_t) n_embd * sizeof(float)); + } + llama_set_dflash_draft_top_k(draft_ctx, + std::min(llama_speculative_draft_top_k_width((int) block_size, params), + n_vocab)); + return true; + } + + const char * name() const override { return "dflash-topk"; } + + bool ingest_target_capture(llama_context * target_ctx, + const int32_t * dfs_indices, + int32_t n_dfs, + int64_t first_pos, + int64_t cap, + double & elapsed_ms) override { + const auto t0 = ddtree_draft_clock::now(); + elapsed_ms = 0.0; + if (target_ctx == nullptr || n_dfs <= 0 || cap <= 0) { + return false; + } + const int ret = llama_dflash_draft_update_fused_cache_from_capture(draft_ctx, target_ctx, dfs_indices, + n_dfs, first_pos, cap); + elapsed_ms = draft_elapsed_ms(t0); + if (ret != 0) { + return false; + } + fused_target_feat_cap = cap; + fused_target_feat_n_embd = n_embd; + fused_target_feat_n_committed = first_pos + n_dfs; + return true; + } + + bool decode_topk(llama_token root_token, + llama_pos committed_pos, + const llama_speculative_draft_target_feat_view & target_feat, + std::vector & top_log_probs, + std::vector & top_token_ids, + llama_speculative_draft_decode_info & info) override { + info = {}; + info.L = (int) block_size - 1; + info.K = std::min(llama_speculative_draft_top_k_width((int) block_size, params), n_vocab); + + if (target_feat.ring == nullptr || target_feat.n_committed <= 0 || target_feat.cap <= 0 || + target_feat.n_embd_fc <= 0 || target_feat.n_embd_fc % 5 != 0) { + LOG_ERR( + "%s: target_feat ring is empty; call llama_speculative_tree_driver_ingest_prompt_capture first\n", + __func__); + return false; + } + + info.ctx_len = std::min(target_feat.n_committed, target_feat.cap); + const int64_t ring_start = target_feat.n_committed - info.ctx_len; + + { + const auto t0 = ddtree_draft_clock::now(); + + if (llama_model_token_embd_lookup(target_model, root_token, noise_embd.data(), n_embd) != 0) { + LOG_ERR("%s: token_embd_lookup failed for root_token=%d\n", __func__, (int) root_token); + info.t_draft_decode_ms += draft_elapsed_ms(t0); + return false; + } + for (int32_t i = 0; i < (int32_t) block_size; ++i) { + pos[(size_t) i] = committed_pos + i; + } + + llama_batch draft_batch{}; + draft_batch.n_tokens = (int32_t) block_size; + draft_batch.token = nullptr; + draft_batch.embd = noise_embd.data(); + draft_batch.pos = pos.data(); + draft_batch.n_seq_id = n_seq_id.data(); + draft_batch.seq_id = seq_id_ptrs.data(); + draft_batch.logits = logits.data(); + draft_batch.parent_id = nullptr; + + const int64_t fused_n_embd = target_feat.n_embd_fc / 5; + if (!ensure_fused_target_feat(target_feat, fused_n_embd)) { + info.t_draft_decode_ms += draft_elapsed_ms(t0); + return false; + } + const int ret = llama_dflash_draft_encode_top_k_cached(draft_ctx, draft_batch, + fused_n_embd, info.ctx_len, + ring_start, target_feat.cap, + committed_pos, info.K); + if (ret != 0) { + LOG_ERR("%s: dflash draft encode-topK failed: %d\n", __func__, ret); + info.t_draft_decode_ms += draft_elapsed_ms(t0); + return false; + } + info.t_draft_decode_ms += draft_elapsed_ms(t0); + } + + top_log_probs.resize((size_t) info.L * info.K); + top_token_ids.resize((size_t) info.L * info.K); + + { + const auto t0 = ddtree_draft_clock::now(); + + float proposal_temp = params.temp; + if (const char * e = std::getenv("LLAMA_DDTREE_PROPOSAL_TEMP")) { + char * end = nullptr; + const float v = std::strtof(e, &end); + if (end != e && v > 0.0f) { + proposal_temp = v; + } + } + const float inv_t = 1.0f / std::max(1e-6f, proposal_temp); + + const float * draft_top_logits = nullptr; + const llama_token * draft_top_tokens = nullptr; + int32_t top_rows = 0; + int32_t top_k = 0; + if (!llama_get_dflash_draft_top_k(draft_ctx, &draft_top_logits, &draft_top_tokens, &top_rows, &top_k) || + draft_top_logits == nullptr || draft_top_tokens == nullptr || top_rows < (int32_t) block_size || + top_k < info.K) { + LOG_ERR("%s: dflash draft top-K unavailable\n", __func__); + info.t_topk_ms += draft_elapsed_ms(t0); + return false; + } + + struct Entry { + float logit; + llama_token token; + }; + std::vector row_top((size_t) info.K); + + for (int i = 0; i < info.L; ++i) { + const int row_idx = i + 1; + for (int k = 0; k < info.K; ++k) { + row_top[(size_t) k] = { + draft_top_logits[(size_t) row_idx * top_k + k], + draft_top_tokens[(size_t) row_idx * top_k + k], + }; + } + std::sort(row_top.begin(), row_top.end(), [](const Entry & a, const Entry & b) { + return a.logit > b.logit; + }); + + if (std::abs(proposal_temp - 1.0f) < 1e-6f) { + for (int k = 0; k < info.K; ++k) { + top_log_probs[(size_t) i * info.K + k] = row_top[(size_t) k].logit; + top_token_ids[(size_t) i * info.K + k] = row_top[(size_t) k].token; + } + continue; + } + + const float row_best = row_top[0].logit * inv_t; + float sum_exp_top = 0.0f; + for (int k = 0; k < info.K; ++k) { + sum_exp_top += std::exp(row_top[(size_t) k].logit * inv_t - row_best); + } + const float log_z_approx = row_best + std::log(sum_exp_top); + for (int k = 0; k < info.K; ++k) { + top_log_probs[(size_t) i * info.K + k] = row_top[(size_t) k].logit * inv_t - log_z_approx; + top_token_ids[(size_t) i * info.K + k] = row_top[(size_t) k].token; + } + } + info.t_topk_ms += draft_elapsed_ms(t0); + } + + return true; + } + + private: + bool ensure_fused_target_feat(const llama_speculative_draft_target_feat_view & target_feat, + int64_t fused_n_embd) { + const int64_t ctx_len = std::min(target_feat.n_committed, target_feat.cap); + const int64_t ring_start = target_feat.n_committed - ctx_len; + + if (fused_target_feat_cap != target_feat.cap || fused_target_feat_n_embd != fused_n_embd) { + fused_target_feat_cap = target_feat.cap; + fused_target_feat_n_embd = fused_n_embd; + fused_target_feat_n_committed = ring_start; + } + + if (fused_target_feat_n_committed < ring_start || fused_target_feat_n_committed > target_feat.n_committed) { + fused_target_feat_n_committed = ring_start; + } + + const int64_t missing = target_feat.n_committed - fused_target_feat_n_committed; + if (missing <= 0) { + return true; + } + + raw_fuse_buf.resize((size_t) target_feat.n_embd_fc * missing); + for (int64_t t = 0; t < missing; ++t) { + const int64_t logical_col = fused_target_feat_n_committed + t; + const int64_t ring_col = logical_col % target_feat.cap; + const float * src = target_feat.ring + ring_col * target_feat.n_embd_fc; + float * dst = raw_fuse_buf.data() + t * target_feat.n_embd_fc; + memcpy(dst, src, (size_t) target_feat.n_embd_fc * sizeof(float)); + } + + const int ret = llama_dflash_draft_update_fused_cache(draft_ctx, raw_fuse_buf.data(), target_feat.n_embd_fc, + missing, fused_target_feat_n_committed, + target_feat.cap); + if (ret != 0) { + LOG_ERR("%s: dflash target_feat cache update failed: %d\n", __func__, ret); + return false; + } + + fused_target_feat_n_committed = target_feat.n_committed; + return true; + } + + llama_context * draft_ctx = nullptr; + const llama_model * target_model = nullptr; + int64_t n_embd = 0; + int64_t n_vocab = 0; + int64_t block_size = 0; + llama_token mask_token_id = 0; + llama_ddtree_params params; + + std::vector target_feat_buf; + std::vector raw_fuse_buf; + int64_t fused_target_feat_n_committed = 0; + int64_t fused_target_feat_n_embd = 0; + int64_t fused_target_feat_cap = 0; + std::vector mask_embd; + std::vector noise_embd; + std::vector pos; + std::vector n_seq_id; + std::vector seq_id_values; + std::vector seq_id_ptrs; + std::vector logits; +}; + +std::unique_ptr llama_speculative_draft_backend_init_llama( + llama_context * draft_ctx, + const llama_model * target_model, + int64_t n_embd, + int64_t n_vocab, + int64_t block_size, + llama_token mask_token_id, + const llama_ddtree_params & params) { + auto backend = std::make_unique(draft_ctx, target_model, n_embd, n_vocab, + block_size, mask_token_id, params); + if (!backend->init()) { + return nullptr; + } + return backend; +} diff --git a/common/speculative-draft-backend.h b/common/speculative-draft-backend.h new file mode 100644 index 00000000000..dbc732d0d01 --- /dev/null +++ b/common/speculative-draft-backend.h @@ -0,0 +1,69 @@ +#pragma once + +#include "llama.h" +#include "speculative-tree.h" + +#include +#include +#include + +struct llama_speculative_draft_target_feat_view { + const float * ring = nullptr; + int64_t n_committed = 0; + int64_t cap = 0; + int64_t n_embd_fc = 0; +}; + +struct llama_speculative_draft_decode_info { + int L = 0; + int K = 0; + int64_t ctx_len = 0; + + double t_target_feat_pack_ms = 0.0; + double t_draft_decode_ms = 0.0; + double t_topk_ms = 0.0; +}; + +int llama_speculative_draft_top_k_width(int block_size, const llama_ddtree_params & params); + +bool llama_speculative_draft_pack_target_feat(const llama_speculative_draft_target_feat_view & view, + std::vector & out, + int64_t & ctx_len); + +class llama_speculative_draft_backend { + public: + virtual ~llama_speculative_draft_backend() = default; + + virtual const char * name() const = 0; + + virtual bool ingest_target_capture(llama_context * target_ctx, + const int32_t * dfs_indices, + int32_t n_dfs, + int64_t first_pos, + int64_t cap, + double & elapsed_ms) { + GGML_UNUSED(target_ctx); + GGML_UNUSED(dfs_indices); + GGML_UNUSED(n_dfs); + GGML_UNUSED(first_pos); + GGML_UNUSED(cap); + elapsed_ms = 0.0; + return false; + } + + virtual bool decode_topk(llama_token root_token, + llama_pos committed_pos, + const llama_speculative_draft_target_feat_view & target_feat, + std::vector & top_log_probs, + std::vector & top_token_ids, + llama_speculative_draft_decode_info & info) = 0; +}; + +std::unique_ptr llama_speculative_draft_backend_init_llama( + llama_context * draft_ctx, + const llama_model * target_model, + int64_t n_embd, + int64_t n_vocab, + int64_t block_size, + llama_token mask_token_id, + const llama_ddtree_params & params); diff --git a/common/speculative-tree-driver.cpp b/common/speculative-tree-driver.cpp new file mode 100644 index 00000000000..5cc514026b2 --- /dev/null +++ b/common/speculative-tree-driver.cpp @@ -0,0 +1,1194 @@ +// speculative-tree-driver.cpp — Phase 4 DDTree spec-decode step coordinator. +// +// Mirrors the main loop in test_dflash.cpp:1070-1500 using the llama.cpp public API. +// +// Target_feat layout (from qwen35.cpp capture): +// The hidden capture tensor is [n_embd, 5*n_tokens] (column-major ggml / row-major C). +// Layer k's hidden for all decoded positions occupies columns [k*n_total .. (k+1)*n_total). +// For the draft window, we take the most recent ctx_len positions per layer and pack +// them into [5*n_embd, ctx_len]: +// row l*n_embd .. (l+1)*n_embd - 1 = layer l's hidden across ctx_len positions. +// This matches what dflash-draft.cpp's fc projection expects. +// +// SSM rollback strategy (Phase 2.4): +// If the accepted node is DFS-last, the live recurrent state already points at the accepted +// state and copying from persist buffers only adds risk. Non-DFS-last persist rollback is +// not yet proven correct for DFlash, so the driver falls back to +// snapshot+restore+chain-replay for that case. + +#include "speculative-tree-driver.h" +#include "speculative-draft-backend.h" +#include "speculative-tree.h" +#include "log.h" + +#include "llama.h" + +#include +#include +#include +#include +#include +#include +#include + +// Maximum target-context window that the draft can attend over. +// Matches test_dflash.cpp:1086 DRAFT_CTX_MAX. The server-port default is +// smaller; raise it with LLAMA_DDTREE_TARGET_FEAT_CTX when needed. +static constexpr int DRAFT_CTX_MAX = 2048; +static constexpr int DRAFT_CTX_DEFAULT = 128; + +// EOS token for Qwen3.5 family. +static constexpr llama_token QWEN35_EOS = 248045; + +using ddtree_clock = std::chrono::steady_clock; + +static double elapsed_ms(ddtree_clock::time_point t0) { + return std::chrono::duration(ddtree_clock::now() - t0).count(); +} + +struct llama_speculative_tree_driver { + llama_context * target_ctx = nullptr; + llama_context * draft_ctx = nullptr; + llama_ddtree_params params; + std::unique_ptr draft_backend; + + // n_embd from the target model (for hidden capture slicing). + int64_t n_embd = 0; + // vocabulary size (for logit indexing). + int64_t n_vocab = 0; + // draft block_size (number of noise tokens per step, typically 16). + int64_t block_size = 0; + // mask token id used to fill noise positions in the draft input. + llama_token mask_token_id = 0; + + // Scratch buffer for packed target_feat: [5*n_embd, ctx_len] + std::vector target_feat_buf; + + // Cumulative target_feat sliding ring buffer, [5*n_embd, target_feat_cap] + // Stored column-major: column t = position t, rows = [l*n_embd .. (l+1)*n_embd) for layer l. + // i.e. ring[(logical_col % cap) * target_feat_n_embd_fc + l*n_embd .. +n_embd] = + // layer l at committed pos logical_col. + std::vector target_feat_ring; // size = target_feat_n_embd_fc * target_feat_cap + int64_t target_feat_n_committed = 0; // total committed positions appended to the ring, not capped + int64_t target_feat_n_embd_fc = 0; // = 5 * n_embd + int64_t target_feat_cap = DRAFT_CTX_DEFAULT; // target feature context retained for draft + + // Scratch buffers + std::vector top_log_probs; // [block_size-1, K] + std::vector top_token_ids; // [block_size-1, K] + std::vector noise_embd_buf; // [block_size * n_embd] + std::vector posterior; // [N] argmax per tree node + + llama_speculative_tree_driver_stats stats; + bool fast_rollback_unavailable = false; +}; + +enum class ddtree_verifier_mode { + exact, + paper, +}; + +static ddtree_verifier_mode ddtree_get_verifier_mode() { + const char * e = std::getenv("LLAMA_DDTREE_VERIFIER"); + if (e != nullptr) { + if (std::strcmp(e, "exact") == 0 || std::strcmp(e, "chain") == 0) { + return ddtree_verifier_mode::exact; + } + if (std::strcmp(e, "paper") == 0 || std::strcmp(e, "tree") == 0) { + return ddtree_verifier_mode::paper; + } + } + + const char * exact = std::getenv("LLAMA_DDTREE_EXACT_VALIDATION"); + if (exact != nullptr && exact[0] == '1') { + return ddtree_verifier_mode::exact; + } + + return ddtree_verifier_mode::paper; +} + +static bool ddtree_paper_verifier_enabled() { + return ddtree_get_verifier_mode() == ddtree_verifier_mode::paper; +} + +static bool ddtree_fast_batched_enabled() { + const char * e = std::getenv("LLAMA_DDTREE_FAST_BATCHED"); + return e != nullptr && e[0] == '1'; +} + +static bool ddtree_target_top1_enabled() { + const char * e = std::getenv("LLAMA_DDTREE_TARGET_TOP1"); + return e != nullptr && e[0] == '1'; +} + +static bool ddtree_fast_rollback_enabled() { + const char * e = std::getenv("LLAMA_DDTREE_FAST_ROLLBACK"); + return e != nullptr && e[0] == '1'; +} + +static bool ddtree_snapshot_fallback_enabled() { + const char * e = std::getenv("LLAMA_DDTREE_SNAPSHOT_FALLBACK"); + return e == nullptr || e[0] != '0'; +} + +static bool ddtree_capture_direct_enabled() { + const char * e = std::getenv("LLAMA_DDTREE_CAPTURE_DIRECT"); + return e != nullptr && e[0] == '1'; +} + +static bool ddtree_trust_batched_posterior() { + const char * e = std::getenv("LLAMA_DDTREE_UNSAFE_TRUST_BATCHED"); + return e != nullptr && e[0] == '1'; +} + +static bool ddtree_diag_batched_enabled() { + const char * e = std::getenv("LLAMA_DDTREE_DIAG_BATCHED"); + return e != nullptr && e[0] == '1'; +} + +static bool ddtree_exact_ar_fallback_enabled() { + const char * e = std::getenv("LLAMA_DDTREE_EXACT_AR_FALLBACK"); + return e == nullptr || e[0] != '0'; +} + +static bool ddtree_unsafe_fast_tree_state_enabled() { + const char * e = std::getenv("LLAMA_DDTREE_UNSAFE_FAST_TREE_STATE"); + return e != nullptr && e[0] == '1'; +} + +static int64_t ddtree_target_feat_cap() { + const char * e = std::getenv("LLAMA_DDTREE_TARGET_FEAT_CTX"); + if (!e || e[0] == '\0') { + return DRAFT_CTX_DEFAULT; + } + + char * end = nullptr; + const long v = std::strtol(e, &end, 10); + if (end == e || v <= 0) { + return DRAFT_CTX_DEFAULT; + } + + return std::min(DRAFT_CTX_MAX, std::max(1, (int64_t)v)); +} + +llama_speculative_tree_driver * llama_speculative_tree_driver_init( + llama_context * target_ctx, + llama_context * draft_ctx, + const llama_ddtree_params & params) { + + const llama_model * target_model = llama_get_model(target_ctx); + + if (!target_model || !llama_get_model(draft_ctx)) { + LOG_ERR("%s: null model pointer\n", __func__); + return nullptr; + } + + auto * d = new llama_speculative_tree_driver; + d->target_ctx = target_ctx; + d->draft_ctx = draft_ctx; + d->params = params; + + d->n_embd = llama_model_n_embd(target_model); + // n_vocab: use target vocab; draft shares the same lm_head. + const llama_vocab * target_vocab = llama_model_get_vocab(target_model); + d->n_vocab = (target_vocab != nullptr) ? llama_vocab_n_tokens(target_vocab) : 0; + if (d->n_embd <= 0 || d->n_vocab <= 0) { + LOG_ERR("%s: invalid model dimensions n_embd=%lld n_vocab=%lld\n", + __func__, (long long)d->n_embd, (long long)d->n_vocab); + delete d; + return nullptr; + } + // block_size and mask_token_id are constants baked into the dflash-draft checkpoint. + // Qwen3.5-27B-DFlash always uses block_size=16 and mask_token_id=248070. + d->block_size = (int64_t)params.block_size; // from llama_ddtree_params (default 16) + d->mask_token_id = 248070; // dflash-draft MASK token + + // Initialize cumulative target_feat ring buffer. + d->target_feat_n_embd_fc = 5 * d->n_embd; + d->target_feat_n_committed = 0; + d->target_feat_cap = ddtree_target_feat_cap(); + d->target_feat_ring.assign((size_t)d->target_feat_n_embd_fc * d->target_feat_cap, 0.0f); + + if (ddtree_paper_verifier_enabled()) { + llama_dflash_ensure_persist_capacity(target_ctx, params.budget); + } + + d->draft_backend = llama_speculative_draft_backend_init_llama( + draft_ctx, target_model, d->n_embd, d->n_vocab, d->block_size, d->mask_token_id, d->params); + if (!d->draft_backend) { + LOG_ERR("%s: failed to initialize draft backend\n", __func__); + delete d; + return nullptr; + } + + return d; +} + +void llama_speculative_tree_driver_free(llama_speculative_tree_driver * d) { + delete d; +} + +llama_speculative_tree_driver_stats llama_speculative_tree_driver_get_stats( + const llama_speculative_tree_driver * d) { + return d ? d->stats : llama_speculative_tree_driver_stats{}; +} + +int32_t llama_speculative_tree_driver_context_window() { + return (int32_t)ddtree_target_feat_cap(); +} + +// Pack the hidden capture buffer into [5*n_embd, ctx_len] F32. +// capture : [n_embd, 5*n_total] — ggml column-major, row-major in C means ne[0]=n_embd. +// n_total : total positions in the capture tensor (= ne[1] / 5). +// ctx_len : number of most-recent positions to pack (ctx_len <= n_total). +// out : caller-allocated [5*n_embd * ctx_len] F32. +static void pack_target_feat( + const float * capture, int64_t n_embd, int64_t n_total, int64_t ctx_len, float * out) { + const int64_t start = n_total - ctx_len; // first column to include + for (int64_t l = 0; l < 5; ++l) { + for (int64_t t = 0; t < ctx_len; ++t) { + // Source: capture column (start+t) in layer l's block. + // In C memory (n_embd fastest): capture[(l*n_total + start + t) * n_embd .. +n_embd) + const float * src = capture + (l * n_total + start + t) * n_embd; + // Destination: row l*n_embd in the output, column t. + // Output is [5*n_embd, ctx_len] → in C: out[t * 5*n_embd + l*n_embd .. +n_embd) + float * dst = out + (size_t)t * 5 * n_embd + (size_t)l * n_embd; + memcpy(dst, src, (size_t)n_embd * sizeof(float)); + } + } +} + +// Append hidden capture data from target_ctx into the driver's ring buffer. +// dfs_indices: if non-NULL, selects which capture columns to ingest (the DFS accepted indices). +// if NULL, ingest the first n_dfs columns linearly (prompt prefill path). +// n_dfs: number of columns to ingest. +enum class ingest_source { + prompt, + tree, + replay, +}; + +static int32_t driver_ingest_capture(llama_speculative_tree_driver * d, + const int32_t * dfs_indices, + int32_t n_dfs, + ingest_source source) { + const auto t0 = ddtree_clock::now(); + ggml_tensor * t_capture = llama_get_hidden_capture(d->target_ctx); + int64_t ne0 = t_capture != nullptr ? t_capture->ne[0] : 0; + int64_t ne1 = t_capture != nullptr ? t_capture->ne[1] : 0; + if (t_capture == nullptr || ne0 == 0 || ne1 == 0) { + LOG_ERR("%s: no hidden capture data available\n", __func__); + d->stats.t_ingest_capture_ms += elapsed_ms(t0); + return 0; + } + // capture layout: [n_embd, 5*n_tokens] → ne0=n_embd, ne1=5*n_tokens + const int64_t n_embd = ne0; + const int64_t n_tokens = ne1 / 5; // number of decoded positions in this capture + + if (n_embd != d->n_embd) { + LOG_ERR("%s: capture n_embd=%lld != driver n_embd=%lld\n", + __func__, (long long)n_embd, (long long)d->n_embd); + d->stats.t_ingest_capture_ms += elapsed_ms(t0); + return 0; + } + + // Clamp n_dfs to what the capture actually contains. The server may call + // ingest_prompt_capture(slot.prompt_size) when only the new (uncached) tail + // of the prompt actually went through llama_decode — the capture only holds + // the most recent decode's columns. Out-of-range reads here would be UB. + int32_t n_to_ingest = n_dfs; + if (dfs_indices == nullptr && n_to_ingest > (int32_t)n_tokens) { + d->stats.n_capture_clamps++; + LOG_WRN("%s: requested n_dfs=%d but capture only has n_tokens=%lld; clamping (ring will be incomplete)\n", + __func__, n_dfs, (long long)n_tokens); + n_to_ingest = (int32_t)n_tokens; + } + + if (ddtree_capture_direct_enabled() && d->draft_backend) { + double direct_ms = 0.0; + if (d->draft_backend->ingest_target_capture(d->target_ctx, dfs_indices, n_to_ingest, + d->target_feat_n_committed, d->target_feat_cap, + direct_ms)) { + d->target_feat_n_committed += (int64_t)n_to_ingest; + switch (source) { + case ingest_source::prompt: + d->stats.n_prompt_ingest_calls++; + d->stats.n_prompt_ingested_tokens += n_to_ingest; + d->stats.t_prompt_ingest_ms += direct_ms; + break; + case ingest_source::tree: + d->stats.n_tree_ingested_tokens += n_to_ingest; + d->stats.t_tree_ingest_ms += direct_ms; + break; + case ingest_source::replay: + d->stats.n_replay_ingested_tokens += n_to_ingest; + d->stats.t_replay_ingest_ms += direct_ms; + break; + } + d->stats.t_ingest_capture_ms += direct_ms; + return n_to_ingest; + } + } + + const float * capture = llama_get_hidden_capture_data(d->target_ctx, &ne0, &ne1); + if (!capture || ne0 == 0 || ne1 == 0) { + LOG_ERR("%s: no hidden capture data available after direct-ingest fallback\n", __func__); + d->stats.t_ingest_capture_ms += elapsed_ms(t0); + return 0; + } + + for (int32_t i = 0; i < n_to_ingest; ++i) { + // Source column index in the capture buffer (within each layer's block). + const int64_t src_col = (dfs_indices != nullptr) ? (int64_t)dfs_indices[i] : (int64_t)i; + if (src_col < 0 || src_col >= n_tokens) { + LOG_ERR("%s: src_col=%lld out of capture range [0, %lld)\n", + __func__, (long long)src_col, (long long)n_tokens); + break; + } + + const int64_t logical_col = d->target_feat_n_committed + (int64_t)i; + const int64_t dst_col = logical_col % d->target_feat_cap; + + for (int64_t l = 0; l < 5; ++l) { + // Source: layer l's block starts at column l*n_tokens; pick column src_col within it. + const float * src = capture + (l * n_tokens + src_col) * n_embd; + // Destination: ring column dst_col, row l*n_embd. + float * dst = d->target_feat_ring.data() + dst_col * d->target_feat_n_embd_fc + l * n_embd; + memcpy(dst, src, (size_t)n_embd * sizeof(float)); + } + } + + d->target_feat_n_committed += (int64_t)n_to_ingest; + const double ingest_ms = elapsed_ms(t0); + switch (source) { + case ingest_source::prompt: + d->stats.n_prompt_ingest_calls++; + d->stats.n_prompt_ingested_tokens += n_to_ingest; + d->stats.t_prompt_ingest_ms += ingest_ms; + break; + case ingest_source::tree: + d->stats.n_tree_ingested_tokens += n_to_ingest; + d->stats.t_tree_ingest_ms += ingest_ms; + break; + case ingest_source::replay: + d->stats.n_replay_ingested_tokens += n_to_ingest; + d->stats.t_replay_ingest_ms += ingest_ms; + break; + } + d->stats.t_ingest_capture_ms += ingest_ms; + return n_to_ingest; +} + +static bool replay_committed_chain(llama_speculative_tree_driver * d, + const llama_ddtree & tree, + const int32_t * accepted_dfs, + int32_t commit_n, + llama_pos committed_pos) { + if (commit_n <= 0) { + return true; + } + + llama_memory_t mem = llama_get_memory(d->target_ctx); + if (!llama_memory_seq_rm(mem, /*seq_id=*/0, committed_pos, /*p1=*/-1)) { + LOG_ERR("%s: failed to remove tree KV/recurrent range at pos >= %d\n", + __func__, (int)committed_pos); + return false; + } + + llama_batch replay = llama_batch_init(commit_n, /*embd=*/0, /*n_seq_max=*/1); + replay.n_tokens = commit_n; + for (int32_t i = 0; i < commit_n; ++i) { + replay.token[i] = tree.nodes[accepted_dfs[i]].token_id; + replay.pos[i] = committed_pos + i; + replay.n_seq_id[i] = 1; + replay.seq_id[i][0] = 0; + replay.logits[i] = 0; + } + + const int ret = llama_decode(d->target_ctx, replay); + llama_batch_free(replay); + if (ret != 0) { + LOG_ERR("%s: chain replay llama_decode failed: %d\n", __func__, ret); + return false; + } + + driver_ingest_capture(d, nullptr, commit_n, ingest_source::replay); + return true; +} + +static int32_t find_child_token(const llama_ddtree & tree, int32_t parent, llama_token token) { + for (int32_t i = 1; i < (int32_t) tree.nodes.size(); ++i) { + if (tree.nodes[i].parent_idx == parent && tree.nodes[i].token_id == token) { + return i; + } + } + return -1; +} + +static llama_token pick_current_logits(llama_speculative_tree_driver * d, + const llama_speculative_tree_verify_cbs * verify_cbs) { + if (verify_cbs != nullptr && verify_cbs->sample_cb != nullptr) { + // exact-chain mode: no precomputed batched argmax for this row, signal + // the cb to do a full sample with LLAMA_TOKEN_NULL. + return (llama_token) verify_cbs->sample_cb(verify_cbs->user_data, + /*logits_row_idx=*/0, + /*batched_pick=*/LLAMA_TOKEN_NULL); + } + + const float * row = llama_get_logits_ith(d->target_ctx, 0); + if (!row) { + return LLAMA_TOKEN_NULL; + } + int32_t best = 0; + float best_val = row[0]; + for (int64_t v = 1; v < d->n_vocab; ++v) { + if (row[v] > best_val) { + best_val = row[v]; + best = (int32_t) v; + } + } + return (llama_token) best; +} + +static bool validate_tree_with_chain(llama_speculative_tree_driver * d, + const llama_ddtree & tree, + llama_pos committed_pos, + const llama_speculative_tree_verify_cbs * verify_cbs, + std::vector & accepted_dfs, + llama_token & next_token) { + accepted_dfs.clear(); + accepted_dfs.push_back(0); + next_token = LLAMA_TOKEN_NULL; + + llama_memory_t mem = llama_get_memory(d->target_ctx); + if (!llama_memory_seq_rm(mem, /*seq_id=*/0, committed_pos, /*p1=*/-1)) { + LOG_ERR("%s: failed to remove tree KV/recurrent range at pos >= %d\n", + __func__, (int)committed_pos); + return false; + } + + int32_t current = 0; + for (int32_t depth = 0; depth < (int32_t) tree.nodes.size(); ++depth) { + llama_batch b = llama_batch_init(1, /*embd=*/0, /*n_seq_max=*/1); + b.n_tokens = 1; + b.token[0] = tree.nodes[current].token_id; + b.pos[0] = committed_pos + depth; + b.n_seq_id[0] = 1; + b.seq_id[0][0] = 0; + b.logits[0] = 1; + + const auto t_decode0 = ddtree_clock::now(); + const int ret = llama_decode(d->target_ctx, b); + d->stats.t_exact_decode_ms += elapsed_ms(t_decode0); + d->stats.n_exact_validate_nodes++; + llama_batch_free(b); + if (ret != 0) { + LOG_ERR("%s: chain validation llama_decode failed at depth %d: %d\n", + __func__, (int) depth, ret); + return false; + } + + driver_ingest_capture(d, nullptr, 1, ingest_source::replay); + + const auto t_sample0 = ddtree_clock::now(); + const llama_token picked = pick_current_logits(d, verify_cbs); + d->stats.t_exact_sample_ms += elapsed_ms(t_sample0); + if (picked == LLAMA_TOKEN_NULL) { + LOG_ERR("%s: failed to pick from chain validation logits\n", __func__); + return false; + } + + const int32_t child = find_child_token(tree, current, picked); + if (child < 0) { + next_token = picked; + return true; + } + + if (verify_cbs != nullptr && verify_cbs->advance_cb != nullptr) { + const auto t_advance0 = ddtree_clock::now(); + verify_cbs->advance_cb(verify_cbs->user_data, picked); + d->stats.t_exact_advance_ms += elapsed_ms(t_advance0); + } + + accepted_dfs.push_back(child); + current = child; + } + + const auto t_sample0 = ddtree_clock::now(); + const llama_token picked = pick_current_logits(d, verify_cbs); + d->stats.t_exact_sample_ms += elapsed_ms(t_sample0); + if (picked == LLAMA_TOKEN_NULL) { + LOG_ERR("%s: failed to pick final chain validation token\n", __func__); + return false; + } + next_token = picked; + return true; +} + +static std::vector decode_exact_ar_fallback_step( + llama_speculative_tree_driver * d, + llama_token root_token, + llama_pos committed_pos, + ddtree_clock::time_point t_step0) { + llama_memory_t mem = llama_get_memory(d->target_ctx); + if (!llama_memory_seq_rm(mem, /*seq_id=*/0, committed_pos, /*p1=*/-1)) { + LOG_ERR("%s: failed to clear target future range before AR fallback at pos %d\n", + __func__, (int)committed_pos); + return {}; + } + + llama_batch b = llama_batch_init(1, /*embd=*/0, /*n_seq_max=*/1); + b.n_tokens = 1; + b.token[0] = root_token; + b.pos[0] = committed_pos; + b.n_seq_id[0] = 1; + b.seq_id[0][0] = 0; + b.logits[0] = 1; + + const auto t_decode0 = ddtree_clock::now(); + const int ret = llama_decode(d->target_ctx, b); + d->stats.t_exact_decode_ms += elapsed_ms(t_decode0); + d->stats.n_exact_validate_nodes++; + llama_batch_free(b); + if (ret != 0) { + LOG_ERR("%s: AR fallback llama_decode failed: %d\n", __func__, ret); + return {}; + } + + driver_ingest_capture(d, nullptr, 1, ingest_source::replay); + + const auto t_sample0 = ddtree_clock::now(); + const llama_token next_token = pick_current_logits(d, nullptr); + d->stats.t_exact_sample_ms += elapsed_ms(t_sample0); + if (next_token == LLAMA_TOKEN_NULL) { + LOG_ERR("%s: failed to pick AR fallback token\n", __func__); + return {}; + } + + d->stats.n_steps++; + d->stats.n_committed_tokens++; + d->stats.max_committed_tokens_per_step = + std::max(d->stats.max_committed_tokens_per_step, 1); + d->stats.t_step_ms += elapsed_ms(t_step0); + + return { root_token, next_token }; +} + +static llama_token diagnose_chain_root_argmax(llama_speculative_tree_driver * d, + llama_token root_token, + llama_pos committed_pos) { + llama_mem_snapshot_id snap = llama_seq_snapshot(d->target_ctx, /*seq_id=*/0); + if (snap == LLAMA_MEM_SNAPSHOT_INVALID) { + return LLAMA_TOKEN_NULL; + } + + llama_batch b = llama_batch_init(1, /*embd=*/0, /*n_seq_max=*/1); + b.n_tokens = 1; + b.token[0] = root_token; + b.pos[0] = committed_pos; + b.n_seq_id[0] = 1; + b.seq_id[0][0] = 0; + b.logits[0] = 1; + + llama_token best = LLAMA_TOKEN_NULL; + if (llama_decode(d->target_ctx, b) == 0) { + const float * row = llama_get_logits_ith(d->target_ctx, 0); + if (row) { + best = 0; + float best_val = row[0]; + for (int64_t v = 1; v < d->n_vocab; ++v) { + if (row[v] > best_val) { best_val = row[v]; best = (llama_token)v; } + } + } + } + llama_batch_free(b); + + llama_memory_t mem = llama_get_memory(d->target_ctx); + llama_memory_seq_rm(mem, /*seq_id=*/0, committed_pos, /*p1=*/-1); + llama_seq_restore(d->target_ctx, snap); + llama_seq_release(d->target_ctx, snap); + return best; +} + +void llama_speculative_tree_driver_ingest_prompt_capture( + llama_speculative_tree_driver * d, + int32_t n_prompt_tokens) { + // Prompt prefill capture is laid out linearly; ingest columns 0..n_prompt_tokens-1. + driver_ingest_capture(d, nullptr, n_prompt_tokens, ingest_source::prompt); +} + +std::vector llama_speculative_tree_driver_step( + llama_speculative_tree_driver * d, + llama_token root_token, + llama_pos committed_pos, + const llama_speculative_tree_verify_cbs * verify_cbs) { + + if (!d) { + return {}; + } + const auto t_step0 = ddtree_clock::now(); + + const int64_t n_vocab = d->n_vocab; + + if (verify_cbs == nullptr && + ddtree_paper_verifier_enabled() && + !ddtree_trust_batched_posterior() && + !ddtree_diag_batched_enabled() && + ddtree_exact_ar_fallback_enabled()) { + return decode_exact_ar_fallback_step(d, root_token, committed_pos, t_step0); + } + + llama_speculative_draft_decode_info draft_info; + llama_speculative_draft_target_feat_view target_feat_view { + d->target_feat_ring.data(), + d->target_feat_n_committed, + d->target_feat_cap, + d->target_feat_n_embd_fc, + }; + if (!d->draft_backend->decode_topk( + root_token, committed_pos, target_feat_view, + d->top_log_probs, d->top_token_ids, draft_info)) { + return {}; + } + d->stats.t_target_feat_pack_ms += draft_info.t_target_feat_pack_ms; + d->stats.t_draft_decode_ms += draft_info.t_draft_decode_ms; + d->stats.t_topk_ms += draft_info.t_topk_ms; + + const int L = draft_info.L; + const int K = draft_info.K; + const int64_t ctx_len = draft_info.ctx_len; + + if (std::getenv("LLAMA_DDTREE_DUMP_DRAFT_TOP") != nullptr && d->stats.n_steps == 0) { + LOG_INF("draft_top port: step=%lld committed=%d ctx_len=%lld root=%d K=%d backend=%s\n", + (long long)d->stats.n_steps, + (int)committed_pos, + (long long)ctx_len, + (int)root_token, + K, + d->draft_backend->name()); + LOG_INF("draft_top port: top1:"); + for (int i = 0; i < L; ++i) { + LOG_INF(" %d", (int)d->top_token_ids[(size_t)i * K]); + } + LOG_INF("\n"); + if (K > 1) { + const int rows = std::min(4, L); + for (int r = 0; r < rows; ++r) { + LOG_INF("draft_top port: row%d:", r + 1); + for (int k = 0; k < K; ++k) { + LOG_INF(" %d", (int)d->top_token_ids[(size_t)r * K + k]); + } + LOG_INF("\n"); + } + } + } + + // ── Step 5: build DDTree ────────────────────────────────────────────────── + llama_ddtree tree; + { + const auto t0 = ddtree_clock::now(); + tree = build_ddtree( + d->top_log_probs.data(), d->top_token_ids.data(), + L, K, root_token, d->params); + d->stats.t_build_tree_ms += elapsed_ms(t0); + } + + const int N = (int)tree.nodes.size(); // includes root node at index 0 + d->stats.n_steps++; + d->stats.n_tree_verifies++; + d->stats.n_tree_nodes_total += N; + d->stats.max_tree_nodes = std::max(d->stats.max_tree_nodes, N); + + const bool paper_verifier = ddtree_paper_verifier_enabled(); + const bool trust_batched = ddtree_trust_batched_posterior(); + const bool diag_batched = ddtree_diag_batched_enabled(); + // Q4 KV batched/tree logits are not AR-equivalent. In paper mode, keep the + // safe exact verifier as the default and only pay for tree decode when the + // caller explicitly opts into unsafe trust or diagnostic comparison. + const bool paper_needs_tree = paper_verifier && (trust_batched || diag_batched); + const bool fast_batched = paper_needs_tree || ddtree_fast_batched_enabled(); + const bool trace_batched = std::getenv("LLAMA_DDTREE_TRACE") != nullptr || + std::getenv("LLAMA_DDTREE_TRACE_CHAIN_ROOT") != nullptr; + const bool need_batched_tree = fast_batched || trace_batched; + const bool use_target_top1 = ddtree_target_top1_enabled() && + trust_batched && + verify_cbs == nullptr && + !diag_batched && + !trace_batched; + + if (!need_batched_tree) { + std::vector accepted_dfs; + llama_token next_token = LLAMA_TOKEN_NULL; + { + const auto t0 = ddtree_clock::now(); + if (!validate_tree_with_chain(d, tree, committed_pos, verify_cbs, accepted_dfs, next_token)) { + return {}; + } + d->stats.t_exact_validate_ms += elapsed_ms(t0); + } + + const int commit_n = (int)accepted_dfs.size(); + d->stats.n_committed_tokens += commit_n; + d->stats.max_committed_tokens_per_step = + std::max(d->stats.max_committed_tokens_per_step, commit_n); + + std::vector result; + result.reserve(commit_n + 1); + for (int i = 0; i < commit_n; ++i) { + result.push_back(tree.nodes[accepted_dfs[i]].token_id); + } + result.push_back(next_token); + + d->stats.t_step_ms += elapsed_ms(t_step0); + return result; + } + + const bool fast_rollback = fast_batched && (paper_verifier || ddtree_fast_rollback_enabled()) && + !d->fast_rollback_unavailable; + const bool exact_gate_batched = paper_verifier && N > 1 && !trust_batched; + const bool unsafe_fast_tree_state = ddtree_unsafe_fast_tree_state_enabled(); + const bool keep_snapshot = exact_gate_batched || + (paper_verifier && fast_batched && fast_rollback && !unsafe_fast_tree_state) || + (!paper_verifier && + (!fast_batched || !fast_rollback || ddtree_snapshot_fallback_enabled())); + + // ── Step 6: snapshot before target verify ──────────────────────────────── + // By default fast-rollback mode still keeps a snapshot as a safety net. + // Set LLAMA_DDTREE_SNAPSHOT_FALLBACK=0 to remove this per-step host bounce + // after validating that persist allocation succeeds in the target runtime. + llama_mem_snapshot_id snap = LLAMA_MEM_SNAPSHOT_INVALID; + auto release_snap = [&]() { + if (snap != LLAMA_MEM_SNAPSHOT_INVALID) { + llama_seq_release(d->target_ctx, snap); + snap = LLAMA_MEM_SNAPSHOT_INVALID; + } + }; + if (N > 1) { + const auto t0 = ddtree_clock::now(); + llama_memory_t mem = llama_get_memory(d->target_ctx); + if (!llama_memory_seq_rm(mem, /*seq_id=*/0, committed_pos, /*p1=*/-1)) { + LOG_ERR("%s: failed to clear target future range before tree verify at pos %d\n", + __func__, (int)committed_pos); + return {}; + } + + if (keep_snapshot) { + snap = llama_seq_snapshot(d->target_ctx, /*seq_id=*/0); + if (snap == LLAMA_MEM_SNAPSHOT_INVALID) { + LOG_ERR("%s: llama_seq_snapshot failed before tree verify\n", __func__); + return {}; + } + } + d->stats.t_snapshot_ms += elapsed_ms(t0); + } + + llama_token diag_chain_root = LLAMA_TOKEN_NULL; + if (std::getenv("LLAMA_DDTREE_TRACE_CHAIN_ROOT") != nullptr) { + diag_chain_root = diagnose_chain_root_argmax(d, root_token, committed_pos); + } + + // ── Step 7: target verify (tree-mode forward) ───────────────────────────── + // Build a tree batch of N tokens and run target decode. + { + const auto t0 = ddtree_clock::now(); + if (use_target_top1) { + llama_set_dflash_draft_top_k(d->target_ctx, 1); + } + llama_batch tree_batch = llama_batch_init_tree(N, 0, 1); + tree_batch.n_tokens = N; + for (int i = 0; i < N; ++i) { + tree_batch.token[i] = tree.nodes[i].token_id; + tree_batch.pos[i] = committed_pos + tree.nodes[i].depth; + tree_batch.n_seq_id[i] = 1; + tree_batch.seq_id[i][0] = 0; + tree_batch.logits[i] = 1; // output logits for all nodes + tree_batch.parent_id[i] = tree.nodes[i].parent_idx; // -1 for root + } + int ret = llama_decode(d->target_ctx, tree_batch); + if (use_target_top1) { + llama_set_dflash_draft_top_k(d->target_ctx, 0); + } + llama_batch_free(tree_batch); + if (ret != 0) { + LOG_ERR("%s: target tree llama_decode failed: %d\n", __func__, ret); + release_snap(); + return {}; + } + d->stats.t_target_tree_decode_ms += elapsed_ms(t0); + } + + // ── Step 8: pick verify chain ───────────────────────────────────────────── + // Keep the batched tree posterior for diagnostics, but do not trust it for + // final acceptance. Quantized batched tree logits can drift from one-token + // AR logits enough to flip argmax on close rows. + d->posterior.resize(N); + std::vector posterior_margins; + posterior_margins.resize(N); + { + const auto t0 = ddtree_clock::now(); + const float * top_logits = nullptr; + const llama_token * top_ids = nullptr; + int32_t top_rows = 0; + int32_t top_k = 0; + if (use_target_top1 && + llama_get_dflash_draft_top_k(d->target_ctx, &top_logits, &top_ids, &top_rows, &top_k)) { + if (top_k < 1 || top_rows < N || top_ids == nullptr) { + LOG_ERR("%s: target top1 output has invalid shape: rows=%d k=%d expected rows>=%d k>=1\n", + __func__, (int)top_rows, (int)top_k, N); + release_snap(); + return {}; + } + for (int i = 0; i < N; ++i) { + d->posterior[i] = (int32_t) top_ids[(size_t)i * top_k]; + posterior_margins[i] = 0.0f; + } + } else { + if (use_target_top1) { + LOG_ERR("%s: target top1 output unavailable\n", __func__); + release_snap(); + return {}; + } + for (int i = 0; i < N; ++i) { + const float * row = llama_get_logits_ith(d->target_ctx, i); + if (!row) { + LOG_ERR("%s: target logits[%d] unavailable\n", __func__, i); + release_snap(); + return {}; + } + int32_t best = 0; + float best_val = row[0]; + int32_t second = 0; + float second_val = row[0]; + if (n_vocab > 1) { + second = 1; + second_val = row[1]; + if (second_val > best_val) { + std::swap(best, second); + std::swap(best_val, second_val); + } + } + for (int64_t v = 2; v < n_vocab; ++v) { + const float val = row[v]; + if (val > best_val) { + second = best; + second_val = best_val; + best = (int32_t)v; + best_val = val; + } else if (val > second_val) { + second = (int32_t)v; + second_val = val; + } + } + d->posterior[i] = best; + posterior_margins[i] = best_val - second_val; + } + } + d->stats.t_posterior_scan_ms += elapsed_ms(t0); + } + + std::vector batched_accepted_dfs; + llama_token batched_next_token = LLAMA_TOKEN_NULL; + { + const auto t0 = ddtree_clock::now(); + follow_verified_tree(tree, d->posterior.data(), batched_accepted_dfs, batched_next_token); + d->stats.t_accept_path_ms += elapsed_ms(t0); + } + + const int batched_commit_n = (int)batched_accepted_dfs.size(); + d->stats.n_batched_posterior_committed_tokens += batched_commit_n; + d->stats.max_batched_posterior_committed_tokens_per_step = + std::max(d->stats.max_batched_posterior_committed_tokens_per_step, batched_commit_n); + + std::vector accepted_dfs; + llama_token next_token = LLAMA_TOKEN_NULL; + + if (fast_batched) { + if (verify_cbs != nullptr && verify_cbs->sample_cb != nullptr) { + const auto t0 = ddtree_clock::now(); + follow_verified_tree_cb( + tree, + d->posterior, + verify_cbs->sample_cb, + verify_cbs->advance_cb, + verify_cbs->user_data, + accepted_dfs, + next_token); + d->stats.t_accept_path_ms += elapsed_ms(t0); + d->stats.n_fast_batched_callback_steps++; + } else { + accepted_dfs = batched_accepted_dfs; + next_token = batched_next_token; + } + + bool did_commit_state = false; + if (exact_gate_batched && verify_cbs == nullptr) { + if (snap == LLAMA_MEM_SNAPSHOT_INVALID || !llama_seq_restore(d->target_ctx, snap)) { + LOG_ERR("%s: exact gate failed to restore snapshot before chain validation\n", __func__); + release_snap(); + return {}; + } + d->stats.n_snapshot_replays++; + + std::vector exact_accepted_dfs; + llama_token exact_next_token = LLAMA_TOKEN_NULL; + { + const auto t0 = ddtree_clock::now(); + if (!validate_tree_with_chain(d, tree, committed_pos, verify_cbs, + exact_accepted_dfs, exact_next_token)) { + release_snap(); + return {}; + } + d->stats.t_exact_validate_ms += elapsed_ms(t0); + } + + if (batched_accepted_dfs == exact_accepted_dfs && batched_next_token == exact_next_token) { + d->stats.n_batched_exact_same++; + } else { + d->stats.n_batched_exact_diff++; + } + if (batched_commit_n > (int)exact_accepted_dfs.size()) { + d->stats.n_batched_exact_longer++; + } else if (batched_commit_n < (int)exact_accepted_dfs.size()) { + d->stats.n_batched_exact_shorter++; + } + if (std::getenv("LLAMA_DDTREE_TRACE") != nullptr) { + auto min_margin_for = [&](const std::vector & path) { + float min_margin = 0.0f; + for (int i = 0; i < (int)path.size(); ++i) { + const int32_t idx = path[i]; + const float margin = (idx >= 0 && idx < (int)posterior_margins.size()) ? posterior_margins[idx] : 0.0f; + min_margin = (i == 0) ? margin : std::min(min_margin, margin); + } + return min_margin; + }; + LOG_INF("ddtree_trace_fast: step=%lld pos=%d same=%d exact_commit_n=%d batched_commit_n=%d exact_next=%d batched_next=%d exact_min_margin=%.6g batched_min_margin=%.6g\n", + (long long)d->stats.n_steps, + (int)committed_pos, + batched_accepted_dfs == exact_accepted_dfs && batched_next_token == exact_next_token, + (int)exact_accepted_dfs.size(), + batched_commit_n, + (int)exact_next_token, + (int)batched_next_token, + (double)min_margin_for(exact_accepted_dfs), + (double)min_margin_for(batched_accepted_dfs)); + } + + accepted_dfs = std::move(exact_accepted_dfs); + next_token = exact_next_token; + did_commit_state = true; + } + + const int accept_depth = (int)accepted_dfs.size(); // includes root node (index 0) + const int commit_n = accept_depth; + const bool fast_path_state_safe = unsafe_fast_tree_state; + + if (!did_commit_state && fast_rollback && fast_path_state_safe && N > 1) { + { + const auto t0 = ddtree_clock::now(); + llama_kv_cache_seq_compact_tree( + d->target_ctx, + /*seq_id=*/0, + accepted_dfs.data(), + (int32_t)accepted_dfs.size(), + commit_n, + (int32_t)committed_pos); + d->stats.t_kv_compact_ms += elapsed_ms(t0); + } + + const int32_t rollback_node = commit_n > 0 ? accepted_dfs[commit_n - 1] : 0; + bool rollback_ok = false; + { + const auto t0 = ddtree_clock::now(); + rollback_ok = llama_dflash_rollback_ssm_to_dfs(d->target_ctx, /*seq_id=*/0, rollback_node); + d->stats.t_ssm_rollback_ms += elapsed_ms(t0); + } + if (!rollback_ok) { + LOG_WRN("%s: fast rollback failed at DFS node %d; falling back to snapshot replay\n", + __func__, (int)rollback_node); + d->fast_rollback_unavailable = true; + } else { + const llama_pos recurrent_tail_pos = committed_pos + commit_n - 1; + bool tail_ok = false; + { + const auto t0 = ddtree_clock::now(); + tail_ok = llama_dflash_set_recurrent_tail_pos(d->target_ctx, /*seq_id=*/0, recurrent_tail_pos); + d->stats.t_ssm_rollback_ms += elapsed_ms(t0); + } + if (!tail_ok) { + LOG_WRN("%s: failed to set recurrent tail pos to %d after fast rollback; falling back to snapshot replay\n", + __func__, (int)recurrent_tail_pos); + d->fast_rollback_unavailable = true; + } else { + driver_ingest_capture(d, accepted_dfs.data(), commit_n, ingest_source::tree); + d->stats.n_fast_rollback_steps++; + did_commit_state = true; + } + } + } + + if (N > 1 && !did_commit_state) { + if (snap == LLAMA_MEM_SNAPSHOT_INVALID || !llama_seq_restore(d->target_ctx, snap)) { + LOG_ERR("%s: fast rollback failed and snapshot fallback is unavailable\n", __func__); + release_snap(); + return {}; + } + d->stats.n_snapshot_replays++; + d->stats.n_fast_batched_replays++; + { + const auto t0 = ddtree_clock::now(); + if (!replay_committed_chain(d, tree, accepted_dfs.data(), commit_n, committed_pos)) { + release_snap(); + return {}; + } + d->stats.t_replay_ms += elapsed_ms(t0); + } + } else if (!did_commit_state) { + // Root-only verify is already a normal one-token forward. Keep its + // live target state and ingest its hidden capture for the next draft. + driver_ingest_capture(d, nullptr, commit_n, ingest_source::replay); + } + + d->stats.n_committed_tokens += commit_n; + d->stats.max_committed_tokens_per_step = + std::max(d->stats.max_committed_tokens_per_step, commit_n); + + if (std::getenv("LLAMA_DDTREE_DUMP_STEP") != nullptr) { + const int32_t rollback_node = (commit_n > 0) ? accepted_dfs[commit_n - 1] : 0; + LOG_INF("ddtree_dump: step=%lld pos=%d root=%d N=%d budget=%d top_k=%d commit_n=%d next=%d rollback_node=%d\n", + (long long)d->stats.n_steps, + (int)committed_pos, + (int)root_token, + N, + d->params.budget, + d->params.top_k, + commit_n, + (int)next_token, + (int)rollback_node); + LOG_INF("ddtree_dump: accepted_dfs="); + for (int i = 0; i < (int)accepted_dfs.size(); ++i) { + LOG_INF("%s%d", i == 0 ? "" : ",", (int)accepted_dfs[i]); + } + LOG_INF("\n"); + for (int i = 0; i < N; ++i) { + const int32_t post = (i < (int)d->posterior.size()) ? d->posterior[i] : -1; + LOG_INF("ddtree_dump: node=%d parent=%d depth=%d tok=%d posterior=%d\n", + i, + (int)tree.nodes[i].parent_idx, + (int)tree.nodes[i].depth, + (int)tree.nodes[i].token_id, + (int)post); + } + } + + release_snap(); + + std::vector result; + result.reserve(commit_n + 1); + for (int i = 0; i < commit_n; ++i) { + result.push_back(tree.nodes[accepted_dfs[i]].token_id); + } + result.push_back(next_token); + + d->stats.t_step_ms += elapsed_ms(t_step0); + return result; + } + + if (snap == LLAMA_MEM_SNAPSHOT_INVALID || !llama_seq_restore(d->target_ctx, snap)) { + LOG_ERR("%s: llama_seq_restore failed before exact chain validation\n", __func__); + release_snap(); + return {}; + } + d->stats.n_snapshot_replays++; + { + const auto t0 = ddtree_clock::now(); + if (!validate_tree_with_chain(d, tree, committed_pos, verify_cbs, accepted_dfs, next_token)) { + release_snap(); + return {}; + } + d->stats.t_exact_validate_ms += elapsed_ms(t0); + } + + const int accept_depth = (int)accepted_dfs.size(); // includes root node (index 0) + const int commit_n = accept_depth; // root is always committed + if (batched_accepted_dfs == accepted_dfs && batched_next_token == next_token) { + d->stats.n_batched_exact_same++; + } else { + d->stats.n_batched_exact_diff++; + } + if (batched_commit_n > commit_n) { + d->stats.n_batched_exact_longer++; + } else if (batched_commit_n < commit_n) { + d->stats.n_batched_exact_shorter++; + } + d->stats.n_committed_tokens += commit_n; + d->stats.max_committed_tokens_per_step = + std::max(d->stats.max_committed_tokens_per_step, commit_n); + + if (std::getenv("LLAMA_DDTREE_TRACE") != nullptr) { + const int32_t rollback_node = (commit_n > 0) ? accepted_dfs[commit_n - 1] : 0; + const int32_t posterior0 = d->posterior.empty() ? -1 : d->posterior[0]; + float exact_min_margin = 0.0f; + for (int i = 0; i < (int)accepted_dfs.size(); ++i) { + const int32_t idx = accepted_dfs[i]; + const float margin = (idx >= 0 && idx < (int)posterior_margins.size()) ? posterior_margins[idx] : 0.0f; + exact_min_margin = (i == 0) ? margin : std::min(exact_min_margin, margin); + } + float batched_min_margin = 0.0f; + for (int i = 0; i < (int)batched_accepted_dfs.size(); ++i) { + const int32_t idx = batched_accepted_dfs[i]; + const float margin = (idx >= 0 && idx < (int)posterior_margins.size()) ? posterior_margins[idx] : 0.0f; + batched_min_margin = (i == 0) ? margin : std::min(batched_min_margin, margin); + } + LOG_INF("ddtree_trace: step=%lld pos=%d root=%d N=%d budget=%d posterior0=%d next=%d commit_n=%d batched_commit_n=%d batched_next=%d exact_min_margin=%.6g batched_min_margin=%.6g rollback_node=%d\n", + (long long)d->stats.n_steps, + (int)committed_pos, + (int)root_token, + N, + d->params.budget, + (int)posterior0, + (int)next_token, + commit_n, + batched_commit_n, + (int)batched_next_token, + (double)exact_min_margin, + (double)batched_min_margin, + (int)rollback_node); + if (diag_chain_root != LLAMA_TOKEN_NULL) { + LOG_INF("ddtree_trace: chain_pre_argmax=%d tree_root_argmax=%d\n", + (int)diag_chain_root, (int)posterior0); + } + LOG_INF("ddtree_trace: accepted="); + for (int i = 0; i < (int)accepted_dfs.size(); ++i) { + LOG_INF("%s%d", i == 0 ? "" : ",", (int)accepted_dfs[i]); + } + LOG_INF("\n"); + for (int i = 0; i < N; ++i) { + const int32_t post = (i < (int)d->posterior.size()) ? d->posterior[i] : -1; + LOG_INF("ddtree_trace: node=%d parent=%d depth=%d tok=%d posterior=%d\n", + i, + (int)tree.nodes[i].parent_idx, + (int)tree.nodes[i].depth, + (int)tree.nodes[i].token_id, + (int)post); + } + } + + // Step 9 is handled by validate_tree_with_chain(): after restoring the + // snapshot, it decodes the exact accepted path one token at a time and + // ingests the corresponding hidden captures into the draft feature ring. + release_snap(); + + // ── Step 10: assemble output ────────────────────────────────────────────── + // accepted[0] = root_token (always, the input token echoed back). + // accepted[1..accept_depth-1] = newly accepted draft tokens. + // accepted[accept_depth] = bonus token (next_token). + std::vector result; + result.reserve(commit_n + 1); + for (int i = 0; i < commit_n; ++i) { + result.push_back(tree.nodes[accepted_dfs[i]].token_id); + } + result.push_back(next_token); + + d->stats.t_step_ms += elapsed_ms(t_step0); + return result; +} diff --git a/common/speculative-tree-driver.h b/common/speculative-tree-driver.h new file mode 100644 index 00000000000..676d2a03b0e --- /dev/null +++ b/common/speculative-tree-driver.h @@ -0,0 +1,124 @@ +#pragma once + +// speculative-tree-driver.h — Phase 4 DDTree spec-decode coordinator. +// +// Binds a target context (Qwen3.5-27B with capture_hidden) and a draft context +// (LLM_ARCH_DFLASH_DRAFT) and implements one spec-decode step per call. +// +// Lifecycle: +// driver = llama_speculative_tree_driver_init(target_ctx, draft_ctx, params) +// while (not done): +// accepted = llama_speculative_tree_driver_step(driver, root_token, committed_pos) +// ... append accepted to output ... +// root_token = accepted.back(); committed_pos += accepted.size() - 1 +// llama_speculative_tree_driver_free(driver) + +#include "llama.h" +#include "speculative-tree.h" + +#include +#include + +struct llama_speculative_tree_driver; + +struct llama_speculative_tree_driver_stats { + int64_t n_steps = 0; + int64_t n_tree_verifies = 0; + int64_t n_tree_nodes_total = 0; + int32_t max_tree_nodes = 0; + int64_t n_dfs_last_commits = 0; + int64_t n_snapshot_replays = 0; + int64_t n_committed_tokens = 0; + int32_t max_committed_tokens_per_step = 0; + int64_t n_batched_posterior_committed_tokens = 0; + int32_t max_batched_posterior_committed_tokens_per_step = 0; + int64_t n_batched_exact_same = 0; + int64_t n_batched_exact_diff = 0; + int64_t n_batched_exact_longer = 0; + int64_t n_batched_exact_shorter = 0; + int64_t n_fast_batched_replays = 0; + int64_t n_fast_batched_callback_steps = 0; + int64_t n_fast_rollback_steps = 0; + int64_t n_prompt_ingest_calls = 0; + int64_t n_prompt_ingested_tokens = 0; + int64_t n_tree_ingested_tokens = 0; + int64_t n_replay_ingested_tokens = 0; + int64_t n_capture_clamps = 0; + int64_t n_exact_validate_nodes = 0; + + double t_step_ms = 0.0; + double t_target_feat_pack_ms = 0.0; + double t_draft_decode_ms = 0.0; + double t_topk_ms = 0.0; + double t_build_tree_ms = 0.0; + double t_snapshot_ms = 0.0; + double t_target_tree_decode_ms = 0.0; + double t_posterior_scan_ms = 0.0; + double t_accept_path_ms = 0.0; + double t_kv_compact_ms = 0.0; + double t_ssm_rollback_ms = 0.0; + double t_ingest_capture_ms = 0.0; + double t_prompt_ingest_ms = 0.0; + double t_tree_ingest_ms = 0.0; + double t_replay_ingest_ms = 0.0; + double t_replay_ms = 0.0; + double t_exact_validate_ms = 0.0; + double t_exact_decode_ms = 0.0; + double t_exact_sample_ms = 0.0; + double t_exact_advance_ms = 0.0; +}; + +// Allocate a driver. target_ctx must have capture_hidden enabled before any +// llama_decode() calls that prime the context. draft_ctx must use the +// LLM_ARCH_DFLASH_DRAFT architecture. +llama_speculative_tree_driver * llama_speculative_tree_driver_init( + llama_context * target_ctx, + llama_context * draft_ctx, + const llama_ddtree_params & params); + +void llama_speculative_tree_driver_free(llama_speculative_tree_driver * d); + +// Run one spec-decode step. +// +// root_token — the last committed token (bonus token from the previous step, +// or the last prompt token on the very first call). +// committed_pos — number of KV positions committed in the target context so far +// (i.e. seq_pos_max + 1 for the next token to be placed). +// +// Returns accepted tokens in chronological order (length >= 1): +// accepted[0] = root_token (the input, echoed for convenience) +// accepted[1..] = newly accepted draft tokens +// accepted.back() = bonus token from the target (the next root for the next step) +// +// The KV cache of target_ctx is compacted to hold only the accepted path after +// each step. The SSM/conv state is snapshot-before and restore-on-mismatch. +// +// Optional verify callbacks. If non-null, the driver picks the next token at +// each verify-chain step via sample_cb instead of internal argmax. advance_cb +// is invoked whenever the chain accepts a child, so callers can advance their +// sampler/grammar state to mirror the chain. +struct llama_speculative_tree_verify_cbs { + llama_speculative_pick_cb sample_cb = nullptr; + llama_speculative_advance_cb advance_cb = nullptr; + void * user_data = nullptr; +}; + +// Returns an empty vector on internal failure. +std::vector llama_speculative_tree_driver_step( + llama_speculative_tree_driver * d, + llama_token root_token, + llama_pos committed_pos, + const llama_speculative_tree_verify_cbs * verify_cbs = nullptr); + +// Ingest the most recent target_ctx capture as the initial ring contents. +// Call this AFTER the chain-mode prompt prefill that primed target capture, +// BEFORE the first spec step. +// n_prompt_tokens: number of tokens in the prompt that were decoded in the prefill batch. +void llama_speculative_tree_driver_ingest_prompt_capture( + llama_speculative_tree_driver * d, + int32_t n_prompt_tokens); + +int32_t llama_speculative_tree_driver_context_window(); + +llama_speculative_tree_driver_stats llama_speculative_tree_driver_get_stats( + const llama_speculative_tree_driver * d); diff --git a/common/speculative-tree.cpp b/common/speculative-tree.cpp new file mode 100644 index 00000000000..16684f1bf92 --- /dev/null +++ b/common/speculative-tree.cpp @@ -0,0 +1,296 @@ +#include "speculative-tree.h" + +#include +#include +#include +#include +#include + +// --------------------------------------------------------------------------- +// build_ddtree +// --------------------------------------------------------------------------- + +llama_ddtree build_ddtree( + const float * top_log_probs, + const int32_t * top_token_ids, + int L, + int K, + llama_token root_token, + const llama_ddtree_params & p) { + + llama_ddtree tree; + + // Node 0 is always the root (last committed token). + tree.nodes.push_back({root_token, /*parent_idx*/ -1, /*depth*/ 0}); + + // Match standalone DFlash: --ddtree-budget counts non-root nodes, while + // flat tree storage also includes slot 0 = root. + const int budget = (p.budget < 0) ? 1 : p.budget + 1; + + if (budget <= 1 || L <= 0) { + tree.visibility.assign(1, 1); + return tree; + } + + // child_maps[flat_index] maps token_id → child flat index. + std::vector> child_maps; + child_maps.emplace_back(); // root's children (index 0) + + // Heap entry: a candidate node waiting to be inserted. + struct HeapEntry { + float neg_logw; // stored as negative so max-heap by logw + int parent_index; // flat index of already-inserted parent + int depth; // absolute depth (1..L) + int rank; // rank within top_token_ids row (depth-1) + float logw; // cumulative path log-prob from root to candidate + }; + struct HeapCmp { + bool operator()(const HeapEntry & a, const HeapEntry & b) const { + return a.neg_logw > b.neg_logw; // pop smallest neg_logw = highest logw + } + }; + std::priority_queue, HeapCmp> heap; + + if (p.chain_seed) { + // Pre-insert the top-1 greedy chain to depth min(L, budget-1). + // Guarantees the tree always contains at least the greedy chain path. + const int chain_depth = std::min(L, budget - 1); + float cum_logw = 0.0f; + int prev_idx = 0; + + for (int d = 1; d <= chain_depth; d++) { + const int32_t tok_id = top_token_ids[(size_t)(d - 1) * K + 0]; + cum_logw += top_log_probs[(size_t)(d - 1) * K + 0]; + + const int cur_idx = (int)tree.nodes.size(); + tree.nodes.push_back({tok_id, prev_idx, d}); + child_maps.emplace_back(); + child_maps[prev_idx][tok_id] = cur_idx; + + // Queue rank-1 sibling so best-first can branch off the chain. + if (K > 1) { + const float sib_logw = cum_logw + - top_log_probs[(size_t)(d - 1) * K + 0] + + top_log_probs[(size_t)(d - 1) * K + 1]; + heap.push({-sib_logw, prev_idx, d, 1, sib_logw}); + } + + prev_idx = cur_idx; + } + } else { + // Pure best-first: seed with depth-1 top-1 candidate only. + const float logw0 = top_log_probs[0]; + heap.push({-logw0, 0, 1, 0, logw0}); + } + + // Expand candidates in log-prob order until budget is reached. + while (!heap.empty() && (int)tree.nodes.size() < budget) { + const HeapEntry top = heap.top(); + heap.pop(); + + const int dm1 = top.depth - 1; + const int rank = top.rank; + const int32_t tok_id = top_token_ids[(size_t)dm1 * K + rank]; + + // Skip duplicates (chain_seed may have already inserted this token + // under the same parent). + if (child_maps[top.parent_index].count(tok_id)) { + continue; + } + + const int cur_idx = (int)tree.nodes.size(); + tree.nodes.push_back({tok_id, top.parent_index, top.depth}); + child_maps.emplace_back(); + child_maps[top.parent_index][tok_id] = cur_idx; + + // Next sibling (same depth, rank+1). + if (rank + 1 < K) { + const float sib_logw = top.logw + - top_log_probs[(size_t)dm1 * K + rank] + + top_log_probs[(size_t)dm1 * K + rank + 1]; + heap.push({-sib_logw, top.parent_index, top.depth, rank + 1, sib_logw}); + } + + // First child (depth+1, top-1 under this node). + if (top.depth < L) { + const float child_logw = top.logw + + top_log_probs[(size_t)top.depth * K + 0]; + heap.push({-child_logw, cur_idx, top.depth + 1, 0, child_logw}); + } + } + + // Build ancestor-only visibility mask for attention masking. + const int N = (int)tree.nodes.size(); + tree.visibility.assign((size_t)N * N, 0); + build_tree_visibility(tree.nodes, tree.visibility.data()); + + return tree; +} + +// --------------------------------------------------------------------------- +// follow_verified_tree +// --------------------------------------------------------------------------- + +void follow_verified_tree( + const llama_ddtree & tree, + const int32_t * posterior, + std::vector & accepted, + llama_token & next_token) { + + const int N = (int)tree.nodes.size(); + + // Build per-node child maps from parent_idx links. + std::vector> child_maps(N); + for (int i = 1; i < N; i++) { + const int p = tree.nodes[i].parent_idx; + child_maps[p][tree.nodes[i].token_id] = i; + } + + accepted.clear(); + accepted.reserve(N); + accepted.push_back(0); // root is always accepted + + int current = 0; + while (true) { + // posterior[current] is the target model's argmax at this tree position. + const auto it = child_maps[current].find((llama_token)posterior[current]); + if (it == child_maps[current].end()) { + break; + } + current = it->second; + accepted.push_back(current); + } + + // Bonus token: the target's argmax at the deepest accepted node. + next_token = (llama_token)posterior[current]; +} + +// follow_verified_tree_cb: same chain-walk semantics as follow_verified_tree +// but the picked token at each step comes from sample_cb (caller-side +// grammar/sampler), and chain advances notify the caller via advance_cb. +void follow_verified_tree_cb( + const llama_ddtree & tree, + const std::vector & posterior, + llama_speculative_pick_cb sample_cb, + llama_speculative_advance_cb advance_cb, + void * user_data, + std::vector & accepted, + llama_token & next_token) { + const int N = (int)tree.nodes.size(); + + // Build per-node child maps from parent_idx links. + std::vector> child_maps(N); + for (int i = 1; i < N; i++) { + const int p = tree.nodes[i].parent_idx; + child_maps[p][tree.nodes[i].token_id] = i; + } + + accepted.clear(); + accepted.reserve(N); + accepted.push_back(0); // root is always accepted + + int current = 0; + while (true) { + const llama_token batched_pick = (current < (int)posterior.size()) + ? (llama_token)posterior[current] + : LLAMA_TOKEN_NULL; + const int32_t picked = sample_cb(user_data, current, batched_pick); + const auto it = child_maps[current].find(picked); + if (it == child_maps[current].end()) { + next_token = (llama_token)picked; + break; + } + if (advance_cb != nullptr) { + advance_cb(user_data, (llama_token)tree.nodes[it->second].token_id); + } + current = it->second; + accepted.push_back(current); + } +} + +// --------------------------------------------------------------------------- +// build_tree_visibility +// --------------------------------------------------------------------------- + +void build_tree_visibility( + const std::vector & nodes, + uint8_t * dst) { + + const int N = (int)nodes.size(); + + // Root only sees itself. + dst[0 * N + 0] = 1; + + for (int i = 1; i < N; i++) { + const int p = nodes[i].parent_idx; + // DFS order guarantees p < i, so row p is already complete. + // Inherit the parent's visibility row, then mark self. + for (int j = 0; j < i; j++) { + dst[(size_t)i * N + j] = dst[(size_t)p * N + j]; + } + dst[(size_t)i * N + i] = 1; + } +} + +// --------------------------------------------------------------------------- +// extract_top_k_logprobs +// --------------------------------------------------------------------------- + +void extract_top_k_logprobs( + const float * logits, + int L, + int V, + int K, + float temp, + float * out_log_probs, + int32_t * out_token_ids) { + + const float inv_t = 1.0f / std::max(1e-6f, temp); + + struct Entry { + float logit; // temperature-scaled + int32_t id; + }; + // Min-heap: smallest scaled logit at top, evicted when a larger one arrives. + auto cmp_min = [](const Entry & a, const Entry & b) { + return a.logit > b.logit; + }; + + for (int i = 0; i < L; i++) { + const float * row = logits + (size_t)i * V; + + std::vector heap; + heap.reserve(K + 1); + + // Single pass: top-K min-heap. Approximate row normalization from + // the retained top-K only to avoid a full-vocab exp/logsumexp pass. + for (int j = 0; j < V; j++) { + const float l = row[j] * inv_t; + + // Maintain top-K min-heap. + if ((int)heap.size() < K) { + heap.push_back({l, (int32_t)j}); + std::push_heap(heap.begin(), heap.end(), cmp_min); + } else if (l > heap.front().logit) { + std::pop_heap(heap.begin(), heap.end(), cmp_min); + heap.back() = {l, (int32_t)j}; + std::push_heap(heap.begin(), heap.end(), cmp_min); + } + } + + // sort_heap with a greater-than comparator (cmp_min) produces descending + // order — same as std::sort with std::greater — so no reversal needed. + std::sort_heap(heap.begin(), heap.end(), cmp_min); + + const float row_best = heap.empty() ? 0.0f : heap[0].logit; + float sum_exp_top = 0.0f; + for (int k = 0; k < K; ++k) { + sum_exp_top += std::exp(heap[k].logit - row_best); + } + const float log_z_approx = row_best + std::log(sum_exp_top); + for (int k = 0; k < K; k++) { + out_log_probs[(size_t)i * K + k] = heap[k].logit - log_z_approx; + out_token_ids[(size_t)i * K + k] = heap[k].id; + } + } +} diff --git a/common/speculative-tree.h b/common/speculative-tree.h new file mode 100644 index 00000000000..67721bf7149 --- /dev/null +++ b/common/speculative-tree.h @@ -0,0 +1,116 @@ +#pragma once + +#include "llama.h" + +#include +#include + +struct llama_ddtree_params { + int budget = 22; // non-root tree node cap; flat tree has root + budget + float temp = 1.0f; // temperature for log-prob computation + bool chain_seed = true; // seed heap with greedy chain (recommended) + int block_size = 16; // matches dflash draft block_size + int top_k = 0; // 0 = auto (standalone-compatible: 8 when branching) +}; + +struct llama_ddtree_node { + llama_token token_id; + int32_t parent_idx; // -1 for root (index 0) + int32_t depth; // root = 0, root's children = 1, etc. +}; + +struct llama_ddtree { + // nodes[0] is always the root (last committed token). + // nodes[1..N-1] are DFS-ordered tree branches. + std::vector nodes; + + // visibility[i*N + j] = 1 iff node j is an ancestor of node i (inclusive). + // Row i of this matrix is the attention mask row for tree position i. + std::vector visibility; +}; + +// Build a DDTree from per-position top-K log-probabilities. +// +// top_log_probs [L, K] draft top-K log-probabilities, descending per row +// top_token_ids [L, K] matching token ids +// L number of draft positions (depth extent of the tree) +// K top-K width per position +// root_token the root token (last committed token) +// p build parameters +llama_ddtree build_ddtree( + const float * top_log_probs, + const int32_t * top_token_ids, + int L, + int K, + llama_token root_token, + const llama_ddtree_params & p); + +// Walk the tree greedily following the target's per-node argmax (posterior). +// +// Starting at the root (index 0), at each step the walk looks for a child +// whose token_id matches posterior[current_index]. The walk stops when no +// matching child exists. The root is always in the accepted list. +// +// posterior [N] target argmax token at each tree node position +// accepted output: flat node indices of the accepted path (starts with 0) +// next_token output: target argmax at the deepest accepted node (bonus token) +void follow_verified_tree( + const llama_ddtree & tree, + const int32_t * posterior, + std::vector & accepted, + llama_token & next_token); + +// Variant of follow_verified_tree that pulls the picked token at each chain +// step from caller-provided callbacks instead of a precomputed posterior[]. +// Lets callers (server) plug in grammar-aware sampling so the chain only +// accepts tokens the sampler+grammar would have produced. +// +// sample_cb (ud, logits_row_idx, batched_pick) -> picked token at this row +// advance_cb(ud, accepted_token) -> caller must advance its sampler/grammar +// +// batched_pick is the driver's already-computed full-vocab argmax for this +// row (LLAMA_TOKEN_NULL when not available, e.g. exact-chain mode). The +// callback may short-circuit by returning batched_pick when the caller's +// grammar/sampler would accept it, avoiding a full re-sample over n_vocab. +// +// advance_cb is invoked every time the chain accepts a child (= the picked +// token matched a child of `current`). It is NOT invoked for the bonus token. +typedef int32_t (*llama_speculative_pick_cb) (void * user_data, int32_t logits_row_idx, llama_token batched_pick); +typedef void (*llama_speculative_advance_cb)(void * user_data, llama_token accepted_token); + +void follow_verified_tree_cb( + const llama_ddtree & tree, + const std::vector & posterior, + llama_speculative_pick_cb sample_cb, + llama_speculative_advance_cb advance_cb, + void * user_data, + std::vector & accepted, + llama_token & next_token); + +// Compute the [N, N] ancestor visibility mask from nodes[].parent_idx. +// dst must point to an N*N uint8 buffer (caller-allocated). +// Row i: dst[i*N + j] = 1 iff node j is an ancestor of i (inclusive). +void build_tree_visibility( + const std::vector & nodes, + uint8_t * dst); + +// Extract per-position top-K log-probabilities from a [L, V] logits matrix. +// +// Uses online logsumexp + a size-K min-heap for a single-pass O(L*V) scan. +// Output rows are sorted descending by log-probability (rank 0 = argmax). +// +// logits [L, V] row-major F32 +// L number of rows (draft positions) +// V vocabulary size +// K top-K width +// temp temperature: logits are divided by temp before softmax +// out_log_probs [L, K] caller-allocated output +// out_token_ids [L, K] caller-allocated output +void extract_top_k_logprobs( + const float * logits, + int L, + int V, + int K, + float temp, + float * out_log_probs, + int32_t * out_token_ids); diff --git a/docs/ddtree-dataset-eval-plan.md b/docs/ddtree-dataset-eval-plan.md new file mode 100644 index 00000000000..598303c53ae --- /dev/null +++ b/docs/ddtree-dataset-eval-plan.md @@ -0,0 +1,229 @@ +# DDTree Dataset Eval Plan + +Goal: measure the llama.cpp DDTree port on the original DFlash public +benchmarks: HumanEval, GSM8K, and Math500. The previous numbers were produced +by `repo/dflash/scripts/bench_llm.py`; this plan measures the llama.cpp path on +Castle through `test-speculative-tree-e2e`. + +## Scope + +- Datasets: 10 prompts per dataset, `datasets.shuffle(seed=42)`, matching the + original DFlash `bench_llm.py` sampling policy. +- Generation: greedy, `n_gen=256`. +- Primary Python-compatible config: `budget=22` non-root DDTree nodes, + `top_k=0` (auto => K=8 when branching), `proposal_temp=1`, + `target_feat_ctx=2048`, `q8_0` KV, `prompt_chunk=8`. +- Harness: `build-server/bin/test-speculative-tree-e2e`. +- Correctness gate: greedy token trajectory must be bit-equal between chain and + spec output for every sample. + +Current status after the paper verifier fix: the default correctness path does +not trust batched/tree logits. It skips the redundant tree verifier and gates +final acceptance through exact one-token chain validation. The Python +`bench_llm.py` headline uses the standalone fast batched path with Q8_0 KV and +does not check bit-equal token trajectories; direct standalone testing on +HumanEval_01 shows this fast path diverges from standalone AR at generated token +34 for `n_gen=64`. + +Castle 4090 llama.cpp exact-gated result with the Python-compatible parameters, +Qwen3.5-27B Q4_K_M target + DFlash draft, `n_gen=256`, 10 prompts per dataset: + +| dataset | AR tok/s | DFlash tok/s | AL | speedup | bit-equal | +|---|---:|---:|---:|---:|---:| +| HumanEval | 46.32 | 40.23 | 8.34 | 0.87x | 10/10 | +| GSM8K | 46.31 | 40.12 | 6.72 | 0.87x | 10/10 | +| Math500 | 46.33 | 40.05 | 7.30 | 0.86x | 10/10 | + +The exact-gated path is slower than AR because it still runs one exact target +decode per generated token, then adds the draft pass. On the 30-prompt run the +target AR cost is ~21.59 ms/token; exact-gated DDTree costs ~24.9 ms/token +after draft overhead is amortized. + +For throughput experiments, the unsafe fast batched path with +`LLAMA_DDTREE_UNSAFE_TRUST_BATCHED=1` is not a correctness row: + +| dataset | AR tok/s | DFlash tok/s | AL | speedup | bit-equal | +|---|---:|---:|---:|---:|---:| +| HumanEval | 46.33 | 145.82 | 8.14 | 3.15x | 4/10 | +| GSM8K | 46.31 | 120.68 | 6.57 | 2.61x | 2/10 | +| Math500 | 46.31 | 131.32 | 7.20 | 2.84x | 5/10 | + +After separating prompt ingest chunking from runtime `n_ubatch`, the +`q8_0` run should use `--prompt-chunk 8` to keep prompt prefill AR-equivalent +while preserving `n_batch/n_ubatch=64` for tree verify. Castle 4090 result: + +| dataset | AR tok/s | DFlash tok/s | AL | speedup | bit-equal | +|---|---:|---:|---:|---:|---:| +| HumanEval | 46.38 | 155.64 | 8.82 | 3.36x | 6/10 | +| GSM8K | 46.38 | 125.39 | 6.91 | 2.70x | 3/10 | +| Math500 | 46.38 | 129.60 | 7.15 | 2.79x | 5/10 | + +The matching Python standalone reference run +`/tmp/dflash_python_bitequal_gen256_b22_q8/results.json` reported: + +| dataset | AR tok/s | DFlash tok/s | AL | bit-equal | +|---|---:|---:|---:|---:| +| HumanEval | 42.59 | 146.30 | 8.01 | 3/10 | +| GSM8K | 42.56 | 126.55 | 6.89 | 3/10 | +| Math500 | 42.58 | 131.34 | 7.12 | 3/10 | + +Under the same non-correctness-gated comparison, llama.cpp is now close to the +Python implementation: GSM8K and Math500 are within roughly 1-2%, and HumanEval +is faster but has a different bit-equal pass count. This is not a 10/10 +correctness row; it is the apples-to-apples comparison with the Python fast +batched condition. + +## Megakernel-style optimization notes + +The existing `repo/megakernel` implementation is a Qwen3.5-0.8B BF16, +batch-size-1 autoregressive decode proof of concept. It is not directly +integrable into the current Qwen3.5-27B Q4_K_M DDTree target-tree verifier. +The useful idea to borrow is to reduce graph/kernel boundaries and redundant +state traffic in the target tree path. + +Two low-risk changes were applied on 2026-05-05: + +- `11a119d77 Skip Qwen35 tree live state writes`: in tree mode with persist + rollback available, skip writing live recurrent state that will be overwritten + by rollback. +- `4f6760fe5 Skip read-only recurrent state maintenance in Qwen35 tree`: skip + recurrent zero/copy-extra maintenance for read-only tree state loads. + +Castle single-prompt GSM-style smoke, Qwen3.5-27B Q4_K_M target + DFlash draft, +`gen=128`, `budget=22`, `q8_0`, `prompt_chunk=8`, `n_batch=n_ubatch=64`: + +| variant | graph nodes | cpy ops | target_tree avg | bit-equal | +|---|---:|---:|---:|---:| +| before skip-live | 3671 | 193 | 36.78 ms | pass | +| skip live writes | 3383 | 97 | 36.43 ms | pass | +| read-only recurrent state | 2902 | 1 | 36.36 ms | pass | + +This confirms the redundant-state path exists, but it is not the main runtime +bottleneck. The remaining tree graph still has about 497 `mul_mat`, 48 +`gated_delta_net`, 48 `ssm_conv`, and 16 attention ops for a 23-node tree. +The next meaningful megakernel-style step is a larger recurrent-layer fusion, +for example combining tree conv, SiLU, q/k/v normalization, gated delta net, +and persist writes behind one Qwen35 tree op. Further small graph-maintenance +cleanup is unlikely to produce a large speedup. + +For performance experiments, `LLAMA_DDTREE_UNSAFE_TRUST_BATCHED=1` restores the +fast batched posterior behavior. Rows where `bit_equal` is not 10/10 must be +treated as non-correctness-gated throughput rows, matching the limitation of the +standalone Python benchmark. `LLAMA_DDTREE_DIAG_BATCHED=1` restores the +diagnostic batched+exact path. + +## Metrics + +Primary metrics: + +- `chain_decode_tps`: target-only greedy chain decode throughput from the same + llama.cpp binary. +- `spec_decode_tps`: DDTree decode throughput from per-step timing, excluding + model load and prompt prefill. +- `spec_acceptance`: exact average committed tokens per DDTree step. +- `speedup_decode`: `spec_decode_tps / chain_decode_tps`. +- `bit_equal_pass`: required for every sample. + +Secondary metrics: + +- `spec_e2e_tps`: includes prompt ingest and harness context setup; useful for + sanity only, not directly comparable with the original DFlash headline. +- `step_ms`, `draft_ms`, `topk_ms`, `exact_ms`: diagnose where the llama.cpp + port loses time or acceptance. + +## Commands + +From local Mac, sync the bench script to Castle if needed: + +```sh +cd /Users/leechael/workshop/playgrounds/luceboxhub-castle/repo/dflash/deps/llama.cpp +rsync -az scripts/bench_dflash_datasets_llamacpp.py \ + castle.local:/home/leechael/workshop/lucebox-hub/dflash/deps/llama.cpp/scripts/ +``` + +Build the harness on Castle: + +```sh +/usr/bin/ssh castle.local \ + 'cd /home/leechael/workshop/lucebox-hub/dflash/deps/llama.cpp && \ + cmake --build build-server -j 16 --target test-speculative-tree-e2e' +``` + +Run a one-prompt smoke first: + +```sh +/usr/bin/ssh castle.local \ + 'cd /home/leechael/workshop/lucebox-hub/dflash/deps/llama.cpp && \ + /home/leechael/workshop/lucebox-hub/sglang/.venv/bin/python \ + scripts/bench_dflash_datasets_llamacpp.py \ + --datasets HumanEval --n-sample 1 --gen 64 \ + --out-dir /tmp/llamacpp_dflash_dataset_smoke' +``` + +Before the full run, a short horizon check is useful after verifier changes: + +```sh +for n in 32 64 128 256; do + /usr/bin/ssh castle.local \ + "cd /home/leechael/workshop/lucebox-hub/dflash/deps/llama.cpp && \ + /home/leechael/workshop/lucebox-hub/sglang/.venv/bin/python \ + scripts/bench_dflash_datasets_llamacpp.py \ + --datasets HumanEval --n-sample 1 --gen $n \ + --out-dir /tmp/llamacpp_dflash_horizon_$n" || true +done +``` + +Run the full 30-prompt bench: + +```sh +/usr/bin/ssh castle.local \ + 'cd /home/leechael/workshop/lucebox-hub/dflash/deps/llama.cpp && \ + /home/leechael/workshop/lucebox-hub/sglang/.venv/bin/python \ + scripts/bench_dflash_datasets_llamacpp.py \ + --datasets HumanEval,GSM8K,Math500 \ + --n-sample 10 --gen 256 \ + --out-dir /tmp/llamacpp_dflash_dataset_full' +``` + +For faster trend data, use the short gated run: + +```sh +/usr/bin/ssh castle.local \ + 'cd /home/leechael/workshop/lucebox-hub/dflash/deps/llama.cpp && \ + /home/leechael/workshop/lucebox-hub/sglang/.venv/bin/python \ + scripts/bench_dflash_datasets_llamacpp.py \ + --datasets HumanEval,GSM8K,Math500 \ + --n-sample 10 --gen 32 \ + --out-dir /tmp/llamacpp_dflash_dataset_short_30_gen32' +``` + +For a reference-budget comparison, rerun with: + +```sh +--budget 22 --top-k 22 --out-dir /tmp/llamacpp_dflash_dataset_budget22 +``` + +For exact-path diagnostics, rerun a sample with: + +```sh +--verifier exact --exact-validation --out-dir /tmp/llamacpp_dflash_dataset_exact_diag +``` + +## Result Files + +Each run writes: + +- `summary.md`: dataset-level table. +- `results.csv`: per-prompt flat table. +- `results.json`: full arguments and per-prompt metrics. +- `logs/*.log`: raw harness output for audit and parser repair. +- `tokens/*.bin`: chain/spec generated token files. + +## Interpretation + +Use `speedup_decode` and `spec_acceptance` for direct comparison against the +original DFlash benchmark only for rows where `bit_equal_pass` is true. If +long-generation rows fail the bit-equal gate, treat those rows as correctness +failures first and performance data second. Use server/API TPS only as a +follow-up after this harness confirms acceptance and decode speed, because API +TPS includes server queueing, prompt handling, and response streaming. diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index a29dc707c3d..0283ca3a8be 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -32,6 +32,9 @@ else() add_subdirectory(simple-chat) add_subdirectory(speculative) add_subdirectory(speculative-simple) + # DDTree spec-decode driver; off by default to keep CI fast. + option(LLAMA_BUILD_EXAMPLES_SPECULATIVE_TREE "Build the llama-speculative-tree DDTree example" OFF) + add_subdirectory(speculative-tree) add_subdirectory(gen-docs) add_subdirectory(training) add_subdirectory(diffusion) diff --git a/examples/speculative-tree/CMakeLists.txt b/examples/speculative-tree/CMakeLists.txt new file mode 100644 index 00000000000..0599b9bf3ae --- /dev/null +++ b/examples/speculative-tree/CMakeLists.txt @@ -0,0 +1,7 @@ +if (LLAMA_BUILD_EXAMPLES_SPECULATIVE_TREE) + set(TARGET llama-speculative-tree) + add_executable(${TARGET} main.cpp) + install(TARGETS ${TARGET} RUNTIME) + target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) + target_compile_features(${TARGET} PRIVATE cxx_std_17) +endif() diff --git a/examples/speculative-tree/main.cpp b/examples/speculative-tree/main.cpp new file mode 100644 index 00000000000..a80d4849348 --- /dev/null +++ b/examples/speculative-tree/main.cpp @@ -0,0 +1,318 @@ +// examples/speculative-tree/main.cpp — DDTree spec-decode CLI driver. +// +// End-to-end command-line tool that loads a target Qwen3.5-27B model and a +// dflash-draft companion model, tokenizes a prompt, runs DDTree speculative +// decoding, and prints the generated text (and optionally timing statistics). +// +// Usage: +// llama-speculative-tree \ +// -m -md \ +// -p [--gen N] [--ddtree-budget N] [--temp F] \ +// [--n-gpu-layers N] [--n-ctx N] [--bench] [--out-tokens PATH] + +#include "speculative-tree-driver.h" +#include "llama.h" +#include "log.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Qwen3.5 EOS token id. +static constexpr llama_token QWEN35_EOS = 248045; + +struct cli_params { + std::string model_target; + std::string model_draft; + std::string prompt; + std::string prompt_tokens_path; + std::string out_tokens_path; + int gen = 64; + int ddtree_budget = 22; + bool ddtree_chain = true; + float temp = 1.0f; + int n_gpu_layers = 99; + int n_ctx = 4096; + bool bench = false; +}; + +static void print_usage(const char * prog) { + fprintf(stderr, + "Usage: %s -m PATH -md PATH [-p TEXT | --prompt-tokens PATH]\n" + " [--gen N] [--ddtree-budget N] [--ddtree-no-chain-seed]\n" + " [--temp F] [--n-gpu-layers N] [--n-ctx N]\n" + " [--out-tokens PATH] [--bench]\n", prog); +} + +static cli_params parse_args(int argc, char ** argv) { + cli_params p; + for (int i = 1; i < argc; ++i) { + std::string a = argv[i]; + auto next = [&]() -> std::string { + if (++i >= argc) { fprintf(stderr, "missing value for %s\n", a.c_str()); exit(1); } + return argv[i]; + }; + if (a == "-m") p.model_target = next(); + else if (a == "-md") p.model_draft = next(); + else if (a == "-p") p.prompt = next(); + else if (a == "--prompt-tokens") p.prompt_tokens_path = next(); + else if (a == "--gen") p.gen = std::stoi(next()); + else if (a == "--ddtree-budget") p.ddtree_budget = std::stoi(next()); + else if (a == "--ddtree-no-chain-seed") p.ddtree_chain = false; + else if (a == "--temp") p.temp = std::stof(next()); + else if (a == "--n-gpu-layers") p.n_gpu_layers = std::stoi(next()); + else if (a == "--n-ctx") p.n_ctx = std::stoi(next()); + else if (a == "--out-tokens") p.out_tokens_path = next(); + else if (a == "--bench") p.bench = true; + else { fprintf(stderr, "unknown option: %s\n", a.c_str()); print_usage(argv[0]); exit(1); } + } + if (p.model_target.empty() || p.model_draft.empty()) { + fprintf(stderr, "error: -m and -md are required\n"); + print_usage(argv[0]); + exit(1); + } + if (p.prompt.empty() && p.prompt_tokens_path.empty()) { + fprintf(stderr, "error: -p or --prompt-tokens is required\n"); + print_usage(argv[0]); + exit(1); + } + return p; +} + +// Write int32 LE binary. +static void write_tokens_bin(const std::string & path, const std::vector & toks) { + std::ofstream f(path, std::ios::binary); + for (llama_token t : toks) { + int32_t v = (int32_t)t; + f.write(reinterpret_cast(&v), 4); + } +} + +// Read int32 LE binary. +static std::vector read_tokens_bin(const std::string & path) { + std::ifstream f(path, std::ios::binary); + std::vector toks; + int32_t v; + while (f.read(reinterpret_cast(&v), 4)) { + toks.push_back((llama_token)v); + } + return toks; +} + +int main(int argc, char ** argv) { + cli_params cli = parse_args(argc, argv); + + llama_backend_init(); + + // Load target model. + llama_model_params mparams_tgt = llama_model_default_params(); + mparams_tgt.n_gpu_layers = cli.n_gpu_layers; + llama_model * model_tgt = llama_model_load_from_file(cli.model_target.c_str(), mparams_tgt); + if (!model_tgt) { + fprintf(stderr, "error: failed to load target model: %s\n", cli.model_target.c_str()); + return 1; + } + + // Load draft model. + llama_model_params mparams_dft = llama_model_default_params(); + mparams_dft.n_gpu_layers = cli.n_gpu_layers; + llama_model * model_dft = llama_model_load_from_file(cli.model_draft.c_str(), mparams_dft); + if (!model_dft) { + fprintf(stderr, "error: failed to load draft model: %s\n", cli.model_draft.c_str()); + llama_model_free(model_tgt); + return 1; + } + + // Create target context. + llama_context_params cparams_tgt = llama_context_default_params(); + cparams_tgt.n_ctx = (uint32_t)cli.n_ctx; + cparams_tgt.n_batch = 512; + llama_context * ctx_tgt = llama_init_from_model(model_tgt, cparams_tgt); + if (!ctx_tgt) { + fprintf(stderr, "error: failed to create target context\n"); + return 1; + } + + // Enable hidden capture on target so the draft can read its features. + llama_set_capture_hidden(ctx_tgt, true); + + // Create draft context. + // Draft uses small n_ctx (= DRAFT_CTX_MAX + block_size) since the dflash-draft + // model doesn't have a KV cache (it reuses target features directly). + llama_context_params cparams_dft = llama_context_default_params(); + cparams_dft.n_ctx = 2048 + 16; // DRAFT_CTX_MAX + block_size + cparams_dft.n_batch = 16; // one block per decode + llama_context * ctx_dft = llama_init_from_model(model_dft, cparams_dft); + if (!ctx_dft) { + fprintf(stderr, "error: failed to create draft context\n"); + return 1; + } + + // Tokenize prompt. + std::vector prompt_tokens; + if (!cli.prompt_tokens_path.empty()) { + prompt_tokens = read_tokens_bin(cli.prompt_tokens_path); + } else { + const llama_vocab * vocab = llama_model_get_vocab(model_tgt); + const int n_prompt = llama_tokenize(vocab, cli.prompt.c_str(), + (int32_t)cli.prompt.size(), + nullptr, 0, /*add_special=*/true, /*parse_special=*/false); + if (n_prompt < 0) { + fprintf(stderr, "error: tokenize failed\n"); + return 1; + } + prompt_tokens.resize(n_prompt); + llama_tokenize(vocab, cli.prompt.c_str(), (int32_t)cli.prompt.size(), + prompt_tokens.data(), n_prompt, true, false); + } + + if (prompt_tokens.empty()) { + fprintf(stderr, "error: empty prompt\n"); + return 1; + } + + // ── Prompt prefill (chain decode on target) ──────────────────────────────── + // Decode the prompt in one batch to fill the target KV cache and + // populate hidden_capture with the last token's layer features. + { + const int n_prompt = (int)prompt_tokens.size(); + llama_batch batch = llama_batch_init(n_prompt, 0, 1); + batch.n_tokens = n_prompt; + for (int i = 0; i < n_prompt; ++i) { + batch.token[i] = prompt_tokens[i]; + batch.pos[i] = (llama_pos)i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + batch.logits[i] = (i == n_prompt - 1) ? 1 : 0; + } + int ret = llama_decode(ctx_tgt, batch); + llama_batch_free(batch); + if (ret != 0) { + fprintf(stderr, "error: prompt prefill decode failed: %d\n", ret); + return 1; + } + } + + // Greedy sample from last prompt token to get the first generated token. + // This becomes the root token for spec-decode step 0. + llama_token root_token; + { + const float * logits = llama_get_logits_ith(ctx_tgt, 0); + if (!logits) { fprintf(stderr, "error: no logits after prefill\n"); return 1; } + const int n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model_tgt)); + root_token = 0; + float best = logits[0]; + for (int v = 1; v < n_vocab; ++v) { + if (logits[v] > best) { best = logits[v]; root_token = (llama_token)v; } + } + } + + llama_pos committed_pos = (llama_pos)prompt_tokens.size(); + + // ── Build DDTree driver ─────────────────────────────────────────────────── + // Must be done before calling ingest_prompt_capture so the driver is initialized. + llama_ddtree_params ddparams; + ddparams.budget = cli.ddtree_budget; + ddparams.temp = cli.temp; + ddparams.chain_seed = cli.ddtree_chain; + ddparams.block_size = 16; // dflash-draft block size + + llama_speculative_tree_driver * driver = + llama_speculative_tree_driver_init(ctx_tgt, ctx_dft, ddparams); + if (!driver) { + fprintf(stderr, "error: failed to init speculative tree driver\n"); + return 1; + } + + // Ingest the prompt prefill capture into the driver's ring buffer. + llama_speculative_tree_driver_ingest_prompt_capture(driver, (int32_t)prompt_tokens.size()); + + // ── Generation loop ─────────────────────────────────────────────────────── + std::vector generated; + generated.reserve((size_t)cli.gen + 16); + + int64_t total_steps = 0; + int64_t total_accept = 0; // sum of commit_n per step (for accept rate) + + auto t_start = std::chrono::high_resolution_clock::now(); + + const llama_vocab * target_vocab = llama_model_get_vocab(model_tgt); + + while ((int)generated.size() < cli.gen && root_token != QWEN35_EOS) { + std::vector accepted = + llama_speculative_tree_driver_step(driver, root_token, committed_pos); + + if (accepted.empty()) { + fprintf(stderr, "error: driver step returned empty result\n"); + break; + } + + // accepted[0] = root_token (echoed) + // accepted[1..n-2] = newly accepted draft tokens + // accepted[n-1] = bonus token (next root) + + // Commit all tokens (root + draft accepted); hold bonus as new root. + // accepted: [root, draft_1, ..., draft_k, bonus] + // n_new = accept_depth = number of KV positions consumed this step. + const int n_new = (int)accepted.size() - 1; // excludes bonus + root_token = accepted.back(); // bonus = new root for next step + committed_pos += (llama_pos)n_new; // advance past all committed slots + + // The first accepted token == root_token == the one we greedy-sampled from prompt + // OR was returned as bonus from the prior step. Either way, we count it. + for (int i = 0; i < n_new && (int)generated.size() < cli.gen; ++i) { + generated.push_back(accepted[i]); + if (accepted[i] == QWEN35_EOS) { + root_token = QWEN35_EOS; + break; + } + } + + total_steps++; + total_accept += n_new; + + // Stream token text to stdout. + if (!cli.bench) { + for (int i = 0; i < n_new; ++i) { + char buf[256] = {0}; + int len = llama_token_to_piece(target_vocab, accepted[i], buf, sizeof(buf)-1, + /*lstrip=*/0, /*special=*/false); + if (len > 0) { buf[len] = '\0'; fputs(buf, stdout); fflush(stdout); } + } + } + } + + auto t_end = std::chrono::high_resolution_clock::now(); + double elapsed_s = std::chrono::duration(t_end - t_start).count(); + + printf("\n"); + + if (cli.bench) { + const double tps = (double)generated.size() / elapsed_s; + const double accept_rate = total_steps > 0 + ? (double)total_accept / (double)total_steps : 0.0; + printf("[bench] generated=%d tokens, elapsed=%.2fs, tokens/s=%.1f, " + "accept_rate=%.2f tokens/step, steps=%lld\n", + (int)generated.size(), elapsed_s, tps, accept_rate, + (long long)total_steps); + } + + if (!cli.out_tokens_path.empty()) { + write_tokens_bin(cli.out_tokens_path, generated); + } + + llama_speculative_tree_driver_free(driver); + llama_free(ctx_dft); + llama_free(ctx_tgt); + llama_model_free(model_dft); + llama_model_free(model_tgt); + llama_backend_free(); + + return 0; +} diff --git a/include/llama.h b/include/llama.h index ac267b5089a..22c5639f254 100644 --- a/include/llama.h +++ b/include/llama.h @@ -48,6 +48,8 @@ #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ #define LLAMA_STATE_SEQ_VERSION 2 +#define LLAMA_MEM_SNAPSHOT_INVALID -1 + #ifdef __cplusplus extern "C" { #endif @@ -69,6 +71,9 @@ extern "C" { typedef int32_t llama_token; typedef int32_t llama_seq_id; + // opaque handle returned by llama_seq_snapshot / used by llama_seq_restore and llama_seq_release + typedef int32_t llama_mem_snapshot_id; + enum llama_vocab_type { LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback @@ -241,6 +246,10 @@ extern "C" { int32_t * n_seq_id; llama_seq_id ** seq_id; int8_t * logits; // TODO: rename this to "output" + + // tree-mode parent indices: parent_id[i] is the index of token i's parent in the batch + // -1 means root (no parent). NULL means chain mode (default behavior unchanged). + int32_t * parent_id; } llama_batch; enum llama_model_kv_override_type { @@ -290,6 +299,9 @@ extern "C" { // NULL-terminated list of buffer types to use for tensors that match a pattern const struct llama_model_tensor_buft_override * tensor_buft_overrides; + // optional target model for auxiliary models that share target tensors + const struct llama_model * target_model; + int32_t n_gpu_layers; // number of layers to store in VRAM, a negative value means all layers enum llama_split_mode split_mode; // how to split the model across multiple GPUs @@ -642,6 +654,16 @@ extern "C" { // Returns true if the model is diffusion-based (like LLaDA, Dream, etc.) LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model); + // Copy one token-embedding row from model->tok_embd into caller-supplied buffer. + // out_n must be >= model->hparams.n_embd. The embedding is returned as F32 + // regardless of the on-disk storage type (conversion happens on the backend). + // Returns 0 on success, -1 if token is out of range or tok_embd is unavailable. + LLAMA_API int llama_model_token_embd_lookup( + const struct llama_model * model, + llama_token token, + float * out, + int64_t out_n); + // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( const char * fname_inp, @@ -781,6 +803,28 @@ extern "C" { // Check if the memory supports shifting LLAMA_API bool llama_memory_can_shift(llama_memory_t mem); + // Snapshot/restore the recurrent state (SSM + conv) for seq_id. + // snapshot() allocates per-layer backup buffers and copies the current state into them. + // restore() copies the backed-up state back; release() frees the backup buffers. + // These are no-ops on non-recurrent memory types (returns LLAMA_MEM_SNAPSHOT_INVALID). + // The caller is responsible for calling release() after each snapshot. + LLAMA_API llama_mem_snapshot_id llama_seq_snapshot(struct llama_context * ctx, llama_seq_id seq_id); + LLAMA_API bool llama_seq_restore (struct llama_context * ctx, llama_mem_snapshot_id snap_id); + LLAMA_API void llama_seq_release (struct llama_context * ctx, llama_mem_snapshot_id snap_id); + + // Compact the KV cache after a tree-verify forward pass. + // The tree was placed at slots [spine_start, spine_start+N); after this call + // the accepted spine occupies slots [spine_start, spine_start+commit_n) in + // DFS order, and prompt cells (slots < spine_start) are untouched. + // No-op on non-KV memory types (e.g. pure SSM models). + LLAMA_API void llama_kv_cache_seq_compact_tree( + struct llama_context * ctx, + llama_seq_id seq_id, + const int32_t * accepted_dfs, + int32_t n_accepted, + int32_t commit_n, + int32_t spine_start); + // // State / sessions // @@ -932,7 +976,14 @@ extern "C" { int32_t embd, int32_t n_seq_max); - // Frees a batch of tokens allocated with llama_batch_init() + // Like llama_batch_init but also allocates parent_id[n_tokens], filled with -1 (tree roots). + // Callers must free with llama_batch_free(). + LLAMA_API struct llama_batch llama_batch_init_tree( + int32_t n_tokens, + int32_t embd, + int32_t n_seq_max); + + // Frees a batch of tokens allocated with llama_batch_init() or llama_batch_init_tree() LLAMA_API void llama_batch_free(struct llama_batch batch); // Process a batch of tokens. @@ -984,6 +1035,107 @@ extern "C" { // If true, all model tensors are activated during llama_decode() to load and cache their weights. LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup); + // dflash hidden capture: when enabled, qwen35 forward writes per-layer hidden states + // for the 5 dflash target capture layers into an output tensor readable via + // llama_get_hidden_capture() after llama_decode(). Toggling this triggers a graph + // reserve. If disabled (default), behavior is byte-for-byte identical to baseline. + LLAMA_API void llama_set_capture_hidden(struct llama_context * ctx, bool enable); + LLAMA_API struct ggml_tensor * llama_get_hidden_capture(struct llama_context * ctx); + + // Host-side accessor: returns a pointer into a context-owned CPU buffer. + // The device-side capture tensor is synchronized lazily only when this is called. + // Returns NULL when capture is disabled or no decode has run yet. + // out_ne0 / out_ne1 receive the tensor dimensions. + LLAMA_API const float * llama_get_hidden_capture_data(struct llama_context * ctx, + int64_t * out_ne0, + int64_t * out_ne1); + + // dflash draft target_feat injection (Task 1 Phase 4 gap fix). + LLAMA_API void llama_set_dflash_draft_top_k(struct llama_context * ctx, int32_t k); + + // Must be called on the draft context before llama_decode() when running a + // dflash-draft (LLM_ARCH_DFLASH_DRAFT) model. The driver supplies a packed + // [5*n_embd, ctx_len] F32 host buffer with per-layer hidden captures from the + // target model. committed_pos is the number of tokens already committed in the + // target context; it drives the RoPE position indices for Q and K in the draft. + // The data pointer is non-owning; it must remain valid until llama_decode() returns. + LLAMA_API void llama_set_target_feat_raw(struct llama_context * ctx, + const float * data, + int64_t n_embd_fc, + int64_t ctx_len, + int64_t committed_pos); + + LLAMA_API int llama_dflash_draft_fuse_target_feat(struct llama_context * ctx, + const float * target_feat_raw, + int64_t n_embd_fc, + int64_t ctx_len, + float * target_feat_fused); + + LLAMA_API int llama_dflash_draft_encode_top_k(struct llama_context * ctx, + struct llama_batch batch, + const float * target_feat_raw, + int64_t n_embd_fc, + int64_t ctx_len, + int64_t committed_pos, + int32_t top_k); + + LLAMA_API int llama_dflash_draft_encode_top_k_fused(struct llama_context * ctx, + struct llama_batch batch, + const float * target_feat_fused, + int64_t n_embd, + int64_t ctx_len, + int64_t committed_pos, + int32_t top_k); + + LLAMA_API int llama_dflash_draft_update_fused_cache(struct llama_context * ctx, + const float * target_feat_raw, + int64_t n_embd_fc, + int64_t n_new, + int64_t first_pos, + int64_t cap); + + LLAMA_API int llama_dflash_draft_update_fused_cache_from_capture(struct llama_context * draft_ctx, + struct llama_context * target_ctx, + const int32_t * dfs_indices, + int32_t n_dfs, + int64_t first_pos, + int64_t cap); + + LLAMA_API int llama_dflash_draft_encode_top_k_cached(struct llama_context * ctx, + struct llama_batch batch, + int64_t n_embd, + int64_t ctx_len, + int64_t ring_start, + int64_t cap, + int64_t committed_pos, + int32_t top_k); + + // dflash Phase 2.4: persist-based SSM rollback after tree verify. + LLAMA_API void llama_dflash_ensure_persist_capacity( + struct llama_context * ctx, + int64_t n_tokens); + + // After llama_kv_cache_seq_compact_tree(), call this to copy the SSM state + // captured at DFS node accepted_dfs_node from the persist buffer back into + // the live recurrent cache for seq_id, replacing the snapshot/restore/replay path. + // Must be called after the tree-mode llama_decode() and before the next decode. + // Returns true on success, false if persist buffers are unavailable. + // KNOWN LIMITATION: conv state is NOT rolled back (see Phase 2.4 Task 4 — option b). + // Conv-state divergence decays within ~K_conv tokens; the chain-vs-spec test may + // diverge by a few tokens at each tree boundary before reconverging. + LLAMA_API bool llama_dflash_rollback_ssm_to_dfs( + struct llama_context * ctx, + llama_seq_id seq_id, + int32_t accepted_dfs_node); + + // After persist-based rollback, adjust the recurrent cache bookkeeping so + // seq_pos_max() reflects the accepted chain position rather than the DFS + // tree position left by the tree-mode forward. + LLAMA_API bool llama_dflash_set_recurrent_tail_pos( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos pos); + // Set abort callback LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data); @@ -1054,6 +1206,15 @@ extern "C" { LLAMA_API llama_token * llama_get_sampled_candidates_ith (struct llama_context * ctx, int32_t i); LLAMA_API uint32_t llama_get_sampled_candidates_count_ith(struct llama_context * ctx, int32_t i); + // DFlash draft graph top-K tensors. Returns false when the last eval was not a + // dflash-draft graph with top-K output. Layout is row-major [n_rows, k]. + LLAMA_API bool llama_get_dflash_draft_top_k( + struct llama_context * ctx, + const float ** logits, + const llama_token ** token_ids, + int32_t * n_rows, + int32_t * k); + // // Vocab // diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 121c21fed95..f1be2a7ec06 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -55,6 +55,7 @@ add_library(llama models/command-r.cpp models/dbrx.cpp models/deci.cpp + models/dflash-draft.cpp models/deepseek.cpp models/deepseek2.cpp models/delta-net-base.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 6904b9c1a64..721d94ca299 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -132,6 +132,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_LLAMA_EMBED, "llama-embed" }, { LLM_ARCH_MAINCODER, "maincoder" }, { LLM_ARCH_KIMI_LINEAR, "kimi-linear" }, + { LLM_ARCH_DFLASH_DRAFT, "dflash-draft" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -545,6 +546,10 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_INDEXER_PROJ, "blk.%d.indexer.proj" }, { LLM_TENSOR_INDEXER_ATTN_K, "blk.%d.indexer.attn_k" }, { LLM_TENSOR_INDEXER_ATTN_Q_B, "blk.%d.indexer.attn_q_b" }, + // dflash-draft top-level tensors + { LLM_TENSOR_DFLASH_FC, "fc" }, + { LLM_TENSOR_DFLASH_HIDDEN_NORM, "hidden_norm" }, + { LLM_TENSOR_DFLASH_OUT_NORM, "out_norm" }, }; // declare information about the model weight tensors: @@ -765,6 +770,10 @@ static const std::map LLM_TENSOR_INFOS = { // Nemotron 3 Super {LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + // dflash-draft + {LLM_TENSOR_DFLASH_FC, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DFLASH_HIDDEN_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_DFLASH_OUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} diff --git a/src/llama-arch.h b/src/llama-arch.h index c4aabab7e0c..15183f18d88 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -136,6 +136,7 @@ enum llm_arch { LLM_ARCH_LLAMA_EMBED, LLM_ARCH_MAINCODER, LLM_ARCH_KIMI_LINEAR, + LLM_ARCH_DFLASH_DRAFT, LLM_ARCH_UNKNOWN, }; @@ -552,6 +553,10 @@ enum llm_tensor { LLM_TENSOR_NEXTN_HNORM, LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, + // dflash-draft top-level tensors + LLM_TENSOR_DFLASH_FC, // "fc" [5*hidden, hidden] + LLM_TENSOR_DFLASH_HIDDEN_NORM, // "hidden_norm" [hidden] + LLM_TENSOR_DFLASH_OUT_NORM, // "out_norm" [hidden] }; enum llm_tensor_layer { diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 6bf76939cdd..8aac5c7cce2 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -224,6 +224,7 @@ bool llama_batch_allocr::init( /*.seq_id_unq =*/ this->seq_id_unq.data(), /*.seq_idx =*/ this->seq_idx.data(), /*.output =*/ batch.logits, + /*.parent_id =*/ batch.parent_id, /*.data =*/ {}, }; @@ -428,6 +429,7 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t /*.seq_id_unq =*/ udata->seq_id_unq.data(), /*.seq_idx =*/ udata->seq_idx.data(), /*.output =*/ udata->output.data(), + /*.parent_id =*/ nullptr, /*.data =*/ std::move(udata), }; @@ -683,6 +685,11 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u assert(n_tokens%n_seqs == 0); + // tree-mode parent_ids must not be split across ubatches: the ubatch must cover + // every token of the public batch in a single emission. + GGML_ASSERT((batch.parent_id == nullptr || (size_t) n_tokens == (size_t) batch.n_tokens) && + "tree-mode batch with parent_id must fit in a single ubatch"); + auto udata = std::make_shared(); const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0; @@ -722,6 +729,10 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u udata->n_seq_id[i] = batch.n_seq_id[idxs[i]]; udata->output[i] = batch.logits[idxs[i]]; + if (batch.parent_id) { + udata->parent_id.push_back(batch.parent_id[idxs[i]]); + } + for (int s = 0; s < udata->n_seq_id[i]; ++s) { const llama_seq_id seq_id = batch.seq_id[idxs[i]][s]; @@ -747,6 +758,8 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u } } + int32_t * parent_id_ptr = udata->parent_id.empty() ? nullptr : udata->parent_id.data(); + llama_ubatch res { /*.b_equal_seqs =*/ equal_seqs, /*.n_tokens =*/ n_tokens, @@ -763,6 +776,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u /*.seq_id_unq =*/ udata->seq_id_unq.data(), /*.seq_idx =*/ udata->seq_idx.data(), /*.output =*/ udata->output.data(), + /*.parent_id =*/ parent_id_ptr, /*.data =*/ std::move(udata), }; @@ -864,25 +878,27 @@ struct llama_batch llama_batch_get_one( llama_token * tokens, int32_t n_tokens) { return { - /*n_tokens =*/ n_tokens, - /*tokens =*/ tokens, - /*embd =*/ nullptr, - /*pos =*/ nullptr, - /*n_seq_id =*/ nullptr, - /*seq_id =*/ nullptr, - /*logits =*/ nullptr, + /*n_tokens =*/ n_tokens, + /*tokens =*/ tokens, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, + /*parent_id =*/ nullptr, }; } struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { llama_batch batch = { - /*n_tokens =*/ 0, - /*tokens =*/ nullptr, - /*embd =*/ nullptr, - /*pos =*/ nullptr, - /*n_seq_id =*/ nullptr, - /*seq_id =*/ nullptr, - /*logits =*/ nullptr, + /*n_tokens =*/ 0, + /*tokens =*/ nullptr, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, + /*parent_id =*/ nullptr, }; if (embd) { @@ -904,6 +920,17 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ return batch; } +struct llama_batch llama_batch_init_tree(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { + llama_batch batch = llama_batch_init(n_tokens_alloc, embd, n_seq_max); + + batch.parent_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc); + for (int i = 0; i < n_tokens_alloc; ++i) { + batch.parent_id[i] = -1; + } + + return batch; +} + void llama_batch_free(struct llama_batch batch) { if (batch.token) free(batch.token); if (batch.embd) free(batch.embd); @@ -916,4 +943,5 @@ void llama_batch_free(struct llama_batch batch) { free(batch.seq_id); } if (batch.logits) free(batch.logits); + if (batch.parent_id) free(batch.parent_id); } diff --git a/src/llama-batch.h b/src/llama-batch.h index f77520e86c3..6319170db00 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -50,6 +50,7 @@ struct llama_ubatch { llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx int8_t * output; // [n_tokens] | i | - + int32_t * parent_id; // [n_tokens] | i | parent index, -1 = root; NULL = chain mode struct data_t { std::vector token; @@ -60,6 +61,7 @@ struct llama_ubatch { std::vector seq_id_unq; std::vector seq_idx; std::vector output; + std::vector parent_id; std::vector seq_id_data; }; diff --git a/src/llama-context.cpp b/src/llama-context.cpp index ee0c29235cd..0ca058198b9 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -6,6 +6,9 @@ #include "llama-batch.h" #include "llama-io.h" #include "llama-memory.h" +#include "llama-kv-cache.h" +#include "llama-memory-recurrent.h" +#include "llama-memory-hybrid.h" #include "llama-mmap.h" #include "llama-model.h" #include "llama-ext.h" @@ -13,10 +16,27 @@ #include #include +#include #include #include #include +static bool llama_dflash_fast_rollback_enabled() { + const char * verifier = std::getenv("LLAMA_DDTREE_VERIFIER"); + if (verifier != nullptr && (std::strcmp(verifier, "paper") == 0 || std::strcmp(verifier, "tree") == 0)) { + return true; + } + const char * exact = std::getenv("LLAMA_DDTREE_EXACT_VALIDATION"); + if (exact != nullptr && exact[0] == '1') { + return false; + } + const char * e = std::getenv("LLAMA_DDTREE_FAST_ROLLBACK"); + if (e != nullptr && e[0] != '\0') { + return e[0] == '1'; + } + return true; +} + // // llama_context // @@ -156,6 +176,12 @@ llama_context::llama_context( cparams.fused_gdn_ar = true; cparams.fused_gdn_ch = true; cparams.auto_fgdn = true; + if (const char * e = std::getenv("LLAMA_FUSED_GDN_AR")) { + cparams.fused_gdn_ar = std::atoi(e) != 0; + } + if (const char * e = std::getenv("LLAMA_FUSED_GDN_CH")) { + cparams.fused_gdn_ch = std::atoi(e) != 0; + } // with causal attention, the batch size is limited by the context size cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; @@ -990,6 +1016,27 @@ size_t llama_context::get_sampled_probs_count(int32_t idx) { } } +bool llama_context::get_dflash_draft_top_k( + const float ** top_logits, + const llama_token ** top_token_ids, + int32_t * n_rows, + int32_t * k) { + if (top_logits) { + *top_logits = dflash_draft_top_logits.empty() ? nullptr : dflash_draft_top_logits.data(); + } + if (top_token_ids) { + *top_token_ids = dflash_draft_top_token_ids.empty() ? nullptr : dflash_draft_top_token_ids.data(); + } + if (n_rows) { + *n_rows = dflash_draft_top_rows; + } + if (k) { + *k = dflash_draft_top_k; + } + return dflash_draft_top_rows > 0 && dflash_draft_top_k > 0 && + !dflash_draft_top_token_ids.empty(); +} + void llama_context::attach_threadpool( ggml_threadpool_t threadpool, @@ -1052,6 +1099,238 @@ void llama_context::set_causal_attn(bool value) { sched_need_reserve = true; } +void llama_context::set_capture_hidden(bool enable) { + LLAMA_LOG_DEBUG("%s: enable = %d\n", __func__, enable); + if (capture_hidden == enable) { + return; + } + capture_hidden = enable; + sched_need_reserve = true; // graph topology changes when capture is toggled +} + +ggml_tensor * llama_context::get_hidden_capture() const { + if (gf_res_prev && gf_res_prev->t_hidden_capture) { + return gf_res_prev->t_hidden_capture; + } + return nullptr; +} + +void llama_context::set_dflash_draft_top_k(int32_t k) { + k = std::max(0, k); + if (dflash_draft_top_k_req == k) { + return; + } + dflash_draft_top_k_req = k; + sched_need_reserve = true; +} + +const float * llama_context::get_hidden_capture_data(int64_t * out_ne0, int64_t * out_ne1) const { + ggml_tensor * t_cap = get_hidden_capture(); + if (t_cap == nullptr || t_cap->buffer == nullptr) { + if (out_ne0) *out_ne0 = 0; + if (out_ne1) *out_ne1 = 0; + return nullptr; + } + + const size_t cap_n = ggml_nelements(t_cap); + if (!hidden_capture_host_valid || hidden_capture_host.size() < cap_n || + hidden_capture_ne0 != t_cap->ne[0] || hidden_capture_ne1 != t_cap->ne[1]) { + if (hidden_capture_host.size() < cap_n) { + hidden_capture_host.resize(cap_n); + } + hidden_capture_ne0 = t_cap->ne[0]; + hidden_capture_ne1 = t_cap->ne[1]; + ggml_backend_tensor_get(t_cap, hidden_capture_host.data(), 0, cap_n * sizeof(float)); + hidden_capture_host_valid = true; + } + + if (out_ne0) *out_ne0 = hidden_capture_ne0; + if (out_ne1) *out_ne1 = hidden_capture_ne1; + return hidden_capture_host.data(); +} + +ggml_tensor * llama_context::dflash_get_persist_inter(int32_t il) const { + if (il < 0 || il >= (int32_t)dflash_persist_inter_l.size()) { + return nullptr; + } + return dflash_persist_inter_l[il]; +} + +ggml_tensor * llama_context::dflash_get_persist_conv(int32_t il) const { + if (il < 0 || il >= (int32_t)dflash_persist_conv_l.size()) { + return nullptr; + } + return dflash_persist_conv_l[il]; +} + +void llama_context::ensure_dflash_persist_capacity(int64_t n_tokens) { + if (model.arch != LLM_ARCH_QWEN35) { + return; // only Qwen3.5 uses delta-net recurrent layers + } + if (n_tokens <= dflash_persist_max_n_tokens) { + return; // already large enough + } + if (n_tokens <= dflash_persist_failed_n_tokens) { + return; // allocation already failed for this size in this context + } + + // Derive SSM dimensions from hparams (same as build_layer_attn_linear). + const auto & hparams = model.hparams; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t num_v_heads = hparams.ssm_dt_rank; // H_v + const int64_t head_v_dim = d_inner / num_v_heads; // S_v + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t conv_channels = d_inner + 2 * (int64_t)hparams.ssm_n_group * (int64_t)hparams.ssm_d_state; + const int32_t n_layer = (int32_t)hparams.n_layer; + + // Release existing allocation before reallocating. + dflash_persist_inter_l.clear(); + dflash_persist_conv_l.clear(); + dflash_persist_inter_buf.reset(); + dflash_persist_inter_ctx.reset(); + dflash_persist_ctxs_bufs.clear(); + + dflash_persist_inter_l.resize(n_layer, nullptr); + dflash_persist_conv_l.resize(n_layer, nullptr); + + // Persist tensors have to live next to each layer's recurrent state. With + // partial offload, CPU and CUDA recurrent layers coexist; one shared buffer + // would make CUDA layers write persist state into CPU memory or vice versa. + auto * raw_mem = memory.get(); + auto * mem_recr = dynamic_cast(raw_mem); + if (!mem_recr) { + if (auto * hyb = dynamic_cast(raw_mem)) { + mem_recr = hyb->get_mem_recr(); + } + } + + struct ggml_backend_buft_comparator { + bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const { + return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0; + } + }; + + // dflash Phase 2.4 fix: allocate each layer's persist tensors in a separate + // context so they get separate backend buffers. This lets small (~26 MiB) + // per-layer allocations fit into fragmented GPU memory where one large + // (~1.7 GiB) contiguous block would fail. + size_t total_bytes = 0; + size_t total_ssm_bytes = 0; + size_t total_conv_bytes = 0; + for (int il = 0; il < n_layer; ++il) { + if (!hparams.is_recurrent(il)) { + continue; // full-attn layer — no persist buffer needed + } + + ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); + if (mem_recr && il < (int)mem_recr->s_l.size() && + mem_recr->s_l[il] != nullptr && mem_recr->s_l[il]->buffer != nullptr) { + buft = ggml_backend_buffer_get_type(mem_recr->s_l[il]->buffer); + } + + struct ggml_init_params init_params = { + /* mem_size = */ ggml_tensor_overhead() * 2 + 1024, + /* mem_buffer = */ nullptr, + /* no_alloc = */ true, + }; + ggml_context * ctx = ggml_init(init_params); + if (!ctx) { + LLAMA_LOG_ERROR("%s: failed to create ggml context for persist buffers\n", __func__); + dflash_persist_inter_l.clear(); + dflash_persist_conv_l.clear(); + dflash_persist_ctxs_bufs.clear(); + dflash_persist_failed_n_tokens = std::max(dflash_persist_failed_n_tokens, n_tokens); + return; + } + + // SSM persist: [S_v, S_v, H_v, n_tokens]. Keep F32 by default so fast + // rollback restores the same recurrent state precision as chain decode. + // The unsafe DDTree path mirrors standalone DFlash's F16 intermediates; + // use LLAMA_DDTREE_PERSIST_S_F32=1 to force strict storage there. + const bool persist_s_f16 = []{ + const char * force_f32 = getenv("LLAMA_DDTREE_PERSIST_S_F32"); + if (force_f32 && force_f32[0] == '1') { + return false; + } + const char * force_f16 = getenv("LLAMA_DDTREE_PERSIST_S_F16"); + if (force_f16 && force_f16[0] == '1') { + return true; + } + const char * unsafe = getenv("LLAMA_DDTREE_UNSAFE_FAST_TREE_STATE"); + return unsafe && unsafe[0] == '1'; + }(); + const ggml_type persist_s_type = persist_s_f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; + ggml_tensor * ts = ggml_new_tensor_4d(ctx, + persist_s_type, head_v_dim, head_v_dim, num_v_heads, n_tokens); + ggml_format_name(ts, "dflash_persist_il%d", il); + dflash_persist_inter_l[il] = ts; + + // Conv persist: [K-1, conv_channels, n_tokens] F32 — matches the live + // r_l[il] layout (K-1 fastest, then conv_channels) per token. + ggml_tensor * tc = ggml_new_tensor_3d(ctx, + GGML_TYPE_F32, d_conv - 1, conv_channels, n_tokens); + ggml_format_name(tc, "dflash_persist_conv_il%d", il); + dflash_persist_conv_l[il] = tc; + + const size_t layer_ssm_bytes = ggml_nbytes(ts); + const size_t layer_conv_bytes = ggml_nbytes(tc); + total_ssm_bytes += layer_ssm_bytes; + total_conv_bytes += layer_conv_bytes; + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + LLAMA_LOG_ERROR("%s: failed to allocate persist buffer for layer %d (n_tokens=%lld, layer_ssm=%.2f MiB, layer_conv=%.2f MiB, total_ssm=%.2f MiB, total_conv=%.2f MiB)\n", + __func__, il, (long long)n_tokens, + (double)layer_ssm_bytes / (1024.0 * 1024.0), + (double)layer_conv_bytes / (1024.0 * 1024.0), + (double)total_ssm_bytes / (1024.0 * 1024.0), + (double)total_conv_bytes / (1024.0 * 1024.0)); + dflash_persist_inter_l.clear(); + dflash_persist_conv_l.clear(); + dflash_persist_ctxs_bufs.clear(); + dflash_persist_failed_n_tokens = std::max(dflash_persist_failed_n_tokens, n_tokens); + return; + } + ggml_backend_buffer_clear(buf, 0); + total_bytes += ggml_backend_buffer_get_size(buf); + dflash_persist_ctxs_bufs.emplace_back(ggml_context_ptr(ctx), buf); + } + dflash_persist_max_n_tokens = n_tokens; + dflash_persist_failed_n_tokens = 0; + + LLAMA_LOG_INFO("%s: allocated dflash persist buffers: %d layers, %lld tokens, %.2f MiB across %zu backend buffers (ssm=%.2f MiB, conv=%.2f MiB)\n", + __func__, n_layer, (long long)n_tokens, + (double)total_bytes / (1024.0 * 1024.0), + dflash_persist_ctxs_bufs.size(), + (double)total_ssm_bytes / (1024.0 * 1024.0), + (double)total_conv_bytes / (1024.0 * 1024.0)); +} + +void llama_context::set_target_feat_raw(const float * data, int64_t n_embd_fc, int64_t ctx_len, + int64_t committed_pos) { + // Stash non-owning pointer and dims; read by llm_graph_input_target_feat::set_input(). + pending_target_feat_raw = data; + pending_target_feat_n_embd_fc = n_embd_fc; + pending_target_feat_ctx_len = ctx_len; + pending_draft_committed_pos = committed_pos; + pending_target_feat_fused = false; + pending_dflash_fuse_only = false; + pending_dflash_kv_update_only = false; + pending_target_feat_tensor = nullptr; +} + +void llama_context::set_target_feat_fused(const float * data, int64_t n_embd, int64_t ctx_len, + int64_t committed_pos) { + pending_target_feat_raw = data; + pending_target_feat_n_embd_fc = n_embd; + pending_target_feat_ctx_len = ctx_len; + pending_draft_committed_pos = committed_pos; + pending_target_feat_fused = true; + pending_dflash_fuse_only = false; + pending_dflash_kv_update_only = false; + pending_target_feat_tensor = nullptr; +} + void llama_context::set_warmup(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); @@ -1168,21 +1447,92 @@ bool llama_context::set_adapter_cvec( return res; } +static std::map build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset); +static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map & samplers); +static void copy_tensor_async_ints( + const std::map & tensor_map, + const buffer_view & sampled, + const std::map & seq_to_row, + ggml_backend_sched_t sched); +static void copy_tensor_async_floats( + const std::map & tensor_map, + const buffer_view & dst, + size_t stride, + std::vector & counts, + const std::map & seq_to_row, + ggml_backend_sched_t sched); +static void copy_tensor_async_candidates( + const std::map & tensor_map, + const buffer_view & dst, + size_t stride, + std::vector & counts, + const std::map & seq_to_row, + ggml_backend_sched_t sched); + llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { + const bool profile_dflash_draft = + model.arch == LLM_ARCH_DFLASH_DRAFT && + std::getenv("LLAMA_DDTREE_PROFILE") != nullptr; + const bool profile_dflash_tree = + model.arch == LLM_ARCH_QWEN35 && + ubatch.parent_id != nullptr && + std::getenv("LLAMA_DDTREE_PROFILE") != nullptr; + const bool profile_dflash = profile_dflash_draft || profile_dflash_tree; + + const int64_t t_total_start_us = profile_dflash ? ggml_time_us() : 0; + int64_t t_apply_us = 0; + int64_t t_build_alloc_us = 0; + int64_t t_set_inputs_us = 0; + int64_t t_compute_us = 0; + bool reused_graph = false; + + int64_t t0_us = profile_dflash ? ggml_time_us() : 0; if (mctx && !mctx->apply()) { LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); ret = GGML_STATUS_FAILED; return nullptr; } + if (profile_dflash) { + t_apply_us = ggml_time_us() - t0_us; + } + + auto ensure_dflash_runtime = [&](llm_graph_result_ptr & res_ptr, ggml_backend_sched_ptr & sched_ptr) { + if (!res_ptr) { + res_ptr.reset(new llm_graph_result(this->graph_max_nodes(std::min(cparams.n_ctx, cparams.n_ubatch)))); + } + if (!sched_ptr) { + sched_ptr.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), + res_ptr->get_max_nodes(), cparams.pipeline_parallel, + cparams.op_offload)); + } + }; - auto * res = gf_res_prev.get(); + llm_graph_result * res = gf_res_prev.get(); + ggml_backend_sched_t sched_use = sched.get(); + if (model.arch == LLM_ARCH_DFLASH_DRAFT) { + if (pending_dflash_kv_update_only) { + ensure_dflash_runtime(dflash_res_kv, dflash_sched_kv); + res = dflash_res_kv.get(); + sched_use = dflash_sched_kv.get(); + } else if (pending_dflash_fuse_only) { + ensure_dflash_runtime(dflash_res_fuse, dflash_sched_fuse); + res = dflash_res_fuse.get(); + sched_use = dflash_sched_fuse.get(); + } else if (dflash_draft_top_k_req > 0 && pending_target_feat_tensor != nullptr) { + ensure_dflash_runtime(dflash_res_draft, dflash_sched_draft); + res = dflash_res_draft.get(); + sched_use = dflash_sched_draft.get(); + } + } auto * gf = res->get_gf(); // the new graph parameters // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters - const auto gparams = graph_params(res, ubatch, mctx, gtype); + const auto gparams = graph_params(res, ubatch, mctx, gtype, sched_use); - if (!graph_reuse_disable && res->can_reuse(gparams)) { + const bool force_rebuild = model.arch == LLM_ARCH_DFLASH_DRAFT && + (pending_dflash_kv_update_only || pending_dflash_fuse_only); + if (!force_rebuild && !graph_reuse_disable && res->can_reuse(gparams)) { //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__); // with pipeline parallelism, the previous graph_compute_async may still be running @@ -1193,11 +1543,13 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } n_reused++; + reused_graph = true; } else { + t0_us = profile_dflash ? ggml_time_us() : 0; res->reset(); - ggml_backend_sched_reset(sched.get()); - ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); + ggml_backend_sched_reset(sched_use); + ggml_backend_sched_set_eval_callback(sched_use, cparams.cb_eval, cparams.cb_eval_user_data); //const auto t_start_us = ggml_time_us(); @@ -1211,24 +1563,57 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll return nullptr; } - if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { + if (!ggml_backend_sched_alloc_graph(sched_use, gf)) { LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); ret = GGML_STATUS_ALLOC_FAILED; return nullptr; } + if (profile_dflash) { + t_build_alloc_us = ggml_time_us() - t0_us; + } + + if (profile_dflash_tree) { + int op_counts[GGML_OP_COUNT] = {}; + for (int i = 0; i < ggml_graph_n_nodes(gf); ++i) { + ggml_tensor * node = ggml_graph_node(gf, i); + if ((int)node->op >= 0 && node->op < GGML_OP_COUNT) { + op_counts[node->op]++; + } + } + LLAMA_LOG_INFO("dflash_tree_graph: tokens=%u outputs=%d nodes=%d " + "mul_mat=%d fgdn=%d ssm_conv=%d flash_attn=%d cpy=%d rms_norm=%d l2_norm=%d unary=%d soft_max=%d\n", + ubatch.n_tokens, + n_outputs, + ggml_graph_n_nodes(gf), + op_counts[GGML_OP_MUL_MAT], + op_counts[GGML_OP_GATED_DELTA_NET], + op_counts[GGML_OP_SSM_CONV], + op_counts[GGML_OP_FLASH_ATTN_EXT], + op_counts[GGML_OP_CPY], + op_counts[GGML_OP_RMS_NORM], + op_counts[GGML_OP_L2_NORM], + op_counts[GGML_OP_UNARY], + op_counts[GGML_OP_SOFT_MAX]); + } } // set the input data for the input tensors { - //const auto t_start_us = ggml_time_us(); + t0_us = profile_dflash ? ggml_time_us() : 0; // FIXME this call causes a crash if any model inputs were not used in the graph and were therefore not allocated res->set_inputs(&ubatch); - //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); + if (profile_dflash) { + t_set_inputs_us = ggml_time_us() - t0_us; + } } - const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1); + t0_us = profile_dflash ? ggml_time_us() : 0; + const auto status = graph_compute(sched_use, res->get_gf(), ubatch.n_tokens > 1); + if (profile_dflash) { + t_compute_us = ggml_time_us() - t0_us; + } if (status != GGML_STATUS_SUCCESS) { LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status); ret = status; @@ -1237,6 +1622,29 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll ret = GGML_STATUS_SUCCESS; + if (profile_dflash_draft) { + LLAMA_LOG_INFO("dflash_draft_ubatch_timing: tokens=%u outputs=%d ctx_len=%" PRId64 " reused=%d apply=%.3f build_alloc=%.3f set_inputs=%.3f compute=%.3f total=%.3f ms\n", + ubatch.n_tokens, + n_outputs, + pending_target_feat_ctx_len, + reused_graph ? 1 : 0, + t_apply_us / 1000.0, + t_build_alloc_us / 1000.0, + t_set_inputs_us / 1000.0, + t_compute_us / 1000.0, + (ggml_time_us() - t_total_start_us) / 1000.0); + } else if (profile_dflash_tree) { + LLAMA_LOG_INFO("dflash_tree_ubatch_timing: tokens=%u outputs=%d reused=%d apply=%.3f build_alloc=%.3f set_inputs=%.3f compute=%.3f total=%.3f ms\n", + ubatch.n_tokens, + n_outputs, + reused_graph ? 1 : 0, + t_apply_us / 1000.0, + t_build_alloc_us / 1000.0, + t_set_inputs_us / 1000.0, + t_compute_us / 1000.0, + (ggml_time_us() - t_total_start_us) / 1000.0); + } + return res; } @@ -1316,7 +1724,7 @@ int llama_context::encode(const llama_batch & batch_inp) { auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); // extract logits - if (logits.data && t_logits) { + if (logits.data && t_logits && dflash_draft_top_k_req <= 0 && needs_raw_logits(ubatch, sampling.samplers)) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); GGML_ASSERT(backend_res != nullptr); GGML_ASSERT(logits.data != nullptr); @@ -1324,6 +1732,70 @@ int llama_context::encode(const llama_batch & batch_inp) { ggml_backend_tensor_get_async(backend_res, t_logits, logits.data, 0, n_tokens*n_vocab*sizeof(float)); } + // Copy backend sampling output if this ubatch produced any sampling tensors. + if (!sampling.samplers.empty() && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || + !res->t_sampled_logits.empty() || !res->t_candidates.empty())) { + const auto seq_to_output_row = build_seq_to_output_row(ubatch, 0); + const auto stride = n_vocab; + + copy_tensor_async_ints(res->t_sampled, sampling.sampled, seq_to_output_row, sched.get()); + copy_tensor_async_floats(res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, + seq_to_output_row, sched.get()); + copy_tensor_async_floats(res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, + seq_to_output_row, sched.get()); + copy_tensor_async_candidates(res->t_candidates, sampling.candidates, stride, sampling.candidates_count, + seq_to_output_row, sched.get()); + } + + // dflash-draft top-K output: pull compact [K, n_tokens] tensors into host memory. + if (model.arch == LLM_ARCH_DFLASH_DRAFT && res->get_dflash_top_logits() != nullptr && + res->get_dflash_top_ids() != nullptr) { + ggml_tensor * t_top_logits = res->get_dflash_top_logits(); + ggml_tensor * t_top_ids = res->get_dflash_top_ids(); + ggml_backend_t backend_logits = ggml_backend_sched_get_tensor_backend(sched.get(), t_top_logits); + ggml_backend_t backend_ids = ggml_backend_sched_get_tensor_backend(sched.get(), t_top_ids); + GGML_ASSERT(backend_logits != nullptr); + GGML_ASSERT(backend_ids != nullptr); + + const int64_t top_k = t_top_logits->ne[0]; + const int64_t rows = t_top_logits->ne[1]; + GGML_ASSERT(t_top_ids->ne[0] == top_k && t_top_ids->ne[1] == rows); + + dflash_draft_top_k = (int32_t) top_k; + dflash_draft_top_rows = (int32_t) rows; + dflash_draft_top_logits.resize((size_t) top_k * rows); + dflash_draft_top_token_ids.resize((size_t) top_k * rows); + + ggml_backend_tensor_get_async(backend_logits, t_top_logits, dflash_draft_top_logits.data(), 0, + ggml_nbytes(t_top_logits)); + ggml_backend_tensor_get_async(backend_ids, t_top_ids, dflash_draft_top_token_ids.data(), 0, + ggml_nbytes(t_top_ids)); + } else { + dflash_draft_top_k = 0; + dflash_draft_top_rows = 0; + dflash_draft_top_logits.clear(); + dflash_draft_top_token_ids.clear(); + } + + if (capture_hidden && res->t_hidden_capture != nullptr) { + ggml_tensor * t_cap = res->t_hidden_capture; + hidden_capture_ne0 = t_cap->ne[0]; + hidden_capture_ne1 = t_cap->ne[1]; + const char * direct = std::getenv("LLAMA_DDTREE_CAPTURE_DIRECT"); + if (direct != nullptr && direct[0] == '1') { + hidden_capture_host_valid = false; + } else { + ggml_backend_t backend_cap = ggml_backend_sched_get_tensor_backend(sched.get(), t_cap); + GGML_ASSERT(backend_cap != nullptr); + const size_t cap_n = ggml_nelements(t_cap); + if (hidden_capture_host.size() < cap_n) { + hidden_capture_host.resize(cap_n); + } + ggml_backend_tensor_get_async(backend_cap, t_cap, hidden_capture_host.data(), 0, cap_n * sizeof(float)); + hidden_capture_host_valid = true; + } + } + // extract embeddings if (embd.data && t_embd) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); @@ -1362,49 +1834,1162 @@ int llama_context::encode(const llama_batch & batch_inp) { // extract the rerank score - n_cls_out floats per sequence auto & embd_seq_out = embd_seq; - const uint32_t n_cls_out = hparams.n_cls_out; + const uint32_t n_cls_out = hparams.n_cls_out; + + for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { + const llama_seq_id seq_id = ubatch.seq_id_unq[s]; + const int32_t seq_idx = ubatch.seq_idx[seq_id]; + + embd_seq_out[seq_id].resize(n_cls_out); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_UNSPECIFIED: + { + GGML_ABORT("unknown pooling type"); + } + } + } + + // TODO: hacky solution + if (model.arch == LLM_ARCH_T5 && t_embd) { + //cross.t_embd = t_embd; + + synchronize(); + + cross.n_embd = t_embd->ne[0]; + cross.n_enc = t_embd->ne[1]; + cross.v_embd.resize(cross.n_embd*cross.n_enc); + memcpy(cross.v_embd.data(), embd.data, ggml_nbytes(t_embd)); + + const auto & batch = balloc->get_batch(); + + // remember the sequence ids used during the encoding - needed for cross attention later + cross.seq_ids_enc.resize(n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + cross.seq_ids_enc[i].clear(); + + for (int s = 0; s < batch.n_seq_id[i]; s++) { + const llama_seq_id seq_id = batch.seq_id[i][s]; + + cross.seq_ids_enc[i].insert(seq_id); + } + } + } + + return 0; +} + +int llama_context::dflash_draft_fuse_target_feat( + const float * target_feat_raw, + int64_t n_embd_fc, + int64_t ctx_len, + float * target_feat_fused) { + if (model.arch != LLM_ARCH_DFLASH_DRAFT || target_feat_raw == nullptr || target_feat_fused == nullptr || + n_embd_fc <= 0 || ctx_len <= 0) { + return -1; + } + + set_target_feat_raw(target_feat_raw, n_embd_fc, ctx_len, 0); + pending_dflash_fuse_only = true; + set_dflash_draft_top_k(0); + + const auto & hparams = model.hparams; + const int64_t n_embd = hparams.n_embd_inp(); + std::vector dummy_embd((size_t) n_embd, 0.0f); + llama_pos pos = 0; + int32_t n_seq_id = 1; + llama_seq_id seq_id_value = 0; + llama_seq_id * seq_id = &seq_id_value; + int8_t output = 1; + + llama_batch batch{}; + batch.n_tokens = 1; + batch.token = nullptr; + batch.embd = dummy_embd.data(); + batch.pos = &pos; + batch.n_seq_id = &n_seq_id; + batch.seq_id = &seq_id; + batch.logits = &output; + + if (!balloc->init(batch, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { + LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); + pending_dflash_fuse_only = false; + return -1; + } + + const uint32_t n_tokens = balloc->get_n_tokens(); + const llama_ubatch ubatch = balloc->split_simple(n_tokens); + + if (t_compute_start_us == 0) { + t_compute_start_us = ggml_time_us(); + } + + embd_seq.clear(); + sched_reserve(); + n_queued_tokens += n_tokens; + n_outputs = n_tokens; + + const bool causal_attn_org = cparams.causal_attn; + cparams.causal_attn = false; + + ggml_status status; + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); + + cparams.causal_attn = causal_attn_org; + pending_dflash_fuse_only = false; + + if (!res) { + switch (status) { + case GGML_STATUS_ABORTED: return 2; + case GGML_STATUS_ALLOC_FAILED: return -2; + case GGML_STATUS_FAILED: return -3; + case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); + } + } + + ggml_tensor * t_fused = res->get_embd(); + if (t_fused == nullptr || t_fused->ne[0] != n_embd || t_fused->ne[1] != ctx_len) { + return -3; + } + + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), t_fused); + GGML_ASSERT(backend != nullptr); + ggml_backend_tensor_get_async(backend, t_fused, target_feat_fused, 0, ggml_nbytes(t_fused)); + synchronize(); + + return 0; +} + +static ggml_backend_t dflash_tensor_backend( + const std::vector & backends, + ggml_tensor * t) { + if (t == nullptr || t->buffer == nullptr) { + return nullptr; + } + + ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(t->buffer); + ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft); + + for (const auto & backend : backends) { + ggml_backend_t be = backend.get(); + if (be != nullptr && ggml_backend_get_device(be) == dev && ggml_backend_supports_buft(be, buft)) { + return be; + } + } + + return nullptr; +} + +static ggml_backend_buffer_type_t dflash_preferred_cache_buft( + const std::vector & backends, + ggml_backend_buffer_type_t fallback) { + if (!ggml_backend_buft_is_host(fallback)) { + return fallback; + } + + for (const auto & backend : backends) { + ggml_backend_t be = backend.get(); + if (be == nullptr) { + continue; + } + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(be); + if (buft != nullptr && !ggml_backend_buft_is_host(buft)) { + return buft; + } + } + + return fallback; +} + +static bool dflash_draft_kv_cache_enabled() { + const char * e = getenv("LLAMA_DFLASH_DRAFT_KV_CACHE"); + return e == nullptr || e[0] != '0'; +} + +static bool dflash_graph_copy_1d( + ggml_backend_t backend, + ggml_tensor * src, + ggml_tensor * dst, + int64_t ne, + size_t src_off, + size_t dst_off) { + struct ggml_init_params params = { + /* mem_size = */ ggml_tensor_overhead() * 8 + ggml_graph_overhead_custom(8, false), + /* mem_buffer = */ nullptr, + /* no_alloc = */ true, + }; + ggml_context_ptr ctx { ggml_init(params) }; + if (!ctx) { + return false; + } + + ggml_cgraph * gf = ggml_new_graph_custom(ctx.get(), 8, false); + ggml_tensor * src_view = ggml_view_1d(ctx.get(), src, ne, src_off); + ggml_tensor * dst_view = ggml_view_1d(ctx.get(), dst, ne, dst_off); + ggml_tensor * out = ggml_cpy(ctx.get(), src_view, dst_view); + ggml_build_forward_expand(gf, out); + + return ggml_backend_graph_compute(backend, gf) == GGML_STATUS_SUCCESS; +} + +static bool dflash_graph_pack_capture( + ggml_backend_t backend, + ggml_tensor * src, + ggml_tensor * dst, + int64_t n_embd, + int64_t n_tokens, + const int32_t * dfs_indices, + int32_t done, + int64_t valid, + int64_t width) { + const int64_t n_copies = dfs_indices == nullptr ? 5 : 5 * valid; + const size_t src_elt = ggml_element_size(src); + const size_t dst_elt = ggml_element_size(dst); + struct ggml_init_params params = { + /* mem_size = */ ggml_tensor_overhead() * (size_t)(3 * n_copies + 8) + + ggml_graph_overhead_custom((size_t)(n_copies + 8), false), + /* mem_buffer = */ nullptr, + /* no_alloc = */ true, + }; + ggml_context_ptr ctx { ggml_init(params) }; + if (!ctx) { + return false; + } + + ggml_cgraph * gf = ggml_new_graph_custom(ctx.get(), (size_t)(n_copies + 8), false); + if (dfs_indices == nullptr) { + for (int64_t l = 0; l < 5; ++l) { + const size_t src_off = (size_t)(l * n_tokens + done) * n_embd * src_elt; + const size_t dst_off = (size_t)l * width * n_embd * dst_elt; + ggml_tensor * src_view = ggml_view_1d(ctx.get(), src, n_embd * valid, src_off); + ggml_tensor * dst_view = ggml_view_1d(ctx.get(), dst, n_embd * valid, dst_off); + ggml_build_forward_expand(gf, ggml_cpy(ctx.get(), src_view, dst_view)); + } + } else { + for (int64_t i = 0; i < valid; ++i) { + const int64_t src_col = (int64_t)dfs_indices[(int64_t)done + i]; + if (src_col < 0 || src_col >= n_tokens) { + return false; + } + for (int64_t l = 0; l < 5; ++l) { + const size_t src_off = (size_t)(l * n_tokens + src_col) * n_embd * src_elt; + const size_t dst_off = ((size_t)l * width + (size_t)i) * n_embd * dst_elt; + ggml_tensor * src_view = ggml_view_1d(ctx.get(), src, n_embd, src_off); + ggml_tensor * dst_view = ggml_view_1d(ctx.get(), dst, n_embd, dst_off); + ggml_build_forward_expand(gf, ggml_cpy(ctx.get(), src_view, dst_view)); + } + } + } + + return ggml_backend_graph_compute(backend, gf) == GGML_STATUS_SUCCESS; +} + +bool llama_context::dflash_draft_ensure_fused_cache_tensor( + int64_t n_embd, + int64_t cap, + ggml_backend_buffer_type_t buft) { + if (dflash_fused_cache != nullptr && dflash_fused_cache_n_embd == n_embd && + dflash_fused_cache_cap == cap && dflash_fused_cache->buffer != nullptr && + ggml_backend_buffer_get_type(dflash_fused_cache->buffer) == buft) { + return true; + } + + dflash_fused_cache = nullptr; + dflash_fused_cache_ctx.reset(); + dflash_fused_cache_buf.reset(); + + struct ggml_init_params params = { + /* mem_size = */ ggml_tensor_overhead() + 1024, + /* mem_buffer = */ nullptr, + /* no_alloc = */ true, + }; + ggml_context * ctx = ggml_init(params); + if (!ctx) { + return false; + } + + ggml_tensor * cache = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, cap); + ggml_set_name(cache, "dflash_fused_target_feat_cache"); + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + ggml_free(ctx); + return false; + } + ggml_backend_buffer_clear(buf, 0); + + dflash_fused_cache_ctx.reset(ctx); + dflash_fused_cache_buf.reset(buf); + dflash_fused_cache = cache; + dflash_fused_cache_n_embd = n_embd; + dflash_fused_cache_cap = cap; + return true; +} + +bool llama_context::dflash_draft_ensure_packed_target_feat_tensor( + int64_t n_embd, + int64_t ctx_len, + ggml_backend_buffer_type_t buft) { + if (dflash_packed_target_feat != nullptr && dflash_packed_target_feat_n_embd == n_embd && + dflash_packed_target_feat_ctx_len == ctx_len && dflash_packed_target_feat->buffer != nullptr && + ggml_backend_buffer_get_type(dflash_packed_target_feat->buffer) == buft) { + return true; + } + + dflash_packed_target_feat = nullptr; + dflash_packed_target_feat_ctx.reset(); + dflash_packed_target_feat_buf.reset(); + + struct ggml_init_params params = { + /* mem_size = */ ggml_tensor_overhead() + 1024, + /* mem_buffer = */ nullptr, + /* no_alloc = */ true, + }; + ggml_context * ctx = ggml_init(params); + if (!ctx) { + return false; + } + + ggml_tensor * packed = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ctx_len); + ggml_set_name(packed, "dflash_packed_target_feat"); + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + ggml_free(ctx); + return false; + } + ggml_backend_buffer_clear(buf, 0); + + dflash_packed_target_feat_ctx.reset(ctx); + dflash_packed_target_feat_buf.reset(buf); + dflash_packed_target_feat = packed; + dflash_packed_target_feat_n_embd = n_embd; + dflash_packed_target_feat_ctx_len = ctx_len; + return true; +} + +bool llama_context::dflash_draft_ensure_kv_cache_tensors( + int64_t n_embd_head, + int64_t n_head_kv, + int64_t cap, + ggml_backend_buffer_type_t buft) { + const int64_t n_layer = model.hparams.n_layer; + if ((int64_t)dflash_k_cache_l.size() == n_layer && dflash_kv_cache_head_dim == n_embd_head && + dflash_kv_cache_n_head_kv == n_head_kv && dflash_kv_cache_cap == cap && + !dflash_k_cache_l.empty() && dflash_k_cache_l[0] != nullptr && dflash_k_cache_l[0]->buffer != nullptr && + ggml_backend_buffer_get_type(dflash_k_cache_l[0]->buffer) == buft) { + return true; + } + + dflash_k_cache_l.clear(); + dflash_v_cache_l.clear(); + dflash_kv_cache_ctx.reset(); + dflash_kv_cache_buf.reset(); + + struct ggml_init_params params = { + /* mem_size = */ ggml_tensor_overhead() * (size_t)n_layer * 2 + 1024, + /* mem_buffer = */ nullptr, + /* no_alloc = */ true, + }; + ggml_context * ctx = ggml_init(params); + if (!ctx) { + return false; + } + + dflash_k_cache_l.resize((size_t)n_layer, nullptr); + dflash_v_cache_l.resize((size_t)n_layer, nullptr); + for (int64_t il = 0; il < n_layer; ++il) { + dflash_k_cache_l[(size_t)il] = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head_kv, cap); + dflash_v_cache_l[(size_t)il] = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head_kv, cap); + ggml_format_name(dflash_k_cache_l[(size_t)il], "dflash_k_cache_%lld", (long long)il); + ggml_format_name(dflash_v_cache_l[(size_t)il], "dflash_v_cache_%lld", (long long)il); + } + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + ggml_free(ctx); + dflash_k_cache_l.clear(); + dflash_v_cache_l.clear(); + return false; + } + ggml_backend_buffer_clear(buf, 0); + + dflash_kv_cache_ctx.reset(ctx); + dflash_kv_cache_buf.reset(buf); + dflash_kv_cache_head_dim = n_embd_head; + dflash_kv_cache_n_head_kv = n_head_kv; + dflash_kv_cache_cap = cap; + return true; +} + +bool llama_context::dflash_draft_ensure_packed_kv_tensors( + int64_t n_embd_head, + int64_t n_head_kv, + int64_t ctx_len, + ggml_backend_buffer_type_t buft) { + const int64_t n_layer = model.hparams.n_layer; + if ((int64_t)dflash_k_packed_l.size() == n_layer && dflash_kv_packed_head_dim == n_embd_head && + dflash_kv_packed_n_head_kv == n_head_kv && dflash_kv_packed_ctx_len == ctx_len && + !dflash_k_packed_l.empty() && dflash_k_packed_l[0] != nullptr && dflash_k_packed_l[0]->buffer != nullptr && + ggml_backend_buffer_get_type(dflash_k_packed_l[0]->buffer) == buft) { + return true; + } + + dflash_k_packed_l.clear(); + dflash_v_packed_l.clear(); + dflash_kv_packed_ctx.reset(); + dflash_kv_packed_buf.reset(); + + struct ggml_init_params params = { + /* mem_size = */ ggml_tensor_overhead() * (size_t)n_layer * 2 + 1024, + /* mem_buffer = */ nullptr, + /* no_alloc = */ true, + }; + ggml_context * ctx = ggml_init(params); + if (!ctx) { + return false; + } + + dflash_k_packed_l.resize((size_t)n_layer, nullptr); + dflash_v_packed_l.resize((size_t)n_layer, nullptr); + for (int64_t il = 0; il < n_layer; ++il) { + dflash_k_packed_l[(size_t)il] = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head_kv, ctx_len); + dflash_v_packed_l[(size_t)il] = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head_kv, ctx_len); + ggml_format_name(dflash_k_packed_l[(size_t)il], "dflash_k_packed_%lld", (long long)il); + ggml_format_name(dflash_v_packed_l[(size_t)il], "dflash_v_packed_%lld", (long long)il); + } + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + ggml_free(ctx); + dflash_k_packed_l.clear(); + dflash_v_packed_l.clear(); + return false; + } + ggml_backend_buffer_clear(buf, 0); + + dflash_kv_packed_ctx.reset(ctx); + dflash_kv_packed_buf.reset(buf); + dflash_kv_packed_head_dim = n_embd_head; + dflash_kv_packed_n_head_kv = n_head_kv; + dflash_kv_packed_ctx_len = ctx_len; + return true; +} + +bool llama_context::dflash_draft_ensure_top_output_tensors( + int64_t top_k, + int64_t rows, + ggml_backend_buffer_type_t buft) { + if (dflash_top_logits_fixed != nullptr && dflash_top_ids_fixed != nullptr && + dflash_top_output_k == top_k && dflash_top_output_rows == rows && + dflash_top_logits_fixed->buffer != nullptr && dflash_top_ids_fixed->buffer != nullptr && + ggml_backend_buffer_get_type(dflash_top_logits_fixed->buffer) == buft) { + return true; + } + + dflash_top_logits_fixed = nullptr; + dflash_top_ids_fixed = nullptr; + dflash_top_output_ctx.reset(); + dflash_top_output_buf.reset(); + + struct ggml_init_params params = { + /* mem_size = */ ggml_tensor_overhead() * 2 + 1024, + /* mem_buffer = */ nullptr, + /* no_alloc = */ true, + }; + ggml_context * ctx = ggml_init(params); + if (!ctx) { + return false; + } + + ggml_tensor * top_logits = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, top_k, rows); + ggml_tensor * top_ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, top_k, rows); + ggml_set_name(top_logits, "dflash_top_logits_fixed"); + ggml_set_name(top_ids, "dflash_top_ids_fixed"); + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + ggml_free(ctx); + return false; + } + ggml_backend_buffer_clear(buf, 0); + + dflash_top_output_ctx.reset(ctx); + dflash_top_output_buf.reset(buf); + dflash_top_logits_fixed = top_logits; + dflash_top_ids_fixed = top_ids; + dflash_top_output_k = top_k; + dflash_top_output_rows = rows; + return true; +} + +int llama_context::dflash_draft_update_fused_cache( + const float * target_feat_raw, + int64_t n_embd_fc, + int64_t n_new, + int64_t first_pos, + int64_t cap) { + if (model.arch != LLM_ARCH_DFLASH_DRAFT || target_feat_raw == nullptr || + n_embd_fc <= 0 || n_new <= 0 || cap <= 0 || n_embd_fc % 5 != 0) { + return -1; + } + + if (n_new > cap) { + const int64_t skip = n_new - cap; + target_feat_raw += (size_t)skip * n_embd_fc; + first_pos += skip; + n_new = cap; + } + + set_target_feat_raw(target_feat_raw, n_embd_fc, n_new, 0); + pending_dflash_fuse_only = true; + set_dflash_draft_top_k(0); + + const int64_t n_embd = n_embd_fc / 5; + std::vector dummy_embd((size_t) n_embd, 0.0f); + llama_pos pos = 0; + int32_t n_seq_id = 1; + llama_seq_id seq_id_value = 0; + llama_seq_id * seq_id = &seq_id_value; + int8_t output = 1; + + llama_batch batch{}; + batch.n_tokens = 1; + batch.token = nullptr; + batch.embd = dummy_embd.data(); + batch.pos = &pos; + batch.n_seq_id = &n_seq_id; + batch.seq_id = &seq_id; + batch.logits = &output; + + if (!balloc->init(batch, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { + pending_dflash_fuse_only = false; + LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); + return -1; + } + + const uint32_t n_tokens = balloc->get_n_tokens(); + const llama_ubatch ubatch = balloc->split_simple(n_tokens); + + if (t_compute_start_us == 0) { + t_compute_start_us = ggml_time_us(); + } + + embd_seq.clear(); + sched_reserve(); + n_queued_tokens += n_tokens; + n_outputs = n_tokens; + + const bool causal_attn_org = cparams.causal_attn; + cparams.causal_attn = false; + + ggml_status status; + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); + + cparams.causal_attn = causal_attn_org; + pending_dflash_fuse_only = false; + + if (!res) { + switch (status) { + case GGML_STATUS_ABORTED: return 2; + case GGML_STATUS_ALLOC_FAILED: return -2; + case GGML_STATUS_FAILED: return -3; + case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); + } + } + + ggml_tensor * t_fused = res->get_embd(); + if (t_fused == nullptr || t_fused->ne[0] != n_embd || t_fused->ne[1] != n_new || t_fused->buffer == nullptr) { + return -3; + } + + ggml_backend_buffer_type_t buft = dflash_preferred_cache_buft(backends, ggml_backend_buffer_get_type(t_fused->buffer)); + if (!dflash_draft_ensure_fused_cache_tensor(n_embd, cap, buft)) { + LLAMA_LOG_ERROR("%s: failed to allocate fused target_feat cache\n", __func__); + return -2; + } + + ggml_backend_t src_backend = dflash_tensor_backend(backends, t_fused); + ggml_backend_t dst_backend = dflash_tensor_backend(backends, dflash_fused_cache); + std::vector bounce; + + int64_t copied = 0; + while (copied < n_new) { + const int64_t dst_col = (first_pos + copied) % cap; + const int64_t n_part = std::min(n_new - copied, cap - dst_col); + const size_t src_off = (size_t) copied * n_embd * sizeof(float); + const size_t dst_off = (size_t) dst_col * n_embd * sizeof(float); + const size_t nbytes = (size_t) n_embd * n_part * sizeof(float); + if (src_backend != nullptr && src_backend == dst_backend) { + if (!dflash_graph_copy_1d(src_backend, t_fused, dflash_fused_cache, n_embd * n_part, src_off, dst_off)) { + LLAMA_LOG_ERROR("%s: failed to copy fused target_feat into device cache\n", __func__); + return -3; + } + } else { + bounce.resize((size_t) n_embd * n_part); + ggml_backend_tensor_get(t_fused, bounce.data(), src_off, nbytes); + ggml_backend_tensor_set(dflash_fused_cache, bounce.data(), dst_off, nbytes); + } + copied += n_part; + } + + if (dflash_draft_kv_cache_enabled() && + dflash_draft_update_kv_cache(nullptr, n_embd, n_new, first_pos, cap) != 0) { + LLAMA_LOG_ERROR("%s: failed to update draft K/V cache\n", __func__); + return -3; + } + + return 0; +} + +int llama_context::dflash_draft_update_fused_cache_from_capture( + llama_context * target_ctx, + const int32_t * dfs_indices, + int32_t n_dfs, + int64_t first_pos, + int64_t cap) { + if (model.arch != LLM_ARCH_DFLASH_DRAFT || target_ctx == nullptr || n_dfs <= 0 || cap <= 0) { + return -1; + } + + ggml_tensor * t_cap = target_ctx->get_hidden_capture(); + if (t_cap == nullptr || t_cap->buffer == nullptr || t_cap->ne[0] <= 0 || t_cap->ne[1] <= 0 || t_cap->ne[1] % 5 != 0) { + return -1; + } + + const int64_t n_embd = t_cap->ne[0]; + const int64_t n_tokens = t_cap->ne[1] / 5; + const int64_t n_embd_fc = 5 * n_embd; + const int64_t update_chunk = std::min(16, cap); + + ggml_backend_buffer_type_t src_buft = ggml_backend_buffer_get_type(t_cap->buffer); + ggml_backend_buffer_type_t buft = dflash_preferred_cache_buft(backends, src_buft); + if (!dflash_draft_ensure_packed_target_feat_tensor(n_embd, 5*update_chunk, buft)) { + return -2; + } + if (!dflash_draft_ensure_fused_cache_tensor(n_embd, cap, buft)) { + return -2; + } + + ggml_backend_t src_backend = dflash_tensor_backend(target_ctx->backends, t_cap); + ggml_backend_t dst_backend = dflash_tensor_backend(backends, dflash_packed_target_feat); + if (dst_backend == nullptr) { + return -3; + } + + std::vector zeros((size_t)n_embd * 5 * update_chunk, 0.0f); + std::vector bounce((size_t)n_embd); + + int32_t done = 0; + while (done < n_dfs) { + const int64_t dst_col = (first_pos + done) % cap; + const int64_t valid = std::min({(int64_t)n_dfs - done, update_chunk, cap - dst_col}); + const int64_t width = valid; + if (valid <= 0 || width <= 0) { + return -3; + } + if (!dflash_draft_ensure_packed_target_feat_tensor(n_embd, 5*width, buft)) { + return -2; + } + dst_backend = dflash_tensor_backend(backends, dflash_packed_target_feat); + if (dst_backend == nullptr) { + return -3; + } + + if (width > valid) { + ggml_backend_tensor_set(dflash_packed_target_feat, zeros.data(), 0, + (size_t)n_embd * 5 * width * sizeof(float)); + } + + if (dfs_indices == nullptr) { + if (src_backend != nullptr && src_backend == dst_backend) { + if (!dflash_graph_pack_capture(src_backend, t_cap, dflash_packed_target_feat, + n_embd, n_tokens, nullptr, done, valid, width)) { + return -3; + } + } else { + for (int64_t l = 0; l < 5; ++l) { + const size_t src_off = (size_t)(l * n_tokens + done) * n_embd * sizeof(float); + const size_t dst_off = (size_t)l * width * n_embd * sizeof(float); + const size_t nbytes = (size_t)n_embd * valid * sizeof(float); + std::vector layer_bounce((size_t)n_embd * valid); + ggml_backend_tensor_get(t_cap, layer_bounce.data(), src_off, nbytes); + ggml_backend_tensor_set(dflash_packed_target_feat, layer_bounce.data(), dst_off, nbytes); + } + } + } else { + if (src_backend != nullptr && src_backend == dst_backend) { + if (!dflash_graph_pack_capture(src_backend, t_cap, dflash_packed_target_feat, + n_embd, n_tokens, dfs_indices, done, valid, width)) { + return -3; + } + } else { + for (int64_t i = 0; i < valid; ++i) { + const int64_t src_col = (int64_t)dfs_indices[done + i]; + if (src_col < 0 || src_col >= n_tokens) { + return -1; + } + for (int64_t l = 0; l < 5; ++l) { + const size_t src_off = (size_t)(l * n_tokens + src_col) * n_embd * sizeof(float); + const size_t dst_off = ((size_t)l * width + (size_t)i) * n_embd * sizeof(float); + ggml_backend_tensor_get(t_cap, bounce.data(), src_off, (size_t)n_embd * sizeof(float)); + ggml_backend_tensor_set(dflash_packed_target_feat, bounce.data(), dst_off, (size_t)n_embd * sizeof(float)); + } + } + } + } + + pending_target_feat_raw = nullptr; + pending_target_feat_n_embd_fc = n_embd_fc; + pending_target_feat_ctx_len = width; + pending_draft_committed_pos = 0; + pending_target_feat_fused = false; + pending_dflash_fuse_only = true; + pending_dflash_kv_update_only = false; + pending_target_feat_tensor = dflash_packed_target_feat; + set_dflash_draft_top_k(0); + + std::vector dummy_embd((size_t)n_embd, 0.0f); + llama_pos pos = 0; + int32_t n_seq_id = 1; + llama_seq_id seq_id_value = 0; + llama_seq_id * seq_id = &seq_id_value; + int8_t output = 1; + + llama_batch batch{}; + batch.n_tokens = 1; + batch.token = nullptr; + batch.embd = dummy_embd.data(); + batch.pos = &pos; + batch.n_seq_id = &n_seq_id; + batch.seq_id = &seq_id; + batch.logits = &output; + + if (!balloc->init(batch, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { + pending_dflash_fuse_only = false; + pending_target_feat_tensor = nullptr; + return -1; + } + + const uint32_t n_batch_tokens = balloc->get_n_tokens(); + const llama_ubatch ubatch = balloc->split_simple(n_batch_tokens); + if (t_compute_start_us == 0) { + t_compute_start_us = ggml_time_us(); + } + embd_seq.clear(); + sched_reserve(); + n_queued_tokens += n_batch_tokens; + n_outputs = n_batch_tokens; + + const bool causal_attn_org = cparams.causal_attn; + cparams.causal_attn = false; + ggml_status status; + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); + cparams.causal_attn = causal_attn_org; + pending_dflash_fuse_only = false; + pending_target_feat_tensor = nullptr; + if (!res) { + switch (status) { + case GGML_STATUS_ABORTED: return 2; + case GGML_STATUS_ALLOC_FAILED: return -2; + case GGML_STATUS_FAILED: return -3; + case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); + } + } + + ggml_tensor * t_fused = res->get_embd(); + if (t_fused == nullptr || t_fused->ne[0] != n_embd || t_fused->ne[1] != width || t_fused->buffer == nullptr) { + return -3; + } + + ggml_backend_t fused_backend = dflash_tensor_backend(backends, t_fused); + ggml_backend_t cache_backend = dflash_tensor_backend(backends, dflash_fused_cache); + if (fused_backend != nullptr && fused_backend == cache_backend) { + if (!dflash_graph_copy_1d(fused_backend, t_fused, dflash_fused_cache, + n_embd * valid, 0, (size_t)dst_col * n_embd * sizeof(float))) { + return -3; + } + } else { + std::vector fused_bounce((size_t)n_embd * valid); + ggml_backend_tensor_get(t_fused, fused_bounce.data(), 0, (size_t)n_embd * valid * sizeof(float)); + ggml_backend_tensor_set(dflash_fused_cache, fused_bounce.data(), + (size_t)dst_col * n_embd * sizeof(float), + (size_t)n_embd * valid * sizeof(float)); + } + + if (dflash_draft_kv_cache_enabled() && + dflash_draft_update_kv_cache(nullptr, n_embd, valid, first_pos + done, cap) != 0) { + return -3; + } + + done += (int32_t)valid; + } + + return 0; +} + +int llama_context::dflash_draft_update_kv_cache( + const float * target_feat_fused, + int64_t n_embd, + int64_t n_new, + int64_t first_pos, + int64_t cap) { + GGML_UNUSED(target_feat_fused); + + const int64_t n_embd_head = model.hparams.n_embd_head_k(); + const int64_t n_head_kv = model.hparams.n_head_kv(); + ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(dflash_fused_cache->buffer); + if (!dflash_draft_ensure_kv_cache_tensors(n_embd_head, n_head_kv, cap, buft)) { + LLAMA_LOG_ERROR("%s: failed to allocate draft K/V cache\n", __func__); + return -2; + } + + const int64_t update_chunk = std::min(16, cap); + int64_t done = 0; + while (done < n_new) { + const int64_t src_col = (first_pos + done) % cap; + const int64_t valid = std::min({n_new - done, update_chunk, cap - src_col}); + const int64_t width = std::min(update_chunk, cap - src_col); + if (valid <= 0 || width <= 0) { + return -3; + } + if (!dflash_draft_ensure_packed_target_feat_tensor(n_embd, width, buft)) { + return -2; + } + + ggml_backend_t backend = dflash_tensor_backend(backends, dflash_fused_cache); + if (backend == nullptr || backend != dflash_tensor_backend(backends, dflash_packed_target_feat)) { + return -3; + } + if (!dflash_graph_copy_1d(backend, dflash_fused_cache, dflash_packed_target_feat, + n_embd * width, + (size_t)src_col * n_embd * sizeof(float), 0)) { + return -3; + } + + pending_target_feat_raw = nullptr; + pending_target_feat_n_embd_fc = n_embd; + pending_target_feat_ctx_len = width; + pending_draft_committed_pos = 0; + pending_target_feat_fused = true; + pending_dflash_fuse_only = false; + pending_dflash_kv_update_only = true; + pending_dflash_kv_update_dst_pos = src_col; + pending_target_feat_tensor = dflash_packed_target_feat; + set_dflash_draft_top_k(0); + + std::vector dummy_embd((size_t) n_embd, 0.0f); + llama_pos pos = 0; + int32_t n_seq_id = 1; + llama_seq_id seq_id_value = 0; + llama_seq_id * seq_id = &seq_id_value; + int8_t output = 1; + llama_batch batch{}; + batch.n_tokens = 1; + batch.token = nullptr; + batch.embd = dummy_embd.data(); + batch.pos = &pos; + batch.n_seq_id = &n_seq_id; + batch.seq_id = &seq_id; + batch.logits = &output; + + if (!balloc->init(batch, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { + pending_dflash_kv_update_only = false; + return -1; + } + + const uint32_t n_tokens = balloc->get_n_tokens(); + const llama_ubatch ubatch = balloc->split_simple(n_tokens); + if (t_compute_start_us == 0) { + t_compute_start_us = ggml_time_us(); + } + embd_seq.clear(); + sched_reserve(); + n_queued_tokens += n_tokens; + n_outputs = n_tokens; + + const bool causal_attn_org = cparams.causal_attn; + cparams.causal_attn = false; + ggml_status status; + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); + cparams.causal_attn = causal_attn_org; + pending_dflash_kv_update_only = false; + pending_target_feat_tensor = nullptr; + if (!res) { + switch (status) { + case GGML_STATUS_ABORTED: return 2; + case GGML_STATUS_ALLOC_FAILED: return -2; + case GGML_STATUS_FAILED: return -3; + case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); + } + } + + done += valid; + } + + return 0; +} + +int llama_context::dflash_draft_pack_kv_cache( + int64_t n_embd_head, + int64_t n_head_kv, + int64_t ctx_len, + int64_t ring_start, + int64_t cap) { + if (dflash_k_cache_l.empty() || dflash_v_cache_l.empty()) { + return -1; + } + ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(dflash_k_cache_l[0]->buffer); + if (!dflash_draft_ensure_packed_kv_tensors(n_embd_head, n_head_kv, ctx_len, buft)) { + return -2; + } + ggml_backend_t backend = dflash_tensor_backend(backends, dflash_k_cache_l[0]); + if (backend == nullptr || backend != dflash_tensor_backend(backends, dflash_k_packed_l[0])) { + return -3; + } + + const int64_t per_tok = n_embd_head * n_head_kv; + for (size_t il = 0; il < dflash_k_cache_l.size(); ++il) { + int64_t copied = 0; + while (copied < ctx_len) { + const int64_t src_col = (ring_start + copied) % cap; + const int64_t n_part = std::min(ctx_len - copied, cap - src_col); + const size_t src_off = (size_t) src_col * per_tok * sizeof(float); + const size_t dst_off = (size_t) copied * per_tok * sizeof(float); + const int64_t ne = per_tok * n_part; + if (!dflash_graph_copy_1d(backend, dflash_k_cache_l[il], dflash_k_packed_l[il], ne, src_off, dst_off) || + !dflash_graph_copy_1d(backend, dflash_v_cache_l[il], dflash_v_packed_l[il], ne, src_off, dst_off)) { + return -3; + } + copied += n_part; + } + } + return 0; +} + +int llama_context::dflash_draft_encode_top_k_cached( + const llama_batch & batch_inp, + int64_t n_embd, + int64_t ctx_len, + int64_t ring_start, + int64_t cap, + int64_t committed_pos, + int32_t top_k) { + if (model.arch != LLM_ARCH_DFLASH_DRAFT || dflash_fused_cache == nullptr || + n_embd <= 0 || ctx_len <= 0 || cap <= 0 || top_k <= 0 || + n_embd != dflash_fused_cache_n_embd || cap != dflash_fused_cache_cap) { + return -1; + } + + ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(dflash_fused_cache->buffer); + if (!dflash_draft_ensure_packed_target_feat_tensor(n_embd, ctx_len, buft)) { + LLAMA_LOG_ERROR("%s: failed to allocate packed target_feat tensor\n", __func__); + return -2; + } + + ggml_backend_t backend = dflash_tensor_backend(backends, dflash_fused_cache); + if (backend == nullptr || backend != dflash_tensor_backend(backends, dflash_packed_target_feat)) { + LLAMA_LOG_ERROR("%s: packed target_feat backend mismatch\n", __func__); + return -3; + } + + int64_t copied = 0; + while (copied < ctx_len) { + const int64_t src_col = (ring_start + copied) % cap; + const int64_t n_part = std::min(ctx_len - copied, cap - src_col); + const size_t src_off = (size_t) src_col * n_embd * sizeof(float); + const size_t dst_off = (size_t) copied * n_embd * sizeof(float); + if (!dflash_graph_copy_1d(backend, dflash_fused_cache, dflash_packed_target_feat, + n_embd * n_part, src_off, dst_off)) { + LLAMA_LOG_ERROR("%s: failed to pack fused target_feat device window\n", __func__); + return -3; + } + copied += n_part; + } + + if (dflash_draft_kv_cache_enabled()) { + const int64_t n_embd_head = model.hparams.n_embd_head_k(); + const int64_t n_head_kv = model.hparams.n_head_kv(); + if (dflash_draft_pack_kv_cache(n_embd_head, n_head_kv, ctx_len, ring_start, cap) != 0) { + LLAMA_LOG_ERROR("%s: failed to pack draft K/V cache\n", __func__); + return -3; + } + } else { + dflash_kv_packed_ctx.reset(); + dflash_kv_packed_buf.reset(); + dflash_k_packed_l.clear(); + dflash_v_packed_l.clear(); + dflash_kv_packed_head_dim = 0; + dflash_kv_packed_n_head_kv = 0; + dflash_kv_packed_ctx_len = 0; + } + + pending_target_feat_raw = nullptr; + pending_target_feat_n_embd_fc = n_embd; + pending_target_feat_ctx_len = ctx_len; + pending_draft_committed_pos = committed_pos; + pending_target_feat_fused = true; + pending_dflash_fuse_only = false; + pending_dflash_kv_update_only = false; + pending_target_feat_tensor = dflash_packed_target_feat; + + return dflash_draft_encode_top_k_pending(batch_inp, top_k); +} + +int llama_context::dflash_draft_encode_top_k( + const llama_batch & batch_inp, + const float * target_feat_raw, + int64_t n_embd_fc, + int64_t ctx_len, + int64_t committed_pos, + int32_t top_k) { + if (model.arch != LLM_ARCH_DFLASH_DRAFT || target_feat_raw == nullptr || top_k <= 0) { + return -1; + } + + set_target_feat_raw(target_feat_raw, n_embd_fc, ctx_len, committed_pos); + return dflash_draft_encode_top_k_pending(batch_inp, top_k); +} + +int llama_context::dflash_draft_encode_top_k_fused( + const llama_batch & batch_inp, + const float * target_feat_fused, + int64_t n_embd, + int64_t ctx_len, + int64_t committed_pos, + int32_t top_k) { + if (model.arch != LLM_ARCH_DFLASH_DRAFT || target_feat_fused == nullptr || top_k <= 0) { + return -1; + } + + set_target_feat_fused(target_feat_fused, n_embd, ctx_len, committed_pos); + return dflash_draft_encode_top_k_pending(batch_inp, top_k); +} + +int llama_context::dflash_draft_encode_top_k_pending( + const llama_batch & batch_inp, + int32_t top_k) { + if (model.arch != LLM_ARCH_DFLASH_DRAFT || + (pending_target_feat_raw == nullptr && pending_target_feat_tensor == nullptr) || top_k <= 0) { + return -1; + } + + set_dflash_draft_top_k(top_k); + + const auto & hparams = model.hparams; + const int64_t n_embd_model = hparams.n_embd_inp(); + + if (pending_target_feat_fused && pending_target_feat_n_embd_fc != n_embd_model) { + LLAMA_LOG_ERROR("%s: fused target_feat width mismatch: got %lld expected %lld\n", + __func__, (long long)pending_target_feat_n_embd_fc, (long long)n_embd_model); + return -1; + } + + if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd_model, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { + LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); + return -1; + } + + const uint32_t n_tokens = balloc->get_n_tokens(); + const llama_ubatch ubatch = balloc->split_simple(n_tokens); - for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { - const llama_seq_id seq_id = ubatch.seq_id_unq[s]; - const int32_t seq_idx = ubatch.seq_idx[seq_id]; + if (cparams.n_ubatch < n_tokens) { + LLAMA_LOG_ERROR("%s: encoder requires n_ubatch >= n_tokens\n", __func__); + return -1; + } - embd_seq_out[seq_id].resize(n_cls_out); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float)); - } - } break; - case LLAMA_POOLING_TYPE_UNSPECIFIED: - { - GGML_ABORT("unknown pooling type"); - } - } + if (t_compute_start_us == 0) { + t_compute_start_us = ggml_time_us(); } - // TODO: hacky solution - if (model.arch == LLM_ARCH_T5 && t_embd) { - //cross.t_embd = t_embd; + embd_seq.clear(); + sched_reserve(); + n_queued_tokens += n_tokens; + n_outputs = n_tokens; - synchronize(); + const bool causal_attn_org = cparams.causal_attn; + cparams.causal_attn = false; - cross.n_embd = t_embd->ne[0]; - cross.n_enc = t_embd->ne[1]; - cross.v_embd.resize(cross.n_embd*cross.n_enc); - memcpy(cross.v_embd.data(), embd.data, ggml_nbytes(t_embd)); + ggml_status status; + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); - const auto & batch = balloc->get_batch(); + cparams.causal_attn = causal_attn_org; - // remember the sequence ids used during the encoding - needed for cross attention later - cross.seq_ids_enc.resize(n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - cross.seq_ids_enc[i].clear(); + if (!res) { + switch (status) { + case GGML_STATUS_ABORTED: return 2; + case GGML_STATUS_ALLOC_FAILED: return -2; + case GGML_STATUS_FAILED: return -3; + case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); + } + } - for (int s = 0; s < batch.n_seq_id[i]; s++) { - const llama_seq_id seq_id = batch.seq_id[i][s]; + if (res->get_dflash_top_logits() == nullptr || res->get_dflash_top_ids() == nullptr) { + dflash_draft_top_k = 0; + dflash_draft_top_rows = 0; + dflash_draft_top_logits.clear(); + dflash_draft_top_token_ids.clear(); + return -3; + } - cross.seq_ids_enc[i].insert(seq_id); - } + ggml_tensor * t_top_logits = res->get_dflash_top_logits(); + ggml_tensor * t_top_ids = res->get_dflash_top_ids(); + auto tensor_backend_from_any_sched = [&](ggml_tensor * t) -> ggml_backend_t { + ggml_backend_t be = ggml_backend_sched_get_tensor_backend(sched.get(), t); + if (be == nullptr && dflash_sched_draft) { + be = ggml_backend_sched_get_tensor_backend(dflash_sched_draft.get(), t); + } + if (be == nullptr && dflash_sched_fuse) { + be = ggml_backend_sched_get_tensor_backend(dflash_sched_fuse.get(), t); } + if (be == nullptr && dflash_sched_kv) { + be = ggml_backend_sched_get_tensor_backend(dflash_sched_kv.get(), t); + } + return be; + }; + ggml_backend_t backend_logits = tensor_backend_from_any_sched(t_top_logits); + ggml_backend_t backend_ids = tensor_backend_from_any_sched(t_top_ids); + GGML_ASSERT(backend_logits != nullptr); + GGML_ASSERT(backend_ids != nullptr); + + const int64_t graph_top_k = t_top_logits->ne[0]; + const int64_t rows = t_top_logits->ne[1]; + GGML_ASSERT(t_top_ids->ne[0] == graph_top_k && t_top_ids->ne[1] == rows); + + ggml_backend_buffer_type_t out_buft = ggml_backend_buffer_get_type(t_top_logits->buffer); + if (!dflash_draft_ensure_top_output_tensors(graph_top_k, rows, out_buft)) { + return -2; + } + + ggml_backend_t backend_fixed_logits = dflash_tensor_backend(backends, dflash_top_logits_fixed); + ggml_backend_t backend_fixed_ids = dflash_tensor_backend(backends, dflash_top_ids_fixed); + if (backend_fixed_logits == nullptr || backend_fixed_ids == nullptr) { + return -3; + } + if (backend_logits != nullptr && backend_logits == backend_fixed_logits) { + ggml_backend_tensor_copy_async(backend_logits, backend_fixed_logits, t_top_logits, dflash_top_logits_fixed); + } else { + std::vector top_bounce((size_t)graph_top_k * rows); + ggml_backend_tensor_get(t_top_logits, top_bounce.data(), 0, ggml_nbytes(t_top_logits)); + ggml_backend_tensor_set(dflash_top_logits_fixed, top_bounce.data(), 0, ggml_nbytes(t_top_logits)); + } + if (backend_ids != nullptr && backend_ids == backend_fixed_ids) { + ggml_backend_tensor_copy_async(backend_ids, backend_fixed_ids, t_top_ids, dflash_top_ids_fixed); + } else { + std::vector id_bounce((size_t)graph_top_k * rows); + ggml_backend_tensor_get(t_top_ids, id_bounce.data(), 0, ggml_nbytes(t_top_ids)); + ggml_backend_tensor_set(dflash_top_ids_fixed, id_bounce.data(), 0, ggml_nbytes(t_top_ids)); } + dflash_draft_top_k = (int32_t) graph_top_k; + dflash_draft_top_rows = (int32_t) rows; + dflash_draft_top_logits.resize((size_t) graph_top_k * rows); + dflash_draft_top_token_ids.resize((size_t) graph_top_k * rows); + + ggml_backend_tensor_get_async(backend_fixed_logits, dflash_top_logits_fixed, dflash_draft_top_logits.data(), 0, + ggml_nbytes(dflash_top_logits_fixed)); + ggml_backend_tensor_get_async(backend_fixed_ids, dflash_top_ids_fixed, dflash_draft_top_token_ids.data(), 0, + ggml_nbytes(dflash_top_ids_fixed)); + return 0; } @@ -1538,6 +3123,19 @@ int llama_context::decode(const llama_batch & batch_inp) { return encode(batch_inp); } + // tree-mode batches are only supported for the Qwen3.5 hybrid architecture + if (batch_inp.parent_id != nullptr && model.arch != LLM_ARCH_QWEN35) { + LLAMA_LOG_ERROR("%s: parent_id (tree-mode batch) is only supported for LLM_ARCH_QWEN35, got arch=%d\n", + __func__, (int) model.arch); + return -1; + } + + // Phase 2.4: ensure SSM persist buffers are large enough for this tree batch. + // Must happen before graph_params() so the pointer is valid when building the graph. + if (batch_inp.parent_id != nullptr && llama_dflash_fast_rollback_enabled()) { + ensure_dflash_persist_capacity((int64_t)batch_inp.n_tokens); + } + if (batch_inp.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); return -1; @@ -1735,7 +3333,8 @@ int llama_context::decode(const llama_batch & batch_inp) { } // extract logits - if (logits.data && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers)) { + if (logits.data && t_logits && n_outputs > 0 && dflash_draft_top_k_req <= 0 && + needs_raw_logits(ubatch, sampling.samplers)) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); GGML_ASSERT(backend_res != nullptr); GGML_ASSERT(logits.data != nullptr); @@ -1749,6 +3348,61 @@ int llama_context::decode(const llama_batch & batch_inp) { } } + if (dflash_draft_top_k_req > 0 && res->get_dflash_top_ids() != nullptr) { + ggml_tensor * t_top_logits = res->get_dflash_top_logits(); + ggml_tensor * t_top_ids = res->get_dflash_top_ids(); + ggml_backend_t backend_ids = ggml_backend_sched_get_tensor_backend(sched.get(), t_top_ids); + ggml_backend_t backend_logits = t_top_logits != nullptr + ? ggml_backend_sched_get_tensor_backend(sched.get(), t_top_logits) + : nullptr; + GGML_ASSERT(backend_ids != nullptr); + GGML_ASSERT(t_top_logits == nullptr || backend_logits != nullptr); + + const int64_t top_k = t_top_ids->ne[0]; + const int64_t rows = t_top_ids->ne[1]; + GGML_ASSERT(t_top_logits == nullptr || (t_top_logits->ne[0] == top_k && t_top_logits->ne[1] == rows)); + GGML_ASSERT(rows == n_outputs); + + if (n_outputs_prev == 0) { + dflash_draft_top_k = (int32_t) top_k; + dflash_draft_top_rows = (int32_t) n_outputs_all; + dflash_draft_top_logits.resize((size_t) top_k * n_outputs_all); + dflash_draft_top_token_ids.resize((size_t) top_k * n_outputs_all); + } + + GGML_ASSERT(dflash_draft_top_k == (int32_t) top_k); + float * top_logits_out = dflash_draft_top_logits.data() + (size_t) top_k * n_outputs_prev; + llama_token * top_ids_out = dflash_draft_top_token_ids.data() + (size_t) top_k * n_outputs_prev; + if (t_top_logits != nullptr) { + ggml_backend_tensor_get_async(backend_logits, t_top_logits, top_logits_out, 0, ggml_nbytes(t_top_logits)); + } + ggml_backend_tensor_get_async(backend_ids, t_top_ids, top_ids_out, 0, ggml_nbytes(t_top_ids)); + } else if (dflash_draft_top_k_req > 0 && n_outputs_prev == 0) { + dflash_draft_top_k = 0; + dflash_draft_top_rows = 0; + dflash_draft_top_logits.clear(); + dflash_draft_top_token_ids.clear(); + } + + if (capture_hidden && res->t_hidden_capture != nullptr) { + ggml_tensor * t_cap = res->t_hidden_capture; + hidden_capture_ne0 = t_cap->ne[0]; + hidden_capture_ne1 = t_cap->ne[1]; + const char * direct = std::getenv("LLAMA_DDTREE_CAPTURE_DIRECT"); + if (direct != nullptr && direct[0] == '1') { + hidden_capture_host_valid = false; + } else { + ggml_backend_t backend_cap = ggml_backend_sched_get_tensor_backend(sched.get(), t_cap); + GGML_ASSERT(backend_cap != nullptr); + const size_t cap_n = ggml_nelements(t_cap); + if (hidden_capture_host.size() < cap_n) { + hidden_capture_host.resize(cap_n); + } + ggml_backend_tensor_get_async(backend_cap, t_cap, hidden_capture_host.data(), 0, cap_n * sizeof(float)); + hidden_capture_host_valid = true; + } + } + // extract embeddings if (embd.data && t_embd && n_outputs > 0) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); @@ -2149,29 +3803,61 @@ llm_graph_params llama_context::graph_params( llm_graph_result * res, const llama_ubatch & ubatch, const llama_memory_context_i * mctx, - llm_graph_type gtype) const { + llm_graph_type gtype, + ggml_backend_sched_t sched_use) const { + if (sched_use == nullptr) { + sched_use = sched.get(); + } return { /*.arch =*/ model.arch, /*.hparams =*/ model.hparams, /*.cparams =*/ cparams, /*.ubatch =*/ ubatch, /*.gtype =*/ gtype, - /*.sched =*/ sched.get(), + /*.sched =*/ sched_use, /*.backend_cpu =*/ backend_cpu, /*.cvec =*/ cvec.get(), /*.loras =*/ loras.get(), /*.mctx =*/ mctx, /*.cross =*/ &cross, /*.samplers =*/ sampling.samplers, - /*.n_outputs =*/ n_outputs, - /*.cb =*/ graph_get_cb(), - /*.res =*/ res, + /*.n_outputs =*/ n_outputs, + /*.cb =*/ graph_get_cb(), + /*.res =*/ res, + /*.capture_hidden =*/ capture_hidden, + // Wire pending_target_feat pointers so build_inp_target_feat() can read them. + // These are non-null only when the caller invoked llama_set_target_feat_raw(). + /*.pending_target_feat_raw_ptr =*/ &pending_target_feat_raw, + /*.pending_target_feat_n_embd_fc_ptr =*/ &pending_target_feat_n_embd_fc, + /*.pending_target_feat_ctx_len_ptr =*/ &pending_target_feat_ctx_len, + /*.pending_draft_committed_pos_ptr =*/ &pending_draft_committed_pos, + /*.pending_target_feat_tensor_ptr =*/ &pending_target_feat_tensor, + /*.dflash_kv_cache_k_l =*/ pending_dflash_kv_update_only ? &dflash_k_cache_l : &dflash_k_packed_l, + /*.dflash_kv_cache_v_l =*/ pending_dflash_kv_update_only ? &dflash_v_cache_l : &dflash_v_packed_l, + /*.dflash_kv_cache_dst_pos =*/ pending_dflash_kv_update_dst_pos, + /*.dflash_target_feat_fused =*/ pending_target_feat_fused, + /*.dflash_kv_update_only =*/ pending_dflash_kv_update_only, + /*.dflash_fuse_only =*/ pending_dflash_fuse_only, + /*.dflash_draft_top_k =*/ dflash_draft_top_k_req, + // Phase 2.4: pass persist buffer vector when in tree mode (parent_id is set). + // Non-null only after ensure_dflash_persist_capacity() ran in decode(). + /*.dflash_persist_inter_l =*/ (!dflash_persist_inter_l.empty() && ubatch.parent_id != nullptr) + ? &dflash_persist_inter_l : nullptr, + /*.dflash_persist_conv_l =*/ (!dflash_persist_conv_l.empty() && ubatch.parent_id != nullptr) + ? &dflash_persist_conv_l : nullptr, }; } ggml_status llama_context::graph_compute( ggml_cgraph * gf, bool batched) { + return graph_compute(sched.get(), gf, batched); +} + +ggml_status llama_context::graph_compute( + ggml_backend_sched_t sched_use, + ggml_cgraph * gf, + bool batched) { int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads; ggml_threadpool_t tp = batched ? threadpool_batch : threadpool; @@ -2188,7 +3874,7 @@ ggml_status llama_context::graph_compute( set_n_threads_fn.second(set_n_threads_fn.first, n_threads); } - auto status = ggml_backend_sched_graph_compute_async(sched.get(), gf); + auto status = ggml_backend_sched_graph_compute_async(sched_use, gf); if (status != GGML_STATUS_SUCCESS) { LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status); } @@ -3083,6 +4769,132 @@ void llama_set_warmup(llama_context * ctx, bool warmup) { ctx->set_warmup(warmup); } +void llama_set_capture_hidden(llama_context * ctx, bool enable) { + ctx->set_capture_hidden(enable); +} + +void llama_set_dflash_draft_top_k(llama_context * ctx, int32_t k) { + ctx->set_dflash_draft_top_k(k); +} + +ggml_tensor * llama_get_hidden_capture(llama_context * ctx) { + ctx->synchronize(); + return ctx->get_hidden_capture(); +} + +const float * llama_get_hidden_capture_data(llama_context * ctx, int64_t * out_ne0, int64_t * out_ne1) { + ctx->synchronize(); + return ctx->get_hidden_capture_data(out_ne0, out_ne1); +} + +void llama_set_target_feat_raw(llama_context * ctx, + const float * data, + int64_t n_embd_fc, + int64_t ctx_len, + int64_t committed_pos) { + ctx->set_target_feat_raw(data, n_embd_fc, ctx_len, committed_pos); +} + +int llama_dflash_draft_fuse_target_feat(llama_context * ctx, + const float * target_feat_raw, + int64_t n_embd_fc, + int64_t ctx_len, + float * target_feat_fused) { + return ctx->dflash_draft_fuse_target_feat(target_feat_raw, n_embd_fc, ctx_len, target_feat_fused); +} + +int llama_dflash_draft_encode_top_k(llama_context * ctx, + llama_batch batch, + const float * target_feat_raw, + int64_t n_embd_fc, + int64_t ctx_len, + int64_t committed_pos, + int32_t top_k) { + return ctx->dflash_draft_encode_top_k(batch, target_feat_raw, n_embd_fc, ctx_len, committed_pos, top_k); +} + +int llama_dflash_draft_encode_top_k_fused(llama_context * ctx, + llama_batch batch, + const float * target_feat_fused, + int64_t n_embd, + int64_t ctx_len, + int64_t committed_pos, + int32_t top_k) { + return ctx->dflash_draft_encode_top_k_fused(batch, target_feat_fused, n_embd, ctx_len, committed_pos, top_k); +} + +int llama_dflash_draft_update_fused_cache(llama_context * ctx, + const float * target_feat_raw, + int64_t n_embd_fc, + int64_t n_new, + int64_t first_pos, + int64_t cap) { + return ctx->dflash_draft_update_fused_cache(target_feat_raw, n_embd_fc, n_new, first_pos, cap); +} + +int llama_dflash_draft_update_fused_cache_from_capture(llama_context * draft_ctx, + llama_context * target_ctx, + const int32_t * dfs_indices, + int32_t n_dfs, + int64_t first_pos, + int64_t cap) { + return draft_ctx->dflash_draft_update_fused_cache_from_capture(target_ctx, dfs_indices, n_dfs, first_pos, cap); +} + +int llama_dflash_draft_encode_top_k_cached(llama_context * ctx, + llama_batch batch, + int64_t n_embd, + int64_t ctx_len, + int64_t ring_start, + int64_t cap, + int64_t committed_pos, + int32_t top_k) { + return ctx->dflash_draft_encode_top_k_cached(batch, n_embd, ctx_len, ring_start, cap, committed_pos, top_k); +} + +void llama_dflash_ensure_persist_capacity(struct llama_context * ctx, int64_t n_tokens) { + ctx->ensure_dflash_persist_capacity(n_tokens); +} + +bool llama_dflash_rollback_ssm_to_dfs( + struct llama_context * ctx, + llama_seq_id seq_id, + int32_t accepted_dfs_node) { + ctx->synchronize(); // ensure the tree-mode decode kernel has completed + return ctx->dflash_rollback_ssm_to_dfs(seq_id, accepted_dfs_node); +} + +bool llama_dflash_set_recurrent_tail_pos( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos pos) { + if (ctx == nullptr) { + return false; + } + + auto * raw_mem = ctx->get_memory(); + auto * mem_recr = dynamic_cast(raw_mem); + if (!mem_recr) { + if (auto * hyb = dynamic_cast(raw_mem)) { + mem_recr = hyb->get_mem_recr(); + } + } + if (!mem_recr || seq_id < 0 || seq_id >= (llama_seq_id) mem_recr->cells.size()) { + return false; + } + + const int32_t cell_id = mem_recr->cells[seq_id].tail; + if (cell_id < 0 || cell_id >= (int32_t) mem_recr->cells.size()) { + return false; + } + if (!mem_recr->cells[cell_id].has_seq_id(seq_id)) { + return false; + } + + mem_recr->cells[cell_id].pos = pos; + return true; +} + void llama_synchronize(llama_context * ctx) { ctx->synchronize(); } @@ -3171,6 +4983,16 @@ uint32_t llama_get_sampled_probs_count_ith(llama_context * ctx, int32_t i) { return static_cast(ctx->get_sampled_probs_count(i)); } +bool llama_get_dflash_draft_top_k( + llama_context * ctx, + const float ** logits, + const llama_token ** token_ids, + int32_t * n_rows, + int32_t * k) { + ctx->synchronize(); + return ctx->get_dflash_draft_top_k(logits, token_ids, n_rows, k); +} + struct ggml_cgraph * llama_graph_reserve( struct llama_context * ctx, uint32_t n_tokens, @@ -3317,6 +5139,345 @@ bool llama_memory_can_shift(llama_memory_t mem) { return mem->get_can_shift(); } +// snapshot/restore for recurrent (SSM + conv) state + +llama_mem_snapshot_id llama_seq_snapshot(struct llama_context * ctx, llama_seq_id seq_id) { + auto * raw_mem = ctx->get_memory(); + auto * mem = dynamic_cast(raw_mem); + if (!mem) { + if (auto * hyb = dynamic_cast(raw_mem)) { + mem = hyb->get_mem_recr(); + } + } + if (!mem) { + return LLAMA_MEM_SNAPSHOT_INVALID; + } + return mem->snapshot(seq_id); +} + +bool llama_seq_restore(struct llama_context * ctx, llama_mem_snapshot_id snap_id) { + auto * raw_mem = ctx->get_memory(); + auto * mem = dynamic_cast(raw_mem); + if (!mem) { + if (auto * hyb = dynamic_cast(raw_mem)) { + mem = hyb->get_mem_recr(); + } + } + if (!mem) { + return false; + } + return mem->restore(snap_id); +} + +void llama_seq_release(struct llama_context * ctx, llama_mem_snapshot_id snap_id) { + auto * raw_mem = ctx->get_memory(); + auto * mem = dynamic_cast(raw_mem); + if (!mem) { + if (auto * hyb = dynamic_cast(raw_mem)) { + mem = hyb->get_mem_recr(); + } + } + if (mem) { + mem->release(snap_id); + } +} + +bool llama_context::dflash_rollback_ssm_to_dfs(llama_seq_id seq_id, int32_t accepted_dfs_node) { + if (dflash_persist_inter_l.empty()) { + LLAMA_LOG_WARN("%s: persist buffers not allocated (no tree-mode decode has run)\n", __func__); + return false; + } + + // Resolve the recurrent memory module. + auto * raw_mem = memory.get(); + auto * mem_recr = dynamic_cast(raw_mem); + if (!mem_recr) { + if (auto * hyb = dynamic_cast(raw_mem)) { + mem_recr = hyb->get_mem_recr(); + } + } + if (!mem_recr) { + LLAMA_LOG_WARN("%s: no recurrent memory module; rollback is a no-op\n", __func__); + return false; + } + + const auto & hparams = model.hparams; + const int32_t n_layer = (int32_t)hparams.n_layer; + const int32_t cell_id = (seq_id >= 0 && seq_id < (int32_t)mem_recr->cells.size()) + ? mem_recr->cells[seq_id].tail : -1; + if (cell_id < 0) { + LLAMA_LOG_WARN("%s: seq_id=%d has no tail cell; rollback skipped\n", __func__, (int)seq_id); + return false; + } + + const int64_t n_embd_s = (int64_t)hparams.n_embd_s(); + const bool skip_s_rollback = []{ + const char * e = getenv("LLAMA_DDTREE_ROLLBACK_SKIP_S"); + return e && e[0] == '1'; + }(); + + // Fast path: execute the rollback as a tiny backend graph so CUDA layers do + // not bounce every persist column through host memory. Keep the host path + // below as the exact fallback for mixed/offloaded or unsupported layouts. + const bool graph_rollback_enabled = []{ + const char * e = getenv("LLAMA_DDTREE_ROLLBACK_GRAPH"); + return !e || e[0] != '0'; + }(); + + const bool skip_conv_rollback = []{ + const char * e = getenv("LLAMA_DDTREE_ROLLBACK_SKIP_CONV"); + return e && e[0] == '1'; + }(); + + struct dflash_rollback_copy { + ggml_tensor * src; + ggml_tensor * dst; + int64_t ne; + size_t src_off; + size_t dst_off; + }; + + auto tensor_backend = [&](const ggml_tensor * t) -> ggml_backend_t { + if (t == nullptr || t->buffer == nullptr) { + return nullptr; + } + + ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(t->buffer); + ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft); + + for (auto & backend : backends) { + ggml_backend_t be = backend.get(); + if (be != nullptr && ggml_backend_get_device(be) == dev && ggml_backend_supports_buft(be, buft)) { + return be; + } + } + + return nullptr; + }; + + auto try_graph_rollback = [&]() -> bool { + if (!graph_rollback_enabled) { + return false; + } + + std::vector copies; + copies.reserve((size_t)n_layer * 2); + + ggml_backend_t graph_backend = nullptr; + auto add_copy = [&](ggml_tensor * src, ggml_tensor * dst, int64_t ne, size_t src_off, size_t dst_off) -> bool { + ggml_backend_t src_backend = tensor_backend(src); + ggml_backend_t dst_backend = tensor_backend(dst); + if (src_backend == nullptr || src_backend != dst_backend) { + return false; + } + if (graph_backend == nullptr) { + graph_backend = src_backend; + } else if (graph_backend != src_backend) { + return false; + } + + copies.push_back({ src, dst, ne, src_off, dst_off }); + return true; + }; + + if (!skip_s_rollback) { + for (int il = 0; il < n_layer; ++il) { + if (!hparams.is_recurrent(il)) { continue; } + ggml_tensor * persist = dflash_persist_inter_l[il]; + ggml_tensor * s_state = (il < (int32_t)mem_recr->s_l.size()) ? mem_recr->s_l[il] : nullptr; + if (!persist || !s_state) { continue; } + if (accepted_dfs_node >= dflash_persist_max_n_tokens) { return false; } + + if (!((persist->type == GGML_TYPE_F32 || persist->type == GGML_TYPE_F16) && + (s_state->type == GGML_TYPE_F32 || s_state->type == GGML_TYPE_F16))) { + return false; + } + + const size_t src_col_bytes = ggml_row_size(persist->type, n_embd_s); + const size_t dst_row_bytes = ggml_row_size(s_state->type, n_embd_s); + if (!add_copy(persist, s_state, n_embd_s, + (size_t)accepted_dfs_node * src_col_bytes, + (size_t)cell_id * dst_row_bytes)) { + return false; + } + } + } + + if (!skip_conv_rollback && !dflash_persist_conv_l.empty()) { + const int64_t n_embd_r = (int64_t)hparams.n_embd_r(); + for (int il = 0; il < n_layer; ++il) { + if (!hparams.is_recurrent(il)) { continue; } + ggml_tensor * persist_conv = (il < (int32_t)dflash_persist_conv_l.size()) + ? dflash_persist_conv_l[il] : nullptr; + ggml_tensor * r_state = (il < (int32_t)mem_recr->r_l.size()) + ? mem_recr->r_l[il] : nullptr; + if (!persist_conv || !r_state) { continue; } + if (accepted_dfs_node >= dflash_persist_max_n_tokens) { return false; } + if (persist_conv->type != GGML_TYPE_F32 || r_state->type != GGML_TYPE_F32) { + return false; + } + + const size_t conv_col_bytes = (size_t)n_embd_r * sizeof(float); + const size_t r_row_bytes = ggml_row_size(r_state->type, n_embd_r); + if (!add_copy(persist_conv, r_state, n_embd_r, + (size_t)accepted_dfs_node * conv_col_bytes, + (size_t)cell_id * r_row_bytes)) { + return false; + } + } + } + + if (copies.empty()) { + return true; + } + + const size_t graph_size = copies.size() * 4 + 16; + struct ggml_init_params params = { + /* mem_size = */ ggml_tensor_overhead() * (copies.size() * 4 + 16) + + ggml_graph_overhead_custom(graph_size, false), + /* mem_buffer = */ nullptr, + /* no_alloc = */ true, + }; + ggml_context_ptr ctx { ggml_init(params) }; + if (!ctx) { + return false; + } + + ggml_cgraph * gf = ggml_new_graph_custom(ctx.get(), graph_size, false); + for (const auto & copy : copies) { + ggml_tensor * src = ggml_view_1d(ctx.get(), copy.src, copy.ne, copy.src_off); + ggml_tensor * dst = ggml_view_1d(ctx.get(), copy.dst, copy.ne, copy.dst_off); + ggml_tensor * out = ggml_cpy(ctx.get(), src, dst); + ggml_build_forward_expand(gf, out); + } + + ggml_backend_sched_synchronize(sched.get()); + const ggml_status status = ggml_backend_graph_compute(graph_backend, gf); + if (status != GGML_STATUS_SUCCESS) { + LLAMA_LOG_WARN("%s: graph rollback failed with status %d; falling back to host copy\n", + __func__, (int)status); + return false; + } + + return true; + }; + + if (try_graph_rollback()) { + return true; + } + + // Persist tensor may be F32 (correctness baseline) or F16 (memory-saving + // variant). s_state may be F32 (Qwen3.5 hybrid stores SSM in F32). + std::vector bounce_f16((size_t)n_embd_s); + std::vector bounce_f32((size_t)n_embd_s); + + if (!skip_s_rollback) { + for (int il = 0; il < n_layer; ++il) { + if (!hparams.is_recurrent(il)) { continue; } + ggml_tensor * persist = dflash_persist_inter_l[il]; + ggml_tensor * s_state = (il < (int32_t)mem_recr->s_l.size()) ? mem_recr->s_l[il] : nullptr; + if (!persist || !s_state) { continue; } + + if (accepted_dfs_node >= dflash_persist_max_n_tokens) { + LLAMA_LOG_WARN("%s: accepted_dfs_node=%d >= persist capacity=%lld at il=%d\n", + __func__, (int)accepted_dfs_node, + (long long)dflash_persist_max_n_tokens, il); + continue; + } + + const size_t state_row_bytes = ggml_row_size(s_state->type, n_embd_s); + const size_t state_offset = (size_t)cell_id * state_row_bytes; + + if (persist->type == GGML_TYPE_F32) { + const size_t persist_col_bytes = (size_t)n_embd_s * sizeof(float); + const size_t persist_offset = (size_t)accepted_dfs_node * persist_col_bytes; + ggml_backend_tensor_get(persist, bounce_f32.data(), persist_offset, persist_col_bytes); + + if (s_state->type == GGML_TYPE_F32) { + ggml_backend_tensor_set(s_state, bounce_f32.data(), state_offset, state_row_bytes); + } else if (s_state->type == GGML_TYPE_F16) { + ggml_fp32_to_fp16_row(bounce_f32.data(), bounce_f16.data(), n_embd_s); + ggml_backend_tensor_set(s_state, bounce_f16.data(), state_offset, state_row_bytes); + } else { + GGML_ABORT("dflash_rollback_ssm_to_dfs: unsupported s_state type"); + } + } else if (persist->type == GGML_TYPE_F16) { + const size_t persist_col_bytes = (size_t)n_embd_s * sizeof(ggml_fp16_t); + const size_t persist_offset = (size_t)accepted_dfs_node * persist_col_bytes; + ggml_backend_tensor_get(persist, bounce_f16.data(), persist_offset, persist_col_bytes); + + if (s_state->type == GGML_TYPE_F16) { + ggml_backend_tensor_set(s_state, bounce_f16.data(), state_offset, state_row_bytes); + } else if (s_state->type == GGML_TYPE_F32) { + ggml_fp16_to_fp32_row(bounce_f16.data(), bounce_f32.data(), n_embd_s); + ggml_backend_tensor_set(s_state, bounce_f32.data(), state_offset, state_row_bytes); + } else { + GGML_ABORT("dflash_rollback_ssm_to_dfs: unsupported s_state type"); + } + } else { + GGML_ABORT("dflash_rollback_ssm_to_dfs: unsupported persist type"); + } + } + } + + // Phase 5 fix: also roll the conv state (r_l[il]) back to accepted_dfs_node. + // Without this, the conv window stays at the DFS-last node and pollutes the + // root forward of the next spec step. + if (!skip_conv_rollback && !dflash_persist_conv_l.empty()) { + const int64_t n_embd_r = (int64_t)hparams.n_embd_r(); + std::vector bounce_conv((size_t)n_embd_r); + for (int il = 0; il < n_layer; ++il) { + if (!hparams.is_recurrent(il)) { continue; } + ggml_tensor * persist_conv = (il < (int32_t)dflash_persist_conv_l.size()) + ? dflash_persist_conv_l[il] : nullptr; + ggml_tensor * r_state = (il < (int32_t)mem_recr->r_l.size()) + ? mem_recr->r_l[il] : nullptr; + if (!persist_conv || !r_state) { continue; } + if (accepted_dfs_node >= dflash_persist_max_n_tokens) { continue; } + + GGML_ASSERT(persist_conv->type == GGML_TYPE_F32); + // persist_conv layout: [K-1, conv_channels, n_tokens] F32 contiguous; + // each token col is exactly n_embd_r elements (K-1 * conv_channels). + const size_t conv_col_bytes = (size_t)n_embd_r * sizeof(float); + const size_t conv_off_src = (size_t)accepted_dfs_node * conv_col_bytes; + ggml_backend_tensor_get(persist_conv, bounce_conv.data(), conv_off_src, conv_col_bytes); + + // Live r_state row layout: ggml_row_size(type, n_embd_r) per cell. + const size_t r_row_bytes = ggml_row_size(r_state->type, n_embd_r); + const size_t r_off_dst = (size_t)cell_id * r_row_bytes; + if (r_state->type == GGML_TYPE_F32) { + ggml_backend_tensor_set(r_state, bounce_conv.data(), r_off_dst, r_row_bytes); + } else { + GGML_ABORT("dflash_rollback_ssm_to_dfs: unsupported r_state type"); + } + } + } + + return true; +} + +void llama_kv_cache_seq_compact_tree( + struct llama_context * ctx, + llama_seq_id seq_id, + const int32_t * accepted_dfs, + int32_t n_accepted, + int32_t commit_n, + int32_t spine_start) { + auto * raw_mem = ctx->get_memory(); + llama_kv_cache * kv = dynamic_cast(raw_mem); + if (!kv) { + if (auto * hyb = dynamic_cast(raw_mem)) { + kv = hyb->get_mem_attn(); + } + } + if (!kv) { + // non-KV memory (pure SSM) — no cache compaction needed + return; + } + std::vector dfs_vec(accepted_dfs, accepted_dfs + n_accepted); + kv->seq_compact_tree(seq_id, dfs_vec, commit_n, spine_start); +} + // llama state API // deprecated diff --git a/src/llama-context.h b/src/llama-context.h index e0d0085c1c3..3005d3987f7 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -91,6 +91,12 @@ struct llama_context { const llama_token * get_sampled_candidates_ith(int32_t idx); size_t get_sampled_candidates_count(int32_t idx); + bool get_dflash_draft_top_k( + const float ** top_logits, + const llama_token ** top_token_ids, + int32_t * n_rows, + int32_t * k); + void attach_threadpool( ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch); @@ -105,6 +111,92 @@ struct llama_context { void set_causal_attn(bool value); void set_warmup(bool value); + // dflash hidden capture API + void set_capture_hidden(bool enable); + ggml_tensor * get_hidden_capture() const; + void set_dflash_draft_top_k(int32_t k); + + // dflash Phase 2.4: persist-based SSM rollback. + // Copies the SSM state stored in dflash_persist_inter_l[il] at DFS column + // accepted_dfs_node back into the live s_l[il] tensor at the seq's tail cell. + // Returns false if the context has no recurrent memory or buffers are unallocated. + bool dflash_rollback_ssm_to_dfs(llama_seq_id seq_id, int32_t accepted_dfs_node); + + // Host-side accessor: returns pointer into hidden_capture_host (always CPU). + // Returns nullptr if capture is disabled or no decode has run. + const float * get_hidden_capture_data(int64_t * out_ne0, int64_t * out_ne1) const; + + // dflash draft target_feat injection API (Task 1). + // Stashes host pointer + dims so the next llama_decode() on this draft context + // can copy the data into the dflash_target_feat_raw GGML input tensor. + // committed_pos is the number of tokens committed in the target context so far. + void set_target_feat_raw(const float * data, int64_t n_embd_fc, int64_t ctx_len, + int64_t committed_pos); + void set_target_feat_fused(const float * data, int64_t n_embd, int64_t ctx_len, + int64_t committed_pos); + int dflash_draft_fuse_target_feat(const float * target_feat_raw, + int64_t n_embd_fc, + int64_t ctx_len, + float * target_feat_fused); + int dflash_draft_update_fused_cache(const float * target_feat_raw, + int64_t n_embd_fc, + int64_t n_new, + int64_t first_pos, + int64_t cap); + int dflash_draft_update_fused_cache_from_capture(llama_context * target_ctx, + const int32_t * dfs_indices, + int32_t n_dfs, + int64_t first_pos, + int64_t cap); + int dflash_draft_encode_top_k_cached(const llama_batch & batch_inp, + int64_t n_embd, + int64_t ctx_len, + int64_t ring_start, + int64_t cap, + int64_t committed_pos, + int32_t top_k); + bool dflash_draft_ensure_fused_cache_tensor(int64_t n_embd, + int64_t cap, + ggml_backend_buffer_type_t buft); + bool dflash_draft_ensure_packed_target_feat_tensor(int64_t n_embd, + int64_t ctx_len, + ggml_backend_buffer_type_t buft); + bool dflash_draft_ensure_kv_cache_tensors(int64_t n_embd_head, + int64_t n_head_kv, + int64_t cap, + ggml_backend_buffer_type_t buft); + bool dflash_draft_ensure_packed_kv_tensors(int64_t n_embd_head, + int64_t n_head_kv, + int64_t ctx_len, + ggml_backend_buffer_type_t buft); + bool dflash_draft_ensure_top_output_tensors(int64_t top_k, + int64_t rows, + ggml_backend_buffer_type_t buft); + int dflash_draft_update_kv_cache(const float * target_feat_fused, + int64_t n_embd, + int64_t n_new, + int64_t first_pos, + int64_t cap); + int dflash_draft_pack_kv_cache(int64_t n_embd_head, + int64_t n_head_kv, + int64_t ctx_len, + int64_t ring_start, + int64_t cap); + int dflash_draft_encode_top_k(const llama_batch & batch_inp, + const float * target_feat_raw, + int64_t n_embd_fc, + int64_t ctx_len, + int64_t committed_pos, + int32_t top_k); + int dflash_draft_encode_top_k_fused(const llama_batch & batch_inp, + const float * target_feat_fused, + int64_t n_embd, + int64_t ctx_len, + int64_t committed_pos, + int32_t top_k); + int dflash_draft_encode_top_k_pending(const llama_batch & batch_inp, + int32_t top_k); + void set_adapters_lora(llama_adapter_lora ** adapters, size_t n_adapters, float * scales); bool adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales); @@ -227,6 +319,7 @@ struct llama_context { // returns the result of ggml_backend_sched_graph_compute_async execution ggml_status graph_compute(ggml_cgraph * gf, bool batched); + ggml_status graph_compute(ggml_backend_sched_t sched_use, ggml_cgraph * gf, bool batched); // reserve a graph with a dummy ubatch of the specified size ggml_cgraph * graph_reserve( @@ -234,12 +327,16 @@ struct llama_context { bool set_sampler(llama_seq_id seq_id, llama_sampler * sampler); + // Ensure the DDTree persist buffers can hold n_tokens columns; reallocates if needed. + void ensure_dflash_persist_capacity(int64_t n_tokens); + private: llm_graph_params graph_params( llm_graph_result * res, const llama_ubatch & ubatch, const llama_memory_context_i * mctx, - llm_graph_type gtype) const; + llm_graph_type gtype, + ggml_backend_sched_t sched_use = nullptr) const; llm_graph_cb graph_get_cb() const; @@ -291,6 +388,12 @@ struct llama_context { sampling_info sampling; + int32_t dflash_draft_top_k_req = 0; + std::vector dflash_draft_top_logits; + std::vector dflash_draft_top_token_ids; + int32_t dflash_draft_top_rows = 0; + int32_t dflash_draft_top_k = 0; + // sequence embeddings output (map of [n_embd] vectors) // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE std::map> embd_seq; @@ -334,12 +437,104 @@ struct llama_context { llm_graph_result_ptr gf_res_prev; llm_graph_result_ptr gf_res_reserve; + llm_graph_result_ptr dflash_res_fuse; + llm_graph_result_ptr dflash_res_kv; + llm_graph_result_ptr dflash_res_draft; + ggml_backend_sched_ptr dflash_sched_fuse; + ggml_backend_sched_ptr dflash_sched_kv; + ggml_backend_sched_ptr dflash_sched_draft; // host buffer for the model output (logits and embeddings) ggml_backend_buffer_ptr buf_output; bool has_evaluated_once = false; + // dflash hidden capture: when true, qwen35 forward writes captured hidden states + // into a graph output tensor; accessible via get_hidden_capture() after decode. + bool capture_hidden = false; + + // dflash Phase 2.4: per-layer SSM intermediate-state persist buffers. + // Allocated on first tree-mode decode; one tensor per delta-net layer (nullptr for + // full-attn layers). Shape: [S_v, S_v, H_v, n_tokens] F16, contiguous. + // After tree verify, llama_dflash_rollback_ssm_to_dfs() copies column[accepted_dfs_node] + // back into the live SSM state, replacing the snapshot/restore/replay path. + std::vector dflash_persist_inter_l; // [n_layer], nullptr for non-recurrent + ggml_context_ptr dflash_persist_inter_ctx; // ggml context owning the tensors + ggml_backend_buffer_ptr dflash_persist_inter_buf; // backend buffer owning the data + std::vector> + dflash_persist_ctxs_bufs; // per backend buffer type for mixed CPU/GPU offload + int64_t dflash_persist_max_n_tokens = 0; // current capacity + int64_t dflash_persist_failed_n_tokens = 0; // suppress repeated OOM retries + + // dflash Phase 5 fix: per-token conv post-state persist buffer used by + // ggml_ssm_conv_tree_persist. One tensor per delta-net layer; shape + // [K_conv-1, conv_channels, n_tokens] F32. Read by dflash_rollback_ssm_to_dfs + // to roll the live conv state (r_l[il]) back to the accepted DFS node. + std::vector dflash_persist_conv_l; // [n_layer], nullptr for non-recurrent + + // Returns the per-layer SSM/conv persist tensors for layer il, or nullptr if not + // a recurrent layer or the buffers have not yet been allocated. + ggml_tensor * dflash_get_persist_inter(int32_t il) const; + ggml_tensor * dflash_get_persist_conv (int32_t il) const; + + // host-side mirror of t_hidden_capture, populated via ggml_backend_tensor_get_async + // after each decode. get_hidden_capture_data() returns into this buffer so callers + // don't dereference device pointers. + mutable std::vector hidden_capture_host; + mutable int64_t hidden_capture_ne0 = 0; + mutable int64_t hidden_capture_ne1 = 0; + mutable bool hidden_capture_host_valid = false; + + // dflash draft target_feat injection: stashed by llama_set_target_feat_raw() before + // llama_decode() on the draft context. The dflash-draft graph input reads from these + // fields in set_input() and copies them into the host-pinned GGML input tensors. + // Non-owning pointer — caller (speculative-tree-driver) owns the lifetime. + // Mutable so graph_params() (a const method) can take their address for the param struct. + mutable const float * pending_target_feat_raw = nullptr; + mutable int64_t pending_target_feat_n_embd_fc = 0; + mutable int64_t pending_target_feat_ctx_len = 0; + mutable int64_t pending_draft_committed_pos = 0; + mutable bool pending_target_feat_fused = false; + mutable bool pending_dflash_fuse_only = false; + mutable bool pending_dflash_kv_update_only = false; + mutable int64_t pending_dflash_kv_update_dst_pos = 0; + mutable ggml_tensor * pending_target_feat_tensor = nullptr; + + ggml_context_ptr dflash_fused_cache_ctx; + ggml_backend_buffer_ptr dflash_fused_cache_buf; + ggml_tensor * dflash_fused_cache = nullptr; + int64_t dflash_fused_cache_n_embd = 0; + int64_t dflash_fused_cache_cap = 0; + + ggml_context_ptr dflash_packed_target_feat_ctx; + ggml_backend_buffer_ptr dflash_packed_target_feat_buf; + ggml_tensor * dflash_packed_target_feat = nullptr; + int64_t dflash_packed_target_feat_n_embd = 0; + int64_t dflash_packed_target_feat_ctx_len = 0; + + ggml_context_ptr dflash_kv_cache_ctx; + ggml_backend_buffer_ptr dflash_kv_cache_buf; + std::vector dflash_k_cache_l; + std::vector dflash_v_cache_l; + int64_t dflash_kv_cache_head_dim = 0; + int64_t dflash_kv_cache_n_head_kv = 0; + int64_t dflash_kv_cache_cap = 0; + + ggml_context_ptr dflash_kv_packed_ctx; + ggml_backend_buffer_ptr dflash_kv_packed_buf; + std::vector dflash_k_packed_l; + std::vector dflash_v_packed_l; + int64_t dflash_kv_packed_head_dim = 0; + int64_t dflash_kv_packed_n_head_kv = 0; + int64_t dflash_kv_packed_ctx_len = 0; + + ggml_context_ptr dflash_top_output_ctx; + ggml_backend_buffer_ptr dflash_top_output_buf; + ggml_tensor * dflash_top_logits_fixed = nullptr; + ggml_tensor * dflash_top_ids_fixed = nullptr; + int64_t dflash_top_output_k = 0; + int64_t dflash_top_output_rows = 0; + // env: LLAMA_GRAPH_REUSE_DISABLE bool graph_reuse_disable = false; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 8e2b6ab8e7e..128b1a1034b 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -326,8 +326,14 @@ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) { res &= s_copy_main->ne[0] == params.ubatch.n_seqs; res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs; - res &= head == mctx->get_head(); - res &= rs_z == mctx->get_rs_z(); + const bool read_only_tree_compatible = + read_only_tree && + params.ubatch.parent_id != nullptr && + params.ubatch.n_tokens > 1; + if (!read_only_tree_compatible) { + res &= head == mctx->get_head(); + res &= rs_z == mctx->get_rs_z(); + } return res; } @@ -613,8 +619,14 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs; - res &= inp_rs->head == mctx->get_recr()->get_head(); - res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z(); + const bool read_only_tree_compatible = + inp_rs->read_only_tree && + params.ubatch.parent_id != nullptr && + params.ubatch.n_tokens > 1; + if (!read_only_tree_compatible) { + res &= inp_rs->head == mctx->get_recr()->get_head(); + res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z(); + } return res; } @@ -783,6 +795,17 @@ bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) { return true; } +void llm_graph_input_tree::set_input(const llama_ubatch * ubatch) { + GGML_ASSERT(ubatch->parent_id != nullptr); + GGML_ASSERT(inp_parent_ids != nullptr); + + const int32_t n_tokens = (int32_t) ubatch->n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(inp_parent_ids->buffer)); + int32_t * data = (int32_t *) inp_parent_ids->data; + memcpy(data, ubatch->parent_id, n_tokens * sizeof(int32_t)); +} + // // llm_graph_result // @@ -801,9 +824,12 @@ int64_t llm_graph_result::get_max_nodes() const { void llm_graph_result::reset() { t_inp_tokens = nullptr; t_inp_embd = nullptr; - t_logits = nullptr; - t_embd = nullptr; - t_embd_pooled = nullptr; + t_logits = nullptr; + t_embd = nullptr; + t_embd_pooled = nullptr; + t_hidden_capture = nullptr; + t_dflash_top_logits = nullptr; + t_dflash_top_ids = nullptr; t_sampled.clear(); t_sampled_probs.clear(); t_sampled_logits.clear(); @@ -842,6 +868,15 @@ void llm_graph_result::set_outputs() { if (t_embd_pooled != nullptr) { ggml_set_output(t_embd_pooled); } + if (t_hidden_capture != nullptr) { + ggml_set_output(t_hidden_capture); + } + if (t_dflash_top_logits != nullptr) { + ggml_set_output(t_dflash_top_logits); + } + if (t_dflash_top_ids != nullptr) { + ggml_set_output(t_dflash_top_ids); + } for (auto & [seq_id, t] : t_sampled) { if (t != nullptr) { ggml_set_output(t); @@ -896,6 +931,81 @@ bool llm_graph_result::can_reuse(const llm_graph_params & params) { return res; } +void llm_graph_input_target_feat::set_input(const llama_ubatch * ubatch) { + GGML_UNUSED(ubatch); + + // host_data, n_embd_fc, ctx_len are stashed by llama_set_target_feat_raw() before decode. + const float * data = *host_data_ptr; + const int64_t fc = *n_embd_fc_ptr; + const int64_t ctx_len = *ctx_len_ptr; + + if (inp_target_feat_raw && inp_target_feat_raw->buffer != nullptr && data != nullptr) { + GGML_ASSERT(inp_target_feat_raw->ne[0] == fc); + GGML_ASSERT(inp_target_feat_raw->ne[1] == ctx_len); + ggml_backend_tensor_set(inp_target_feat_raw, data, 0, (size_t)fc * ctx_len * sizeof(float)); + } + + // pos_q is local to the draft attention window, not the target's global + // sequence position. The draft attends over target_feat[0..ctx_len) plus + // the block's noise tokens, matching standalone DFlash's draft_ctx+i. + if (inp_pos_q && inp_pos_q->buffer != nullptr) { + const int64_t block_size = inp_pos_q->ne[0]; + std::vector pos_q(block_size); + for (int64_t i = 0; i < block_size; ++i) { + pos_q[i] = (int32_t)(ctx_len + i); + } + ggml_backend_tensor_set(inp_pos_q, pos_q.data(), 0, block_size * sizeof(int32_t)); + } + + // pos_k: [0 .. ctx_len + block_size) + if (inp_pos_k && inp_pos_k->buffer != nullptr) { + const int64_t total_k = inp_pos_k->ne[0]; + std::vector pos_k(total_k); + for (int64_t i = 0; i < total_k; ++i) { + pos_k[i] = (int32_t)i; + } + ggml_backend_tensor_set(inp_pos_k, pos_k.data(), 0, total_k * sizeof(int32_t)); + } +} + +bool llm_graph_input_target_feat::can_reuse(const llm_graph_params & params) { + if (params.pending_target_feat_raw_ptr == nullptr || + params.pending_target_feat_n_embd_fc_ptr == nullptr || + params.pending_target_feat_ctx_len_ptr == nullptr) { + return false; + } + + const int64_t fc = *params.pending_target_feat_n_embd_fc_ptr; + const int64_t ctx_len = *params.pending_target_feat_ctx_len_ptr; + const int64_t n_tokens = params.ubatch.n_tokens; + + if (fc <= 0 || ctx_len <= 0 || n_tokens <= 0) { + return false; + } + + bool res = true; + res &= inp_target_feat_raw != nullptr; + res &= inp_pos_q != nullptr; + res &= inp_pos_k != nullptr; + + if (inp_target_feat_raw) { + res &= inp_target_feat_raw->ne[0] == fc; + res &= inp_target_feat_raw->ne[1] == ctx_len; + } + if (inp_pos_q) { + res &= inp_pos_q->ne[0] == n_tokens; + } + if (inp_pos_k) { + res &= inp_pos_k->ne[0] == ctx_len + n_tokens; + } + + if (debug > 1) { + LLAMA_LOG_DEBUG("%s: can reuse dflash target_feat graph input = %d\n", __func__, res); + } + + return res; +} + llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) { inputs.emplace_back(std::move(input)); return inputs.back().get(); @@ -948,6 +1058,21 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : samplers (params.samplers), cb_func (params.cb), res (params.res), + capture_hidden (params.capture_hidden), + dflash_persist_inter_l(params.dflash_persist_inter_l), + dflash_target_feat_fused(params.dflash_target_feat_fused), + dflash_kv_update_only(params.dflash_kv_update_only), + dflash_fuse_only(params.dflash_fuse_only), + dflash_draft_top_k(params.dflash_draft_top_k), + dflash_persist_conv_l (params.dflash_persist_conv_l), + pending_target_feat_raw_ptr (params.pending_target_feat_raw_ptr), + pending_target_feat_n_embd_fc_ptr(params.pending_target_feat_n_embd_fc_ptr), + pending_target_feat_ctx_len_ptr (params.pending_target_feat_ctx_len_ptr), + pending_draft_committed_pos_ptr (params.pending_draft_committed_pos_ptr), + pending_target_feat_tensor_ptr (params.pending_target_feat_tensor_ptr), + dflash_kv_cache_k_l (params.dflash_kv_cache_k_l), + dflash_kv_cache_v_l (params.dflash_kv_cache_v_l), + dflash_kv_cache_dst_pos (params.dflash_kv_cache_dst_pos), ctx0 (res->get_ctx()), gf (res->get_gf()) { res->set_params(params); @@ -1719,6 +1844,35 @@ ggml_tensor * llm_graph_context::build_inp_pos() const { return cur; } +llm_graph_input_target_feat * llm_graph_context::build_inp_target_feat(int64_t n_embd_fc, int64_t ctx_len) const { + // The graph context holds non-owning pointers into the llama_context's pending_target_feat + // fields, propagated via llm_graph_params. The llama_context outlives every graph invocation, + // so the pointer lifetime is safe within a single decode call. + GGML_ASSERT(pending_target_feat_raw_ptr != nullptr && + "build_inp_target_feat called without pending pointers wired in llm_graph_params"); + auto inp = std::make_unique( + pending_target_feat_raw_ptr, + pending_target_feat_n_embd_fc_ptr, + pending_target_feat_ctx_len_ptr, + pending_draft_committed_pos_ptr); + + const int64_t block_size = n_tokens; // == dflash_block_size at draft invocation time + + inp->inp_target_feat_raw = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_fc, ctx_len); + ggml_set_name(inp->inp_target_feat_raw, "dflash_target_feat_raw"); + ggml_set_input(inp->inp_target_feat_raw); + + inp->inp_pos_q = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, block_size); + ggml_set_name(inp->inp_pos_q, "dflash_pos_q"); + ggml_set_input(inp->inp_pos_q); + + inp->inp_pos_k = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ctx_len + block_size); + ggml_set_name(inp->inp_pos_k, "dflash_pos_k"); + ggml_set_input(inp->inp_pos_k); + + return (llm_graph_input_target_feat *) res->add_input(std::move(inp)); +} + ggml_tensor * llm_graph_context::build_inp_attn_scale() const { auto inp = std::make_unique(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale, hparams.f_attn_temp_offset); @@ -1805,6 +1959,23 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const { return cur; } +void llm_graph_context::build_inp_tree() const { + auto inp = std::make_unique(); + + inp->inp_parent_ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->inp_parent_ids); + ggml_set_name(inp->inp_parent_ids, "parent_ids"); + + // The ancestor-only attention mask is written directly into the standard + // kq_mask buffer in llama_kv_cache::set_input_kq_mask when ubatch->parent_id + // is set, so we don't allocate a separate tree_mask graph input here. + // TODO: phase-1 leaves pos as 1D — M-RoPE 4-axis is UNKNOWN-3 in roadmap + + const_cast(this)->parent_ids = inp->inp_parent_ids; + + res->add_input(std::move(inp)); +} + ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const { auto inp = std::make_unique(hparams); @@ -2431,10 +2602,13 @@ ggml_tensor * llm_graph_context::build_rs( ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size); - // Clear a single state which will then be copied to the other cleared states. - // Note that this is a no-op when the view is zero-sized. - ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0)); - ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0)); + const bool read_only_tree = parent_ids != nullptr && ubatch.parent_id != nullptr && ubatch.n_tokens > 1; + if (!read_only_tree) { + // Clear a single state which will then be copied to the other cleared states. + // Note that this is a no-op when the view is zero-sized. + ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0)); + ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0)); + } // copy states // NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs @@ -2442,12 +2616,14 @@ ggml_tensor * llm_graph_context::build_rs( ggml_tensor * output_states = get_state_rows(ctx0, states, state_copy_main); ggml_build_forward_expand(gf, output_states); - // copy extra states which won't be changed further (between n_seqs and n_rs) - ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra); - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, - states_extra, - ggml_view_2d(ctx0, s, state_size, (n_rs - n_seqs), s->nb[1], (rs_head + n_seqs)*s->nb[1]))); + if (!read_only_tree) { + // copy extra states which won't be changed further (between n_seqs and n_rs) + ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra); + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, + states_extra, + ggml_view_2d(ctx0, s, state_size, (n_rs - n_seqs), s->nb[1], (rs_head + n_seqs)*s->nb[1]))); + } return output_states; } @@ -2470,6 +2646,7 @@ static std::unique_ptr build_rs_inp_impl( inp->head = mctx_cur->get_head(); inp->rs_z = mctx_cur->get_rs_z(); + inp->read_only_tree = ubatch.parent_id != nullptr && ubatch.n_tokens > 1; return inp; } diff --git a/src/llama-graph.h b/src/llama-graph.h index 29e78451fbb..f99673a084f 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -243,6 +243,7 @@ class llm_graph_input_rs : public llm_graph_input_i { // used in view offsets, need to match for valid graph reuse uint32_t head; int32_t rs_z; + bool read_only_tree = false; }; class llm_graph_input_cross_embd : public llm_graph_input_i { @@ -511,6 +512,52 @@ class llm_graph_input_sampling : public llm_graph_input_i { std::map samplers; }; +// Input for tree-mode forward. Holds parent_ids; ancestor mask is written +// into the standard kq_mask by llama_kv_cache::set_input_kq_mask when +// ubatch->parent_id is non-null, so no separate mask tensor is allocated here. +class llm_graph_input_tree : public llm_graph_input_i { +public: + llm_graph_input_tree() = default; + virtual ~llm_graph_input_tree() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * inp_parent_ids = nullptr; // I32 [n_tokens] +}; + +// Graph input class for the dflash draft model's target_feat_raw, pos_q, and pos_k tensors. +// The host side stashes the data via llama_set_target_feat_raw() before calling llama_decode() +// on the draft context. set_input() memcpy's the stashed data into the GGML input tensors. +class llm_graph_input_target_feat : public llm_graph_input_i { +public: + // host_data, n_embd_fc, ctx_len are non-owning; they point into llama_context's pending fields. + // committed_pos is the number of tokens already committed before this draft step. + llm_graph_input_target_feat( + const float ** host_data_ptr, // pointer to context's pending_target_feat_raw field + const int64_t * n_embd_fc_ptr, // pointer to context's pending_target_feat_n_embd_fc + const int64_t * ctx_len_ptr, // pointer to context's pending_target_feat_ctx_len + const int64_t * committed_pos_ptr) // pointer to context's pending_draft_committed_pos + : host_data_ptr(host_data_ptr), n_embd_fc_ptr(n_embd_fc_ptr), + ctx_len_ptr(ctx_len_ptr), committed_pos_ptr(committed_pos_ptr) {} + virtual ~llm_graph_input_target_feat() = default; + + void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + + // [5*n_embd, ctx_len] F32 — stacked hidden captures from target layers + ggml_tensor * inp_target_feat_raw = nullptr; + // [block_size] I32 — Q positions: [committed_pos .. committed_pos + block_size) + ggml_tensor * inp_pos_q = nullptr; + // [ctx_len + block_size] I32 — K positions: [0 .. ctx_len + block_size) + ggml_tensor * inp_pos_k = nullptr; + +private: + const float ** host_data_ptr; + const int64_t * n_embd_fc_ptr; + const int64_t * ctx_len_ptr; + const int64_t * committed_pos_ptr; +}; + // // llm_graph_result // @@ -567,6 +614,39 @@ struct llm_graph_params { llm_graph_result * res; + // If true, qwen35 forward writes hidden states at dflash_target_capture_layers + // into t_hidden_capture on the result. No-op (zero overhead) when false. + bool capture_hidden = false; + + // dflash draft target_feat injection (Task 1). + // Non-owning pointers into llama_context's pending_target_feat fields. + // Non-null only when running the dflash-draft graph; graph inputs use them in set_input(). + const float ** pending_target_feat_raw_ptr = nullptr; + const int64_t * pending_target_feat_n_embd_fc_ptr = nullptr; + const int64_t * pending_target_feat_ctx_len_ptr = nullptr; + const int64_t * pending_draft_committed_pos_ptr = nullptr; + ggml_tensor * const * pending_target_feat_tensor_ptr = nullptr; + const std::vector * dflash_kv_cache_k_l = nullptr; + const std::vector * dflash_kv_cache_v_l = nullptr; + int64_t dflash_kv_cache_dst_pos = 0; + + bool dflash_target_feat_fused = false; + bool dflash_kv_update_only = false; + bool dflash_fuse_only = false; + int32_t dflash_draft_top_k = 0; + + // dflash Phase 2.4: per-layer SSM intermediate-state persist buffers. + // Non-owning pointer into llama_context::dflash_persist_inter_l. Null when not in + // tree mode or when the buffers have not yet been allocated. Graph builder reads + // (*dflash_persist_inter_l)[il] for each recurrent layer and passes it to + // build_delta_net_tree() as the persist_inter argument. + const std::vector * dflash_persist_inter_l = nullptr; + + // dflash Phase 5: per-layer conv post-state persist buffers (paired with + // dflash_persist_inter_l). Read by ggml_ssm_conv_tree_persist; rolled back + // into r_l[il] after spec verify. + const std::vector * dflash_persist_conv_l = nullptr; + // return true if the "other" params would result in a graph with the same topology as with the current params // having the same topology allows us to reuse the graph in some cases bool allow_reuse(const llm_graph_params & other) const { @@ -625,11 +705,18 @@ struct llm_graph_params { return cparams.embeddings == other.cparams.embeddings && cparams.causal_attn == other.cparams.causal_attn && - arch == other.arch && - gtype == other.gtype && - cvec == other.cvec && - loras == other.loras && - cross == other.cross; + arch == other.arch && + gtype == other.gtype && + cvec == other.cvec && + loras == other.loras && + cross == other.cross && + capture_hidden == other.capture_hidden && + dflash_target_feat_fused == other.dflash_target_feat_fused && + dflash_kv_update_only == other.dflash_kv_update_only && + dflash_fuse_only == other.dflash_fuse_only && + dflash_draft_top_k == other.dflash_draft_top_k && + (dflash_persist_inter_l != nullptr) == (other.dflash_persist_inter_l != nullptr) && + (dflash_persist_conv_l != nullptr) == (other.dflash_persist_conv_l != nullptr); } }; @@ -640,9 +727,12 @@ class llm_graph_result { virtual ~llm_graph_result() = default; ggml_tensor * get_inp_tokens() const { return t_inp_tokens; } - ggml_tensor * get_logits() const { return t_logits; } - ggml_tensor * get_embd() const { return t_embd; } - ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } + ggml_tensor * get_logits() const { return t_logits; } + ggml_tensor * get_embd() const { return t_embd; } + ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } + ggml_tensor * get_hidden_capture() const { return t_hidden_capture; } + ggml_tensor * get_dflash_top_logits() const { return t_dflash_top_logits; } + ggml_tensor * get_dflash_top_ids() const { return t_dflash_top_ids; } ggml_cgraph * get_gf() const { return gf; } ggml_context * get_ctx() const { return ctx_compute.get(); } @@ -666,11 +756,18 @@ class llm_graph_result { void set_params(const llm_graph_params & params); // important graph nodes - ggml_tensor * t_inp_tokens = nullptr; - ggml_tensor * t_inp_embd = nullptr; // [n_embd_inp, n_tokens] - ggml_tensor * t_logits = nullptr; - ggml_tensor * t_embd = nullptr; - ggml_tensor * t_embd_pooled = nullptr; + ggml_tensor * t_inp_tokens = nullptr; + ggml_tensor * t_inp_embd = nullptr; // [n_embd_inp, n_tokens] + ggml_tensor * t_logits = nullptr; + ggml_tensor * t_embd = nullptr; + ggml_tensor * t_embd_pooled = nullptr; + // dflash hidden capture: [5*n_embd, n_tokens] F32, populated when capture_hidden=true in graph_params + ggml_tensor * t_hidden_capture = nullptr; + + // dflash-draft top-K graph outputs: [K, n_tokens] + // t_dflash_top_logits stores full-vocab-normalized log-probs for the selected ids. + ggml_tensor * t_dflash_top_logits = nullptr; + ggml_tensor * t_dflash_top_ids = nullptr; std::map t_sampled_logits; std::map t_candidates; @@ -758,9 +855,41 @@ struct llm_graph_context { llm_graph_result * res; + // dflash hidden capture: propagated from llm_graph_params::capture_hidden + const bool capture_hidden; + + // dflash Phase 2.4: per-layer SSM intermediate-state persist buffer pointers. + // Non-owning pointer into llama_context::dflash_persist_inter_l (via graph_params). + // Null when not in tree mode. Indexed by layer index il. + const std::vector * dflash_persist_inter_l; + + bool dflash_target_feat_fused; + bool dflash_kv_update_only; + bool dflash_fuse_only; + int32_t dflash_draft_top_k; + + // dflash Phase 5: per-layer conv post-state persist buffer pointers + // (paired with dflash_persist_inter_l). + const std::vector * dflash_persist_conv_l; + + // dflash draft target_feat injection: propagated from llm_graph_params. + // Non-owning; valid only for the dflash-draft graph builder. + const float ** pending_target_feat_raw_ptr; + const int64_t * pending_target_feat_n_embd_fc_ptr; + const int64_t * pending_target_feat_ctx_len_ptr; + const int64_t * pending_draft_committed_pos_ptr; + ggml_tensor * const * pending_target_feat_tensor_ptr; + const std::vector * dflash_kv_cache_k_l; + const std::vector * dflash_kv_cache_v_l; + int64_t dflash_kv_cache_dst_pos; + ggml_context * ctx0 = nullptr; ggml_cgraph * gf = nullptr; + // tree-mode field: non-null when batch.parent_id is set; consumed by + // ggml_ssm_conv_tree / ggml_gated_delta_net_tree on hybrid layers. + ggml_tensor * parent_ids = nullptr; // [n_tokens] i32 + llm_graph_context(const llm_graph_params & params); virtual ~llm_graph_context() = default; @@ -866,9 +995,17 @@ struct llm_graph_context { ggml_tensor * build_inp_mean() const; ggml_tensor * build_inp_cls() const; + // build tree-mode input tensors (parent_ids + tree_mask); sets this->parent_ids and this->tree_mask + void build_inp_tree() const; + ggml_tensor * build_inp_cross_embd() const; ggml_tensor * build_inp_pos_bucket_enc() const; ggml_tensor * build_inp_pos_bucket_dec() const; + + // Build input tensors for the dflash draft model (target_feat_raw, pos_q, pos_k). + // Returns the target_feat_raw tensor (already registered as graph input). + // pos_q and pos_k are accessible via the returned llm_graph_input_target_feat*. + llm_graph_input_target_feat * build_inp_target_feat(int64_t n_embd_fc, int64_t ctx_len) const; ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const; // diff --git a/src/llama-hparams.h b/src/llama-hparams.h index c2000c77c37..965e5549c1b 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -212,6 +212,13 @@ struct llama_hparams { // gemma4 per-layer embedding uint32_t n_embd_per_layer = 0; + // dflash-draft specific (27B-specific for now: 5-layer draft targeting Qwen3.5-27B) + // target_capture_layers: indices into target model layers to capture hidden states from + std::array dflash_target_capture_layers = {1, 16, 31, 46, 61}; + uint32_t dflash_target_n_embd = 5120; // target hidden dim (n_embd of target model) + uint32_t dflash_mask_token_id = 248070; // mask token id for noise embedding lookup + uint32_t dflash_block_size = 16; // number of noise tokens per spec step + // needed by encoder-decoder models (e.g. T5, FLAN-T5) // ref: https://github.com/ggml-org/llama.cpp/pull/8141 llama_token dec_start_token_id = LLAMA_TOKEN_NULL; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 6b41d6c6ea2..8dbdbdb530b 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -608,6 +608,109 @@ llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const { return cells.seq_pos_max(seq_id); } +void llama_kv_cache::seq_compact_tree( + llama_seq_id seq_id, + const std::vector & accepted_dfs, + int32_t commit_n, + int32_t spine_start) { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); + GGML_ASSERT(commit_n >= 0 && commit_n <= (int32_t) accepted_dfs.size()); + GGML_ASSERT(spine_start >= 0); + + if (commit_n == 0) { + return; + } + + const uint32_t strm = seq_to_stream[seq_id]; + auto & cells = v_cells[strm]; + + // accepted_dfs[i] is the tree-local DFS index of the i-th accepted node. + // The tree was placed at slots [spine_start, spine_start + N), so absolute + // src/dst slot indices need spine_start added. + // + // Copy K/V rows via ggml_backend_tensor_get/set with explicit offsets + // (ggml_backend_tensor_copy doesn't follow view_src->buffer). + std::vector bounce; + + for (int32_t i = 0; i < commit_n; ++i) { + const int32_t src_slot = spine_start + accepted_dfs[i]; + const int32_t dst_slot = spine_start + i; + + if (src_slot == dst_slot) { + continue; + } + + GGML_ASSERT(src_slot >= 0 && (uint32_t) src_slot < cells.size()); + GGML_ASSERT(dst_slot >= 0 && (uint32_t) dst_slot < cells.size()); + + for (auto & layer : layers) { + if (layer.k) { + const size_t k_row_bytes = ggml_row_size(layer.k->type, layer.k->ne[0]); + const size_t k_row_stride = layer.k->nb[1]; + const size_t k_stride_stream = layer.k->nb[2]; + + if (bounce.size() < k_row_bytes) bounce.resize(k_row_bytes); + + const size_t src_off = strm * k_stride_stream + (size_t) src_slot * k_row_stride; + const size_t dst_off = strm * k_stride_stream + (size_t) dst_slot * k_row_stride; + + ggml_backend_tensor_get(layer.k, bounce.data(), src_off, k_row_bytes); + ggml_backend_tensor_set(layer.k, bounce.data(), dst_off, k_row_bytes); + } + + if (layer.v && !v_trans) { + const size_t v_row_bytes = ggml_row_size(layer.v->type, layer.v->ne[0]); + const size_t v_row_stride = layer.v->nb[1]; + const size_t v_stride_stream = layer.v->nb[2]; + + if (bounce.size() < v_row_bytes) bounce.resize(v_row_bytes); + + const size_t src_off = strm * v_stride_stream + (size_t) src_slot * v_row_stride; + const size_t dst_off = strm * v_stride_stream + (size_t) dst_slot * v_row_stride; + + ggml_backend_tensor_get(layer.v, bounce.data(), src_off, v_row_bytes); + ggml_backend_tensor_set(layer.v, bounce.data(), dst_off, v_row_bytes); + } + } + } + + // Update cell metadata: only touch the tree region [spine_start, spine_start+N). + // Past prompt cells (slots < spine_start) are left untouched. + // + // Snapshot positions from the accepted source slots first to avoid aliasing + // when src_slot < dst_slot. + std::vector accepted_pos(commit_n); + for (int32_t i = 0; i < commit_n; ++i) { + const uint32_t src = (uint32_t) (spine_start + accepted_dfs[i]); + accepted_pos[i] = cells.is_empty(src) ? -1 : cells.pos_get(src); + } + + // Clear all tree slots used by seq_id (i.e. cells with slot >= spine_start + // and pos >= the tree start position). To be conservative, scan the entire + // tree region width: assume the tree had at most max(accepted_dfs)+1 nodes, + // but we don't know N here — use the max of accepted_dfs as a lower bound + // and rely on the caller passing the correct spine_start. Clear all cells + // that belong to this seq with slot >= spine_start. + const uint32_t kv_size = cells.size(); + for (uint32_t slot = (uint32_t) spine_start; slot < kv_size; ++slot) { + if (!cells.is_empty(slot) && cells.seq_has(slot, seq_id)) { + cells.rm(slot); + } + } + + // Set the spine slots [spine_start, spine_start+commit_n) with accepted positions + for (int32_t i = 0; i < commit_n; ++i) { + if (accepted_pos[i] >= 0) { + const uint32_t slot = (uint32_t) (spine_start + i); + cells.pos_set(slot, accepted_pos[i]); + cells.seq_add(slot, seq_id); + } + } + + // Search head moves to just past the spine. + v_heads[strm] = (uint32_t) (spine_start + commit_n); +} + std::map llama_kv_cache::memory_breakdown() const { std::map ret; for (const auto & [ctx, buf] : ctxs_bufs) { @@ -1607,7 +1710,7 @@ static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * } } -void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { +void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn, const slot_info & sinfo) const { const uint32_t n_tokens = ubatch->n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); @@ -1621,6 +1724,68 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u // n_tps == n_tokens_per_stream const int64_t n_tps = n_tokens/n_stream; + // Tree-mode mask: each query node attends to all past (committed) KV cells + // unconditionally, plus its exact tree ancestors in the current ubatch. + // Do not match tree nodes by position: siblings share the same depth/pos. + if (ubatch->parent_id != nullptr) { + GGML_ASSERT(n_stream == 1 && "tree-mode requires n_stream == 1 in Phase 4"); + GGML_ASSERT(sinfo.n_stream() == 1 && sinfo.size() == n_tokens); + + // Find the boundary between past KV and the current tree ubatch. + llama_pos tree_min_pos = std::numeric_limits::max(); + for (uint32_t i = 0; i < ubatch->n_tokens; ++i) { + tree_min_pos = std::min(tree_min_pos, ubatch->pos[i]); + } + + const llama_seq_id seq0 = ubatch->seq_id[0][0]; + const auto & cells = v_cells.at(seq_to_stream[seq0]); + + // Pass 1: classify every KV cell once (O(n_kv)). past_visible[j] = 1 + // means cell j belongs to a committed token and is unconditionally + // visible to every tree query. Tree-region cells stay 0 here and are + // turned on per-query in pass 3. + std::vector past_visible(n_kv, 0); + for (int64_t j = 0; j < n_kv; ++j) { + if (cells.is_empty(j) || !cells.seq_has(j, seq0)) { + continue; + } + if (cells.pos_get(j) < tree_min_pos) { + past_visible[j] = 1; + } + } + + // Pass 2: for each query i, fill the row with -INF, then write 0.0f + // for every past_visible cell. Tree-region cells remain -INF until + // pass 3 marks the query's own ancestors. Total cost is O(N * n_kv) + // for fills + O(N * depth) for ancestor walks instead of the previous + // O(N * n_kv * 64) where the inner 64-element ancestor scan was + // duplicated against every KV cell. + for (int64_t i = 0; i < (int64_t) n_tokens; ++i) { + float * row = data + i * n_kv; + std::fill(row, row + n_kv, -INFINITY); + for (int64_t j = 0; j < n_kv; ++j) { + if (past_visible[j]) { + row[j] = 0.0f; + } + } + + // Walk up the parent chain to the root and mark each ancestor's + // KV slot visible. Bounded by tree depth (<= L+1 < 64 in practice) + // and only touches ancestors of this query. + int32_t cur = (int32_t) i; + while (cur >= 0) { + const uint32_t slot = sinfo.idxs[0][cur]; + if (slot < (uint32_t) n_kv && !cells.is_empty(slot) && cells.seq_has(slot, seq0)) { + row[slot] = 0.0f; + } + const int32_t p = ubatch->parent_id[cur]; + if (p < 0) break; + cur = p; + } + } + return; + } + //const int64_t t_start = ggml_time_us(); const args_set_input_kq_mask args = { @@ -2490,7 +2655,7 @@ void llama_kv_cache_context::set_input_v_idxs(ggml_tensor * dst, const llama_uba } void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { - kv->set_input_kq_mask(dst, ubatch, causal_attn); + kv->set_input_kq_mask(dst, ubatch, causal_attn, sinfos[i_cur]); } void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 0b62dc7b232..ad7311cb282 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -179,6 +179,16 @@ class llama_kv_cache : public llama_memory_i { bool update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info); + // After a tree-verify forward fills KV slots [spine_start, spine_start+N), + // compact the accepted spine: K/V rows from slot (spine_start + accepted_dfs[i]) + // are moved to slot (spine_start + i) for i in [0, commit_n). Cells outside + // [spine_start, spine_start+kv_size) are untouched (preserves prompt prefill). + // Only operates on the stream assigned to seq_id. + void seq_compact_tree(llama_seq_id seq_id, + const std::vector & accepted_dfs, + int32_t commit_n, + int32_t spine_start); + // find a slot of kv cells that can hold the ubatch // if cont == true, then the slot must be continuous // return empty slot_info on failure @@ -202,7 +212,7 @@ class llama_kv_cache : public llama_memory_i { void set_input_k_shift(ggml_tensor * dst) const; - void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; + void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn, const slot_info & sinfo) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; void set_input_k_rot(ggml_tensor * dst) const; diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 9287fe45e96..77c6bca5e5d 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -12,6 +12,7 @@ #include #include #include +#include // // llama_memory_recurrent @@ -604,7 +605,7 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) { const int32_t cell_id = s + min; auto & cell = cells[cell_id]; - if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) { + if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens && ubatch.parent_id == nullptr) { // What should happen when the pos backtracks or skips a value? // Clearing the state mid-batch would require special-casing which isn't done. LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", @@ -666,6 +667,147 @@ bool llama_memory_recurrent::get_can_shift() const { return true; } +llama_mem_snapshot_id llama_memory_recurrent::snapshot(llama_seq_id seq_id) { + if (seq_id < 0 || (uint32_t) seq_id >= size) { + return LLAMA_MEM_SNAPSHOT_INVALID; + } + + const int32_t n_layer = (int32_t) r_l.size(); + + snapshot_entry entry; + entry.seq_id = seq_id; + entry.r_backup.resize(n_layer, nullptr); + entry.s_backup.resize(n_layer, nullptr); + + // group backup tensors by buffer type, matching the main cache allocation pattern + struct ggml_backend_buft_comparator { + bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const { + return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0; + } + }; + std::map ctx_map; + + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + ggml_init_params params = { + /*.mem_size =*/ size_t(2u * n_layer * ggml_tensor_overhead()), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context * ctx = ggml_init(params); + if (!ctx) { return nullptr; } + ctx_map.emplace(buft, ctx); + return ctx; + } + return it->second.get(); + }; + + for (int il = 0; il < n_layer; ++il) { + if (r_l[il] == nullptr) { continue; } + + ggml_backend_buffer_t main_buf = r_l[il]->buffer; + ggml_backend_buffer_type_t buft = main_buf + ? ggml_backend_buffer_get_type(main_buf) + : ggml_backend_cpu_buffer_type(); + + ggml_context * ctx = ctx_for_buft(buft); + if (!ctx) { return LLAMA_MEM_SNAPSHOT_INVALID; } + + // one cell worth of r and s + ggml_tensor * rb = ggml_new_tensor_1d(ctx, r_l[il]->type, hparams.n_embd_r()); + ggml_tensor * sb = ggml_new_tensor_1d(ctx, s_l[il]->type, hparams.n_embd_s()); + entry.r_backup[il] = rb; + entry.s_backup[il] = sb; + } + + for (auto & [buft, ctx] : ctx_map) { + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft); + if (!buf) { return LLAMA_MEM_SNAPSHOT_INVALID; } + ggml_backend_buffer_clear(buf, 0); + entry.ctxs_bufs.emplace_back(std::move(ctx), buf); + } + + // copy current cell state into the backup tensors via a host bounce buffer. + // ggml_backend_tensor_copy doesn't follow view_src->buffer, but tensor_get/set do. + const int32_t cell_id = cells[seq_id].tail; + entry.cell_id = cell_id; + if (cell_id >= 0) { + entry.cell_pos = cells[cell_id].pos; + entry.cell_src = cells[cell_id].src; + } + if (cell_id >= 0) { + std::vector bounce; + for (int il = 0; il < n_layer; ++il) { + if (r_l[il] == nullptr) { continue; } + + const size_t r_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r()); + const size_t s_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s()); + + if (bounce.size() < std::max(r_row, s_row)) { + bounce.resize(std::max(r_row, s_row)); + } + + ggml_backend_tensor_get(r_l[il], bounce.data(), (size_t) cell_id * r_row, r_row); + ggml_backend_tensor_set(entry.r_backup[il], bounce.data(), 0, r_row); + + ggml_backend_tensor_get(s_l[il], bounce.data(), (size_t) cell_id * s_row, s_row); + ggml_backend_tensor_set(entry.s_backup[il], bounce.data(), 0, s_row); + } + } + + llama_mem_snapshot_id snap_id = next_snap_id++; + snapshots.emplace(snap_id, std::move(entry)); + return snap_id; +} + +bool llama_memory_recurrent::restore(llama_mem_snapshot_id snap_id) { + auto it = snapshots.find(snap_id); + if (it == snapshots.end()) { return false; } + + const snapshot_entry & entry = it->second; + const llama_seq_id seq_id = entry.seq_id; + + if (seq_id < 0 || (uint32_t) seq_id >= size) { return false; } + + // restore the seq's tail to the cell that was snapshotted + const int32_t cell_id = entry.cell_id; + if (cell_id < 0) { + return true; // no live cell at snapshot time — nothing to restore + } + + cells[seq_id].tail = cell_id; + cells[cell_id].pos = entry.cell_pos; + cells[cell_id].src = entry.cell_src; + + const int32_t n_layer = (int32_t) r_l.size(); + + std::vector bounce; + for (int il = 0; il < n_layer; ++il) { + if (r_l[il] == nullptr) { continue; } + if (entry.r_backup[il] == nullptr) { continue; } + + const size_t r_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r()); + const size_t s_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s()); + + if (bounce.size() < std::max(r_row, s_row)) { + bounce.resize(std::max(r_row, s_row)); + } + + ggml_backend_tensor_get(entry.r_backup[il], bounce.data(), 0, r_row); + ggml_backend_tensor_set(r_l[il], bounce.data(), (size_t) cell_id * r_row, r_row); + + ggml_backend_tensor_get(entry.s_backup[il], bounce.data(), 0, s_row); + ggml_backend_tensor_set(s_l[il], bounce.data(), (size_t) cell_id * s_row, s_row); + } + + return true; +} + +void llama_memory_recurrent::release(llama_mem_snapshot_id snap_id) { + snapshots.erase(snap_id); +} + size_t llama_memory_recurrent::total_size() const { size_t size = 0; for (const auto & [_, buf] : ctxs_bufs) { diff --git a/src/llama-memory-recurrent.h b/src/llama-memory-recurrent.h index 47f01d73912..81ed749aed7 100644 --- a/src/llama-memory-recurrent.h +++ b/src/llama-memory-recurrent.h @@ -6,6 +6,7 @@ #include #include +#include #include // @@ -60,6 +61,12 @@ class llama_memory_recurrent : public llama_memory_i { bool get_can_shift() const override; + // snapshot/restore API for recurrent state (SSM + conv) + // snap_id < 0 indicates failure + llama_mem_snapshot_id snapshot(llama_seq_id seq_id); + bool restore (llama_mem_snapshot_id snap_id); + void release (llama_mem_snapshot_id snap_id); + // state write/load void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; @@ -117,6 +124,25 @@ class llama_memory_recurrent : public llama_memory_i { size_t size_r_bytes() const; size_t size_s_bytes() const; + // per-snapshot: the seq_id that was snapshotted, plus backup tensors for each layer + // and the cell metadata at snapshot time so restore can roll back the position counter. + struct snapshot_entry { + llama_seq_id seq_id; + // backup tensors: one per layer, same type/shape as r_l[il] / s_l[il] for that seq cell + // r_backup[il] and s_backup[il] are null for filtered (null) layers + std::vector r_backup; // [n_layer] + std::vector s_backup; // [n_layer] + // ggml contexts and backend buffers that own the backup tensors + std::vector> ctxs_bufs; + // cell bookkeeping captured at snapshot time + int32_t cell_id = -1; + llama_pos cell_pos = -1; + int32_t cell_src = -1; + }; + + llama_mem_snapshot_id next_snap_id = 0; + std::unordered_map snapshots; + void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 4e65a45a50d..6a0abaccf75 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -394,6 +394,9 @@ namespace GGUFMeta { template bool llama_model_loader::get_arr>(enum llm_kv kid, std::vector & result, bool required); + // Explicit instantiation for dflash-draft: target_capture_layers is a fixed 5-element array. + template bool llama_model_loader::get_arr(const std::string & key, std::array & result, bool required); + template bool llama_model_loader::get_key(const std::string & key, T & result, bool required) { auto it = kv_overrides.find(key); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index b265394ef73..234a8a681a8 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -16,6 +16,7 @@ #include "models/models.h" #include "ggml.h" +#include "ggml-backend.h" #include "ggml-cpp.h" #include @@ -2933,6 +2934,21 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_DFLASH_DRAFT: + { + // Standard transformer hparams are read from GGUF normally (n_layer=5, + // n_embd=5120, n_head=32, n_head_kv=8, n_embd_head=128, n_ff=2048). + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // dflash-draft specific metadata keys + ml.get_arr("dflash_draft.target_capture_layers", + hparams.dflash_target_capture_layers, false); + ml.get_key("dflash_draft.target_n_embd", hparams.dflash_target_n_embd, false); + ml.get_key("dflash_draft.mask_token_id", hparams.dflash_mask_token_id, false); + ml.get_key("dflash_draft.block_size", hparams.dflash_block_size, false); + + type = LLM_TYPE_UNKNOWN; // draft is a small auxiliary model, no standard type + } break; default: throw std::runtime_error("unsupported model architecture: " + arch_name()); } @@ -7971,6 +7987,56 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } } break; + case LLM_ARCH_DFLASH_DRAFT: + { + // dflash-draft: 5-layer non-causal speculative decoder. + // token_embd is NOT loaded — token embeddings are looked up from the + // target model at runtime via llama_model_token_embd_lookup. + // lm_head can be shared from the target model. The draft GGUF may + // still contain output.weight for standalone tests, but in server + // mode loading another copy costs about 1 GiB on Qwen3.5-27B. + // out_norm maps to model.output_norm; fc and hidden_norm are stored in dflash_fc / dflash_hidden_norm. + const int64_t n_draft_fc_in = (int64_t)5 * n_embd; // 5 * hidden = 25600 for 27B + + // Top-level: feature-fusion projection and norms + dflash_fc = create_tensor(tn(LLM_TENSOR_DFLASH_FC, "weight"), {n_draft_fc_in, n_embd}, 0); + dflash_hidden_norm = create_tensor(tn(LLM_TENSOR_DFLASH_HIDDEN_NORM, "weight"), {n_embd}, 0); + output_norm = create_tensor(tn(LLM_TENSOR_DFLASH_OUT_NORM, "weight"), {n_embd}, 0); + + const llama_model * target_model = params.target_model; + if (target_model != nullptr && target_model->output != nullptr && + target_model->output->ne[0] == n_embd && + target_model->output->ne[1] == n_vocab) { + (void) create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_SKIP); + output = target_model->output; + LLAMA_LOG_INFO("%s: dflash-draft: sharing target output.weight; skipped draft lm_head allocation\n", __func__); + } else { + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // Per-layer pre-norms + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + // Attention projections; Q projects to n_head * n_embd_head_k + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // Per-head Q/K norms (qwen3-style) + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + + // SwiGLU FFN; intermediate = n_ff (2048 for 27B draft) + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -8629,6 +8695,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, case LLM_ARCH_LLADA: case LLM_ARCH_LLADA_MOE: case LLM_ARCH_RND1: + case LLM_ARCH_DFLASH_DRAFT: // non-causal draft: no autoregressive KV cache needed { res = nullptr; } break; @@ -9259,6 +9326,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_DFLASH_DRAFT: + { + llm = std::make_unique(*this, params); + } break; default: GGML_ABORT("fatal error"); } @@ -9289,6 +9360,7 @@ llama_model_params llama_model_default_params() { llama_model_params result = { /*.devices =*/ nullptr, /*.tensor_buft_overrides =*/ nullptr, + /*.target_model =*/ nullptr, /*.n_gpu_layers =*/ -1, /*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER, /*.main_gpu =*/ 0, @@ -9511,6 +9583,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_QWEN3NEXT: case LLM_ARCH_MIMO2: case LLM_ARCH_STEP35: + case LLM_ARCH_DFLASH_DRAFT: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: @@ -9658,6 +9731,57 @@ bool llama_model_is_diffusion(const llama_model * model) { return llm_arch_is_diffusion(model->arch); } +int llama_model_token_embd_lookup( + const llama_model * model, + llama_token token, + float * out, + int64_t out_n) { + const ggml_tensor * t = model->tok_embd; + if (!t) { + return -1; + } + + const int64_t n_embd = t->ne[0]; // rows in ggml layout = embedding dim + const int64_t n_vocab = t->ne[1]; + + if (token < 0 || (int64_t)token >= n_vocab) { + return -1; + } + if (out_n < n_embd) { + return -1; + } + + const ggml_type dtype = t->type; + + if (dtype == GGML_TYPE_F32) { + const size_t row_bytes = (size_t)n_embd * sizeof(float); + ggml_backend_tensor_get(t, out, (size_t)token * row_bytes, row_bytes); + return 0; + } + + if (dtype == GGML_TYPE_F16) { + const size_t row_bytes = (size_t)n_embd * sizeof(ggml_fp16_t); + std::vector tmp(n_embd); + ggml_backend_tensor_get(t, tmp.data(), (size_t)token * row_bytes, row_bytes); + ggml_fp16_to_fp32_row(tmp.data(), out, n_embd); + return 0; + } + + // Quantized rows: fetch the raw bytes for one row and dequantize via the + // type's to_float trait. Required for Q4_K_M / Q5_K / etc. target models + // where tok_embd is stored quantized. + const auto * traits = ggml_get_type_traits(dtype); + if (traits != nullptr && traits->to_float != nullptr) { + const size_t row_bytes = ggml_row_size(dtype, n_embd); + std::vector tmp(row_bytes); + ggml_backend_tensor_get(t, tmp.data(), (size_t)token * row_bytes, row_bytes); + traits->to_float(tmp.data(), out, n_embd); + return 0; + } + + return -1; +} + const std::vector> & llama_internal_get_tensor_map(const llama_model * model) { return model->tensors_by_name; } diff --git a/src/llama-model.h b/src/llama-model.h index bba70012e11..5273c977e8b 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -546,6 +546,10 @@ struct llama_model { struct ggml_tensor * conv1d = nullptr; struct ggml_tensor * conv1d_b = nullptr; + // dflash-draft top-level tensors + struct ggml_tensor * dflash_fc = nullptr; // "fc" [5*n_embd, n_embd] + struct ggml_tensor * dflash_hidden_norm = nullptr; // "hidden_norm" [n_embd] + // gemma3n altup struct ggml_tensor * altup_proj = nullptr; struct ggml_tensor * altup_unembd_proj = nullptr; diff --git a/src/models/delta-net-base.cpp b/src/models/delta-net-base.cpp index 6bc989c9509..cbc823bbff5 100644 --- a/src/models/delta-net-base.cpp +++ b/src/models/delta-net-base.cpp @@ -428,6 +428,17 @@ std::pair llm_build_delta_net_base::build_delta_ne ggml_tensor * b, ggml_tensor * s, int il) { + // dispatch to tree variant when parent_ids are available + if (parent_ids != nullptr) { + // Phase 2.4: fetch per-layer persist buffer from graph context if allocated + ggml_tensor * persist_inter = nullptr; + if (dflash_persist_inter_l != nullptr && il >= 0 && + il < (int32_t)dflash_persist_inter_l->size()) { + persist_inter = (*dflash_persist_inter_l)[il]; + } + return build_delta_net_tree(q, k, v, g, b, s, parent_ids, persist_inter, il); + } + const int64_t n_seq_tokens = q->ne[2]; if (n_seq_tokens == 1) { @@ -443,3 +454,43 @@ std::pair llm_build_delta_net_base::build_delta_ne return build_delta_net_chunking(q, k, v, g, b, s, il); } + +std::pair llm_build_delta_net_base::build_delta_net_tree( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + ggml_tensor * par_ids, + ggml_tensor * persist_inter, + int il) { + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + const int64_t n_tokens = v->ne[2]; + const int64_t n_seqs = v->ne[3]; + + // ggml_gated_delta_net_tree has the same packed output layout as ggml_gated_delta_net + ggml_tensor * result; + if (persist_inter != nullptr) { + result = ggml_gated_delta_net_tree_persist(ctx0, q, k, v, g, b, s, par_ids, persist_inter); + } else { + result = ggml_gated_delta_net_tree(ctx0, q, k, v, g, b, s, par_ids); + } + cb(result, "fgdn_tree", il); + + ggml_tensor * output = ggml_view_4d(ctx0, result, + S_v, H_v, n_tokens, n_seqs, + ggml_row_size(result->type, S_v), + ggml_row_size(result->type, S_v * H_v), + ggml_row_size(result->type, S_v * H_v * n_tokens), 0); + + ggml_tensor * new_state = ggml_view_4d(ctx0, result, + S_v, S_v, H_v, n_seqs, + ggml_row_size(result->type, S_v), + ggml_row_size(result->type, S_v * S_v), + ggml_row_size(result->type, S_v * S_v * H_v), + ggml_row_size(result->type, S_v * H_v * n_tokens * n_seqs)); + + return {output, new_state}; +} diff --git a/src/models/dflash-draft.cpp b/src/models/dflash-draft.cpp new file mode 100644 index 00000000000..075cbda8d9d --- /dev/null +++ b/src/models/dflash-draft.cpp @@ -0,0 +1,317 @@ +// dflash-draft.cpp — Graph builder for the z-lab/Qwen3.5-27B-DFlash speculative draft model. +// +// Architecture: 5-layer non-causal transformer. +// Inputs (host-provided, set via ggml_set_input): +// - noise_embed : [n_embd, block_size] — pre-looked-up rows from target tok_embd +// - target_feat_raw: [5*n_embd, ctx_len] — stacked hidden captures from target layers +// Forward pass (mirrors qwen3_dflash_graph.cpp:53-164): +// 1. fc(target_feat_raw) → rms_norm(hidden_norm) → target_feat [n_embd, ctx_len] +// 2. For each of 5 layers: +// - Q: from noise_embed only (attn_norm, wq, q_norm, RoPE-NEOX) +// - K/V: from concat(target_feat, noise), then wk/wv, k_norm, RoPE-NEOX +// - Non-causal FlashAttention (mask=nullptr), GQA 32:8 +// - SwiGLU FFN +// 3. out_norm + shared lm_head → logits [vocab, block_size] + +#include "models.h" +#include "llama-impl.h" // LLAMA_TENSOR_NAME_FATTN +#include "llama-graph.h" // llm_graph_input_target_feat, build_inp_target_feat + +#include + +llm_build_dflash_draft::llm_build_dflash_draft( + const llama_model & model, + const llm_graph_params & params) : llm_graph_context(params) { + + const int64_t n_embd_head = hparams.n_embd_head_k(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_v()); + + // draft constants derived from hparams + const int64_t n_embd_fc = (int64_t)5 * n_embd; // 5*hidden for fc input + + // rope_theta = 10M for draft (matches DFLASH27B_ROPE_THETA) + const float draft_rope_theta = 10000000.0f; + const float scale = 1.0f / sqrtf((float)n_embd_head); + + // ── Draft-specific target-feature input ─────────────────────────────────── + // target_feat_raw / pos_q / pos_k: registered as graph inputs via build_inp_target_feat. + // The host stashes data with llama_set_target_feat_raw() before llama_decode(); the + // graph input class copies it into these GGML tensors at set_input() time. + // + // ctx_len is read directly from the pending context length stashed by the driver before + // llama_decode(). At reservation time (graph preheating), the pointer may be null/zero, + // in which case we fall back to n_ctx as the worst-case upper bound. + const int64_t ctx_len = (pending_target_feat_ctx_len_ptr && *pending_target_feat_ctx_len_ptr > 0) + ? *pending_target_feat_ctx_len_ptr + : n_ctx; + ggml_tensor * cached_target_feat = (pending_target_feat_tensor_ptr != nullptr) + ? *pending_target_feat_tensor_ptr + : nullptr; + const bool use_cached_target_feat = cached_target_feat != nullptr; + const int64_t target_feat_width = dflash_target_feat_fused ? n_embd : n_embd_fc; + llm_graph_input_target_feat * inp_tf = build_inp_target_feat(target_feat_width, ctx_len); + + ggml_tensor * target_feat_in = use_cached_target_feat ? cached_target_feat : inp_tf->inp_target_feat_raw; + ggml_tensor * pos_q = inp_tf->inp_pos_q; + ggml_tensor * pos_k = inp_tf->inp_pos_k; + + // ── Step 1: feature fusion ──────────────────────────────────────────────── + // target_feat = rms_norm(fc @ target_feat_raw, hidden_norm) + // The dedicated draft runtime can pass a cached fused target_feat directly. + ggml_tensor * target_feat = target_feat_in; + if (!dflash_target_feat_fused) { + if (use_cached_target_feat && target_feat_in->ne[0] == n_embd && target_feat_in->ne[1] == 5*ctx_len) { + ggml_tensor * packed = nullptr; + for (int l = 0; l < 5; ++l) { + ggml_tensor * layer = ggml_view_2d(ctx0, target_feat_in, n_embd, ctx_len, + (size_t)n_embd * ggml_element_size(target_feat_in), + (size_t)l * ctx_len * n_embd * ggml_element_size(target_feat_in)); + packed = packed == nullptr ? layer : ggml_concat(ctx0, packed, layer, 0); + } + target_feat_in = packed; + } + // fc: [n_embd_fc, n_embd] (ggml: ne[0]=n_embd_fc, ne[1]=n_embd) + // target_feat_raw: [n_embd_fc, ctx_len] + // Result: [n_embd, ctx_len] + target_feat = ggml_mul_mat(ctx0, model.dflash_fc, target_feat_in); + cb(target_feat, "dflash_fc_out", -1); + + target_feat = ggml_rms_norm(ctx0, target_feat, hparams.f_norm_rms_eps); + target_feat = ggml_mul(ctx0, target_feat, model.dflash_hidden_norm); + } + GGML_ASSERT(target_feat->ne[0] == n_embd); + GGML_ASSERT(target_feat->ne[1] == ctx_len); + cb(target_feat, "dflash_target_feat", -1); + + if (dflash_fuse_only && !dflash_kv_update_only) { + res->t_embd = target_feat; + ggml_build_forward_expand(gf, target_feat); + return; + } + + if (dflash_kv_update_only) { + GGML_ASSERT(dflash_kv_cache_k_l != nullptr); + GGML_ASSERT(dflash_kv_cache_v_l != nullptr); + GGML_ASSERT((int64_t) dflash_kv_cache_k_l->size() >= n_layer); + GGML_ASSERT((int64_t) dflash_kv_cache_v_l->size() >= n_layer); + + for (int il = 0; il < n_layer; ++il) { + const auto & layer = model.layers[il]; + + ggml_tensor * K = ggml_mul_mat(ctx0, layer.wk, target_feat); + K = ggml_reshape_3d(ctx0, K, n_embd_head, n_head_kv, ctx_len); + K = ggml_rms_norm(ctx0, K, hparams.f_norm_rms_eps); + K = ggml_mul(ctx0, K, layer.attn_k_norm); + cb(K, "dflash_k_cache_update", il); + + ggml_tensor * V = ggml_mul_mat(ctx0, layer.wv, target_feat); + V = ggml_reshape_3d(ctx0, V, n_embd_head, n_head_kv, ctx_len); + cb(V, "dflash_v_cache_update", il); + + ggml_tensor * Kdst = (*dflash_kv_cache_k_l)[il]; + ggml_tensor * Vdst = (*dflash_kv_cache_v_l)[il]; + GGML_ASSERT(Kdst != nullptr && Vdst != nullptr); + const int64_t n_el = n_embd_head * n_head_kv * ctx_len; + const size_t dst_off = (size_t) dflash_kv_cache_dst_pos * n_embd_head * n_head_kv * sizeof(float); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_view_1d(ctx0, K, n_el, 0), + ggml_view_1d(ctx0, Kdst, n_el, dst_off))); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_view_1d(ctx0, V, n_el, 0), + ggml_view_1d(ctx0, Vdst, n_el, dst_off))); + } + res->t_embd = target_feat; + return; + } + + // noise_embed: pre-computed embedding rows [n_embd, block_size] — host fills this + // through ubatch.embd. It must be registered as a graph input; merely calling + // ggml_set_input() is not enough for llama_decode() to populate it. + auto inp_noise = std::make_unique(n_embd); + inp_noise->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); + ggml_set_name(inp_noise->embd, "dflash_noise_embed"); + ggml_set_input(inp_noise->embd); + ggml_tensor * noise_embed = inp_noise->embd; + res->add_input(std::move(inp_noise)); + + // ── Step 2: position tensors ────────────────────────────────────────────── + // Q positions: [ctx_len .. ctx_len + block_size) in draft-window-local + // coordinates, matching standalone DFlash. + // K positions: [0 .. ctx_len + block_size) + // Both tensors were created and registered by build_inp_target_feat() above. + // set_input() fills them from pending_draft_committed_pos before each decode. + const int64_t total_k = ctx_len + n_tokens; + + // ── Step 3: 5-layer decoder ─────────────────────────────────────────────── + ggml_tensor * h = noise_embed; // [n_embd, block_size] + + for (int il = 0; il < n_layer; ++il) { + const auto & layer = model.layers[il]; + + // -- Attention pre-norm on noise hidden state + ggml_tensor * hn = ggml_rms_norm(ctx0, h, hparams.f_norm_rms_eps); + hn = ggml_mul(ctx0, hn, layer.attn_norm); // layer.attn_norm: [n_embd] + cb(hn, "attn_norm", il); + + // -- Q from noise only: wq [n_embd, n_head*n_embd_head], reshaped, q_norm, RoPE + ggml_tensor * Q = ggml_mul_mat(ctx0, layer.wq, hn); // [n_head*n_embd_head, block_size] + Q = ggml_reshape_3d(ctx0, Q, n_embd_head, n_head, n_tokens); // [n_embd_head, n_head, block_size] + Q = ggml_rms_norm(ctx0, Q, hparams.f_norm_rms_eps); // per-head rms_norm along n_embd_head + Q = ggml_mul(ctx0, Q, layer.attn_q_norm); // broadcast [n_embd_head] + cb(Q, "Q_normed", il); + + // Q RoPE-NEOX + Q = ggml_rope_ext(ctx0, Q, pos_q, nullptr, + (int)n_embd_head, + (int)LLAMA_ROPE_TYPE_NEOX, + /*n_ctx_orig=*/0, + draft_rope_theta, + /*freq_scale=*/1.0f, + /*ext_factor=*/0.0f, + /*attn_factor=*/1.0f, + /*beta_fast=*/0.0f, + /*beta_slow=*/0.0f); + cb(Q, "Q_rope", il); + + // -- K and V from concat(target_feat, noise) + ggml_tensor * K = nullptr; + ggml_tensor * V = nullptr; + const bool use_cached_kv = dflash_kv_cache_k_l != nullptr && dflash_kv_cache_v_l != nullptr && + (int64_t) dflash_kv_cache_k_l->size() > il && + (int64_t) dflash_kv_cache_v_l->size() > il && + (*dflash_kv_cache_k_l)[il] != nullptr && + (*dflash_kv_cache_v_l)[il] != nullptr; + if (use_cached_kv) { + ggml_tensor * Kctx = (*dflash_kv_cache_k_l)[il]; + ggml_tensor * Vctx = (*dflash_kv_cache_v_l)[il]; + GGML_ASSERT(Kctx->ne[0] == n_embd_head && Kctx->ne[1] == n_head_kv && Kctx->ne[2] == ctx_len); + GGML_ASSERT(Vctx->ne[0] == n_embd_head && Vctx->ne[1] == n_head_kv && Vctx->ne[2] == ctx_len); + + ggml_tensor * Kn = ggml_mul_mat(ctx0, layer.wk, hn); + Kn = ggml_reshape_3d(ctx0, Kn, n_embd_head, n_head_kv, n_tokens); + Kn = ggml_rms_norm(ctx0, Kn, hparams.f_norm_rms_eps); + Kn = ggml_mul(ctx0, Kn, layer.attn_k_norm); + + ggml_tensor * Vn = ggml_mul_mat(ctx0, layer.wv, hn); + Vn = ggml_reshape_3d(ctx0, Vn, n_embd_head, n_head_kv, n_tokens); + + K = ggml_concat(ctx0, Kctx, Kn, 2); + V = ggml_concat(ctx0, Vctx, Vn, 2); + cb(K, "K_normed", il); + } else { + // First compute K/V from target_feat (ctx_len tokens) + ggml_tensor * Kctx = ggml_mul_mat(ctx0, layer.wk, target_feat); // [n_head_kv*n_embd_head, ctx_len] + ggml_tensor * Vctx = ggml_mul_mat(ctx0, layer.wv, target_feat); + + // Then from noise (block_size tokens) + ggml_tensor * Kn = ggml_mul_mat(ctx0, layer.wk, hn); // [n_head_kv*n_embd_head, block_size] + ggml_tensor * Vn = ggml_mul_mat(ctx0, layer.wv, hn); + + // Concat along sequence dimension (ne[1]) + K = ggml_concat(ctx0, Kctx, Kn, 1); // [n_head_kv*n_embd_head, total_k] + V = ggml_concat(ctx0, Vctx, Vn, 1); + + // Per-head K norm + K = ggml_reshape_3d(ctx0, K, n_embd_head, n_head_kv, total_k); + K = ggml_rms_norm(ctx0, K, hparams.f_norm_rms_eps); + K = ggml_mul(ctx0, K, layer.attn_k_norm); + cb(K, "K_normed", il); + + V = ggml_reshape_3d(ctx0, V, n_embd_head, n_head_kv, total_k); + } + + // K RoPE-NEOX + K = ggml_rope_ext(ctx0, K, pos_k, nullptr, + (int)n_embd_head, + (int)LLAMA_ROPE_TYPE_NEOX, + 0, + draft_rope_theta, + 1.0f, 0.0f, 1.0f, 0.0f, 0.0f); + cb(K, "K_rope", il); + + // Permute into flash_attn_ext layout + // Q: [n_embd_head, n_head, block_size, 1] + // K: [n_embd_head, n_head_kv, total_k, 1] + // V: [n_embd_head, n_head_kv, total_k, 1] (not transposed) + Q = ggml_permute(ctx0, Q, 0, 2, 1, 3); + Q = ggml_cont(ctx0, Q); + K = ggml_permute(ctx0, K, 0, 2, 1, 3); + K = ggml_cont(ctx0, K); + V = ggml_permute(ctx0, V, 0, 2, 1, 3); + V = ggml_cont(ctx0, V); + + // Non-causal flash attention; mask=nullptr, GQA broadcast handled internally. + ggml_tensor * attn = ggml_flash_attn_ext(ctx0, Q, K, V, + /*mask=*/nullptr, + scale, + /*max_bias=*/0.0f, + /*logit_softcap=*/0.0f); + // Name the FA tensor so sched_reserve's auto_fa name-prefix assert passes. + cb(attn, LLAMA_TENSOR_NAME_FATTN, il); + // attn: [n_embd_head, n_head, block_size, 1] + attn = ggml_reshape_2d(ctx0, attn, n_embd_head * n_head, n_tokens); + cb(attn, "attn_out", il); + + // Output projection + residual + ggml_tensor * attn_proj = ggml_mul_mat(ctx0, layer.wo, attn); + h = ggml_add(ctx0, h, attn_proj); + cb(h, "attn_residual", il); + + // -- FFN pre-norm + ggml_tensor * hf = ggml_rms_norm(ctx0, h, hparams.f_norm_rms_eps); + hf = ggml_mul(ctx0, hf, layer.ffn_norm); + cb(hf, "ffn_norm", il); + + // SwiGLU: down(silu(gate(x)) * up(x)) + ggml_tensor * g = ggml_mul_mat(ctx0, layer.ffn_gate, hf); + g = ggml_silu(ctx0, g); + ggml_tensor * u = ggml_mul_mat(ctx0, layer.ffn_up, hf); + ggml_tensor * gu = ggml_mul(ctx0, g, u); + ggml_tensor * ffn_out = ggml_mul_mat(ctx0, layer.ffn_down, gu); + cb(ffn_out, "ffn_out", il); + + h = ggml_add(ctx0, h, ffn_out); + cb(h, "l_out", il); + } + + // ── Step 4: final norm + lm_head ───────────────────────────────────────── + ggml_tensor * out = ggml_rms_norm(ctx0, h, hparams.f_norm_rms_eps); + out = ggml_mul(ctx0, out, model.output_norm); // model.output_norm == dflash out_norm + cb(out, "result_norm", -1); + res->t_embd = out; + + // lm_head is shared from target model — it must be provided via model.output. + // If not yet wired (Phase 3 test mode), output is the hidden state only. + if (model.output != nullptr) { + ggml_tensor * logits = ggml_mul_mat(ctx0, model.output, out); + cb(logits, "result_output", -1); + + if (dflash_draft_top_k > 0) { + const int top_k = std::min(dflash_draft_top_k, logits->ne[0]); + + ggml_tensor * top_ids = ggml_top_k(ctx0, logits, top_k); + cb(top_ids, "dflash_top_ids", -1); + + ggml_tensor * logits_rows = ggml_reshape_3d(ctx0, logits, 1, logits->ne[0], n_tokens); + ggml_tensor * top_logits = ggml_get_rows(ctx0, logits_rows, top_ids); + top_logits = ggml_reshape_2d(ctx0, top_logits, top_k, n_tokens); + cb(top_logits, "dflash_top_logits", -1); + + ggml_tensor * probs = ggml_soft_max(ctx0, logits); + cb(probs, "dflash_probs", -1); + ggml_tensor * probs_rows = ggml_reshape_3d(ctx0, probs, 1, probs->ne[0], n_tokens); + ggml_tensor * top_log_probs = ggml_get_rows(ctx0, probs_rows, top_ids); + top_log_probs = ggml_log(ctx0, top_log_probs); + top_log_probs = ggml_reshape_2d(ctx0, top_log_probs, top_k, n_tokens); + cb(top_log_probs, "dflash_top_log_probs", -1); + + res->t_dflash_top_ids = top_ids; + res->t_dflash_top_logits = top_log_probs; + ggml_build_forward_expand(gf, top_ids); + ggml_build_forward_expand(gf, top_log_probs); + } else { + res->t_logits = logits; + ggml_build_forward_expand(gf, logits); + } + } else { + ggml_build_forward_expand(gf, out); + } +} diff --git a/src/models/models.h b/src/models/models.h index a6682ebb287..e2d4e2cc630 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -56,6 +56,7 @@ struct llm_build_delta_net_base : public llm_graph_context { int il); // choose one of two implementations above based on the number of tokens + // if parent_ids != nullptr, dispatches to build_delta_net_tree std::pair build_delta_net( ggml_tensor * q, ggml_tensor * k, @@ -64,6 +65,20 @@ struct llm_build_delta_net_base : public llm_graph_context { ggml_tensor * b, ggml_tensor * s, int il); + + // tree-mode variant: uses ggml_ssm_conv_tree + ggml_gated_delta_net_tree + // When persist_inter != nullptr, calls ggml_gated_delta_net_tree_persist to write + // per-token intermediate states to the provided external buffer. + std::pair build_delta_net_tree( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + ggml_tensor * par_ids, // [n_tokens] i32 + ggml_tensor * persist_inter, // optional, may be null + int il); }; struct llm_build_rwkv6_base : public llm_graph_context { @@ -182,6 +197,10 @@ struct llm_build_dbrx : public llm_graph_context { llm_build_dbrx(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_dflash_draft : public llm_graph_context { + llm_build_dflash_draft(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_deci : public llm_graph_context { llm_build_deci(const llama_model & model, const llm_graph_params & params); }; diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index 28df353050b..6039b358c89 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -14,10 +14,34 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa ggml_tensor * cur; ggml_tensor * inpL; + // dflash hidden capture: per-slot tensors collected during the forward pass. + // Slots are ggml_concat'd into a single [n_embd, 5*n_tokens] tensor AFTER + // the layer loop and registered as t_hidden_capture (OUTPUT). This ensures + // the concat node is a real compute graph output that gallocr / the sched + // execute and sync back to the host — avoiding the INPUT-leaf + cpy-to-view + // pattern which silently produces all-zeros on GPU (cpy dst is CPU-pinned + // while the src lives on the device backend). + ggml_tensor * cap_slots[5] = {nullptr, nullptr, nullptr, nullptr, nullptr}; + inpL = build_inp_embd(model.tok_embd); cb(inpL, "model.input_embed", -1); + // build tree-mode inputs when parent_ids are present in the ubatch. + // LLAMA_DDTREE_FORCE_CHAIN_KERNEL=1 skips the tree input wiring; downstream + // conv/delta-net dispatch then falls back to the chain kernel (parent_ids + // member stays null). Diagnostic only — sibling/cousin tokens are wrong, + // root token stays equivalent to chain. + if (ubatch.parent_id != nullptr) { + static const bool s_ddtree_force_chain_kernel = []{ + const char * e = getenv("LLAMA_DDTREE_FORCE_CHAIN_KERNEL"); + return e && e[0] == '1'; + }(); + if (!s_ddtree_force_chain_kernel) { + build_inp_tree(); + } + } + auto * inp = build_inp_mem_hybrid(); ggml_tensor * inp_pos = build_inp_pos(); @@ -64,6 +88,21 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa cur = ggml_add(ctx0, cur, ffn_residual); cb(cur, "post_ffn", il); + // dflash hidden capture: stash cur in cap_slots[k] for later concat. + // Critical invariant: this block is NOT entered when capture_hidden==false, + // so the baseline qwen35 forward is byte-for-byte unchanged. + if (capture_hidden) { + const auto & cl = hparams.dflash_target_capture_layers; + for (int k = 0; k < 5; ++k) { + if ((int)cl[k] == il) { + // ggml_cont ensures the slot is a standalone contiguous node + // (cur may be a non-owning view after certain ops). + cap_slots[k] = ggml_cont(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens)); + break; + } + } + } + cur = build_cvec(cur, il); cb(cur, "l_out", il); @@ -72,6 +111,24 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa } cur = inpL; + // dflash hidden capture: concat the 5 collected slots along dim 1 into + // [n_embd, 5*n_tokens] and register as t_hidden_capture (OUTPUT). + // The concat result is a regular compute node — gallocr schedules it on the + // device backend and the sched syncs it to host after the forward pass. + if (capture_hidden) { + for (int k = 0; k < 5; ++k) { + GGML_ASSERT(cap_slots[k] != nullptr && + "dflash_target_capture_layers must cover all 5 slots; check hparams"); + } + ggml_tensor * cap = ggml_concat(ctx0, cap_slots[0], cap_slots[1], 1); + cap = ggml_concat(ctx0, cap, cap_slots[2], 1); + cap = ggml_concat(ctx0, cap, cap_slots[3], 1); + cap = ggml_concat(ctx0, cap, cap_slots[4], 1); + ggml_set_name(cap, "dflash_hidden_capture"); + ggml_build_forward_expand(gf, cap); + res->t_hidden_capture = cap; + } + // Final norm cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); @@ -82,9 +139,35 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res->t_logits = cur; - ggml_build_forward_expand(gf, cur); + if (dflash_draft_top_k > 0) { + const int top_k = std::min(dflash_draft_top_k, cur->ne[0]); + + if (top_k == 1) { + ggml_tensor * top_ids = ggml_argmax(ctx0, cur); + top_ids = ggml_reshape_2d(ctx0, top_ids, 1, cur->ne[1]); + cb(top_ids, "dflash_target_argmax_ids", -1); + + res->t_dflash_top_ids = top_ids; + ggml_build_forward_expand(gf, top_ids); + } else { + ggml_tensor * top_ids = ggml_top_k(ctx0, cur, top_k); + cb(top_ids, "dflash_target_top_ids", -1); + + ggml_tensor * logits_rows = ggml_reshape_3d(ctx0, cur, 1, cur->ne[0], cur->ne[1]); + ggml_tensor * top_logits = ggml_get_rows(ctx0, logits_rows, top_ids); + top_logits = ggml_reshape_2d(ctx0, top_logits, top_k, cur->ne[1]); + cb(top_logits, "dflash_target_top_logits", -1); + + res->t_dflash_top_ids = top_ids; + res->t_dflash_top_logits = top_logits; + ggml_build_forward_expand(gf, top_ids); + ggml_build_forward_expand(gf, top_logits); + } + } else { + res->t_logits = cur; + ggml_build_forward_expand(gf, cur); + } } std::pair llm_build_qwen35::build_qkvz( @@ -253,6 +336,29 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; + ggml_tensor * conv_persist = nullptr; + if (parent_ids != nullptr && dflash_persist_conv_l != nullptr && + il >= 0 && il < (int32_t)dflash_persist_conv_l->size()) { + conv_persist = (*dflash_persist_conv_l)[il]; + } + + ggml_tensor * ssm_persist = nullptr; + if (parent_ids != nullptr && dflash_persist_inter_l != nullptr && + il >= 0 && il < (int32_t)dflash_persist_inter_l->size()) { + ssm_persist = (*dflash_persist_inter_l)[il]; + } + + static const bool s_skip_tree_live_updates = []{ + const char * e = getenv("LLAMA_DDTREE_SKIP_TREE_LIVE_UPDATES"); + return e == nullptr || e[0] != '0'; + }(); + const bool skip_tree_live_updates = + s_skip_tree_live_updates && + parent_ids != nullptr && + n_seq_tokens > 1 && + conv_persist != nullptr && + ssm_persist != nullptr; + conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); cb(conv_states, "conv_states_reshaped", il); @@ -274,13 +380,25 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); cb(state_update_target, "state_update_target", il); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + if (!skip_tree_live_updates) { + ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + } ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); cb(state, "state_predelta", il); - ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel); + // use tree conv when parent_ids are set; identical output shape to ggml_ssm_conv. + // Phase 5 fix: when a per-layer conv-persist buffer is allocated, use the + // _persist variant so each token writes its post-state for SSM rollback. + ggml_tensor * conv_output_proper; + if (parent_ids != nullptr) { + conv_output_proper = (conv_persist != nullptr) + ? ggml_ssm_conv_tree_persist(ctx0, conv_input, conv_kernel, parent_ids, conv_persist) + : ggml_ssm_conv_tree (ctx0, conv_input, conv_kernel, parent_ids); + } else { + conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel); + } cb(conv_output_proper, "conv_output_raw", il); ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper); @@ -344,10 +462,12 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( cb(new_state, "new_state", il); // Update the recurrent states - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + if (!skip_tree_live_updates) { + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, new_state, + ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], + kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + } // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); diff --git a/tests/.gitignore b/tests/.gitignore index 52b292b1f87..aea409dd1b9 100644 --- a/tests/.gitignore +++ b/tests/.gitignore @@ -1,6 +1,8 @@ * !*.* !snapshots/ +!fixtures/ +!fixtures/** *.o ggml-common.h **/*.swp diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index cd4bc5ef1d3..0e3dcd61c72 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -210,6 +210,8 @@ llama_build_and_test( peg-parser/tests.h ) llama_build_and_test(test-regex-partial.cpp) +llama_build_and_test(test-speculative-tree.cpp) +llama_build_and_test(test-speculative-draft-backend.cpp) if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x") set(MODEL_NAME "tinyllamas/stories15M-q4_0.gguf") @@ -298,3 +300,59 @@ if (TARGET gguf-model-data) target_link_libraries(export-graph-ops PRIVATE gguf-model-data) target_compile_definitions(export-graph-ops PRIVATE LLAMA_HF_FETCH) endif() + +# DDTree Phase 1 acceptance test. +# Requires a 16+ GB Qwen3.5-27B GGUF not available in CI. +# Enable manually: cmake -DLLAMA_BUILD_TESTS_QWEN35_TREE=ON ... +option(LLAMA_BUILD_TESTS_QWEN35_TREE "Build DDTree tree-mode acceptance test (requires large GGUF)" OFF) +if (LLAMA_BUILD_TESTS_QWEN35_TREE) + llama_build(test-qwen35-tree.cpp) + # Note: NOT registered with llama_test / add_test — must be run manually. +endif() + +# DDTree diagnostic: chain forward vs single tree-mode root forward. +# Used to isolate whether the long-prompt ddtree bug is in the tree kernels +# or above (driver / KV slot / capture). Same large GGUF requirement. +option(LLAMA_BUILD_TESTS_QWEN35_ROOT_VS_CHAIN "Build DDTree root-vs-chain diagnostic (requires large GGUF)" OFF) +if (LLAMA_BUILD_TESTS_QWEN35_ROOT_VS_CHAIN) + llama_build(test-qwen35-root-vs-chain.cpp) + # Manual run only. +endif() + +# DDTree Phase 2 acceptance test: snapshot/restore symmetry. +# Requires a 16+ GB Qwen3.5-27B GGUF not available in CI. +# Requires the llama_seq_snapshot / llama_seq_restore / llama_seq_release API +# delivered by the Phase 2 implementation agent. +# Enable manually: cmake -DLLAMA_BUILD_TESTS_QWEN35_TREE_ROLLBACK=ON ... +option(LLAMA_BUILD_TESTS_QWEN35_TREE_ROLLBACK "Build DDTree snapshot/restore acceptance test (requires large GGUF)" OFF) +if (LLAMA_BUILD_TESTS_QWEN35_TREE_ROLLBACK) + llama_build(test-qwen35-tree-rollback.cpp) + # Note: NOT registered with llama_test / add_test — must be run manually. +endif() + +# DDTree Phase 3 acceptance tests: dflash-draft forward + hidden-state capture. +# Requires a 16+ GB Qwen3.5-27B GGUF and a converted dflash-draft GGUF. +# Requires Phase 3 API from the implementation agent: +# llama_model_token_embd_lookup, llama_set_capture_hidden, llama_get_hidden_capture. +# Enable manually: cmake -DLLAMA_BUILD_TESTS_DFLASH_DRAFT=ON ... +option(LLAMA_BUILD_TESTS_DFLASH_DRAFT "Build DDTree Phase 3 dflash-draft and chain-capture acceptance tests (requires large GGUFs)" OFF) +if (LLAMA_BUILD_TESTS_DFLASH_DRAFT) + llama_build(test-dflash-draft.cpp) + llama_build(test-qwen35-chain-capture.cpp) + # Note: NOT registered with llama_test / add_test — must be run manually. +endif() + +# DDTree Phase 4 end-to-end speculative decode acceptance test. +# Requires Qwen3.5-27B GGUF (~16 GB) and dflash-draft GGUF. Not available in CI. +# Requires Phase 4 API from the implementation agent: +# llama_speculative_tree_driver_init / _step / _free (common/speculative-tree-driver.h) +# llama_set_target_feat_raw (Phase 3 gap, llama.h) +# Enable manually: cmake -DLLAMA_BUILD_TESTS_SPECULATIVE_TREE_E2E=ON ... +option(LLAMA_BUILD_TESTS_SPECULATIVE_TREE_E2E + "Build DDTree Phase 4 end-to-end speculative decode acceptance test (requires large GGUFs)" + OFF) +if (LLAMA_BUILD_TESTS_SPECULATIVE_TREE_E2E) + llama_build(test-speculative-tree-e2e.cpp) + # Note: NOT registered with llama_test / add_test — must be run manually. + # Both --target-model and --draft-model are required at runtime. +endif() diff --git a/tests/fixtures/ddtree/README.md b/tests/fixtures/ddtree/README.md new file mode 100644 index 00000000000..de6a84d11a7 --- /dev/null +++ b/tests/fixtures/ddtree/README.md @@ -0,0 +1,332 @@ +# DDTree Phase 1 test fixtures + +## Files + +- `tree_5node.json` — 5-node tree fixture (1 root + 4 verify nodes). Used by `test-qwen35-tree --mode tree`. Token IDs are mid-range Qwen3.5 placeholders; regenerate from real sampled tokens as needed. + +- `short_prompt.bin` — 16 int32 LE token IDs for chain-mode warm-up (Test 1.A). Hardcoded Qwen3.5 BOS/system-prompt tokens. Regenerate with `make_short_prompt.py`. + +- `make_short_prompt.py` — standalone script; reads a text string and writes int32 LE tokens to `short_prompt.bin`. Requires `llama-cpp-python`. Falls back to the hardcoded 16-token fixture if the package is absent. + +## Build + +The test binary is gated behind a CMake option (off by default, not in ctest): + +``` +cmake -DLLAMA_BUILD_TESTS_QWEN35_TREE=ON .. +make test-qwen35-tree +``` + +Requires a Qwen3.5-27B GGUF (~16 GB). Not available in CI. + +## Running (castle only) + +Test 1.A — chain mode does not regress vs current fork master: + +```bash +./test-qwen35-tree --mode chain \ + --model /path/to/Qwen3.5-27B-Q4_K_M.gguf \ + --prompt-tokens fixtures/ddtree/short_prompt.bin \ + --out-logits /tmp/chain.bin + +# Compare against a golden dump produced by an unmodified fork build: +python3 scripts/compare_logits.py /tmp/chain_golden.bin /tmp/chain.bin +``` + +Test 1.B — tree mode aligns with test_dflash (blocked, see below): + +```bash +./test-qwen35-tree --mode tree \ + --model /path/to/Qwen3.5-27B-Q4_K_M.gguf \ + --tree-fixture fixtures/ddtree/tree_5node.json \ + --out-logits /tmp/tree.bin + +python3 scripts/compare_logits.py /tmp/test_dflash_golden.bin /tmp/tree.bin +``` + +## Blocker: Test 1.B prerequisite + +Test 1.B requires `test_dflash` to support `--dump-verify-logits`, which dumps the per-node logits produced by the dflash tree forward. This flag is a Phase 0 prerequisite listed in roadmap section 7.2 and is not yet implemented. Until it lands, Test 1.B cannot produce a golden reference and cannot be run end-to-end. + +--- + +# DDTree Phase 2 test fixtures + +## Build + +Gated behind a CMake option (off by default, not in ctest): + +``` +cmake -DLLAMA_BUILD_TESTS_QWEN35_TREE_ROLLBACK=ON .. +make test-qwen35-tree-rollback +``` + +Requires Phase 2 `llama_seq_snapshot` / `llama_seq_restore` / `llama_seq_release` API from the implementation agent. + +## Test 2.A — snapshot/restore symmetry + +```bash +./test-qwen35-tree-rollback \ + --model /path/to/Qwen3.5-27B-Q4_K_M.gguf \ + --prompt-tokens fixtures/ddtree/short_prompt.bin \ + --gen 8 \ + --out-logits-pre /tmp/pre.bin \ + --out-logits-post /tmp/post.bin + +# Both logit dumps must be bit-equal: +python3 scripts/compare_logits.py /tmp/pre.bin /tmp/post.bin --abs-tol 0 --rel-tol 0 +``` + +## Test 2.B — BLOCKED + +Requires `test_dflash --dump-state-at-commit`, a Phase 0 prerequisite not yet implemented. Until that flag lands, Test 2.B (tree partial-accept vs sequential golden state) cannot be run. + +## Test 2.C — deferred + +Long-prompt OOM stress test deferred to Phase 5 server integration. + +--- + +# DDTree Phase 3 test fixtures + +## Files + +- `dflash_draft_metadata_smoke.json` — expected GGUF metadata fields for a + converted dflash-draft model. Used by `check_dflash_draft_gguf.py`. + +## Build + +Both Phase 3 test binaries are gated behind a single CMake option (off by +default, not in ctest): + +``` +cmake -DLLAMA_BUILD_TESTS_DFLASH_DRAFT=ON .. +make test-dflash-draft test-qwen35-chain-capture +``` + +Requires: +- A Qwen3.5-27B GGUF (~16 GB). +- A converted dflash-draft GGUF (see conversion step below). +- Phase 3 implementation API: `llama_model_token_embd_lookup`, + `llama_set_capture_hidden`, `llama_get_hidden_capture`. + +## Converting safetensors to dflash-draft GGUF + +The conversion script is written by the implementation agent in parallel. +Once it lands at `repo/dflash/scripts/convert_dflash_draft.py`: + +```bash +python repo/dflash/scripts/convert_dflash_draft.py \ + /path/to/dflash_draft/model.safetensors \ + -o /path/to/draft.gguf +``` + +Until the script lands this step is a TODO. + +## Validating the converted GGUF (Test 3.A) + +```bash +python repo/scripts/check_dflash_draft_gguf.py \ + /path/to/draft.gguf \ + tests/fixtures/ddtree/dflash_draft_metadata_smoke.json +# Exit 0: PASS. Exit 1: one line per discrepant field on stderr. +``` + +## Test 3.B — Draft forward bit-equal vs dflash reference + +BLOCKED on Phase 0 prerequisite: `test_dflash --dump-draft-output` flag is +not yet implemented. Until that flag lands, no golden reference exists and +the end-to-end comparison cannot be run. + +The driver (`test-dflash-draft`) can still be used standalone to inspect +draft logits: + +```bash +./test-dflash-draft \ + --target-model /path/to/Qwen3.5-27B-Q4_K_M.gguf \ + --draft-model /path/to/draft.gguf \ + --last-tok 12345 \ + --target-feat-bin /path/to/target_feat.bin \ + --out-logits /tmp/draft_logits.bin + +# Once the Phase 0 flag lands, compare against the dflash reference dump: +python3 scripts/compare_logits.py /tmp/dflash_draft_golden.bin /tmp/draft_logits.bin +``` + +## Test 3.C — Hidden capture does not break chain mode + +Run both capture and no-capture modes in a single invocation. The driver +asserts logits are bit-equal and that the capture buffer contains valid +(non-NaN, non-zero) values. + +```bash +./test-qwen35-chain-capture \ + --model /path/to/Qwen3.5-27B-Q4_K_M.gguf \ + --prompt-tokens fixtures/ddtree/short_prompt.bin \ + --out-logits /tmp/capture_logits.bin \ + --out-capture /tmp/capture_buf.bin +# Exit 0: both assertions passed. +``` + +Regression-only mode (skips Mode A, writes no-capture logits for external +comparison against a Phase 1 chain golden dump): + +```bash +./test-qwen35-chain-capture \ + --model /path/to/Qwen3.5-27B-Q4_K_M.gguf \ + --prompt-tokens fixtures/ddtree/short_prompt.bin \ + --out-logits /tmp/nocapture_logits.bin \ + --out-capture /dev/null \ + --no-capture + +# Compare against Phase 1 chain baseline: +python3 scripts/compare_logits.py /tmp/chain_golden.bin /tmp/nocapture_logits.bin --abs-tol 0 --rel-tol 0 +``` + +--- + +# DDTree Phase 4 test fixtures + +## Build + +The Phase 4 end-to-end test binary is gated behind its own CMake option (off +by default, not in ctest): + +``` +cmake -DLLAMA_BUILD_TESTS_SPECULATIVE_TREE_E2E=ON .. +make test-speculative-tree-e2e +``` + +Requires: +- A Qwen3.5-27B GGUF (~16 GB). +- A converted dflash-draft GGUF. +- Phase 4 implementation API: + `llama_speculative_tree_driver_init` / `_step` / `_free` + (`common/speculative-tree-driver.h`) and `llama_set_target_feat_raw` + (Phase 3 gap; `llama.h`). + +## Test 4.A — Spec-decode token trajectory matches chain reference + +This is the canonical Phase 4 acceptance test. With `--temp 0` (greedy), +DDTree speculative decoding is lossless: the target verifies each draft token +against its own argmax before accepting it. The resulting token sequence MUST +be bit-equal to a plain greedy chain decode from the same prompt. + +```bash +./test-speculative-tree-e2e \ + --target-model /path/to/Qwen3.5-27B-Q4_K_M.gguf \ + --draft-model /path/to/draft.gguf \ + --prompt-tokens fixtures/ddtree/short_prompt.bin \ + --gen 64 \ + --out-chain /tmp/chain.tokens \ + --out-spec /tmp/spec.tokens \ + --ddtree-budget 22 \ + --temp 0 + +# The driver prints: chain_n=X spec_n=Y first_divergence=none bytes_match=Z/Z +# Exit 0 = PASS. + +# Optional: offline comparison using the script: +python3 scripts/compare_tokens.py /tmp/chain.tokens /tmp/spec.tokens +# Exit 0 = all positions match AND n_a == n_b. +``` + +**Note**: `--temp 0` is required for the bit-equal guarantee. Non-zero +temperature introduces stochastic sampling, which makes the two sequences +non-deterministic relative to each other. With non-zero temp the comparison +is informational only (the driver does not assert bit-equality). + +## Test 4.B — BLOCKED + +Test 4.B (comparison of the spec-decode token sequence against the output of +the `test_dflash` daemon) is blocked on a Phase 0 prerequisite: the +`test_dflash` daemon mode interface (`--daemon` flag) is not yet implemented. +Until that flag lands, no golden `test_dflash` token stream can be produced +for comparison. + +The chain-reference comparison in Test 4.A gives strong independent functional +verification and is sufficient for Phase 4 sign-off. + +--- + +# DDTree Phase 5 test fixtures + +Phase 5 integrates the DDTree driver into `llama-server` as a selectable +speculative-decode mode (`--speculative-mode ddtree`). There are no new +binary test fixtures; validation is done via two shell scripts in the +super-repo `scripts/` directory that run against the server on Castle. + +## New CLI flags (impl agent deliverables) + +| Flag | Type | Default | Notes | +|------|------|---------|-------| +| `--speculative-mode {chain,ddtree}` | string | chain | selects speculative backend | +| `--ddtree-budget N` | int | 22 | max draft tokens per tree step | +| `--ddtree-temp F` | float | 0.0 | draft sampling temperature | +| `--ddtree-no-chain-seed` | bool flag | off | disable chain-seed warmup | + +These flags are parsed in `common/arg.cpp`. The HTTP API surface is +unchanged: same OpenAI-compatible `/v1/chat/completions` and +`/v1/messages` endpoints, SSE streaming, `tool_use`, and +`reasoning_content` all work identically to chain mode. + +Only `--parallel 1` (single slot) is supported in Phase 5. + +## Test 5.A — Smoke test (primary acceptance) + +Run from the local mac: + +```bash +# Default (port 8003, single prompt) +./repo/scripts/run_server_ddtree_castle.sh + +# Custom port and prompt +./repo/scripts/run_server_ddtree_castle.sh 8003 "Write a haiku." +``` + +What the script does: + +1. Verifies the `llama-server` binary exists on Castle. +2. Kills any leftover DDTree-mode server (idempotent). +3. Starts the server via `nohup` on Castle, logging to `/tmp/ddtree_server.log`. +4. Polls `/health` up to 60 s (2 s interval). +5. Sends one non-streaming `POST /v1/chat/completions` and validates + `choices[0].message.content` is non-empty. +6. Sends one streaming request and confirms SSE `data:` lines arrive. +7. Prints the last 50 lines of the server log. +8. Stops the server. +9. Exits 0 (SMOKE PASS) or non-zero (SMOKE FAIL). + +## Test 5.B — Mode comparison (optional / informational) + +```bash +./repo/scripts/compare_server_modes_castle.sh +# or with a custom prompt: +./repo/scripts/compare_server_modes_castle.sh "Describe the sky in exactly 32 tokens." +``` + +Starts both a chain-mode server (port 8001) and a DDTree-mode server +(port 8003) on Castle, sends the same greedy (`temperature: 0`, +`max_tokens: 32`) prompt to each, and reports the first word-level +divergence index. + +**Expected outcome**: divergence at word index >= 17. This matches the +Phase 4 finding that chain and DDTree outputs are bit-equal up to +approximately 17 tokens per speculative-step boundary, then diverge due +to KV-cache / conversation-state differences in the server slot state +machine. Divergence at or above that threshold is not a regression. +Early divergence (word index < 17) warrants investigation. + +The script always exits 0; the comparison is informational. + +## Phase 5 acceptance criteria + +Phase 5 acceptance is **smoke level only**: + +- `run_server_ddtree_castle.sh` exits 0 (SMOKE PASS). +- Non-streaming completion returns valid JSON with non-empty content. +- SSE streaming delivers at least one `data:` chunk. + +Full production replacement of `dflash/scripts/server.py` (pointing +Claude Code at `http://castle.local:8002/v1`) is the user's **manual** +validation step and is outside automated testing scope. diff --git a/tests/fixtures/ddtree/dflash_draft_metadata_smoke.json b/tests/fixtures/ddtree/dflash_draft_metadata_smoke.json new file mode 100644 index 00000000000..dc8099a2000 --- /dev/null +++ b/tests/fixtures/ddtree/dflash_draft_metadata_smoke.json @@ -0,0 +1,11 @@ +{ + "expected_arch": "dflash-draft", + "expected_n_layer": 5, + "expected_n_embd": 5120, + "expected_n_head": 32, + "expected_n_head_kv": 8, + "expected_target_n_embd": 5120, + "expected_mask_token_id": 248070, + "expected_block_size": 16, + "expected_capture_layers": [1, 16, 31, 46, 61] +} diff --git a/tests/fixtures/ddtree/make_short_prompt.py b/tests/fixtures/ddtree/make_short_prompt.py new file mode 100755 index 00000000000..e07eecf79b5 --- /dev/null +++ b/tests/fixtures/ddtree/make_short_prompt.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +""" +make_short_prompt.py + +Tokenize a text string with a Qwen3.5 GGUF and write int32 LE token IDs to +short_prompt.bin next to this script. + +Usage: + python3 make_short_prompt.py --text "You are a helpful assistant." \ + --model-path /path/to/Qwen3.5-27B-Q4_K_M.gguf + +Requires the `gguf` Python package: + pip install gguf + +If the package is unavailable, the script falls back to the hardcoded 16-token +fixture already committed in short_prompt.bin and prints a warning. +""" + +import argparse +import pathlib +import struct +import sys + +SCRIPT_DIR = pathlib.Path(__file__).parent + +FALLBACK_TOKENS = [ + 151644, 8948, 198, 2610, 525, 264, 10950, 17847, + 13, 151645, 198, 151644, 872, 198, 2610, 7291, +] + + +def tokenize_via_gguf(text: str, model_path: str): + try: + from gguf import GGUFReader # type: ignore + except ImportError: + return None + + # GGUFReader gives access to metadata but not a tokenizer runtime. + # For actual tokenization we need llama-cpp-python or similar. + try: + from llama_cpp import Llama # type: ignore + except ImportError: + return None + + llm = Llama(model_path=model_path, vocab_only=True, verbose=False) + tokens = llm.tokenize(text.encode(), add_bos=True, special=True) + return tokens + + +def main(): + parser = argparse.ArgumentParser(description="Generate short_prompt.bin from text.") + parser.add_argument("--text", default="You are a helpful assistant.") + parser.add_argument("--model-path", default="") + parser.add_argument("--out", default=str(SCRIPT_DIR / "short_prompt.bin")) + args = parser.parse_args() + + tokens = None + if args.model_path: + tokens = tokenize_via_gguf(args.text, args.model_path) + if tokens is None: + print( + "WARNING: llama_cpp Python package not available. " + "Writing hardcoded fallback fixture.", + file=sys.stderr, + ) + + if tokens is None: + tokens = FALLBACK_TOKENS + + out_path = pathlib.Path(args.out) + data = struct.pack("<" + "i" * len(tokens), *tokens) + out_path.write_bytes(data) + print(f"wrote {len(tokens)} tokens to {out_path}") + + +if __name__ == "__main__": + main() diff --git a/tests/fixtures/ddtree/short_prompt.bin b/tests/fixtures/ddtree/short_prompt.bin new file mode 100644 index 0000000000000000000000000000000000000000..bdea24d7dac1b048864bef745e29ef529e5fb0ee GIT binary patch literal 64 zcmazEU}E^9#K3Tjfq}t@i-CcciGhKG5y;nKVA$>o +#include +#include +#include +#include +#include +#include +#include +#include + +// --------------------------------------------------------------------------- +// constants +// --------------------------------------------------------------------------- + +static constexpr int32_t DRAFT_BATCH_SIZE = 16; +static constexpr int32_t DEFAULT_MASK_TOK_ID = 248070; + +// --------------------------------------------------------------------------- +// helpers +// --------------------------------------------------------------------------- + +static void usage(const char * prog) { + fprintf(stderr, + "Usage: %s\n" + " --target-model PATH (Qwen3.5-27B GGUF; source of token_embd lookup; required)\n" + " --draft-model PATH (dflash-draft GGUF; required)\n" + " --last-tok N (int32 token id; required)\n" + " --target-feat-bin PATH (F32 binary [ctx_len * 5 * embd_dim]; required)\n" + " --ctx-len N (number of positions in target-feat-bin; 0 = derive from file)\n" + " --out-logits PATH (F32 binary output for 16 positions; required)\n" + " --mask-token-id N (override mask token id; default %d)\n" + " --n-gpu-layers N (default 99)\n", + prog, DEFAULT_MASK_TOK_ID); +} + +// Read a raw F32 binary file into a host buffer. +static std::vector read_f32_bin(const std::string & path) { + std::ifstream f(path, std::ios::binary); + if (!f) { + throw std::runtime_error("cannot open binary file: " + path); + } + f.seekg(0, std::ios::end); + auto sz = f.tellg(); + f.seekg(0, std::ios::beg); + if (sz % sizeof(float) != 0) { + throw std::runtime_error("binary file size not a multiple of 4: " + path); + } + std::vector buf(sz / sizeof(float)); + f.read(reinterpret_cast(buf.data()), sz); + return buf; +} + +static void write_logits(const std::string & path, + const std::vector & data, + int32_t n_tokens, + int32_t vocab_size) { + std::ofstream f(path, std::ios::binary); + if (!f) { + throw std::runtime_error("cannot open out-logits for writing: " + path); + } + f.write(reinterpret_cast(&n_tokens), sizeof(int32_t)); + f.write(reinterpret_cast(&vocab_size), sizeof(int32_t)); + f.write(reinterpret_cast(data.data()), + (std::streamsize)(data.size() * sizeof(float))); +} + +// --------------------------------------------------------------------------- +// main +// --------------------------------------------------------------------------- + +int main(int argc, char ** argv) { + std::string target_model_path; + std::string draft_model_path; + std::string target_feat_path; + std::string out_logits_path; + int32_t last_tok = -1; + int32_t ctx_len = 0; + int32_t mask_tok_id = DEFAULT_MASK_TOK_ID; + int32_t n_gpu_layers = 99; + + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if (arg == "--target-model" && i + 1 < argc) { + target_model_path = argv[++i]; + } else if (arg == "--draft-model" && i + 1 < argc) { + draft_model_path = argv[++i]; + } else if (arg == "--last-tok" && i + 1 < argc) { + last_tok = std::atoi(argv[++i]); + } else if (arg == "--target-feat-bin" && i + 1 < argc) { + target_feat_path = argv[++i]; + } else if (arg == "--ctx-len" && i + 1 < argc) { + ctx_len = std::atoi(argv[++i]); + } else if (arg == "--out-logits" && i + 1 < argc) { + out_logits_path = argv[++i]; + } else if (arg == "--mask-token-id" && i + 1 < argc) { + mask_tok_id = std::atoi(argv[++i]); + } else if (arg == "--n-gpu-layers" && i + 1 < argc) { + n_gpu_layers = std::atoi(argv[++i]); + } else if (arg == "-h" || arg == "--help") { + usage(argv[0]); + return 0; + } else { + fprintf(stderr, "unknown argument: %s\n", arg.c_str()); + usage(argv[0]); + return 1; + } + } + + if (target_model_path.empty()) { fprintf(stderr, "--target-model is required\n"); return 1; } + if (draft_model_path.empty()) { fprintf(stderr, "--draft-model is required\n"); return 1; } + if (last_tok < 0) { fprintf(stderr, "--last-tok is required\n"); return 1; } + if (target_feat_path.empty()) { fprintf(stderr, "--target-feat-bin is required\n"); return 1; } + if (out_logits_path.empty()) { fprintf(stderr, "--out-logits is required\n"); return 1; } + + llama_backend_init(); + + auto mparams = llama_model_default_params(); + mparams.n_gpu_layers = n_gpu_layers; + + // ------------------------------------------------------------------ + // Step 1: load target model (for token_embd lookup only) + // ------------------------------------------------------------------ + llama_model * target_model = llama_model_load_from_file(target_model_path.c_str(), mparams); + if (!target_model) { + LOG_ERR("failed to load target model: %s\n", target_model_path.c_str()); + llama_backend_free(); + return 1; + } + + // ------------------------------------------------------------------ + // Step 2: load draft model + // ------------------------------------------------------------------ + llama_model * draft_model = llama_model_load_from_file(draft_model_path.c_str(), mparams); + if (!draft_model) { + LOG_ERR("failed to load draft model: %s\n", draft_model_path.c_str()); + llama_model_free(target_model); + llama_backend_free(); + return 1; + } + + int ret = 1; + try { + // ------------------------------------------------------------------ + // Step 3: derive embd_dim from target model + // ------------------------------------------------------------------ + // llama_model_n_embd returns the embedding dimension of the model. + const int32_t embd_dim = llama_model_n_embd(target_model); + LOG_INF("embd_dim = %d\n", embd_dim); + + // ------------------------------------------------------------------ + // Step 4: look up token embeddings for the 16-token batch + // tokens: [last_tok, mask_tok_id, mask_tok_id, ..., mask_tok_id] + // (1 + 15 = 16 tokens) + // ------------------------------------------------------------------ + std::vector batch_tokens(DRAFT_BATCH_SIZE); + batch_tokens[0] = (llama_token)last_tok; + for (int i = 1; i < DRAFT_BATCH_SIZE; ++i) { + batch_tokens[i] = (llama_token)mask_tok_id; + } + + // out_embd: [DRAFT_BATCH_SIZE * embd_dim] floats + // Public API is one token per call; loop over the batch. + std::vector token_embd((size_t)DRAFT_BATCH_SIZE * embd_dim, 0.0f); + for (int i = 0; i < DRAFT_BATCH_SIZE; ++i) { + const int rc = llama_model_token_embd_lookup( + target_model, + batch_tokens[i], + token_embd.data() + (size_t)i * embd_dim, + embd_dim); + if (rc != 0) { + throw std::runtime_error( + "llama_model_token_embd_lookup failed for token " + + std::to_string(batch_tokens[i])); + } + } + LOG_INF("token_embd lookup done (%d tokens x %d dim)\n", DRAFT_BATCH_SIZE, embd_dim); + + // ------------------------------------------------------------------ + // Step 5: read target_feat binary + // ------------------------------------------------------------------ + std::vector target_feat = read_f32_bin(target_feat_path); + + // Derive or validate ctx_len. + // Expected layout: [ctx_len * 5 * embd_dim] floats + const int32_t feat_width = 5 * embd_dim; + if (ctx_len == 0) { + if ((int32_t)target_feat.size() % feat_width != 0) { + throw std::runtime_error( + "target-feat-bin size not divisible by 5*embd_dim=" + + std::to_string(feat_width)); + } + ctx_len = (int32_t)(target_feat.size() / feat_width); + LOG_INF("derived ctx_len = %d from target-feat-bin\n", ctx_len); + } else { + const size_t expected = (size_t)ctx_len * feat_width; + if (target_feat.size() != expected) { + throw std::runtime_error( + "target-feat-bin has " + std::to_string(target_feat.size()) + + " floats, expected " + std::to_string(expected) + + " (ctx_len=" + std::to_string(ctx_len) + + " * feat_width=" + std::to_string(feat_width) + ")"); + } + } + LOG_INF("target_feat: %d positions x %d floats\n", ctx_len, feat_width); + + // ------------------------------------------------------------------ + // Step 6: init draft context + // + // n_ctx must cover both the draft batch (16) and the target feat + // positions (ctx_len). Use the larger of the two. + // ------------------------------------------------------------------ + const int32_t n_ctx_draft = std::max(ctx_len, DRAFT_BATCH_SIZE) + 64; + auto cparams = llama_context_default_params(); + cparams.n_ctx = (uint32_t)n_ctx_draft; + cparams.n_batch = (uint32_t)DRAFT_BATCH_SIZE; + + llama_context * draft_ctx = llama_init_from_model(draft_model, cparams); + if (!draft_ctx) { + throw std::runtime_error("failed to create draft context"); + } + + // ------------------------------------------------------------------ + // Step 7: build the embedding input buffer for the draft forward. + // + // The draft graph builder expects an embd batch where the embd pointer + // contains the concatenation of: + // [token_embd rows: DRAFT_BATCH_SIZE * embd_dim floats] + // [target_feat : ctx_len * 5 * embd_dim floats ] + // + // The batch is created with embd != 0 so llama_decode dispatches the + // embd path. Token IDs are left unset (embd takes precedence). + // + // NOTE: This layout is the current best guess from the roadmap. If + // the implementation agent uses a different mechanism (e.g., a separate + // set_target_feat() call), the user will reconcile and update this + // driver before compiling. + // ------------------------------------------------------------------ + const size_t embd_buf_floats = + (size_t)DRAFT_BATCH_SIZE * embd_dim + + (size_t)ctx_len * feat_width; + + std::vector embd_buf(embd_buf_floats); + // Copy token embeddings first + memcpy(embd_buf.data(), + token_embd.data(), + (size_t)DRAFT_BATCH_SIZE * embd_dim * sizeof(float)); + // Then target_feat + memcpy(embd_buf.data() + (size_t)DRAFT_BATCH_SIZE * embd_dim, + target_feat.data(), + (size_t)ctx_len * feat_width * sizeof(float)); + + // Build a batch that feeds embeddings directly. + // embd = 1 tells llama_batch_init to allocate an embd array; however + // we want to point at our own buffer, so we create the struct manually. + llama_batch batch; + memset(&batch, 0, sizeof(batch)); + batch.n_tokens = DRAFT_BATCH_SIZE; + // embd points to our concatenated buffer + batch.embd = embd_buf.data(); + + // Allocate ancillary arrays on the stack/heap. + std::vector pos_arr(DRAFT_BATCH_SIZE); + std::vector n_seq_id_arr(DRAFT_BATCH_SIZE, 1); + std::vector seq_id_val(DRAFT_BATCH_SIZE, 0); + std::vector seq_id_arr(DRAFT_BATCH_SIZE); + std::vector logits_arr(DRAFT_BATCH_SIZE, 1); + + for (int i = 0; i < DRAFT_BATCH_SIZE; ++i) { + pos_arr[i] = (llama_pos)i; + seq_id_arr[i] = &seq_id_val[i]; + } + batch.pos = pos_arr.data(); + batch.n_seq_id = n_seq_id_arr.data(); + batch.seq_id = seq_id_arr.data(); + batch.logits = logits_arr.data(); + + // ------------------------------------------------------------------ + // Step 8: run draft forward + // ------------------------------------------------------------------ + if (llama_decode(draft_ctx, batch) != 0) { + llama_free(draft_ctx); + throw std::runtime_error("llama_decode (draft) failed"); + } + + // ------------------------------------------------------------------ + // Step 9: collect logits for all 16 positions and dump + // ------------------------------------------------------------------ + const auto * vocab = llama_model_get_vocab(draft_model); + const int32_t vocab_size = llama_vocab_n_tokens(vocab); + + std::vector logits_out((size_t)DRAFT_BATCH_SIZE * vocab_size); + for (int i = 0; i < DRAFT_BATCH_SIZE; ++i) { + const float * row = llama_get_logits_ith(draft_ctx, i); + memcpy(&logits_out[(size_t)i * vocab_size], row, + vocab_size * sizeof(float)); + } + + llama_free(draft_ctx); + write_logits(out_logits_path, logits_out, DRAFT_BATCH_SIZE, vocab_size); + LOG_INF("draft forward done: wrote %d x %d logits to %s\n", + DRAFT_BATCH_SIZE, vocab_size, out_logits_path.c_str()); + ret = 0; + + } catch (const std::exception & e) { + LOG_ERR("error: %s\n", e.what()); + ret = 1; + } + + llama_model_free(draft_model); + llama_model_free(target_model); + llama_backend_free(); + return ret; +} diff --git a/tests/test-qwen35-chain-capture.cpp b/tests/test-qwen35-chain-capture.cpp new file mode 100644 index 00000000000..78fdf6910c9 --- /dev/null +++ b/tests/test-qwen35-chain-capture.cpp @@ -0,0 +1,354 @@ +// test-qwen35-chain-capture.cpp +// +// Phase 3 acceptance test (Test 3.C) for hidden-state capture. +// +// Two modes are run in a single invocation using the same model/context: +// +// Mode A (capture): +// - Calls llama_set_capture_hidden(ctx, true). +// - Decodes the prompt as a chain batch. +// - Dumps logits for the last token to --out-logits. +// - Reads the hidden capture buffer via llama_get_hidden_capture(). +// - Dumps the capture buffer to --out-capture. +// - Asserts: shape is [5 * hidden_dim, n_tokens]; no NaN/Inf; not all-zero. +// +// Mode B (regression, --no-capture): +// - Calls llama_set_capture_hidden(ctx, false) then re-decodes same prompt. +// - Asserts logits are BIT-EQUAL to Mode A output (same values, not just close). +// - Skips capture dump. +// +// Both modes run inside a single process so logits can be compared in memory. +// The --out-logits file is written once (from Mode A). If Mode B differs, +// the driver exits with code 1 and prints the first discrepant index. +// +// Build: requires -DLLAMA_BUILD_TESTS_DFLASH_DRAFT=ON (not in ctest). +// +// API assumptions (implementation agent deliverables): +// void llama_set_capture_hidden(llama_context * ctx, bool enable) +// -- opt-in to hidden-state capture for the target model. +// ggml_tensor * llama_get_hidden_capture(llama_context * ctx) +// -- returns a pointer to the capture tensor after decode. +// tensor shape: [5 * hidden_dim, n_tokens] (F32, host-accessible). +// Returns NULL if capture was not enabled or graph not yet run. +// +// Output binary format (--out-logits): +// int32_t n_tokens (= 1, last-position logit row) +// int32_t vocab_size +// float logits[vocab_size] +// +// Output binary format (--out-capture): +// int32_t feat_dim (= 5 * hidden_dim) +// int32_t n_tokens (number of prompt tokens) +// float buf[feat_dim * n_tokens] (row-major: row i = token i's features) + +#include "llama.h" +#include "ggml.h" +#include "log.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// --------------------------------------------------------------------------- +// helpers +// --------------------------------------------------------------------------- + +static void usage(const char * prog) { + fprintf(stderr, + "Usage: %s\n" + " --model PATH (Qwen3.5-27B GGUF; required)\n" + " --prompt-tokens PATH (binary int32 LE token IDs; required)\n" + " --out-logits PATH (F32 binary; required)\n" + " --out-capture PATH (F32 binary capture dump; required)\n" + " --no-capture (skip Mode A, only run Mode B regression check)\n" + " --n-gpu-layers N (default 99)\n" + " --n-ctx N (default 4096)\n" + "\n" + "Both capture and no-capture modes run in sequence within one invocation.\n" + "Logits from both modes are compared in memory and must be bit-equal.\n", + prog); +} + +static std::vector read_prompt_tokens(const std::string & path) { + std::ifstream f(path, std::ios::binary); + if (!f) { + throw std::runtime_error("cannot open prompt-tokens file: " + path); + } + f.seekg(0, std::ios::end); + auto sz = f.tellg(); + f.seekg(0, std::ios::beg); + if (sz % sizeof(int32_t) != 0) { + throw std::runtime_error("prompt-tokens file size not a multiple of 4: " + path); + } + std::vector tokens(sz / sizeof(int32_t)); + f.read(reinterpret_cast(tokens.data()), sz); + return tokens; +} + +static void write_logits(const std::string & path, + const float * data, + int32_t n_tokens, + int32_t vocab_size) { + std::ofstream f(path, std::ios::binary); + if (!f) { + throw std::runtime_error("cannot open out-logits for writing: " + path); + } + f.write(reinterpret_cast(&n_tokens), sizeof(int32_t)); + f.write(reinterpret_cast(&vocab_size), sizeof(int32_t)); + f.write(reinterpret_cast(data), + (std::streamsize)((size_t)n_tokens * vocab_size * sizeof(float))); +} + +static void write_capture(const std::string & path, + const float * data, + int32_t feat_dim, + int32_t n_tokens) { + std::ofstream f(path, std::ios::binary); + if (!f) { + throw std::runtime_error("cannot open out-capture for writing: " + path); + } + f.write(reinterpret_cast(&feat_dim), sizeof(int32_t)); + f.write(reinterpret_cast(&n_tokens), sizeof(int32_t)); + f.write(reinterpret_cast(data), + (std::streamsize)((size_t)feat_dim * n_tokens * sizeof(float))); +} + +// Decode a prompt batch and return the logits for the last position. +// The returned vector is a copy (safe across re-use of the context). +static std::vector decode_chain(llama_context * ctx, + const std::vector & prompt, + int32_t vocab_size) { + const int32_t n_tokens = (int32_t)prompt.size(); + llama_batch batch = llama_batch_init(n_tokens, /*embd=*/0, /*n_seq_max=*/1); + + for (int32_t i = 0; i < n_tokens; ++i) { + batch.token[i] = (llama_token)prompt[i]; + batch.pos[i] = (llama_pos)i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + // Only request logits for the last token to match Phase 1 chain mode. + batch.logits[i] = (i == n_tokens - 1) ? 1 : 0; + } + batch.n_tokens = n_tokens; + + if (llama_decode(ctx, batch) != 0) { + llama_batch_free(batch); + throw std::runtime_error("llama_decode failed"); + } + + // Copy last-token logits before freeing batch. + const float * row = llama_get_logits_ith(ctx, n_tokens - 1); + std::vector logits(row, row + vocab_size); + + llama_batch_free(batch); + return logits; +} + +// Clear context memory (KV cache + recurrent state) between the two decode runs. +static void clear_kv(llama_context * ctx) { + llama_memory_clear(llama_get_memory(ctx), /*data=*/true); +} + +// --------------------------------------------------------------------------- +// main +// --------------------------------------------------------------------------- + +int main(int argc, char ** argv) { + std::string model_path; + std::string prompt_tokens_path; + std::string out_logits_path; + std::string out_capture_path; + bool no_capture = false; + int32_t n_gpu_layers = 99; + int32_t n_ctx = 4096; + + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if (arg == "--model" && i + 1 < argc) { + model_path = argv[++i]; + } else if (arg == "--prompt-tokens" && i + 1 < argc) { + prompt_tokens_path = argv[++i]; + } else if (arg == "--out-logits" && i + 1 < argc) { + out_logits_path = argv[++i]; + } else if (arg == "--out-capture" && i + 1 < argc) { + out_capture_path = argv[++i]; + } else if (arg == "--no-capture") { + no_capture = true; + } else if (arg == "--n-gpu-layers" && i + 1 < argc) { + n_gpu_layers = std::atoi(argv[++i]); + } else if (arg == "--n-ctx" && i + 1 < argc) { + n_ctx = std::atoi(argv[++i]); + } else if (arg == "-h" || arg == "--help") { + usage(argv[0]); + return 0; + } else { + fprintf(stderr, "unknown argument: %s\n", arg.c_str()); + usage(argv[0]); + return 1; + } + } + + if (model_path.empty()) { fprintf(stderr, "--model is required\n"); return 1; } + if (prompt_tokens_path.empty()) { fprintf(stderr, "--prompt-tokens is required\n"); return 1; } + if (out_logits_path.empty()) { fprintf(stderr, "--out-logits is required\n"); return 1; } + if (out_capture_path.empty() && !no_capture) { + fprintf(stderr, "--out-capture is required (or pass --no-capture to skip Mode A)\n"); + return 1; + } + + llama_backend_init(); + + auto mparams = llama_model_default_params(); + mparams.n_gpu_layers = n_gpu_layers; + + llama_model * model = llama_model_load_from_file(model_path.c_str(), mparams); + if (!model) { + LOG_ERR("failed to load model: %s\n", model_path.c_str()); + llama_backend_free(); + return 1; + } + + auto cparams = llama_context_default_params(); + cparams.n_ctx = (uint32_t)n_ctx; + cparams.n_batch = (uint32_t)n_ctx; + + llama_context * ctx = llama_init_from_model(model, cparams); + if (!ctx) { + LOG_ERR("failed to create context\n"); + llama_model_free(model); + llama_backend_free(); + return 1; + } + + int ret = 1; + try { + const auto * vocab = llama_model_get_vocab(model); + const int32_t vocab_size = llama_vocab_n_tokens(vocab); + const int32_t hidden_dim = llama_model_n_embd(model); + const int32_t feat_dim = 5 * hidden_dim; + + std::vector prompt = read_prompt_tokens(prompt_tokens_path); + const int32_t n_prompt = (int32_t)prompt.size(); + + std::vector logits_capture; + std::vector logits_nocapture; + + // ================================================================== + // Mode A: capture enabled + // ================================================================== + if (!no_capture) { + LOG_INF("--- Mode A: capture enabled ---\n"); + llama_set_capture_hidden(ctx, true); + + logits_capture = decode_chain(ctx, prompt, vocab_size); + + // Write logits (last token only). + write_logits(out_logits_path, logits_capture.data(), 1, vocab_size); + LOG_INF("Mode A: logits written to %s\n", out_logits_path.c_str()); + + // Read hidden capture data (host-side, populated via ggml_backend_tensor_get_async). + int64_t ne0 = 0, ne1 = 0; + const float * cap_data = llama_get_hidden_capture_data(ctx, &ne0, &ne1); + if (!cap_data) { + throw std::runtime_error( + "llama_get_hidden_capture_data returned NULL after capture decode; " + "check that llama_set_capture_hidden is wired in the graph builder"); + } + + // Validate shape. qwen35.cpp allocates [n_embd, 5*n_tokens] (slots stacked + // along ne[1]); accept both layouts so the assertion stays portable. + const bool layout_stacked_ne1 = + ne0 == (int64_t)hidden_dim && ne1 == (int64_t)5 * n_prompt; + const bool layout_stacked_ne0 = + ne0 == (int64_t)feat_dim && ne1 == (int64_t)n_prompt; + if (!layout_stacked_ne1 && !layout_stacked_ne0) { + throw std::runtime_error( + "hidden capture tensor shape mismatch: got [" + + std::to_string(ne0) + ", " + std::to_string(ne1) + + "], expected [" + std::to_string(hidden_dim) + ", " + + std::to_string(5 * n_prompt) + "] or [" + + std::to_string(feat_dim) + ", " + std::to_string(n_prompt) + "]"); + } + LOG_INF("capture shape: [%lld, %lld] — OK (%s)\n", + (long long)ne0, (long long)ne1, + layout_stacked_ne1 ? "stacked along ne[1]" : "stacked along ne[0]"); + + // Validate: no NaN/Inf and not all-zero. + const size_t cap_n = (size_t) ne0 * (size_t) ne1; + bool any_nonzero = false; + for (size_t k = 0; k < cap_n; ++k) { + float v = cap_data[k]; + if (!std::isfinite(v)) { + throw std::runtime_error( + "hidden capture contains non-finite value at index " + + std::to_string(k)); + } + if (v != 0.0f) { + any_nonzero = true; + } + } + if (!any_nonzero) { + throw std::runtime_error( + "hidden capture is all-zero; capture hook is likely not wired"); + } + LOG_INF("capture: no NaN/Inf, at least one non-zero value — OK\n"); + + write_capture(out_capture_path, cap_data, feat_dim, n_prompt); + LOG_INF("Mode A: capture written to %s\n", out_capture_path.c_str()); + + clear_kv(ctx); + } + + // ================================================================== + // Mode B: capture disabled — must produce bit-equal logits + // ================================================================== + LOG_INF("--- Mode B: capture disabled ---\n"); + llama_set_capture_hidden(ctx, false); + + logits_nocapture = decode_chain(ctx, prompt, vocab_size); + + if (!no_capture) { + // Compare bit-for-bit against Mode A. + bool mismatch = false; + for (int32_t v = 0; v < vocab_size; ++v) { + if (logits_capture[v] != logits_nocapture[v]) { + fprintf(stderr, + "FAIL: logit mismatch at vocab index %d: " + "capture=%.8e no-capture=%.8e\n", + v, logits_capture[v], logits_nocapture[v]); + mismatch = true; + break; // report first discrepancy only + } + } + if (mismatch) { + throw std::runtime_error( + "Mode A and Mode B logits are not bit-equal; " + "hidden capture hook may be altering the compute graph"); + } + LOG_INF("Mode B: logits bit-equal to Mode A — OK\n"); + } else { + // no-capture-only run: write logits so the caller can compare + // against a Phase 1 golden dump externally. + write_logits(out_logits_path, logits_nocapture.data(), 1, vocab_size); + LOG_INF("Mode B only: logits written to %s\n", out_logits_path.c_str()); + } + + LOG_INF("all assertions passed\n"); + ret = 0; + + } catch (const std::exception & e) { + LOG_ERR("error: %s\n", e.what()); + ret = 1; + } + + llama_free(ctx); + llama_model_free(model); + llama_backend_free(); + return ret; +} diff --git a/tests/test-qwen35-root-vs-chain.cpp b/tests/test-qwen35-root-vs-chain.cpp new file mode 100644 index 00000000000..d7f092436e4 --- /dev/null +++ b/tests/test-qwen35-root-vs-chain.cpp @@ -0,0 +1,635 @@ +// test-qwen35-root-vs-chain.cpp +// +// DDTree diagnostic: verify that a single tree-mode forward at the root node +// (parent_id = -1) is equivalent to a chain forward of the same token at the +// same position. +// +// Two passes inside the same process (model loaded once): +// pass A (chain): chain prefill tokens[0 .. N-1], record logits at index N-1 +// pass B (tree-root): chain prefill tokens[0 .. N-2], then a single +// tree-mode batch with one node {token = tokens[N-1], parent_id = -1, +// pos = N-1}, record logits at index 0 +// +// If the tree kernel + tree input wiring are correct, A and B should match +// within numerical tolerance for that one position. +// +// Build: -DLLAMA_BUILD_TESTS_QWEN35_ROOT_VS_CHAIN=ON +// ./build-server/bin/test-qwen35-root-vs-chain \ +// --model PATH --prompt-tokens tokens.bin --out-summary diff.txt + +#include "llama.h" +#include "common.h" +#include "log.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static void usage(const char * prog) { + fprintf(stderr, + "Usage: %s\n" + " --model PATH (GGUF; required)\n" + " --prompt-tokens PATH (binary int32 LE token IDs)\n" + " --prompt-text STR (alternative to --prompt-tokens; tokenized in-process)\n" + " --prompt-text-file PATH (alternative to --prompt-text; UTF-8 text file)\n" + " --out-summary PATH (text summary; required)\n" + " --n-siblings N (extra sibling nodes at depth 1; default 0)\n" + " --n-spec-steps N (1 or 2; default 1; 2 chains step1 -> compact/rollback -> step2)\n" + " --gapped-accept (two-step diagnostic: accept DFS path 0,16,19)\n" + " --skip-rollback (diagnostic: compact accepted root but do not rollback SSM)\n" + " --n-gpu-layers N (default 99)\n" + " --n-ctx N (default 4096)\n", + prog); +} + +static std::vector read_prompt_tokens(const std::string & path) { + std::ifstream f(path, std::ios::binary); + if (!f) { + throw std::runtime_error("cannot open prompt-tokens file: " + path); + } + f.seekg(0, std::ios::end); + auto size = f.tellg(); + f.seekg(0, std::ios::beg); + if (size % sizeof(int32_t) != 0) { + throw std::runtime_error("prompt-tokens file size not a multiple of 4: " + path); + } + std::vector tokens(size / sizeof(int32_t)); + f.read(reinterpret_cast(tokens.data()), size); + return tokens; +} + +static std::vector run_chain_capture_last(llama_model * model, + const llama_context_params & cparams, + const std::vector & tokens, + std::vector * hidden_last = nullptr) { + llama_context * ctx = llama_init_from_model(model, cparams); + if (!ctx) { + throw std::runtime_error("failed to create chain context"); + } + if (hidden_last != nullptr) { + llama_set_capture_hidden(ctx, true); + } + + const int32_t n_tokens = (int32_t)tokens.size(); + const auto * vocab = llama_model_get_vocab(model); + const int32_t vocab_size = llama_vocab_n_tokens(vocab); + + llama_batch batch = llama_batch_init(n_tokens, /*embd=*/0, /*n_seq_max=*/1); + for (int32_t i = 0; i < n_tokens; ++i) { + batch.token[i] = (llama_token)tokens[i]; + batch.pos[i] = i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + batch.logits[i] = (i == n_tokens - 1) ? 1 : 0; + } + batch.n_tokens = n_tokens; + + if (llama_decode(ctx, batch) != 0) { + llama_batch_free(batch); + llama_free(ctx); + throw std::runtime_error("chain llama_decode failed"); + } + + const float * row = llama_get_logits_ith(ctx, n_tokens - 1); + std::vector out(vocab_size); + memcpy(out.data(), row, (size_t)vocab_size * sizeof(float)); + + if (hidden_last != nullptr) { + int64_t ne0 = 0; + int64_t ne1 = 0; + const float * cap = llama_get_hidden_capture_data(ctx, &ne0, &ne1); + if (cap == nullptr || ne0 <= 0 || ne1 != 5 * n_tokens) { + llama_batch_free(batch); + llama_free(ctx); + throw std::runtime_error("chain hidden capture unavailable or shape mismatch"); + } + hidden_last->assign((size_t)5 * ne0, 0.0f); + for (int64_t l = 0; l < 5; ++l) { + const float * src = cap + (l * n_tokens + (n_tokens - 1)) * ne0; + float * dst = hidden_last->data() + l * ne0; + memcpy(dst, src, (size_t)ne0 * sizeof(float)); + } + } + + llama_batch_free(batch); + llama_free(ctx); + return out; +} + +// Build a tree batch of (1 + n_siblings) nodes at the given root pos. +// Sibling tokens come from the prompt history (cyclic) so they are distinct. +static llama_batch build_tree_batch(const std::vector & tokens, + int32_t root_pos, + llama_token root_token, + int n_siblings) { + const int n_nodes = 1 + n_siblings; + llama_batch tb = llama_batch_init_tree(n_nodes, /*embd=*/0, /*n_seq_max=*/1); + tb.token[0] = root_token; + tb.pos[0] = root_pos; + tb.n_seq_id[0] = 1; + tb.seq_id[0][0] = 0; + tb.parent_id[0] = -1; + tb.logits[0] = 1; + for (int i = 1; i < n_nodes; ++i) { + const int n = (int)tokens.size(); + const int src = ((root_pos - i) % n + n) % n; + tb.token[i] = (llama_token)tokens[src]; + tb.pos[i] = root_pos + 1; + tb.n_seq_id[i] = 1; + tb.seq_id[i][0] = 0; + tb.parent_id[i] = 0; + tb.logits[i] = 0; + } + tb.n_tokens = n_nodes; + return tb; +} + +static llama_batch build_tree_gapped_accept_batch(const std::vector & tokens, + int32_t root_pos) { + const int n_nodes = 20; + llama_batch tb = llama_batch_init_tree(n_nodes, /*embd=*/0, /*n_seq_max=*/1); + for (int i = 0; i < n_nodes; ++i) { + const int n = (int)tokens.size(); + const int src = ((root_pos - 1 - i) % n + n) % n; + tb.token[i] = (llama_token)tokens[src]; + tb.pos[i] = root_pos + 1; + tb.n_seq_id[i] = 1; + tb.seq_id[i][0] = 0; + tb.parent_id[i] = 0; + tb.logits[i] = 0; + } + + tb.token[0] = (llama_token)tokens[root_pos]; + tb.pos[0] = root_pos; + tb.parent_id[0] = -1; + tb.logits[0] = 1; + + tb.token[16] = (llama_token)tokens[root_pos + 1]; + tb.pos[16] = root_pos + 1; + tb.parent_id[16] = 0; + + tb.token[19] = (llama_token)tokens[root_pos + 2]; + tb.pos[19] = root_pos + 2; + tb.parent_id[19] = 16; + + tb.n_tokens = n_nodes; + return tb; +} + +// Two-spec-step variant: chain prefill [0..N-3], step1 (root=tokens[N-2]), accept +// root only -> compact+rollback, step2 (root=tokens[N-1]); return logits at step2 root. +static std::vector run_chain_then_tree_two_step(llama_model * model, + const llama_context_params & cparams, + const std::vector & tokens, + int n_siblings, + bool skip_rollback, + bool gapped_accept, + std::vector * hidden_root = nullptr) { + llama_context * ctx = llama_init_from_model(model, cparams); + if (!ctx) { + throw std::runtime_error("failed to create two-step context"); + } + if (hidden_root != nullptr) { + llama_set_capture_hidden(ctx, true); + } + const int32_t N = (int32_t)tokens.size(); + const int32_t accept_depth = gapped_accept ? 3 : 1; + const int32_t n_prefix = N - accept_depth - 1; + const auto * vocab = llama_model_get_vocab(model); + const int32_t vocab_sz = llama_vocab_n_tokens(vocab); + + // chain prefill + if (n_prefix > 0) { + llama_batch b = llama_batch_init(n_prefix, /*embd=*/0, /*n_seq_max=*/1); + for (int32_t i = 0; i < n_prefix; ++i) { + b.token[i] = (llama_token)tokens[i]; b.pos[i] = i; + b.n_seq_id[i] = 1; b.seq_id[i][0] = 0; b.logits[i] = 0; + } + b.n_tokens = n_prefix; + if (llama_decode(ctx, b) != 0) { + llama_batch_free(b); llama_free(ctx); + throw std::runtime_error("two-step: prefill failed"); + } + llama_batch_free(b); + } + + // spec step 1: root @ pos n_prefix (token = tokens[n_prefix]) + { + llama_batch t = gapped_accept + ? build_tree_gapped_accept_batch(tokens, n_prefix) + : build_tree_batch(tokens, n_prefix, (llama_token)tokens[n_prefix], n_siblings); + if (llama_decode(ctx, t) != 0) { + llama_batch_free(t); llama_free(ctx); + throw std::runtime_error("two-step: spec step 1 failed"); + } + llama_batch_free(t); + } + // accept root only -> compact tree + SSM rollback + int32_t accepted_dfs_root[1] = {0}; + int32_t accepted_dfs_gapped[3] = {0, 16, 19}; + int32_t * accepted_dfs = gapped_accept ? accepted_dfs_gapped : accepted_dfs_root; + llama_kv_cache_seq_compact_tree(ctx, /*seq_id=*/0, accepted_dfs, + /*n_accepted=*/accept_depth, /*commit_n=*/accept_depth, + /*spine_start=*/n_prefix); + if (!skip_rollback) { + llama_dflash_rollback_ssm_to_dfs(ctx, /*seq_id=*/0, /*accepted_dfs_node=*/accepted_dfs[accept_depth - 1]); + llama_dflash_set_recurrent_tail_pos(ctx, /*seq_id=*/0, /*pos=*/n_prefix + accept_depth - 1); + } + + // spec step 2: root after the accepted chain. + std::vector out(vocab_sz); + { + llama_batch t = build_tree_batch(tokens, n_prefix + accept_depth, + (llama_token)tokens[n_prefix + accept_depth], n_siblings); + if (llama_decode(ctx, t) != 0) { + llama_batch_free(t); llama_free(ctx); + throw std::runtime_error("two-step: spec step 2 failed"); + } + memcpy(out.data(), llama_get_logits_ith(ctx, 0), + (size_t)vocab_sz * sizeof(float)); + if (hidden_root != nullptr) { + int64_t ne0 = 0; + int64_t ne1 = 0; + const float * cap = llama_get_hidden_capture_data(ctx, &ne0, &ne1); + const int n_nodes = 1 + n_siblings; + if (cap == nullptr || ne0 <= 0 || ne1 != 5 * n_nodes) { + llama_batch_free(t); + llama_free(ctx); + throw std::runtime_error("two-step hidden capture unavailable or shape mismatch"); + } + hidden_root->assign((size_t)5 * ne0, 0.0f); + for (int64_t l = 0; l < 5; ++l) { + const float * src = cap + l * n_nodes * ne0; + float * dst = hidden_root->data() + l * ne0; + memcpy(dst, src, (size_t)ne0 * sizeof(float)); + } + } + llama_batch_free(t); + } + + llama_free(ctx); + return out; +} + +static std::vector run_chain_then_tree_root(llama_model * model, + const llama_context_params & cparams, + const std::vector & tokens, + int n_siblings, + std::vector * hidden_root = nullptr) { + llama_context * ctx = llama_init_from_model(model, cparams); + if (!ctx) { + throw std::runtime_error("failed to create tree-root context"); + } + if (hidden_root != nullptr) { + llama_set_capture_hidden(ctx, true); + } + + const int32_t n_tokens = (int32_t)tokens.size(); + const auto * vocab = llama_model_get_vocab(model); + const int32_t vocab_size = llama_vocab_n_tokens(vocab); + const int32_t n_prefix = n_tokens - 1; + + // chain prefill tokens[0 .. n_prefix - 1] + if (n_prefix > 0) { + llama_batch batch = llama_batch_init(n_prefix, /*embd=*/0, /*n_seq_max=*/1); + for (int32_t i = 0; i < n_prefix; ++i) { + batch.token[i] = (llama_token)tokens[i]; + batch.pos[i] = i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + batch.logits[i] = 0; + } + batch.n_tokens = n_prefix; + + if (llama_decode(ctx, batch) != 0) { + llama_batch_free(batch); + llama_free(ctx); + throw std::runtime_error("chain prefill llama_decode failed"); + } + llama_batch_free(batch); + } + + // tree batch: root at index 0 + n_siblings nodes at depth 1 (parent = root) + const int n_nodes = 1 + n_siblings; + llama_batch tbatch = llama_batch_init_tree(/*n_tokens=*/n_nodes, /*embd=*/0, /*n_seq_max=*/1); + // root + tbatch.token[0] = (llama_token)tokens[n_prefix]; + tbatch.pos[0] = n_prefix; + tbatch.n_seq_id[0] = 1; + tbatch.seq_id[0][0] = 0; + tbatch.parent_id[0] = -1; + tbatch.logits[0] = 1; + // siblings: depth 1, parent = root, token taken from prompt history (cyclic) + // so the tree batch has *distinct* tokens like a real spec verify ubatch. + for (int i = 1; i < n_nodes; ++i) { + const int src = (n_prefix - 1 - i + (int)tokens.size()) % (int)tokens.size(); + tbatch.token[i] = (llama_token)tokens[(src < 0 ? src + (int)tokens.size() : src)]; + tbatch.pos[i] = n_prefix + 1; + tbatch.n_seq_id[i] = 1; + tbatch.seq_id[i][0] = 0; + tbatch.parent_id[i] = 0; + tbatch.logits[i] = 0; + } + tbatch.n_tokens = n_nodes; + + if (llama_decode(ctx, tbatch) != 0) { + llama_batch_free(tbatch); + llama_free(ctx); + throw std::runtime_error("tree-root llama_decode failed"); + } + + const float * row = llama_get_logits_ith(ctx, 0); + std::vector out(vocab_size); + memcpy(out.data(), row, (size_t)vocab_size * sizeof(float)); + + if (hidden_root != nullptr) { + int64_t ne0 = 0; + int64_t ne1 = 0; + const float * cap = llama_get_hidden_capture_data(ctx, &ne0, &ne1); + if (cap == nullptr || ne0 <= 0 || ne1 != 5 * n_nodes) { + llama_batch_free(tbatch); + llama_free(ctx); + throw std::runtime_error("tree hidden capture unavailable or shape mismatch"); + } + hidden_root->assign((size_t)5 * ne0, 0.0f); + for (int64_t l = 0; l < 5; ++l) { + const float * src = cap + l * n_nodes * ne0; + float * dst = hidden_root->data() + l * ne0; + memcpy(dst, src, (size_t)ne0 * sizeof(float)); + } + } + + llama_batch_free(tbatch); + llama_free(ctx); + return out; +} + +static std::vector run_chain_then_tree_gapped_last(llama_model * model, + const llama_context_params & cparams, + const std::vector & tokens, + std::vector * hidden_last = nullptr) { + llama_context * ctx = llama_init_from_model(model, cparams); + if (!ctx) { + throw std::runtime_error("failed to create gapped tree context"); + } + if (hidden_last != nullptr) { + llama_set_capture_hidden(ctx, true); + } + + const int32_t n_tokens = (int32_t)tokens.size(); + const int32_t n_prefix = n_tokens - 3; + const auto * vocab = llama_model_get_vocab(model); + const int32_t vocab_size = llama_vocab_n_tokens(vocab); + + if (n_prefix > 0) { + llama_batch batch = llama_batch_init(n_prefix, /*embd=*/0, /*n_seq_max=*/1); + for (int32_t i = 0; i < n_prefix; ++i) { + batch.token[i] = (llama_token)tokens[i]; + batch.pos[i] = i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + batch.logits[i] = 0; + } + batch.n_tokens = n_prefix; + if (llama_decode(ctx, batch) != 0) { + llama_batch_free(batch); + llama_free(ctx); + throw std::runtime_error("gapped prefill llama_decode failed"); + } + llama_batch_free(batch); + } + + llama_batch tbatch = build_tree_gapped_accept_batch(tokens, n_prefix); + tbatch.logits[0] = 0; + tbatch.logits[19] = 1; + if (llama_decode(ctx, tbatch) != 0) { + llama_batch_free(tbatch); + llama_free(ctx); + throw std::runtime_error("gapped tree llama_decode failed"); + } + + const float * row = llama_get_logits_ith(ctx, 19); + std::vector out(vocab_size); + memcpy(out.data(), row, (size_t)vocab_size * sizeof(float)); + + if (hidden_last != nullptr) { + int64_t ne0 = 0; + int64_t ne1 = 0; + const float * cap = llama_get_hidden_capture_data(ctx, &ne0, &ne1); + const int n_nodes = 20; + if (cap == nullptr || ne0 <= 0 || ne1 != 5 * n_nodes) { + llama_batch_free(tbatch); + llama_free(ctx); + throw std::runtime_error("gapped hidden capture unavailable or shape mismatch"); + } + hidden_last->assign((size_t)5 * ne0, 0.0f); + for (int64_t l = 0; l < 5; ++l) { + const float * src = cap + (l * n_nodes + 19) * ne0; + float * dst = hidden_last->data() + l * ne0; + memcpy(dst, src, (size_t)ne0 * sizeof(float)); + } + } + + llama_batch_free(tbatch); + llama_free(ctx); + return out; +} + +struct DiffStats { + double max_abs_diff; + double mean_abs_diff; + int argmax_a; + int argmax_b; + std::vector top5_a; + std::vector top5_b; +}; + +static std::vector top_k_indices(const std::vector & v, int k) { + std::vector idx(v.size()); + std::iota(idx.begin(), idx.end(), 0); + std::partial_sort(idx.begin(), idx.begin() + k, idx.end(), + [&](int a, int b) { return v[a] > v[b]; }); + idx.resize(k); + return idx; +} + +static DiffStats diff_logits(const std::vector & a, const std::vector & b) { + DiffStats s = {}; + if (a.size() != b.size() || a.empty()) { + throw std::runtime_error("logits size mismatch"); + } + double sum_abs = 0.0; + double max_abs = 0.0; + for (size_t i = 0; i < a.size(); ++i) { + const double d = std::fabs((double)a[i] - (double)b[i]); + sum_abs += d; + if (d > max_abs) max_abs = d; + } + s.max_abs_diff = max_abs; + s.mean_abs_diff = sum_abs / (double)a.size(); + s.argmax_a = (int)(std::max_element(a.begin(), a.end()) - a.begin()); + s.argmax_b = (int)(std::max_element(b.begin(), b.end()) - b.begin()); + s.top5_a = top_k_indices(a, 5); + s.top5_b = top_k_indices(b, 5); + return s; +} + +static std::string read_text_file(const std::string & path) { + std::ifstream f(path); + if (!f) { + throw std::runtime_error("cannot open prompt-text-file: " + path); + } + std::string s((std::istreambuf_iterator(f)), std::istreambuf_iterator()); + return s; +} + +static std::vector tokenize_text(llama_model * model, const std::string & text) { + const auto * vocab = llama_model_get_vocab(model); + int32_t n = -llama_tokenize(vocab, text.data(), (int32_t)text.size(), + nullptr, 0, /*add_special=*/true, /*parse_special=*/false); + if (n <= 0) { + throw std::runtime_error("llama_tokenize sizing failed"); + } + std::vector tmp(n); + int32_t got = llama_tokenize(vocab, text.data(), (int32_t)text.size(), + tmp.data(), n, true, false); + if (got != n) { + throw std::runtime_error("llama_tokenize result mismatch"); + } + std::vector out(tmp.begin(), tmp.end()); + return out; +} + +int main(int argc, char ** argv) { + std::string model_path; + std::string prompt_tokens_path; + std::string prompt_text; + std::string prompt_text_file; + std::string out_summary; + int32_t n_gpu_layers = 99; + int32_t n_ctx = 4096; + int32_t n_siblings = 0; + int32_t n_spec_steps = 1; + bool skip_rollback = false; + bool gapped_accept = false; + + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if (arg == "--model" && i + 1 < argc) model_path = argv[++i]; + else if (arg == "--prompt-tokens" && i + 1 < argc) prompt_tokens_path = argv[++i]; + else if (arg == "--prompt-text" && i + 1 < argc) prompt_text = argv[++i]; + else if (arg == "--prompt-text-file" && i + 1 < argc) prompt_text_file = argv[++i]; + else if (arg == "--out-summary" && i + 1 < argc) out_summary = argv[++i]; + else if (arg == "--n-siblings" && i + 1 < argc) n_siblings = std::atoi(argv[++i]); + else if (arg == "--n-spec-steps" && i + 1 < argc) n_spec_steps = std::atoi(argv[++i]); + else if (arg == "--gapped-accept") gapped_accept = true; + else if (arg == "--skip-rollback") skip_rollback = true; + else if (arg == "--n-gpu-layers" && i + 1 < argc) n_gpu_layers = std::atoi(argv[++i]); + else if (arg == "--n-ctx" && i + 1 < argc) n_ctx = std::atoi(argv[++i]); + else if (arg == "-h" || arg == "--help") { usage(argv[0]); return 0; } + else { fprintf(stderr, "unknown argument: %s\n", arg.c_str()); usage(argv[0]); return 1; } + } + + int input_modes = (!prompt_tokens_path.empty()) + (!prompt_text.empty()) + (!prompt_text_file.empty()); + if (model_path.empty() || out_summary.empty() || input_modes != 1) { + fprintf(stderr, "must provide exactly one of --prompt-tokens, --prompt-text, --prompt-text-file\n"); + usage(argv[0]); return 1; + } + + llama_backend_init(); + + auto mparams = llama_model_default_params(); + mparams.n_gpu_layers = n_gpu_layers; + + llama_model * model = llama_model_load_from_file(model_path.c_str(), mparams); + if (!model) { + LOG_ERR("failed to load model: %s\n", model_path.c_str()); + llama_backend_free(); + return 1; + } + + auto cparams = llama_context_default_params(); + cparams.n_ctx = (uint32_t)n_ctx; + cparams.n_batch = (uint32_t)n_ctx; + + int rc = 1; + try { + std::vector tokens; + if (!prompt_tokens_path.empty()) { + tokens = read_prompt_tokens(prompt_tokens_path); + } else { + std::string text = !prompt_text.empty() + ? prompt_text + : read_text_file(prompt_text_file); + tokens = tokenize_text(model, text); + LOG_INF("tokenized %zu tokens from text\n", tokens.size()); + } + if (tokens.size() < 2) { + throw std::runtime_error("need at least 2 tokens"); + } + + LOG_INF("loaded %zu tokens; running chain pass...\n", tokens.size()); + std::vector hidden_a; + std::vector hidden_b; + std::vector A = run_chain_capture_last(model, cparams, tokens, &hidden_a); + + LOG_INF("running chain-prefill + tree-root pass (n_siblings=%d, n_spec_steps=%d)...\n", + n_siblings, n_spec_steps); + std::vector B; + if (n_spec_steps == 1 && gapped_accept) { + B = run_chain_then_tree_gapped_last(model, cparams, tokens, &hidden_b); + } else if (n_spec_steps == 1) { + B = run_chain_then_tree_root(model, cparams, tokens, n_siblings, &hidden_b); + } else if (n_spec_steps == 2) { + B = run_chain_then_tree_two_step(model, cparams, tokens, n_siblings, skip_rollback, gapped_accept, &hidden_b); + } else { + throw std::runtime_error("--n-spec-steps must be 1 or 2"); + } + + DiffStats s = diff_logits(A, B); + + std::ofstream f(out_summary); + f << "n_tokens=" << tokens.size() << "\n"; + f << "n_siblings=" << n_siblings << "\n"; + f << "n_spec_steps=" << n_spec_steps << "\n"; + f << "skip_rollback=" << (skip_rollback ? 1 : 0) << "\n"; + f << "gapped_accept=" << (gapped_accept ? 1 : 0) << "\n"; + f << "vocab_size=" << A.size() << "\n"; + f << "max_abs_diff=" << s.max_abs_diff << "\n"; + f << "mean_abs_diff=" << s.mean_abs_diff << "\n"; + if (!hidden_a.empty() && !hidden_b.empty()) { + DiffStats hs = diff_logits(hidden_a, hidden_b); + f << "hidden_max_abs_diff=" << hs.max_abs_diff << "\n"; + f << "hidden_mean_abs_diff=" << hs.mean_abs_diff << "\n"; + } + f << "argmax_chain=" << s.argmax_a << "\n"; + f << "argmax_tree_root=" << s.argmax_b << "\n"; + f << "top5_chain="; + for (int x : s.top5_a) f << x << " "; + f << "\ntop5_tree_root="; + for (int x : s.top5_b) f << x << " "; + f << "\n"; + + fprintf(stderr, "max_abs_diff = %.6g\n", s.max_abs_diff); + fprintf(stderr, "mean_abs_diff = %.6g\n", s.mean_abs_diff); + fprintf(stderr, "argmax: chain=%d tree_root=%d %s\n", + s.argmax_a, s.argmax_b, + s.argmax_a == s.argmax_b ? "MATCH" : "DIFF"); + + rc = 0; + } catch (const std::exception & e) { + LOG_ERR("error: %s\n", e.what()); + rc = 1; + } + + llama_model_free(model); + llama_backend_free(); + return rc; +} diff --git a/tests/test-qwen35-tree-rollback.cpp b/tests/test-qwen35-tree-rollback.cpp new file mode 100644 index 00000000000..014771eb542 --- /dev/null +++ b/tests/test-qwen35-tree-rollback.cpp @@ -0,0 +1,299 @@ +// test-qwen35-tree-rollback.cpp +// +// Phase 2 acceptance test for DDTree snapshot/restore symmetry (Test 2.A). +// Loads a Qwen3.5-27B GGUF, decodes a prompt chain, takes a recurrent-state +// snapshot, runs N decode steps, restores the snapshot, runs the same N steps +// again from the same starting token, and dumps both runs' final logits. +// The two logit files must be bit-equal (--abs-tol 0 with compare_logits.py). +// +// Build: requires -DLLAMA_BUILD_TESTS_QWEN35_TREE_ROLLBACK=ON (not in ctest). +// +// API assumptions (implementation agent deliverables): +// typedef int32_t llama_mem_snapshot_id; +// llama_mem_snapshot_id llama_seq_snapshot(llama_context *, llama_seq_id); +// bool llama_seq_restore (llama_context *, llama_mem_snapshot_id); +// void llama_seq_release (llama_context *, llama_mem_snapshot_id); +// +// Output binary format (--out-logits-pre / --out-logits-post): +// int32_t n_tokens (= 1, the single last-step logit row) +// int32_t vocab_size +// float logits[vocab_size] + +#include "llama.h" +#include "log.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +// --------------------------------------------------------------------------- +// helpers +// --------------------------------------------------------------------------- + +static void usage(const char * prog) { + fprintf(stderr, + "Usage: %s\n" + " --model PATH (Qwen3.5-27B GGUF; required)\n" + " --prompt-tokens PATH (binary int32 LE token IDs; required)\n" + " --gen N (chain decode steps per run; default 8)\n" + " --out-logits-pre PATH (logit dump from first run; required)\n" + " --out-logits-post PATH (logit dump from second run after restore; required)\n" + " --n-gpu-layers N (default 99)\n" + " --n-ctx N (default 4096)\n", + prog); +} + +static std::vector read_prompt_tokens(const std::string & path) { + std::ifstream f(path, std::ios::binary); + if (!f) { + throw std::runtime_error("cannot open prompt-tokens file: " + path); + } + f.seekg(0, std::ios::end); + auto sz = f.tellg(); + f.seekg(0, std::ios::beg); + if (sz % sizeof(int32_t) != 0) { + throw std::runtime_error("prompt-tokens file size not a multiple of 4: " + path); + } + std::vector tokens(sz / sizeof(int32_t)); + f.read(reinterpret_cast(tokens.data()), sz); + return tokens; +} + +static void write_logits(const std::string & path, + const float * data, + int32_t n_tokens, + int32_t vocab_size) { + std::ofstream f(path, std::ios::binary); + if (!f) { + throw std::runtime_error("cannot open out-logits for writing: " + path); + } + f.write(reinterpret_cast(&n_tokens), sizeof(int32_t)); + f.write(reinterpret_cast(&vocab_size), sizeof(int32_t)); + f.write(reinterpret_cast(data), (std::streamsize)(n_tokens * vocab_size * sizeof(float))); +} + +// Return the argmax token id from a logits row. +static llama_token argmax(const float * logits, int32_t vocab_size) { + return (llama_token)(std::max_element(logits, logits + vocab_size) - logits); +} + +// Decode a single token at the given position and return the logits pointer. +// The returned pointer is valid until the next llama_decode call. +static const float * decode_single(llama_context * ctx, + llama_token tok, + llama_pos pos, + int32_t vocab_size) { + llama_batch batch = llama_batch_init(1, /*embd=*/0, /*n_seq_max=*/1); + batch.token[0] = tok; + batch.pos[0] = pos; + batch.n_seq_id[0] = 1; + batch.seq_id[0][0] = 0; + batch.logits[0] = 1; + batch.n_tokens = 1; + + if (llama_decode(ctx, batch) != 0) { + llama_batch_free(batch); + throw std::runtime_error("llama_decode failed for single token"); + } + + const float * row = llama_get_logits_ith(ctx, 0); + // Copy before freeing the batch (logits buffer owned by context, not batch) + llama_batch_free(batch); + (void)vocab_size; // size used by caller + return row; +} + +// --------------------------------------------------------------------------- +// main +// --------------------------------------------------------------------------- + +int main(int argc, char ** argv) { + std::string model_path; + std::string prompt_tokens_path; + std::string out_logits_pre_path; + std::string out_logits_post_path; + int32_t gen = 8; + int32_t n_gpu_layers = 99; + int32_t n_ctx = 4096; + + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if (arg == "--model" && i + 1 < argc) { + model_path = argv[++i]; + } else if (arg == "--prompt-tokens" && i + 1 < argc) { + prompt_tokens_path = argv[++i]; + } else if (arg == "--gen" && i + 1 < argc) { + gen = std::atoi(argv[++i]); + } else if (arg == "--out-logits-pre" && i + 1 < argc) { + out_logits_pre_path = argv[++i]; + } else if (arg == "--out-logits-post" && i + 1 < argc) { + out_logits_post_path = argv[++i]; + } else if (arg == "--n-gpu-layers" && i + 1 < argc) { + n_gpu_layers = std::atoi(argv[++i]); + } else if (arg == "--n-ctx" && i + 1 < argc) { + n_ctx = std::atoi(argv[++i]); + } else if (arg == "-h" || arg == "--help") { + usage(argv[0]); + return 0; + } else { + fprintf(stderr, "unknown argument: %s\n", arg.c_str()); + usage(argv[0]); + return 1; + } + } + + if (model_path.empty()) { + fprintf(stderr, "--model is required\n"); + return 1; + } + if (prompt_tokens_path.empty()) { + fprintf(stderr, "--prompt-tokens is required\n"); + return 1; + } + if (out_logits_pre_path.empty()) { + fprintf(stderr, "--out-logits-pre is required\n"); + return 1; + } + if (out_logits_post_path.empty()) { + fprintf(stderr, "--out-logits-post is required\n"); + return 1; + } + if (gen < 1) { + fprintf(stderr, "--gen must be >= 1\n"); + return 1; + } + + llama_backend_init(); + + auto mparams = llama_model_default_params(); + mparams.n_gpu_layers = n_gpu_layers; + + llama_model * model = llama_model_load_from_file(model_path.c_str(), mparams); + if (!model) { + LOG_ERR("failed to load model: %s\n", model_path.c_str()); + llama_backend_free(); + return 1; + } + + auto cparams = llama_context_default_params(); + cparams.n_ctx = (uint32_t)n_ctx; + cparams.n_batch = (uint32_t)n_ctx; + + llama_context * ctx = llama_init_from_model(model, cparams); + if (!ctx) { + LOG_ERR("failed to create context\n"); + llama_model_free(model); + llama_backend_free(); + return 1; + } + + int ret = 1; + try { + const auto * vocab = llama_model_get_vocab(model); + const int32_t vocab_size = llama_vocab_n_tokens(vocab); + + // ------------------------------------------------------------------ + // Step 1: decode the prompt as a single chain batch to prime state + // ------------------------------------------------------------------ + std::vector prompt = read_prompt_tokens(prompt_tokens_path); + const int32_t n_prompt = (int32_t)prompt.size(); + + { + llama_batch batch = llama_batch_init(n_prompt, /*embd=*/0, /*n_seq_max=*/1); + for (int32_t i = 0; i < n_prompt; ++i) { + batch.token[i] = (llama_token)prompt[i]; + batch.pos[i] = (llama_pos)i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + batch.logits[i] = (i == n_prompt - 1) ? 1 : 0; + } + batch.n_tokens = n_prompt; + + if (llama_decode(ctx, batch) != 0) { + llama_batch_free(batch); + throw std::runtime_error("prompt decode failed"); + } + llama_batch_free(batch); + } + LOG_INF("prompt decoded (%d tokens)\n", n_prompt); + + // ------------------------------------------------------------------ + // Step 2: snapshot the recurrent state BEFORE the decode loop + // The first token for both runs is the argmax of the prompt's last + // position logits, captured now so both runs start identically. + // ------------------------------------------------------------------ + const float * prompt_logits = llama_get_logits_ith(ctx, n_prompt - 1); + llama_token tok_first = argmax(prompt_logits, vocab_size); + LOG_INF("first token after prompt: %d\n", (int)tok_first); + + llama_mem_snapshot_id snap = llama_seq_snapshot(ctx, /*seq_id=*/0); + if (snap < 0) { + throw std::runtime_error("llama_seq_snapshot returned negative id"); + } + LOG_INF("snapshot id: %d\n", (int)snap); + + // ------------------------------------------------------------------ + // Step 3: run K decode steps (run 1), save logits of last step + // ------------------------------------------------------------------ + std::vector last_logits_pre((size_t)vocab_size); + { + llama_token cur = tok_first; + llama_pos pos = (llama_pos)n_prompt; + for (int step = 0; step < gen; ++step) { + const float * row = decode_single(ctx, cur, pos, vocab_size); + if (step == gen - 1) { + memcpy(last_logits_pre.data(), row, vocab_size * sizeof(float)); + } + cur = argmax(row, vocab_size); + ++pos; + } + } + write_logits(out_logits_pre_path, last_logits_pre.data(), 1, vocab_size); + LOG_INF("run 1 complete, logits written to %s\n", out_logits_pre_path.c_str()); + + // ------------------------------------------------------------------ + // Step 4: restore snapshot and run K steps again with same first token + // ------------------------------------------------------------------ + if (!llama_seq_restore(ctx, snap)) { + throw std::runtime_error("llama_seq_restore failed"); + } + LOG_INF("snapshot restored\n"); + + std::vector last_logits_post((size_t)vocab_size); + { + llama_token cur = tok_first; // same first token as run 1 + llama_pos pos = (llama_pos)n_prompt; + for (int step = 0; step < gen; ++step) { + const float * row = decode_single(ctx, cur, pos, vocab_size); + if (step == gen - 1) { + memcpy(last_logits_post.data(), row, vocab_size * sizeof(float)); + } + cur = argmax(row, vocab_size); + ++pos; + } + } + write_logits(out_logits_post_path, last_logits_post.data(), 1, vocab_size); + LOG_INF("run 2 complete, logits written to %s\n", out_logits_post_path.c_str()); + + // ------------------------------------------------------------------ + // Step 5: release snapshot + // ------------------------------------------------------------------ + llama_seq_release(ctx, snap); + LOG_INF("snapshot released\n"); + + ret = 0; + } catch (const std::exception & e) { + LOG_ERR("error: %s\n", e.what()); + ret = 1; + } + + llama_free(ctx); + llama_model_free(model); + llama_backend_free(); + return ret; +} diff --git a/tests/test-qwen35-tree.cpp b/tests/test-qwen35-tree.cpp new file mode 100644 index 00000000000..2516577c284 --- /dev/null +++ b/tests/test-qwen35-tree.cpp @@ -0,0 +1,311 @@ +// test-qwen35-tree.cpp +// +// Phase 1 acceptance test for DDTree tree-mode forward pass. +// Loads a Qwen3.5-27B GGUF, runs either a plain chain forward or a tree forward, +// and dumps raw F32 logits for offline comparison. +// +// Build: requires -DLLAMA_BUILD_TESTS_QWEN35_TREE=ON (not added to ctest by default). +// +// API assumptions (implementation agent deliverables): +// - llama_batch.parent_id : int32_t *, NULL in chain mode; -1 = root, else flat parent index +// - llama_batch_init_tree(n_tokens, embd, n_seq_max) : like llama_batch_init but also +// allocates parent_id array of size n_tokens +// - llama_batch_free() frees parent_id when non-NULL +// - llama_decode() reads batch.parent_id and dispatches tree forward when non-NULL +// +// Output binary format (--out-logits): +// int32_t n_tokens +// int32_t vocab_size +// float logits[n_tokens * vocab_size] (row-major, little-endian) + +#include "llama.h" +#include "common.h" +#include "log.h" + +#include +#include +#include +#include +#include +#include +#include + +// nlohmann/json is vendored at vendor/nlohmann/json.hpp +#include + +using json = nlohmann::json; + +// --------------------------------------------------------------------------- +// helpers +// --------------------------------------------------------------------------- + +static void usage(const char * prog) { + fprintf(stderr, + "Usage: %s\n" + " --mode {chain,tree} (required)\n" + " --model PATH (GGUF; required)\n" + " --prompt-tokens PATH (binary int32 LE token IDs; required)\n" + " --tree-fixture PATH (JSON; required in tree mode)\n" + " --out-logits PATH (binary F32 output; required)\n" + " --n-gpu-layers N (default 99)\n" + " --n-ctx N (default 4096)\n", + prog); +} + +static std::vector read_prompt_tokens(const std::string & path) { + std::ifstream f(path, std::ios::binary); + if (!f) { + throw std::runtime_error("cannot open prompt-tokens file: " + path); + } + f.seekg(0, std::ios::end); + auto size = f.tellg(); + f.seekg(0, std::ios::beg); + if (size % sizeof(int32_t) != 0) { + throw std::runtime_error("prompt-tokens file size not a multiple of 4: " + path); + } + std::vector tokens(size / sizeof(int32_t)); + f.read(reinterpret_cast(tokens.data()), size); + return tokens; +} + +struct TreeNode { + int32_t flat_idx; + int32_t token_id; + int32_t parent_idx; // -1 = root + int32_t depth; +}; + +static std::vector parse_tree_fixture(const std::string & path, int32_t & committed_offset) { + std::ifstream f(path); + if (!f) { + throw std::runtime_error("cannot open tree-fixture: " + path); + } + json j; + f >> j; + + committed_offset = j.value("committed_offset", 0); + + std::vector nodes; + for (const auto & n : j["nodes"]) { + TreeNode node; + node.flat_idx = n["flat_idx"].get(); + node.token_id = n["token_id"].get(); + node.parent_idx = n["parent_idx"].get(); + node.depth = n["depth"].get(); + nodes.push_back(node); + } + return nodes; +} + +static void write_logits(const std::string & path, + const std::vector & data, + int32_t n_tokens, + int32_t vocab_size) { + std::ofstream f(path, std::ios::binary); + if (!f) { + throw std::runtime_error("cannot open out-logits for writing: " + path); + } + f.write(reinterpret_cast(&n_tokens), sizeof(int32_t)); + f.write(reinterpret_cast(&vocab_size), sizeof(int32_t)); + f.write(reinterpret_cast(data.data()), (std::streamsize)(data.size() * sizeof(float))); +} + +// --------------------------------------------------------------------------- +// chain mode: plain llama_batch forward +// --------------------------------------------------------------------------- + +static int run_chain(llama_model * model, + llama_context * ctx, + const std::vector & prompt, + const std::string & out_path) { + const int32_t n_tokens = (int32_t)prompt.size(); + const auto * vocab = llama_model_get_vocab(model); + const int32_t vocab_size = llama_vocab_n_tokens(vocab); + + llama_batch batch = llama_batch_init(n_tokens, /*embd=*/0, /*n_seq_max=*/1); + + for (int32_t i = 0; i < n_tokens; ++i) { + batch.token[i] = (llama_token)prompt[i]; + batch.pos[i] = i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + batch.logits[i] = 1; // request logits for every position + } + batch.n_tokens = n_tokens; + + if (llama_decode(ctx, batch) != 0) { + LOG_ERR("%s: llama_decode failed\n", __func__); + llama_batch_free(batch); + return 1; + } + + std::vector logits_out((size_t)n_tokens * vocab_size); + for (int32_t i = 0; i < n_tokens; ++i) { + const float * row = llama_get_logits_ith(ctx, i); + memcpy(&logits_out[(size_t)i * vocab_size], row, vocab_size * sizeof(float)); + } + + llama_batch_free(batch); + write_logits(out_path, logits_out, n_tokens, vocab_size); + LOG_INF("chain: wrote %d x %d logits to %s\n", n_tokens, vocab_size, out_path.c_str()); + return 0; +} + +// --------------------------------------------------------------------------- +// tree mode: tree-batch forward +// --------------------------------------------------------------------------- +// +// Position assignment (Phase 1, 1D only): +// pos[i] = committed_offset + node.depth +// +// M-RoPE 4-axis positions are deferred to Phase 3 (UNKNOWN-3 in the roadmap). +// When that work lands, pos[] will need to be a 4-tuple per token and +// llama_batch will need a corresponding multi-axis pos field. + +static int run_tree(llama_model * model, + llama_context * ctx, + const std::string & fixture_path, + const std::string & out_path) { + int32_t committed_offset = 0; + std::vector nodes = parse_tree_fixture(fixture_path, committed_offset); + + const int32_t n_tokens = (int32_t)nodes.size(); + const auto * vocab = llama_model_get_vocab(model); + const int32_t vocab_size = llama_vocab_n_tokens(vocab); + + // llama_batch_init_tree is the new API added by the implementation agent. + // It behaves like llama_batch_init but additionally allocates batch.parent_id. + llama_batch batch = llama_batch_init_tree(n_tokens, /*embd=*/0, /*n_seq_max=*/1); + + for (int32_t i = 0; i < n_tokens; ++i) { + const TreeNode & node = nodes[i]; + batch.token[i] = (llama_token)node.token_id; + batch.pos[i] = (llama_pos)(committed_offset + node.depth); + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + batch.parent_id[i] = node.parent_idx; // -1 = root + batch.logits[i] = 1; + } + batch.n_tokens = n_tokens; + + if (llama_decode(ctx, batch) != 0) { + LOG_ERR("%s: llama_decode (tree) failed\n", __func__); + llama_batch_free(batch); + return 1; + } + + std::vector logits_out((size_t)n_tokens * vocab_size); + for (int32_t i = 0; i < n_tokens; ++i) { + const float * row = llama_get_logits_ith(ctx, i); + memcpy(&logits_out[(size_t)i * vocab_size], row, vocab_size * sizeof(float)); + } + + llama_batch_free(batch); + write_logits(out_path, logits_out, n_tokens, vocab_size); + LOG_INF("tree: wrote %d x %d logits to %s\n", n_tokens, vocab_size, out_path.c_str()); + return 0; +} + +// --------------------------------------------------------------------------- +// main +// --------------------------------------------------------------------------- + +int main(int argc, char ** argv) { + std::string mode; + std::string model_path; + std::string prompt_tokens_path; + std::string tree_fixture_path; + std::string out_logits_path; + int32_t n_gpu_layers = 99; + int32_t n_ctx = 4096; + + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if (arg == "--mode" && i + 1 < argc) { + mode = argv[++i]; + } else if (arg == "--model" && i + 1 < argc) { + model_path = argv[++i]; + } else if (arg == "--prompt-tokens" && i + 1 < argc) { + prompt_tokens_path = argv[++i]; + } else if (arg == "--tree-fixture" && i + 1 < argc) { + tree_fixture_path = argv[++i]; + } else if (arg == "--out-logits" && i + 1 < argc) { + out_logits_path = argv[++i]; + } else if (arg == "--n-gpu-layers" && i + 1 < argc) { + n_gpu_layers = std::atoi(argv[++i]); + } else if (arg == "--n-ctx" && i + 1 < argc) { + n_ctx = std::atoi(argv[++i]); + } else if (arg == "-h" || arg == "--help") { + usage(argv[0]); + return 0; + } else { + fprintf(stderr, "unknown argument: %s\n", arg.c_str()); + usage(argv[0]); + return 1; + } + } + + if (mode != "chain" && mode != "tree") { + fprintf(stderr, "--mode must be 'chain' or 'tree'\n"); + usage(argv[0]); + return 1; + } + if (model_path.empty()) { + fprintf(stderr, "--model is required\n"); + return 1; + } + if (prompt_tokens_path.empty() && mode == "chain") { + fprintf(stderr, "--prompt-tokens is required in chain mode\n"); + return 1; + } + if (tree_fixture_path.empty() && mode == "tree") { + fprintf(stderr, "--tree-fixture is required in tree mode\n"); + return 1; + } + if (out_logits_path.empty()) { + fprintf(stderr, "--out-logits is required\n"); + return 1; + } + + llama_backend_init(); + + auto mparams = llama_model_default_params(); + mparams.n_gpu_layers = n_gpu_layers; + + llama_model * model = llama_model_load_from_file(model_path.c_str(), mparams); + if (!model) { + LOG_ERR("failed to load model: %s\n", model_path.c_str()); + llama_backend_free(); + return 1; + } + + auto cparams = llama_context_default_params(); + cparams.n_ctx = (uint32_t)n_ctx; + cparams.n_batch = (uint32_t)n_ctx; + + llama_context * ctx = llama_init_from_model(model, cparams); + if (!ctx) { + LOG_ERR("failed to create context\n"); + llama_model_free(model); + llama_backend_free(); + return 1; + } + + int ret = 1; + try { + if (mode == "chain") { + std::vector prompt = read_prompt_tokens(prompt_tokens_path); + ret = run_chain(model, ctx, prompt, out_logits_path); + } else { + ret = run_tree(model, ctx, tree_fixture_path, out_logits_path); + } + } catch (const std::exception & e) { + LOG_ERR("error: %s\n", e.what()); + ret = 1; + } + + llama_free(ctx); + llama_model_free(model); + llama_backend_free(); + return ret; +} diff --git a/tests/test-speculative-draft-backend.cpp b/tests/test-speculative-draft-backend.cpp new file mode 100644 index 00000000000..82a1794d23d --- /dev/null +++ b/tests/test-speculative-draft-backend.cpp @@ -0,0 +1,99 @@ +#include "speculative-draft-backend.h" + +#include +#include +#include +#include +#include + +static void require(bool ok, const char * expr, int line) { + if (!ok) { + std::fprintf(stderr, "test-speculative-draft-backend:%d: check failed: %s\n", line, expr); + std::abort(); + } +} + +#define REQUIRE(expr) require((expr), #expr, __LINE__) + +static void test_top_k_width() { + llama_ddtree_params p; + p.block_size = 16; + p.budget = 40; + p.top_k = 0; + REQUIRE(llama_speculative_draft_top_k_width(p.block_size, p) == 40); + + p.budget = 8; + REQUIRE(llama_speculative_draft_top_k_width(p.block_size, p) == 8); + + p.top_k = 4; + REQUIRE(llama_speculative_draft_top_k_width(p.block_size, p) == 4); +} + +static void test_pack_target_feat_no_wrap() { + const int64_t fc = 3; + const int64_t cap = 4; + std::vector ring((size_t) fc * cap); + for (int64_t col = 0; col < cap; ++col) { + for (int64_t row = 0; row < fc; ++row) { + ring[(size_t) col * fc + row] = (float) (10 * col + row); + } + } + + llama_speculative_draft_target_feat_view view{ ring.data(), 3, cap, fc }; + std::vector out; + int64_t ctx_len = 0; + REQUIRE(llama_speculative_draft_pack_target_feat(view, out, ctx_len)); + REQUIRE(ctx_len == 3); + REQUIRE(out.size() == 9); + + for (int64_t col = 0; col < ctx_len; ++col) { + for (int64_t row = 0; row < fc; ++row) { + REQUIRE(out[(size_t) col * fc + row] == ring[(size_t) col * fc + row]); + } + } +} + +static void test_pack_target_feat_wrap() { + const int64_t fc = 3; + const int64_t cap = 4; + std::vector ring((size_t) fc * cap); + + // logical columns 2, 3, 4, 5 live in ring slots 2, 3, 0, 1. + const int64_t logical_by_slot[4] = { 4, 5, 2, 3 }; + for (int64_t slot = 0; slot < cap; ++slot) { + const int64_t logical = logical_by_slot[slot]; + for (int64_t row = 0; row < fc; ++row) { + ring[(size_t) slot * fc + row] = (float) (100 * logical + row); + } + } + + llama_speculative_draft_target_feat_view view{ ring.data(), 6, cap, fc }; + std::vector out; + int64_t ctx_len = 0; + REQUIRE(llama_speculative_draft_pack_target_feat(view, out, ctx_len)); + REQUIRE(ctx_len == cap); + REQUIRE(out.size() == (size_t) fc * cap); + + for (int64_t col = 0; col < ctx_len; ++col) { + const int64_t logical = 2 + col; + for (int64_t row = 0; row < fc; ++row) { + REQUIRE(out[(size_t) col * fc + row] == (float) (100 * logical + row)); + } + } +} + +static void test_pack_target_feat_empty() { + std::vector out{ 1.0f }; + int64_t ctx_len = 123; + llama_speculative_draft_target_feat_view view{}; + REQUIRE(!llama_speculative_draft_pack_target_feat(view, out, ctx_len)); + REQUIRE(ctx_len == 0); +} + +int main() { + test_top_k_width(); + test_pack_target_feat_no_wrap(); + test_pack_target_feat_wrap(); + test_pack_target_feat_empty(); + return 0; +} diff --git a/tests/test-speculative-tree-e2e.cpp b/tests/test-speculative-tree-e2e.cpp new file mode 100644 index 00000000000..6342bff59e2 --- /dev/null +++ b/tests/test-speculative-tree-e2e.cpp @@ -0,0 +1,765 @@ +// test-speculative-tree-e2e.cpp +// +// Phase 4 end-to-end acceptance test (Test 4.A) for DDTree speculative decoding. +// +// Two decode runs are performed back-to-back: +// +// Run 1 (chain reference): +// Load target model. Decode the prompt as a plain chain batch. Then loop +// greedy-argmax N times to collect N reference tokens. Write to --out-chain. +// +// Run 2 (spec decode): +// Reload a fresh target context with capture_hidden=true. Load draft model. +// Decode prompt as chain batch (primes hidden capture). Init the DDTree +// speculative driver. Loop spec steps until N tokens are collected. +// Write (first N) to --out-spec. +// +// Acceptance criterion (--temp 0 / greedy): +// The first min(chain_n, spec_n) tokens MUST be bit-equal. +// DDTree is lossless speculative decoding: each accepted draft token has +// been verified as matching target-argmax at that position. +// +// Build: requires -DLLAMA_BUILD_TESTS_SPECULATIVE_TREE_E2E=ON (not in ctest). +// +// API assumptions (implementation agent deliverables): +// -- From Phase 3 gap: +// void llama_set_target_feat_raw(llama_context * ctx, +// const float * data, +// int64_t n_embd_fc, +// int64_t ctx_len); +// +// -- Phase 4 driver (common/speculative-tree-driver.h): +// struct llama_speculative_tree_driver; +// llama_speculative_tree_driver * llama_speculative_tree_driver_init( +// llama_context * target_ctx, +// llama_context * draft_ctx, +// const llama_ddtree_params & params); +// void llama_speculative_tree_driver_free(llama_speculative_tree_driver * d); +// std::vector llama_speculative_tree_driver_step( +// llama_speculative_tree_driver * d, +// llama_token root_token, +// llama_pos committed_pos); +// +// -- Phase 3 (already landed): +// void llama_set_capture_hidden(llama_context * ctx, bool enable); +// +// Output binary format (--out-chain / --out-spec): +// int32_t n_tokens (number of generated tokens written) +// int32_t tokens[n_tokens] (little-endian int32, one per generated token) +// +// Note: --temp 0 is required for the bit-equal trajectory guarantee. +// Non-zero temperature introduces stochastic sampling and invalidates the +// comparison. + +#include "llama.h" +#include "common.h" +#include "log.h" +#include "speculative-tree.h" +#include "speculative-tree-driver.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// --------------------------------------------------------------------------- +// constants +// --------------------------------------------------------------------------- + +// Qwen3.5 EOS token id. Accept this token but stop further generation. +static constexpr llama_token QWEN35_EOS = 248045; + +// --------------------------------------------------------------------------- +// helpers +// --------------------------------------------------------------------------- + +static void usage(const char * prog) { + fprintf(stderr, + "Usage: %s\n" + " --target-model PATH (Qwen3.5-27B GGUF; required)\n" + " --draft-model PATH (dflash-draft GGUF; required)\n" + " --prompt-tokens PATH (binary int32 LE token IDs; required unless --prompt-text)\n" + " --prompt-text PATH (raw rendered prompt text; tokenized with target vocab)\n" + " --prompt-add-special (with --prompt-text, request tokenizer special BOS/EOS insertion)\n" + " --no-prompt-parse-special\n" + " (with --prompt-text, do not parse <|...|> as special tokens)\n" + " --gen N (tokens to generate; default 32)\n" + " --out-spec PATH (spec-decode output tokens, int32 LE; required)\n" + " --out-chain PATH (chain-decode reference tokens, int32 LE; required)\n" + " --ddtree-budget N (DDTree node budget; default 22)\n" + " --ddtree-no-chain-seed (disable chain-seed heuristic; default: on)\n" + " --require-ddtree (fail unless multi-node DDTree verify ran)\n" + " --require-replay (fail unless snapshot+replay fallback ran)\n" + " --require-full-prompt-ingest\n" + " (fail unless DDTree ingested every prompt token capture)\n" + " --temp F (sampling temperature; default 0.0 = greedy)\n" + " --n-gpu-layers N (default 99)\n" + " --n-ctx N (default 4096)\n" + " --n-batch N (logical prompt batch; default min(n_ctx, 2048))\n" + " --n-ubatch N (physical prompt batch; default 512)\n" + " --prompt-chunk N (prompt ingest chunk; default n_ubatch)\n" + " --no-flash-attn (disable Flash Attention)\n" + "\n" + "Pass --temp 0 (greedy) to enable token-trajectory bit-equal assertion.\n", + prog); +} + +static std::vector read_int32_file(const std::string & path) { + std::ifstream f(path, std::ios::binary); + if (!f) { + throw std::runtime_error("cannot open file: " + path); + } + f.seekg(0, std::ios::end); + auto sz = f.tellg(); + f.seekg(0, std::ios::beg); + if (sz % sizeof(int32_t) != 0) { + throw std::runtime_error("file size not a multiple of 4: " + path); + } + std::vector buf(sz / sizeof(int32_t)); + f.read(reinterpret_cast(buf.data()), sz); + return buf; +} + +static std::string read_text_file(const std::string & path) { + std::ifstream f(path, std::ios::binary); + if (!f) { + throw std::runtime_error("cannot open file: " + path); + } + return std::string((std::istreambuf_iterator(f)), std::istreambuf_iterator()); +} + +static std::vector tokenize_text( + const llama_vocab * vocab, + const std::string & text, + bool add_special, + bool parse_special) { + int32_t n = -llama_tokenize(vocab, text.data(), (int32_t)text.size(), + nullptr, 0, add_special, parse_special); + if (n <= 0) { + throw std::runtime_error("llama_tokenize sizing failed"); + } + std::vector tmp(n); + int32_t got = llama_tokenize(vocab, text.data(), (int32_t)text.size(), + tmp.data(), n, add_special, parse_special); + if (got != n) { + throw std::runtime_error("llama_tokenize result mismatch"); + } + return std::vector(tmp.begin(), tmp.end()); +} + +static void write_token_file(const std::string & path, + const std::vector & tokens) { + std::ofstream f(path, std::ios::binary); + if (!f) { + throw std::runtime_error("cannot open for writing: " + path); + } + int32_t n = (int32_t)tokens.size(); + f.write(reinterpret_cast(&n), sizeof(int32_t)); + f.write(reinterpret_cast(tokens.data()), + (std::streamsize)(tokens.size() * sizeof(llama_token))); +} + +// Decode prompt as plain chain batches, return last-token logits (copy). +// The optional per_chunk callback runs after every llama_decode() and is used +// by the DDTree run to ingest exactly the hidden capture columns produced by +// that physical prompt chunk. +static std::vector decode_chain_prompt(llama_context * ctx, + const std::vector & prompt, + int32_t vocab_size, + int32_t prompt_chunk, + const std::function & per_chunk = {}) { + const int32_t n = (int32_t)prompt.size(); + if (prompt_chunk <= 0) { + throw std::runtime_error("prompt_chunk must be > 0"); + } + + std::vector logits; + for (int32_t off = 0; off < n; off += prompt_chunk) { + const int32_t n_cur = std::min(prompt_chunk, n - off); + llama_batch batch = llama_batch_init(n_cur, /*embd=*/0, /*n_seq_max=*/1); + + for (int32_t i = 0; i < n_cur; ++i) { + const int32_t pos = off + i; + batch.token[i] = (llama_token)prompt[pos]; + batch.pos[i] = (llama_pos)pos; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + batch.logits[i] = (pos == n - 1) ? 1 : 0; + } + batch.n_tokens = n_cur; + + if (llama_decode(ctx, batch) != 0) { + llama_batch_free(batch); + throw std::runtime_error("llama_decode failed on prompt chunk"); + } + + if (per_chunk) { + per_chunk(n_cur); + } + + if (off + n_cur == n) { + const float * row = llama_get_logits_ith(ctx, n_cur - 1); + if (!row) { + llama_batch_free(batch); + throw std::runtime_error("prompt final logits unavailable"); + } + logits.assign(row, row + vocab_size); + } + + llama_batch_free(batch); + } + + if (logits.empty()) { + throw std::runtime_error("prompt decode produced no logits"); + } + return logits; +} + +// Greedy argmax over a logit row. +static llama_token argmax(const float * logits, int32_t vocab_size) { + llama_token best = 0; + float best_val = logits[0]; + for (int32_t v = 1; v < vocab_size; ++v) { + if (logits[v] > best_val) { + best_val = logits[v]; + best = v; + } + } + return best; +} + +// Decode a single token at position pos, return logits for that position. +static std::vector decode_single(llama_context * ctx, + llama_token tok, + llama_pos pos, + int32_t vocab_size) { + llama_batch batch = llama_batch_init(1, /*embd=*/0, /*n_seq_max=*/1); + batch.token[0] = tok; + batch.pos[0] = pos; + batch.n_seq_id[0] = 1; + batch.seq_id[0][0] = 0; + batch.logits[0] = 1; + batch.n_tokens = 1; + + if (llama_decode(ctx, batch) != 0) { + llama_batch_free(batch); + throw std::runtime_error("llama_decode failed on single token"); + } + + const float * row = llama_get_logits_ith(ctx, 0); + std::vector logits(row, row + vocab_size); + llama_batch_free(batch); + return logits; +} + +// --------------------------------------------------------------------------- +// Run 1: chain reference decode +// --------------------------------------------------------------------------- + +static std::vector run_chain( + llama_model * model, + const llama_context_params & cparams, + const std::vector & prompt, + int32_t gen, + int32_t vocab_size, + int32_t prompt_chunk) { + + llama_context * ctx = llama_init_from_model(model, cparams); + if (!ctx) { + throw std::runtime_error("chain: failed to create target context"); + } + + std::vector out; + out.reserve(gen); + + // Decode prompt; logits for last prompt token give the first generated token. + const auto prompt_t0 = std::chrono::steady_clock::now(); + std::vector logits = decode_chain_prompt(ctx, prompt, vocab_size, prompt_chunk); + const auto prompt_t1 = std::chrono::steady_clock::now(); + + llama_pos pos = (llama_pos)prompt.size(); // next decode position + + double decode_ms = 0.0; + int32_t decode_steps = 0; + for (int32_t i = 0; i < gen; ++i) { + llama_token tok = argmax(logits.data(), vocab_size); + out.push_back(tok); + if (tok == QWEN35_EOS) { + LOG_INF("chain: EOS at step %d\n", i); + break; + } + const auto decode_t0 = std::chrono::steady_clock::now(); + logits = decode_single(ctx, tok, pos, vocab_size); + decode_ms += std::chrono::duration(std::chrono::steady_clock::now() - decode_t0).count(); + decode_steps++; + pos++; + } + + llama_free(ctx); + LOG_INF("chain: generated %d tokens\n", (int)out.size()); + LOG_INF("chain timing detail: prompt=%.2f ms decode_steps=%d decode_avg=%.2f ms decode_total=%.2f ms\n", + std::chrono::duration(prompt_t1 - prompt_t0).count(), + (int)decode_steps, + decode_steps > 0 ? decode_ms / (double)decode_steps : 0.0, + decode_ms); + return out; +} + +// --------------------------------------------------------------------------- +// Run 2: speculative decode +// --------------------------------------------------------------------------- + +static std::vector run_spec( + llama_model * target_model, + llama_model * draft_model, + const llama_context_params & target_cparams, + const llama_context_params & draft_cparams, + const std::vector & prompt, + int32_t gen, + int32_t vocab_size, + const llama_ddtree_params & ddparams, + int32_t prompt_chunk, + llama_speculative_tree_driver_stats * out_stats) { + + // Target context with hidden capture enabled (required by the driver). + llama_context * target_ctx = llama_init_from_model(target_model, target_cparams); + if (!target_ctx) { + throw std::runtime_error("spec: failed to create target context"); + } + llama_set_capture_hidden(target_ctx, true); + + llama_context * draft_ctx = llama_init_from_model(draft_model, draft_cparams); + if (!draft_ctx) { + llama_free(target_ctx); + throw std::runtime_error("spec: failed to create draft context"); + } + + // Init spec driver. + llama_speculative_tree_driver * driver = + llama_speculative_tree_driver_init(target_ctx, draft_ctx, ddparams); + if (!driver) { + llama_free(draft_ctx); + llama_free(target_ctx); + throw std::runtime_error("spec: llama_speculative_tree_driver_init returned NULL"); + } + + // Prime hidden capture in physical prompt chunks and ingest each chunk + // immediately. A single logical 16k decode only leaves the last ubatch in + // the capture tensor, which is not a valid DDTree/DFlash prompt state. + std::vector prompt_logits = + decode_chain_prompt(target_ctx, prompt, vocab_size, prompt_chunk, + [&](int32_t n_cur) { + llama_speculative_tree_driver_ingest_prompt_capture(driver, n_cur); + }); + + // Root token = argmax of last prompt position. + llama_token root_token = argmax(prompt_logits.data(), vocab_size); + llama_pos committed_pos = (llama_pos)prompt.size(); + + std::vector out; + out.reserve(gen); + + bool hit_eos = false; + while ((int32_t)out.size() < gen && !hit_eos) { + std::vector accepted = + llama_speculative_tree_driver_step(driver, root_token, committed_pos); + + if (accepted.empty()) { + // Driver signals terminal condition (e.g. EOS from target). + LOG_INF("spec: driver returned empty accepted list at out_n=%d\n", + (int)out.size()); + break; + } + + // Driver returns [committed_tokens..., bonus]. The bonus is the next + // step's root_token and is NOT yet in the KV cache, so it's not part + // of the committed output and doesn't advance committed_pos. + const int32_t n_committed = (int32_t)accepted.size() - 1; + for (int32_t i = 0; i < n_committed; ++i) { + llama_token t = accepted[i]; + out.push_back(t); + if (t == QWEN35_EOS) { + hit_eos = true; + break; + } + if ((int32_t)out.size() >= gen) { + break; + } + } + + root_token = accepted.back(); // bonus, fed as next step's tree[0] + committed_pos += (llama_pos)n_committed; + } + + if (out_stats != nullptr) { + *out_stats = llama_speculative_tree_driver_get_stats(driver); + } + + llama_speculative_tree_driver_free(driver); + llama_free(draft_ctx); + llama_free(target_ctx); + + LOG_INF("spec: generated %d tokens\n", (int)out.size()); + return out; +} + +// --------------------------------------------------------------------------- +// main +// --------------------------------------------------------------------------- + +int main(int argc, char ** argv) { + std::string target_model_path; + std::string draft_model_path; + std::string prompt_tokens_path; + std::string prompt_text_path; + std::string out_spec_path; + std::string out_chain_path; + int32_t gen = 32; + int32_t n_gpu_layers = 99; + int32_t n_gpu_layers_draft = -1; + int32_t n_ctx = 4096; + int32_t n_batch_arg = 0; + int32_t n_ubatch_arg = 512; + int32_t prompt_chunk_arg = 0; + float temp = 0.0f; + std::string kv_type_str = "f16"; // "f16", "q8_0", or "q4_0" + bool require_ddtree = false; + bool require_replay = false; + bool require_full_prompt_ingest = false; + bool prompt_add_special = false; + bool prompt_parse_special = true; + bool no_flash_attn = false; + + llama_ddtree_params ddparams; // defaults: budget=22, chain_seed=true + // temp is set separately below after arg parsing + + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if (arg == "--target-model" && i + 1 < argc) { + target_model_path = argv[++i]; + } else if (arg == "--draft-model" && i + 1 < argc) { + draft_model_path = argv[++i]; + } else if (arg == "--prompt-tokens" && i + 1 < argc) { + prompt_tokens_path = argv[++i]; + } else if (arg == "--prompt-text" && i + 1 < argc) { + prompt_text_path = argv[++i]; + } else if (arg == "--prompt-add-special") { + prompt_add_special = true; + } else if (arg == "--no-prompt-parse-special") { + prompt_parse_special = false; + } else if (arg == "--gen" && i + 1 < argc) { + gen = std::atoi(argv[++i]); + } else if (arg == "--out-spec" && i + 1 < argc) { + out_spec_path = argv[++i]; + } else if (arg == "--out-chain" && i + 1 < argc) { + out_chain_path = argv[++i]; + } else if (arg == "--ddtree-budget" && i + 1 < argc) { + ddparams.budget = std::atoi(argv[++i]); + } else if (arg == "--ddtree-top-k" && i + 1 < argc) { + ddparams.top_k = std::atoi(argv[++i]); + } else if (arg == "--ddtree-no-chain-seed") { + ddparams.chain_seed = false; + } else if (arg == "--require-ddtree") { + require_ddtree = true; + } else if (arg == "--require-replay") { + require_replay = true; + } else if (arg == "--require-full-prompt-ingest") { + require_full_prompt_ingest = true; + } else if (arg == "--temp" && i + 1 < argc) { + temp = std::stof(argv[++i]); + } else if (arg == "--n-gpu-layers" && i + 1 < argc) { + n_gpu_layers = std::atoi(argv[++i]); + } else if (arg == "--draft-gpu-layers" && i + 1 < argc) { + n_gpu_layers_draft = std::atoi(argv[++i]); + } else if (arg == "--n-ctx" && i + 1 < argc) { + n_ctx = std::atoi(argv[++i]); + } else if (arg == "--n-batch" && i + 1 < argc) { + n_batch_arg = std::atoi(argv[++i]); + } else if (arg == "--n-ubatch" && i + 1 < argc) { + n_ubatch_arg = std::atoi(argv[++i]); + } else if (arg == "--prompt-chunk" && i + 1 < argc) { + prompt_chunk_arg = std::atoi(argv[++i]); + } else if (arg == "--no-flash-attn") { + no_flash_attn = true; + } else if (arg == "--kv-type" && i + 1 < argc) { + kv_type_str = argv[++i]; + } else if (arg == "-h" || arg == "--help") { + usage(argv[0]); + return 0; + } else { + fprintf(stderr, "unknown argument: %s\n", arg.c_str()); + usage(argv[0]); + return 1; + } + } + + if (target_model_path.empty()) { fprintf(stderr, "--target-model is required\n"); return 1; } + if (draft_model_path.empty()) { fprintf(stderr, "--draft-model is required\n"); return 1; } + if (prompt_tokens_path.empty() && prompt_text_path.empty()) { + fprintf(stderr, "one of --prompt-tokens or --prompt-text is required\n"); + return 1; + } + if (!prompt_tokens_path.empty() && !prompt_text_path.empty()) { + fprintf(stderr, "use only one of --prompt-tokens or --prompt-text\n"); + return 1; + } + if (out_spec_path.empty()) { fprintf(stderr, "--out-spec is required\n"); return 1; } + if (out_chain_path.empty()) { fprintf(stderr, "--out-chain is required\n"); return 1; } + if (gen <= 0) { fprintf(stderr, "--gen must be > 0\n"); return 1; } + + ddparams.temp = temp; + + const bool greedy = (temp == 0.0f); + if (!greedy) { + fprintf(stderr, + "warning: --temp %.4f is non-zero; token-trajectory bit-equal assertion " + "is DISABLED (stochastic sampling makes sequences non-deterministic)\n", + (double)temp); + } + + llama_backend_init(); + + int ret = 1; + + llama_model * target_model = nullptr; + llama_model * draft_model = nullptr; + + try { + // Load target model. + { + auto mparams = llama_model_default_params(); + mparams.n_gpu_layers = n_gpu_layers; + target_model = llama_model_load_from_file(target_model_path.c_str(), mparams); + if (!target_model) { + throw std::runtime_error("failed to load target model: " + target_model_path); + } + } + + // Load draft model. + { + auto mparams = llama_model_default_params(); + mparams.n_gpu_layers = n_gpu_layers_draft >= 0 ? n_gpu_layers_draft : n_gpu_layers; + mparams.target_model = target_model; + draft_model = llama_model_load_from_file(draft_model_path.c_str(), mparams); + if (!draft_model) { + throw std::runtime_error("failed to load draft model: " + draft_model_path); + } + } + + const auto * vocab = llama_model_get_vocab(target_model); + const int32_t vocab_size = llama_vocab_n_tokens(vocab); + std::vector prompt; + if (!prompt_tokens_path.empty()) { + prompt = read_int32_file(prompt_tokens_path); + } else { + const std::string prompt_text = read_text_file(prompt_text_path); + prompt = tokenize_text(vocab, prompt_text, prompt_add_special, prompt_parse_special); + } + if (prompt.empty()) { + throw std::runtime_error("prompt is empty after loading/tokenization"); + } + LOG_INF("prompt: %d tokens\n", (int)prompt.size()); + + // Context params shared by both target contexts (chain and spec runs). + const uint32_t n_batch = (uint32_t)(n_batch_arg > 0 ? n_batch_arg : std::min(n_ctx, 2048)); + const uint32_t n_ubatch = (uint32_t)(n_ubatch_arg > 0 ? n_ubatch_arg : 512); + ggml_type kv_type = GGML_TYPE_F16; + if (kv_type_str == "f16") kv_type = GGML_TYPE_F16; + else if (kv_type_str == "q8_0") kv_type = GGML_TYPE_Q8_0; + else if (kv_type_str == "q4_0") kv_type = GGML_TYPE_Q4_0; + else { fprintf(stderr, "unknown --kv-type: %s\n", kv_type_str.c_str()); return 1; } + auto target_cparams = llama_context_default_params(); + target_cparams.n_ctx = (uint32_t)n_ctx; + target_cparams.n_batch = n_batch; + target_cparams.n_ubatch = std::min(n_batch, n_ubatch); + target_cparams.type_k = kv_type; + target_cparams.type_v = kv_type; + if (no_flash_attn) { + target_cparams.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; + } + const int32_t prompt_chunk = prompt_chunk_arg > 0 + ? std::min(prompt_chunk_arg, (int32_t)target_cparams.n_batch) + : (int32_t)target_cparams.n_ubatch; + + // Draft context: dflash-draft doesn't keep a prompt KV cache; it consumes + // KV slots only for spec block decode (pos = committed_pos+i). A short + // ctx sized to prompt+gen+budget margin is sufficient and avoids the + // compute-buffer blow-up that target n_ctx would otherwise impose. + const uint32_t draft_n_ctx = (uint32_t)std::min( + (int32_t)4096, + std::max((int32_t)prompt.size() + gen + ddparams.budget + 64, (int32_t)1024)); + auto draft_cparams = llama_context_default_params(); + draft_cparams.n_ctx = draft_n_ctx; + draft_cparams.n_batch = std::min(draft_n_ctx, (uint32_t)64); + draft_cparams.n_ubatch = std::min(draft_cparams.n_batch, n_ubatch); + if (no_flash_attn) { + draft_cparams.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; + } + + // --------------------------------------------------------------- + // Run 1: chain reference + // --------------------------------------------------------------- + LOG_INF("=== Run 1: chain reference decode ===\n"); + const auto chain_t0 = std::chrono::steady_clock::now(); + std::vector chain_tokens = + run_chain(target_model, target_cparams, prompt, gen, vocab_size, prompt_chunk); + const auto chain_t1 = std::chrono::steady_clock::now(); + + write_token_file(out_chain_path, chain_tokens); + LOG_INF("chain: wrote %d tokens to %s\n", + (int)chain_tokens.size(), out_chain_path.c_str()); + LOG_INF("chain timing: %.3f sec\n", + std::chrono::duration(chain_t1 - chain_t0).count()); + + // --------------------------------------------------------------- + // Run 2: speculative decode + // --------------------------------------------------------------- + LOG_INF("=== Run 2: speculative decode ===\n"); + const auto spec_t0 = std::chrono::steady_clock::now(); + llama_speculative_tree_driver_stats spec_stats; + std::vector spec_tokens = + run_spec(target_model, draft_model, + target_cparams, draft_cparams, + prompt, gen, vocab_size, ddparams, prompt_chunk, &spec_stats); + const auto spec_t1 = std::chrono::steady_clock::now(); + + LOG_INF("spec stats: steps=%lld tree_verifies=%lld tree_nodes_total=%lld max_tree_nodes=%d dfs_last=%lld snapshot_replays=%lld fast_batched_replays=%lld fast_batched_cb=%lld fast_rollback=%lld committed=%lld max_commit=%d batched_committed=%lld batched_max_commit=%d batched_exact_same=%lld batched_exact_diff=%lld batched_longer=%lld batched_shorter=%lld prompt_ingests=%lld prompt_tokens=%lld tree_tokens=%lld replay_tokens=%lld capture_clamps=%lld\n", + (long long)spec_stats.n_steps, + (long long)spec_stats.n_tree_verifies, + (long long)spec_stats.n_tree_nodes_total, + (int)spec_stats.max_tree_nodes, + (long long)spec_stats.n_dfs_last_commits, + (long long)spec_stats.n_snapshot_replays, + (long long)spec_stats.n_fast_batched_replays, + (long long)spec_stats.n_fast_batched_callback_steps, + (long long)spec_stats.n_fast_rollback_steps, + (long long)spec_stats.n_committed_tokens, + (int)spec_stats.max_committed_tokens_per_step, + (long long)spec_stats.n_batched_posterior_committed_tokens, + (int)spec_stats.max_batched_posterior_committed_tokens_per_step, + (long long)spec_stats.n_batched_exact_same, + (long long)spec_stats.n_batched_exact_diff, + (long long)spec_stats.n_batched_exact_longer, + (long long)spec_stats.n_batched_exact_shorter, + (long long)spec_stats.n_prompt_ingest_calls, + (long long)spec_stats.n_prompt_ingested_tokens, + (long long)spec_stats.n_tree_ingested_tokens, + (long long)spec_stats.n_replay_ingested_tokens, + (long long)spec_stats.n_capture_clamps); + if (spec_stats.n_steps > 0) { + LOG_INF("spec acceptance: exact_avg_commit_per_step=%.3f batched_avg_commit_per_step=%.3f\n", + (double)spec_stats.n_committed_tokens / (double)spec_stats.n_steps, + (double)spec_stats.n_batched_posterior_committed_tokens / (double)spec_stats.n_steps); + const double inv_steps = 1.0 / (double)spec_stats.n_steps; + LOG_INF("spec timing avg: step=%.2f ms pack=%.2f draft=%.2f topk=%.2f build=%.2f snap=%.2f target_tree=%.2f posterior=%.2f accept=%.2f compact=%.2f rollback=%.2f ingest=%.2f tree_ingest=%.2f replay_ingest=%.2f replay=%.2f exact=%.2f\n", + spec_stats.t_step_ms * inv_steps, + spec_stats.t_target_feat_pack_ms * inv_steps, + spec_stats.t_draft_decode_ms * inv_steps, + spec_stats.t_topk_ms * inv_steps, + spec_stats.t_build_tree_ms * inv_steps, + spec_stats.t_snapshot_ms * inv_steps, + spec_stats.t_target_tree_decode_ms * inv_steps, + spec_stats.t_posterior_scan_ms * inv_steps, + spec_stats.t_accept_path_ms * inv_steps, + spec_stats.t_kv_compact_ms * inv_steps, + spec_stats.t_ssm_rollback_ms * inv_steps, + spec_stats.t_ingest_capture_ms * inv_steps, + spec_stats.t_tree_ingest_ms * inv_steps, + spec_stats.t_replay_ingest_ms * inv_steps, + spec_stats.t_replay_ms * inv_steps, + spec_stats.t_exact_validate_ms * inv_steps); + LOG_INF("spec timing total: prompt_ingest=%.2f ms tree_ingest=%.2f ms replay_ingest=%.2f ms\n", + spec_stats.t_prompt_ingest_ms, + spec_stats.t_tree_ingest_ms, + spec_stats.t_replay_ingest_ms); + } + LOG_INF("spec timing: %.3f sec\n", + std::chrono::duration(spec_t1 - spec_t0).count()); + + if (require_ddtree && (spec_stats.n_tree_verifies <= 0 || spec_stats.max_tree_nodes <= 1)) { + throw std::runtime_error("--require-ddtree failed: no multi-node DDTree verify observed"); + } + if (require_replay && spec_stats.n_snapshot_replays <= 0) { + throw std::runtime_error("--require-replay failed: snapshot+replay fallback was not exercised"); + } + if (require_full_prompt_ingest && + (spec_stats.n_capture_clamps != 0 || + spec_stats.n_prompt_ingested_tokens != (int64_t)prompt.size())) { + throw std::runtime_error("--require-full-prompt-ingest failed: prompt hidden capture was incomplete"); + } + + // Truncate to gen if the driver produced more tokens than requested. + if ((int32_t)spec_tokens.size() > gen) { + spec_tokens.resize(gen); + } + + write_token_file(out_spec_path, spec_tokens); + LOG_INF("spec: wrote %d tokens to %s\n", + (int)spec_tokens.size(), out_spec_path.c_str()); + + // --------------------------------------------------------------- + // Compare trajectories + // --------------------------------------------------------------- + const int32_t chain_n = (int32_t)chain_tokens.size(); + const int32_t spec_n = (int32_t)spec_tokens.size(); + const int32_t cmp_n = std::min(chain_n, spec_n); + + int32_t first_divergence = -1; + int32_t match_count = 0; + for (int32_t k = 0; k < cmp_n; ++k) { + if (chain_tokens[k] == spec_tokens[k]) { + match_count++; + } else if (first_divergence < 0) { + first_divergence = k; + break; + } + } + + if (first_divergence < 0 && match_count == cmp_n) { + // All positions matched. + printf("chain_n=%d spec_n=%d first_divergence=none bytes_match=%d/%d\n", + chain_n, spec_n, match_count, cmp_n); + } else { + printf("chain_n=%d spec_n=%d first_divergence=%d bytes_match=%d/%d\n", + chain_n, spec_n, first_divergence, match_count, cmp_n); + } + + if (greedy) { + if (first_divergence >= 0) { + fprintf(stderr, + "FAIL: token-trajectory divergence at position %d " + "(greedy decoding MUST produce bit-equal sequences)\n" + " chain[%d] = %d\n" + " spec[%d] = %d\n", + first_divergence, + first_divergence, (int)chain_tokens[first_divergence], + first_divergence, (int)spec_tokens[first_divergence]); + ret = 1; + } else { + LOG_INF("PASS: all %d token positions are bit-equal\n", cmp_n); + ret = 0; + } + } else { + // Non-greedy: no hard assertion, just report. + LOG_INF("non-greedy mode: token-trajectory comparison is informational only\n"); + ret = 0; + } + + } catch (const std::exception & e) { + LOG_ERR("error: %s\n", e.what()); + ret = 1; + } + + if (draft_model) { llama_model_free(draft_model); } + if (target_model) { llama_model_free(target_model); } + llama_backend_free(); + return ret; +} diff --git a/tests/test-speculative-tree.cpp b/tests/test-speculative-tree.cpp new file mode 100644 index 00000000000..61644447cc1 --- /dev/null +++ b/tests/test-speculative-tree.cpp @@ -0,0 +1,336 @@ +// Standalone unit tests for speculative-tree.{h,cpp}. +// No model or GPU required. Uses hand-computed fixtures. +// Exits non-zero on first failure. + +#ifdef NDEBUG +#undef NDEBUG +#endif + +#include "speculative-tree.h" + +#include +#include +#include +#include +#include + +// --------------------------------------------------------------------------- +// Test 1: build_ddtree small case +// +// L=2, K=2, budget=4 (total nodes including root), chain_seed=true, root=10. +// +// top_log_probs[L*K]: +// position 0 (depth 1): [-0.1, -1.5] tokens [20, 21] +// position 1 (depth 2): [-0.2, -1.6] tokens [30, 31] +// +// Chain seeding (chain_depth = min(2, budget-1=3) = 2): +// d=1: insert tok=20 as node 1 (parent=0, depth=1), cum_logw=-0.1 +// push sibling: logw=-1.5, parent=0, depth=1, rank=1, tok=21 +// d=2: insert tok=30 as node 2 (parent=1, depth=2), cum_logw=-0.3 +// push sibling: logw=-0.3-(-0.2)+(-1.6)=-1.7, parent=1, depth=2, rank=1, tok=31 +// +// Heap after chain: {logw=-1.5,tok=21} and {logw=-1.7,tok=31} +// Pop best: logw=-1.5 → tok=21 inserted as node 3 (parent=0, depth=1). Done. +// +// Expected tree nodes: +// [0] root(10), parent=-1, depth=0 +// [1] tok=20, parent=0, depth=1 +// [2] tok=30, parent=1, depth=2 +// [3] tok=21, parent=0, depth=1 +// --------------------------------------------------------------------------- +static void test_build_ddtree_small() { + const int L = 2, K = 2; + const float top_log_probs[] = { + -0.1f, -1.5f, // position 0 + -0.2f, -1.6f, // position 1 + }; + const int32_t top_token_ids[] = { + 20, 21, // position 0 + 30, 31, // position 1 + }; + + llama_ddtree_params p; + p.budget = 4; // total node cap including root + p.chain_seed = true; + p.temp = 1.0f; + + const llama_ddtree tree = build_ddtree( + top_log_probs, top_token_ids, L, K, /*root_token*/ 10, p); + + assert(tree.nodes.size() == 4); + + // Node 0: root + assert(tree.nodes[0].token_id == 10); + assert(tree.nodes[0].parent_idx == -1); + assert(tree.nodes[0].depth == 0); + + // Node 1: tok=20, depth-1 chain top-1 + assert(tree.nodes[1].token_id == 20); + assert(tree.nodes[1].parent_idx == 0); + assert(tree.nodes[1].depth == 1); + + // Node 2: tok=30, depth-2 chain top-1 + assert(tree.nodes[2].token_id == 30); + assert(tree.nodes[2].parent_idx == 1); + assert(tree.nodes[2].depth == 2); + + // Node 3: tok=21, best heap candidate (sibling of 20 at depth 1) + assert(tree.nodes[3].token_id == 21); + assert(tree.nodes[3].parent_idx == 0); + assert(tree.nodes[3].depth == 1); + + // Visibility: 4x4 mask, row i has 1 at all ancestors of i (inclusive). + // node 0 ancestors: {0} + // node 1 ancestors: {0, 1} + // node 2 ancestors: {0, 1, 2} + // node 3 ancestors: {0, 3} + const int N = 4; + assert(tree.visibility.size() == (size_t)(N * N)); + // row 0 + assert(tree.visibility[0*N+0] == 1); + assert(tree.visibility[0*N+1] == 0); + assert(tree.visibility[0*N+2] == 0); + assert(tree.visibility[0*N+3] == 0); + // row 1 + assert(tree.visibility[1*N+0] == 1); + assert(tree.visibility[1*N+1] == 1); + assert(tree.visibility[1*N+2] == 0); + assert(tree.visibility[1*N+3] == 0); + // row 2 + assert(tree.visibility[2*N+0] == 1); + assert(tree.visibility[2*N+1] == 1); + assert(tree.visibility[2*N+2] == 1); + assert(tree.visibility[2*N+3] == 0); + // row 3 + assert(tree.visibility[3*N+0] == 1); + assert(tree.visibility[3*N+1] == 0); + assert(tree.visibility[3*N+2] == 0); + assert(tree.visibility[3*N+3] == 1); + + printf("PASS: test_build_ddtree_small\n"); +} + +// --------------------------------------------------------------------------- +// Test 2: follow_verified_tree +// +// Uses the same 4-node tree from test 1. +// Tree structure: +// node 0: root(10), no parent +// node 1: tok=20, parent=0 +// node 2: tok=30, parent=1 +// node 3: tok=21, parent=0 +// +// child_maps derived from parent_idx: +// node 0 children: {20→1, 21→3} +// node 1 children: {30→2} +// node 2 children: {} +// node 3 children: {} +// +// Semantic: posterior[i] is the target model's argmax prediction at node i. +// The walk starts at node 0; at each step it looks for a child whose +// token_id matches posterior[current]. If found, advance; otherwise stop. +// accepted = [visited indices]; next_token = posterior[deepest accepted]. +// +// Case A: posterior = [20, 30, 99, 99] +// Start at 0. posterior[0]=20 → child 1 exists. Move to 1. +// posterior[1]=30 → child 2 exists. Move to 2. +// posterior[2]=99 → no child. Stop. +// accepted=[0,1,2], next_token=99. +// +// Case B: posterior = [21, 99, 99, 99] +// Start at 0. posterior[0]=21 → child 3 exists. Move to 3. +// posterior[3]=99 → no child. Stop. +// accepted=[0,3], next_token=99. +// +// Case C: posterior = [5, 99, 99, 99] +// Start at 0. posterior[0]=5 → no child. Stop immediately. +// accepted=[0], next_token=5. +// --------------------------------------------------------------------------- +static void test_follow_verified_tree() { + // Build the same 4-node tree via build_ddtree. + const int L = 2, K = 2; + const float top_log_probs[] = { -0.1f, -1.5f, -0.2f, -1.6f }; + const int32_t top_token_ids[] = { 20, 21, 30, 31 }; + + llama_ddtree_params p; + p.budget = 4; + p.chain_seed = true; + p.temp = 1.0f; + + const llama_ddtree tree = build_ddtree( + top_log_probs, top_token_ids, L, K, 10, p); + + std::vector accepted; + llama_token next_tok = -1; + + // Case A: greedy chain match + { + const int32_t posterior[] = { 20, 30, 99, 99 }; + follow_verified_tree(tree, posterior, accepted, next_tok); + assert(accepted.size() == 3); + assert(accepted[0] == 0); + assert(accepted[1] == 1); + assert(accepted[2] == 2); + assert(next_tok == 99); + } + + // Case B: branch match (tok=21 path) + { + const int32_t posterior[] = { 21, 99, 99, 99 }; + follow_verified_tree(tree, posterior, accepted, next_tok); + assert(accepted.size() == 2); + assert(accepted[0] == 0); + assert(accepted[1] == 3); + assert(next_tok == 99); + } + + // Case C: no match at root level — only root accepted + { + const int32_t posterior[] = { 5, 99, 99, 99 }; + follow_verified_tree(tree, posterior, accepted, next_tok); + assert(accepted.size() == 1); + assert(accepted[0] == 0); + assert(next_tok == 5); + } + + printf("PASS: test_follow_verified_tree\n"); +} + +// --------------------------------------------------------------------------- +// Test 3: build_tree_visibility — hand-constructed 5-node tree +// +// Manually define a tree: +// node 0: root, parent=-1 +// node 1: child of 0, parent=0 +// node 2: child of 1, parent=1 +// node 3: child of 0, parent=0 +// node 4: child of 3, parent=3 +// +// Expected visibility (5x5): +// row 0: {0} → [1,0,0,0,0] +// row 1: {0,1} → [1,1,0,0,0] +// row 2: {0,1,2} → [1,1,1,0,0] +// row 3: {0,3} → [1,0,0,1,0] +// row 4: {0,3,4} → [1,0,0,1,1] +// --------------------------------------------------------------------------- +static void test_build_tree_visibility() { + std::vector nodes = { + { 10, -1, 0 }, // 0: root + { 20, 0, 1 }, // 1: child of 0 + { 30, 1, 2 }, // 2: child of 1 + { 40, 0, 1 }, // 3: child of 0 + { 50, 3, 2 }, // 4: child of 3 + }; + + const int N = (int)nodes.size(); + std::vector vis(N * N, 0); + build_tree_visibility(nodes, vis.data()); + + // Row 0 + assert(vis[0*N+0] == 1); assert(vis[0*N+1] == 0); + assert(vis[0*N+2] == 0); assert(vis[0*N+3] == 0); assert(vis[0*N+4] == 0); + // Row 1 + assert(vis[1*N+0] == 1); assert(vis[1*N+1] == 1); + assert(vis[1*N+2] == 0); assert(vis[1*N+3] == 0); assert(vis[1*N+4] == 0); + // Row 2 + assert(vis[2*N+0] == 1); assert(vis[2*N+1] == 1); + assert(vis[2*N+2] == 1); assert(vis[2*N+3] == 0); assert(vis[2*N+4] == 0); + // Row 3 + assert(vis[3*N+0] == 1); assert(vis[3*N+1] == 0); + assert(vis[3*N+2] == 0); assert(vis[3*N+3] == 1); assert(vis[3*N+4] == 0); + // Row 4 + assert(vis[4*N+0] == 1); assert(vis[4*N+1] == 0); + assert(vis[4*N+2] == 0); assert(vis[4*N+3] == 1); assert(vis[4*N+4] == 1); + + printf("PASS: test_build_tree_visibility\n"); +} + +// --------------------------------------------------------------------------- +// Test 4: extract_top_k_logprobs +// +// Feed a [3, 8] logits matrix with known values at temp=1.0. +// K=3. Verify output ordering (descending log-prob) and values to 1e-5. +// +// Row 0: logits = [0,1,2,3,4,5,6,7] (argmax = id=7) +// Row 1: logits = [7,6,5,4,3,2,1,0] (argmax = id=0) +// Row 2: logits = [0,0,0,0,10,0,0,0] (argmax = id=4, dominant) +// +// For row 0: log_z = logsumexp([0,1,2,3,4,5,6,7]) +// log_z = 7 + log(sum of exp(k-7) for k=0..7) = 7 + log(exp(-7)+...+exp(0)) +// Top 3 by logit: ids [7,6,5], logprobs = [7-log_z, 6-log_z, 5-log_z] +// +// We compute expected values in the test itself using std::log and std::exp. +// --------------------------------------------------------------------------- +static float logsumexp_vec(const float * v, int n) { + float mx = v[0]; + for (int i = 1; i < n; i++) if (v[i] > mx) mx = v[i]; + float s = 0.0f; + for (int i = 0; i < n; i++) s += std::exp(v[i] - mx); + return mx + std::log(s); +} + +static void test_extract_top_k_logprobs() { + const int L = 3, V = 8, K = 3; + + const float logits[L * V] = { + 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, // row 0 + 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 0.0f, // row 1 + 0.0f, 0.0f, 0.0f, 0.0f, 10.0f, 0.0f, 0.0f, 0.0f, // row 2 + }; + + std::vector out_lp(L * K); + std::vector out_id(L * K); + + extract_top_k_logprobs(logits, L, V, K, 1.0f, + out_lp.data(), out_id.data()); + + // Check row 0: top-3 tokens by logit are ids 7, 6, 5. + assert(out_id[0*K+0] == 7); + assert(out_id[0*K+1] == 6); + assert(out_id[0*K+2] == 5); + // Check log-prob values against manual logsumexp. + { + const float log_z = logsumexp_vec(logits + 0*V, V); + assert(std::fabs(out_lp[0*K+0] - (7.0f - log_z)) < 1e-5f); + assert(std::fabs(out_lp[0*K+1] - (6.0f - log_z)) < 1e-5f); + assert(std::fabs(out_lp[0*K+2] - (5.0f - log_z)) < 1e-5f); + } + + // Check row 1: top-3 tokens are ids 0, 1, 2. + assert(out_id[1*K+0] == 0); + assert(out_id[1*K+1] == 1); + assert(out_id[1*K+2] == 2); + { + const float log_z = logsumexp_vec(logits + 1*V, V); + assert(std::fabs(out_lp[1*K+0] - (7.0f - log_z)) < 1e-5f); + assert(std::fabs(out_lp[1*K+1] - (6.0f - log_z)) < 1e-5f); + assert(std::fabs(out_lp[1*K+2] - (5.0f - log_z)) < 1e-5f); + } + + // Check row 2: id=4 dominates with logit=10. + assert(out_id[2*K+0] == 4); + { + const float log_z = logsumexp_vec(logits + 2*V, V); + assert(std::fabs(out_lp[2*K+0] - (10.0f - log_z)) < 1e-5f); + } + + // Verify descending order within each row. + for (int row = 0; row < L; row++) { + for (int k = 0; k < K - 1; k++) { + assert(out_lp[row*K+k] >= out_lp[row*K+k+1]); + } + } + + printf("PASS: test_extract_top_k_logprobs\n"); +} + +// --------------------------------------------------------------------------- + +int main() { + test_build_ddtree_small(); + test_follow_verified_tree(); + test_build_tree_visibility(); + test_extract_top_k_logprobs(); + printf("All tests passed.\n"); + return 0; +} diff --git a/tools/server/README.md b/tools/server/README.md index b30309bf3b0..ee5aaa6b5e0 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -1825,3 +1825,39 @@ You can use html formatting if needed. ``` + +## Speculative decoding: DDTree (dflash) + +DDTree is a tree-structured speculative decoding method that uses a compact dflash-draft companion model to propose multiple token paths in parallel, then verifies them against the target model in a single forward pass. It typically yields 5-10x accepted tokens per target forward pass compared to 1 for autoregressive decoding. + +**Required flags:** + +- `-m ` — target model (Qwen3.5-27B Q4_K_M or similar) +- `-md ` — dflash-draft GGUF (arch `LLM_ARCH_DFLASH_DRAFT`) +- `--speculative-mode ddtree` + +**Optional flags:** + +- `--ddtree-budget N` — tree node budget per spec step (default: 22) +- `--ddtree-temp F` — temperature for draft log-prob extraction (default: 1.0) +- `--ddtree-no-chain-seed` — disable greedy chain seed for the tree heap + +**Constraints (Phase 5):** + +- `--parallel 1` only — multi-slot DDTree is out of scope for Phase 5 +- Target must be Qwen3.5-27B; draft must be the matching `dflash-draft` GGUF +- Greedy verification only — DDTree's accept decision is argmax-based; temperature and top-p affect only draft log-prob extraction, not acceptance +- Known limitation: SSM conv state may diverge after ~17 tokens per spec step boundary; full bit-equal awaits `ggml_ssm_conv_tree_persist` op (follow-up) + +**Example command:** + +```bash +llama-server \ + -m models/Qwen3.5-27B-Q4_K_M.gguf \ + -md models/draft/model.gguf \ + --speculative-mode ddtree --ddtree-budget 22 \ + -ctk tq3_0 -ctv tq3_0 \ + --port 8002 -ngl 99 -c 16384 +``` + +Compatible with `--api-key`, `--chat-template`, `--jinja`, and all standard server flags. diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index e134b3cfb26..f52344fb2fe 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -9,6 +9,7 @@ #include "log.h" #include "sampling.h" #include "speculative.h" +#include "speculative-tree-driver.h" #include "mtmd.h" #include "mtmd-helper.h" @@ -58,6 +59,11 @@ struct server_slot { common_speculative * spec = nullptr; + // DDTree speculative decoding state (Phase 5); null when ddtree_mode is off + llama_speculative_tree_driver * spec_driver = nullptr; + llama_token ddtree_root_tok = LLAMA_TOKEN_NULL; // bonus token from prev step / first sampled + llama_pos ddtree_committed_pos = 0; // KV positions committed so far + // TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state // see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837 std::unique_ptr task; @@ -187,6 +193,14 @@ struct server_slot { n_draft_total = 0; n_draft_accepted = 0; + // free DDTree driver if one was created for this request + if (spec_driver) { + llama_speculative_tree_driver_free(spec_driver); + spec_driver = nullptr; + } + ddtree_root_tok = LLAMA_TOKEN_NULL; + ddtree_committed_pos = 0; + task_prev = std::move(task); task.reset(); @@ -399,6 +413,56 @@ struct server_slot { ); } + if (spec_driver) { + const llama_speculative_tree_driver_stats st = + llama_speculative_tree_driver_get_stats(spec_driver); + if (st.n_steps > 0) { + SLT_CNT(*this, + "ddtree stats: steps=%lld exact_avg_commit=%0.3f batched_avg_commit=%0.3f exact_max=%d batched_max=%d snapshot_replays=%lld fast_batched_replays=%lld fast_batched_cb=%lld fast_rollback=%lld batched_exact_diff=%lld batched_longer=%lld batched_shorter=%lld capture_clamps=%lld\n", + (long long)st.n_steps, + (double)st.n_committed_tokens / (double)st.n_steps, + (double)st.n_batched_posterior_committed_tokens / (double)st.n_steps, + (int)st.max_committed_tokens_per_step, + (int)st.max_batched_posterior_committed_tokens_per_step, + (long long)st.n_snapshot_replays, + (long long)st.n_fast_batched_replays, + (long long)st.n_fast_batched_callback_steps, + (long long)st.n_fast_rollback_steps, + (long long)st.n_batched_exact_diff, + (long long)st.n_batched_exact_longer, + (long long)st.n_batched_exact_shorter, + (long long)st.n_capture_clamps); + const double inv_steps = 1.0 / (double)st.n_steps; + SLT_CNT(*this, + "ddtree timing avg: step=%0.2f ms pack=%0.2f draft=%0.2f topk=%0.2f build=%0.2f snap=%0.2f target_tree=%0.2f posterior=%0.2f accept=%0.2f compact=%0.2f rollback=%0.2f ingest=%0.2f tree_ingest=%0.2f replay_ingest=%0.2f replay=%0.2f exact=%0.2f exact_decode=%0.2f exact_sample=%0.2f exact_advance=%0.2f exact_nodes=%0.2f\n", + st.t_step_ms * inv_steps, + st.t_target_feat_pack_ms * inv_steps, + st.t_draft_decode_ms * inv_steps, + st.t_topk_ms * inv_steps, + st.t_build_tree_ms * inv_steps, + st.t_snapshot_ms * inv_steps, + st.t_target_tree_decode_ms * inv_steps, + st.t_posterior_scan_ms * inv_steps, + st.t_accept_path_ms * inv_steps, + st.t_kv_compact_ms * inv_steps, + st.t_ssm_rollback_ms * inv_steps, + st.t_ingest_capture_ms * inv_steps, + st.t_tree_ingest_ms * inv_steps, + st.t_replay_ingest_ms * inv_steps, + st.t_replay_ms * inv_steps, + st.t_exact_validate_ms * inv_steps, + st.t_exact_decode_ms * inv_steps, + st.t_exact_sample_ms * inv_steps, + st.t_exact_advance_ms * inv_steps, + (double)st.n_exact_validate_nodes * inv_steps); + SLT_CNT(*this, + "ddtree timing total: prompt_ingest=%0.2f ms tree_ingest=%0.2f ms replay_ingest=%0.2f ms\n", + st.t_prompt_ingest_ms, + st.t_tree_ingest_ms, + st.t_replay_ingest_ms); + } + } + common_speculative_print_stats(spec); } @@ -563,6 +627,11 @@ struct server_context_impl { llama_model_ptr model_dft; + // DDTree draft context — separate from the chain-mode draft since it needs + // different n_ctx / n_batch sizing (small, fixed to draft block_size). + // Null when ddtree_mode is off. + llama_context * ctx_ddtree_dft = nullptr; + bool add_bos_token = true; int32_t n_ctx; // total context for all clients / slots @@ -600,6 +669,16 @@ struct server_context_impl { for (server_slot & slot : slots) { common_speculative_free(slot.spec); slot.spec = nullptr; + + if (slot.spec_driver) { + llama_speculative_tree_driver_free(slot.spec_driver); + slot.spec_driver = nullptr; + } + } + + if (ctx_ddtree_dft) { + llama_free(ctx_ddtree_dft); + ctx_ddtree_dft = nullptr; } llama_batch_free(batch); @@ -682,6 +761,9 @@ struct server_context_impl { params_dft.tensor_buft_overrides = params_spec.tensor_buft_overrides; auto mparams_dft = common_model_params_to_llama(params_dft); + if (params_base.speculative.ddtree_mode) { + mparams_dft.target_model = model; + } model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft)); if (model_dft == nullptr) { @@ -691,6 +773,32 @@ struct server_context_impl { params_base.speculative.model_dft = model_dft.get(); params_base.speculative.cparams_dft = common_context_params_to_llama(params_dft); + + // DDTree mode: create a dedicated draft context with the sizing the + // dflash-draft model expects (small n_ctx, small n_batch = block_size). + if (params_base.speculative.ddtree_mode) { + // Phase 5: single-slot only — enforce this up front. + if (params_base.n_parallel > 1) { + fprintf(stderr, "DDTree mode supports only --parallel 1 in Phase 5\n"); + return false; + } + + llama_context_params cparams_ddft = llama_context_default_params(); + cparams_ddft.n_ctx = 2048 + 16; // DRAFT_CTX_MAX + block_size + cparams_ddft.n_batch = 16; // one block per decode + + ctx_ddtree_dft = llama_init_from_model(model_dft.get(), cparams_ddft); + if (!ctx_ddtree_dft) { + SRV_ERR("%s", "failed to create DDTree draft context\n"); + return false; + } + SRV_INF("%s", "DDTree draft context initialized\n"); + + // Enable hidden capture on the target context so the driver can + // read intermediate layer features for tree scoring. + llama_set_capture_hidden(ctx, true); + SRV_INF("%s", "DDTree: hidden capture enabled on target context\n"); + } } std::string & mmproj_path = params_base.mmproj.path; @@ -1214,6 +1322,14 @@ struct server_context_impl { slot.task = std::make_unique(std::move(task)); + if (params_base.speculative.ddtree_mode && !slot.task->is_child()) { + const std::string prompt_text = slot.task->tokens.detokenize(ctx, true); + SLT_INF(slot, "DDTree request prompt: tokens = %d, chars = %zu\n", + slot.task->n_tokens(), prompt_text.size()); + SLT_INF(slot, "DDTree request prompt begin\n%s\nDDTree request prompt end\n", + prompt_text.c_str()); + } + slot.state = slot.task->is_child() ? SLOT_STATE_WAIT_OTHER // wait for the parent to process prompt : SLOT_STATE_STARTED; @@ -2091,6 +2207,15 @@ struct server_context_impl { continue; } + // DDTree slots drive their own batch submissions internally via the driver. + // Skip the normal token-addition and batch-decode path for them. + if (params_base.speculative.ddtree_mode && slot.spec_driver) { + if (!slot_batched) { + slot_batched = &slot; + } + continue; + } + // check if we can batch this slot with the previous one if (!slot_batched) { slot_batched = &slot; @@ -2155,9 +2280,14 @@ struct server_context_impl { } } - // process in chunks of params.n_batch - int32_t n_batch = llama_n_batch(ctx); + // process in chunks of params.n_batch. In DDTree mode hidden capture is + // only retained for the physical ubatch produced by llama_decode(), so + // prompt prefill must be submitted in n_ubatch-sized chunks to keep the + // driver's target feature ring complete. int32_t n_ubatch = llama_n_ubatch(ctx); + const int32_t n_batch_default = llama_n_batch(ctx); + const int32_t n_batch_prompt = params_base.speculative.ddtree_mode ? n_ubatch : n_batch_default; + int32_t n_batch = n_batch_prompt; float alora_scale = -1.0f; size_t alora_disabled_id = 0; @@ -2333,6 +2463,19 @@ struct server_context_impl { SLT_DBG(slot, "after context reuse, new n_past = %d\n", n_past); } + + if (params_base.speculative.ddtree_mode && n_past > 0) { + const int32_t ddtree_rebuild_nt = llama_speculative_tree_driver_context_window(); + const int32_t n_rebuild = std::min(ddtree_rebuild_nt, slot.task->n_tokens()); + const int32_t n_past_max = std::max(0, slot.task->n_tokens() - n_rebuild); + + if (n_past > n_past_max) { + SLT_WRN(slot, + "DDTree prompt cache reuse capped from %d to %d to rebuild the last %d target-feature tokens\n", + n_past, n_past_max, n_rebuild); + n_past = n_past_max; + } + } } else { // if we don't cache the prompt, we have to remove all previous tokens n_past = 0; @@ -2590,6 +2733,12 @@ struct server_context_impl { break; } } + if (params_base.speculative.ddtree_mode) { + const int32_t ddtree_rebuild_nt = llama_speculative_tree_driver_context_window(); + if (slot.task->n_tokens() == slot.prompt.n_tokens() + ddtree_rebuild_nt) { + should_break = true; + } + } if (should_break) { break; } @@ -2598,6 +2747,9 @@ struct server_context_impl { // the number of tokens added to the batch for the current slot const auto n_tokens_cur = batch.n_tokens - n_tokens_prev; + const bool ddtree_rebuild_checkpoint = + params_base.speculative.ddtree_mode && + slot.task->n_tokens() == slot.prompt.n_tokens() + llama_speculative_tree_driver_context_window(); // entire prompt has been processed if (slot.prompt.n_tokens() == slot.task->n_tokens()) { @@ -2614,7 +2766,11 @@ struct server_context_impl { slot.init_sampler(); SLT_INF(slot, "prompt processing done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens); } else { - if (slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch) { + if (ddtree_rebuild_checkpoint) { + do_checkpoint = do_checkpoint && true; + SLT_INF(slot, "creating DDTree rebuild checkpoint before the last %d prompt tokens at position %d\n", + llama_speculative_tree_driver_context_window(), slot.prompt.n_tokens()); + } else if (slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch) { // near the end of the prompt do_checkpoint = do_checkpoint && true; } else { @@ -2713,10 +2869,31 @@ struct server_context_impl { } if (batch.n_tokens == 0) { - SRV_WRN("%s", "no tokens to decode\n"); + // DDTree slots don't put tokens in the main batch (the driver handles its + // own tree-mode decodes after the main loop). When ddtree_mode is on, the + // main batch can legitimately be empty for several consecutive ticks while + // slots transition through DONE_PROMPT → GENERATING or wait for the next + // request — don't treat that as a hung scheduler. + bool ddtree_active = false; + if (params_base.speculative.ddtree_mode) { + for (const auto & slot : slots) { + if (slot.spec_driver != nullptr || + slot.state == SLOT_STATE_PROCESSING_PROMPT || + slot.state == SLOT_STATE_DONE_PROMPT || + slot.state == SLOT_STATE_STARTED || + slot.state == SLOT_STATE_GENERATING) { + ddtree_active = true; + break; + } + } + } + + if (!ddtree_active) { + SRV_WRN("%s", "no tokens to decode\n"); - if (++n_empty_consecutive > 3) { - GGML_ABORT("fatal error - please provide logs and repro in %s\n", "https://github.com/ggml-org/llama.cpp/pull/20277"); + if (++n_empty_consecutive > 3) { + GGML_ABORT("fatal error - please provide logs and repro in %s\n", "https://github.com/ggml-org/llama.cpp/pull/20277"); + } } } else { n_empty_consecutive = 0; @@ -2795,7 +2972,37 @@ struct server_context_impl { i_next = i + n_tokens; // on successful decode, restore the original batch size - n_batch = llama_n_batch(ctx); + n_batch = n_batch_prompt; + + // DDTree: incrementally ingest the just-decoded ubatch's hidden capture + // into each prompt-processing slot's ring buffer. The capture buffer is + // overwritten on every llama_decode, so we MUST consume it before the + // next inner-loop iteration. Phase 5 is single-slot, so the entire + // batch_view belongs to one slot. + if (params_base.speculative.ddtree_mode && ctx_ddtree_dft) { + for (auto & slot : slots) { + if (slot.state != SLOT_STATE_PROCESSING_PROMPT && + slot.state != SLOT_STATE_DONE_PROMPT) { + continue; + } + if (slot.spec_driver == nullptr) { + llama_ddtree_params dp; + dp.budget = params_base.speculative.ddtree_budget; + dp.temp = params_base.speculative.ddtree_temp; + dp.chain_seed = params_base.speculative.ddtree_chain_seed; + dp.top_k = params_base.speculative.ddtree_top_k; + dp.block_size = 16; + slot.spec_driver = llama_speculative_tree_driver_init(ctx, ctx_ddtree_dft, dp); + if (!slot.spec_driver) { + SLT_ERR(slot, "%s", "failed to allocate DDTree driver during prompt processing\n"); + continue; + } + } + // Append n_tokens columns from this decode's capture buffer to the ring. + llama_speculative_tree_driver_ingest_prompt_capture( + slot.spec_driver, (int32_t)n_tokens); + } + } // handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too for (auto & slot : slots) { @@ -2853,6 +3060,46 @@ struct server_context_impl { // prompt evaluated for next-token prediction slot.state = SLOT_STATE_GENERATING; + if (params_base.speculative.ddtree_mode && ctx_ddtree_dft) { + // DDTree: the driver was lazy-allocated and the ring was filled + // incrementally during prompt prefill (one ingest per inner-loop + // ubatch decode). If something went wrong upstream we may not + // have a driver here — fall back to EOS. + if (!slot.spec_driver) { + SLT_ERR(slot, "%s", "DDTree driver missing at GENERATING transition\n"); + slot.stop = STOP_TYPE_EOS; + slot.has_next_token = false; + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + slot.release(); + continue; + } + + // Greedy-sample the first generated token from the last prompt logit. + const int tok_idx = slot.i_batch - i; + const float * logits = llama_get_logits_ith(ctx, tok_idx); + const int n_vocab = llama_vocab_n_tokens(vocab); + llama_token first_tok = 0; + float best = logits[0]; + for (int v = 1; v < n_vocab; ++v) { + if (logits[v] > best) { best = logits[v]; first_tok = (llama_token)v; } + } + + slot.ddtree_root_tok = first_tok; + slot.ddtree_committed_pos = (llama_pos)slot.prompt.tokens.size(); + slot.i_batch = -1; + // slot.reset() doesn't touch has_next_token; if the previous + // request ended on a stop condition the flag is still false, + // and the DDTree gen block would skip this slot forever. + slot.has_next_token = true; + + slot.t_start_generation = ggml_time_us(); + slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; + metrics.on_prompt_eval(slot); + continue; // will be handled in the DDTree generation loop below + } + if (slot.can_speculate()) { common_speculative_begin(slot.spec, slot.prompt.tokens.get_text_tokens()); } @@ -2860,6 +3107,11 @@ struct server_context_impl { continue; // continue loop of slots } + // DDTree slots run their driver step outside this loop (after llama_decode) + if (params_base.speculative.ddtree_mode && slot.spec_driver) { + continue; + } + if (slot.i_batch_dft.size() > 0) { continue; // sample using speculative decoding } @@ -2962,6 +3214,151 @@ struct server_context_impl { } } + // DDTree generation: run one spec-decode step per slot, after the main llama_decode. + // The driver calls llama_decode on ctx internally (tree-mode batch), so it must run + // outside the main decode loop. + if (params_base.speculative.ddtree_mode) { + for (auto & slot : slots) { + if (slot.state != SLOT_STATE_GENERATING || !slot.spec_driver) { + continue; + } + if (slot.ddtree_root_tok == LLAMA_TOKEN_NULL || !slot.has_next_token) { + continue; + } + + // Grammar-aware verify: clone the slot sampler and let the + // driver pick each chain step via the cloned sampler+grammar. + // The clone gets root_token accepted up front so the first + // sample at row 0 sees the grammar state "after root". + // LLAMA_DDTREE_NO_GRAMMAR_VERIFY=1 disables this and falls back + // to internal argmax (diagnostic; baseline for accept-rate + // comparison). + // Spec verify is grammar-free by default. Reference DFlash / + // DDTree implementations verify with greedy target argmax + // only — grammar / penalties / dry are the main sampler's + // responsibility on commit, not the verify walk's. Running + // grammar inside the verify walk costs ~50 ms / cb on + // tool-call JSON schemas for zero acceptance change + // (batched_exact_diff = 0 measured across multi-turn agent + // runs; see archived/GRAMMAR_VERIFY_DEFAULT_OFF_2026-05-02.md). + // + // LLAMA_DDTREE_GRAMMAR_VERIFY=1 opts back in if a future + // grammar-tight workload needs it. + static const bool s_grammar_verify = []{ + const char * e = getenv("LLAMA_DDTREE_GRAMMAR_VERIFY"); + return e && e[0] == '1'; + }(); + // Opt-in batched-argmax short-circuit. Skips the full sampler + // chain when the driver's raw argmax is already grammar-valid, + // saving ~30 ms per cb call. NOT safe when the chain contains + // score-modifying samplers (penalties/dry/xtc) whose effect + // can shift the argmax: those are exactly the samplers that + // prevent agent reasoning loops, so dropping them lets the + // model re-emit the same tool call indefinitely. Stays off + // unless explicitly enabled by env var; greedy chains with + // only mask-style samplers (top_k/top_p/min_p/temp) can opt in. + static const bool s_batched_shortcircuit = []{ + const char * e = std::getenv("LLAMA_DDTREE_BATCHED_SHORTCIRCUIT"); + return e && e[0] == '1'; + }(); + struct ddtree_verify_state { + common_sampler * smpl; + llama_context * ctx; + bool use_shortcircuit; + }; + ddtree_verify_state vstate{ + /*smpl=*/ (s_grammar_verify && slot.smpl) ? common_sampler_clone(slot.smpl.get()) : nullptr, + /*ctx =*/ ctx, + /*use_shortcircuit=*/ s_batched_shortcircuit, + }; + if (vstate.smpl) { + common_sampler_accept(vstate.smpl, slot.ddtree_root_tok, true); + } + llama_speculative_tree_verify_cbs vcbs{}; + vcbs.user_data = &vstate; + vcbs.sample_cb = [](void * ud, int32_t logits_row_idx, llama_token batched_pick) -> int32_t { + auto * s = (ddtree_verify_state *)ud; + if (!s->smpl) { + return 0; // shouldn't happen; driver falls back if cb null + } + if (s->use_shortcircuit && + batched_pick != LLAMA_TOKEN_NULL && + common_sampler_grammar_token_valid(s->smpl, batched_pick)) { + return (int32_t)batched_pick; + } + return (int32_t)common_sampler_sample(s->smpl, s->ctx, logits_row_idx, /*grammar_first=*/true); + }; + vcbs.advance_cb = [](void * ud, llama_token tok) { + auto * s = (ddtree_verify_state *)ud; + if (s->smpl) { + common_sampler_accept(s->smpl, tok, true); + } + }; + + auto accepted = llama_speculative_tree_driver_step( + slot.spec_driver, slot.ddtree_root_tok, slot.ddtree_committed_pos, + vstate.smpl ? &vcbs : nullptr); + + if (vstate.smpl) { + common_sampler_free(vstate.smpl); + } + + if (accepted.empty()) { + SLT_ERR(slot, "%s", "DDTree driver step returned empty result, treating as EOS\n"); + slot.stop = STOP_TYPE_EOS; + slot.has_next_token = false; + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + slot.release(); + continue; + } + + // accepted: [root_echo, draft_accepted..., bonus] + // commit everything except the bonus token. The driver used a + // grammar-aware verify (via vcbs above) so all accepted tokens + // are guaranteed to be in the sampler+grammar's allowed set. + const int n_committed = (int)accepted.size() - 1; + + const int64_t t_current = ggml_time_us(); + slot.t_token_generation = std::max(1, t_current - slot.t_start_generation) / 1e3; + + bool slot_done = false; + for (int ai = 0; ai < n_committed; ++ai) { + const llama_token tok = accepted[ai]; + + slot.n_decoded += 1; + + completion_token_output result; + result.tok = tok; + result.text_to_send = common_token_to_piece(ctx, tok, accept_special_token(slot, tok)); + result.prob = 1.0f; + + // update sampler history so repetition penalties remain correct + common_sampler_accept(slot.smpl.get(), tok, true); + + // track position in prompt token list + slot.prompt.tokens.push_back(tok); + + if (!process_token(result, slot)) { + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + slot.release(); + slot_done = true; + break; + } + } + + if (!slot_done) { + slot.ddtree_root_tok = accepted.back(); // bonus = next root + slot.ddtree_committed_pos += (llama_pos)n_committed; + slot.n_draft_total += params_base.speculative.ddtree_budget; + slot.n_draft_accepted += n_committed - 1; // root was not a draft, rest were + } + } + } + SRV_DBG("%s", "run slots completed\n"); }