Skip to content

Commit 73d4d12

Browse files
committed
dflash: harden invalid-draft fallback and split crash tracing
Validate reduced-logits argmax shape and token ids before constructing DFlash drafts so impossible GPU outputs fail closed instead of crashing speculative generation. Treat empty DFlash drafts as the existing one-token fallback path so speculative bookkeeping is cleared and committed suffix state advances normally instead of livelocking against stale committed_len. Move high-volume crash breadcrumbs and per-ubatch state dumps behind GGML_DFLASH_CRASH_TRACE, leaving GGML_DFLASH_DEBUG for the broader diagnostic logs.
1 parent a1080e9 commit 73d4d12

3 files changed

Lines changed: 371 additions & 29 deletions

File tree

common/speculative.cpp

Lines changed: 81 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,24 @@ static bool common_dflash_gpu_ring_allowed(llama_context * ctx_tgt, llama_contex
119119
return true;
120120
}
121121

122+
static bool common_dflash_argmax_token_valid(int32_t token_id, int n_vocab) {
123+
return token_id >= 0 && token_id < n_vocab;
124+
}
125+
126+
static bool common_dflash_argmax_shape_valid(
127+
const char * where,
128+
int rows_available,
129+
int rows_required,
130+
int top_k) {
131+
if (top_k < 1 || rows_available < rows_required) {
132+
LOG_ERR("dflash: invalid reduced-logits shape in %s (rows=%d required=%d top_k=%d)\n",
133+
where, rows_available, rows_required, top_k);
134+
return false;
135+
}
136+
137+
return true;
138+
}
139+
122140
common_dflash_ring_write common_dflash_ring_write_plan(int ring_size, int ring_pos, int n_tokens) {
123141
if (ring_size <= 0 || n_tokens <= 0) {
124142
return { 0, 0, 0 };
@@ -2565,7 +2583,17 @@ struct common_speculative_state_dflash : public common_speculative_state {
25652583
int32_t * argmax = llama_get_logits_argmax(ctx_dft);
25662584
float * argmax_probs = llama_get_logits_argmax_probs(ctx_dft);
25672585
const int K_flat = llama_get_logits_argmax_k(ctx_dft);
2586+
const int argmax_rows = llama_get_logits_argmax_n(ctx_dft);
25682587
if (argmax) {
2588+
const int n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model_dft));
2589+
if (!common_dflash_argmax_shape_valid(__func__, argmax_rows, batch_len, K_flat)) {
2590+
if (draft_log_probs) {
2591+
draft_log_probs->clear();
2592+
}
2593+
result.clear();
2594+
return;
2595+
}
2596+
25692597
// GPU argmax path - only 64-128 bytes transferred instead of 15.9MB
25702598
for (int i = 1; i < batch_len && (int) result.size() < n_draft; ++i) {
25712599
if (argmax_probs && params.p_min > 0.0f && (int) result.size() >= params.n_min) {
@@ -2577,7 +2605,18 @@ struct common_speculative_state_dflash : public common_speculative_state {
25772605
break;
25782606
}
25792607
}
2580-
result.push_back((llama_token) argmax[i * K_flat]);
2608+
const int32_t token_raw = argmax[i * K_flat];
2609+
if (!common_dflash_argmax_token_valid(token_raw, n_vocab)) {
2610+
LOG_ERR("dflash: invalid reduced-logits token %d in %s at row=%d/%d (top_k=%d committed=%d cross_len=%d)\n",
2611+
token_raw, __func__, i, batch_len, K_flat, committed_len, cross_len);
2612+
if (draft_log_probs) {
2613+
draft_log_probs->clear();
2614+
}
2615+
result.clear();
2616+
return;
2617+
}
2618+
2619+
result.push_back((llama_token) token_raw);
25812620
if (draft_log_probs && argmax_probs) {
25822621
draft_log_probs->push_back(argmax_probs[i * K_flat]);
25832622
}
@@ -2692,6 +2731,11 @@ struct common_speculative_state_dflash : public common_speculative_state {
26922731
return;
26932732
}
26942733
float * argmax_probs = llama_get_logits_argmax_probs(ctx_dft);
2734+
const int argmax_rows = llama_get_logits_argmax_n(ctx_dft);
2735+
const int n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model_dft));
2736+
if (!common_dflash_argmax_shape_valid(__func__, argmax_rows, depth_limit + 1, K)) {
2737+
return;
2738+
}
26952739

26962740
// Build tree using best-first heap expansion with chain-seed backbone
26972741
tree.tokens.clear();
@@ -2711,7 +2755,14 @@ struct common_speculative_state_dflash : public common_speculative_state {
27112755
{
27122756
int parent = 0;
27132757
for (int d = 1; d <= depth_limit && d <= main_path_len && tree.n_nodes < tree_budget; ++d) {
2714-
llama_token token_id = (llama_token) argmax[d * K];
2758+
const int32_t token_raw = argmax[d * K];
2759+
if (!common_dflash_argmax_token_valid(token_raw, n_vocab)) {
2760+
LOG_ERR("dflash tree: invalid reduced-logits token %d in %s at depth=%d/%d (top_k=%d)\n",
2761+
token_raw, __func__, d, depth_limit, K);
2762+
break;
2763+
}
2764+
2765+
llama_token token_id = (llama_token) token_raw;
27152766
float log_prob = argmax_probs ? argmax_probs[d * K + 0] : -INFINITY;
27162767

27172768
int current_idx = tree.n_nodes + 1;
@@ -2757,8 +2808,9 @@ struct common_speculative_state_dflash : public common_speculative_state {
27572808
auto top = heap.top();
27582809
heap.pop();
27592810

2760-
llama_token token_id = (llama_token) argmax[top.depth * K + top.rank];
2761-
if (token_id < 0) continue;
2811+
const int32_t token_raw = argmax[top.depth * K + top.rank];
2812+
if (!common_dflash_argmax_token_valid(token_raw, n_vocab)) continue;
2813+
llama_token token_id = (llama_token) token_raw;
27622814
if (tree.child_maps[top.parent_idx].count(token_id)) continue;
27632815

27642816
int current_idx = tree.n_nodes + 1;
@@ -2790,8 +2842,9 @@ struct common_speculative_state_dflash : public common_speculative_state {
27902842
for (int d = 1; d <= main_path_len && (tree.n_nodes - tree.main_path_len) < branch_budget; ++d) {
27912843
int parent_idx = (d == 1) ? 0 : d - 1;
27922844
for (int ki = 1; ki < K && (tree.n_nodes - tree.main_path_len) < branch_budget; ++ki) {
2793-
llama_token alt_token = (llama_token) argmax[d * K + ki];
2794-
if (alt_token < 0) continue;
2845+
const int32_t token_raw = argmax[d * K + ki];
2846+
if (!common_dflash_argmax_token_valid(token_raw, n_vocab)) continue;
2847+
llama_token alt_token = (llama_token) token_raw;
27952848
if (tree.child_maps[parent_idx].count(alt_token)) continue;
27962849

27972850
int current_idx = tree.n_nodes + 1;
@@ -3751,6 +3804,7 @@ void common_speculative_draft_batch(
37513804
int32_t * argmax = llama_get_logits_argmax(ctx_dft);
37523805
float * argmax_probs = llama_get_logits_argmax_probs(ctx_dft);
37533806
const int K_flat = llama_get_logits_argmax_k(ctx_dft);
3807+
const int argmax_rows = llama_get_logits_argmax_n(ctx_dft);
37543808

37553809
for (int r = 0; r < n_ready; r++) {
37563810
auto & rs = ready[r];
@@ -3759,14 +3813,34 @@ void common_speculative_draft_batch(
37593813
const int offset = r * batch_len;
37603814

37613815
if (argmax) {
3816+
const int n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model_dft));
3817+
if (!common_dflash_argmax_shape_valid(__func__, argmax_rows, n_ready * batch_len, K_flat)) {
3818+
if (log_probs) {
3819+
log_probs->clear();
3820+
}
3821+
result.clear();
3822+
return;
3823+
}
3824+
37623825
for (int i = 1; i < batch_len && (int) result.size() < n_draft; i++) {
37633826
if (argmax_probs && params.p_min > 0.0f && (int) result.size() >= params.n_min) {
37643827
float log_prob = argmax_probs[(offset + i) * K_flat];
37653828
if (log_prob < logf(params.p_min)) {
37663829
break;
37673830
}
37683831
}
3769-
result.push_back((llama_token) argmax[(offset + i) * K_flat]);
3832+
const int32_t token_raw = argmax[(offset + i) * K_flat];
3833+
if (!common_dflash_argmax_token_valid(token_raw, n_vocab)) {
3834+
LOG_ERR("dflash batch: invalid reduced-logits token %d in %s at spec=%d row=%d/%d (top_k=%d offset=%d)\n",
3835+
token_raw, __func__, rs.spec_idx, i, batch_len, K_flat, offset);
3836+
if (log_probs) {
3837+
log_probs->clear();
3838+
}
3839+
result.clear();
3840+
break;
3841+
}
3842+
3843+
result.push_back((llama_token) token_raw);
37703844
if (log_probs && argmax_probs) {
37713845
log_probs->push_back(argmax_probs[(offset + i) * K_flat]);
37723846
}

0 commit comments

Comments
 (0)