Skip to content

Commit 5e7594c

Browse files
committed
qwen35: fix do_spec_decode argmax OOB on prefix-cache partial restore
`n_last_chunk = committed % PREFILL_UBATCH` only equals the last prefill chunk's actual size when prefill started at kv_offset=0. With prefix-cache partial restore, `restore_and_generate` runs delta-prefill from kv_offset>0, so the last chunk's `n_tokens` is `prompt_len - kv_offset`, not the modulo of `committed` over PREFILL_UBATCH. The read offset was then larger than sg_.argmax_tokens->ne[0], firing the "tensor read out of bounds" assert on the first DFlash spec-decode request against any prompt the cache had already seen. Read the actual last-chunk size from sg_.argmax_tokens->ne[0], which the graph builder sized to match the bound chunk. No-op when kv_offset==0 (`committed % UBATCH == ne[0]`).
1 parent 230c303 commit 5e7594c

1 file changed

Lines changed: 10 additions & 4 deletions

File tree

dflash/src/qwen35/qwen35_backend.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -832,11 +832,17 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen,
832832
// holds argmax for ALL positions in that chunk. We need the LAST position.
833833
int32_t last_tok;
834834
{
835-
const int PREFILL_UBATCH = 512;
836-
int n_last_chunk = committed % PREFILL_UBATCH;
837-
if (n_last_chunk == 0) n_last_chunk = PREFILL_UBATCH;
835+
// The last prefill chunk's size is whatever the still-bound argmax
836+
// tensor says. Deriving it from `committed % PREFILL_UBATCH` was wrong
837+
// when prefix-cache partial restore made the chunk shorter than the
838+
// committed position would suggest (delta-prefill from kv_offset != 0).
839+
const int64_t n_last_chunk = sg_.argmax_tokens ? sg_.argmax_tokens->ne[0] : 0;
840+
if (n_last_chunk <= 0) {
841+
std::fprintf(stderr, "do_spec_decode: argmax_tokens missing or empty\n");
842+
return false;
843+
}
838844
ggml_backend_tensor_get(sg_.argmax_tokens, &last_tok,
839-
sizeof(int32_t) * (n_last_chunk - 1),
845+
sizeof(int32_t) * (size_t)(n_last_chunk - 1),
840846
sizeof(int32_t));
841847
}
842848

0 commit comments

Comments
 (0)