@@ -119,6 +119,24 @@ static bool common_dflash_gpu_ring_allowed(llama_context * ctx_tgt, llama_contex
119119 return true ;
120120}
121121
122+ static bool common_dflash_argmax_token_valid (int32_t token_id, int n_vocab) {
123+ return token_id >= 0 && token_id < n_vocab;
124+ }
125+
126+ static bool common_dflash_argmax_shape_valid (
127+ const char * where,
128+ int rows_available,
129+ int rows_required,
130+ int top_k) {
131+ if (top_k < 1 || rows_available < rows_required) {
132+ LOG_ERR (" dflash: invalid reduced-logits shape in %s (rows=%d required=%d top_k=%d)\n " ,
133+ where, rows_available, rows_required, top_k);
134+ return false ;
135+ }
136+
137+ return true ;
138+ }
139+
122140common_dflash_ring_write common_dflash_ring_write_plan (int ring_size, int ring_pos, int n_tokens) {
123141 if (ring_size <= 0 || n_tokens <= 0 ) {
124142 return { 0 , 0 , 0 };
@@ -2565,7 +2583,17 @@ struct common_speculative_state_dflash : public common_speculative_state {
25652583 int32_t * argmax = llama_get_logits_argmax (ctx_dft);
25662584 float * argmax_probs = llama_get_logits_argmax_probs (ctx_dft);
25672585 const int K_flat = llama_get_logits_argmax_k (ctx_dft);
2586+ const int argmax_rows = llama_get_logits_argmax_n (ctx_dft);
25682587 if (argmax) {
2588+ const int n_vocab = llama_vocab_n_tokens (llama_model_get_vocab (model_dft));
2589+ if (!common_dflash_argmax_shape_valid (__func__, argmax_rows, batch_len, K_flat)) {
2590+ if (draft_log_probs) {
2591+ draft_log_probs->clear ();
2592+ }
2593+ result.clear ();
2594+ return ;
2595+ }
2596+
25692597 // GPU argmax path - only 64-128 bytes transferred instead of 15.9MB
25702598 for (int i = 1 ; i < batch_len && (int ) result.size () < n_draft; ++i) {
25712599 if (argmax_probs && params.p_min > 0 .0f && (int ) result.size () >= params.n_min ) {
@@ -2577,7 +2605,18 @@ struct common_speculative_state_dflash : public common_speculative_state {
25772605 break ;
25782606 }
25792607 }
2580- result.push_back ((llama_token) argmax[i * K_flat]);
2608+ const int32_t token_raw = argmax[i * K_flat];
2609+ if (!common_dflash_argmax_token_valid (token_raw, n_vocab)) {
2610+ LOG_ERR (" dflash: invalid reduced-logits token %d in %s at row=%d/%d (top_k=%d committed=%d cross_len=%d)\n " ,
2611+ token_raw, __func__, i, batch_len, K_flat, committed_len, cross_len);
2612+ if (draft_log_probs) {
2613+ draft_log_probs->clear ();
2614+ }
2615+ result.clear ();
2616+ return ;
2617+ }
2618+
2619+ result.push_back ((llama_token) token_raw);
25812620 if (draft_log_probs && argmax_probs) {
25822621 draft_log_probs->push_back (argmax_probs[i * K_flat]);
25832622 }
@@ -2692,6 +2731,11 @@ struct common_speculative_state_dflash : public common_speculative_state {
26922731 return ;
26932732 }
26942733 float * argmax_probs = llama_get_logits_argmax_probs (ctx_dft);
2734+ const int argmax_rows = llama_get_logits_argmax_n (ctx_dft);
2735+ const int n_vocab = llama_vocab_n_tokens (llama_model_get_vocab (model_dft));
2736+ if (!common_dflash_argmax_shape_valid (__func__, argmax_rows, depth_limit + 1 , K)) {
2737+ return ;
2738+ }
26952739
26962740 // Build tree using best-first heap expansion with chain-seed backbone
26972741 tree.tokens .clear ();
@@ -2711,7 +2755,14 @@ struct common_speculative_state_dflash : public common_speculative_state {
27112755 {
27122756 int parent = 0 ;
27132757 for (int d = 1 ; d <= depth_limit && d <= main_path_len && tree.n_nodes < tree_budget; ++d) {
2714- llama_token token_id = (llama_token) argmax[d * K];
2758+ const int32_t token_raw = argmax[d * K];
2759+ if (!common_dflash_argmax_token_valid (token_raw, n_vocab)) {
2760+ LOG_ERR (" dflash tree: invalid reduced-logits token %d in %s at depth=%d/%d (top_k=%d)\n " ,
2761+ token_raw, __func__, d, depth_limit, K);
2762+ break ;
2763+ }
2764+
2765+ llama_token token_id = (llama_token) token_raw;
27152766 float log_prob = argmax_probs ? argmax_probs[d * K + 0 ] : -INFINITY ;
27162767
27172768 int current_idx = tree.n_nodes + 1 ;
@@ -2757,8 +2808,9 @@ struct common_speculative_state_dflash : public common_speculative_state {
27572808 auto top = heap.top ();
27582809 heap.pop ();
27592810
2760- llama_token token_id = (llama_token) argmax[top.depth * K + top.rank ];
2761- if (token_id < 0 ) continue ;
2811+ const int32_t token_raw = argmax[top.depth * K + top.rank ];
2812+ if (!common_dflash_argmax_token_valid (token_raw, n_vocab)) continue ;
2813+ llama_token token_id = (llama_token) token_raw;
27622814 if (tree.child_maps [top.parent_idx ].count (token_id)) continue ;
27632815
27642816 int current_idx = tree.n_nodes + 1 ;
@@ -2790,8 +2842,9 @@ struct common_speculative_state_dflash : public common_speculative_state {
27902842 for (int d = 1 ; d <= main_path_len && (tree.n_nodes - tree.main_path_len ) < branch_budget; ++d) {
27912843 int parent_idx = (d == 1 ) ? 0 : d - 1 ;
27922844 for (int ki = 1 ; ki < K && (tree.n_nodes - tree.main_path_len ) < branch_budget; ++ki) {
2793- llama_token alt_token = (llama_token) argmax[d * K + ki];
2794- if (alt_token < 0 ) continue ;
2845+ const int32_t token_raw = argmax[d * K + ki];
2846+ if (!common_dflash_argmax_token_valid (token_raw, n_vocab)) continue ;
2847+ llama_token alt_token = (llama_token) token_raw;
27952848 if (tree.child_maps [parent_idx].count (alt_token)) continue ;
27962849
27972850 int current_idx = tree.n_nodes + 1 ;
@@ -3751,6 +3804,7 @@ void common_speculative_draft_batch(
37513804 int32_t * argmax = llama_get_logits_argmax (ctx_dft);
37523805 float * argmax_probs = llama_get_logits_argmax_probs (ctx_dft);
37533806 const int K_flat = llama_get_logits_argmax_k (ctx_dft);
3807+ const int argmax_rows = llama_get_logits_argmax_n (ctx_dft);
37543808
37553809 for (int r = 0 ; r < n_ready; r++) {
37563810 auto & rs = ready[r];
@@ -3759,14 +3813,34 @@ void common_speculative_draft_batch(
37593813 const int offset = r * batch_len;
37603814
37613815 if (argmax) {
3816+ const int n_vocab = llama_vocab_n_tokens (llama_model_get_vocab (model_dft));
3817+ if (!common_dflash_argmax_shape_valid (__func__, argmax_rows, n_ready * batch_len, K_flat)) {
3818+ if (log_probs) {
3819+ log_probs->clear ();
3820+ }
3821+ result.clear ();
3822+ return ;
3823+ }
3824+
37623825 for (int i = 1 ; i < batch_len && (int ) result.size () < n_draft; i++) {
37633826 if (argmax_probs && params.p_min > 0 .0f && (int ) result.size () >= params.n_min ) {
37643827 float log_prob = argmax_probs[(offset + i) * K_flat];
37653828 if (log_prob < logf (params.p_min )) {
37663829 break ;
37673830 }
37683831 }
3769- result.push_back ((llama_token) argmax[(offset + i) * K_flat]);
3832+ const int32_t token_raw = argmax[(offset + i) * K_flat];
3833+ if (!common_dflash_argmax_token_valid (token_raw, n_vocab)) {
3834+ LOG_ERR (" dflash batch: invalid reduced-logits token %d in %s at spec=%d row=%d/%d (top_k=%d offset=%d)\n " ,
3835+ token_raw, __func__, rs.spec_idx , i, batch_len, K_flat, offset);
3836+ if (log_probs) {
3837+ log_probs->clear ();
3838+ }
3839+ result.clear ();
3840+ break ;
3841+ }
3842+
3843+ result.push_back ((llama_token) token_raw);
37703844 if (log_probs && argmax_probs) {
37713845 log_probs->push_back (argmax_probs[(offset + i) * K_flat]);
37723846 }
0 commit comments