Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@

LLM inference in C/C++

## DFlash / DDTree local notes

This fork carries experimental Qwen3.5 DFlash / DDTree work. Dataset benchmark
setup, Castle commands, correctness caveats, and current llama.cpp-vs-Python
numbers are tracked in [docs/ddtree-dataset-eval-plan.md](docs/ddtree-dataset-eval-plan.md).

## Recent API changes

- [Changelog for `libllama` API](https://github.com/ggml-org/llama.cpp/issues/9289)
Expand Down
6 changes: 6 additions & 0 deletions common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ add_library(${TARGET} STATIC
sampling.h
speculative.cpp
speculative.h
speculative-draft-backend.cpp
speculative-draft-backend.h
speculative-tree.cpp
speculative-tree.h
speculative-tree-driver.cpp
speculative-tree-driver.h
unicode.cpp
unicode.h
jinja/lexer.cpp
Expand Down
54 changes: 54 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,11 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
throw std::invalid_argument("error: --model is required\n");
}

// DDTree mode requires a draft model
if (params.speculative.ddtree_mode && !params.speculative.has_dft()) {
throw std::invalid_argument("error: --speculative-mode ddtree requires -md/--model-draft\n");
}

if (params.escape) {
string_process_escapes(params.prompt);
string_process_escapes(params.input_prefix);
Expand Down Expand Up @@ -3554,6 +3559,55 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.ngram_min_hits = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--speculative-mode"}, "[chain|ddtree]",
"speculative decoding mode: 'chain' = standard draft-model chain (default), "
"'ddtree' = DDTree dflash-draft speculative decoding (requires -md)",
[](common_params & params, const std::string & value) {
if (value == "chain") {
params.speculative.ddtree_mode = false;
} else if (value == "ddtree") {
params.speculative.ddtree_mode = true;
} else {
throw std::invalid_argument("--speculative-mode must be 'chain' or 'ddtree'");
}
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SPECULATIVE_MODE"));
add_opt(common_arg(
{"--ddtree-budget"}, "N",
string_format("DDTree: total tree node budget per spec step (default: %d)", params.speculative.ddtree_budget),
[](common_params & params, int value) {
if (value < 1) {
throw std::invalid_argument("--ddtree-budget must be >= 1");
}
params.speculative.ddtree_budget = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_DDTREE_BUDGET"));
add_opt(common_arg(
{"--ddtree-temp"}, "F",
string_format("DDTree: temperature for draft log-prob extraction (default: %.1f)", (double)params.speculative.ddtree_temp),
[](common_params & params, const std::string & value) {
params.speculative.ddtree_temp = std::stof(value);
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_DDTREE_TEMP"));
add_opt(common_arg(
{"--ddtree-no-chain-seed"},
"DDTree: disable chain-seed greedy initialization (enabled by default)",
[](common_params & params) {
params.speculative.ddtree_chain_seed = false;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--ddtree-top-k"}, "N",
string_format("DDTree: per-position draft top-K width (default: %d, 0 = standalone-compatible auto)",
params.speculative.ddtree_top_k),
[](common_params & params, int value) {
if (value < 0) {
throw std::invalid_argument("--ddtree-top-k must be >= 0");
}
params.speculative.ddtree_top_k = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_DDTREE_TOP_K"));
add_opt(common_arg(
{"-ctkd", "--cache-type-k-draft"}, "TYPE",
string_format(
Expand Down
7 changes: 7 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,13 @@ struct common_params_speculative {
bool has_dft() const {
return !mparams_dft.path.empty() || !mparams_dft.hf_repo.empty();
}

// DDTree speculative decoding parameters (Phase 5)
bool ddtree_mode = false; // true when --speculative-mode ddtree is set
int32_t ddtree_budget = 22; // non-root tree node budget (matches dflash default)
float ddtree_temp = 1.0f; // temperature for draft log-prob extraction
bool ddtree_chain_seed = true; // seed the tree heap with greedy chain (recommended)
int32_t ddtree_top_k = 0; // per-position draft top-K width; 0 = standalone-compatible auto
};

struct common_params_vocoder {
Expand Down
47 changes: 46 additions & 1 deletion common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,22 @@ static bool grammar_should_apply(struct common_sampler * gsmpl) {
return true;
}

bool common_sampler_grammar_token_valid(struct common_sampler * gsmpl, llama_token token) {
if (!gsmpl) {
return false;
}
if (!grammar_should_apply(gsmpl)) {
return true;
}
// Apply the grammar sampler to a single-token candidate array. The grammar
// sampler masks invalid tokens to -INFINITY; we only inspect the result and
// do not advance any sampler state.
llama_token_data single = { token, 1.0f, 0.0f };
llama_token_data_array single_array = { &single, 1, -1, false };
llama_sampler_apply(gsmpl->grmr, &single_array);
return single_array.data[0].logit != -INFINITY;
}

void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
if (!gsmpl) {
return;
Expand Down Expand Up @@ -524,7 +540,36 @@ struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
}

llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
llama_synchronize(ctx);
// Optional sub-timing instrumentation (LLAMA_DDTREE_PROFILE_CB=1).
// Splits the cb cost into the GPU sync wait vs the host-side sampler
// work so we can tell whether prefetching logits would help.
static const bool s_profile_cb = []{
const char * e = std::getenv("LLAMA_DDTREE_PROFILE_CB");
return e && e[0] == '1';
}();
int64_t cb_sync_us = 0;
int64_t cb_work_start_us = 0;
if (s_profile_cb) {
const int64_t t0 = ggml_time_us();
llama_synchronize(ctx);
cb_sync_us = ggml_time_us() - t0;
cb_work_start_us = ggml_time_us();
} else {
llama_synchronize(ctx);
}
struct cb_timing_guard {
bool active;
int64_t sync_us;
int64_t work_start_us;
~cb_timing_guard() {
if (active) {
const double work_ms = (ggml_time_us() - work_start_us) * 1e-3;
fprintf(stderr, "cb_timing: sync=%.3f work=%.3f total=%.3f ms\n",
sync_us * 1e-3, work_ms, sync_us * 1e-3 + work_ms);
}
}
} cb_guard{ s_profile_cb, cb_sync_us, cb_work_start_us };
(void) cb_guard;

// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
const auto tm = gsmpl->tm();
Expand Down
6 changes: 6 additions & 0 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);
//
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);

// Cheap single-token grammar check used by speculative decoding short-circuits.
// Returns true when the token would be accepted by the grammar (or when no
// grammar is active / not currently applicable). Does NOT advance any sampler
// state; safe to call from inside a verify-walk callback.
bool common_sampler_grammar_token_valid(struct common_sampler * gsmpl, llama_token token);

// generalized version of common_sampler_sample
//
// will cross-reference the sampled tokens with a batch of draft tokens and accept those that match
Expand Down
Loading
Loading