Skip to content

Commit f9c9734

Browse files
committed
fix bugs
1 parent 20616c1 commit f9c9734

6 files changed

Lines changed: 98 additions & 18 deletions

File tree

ggml/src/ggml-backend.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,7 @@ static bool ggml_is_view_op(enum ggml_op op) {
754754
#endif
755755

756756
#ifndef GGML_SCHED_MAX_SPLIT_INPUTS
757-
#define GGML_SCHED_MAX_SPLIT_INPUTS 30
757+
#define GGML_SCHED_MAX_SPLIT_INPUTS 128
758758
#endif
759759

760760
#ifndef GGML_SCHED_MAX_COPIES

src/llama-graph.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,8 @@ static void dsv4_set_comp_inputs(
697697
uint32_t n_tokens) {
698698
dsv4_set_i32(inp.state_idxs, plan.state_idxs);
699699
dsv4_set_i32(inp.state_pos, plan.state_pos);
700+
dsv4_set_i32(inp.state_persist_src_idxs, plan.state_persist_src_idxs);
701+
dsv4_set_i32(inp.state_persist_dst_idxs, plan.state_persist_dst_idxs);
700702
dsv4_set_i32(inp.state_read_idxs, plan.state_read_idxs);
701703
dsv4_set_i64(inp.state_write_idxs, plan.state_write_idxs);
702704
dsv4_set_i32(inp.state_write_pos, plan.state_write_pos);
@@ -705,8 +707,9 @@ static void dsv4_set_comp_inputs(
705707
dsv4_set_kq_mask(inp.kq_mask, plan, n_tokens);
706708

707709
if (debug || dsv4_compress_debug()) {
708-
LLAMA_LOG_INFO("%s: %s ratio=%u, n_tokens=%u, state_write_end=%s\n",
710+
LLAMA_LOG_INFO("%s: %s ratio=%u, n_tokens=%u, state_persist_dst=%s, state_write_end=%s\n",
709711
__func__, name, plan.ratio, n_tokens,
712+
dsv4_plan_positions(plan.state_persist_dst_idxs).c_str(),
710713
dsv4_plan_positions(plan.state_write_end).c_str());
711714
}
712715
}
@@ -737,6 +740,8 @@ static bool dsv4_can_reuse_comp_input(
737740
bool res = true;
738741
res &= dsv4_can_reuse_tensor_1d(inp.state_idxs, plan.state_idxs.size());
739742
res &= dsv4_can_reuse_tensor_1d(inp.state_pos, plan.state_pos.size());
743+
res &= dsv4_can_reuse_tensor_1d(inp.state_persist_src_idxs, plan.state_persist_src_idxs.size());
744+
res &= dsv4_can_reuse_tensor_1d(inp.state_persist_dst_idxs, plan.state_persist_dst_idxs.size());
740745
res &= dsv4_can_reuse_tensor_1d(inp.state_read_idxs, plan.state_read_idxs.size());
741746
res &= dsv4_can_reuse_tensor_1d(inp.state_write_idxs, plan.state_write_idxs.size());
742747
res &= dsv4_can_reuse_tensor_1d(inp.state_write_pos, plan.state_write_pos.size());
@@ -770,6 +775,8 @@ static void dsv4_build_comp_inputs(
770775
const char * name) {
771776
inp.state_idxs = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_idxs.size(), std::string("dsv4_") + name + "_state_idxs");
772777
inp.state_pos = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_pos.size(), std::string("dsv4_") + name + "_state_pos");
778+
inp.state_persist_src_idxs = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_persist_src_idxs.size(), std::string("dsv4_") + name + "_state_persist_src_idxs");
779+
inp.state_persist_dst_idxs = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_persist_dst_idxs.size(), std::string("dsv4_") + name + "_state_persist_dst_idxs");
773780
inp.state_read_idxs = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_read_idxs.size(), std::string("dsv4_") + name + "_state_read_idxs");
774781
inp.state_write_idxs = dsv4_build_input_1d(ctx, GGML_TYPE_I64, plan.state_write_idxs.size(), std::string("dsv4_") + name + "_state_write_idxs");
775782
inp.state_write_pos = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_write_pos.size(), std::string("dsv4_") + name + "_state_write_pos");

src/llama-graph.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,8 @@ class llm_graph_input_dsv4 : public llm_graph_input_i {
465465
struct comp_input {
466466
ggml_tensor * state_idxs = nullptr; // I32 [n_state]
467467
ggml_tensor * state_pos = nullptr; // I32 [n_state]
468+
ggml_tensor * state_persist_src_idxs = nullptr; // I32 [n_state_persist]
469+
ggml_tensor * state_persist_dst_idxs = nullptr; // I32 [n_state_persist]
468470
ggml_tensor * state_read_idxs = nullptr; // I32 [ratio*n_state_write]
469471
ggml_tensor * state_write_idxs = nullptr; // I64 [n_state_write]
470472
ggml_tensor * state_write_pos = nullptr; // I32 [n_state_write]

src/llama-kv-cache-dsv4.cpp

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,24 @@ static llama_kv_cache_dsv4_context::comp_plan dsv4_build_comp_plan(
219219

220220
const int64_t state_rows = (int64_t) state_size*n_stream;
221221

222+
struct persist_row {
223+
int32_t dst;
224+
int32_t src;
225+
llama_pos pos;
226+
};
227+
228+
std::vector<persist_row> persist_rows;
229+
230+
// For the overlap compressor, build_overlap_compressed_kv_from_state() consumes
231+
// state_read_idxs as two contiguous halves: the first ratio*n_blocks entries are
232+
// the "previous-window" gather indices for every block, followed by the
233+
// "current-window" indices for every block. Collect them separately here and
234+
// append cur after prev once the loop has visited all completed blocks, instead
235+
// of interleaving [prev, cur] per block (which corrupted every block but the
236+
// last in multi-block ubatches / long-context prefill).
237+
std::vector<int32_t> overlap_prev_reads;
238+
std::vector<int32_t> overlap_cur_reads;
239+
222240
const auto current_token_idx = [&](llama_seq_id seq_id, llama_pos pos) -> int64_t {
223241
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
224242
if (ubatch.pos[i] == pos && ubatch.seq_id[i][0] == seq_id) {
@@ -257,9 +275,22 @@ static llama_kv_cache_dsv4_context::comp_plan dsv4_build_comp_plan(
257275

258276
const int64_t stream_off = n_stream > 1 ? (int64_t) seq_id*state_size : 0;
259277

260-
plan.state_idxs.push_back((int32_t) (stream_off + pos%state_size));
278+
const int32_t state_idx = (int32_t) (stream_off + pos%state_size);
279+
280+
plan.state_idxs.push_back(state_idx);
261281
plan.state_pos .push_back((int32_t) (pos%ratio));
262282

283+
const auto it = std::find_if(persist_rows.begin(), persist_rows.end(),
284+
[state_idx](const persist_row & row) {
285+
return row.dst == state_idx;
286+
});
287+
if (it == persist_rows.end()) {
288+
persist_rows.push_back({ state_idx, (int32_t) i, pos });
289+
} else if (pos > it->pos) {
290+
it->src = (int32_t) i;
291+
it->pos = pos;
292+
}
293+
263294
const int64_t n_visible = (int64_t) (pos + 1)/ratio;
264295
plan.n_visible[i] = (int32_t) n_visible;
265296
plan.n_kv = std::max(plan.n_kv, n_visible);
@@ -280,10 +311,10 @@ static llama_kv_cache_dsv4_context::comp_plan dsv4_build_comp_plan(
280311
const llama_pos prev_start = source_start - ratio;
281312

282313
for (uint32_t j = 0; j < ratio; ++j) {
283-
plan.state_read_idxs.push_back(state_source_idx(seq_id, prev_start + j));
314+
overlap_prev_reads.push_back(state_source_idx(seq_id, prev_start + j));
284315
}
285316
for (uint32_t j = 0; j < ratio; ++j) {
286-
plan.state_read_idxs.push_back(state_source_idx(seq_id, source_start + j));
317+
overlap_cur_reads.push_back(state_source_idx(seq_id, source_start + j));
287318
}
288319
} else {
289320
for (uint32_t j = 0; j < ratio; ++j) {
@@ -292,14 +323,34 @@ static llama_kv_cache_dsv4_context::comp_plan dsv4_build_comp_plan(
292323
}
293324
}
294325

326+
if (overlap) {
327+
// [ all blocks' prev-window indices | all blocks' cur-window indices ]
328+
plan.state_read_idxs.reserve(overlap_prev_reads.size() + overlap_cur_reads.size());
329+
plan.state_read_idxs.insert(plan.state_read_idxs.end(),
330+
overlap_prev_reads.begin(), overlap_prev_reads.end());
331+
plan.state_read_idxs.insert(plan.state_read_idxs.end(),
332+
overlap_cur_reads.begin(), overlap_cur_reads.end());
333+
}
334+
335+
std::sort(persist_rows.begin(), persist_rows.end(),
336+
[](const persist_row & a, const persist_row & b) {
337+
return a.dst < b.dst;
338+
});
339+
340+
for (const persist_row & row : persist_rows) {
341+
plan.state_persist_src_idxs.push_back(row.src);
342+
plan.state_persist_dst_idxs.push_back(row.dst);
343+
}
344+
295345
static const bool debug = []() {
296346
const char * env = getenv("LLAMA_DSV4_COMPRESS_DEBUG");
297347
return env && atoi(env) > 0;
298348
}();
299349

300350
if (debug) {
301-
LLAMA_LOG_INFO("%s: ratio=%u, n_tokens=%u, state_write_end=%s\n",
351+
LLAMA_LOG_INFO("%s: ratio=%u, n_tokens=%u, state_persist_dst=%s, state_write_end=%s\n",
302352
__func__, ratio, ubatch.n_tokens,
353+
dsv4_plan_positions(plan.state_persist_dst_idxs).c_str(),
303354
dsv4_plan_positions(plan.state_write_end).c_str());
304355
}
305356

@@ -668,6 +719,12 @@ llama_kv_cache_dsv4::llama_kv_cache_dsv4(
668719
lid_state = std::make_unique<llama_dsv4_comp_state>(
669720
model, offload, unified, n_seq_max, DSV4_CSA_RATIO, 2*DSV4_CSA_RATIO,
670721
2*model.hparams.indexer_head_size, "lid", filter_csa);
722+
723+
// DSV4 attention reads compressed-K / compressor-state rows that the current
724+
// graph does not necessarily overwrite; uninitialized buffer contents would
725+
// otherwise leak in (instance-specific garbage) and corrupt recall. Zero all
726+
// compressed buffers up front so reads of un-written rows are deterministic.
727+
clear_compressed(true);
671728
}
672729

673730
llama_memory_context_ptr llama_kv_cache_dsv4::init_batch(
@@ -766,7 +823,7 @@ void llama_kv_cache_dsv4::clear(bool data) {
766823
restored_trim_pos.clear();
767824

768825
kv_raw->clear(data);
769-
clear_compressed(data);
826+
clear_compressed(true); // DSV4 compressed buffers must never expose stale/uninit rows
770827
}
771828

772829
bool llama_kv_cache_dsv4::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
@@ -808,7 +865,7 @@ bool llama_kv_cache_dsv4::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1
808865
restored_trim_pos.clear();
809866
}
810867

811-
clear_compressed(false);
868+
clear_compressed(true);
812869
}
813870

814871
return res;
@@ -818,28 +875,28 @@ void llama_kv_cache_dsv4::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_ds
818875
restored_trim_pos.clear();
819876

820877
kv_raw->seq_cp(seq_id_src, seq_id_dst, p0, p1);
821-
clear_compressed(false);
878+
clear_compressed(true);
822879
}
823880

824881
void llama_kv_cache_dsv4::seq_keep(llama_seq_id seq_id) {
825882
restored_trim_pos.clear();
826883

827884
kv_raw->seq_keep(seq_id);
828-
clear_compressed(false);
885+
clear_compressed(true);
829886
}
830887

831888
void llama_kv_cache_dsv4::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
832889
restored_trim_pos.clear();
833890

834891
kv_raw->seq_add(seq_id, p0, p1, shift);
835-
clear_compressed(false);
892+
clear_compressed(true);
836893
}
837894

838895
void llama_kv_cache_dsv4::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
839896
restored_trim_pos.clear();
840897

841898
kv_raw->seq_div(seq_id, p0, p1, d);
842-
clear_compressed(false);
899+
clear_compressed(true);
843900
}
844901

845902
llama_pos llama_kv_cache_dsv4::seq_pos_min(llama_seq_id seq_id) const {

src/llama-kv-cache-dsv4.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,11 @@ class llama_kv_cache_dsv4_context : public llama_memory_context_i {
161161
// APE row ids, i.e. pos % ratio, for the compressor-state updates.
162162
std::vector<int32_t> state_pos;
163163

164+
// Current-ubatch source row ids and unique persistent-state
165+
// destination row ids for deterministic ring-state updates.
166+
std::vector<int32_t> state_persist_src_idxs;
167+
std::vector<int32_t> state_persist_dst_idxs;
168+
164169
// Flattened source row ids used for state-backed commits. Source rows
165170
// index the graph-local [persistent_state | current_ubatch_scratch]
166171
// tensor. For overlapped compression the first half is previous rows

src/models/deepseek-v4.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -890,10 +890,13 @@ ggml_tensor * llama_model_deepseek_v4_flash::graph::build_attention(
890890
csa_state_kv = dsv4_with_zero_dep(ctx0, csa_state_kv, csa_state_dep);
891891
csa_state_score = dsv4_with_zero_dep(ctx0, csa_state_score, csa_state_dep);
892892

893+
ggml_tensor * csa_persist_kv = ggml_get_rows(ctx0, csa_state_kv, inp_dsv4->get_csa().state_persist_src_idxs);
894+
ggml_tensor * csa_persist_score = ggml_get_rows(ctx0, csa_state_score, inp_dsv4->get_csa().state_persist_src_idxs);
895+
893896
csa_state_kv = inp_dsv4->mctx->get_csa_state()->cpy_kv(ctx0,
894-
csa_state_kv, inp_dsv4->get_csa().state_idxs, il);
897+
csa_persist_kv, inp_dsv4->get_csa().state_persist_dst_idxs, il);
895898
csa_state_score = inp_dsv4->mctx->get_csa_state()->cpy_score(ctx0,
896-
csa_state_score, inp_dsv4->get_csa().state_idxs, il);
899+
csa_persist_score, inp_dsv4->get_csa().state_persist_dst_idxs, il);
897900

898901
ggml_build_forward_expand(gf, csa_state_kv);
899902
ggml_build_forward_expand(gf, csa_state_score);
@@ -946,10 +949,13 @@ ggml_tensor * llama_model_deepseek_v4_flash::graph::build_attention(
946949
lid_state_kv = dsv4_with_zero_dep(ctx0, lid_state_kv, lid_state_dep);
947950
lid_state_score = dsv4_with_zero_dep(ctx0, lid_state_score, lid_state_dep);
948951

952+
ggml_tensor * lid_persist_kv = ggml_get_rows(ctx0, lid_state_kv, inp_dsv4->get_lid().state_persist_src_idxs);
953+
ggml_tensor * lid_persist_score = ggml_get_rows(ctx0, lid_state_score, inp_dsv4->get_lid().state_persist_src_idxs);
954+
949955
lid_state_kv = inp_dsv4->mctx->get_lid_state()->cpy_kv(ctx0,
950-
lid_state_kv, inp_dsv4->get_lid().state_idxs, il);
956+
lid_persist_kv, inp_dsv4->get_lid().state_persist_dst_idxs, il);
951957
lid_state_score = inp_dsv4->mctx->get_lid_state()->cpy_score(ctx0,
952-
lid_state_score, inp_dsv4->get_lid().state_idxs, il);
958+
lid_persist_score, inp_dsv4->get_lid().state_persist_dst_idxs, il);
953959

954960
ggml_build_forward_expand(gf, lid_state_kv);
955961
ggml_build_forward_expand(gf, lid_state_score);
@@ -987,10 +993,13 @@ ggml_tensor * llama_model_deepseek_v4_flash::graph::build_attention(
987993
hca_state_kv = dsv4_with_zero_dep(ctx0, hca_state_kv, hca_state_dep);
988994
hca_state_score = dsv4_with_zero_dep(ctx0, hca_state_score, hca_state_dep);
989995

996+
ggml_tensor * hca_persist_kv = ggml_get_rows(ctx0, hca_state_kv, inp_dsv4->get_hca().state_persist_src_idxs);
997+
ggml_tensor * hca_persist_score = ggml_get_rows(ctx0, hca_state_score, inp_dsv4->get_hca().state_persist_src_idxs);
998+
990999
hca_state_kv = inp_dsv4->mctx->get_hca_state()->cpy_kv(ctx0,
991-
hca_state_kv, inp_dsv4->get_hca().state_idxs, il);
1000+
hca_persist_kv, inp_dsv4->get_hca().state_persist_dst_idxs, il);
9921001
hca_state_score = inp_dsv4->mctx->get_hca_state()->cpy_score(ctx0,
993-
hca_state_score, inp_dsv4->get_hca().state_idxs, il);
1002+
hca_persist_score, inp_dsv4->get_hca().state_persist_dst_idxs, il);
9941003

9951004
ggml_build_forward_expand(gf, hca_state_kv);
9961005
ggml_build_forward_expand(gf, hca_state_score);

0 commit comments

Comments
 (0)