@@ -219,6 +219,24 @@ static llama_kv_cache_dsv4_context::comp_plan dsv4_build_comp_plan(
219219
220220 const int64_t state_rows = (int64_t ) state_size*n_stream;
221221
222+ struct persist_row {
223+ int32_t dst;
224+ int32_t src;
225+ llama_pos pos;
226+ };
227+
228+ std::vector<persist_row> persist_rows;
229+
230+ // For the overlap compressor, build_overlap_compressed_kv_from_state() consumes
231+ // state_read_idxs as two contiguous halves: the first ratio*n_blocks entries are
232+ // the "previous-window" gather indices for every block, followed by the
233+ // "current-window" indices for every block. Collect them separately here and
234+ // append cur after prev once the loop has visited all completed blocks, instead
235+ // of interleaving [prev, cur] per block (which corrupted every block but the
236+ // last in multi-block ubatches / long-context prefill).
237+ std::vector<int32_t > overlap_prev_reads;
238+ std::vector<int32_t > overlap_cur_reads;
239+
222240 const auto current_token_idx = [&](llama_seq_id seq_id, llama_pos pos) -> int64_t {
223241 for (uint32_t i = 0 ; i < ubatch.n_tokens ; ++i) {
224242 if (ubatch.pos [i] == pos && ubatch.seq_id [i][0 ] == seq_id) {
@@ -257,9 +275,22 @@ static llama_kv_cache_dsv4_context::comp_plan dsv4_build_comp_plan(
257275
258276 const int64_t stream_off = n_stream > 1 ? (int64_t ) seq_id*state_size : 0 ;
259277
260- plan.state_idxs .push_back ((int32_t ) (stream_off + pos%state_size));
278+ const int32_t state_idx = (int32_t ) (stream_off + pos%state_size);
279+
280+ plan.state_idxs .push_back (state_idx);
261281 plan.state_pos .push_back ((int32_t ) (pos%ratio));
262282
283+ const auto it = std::find_if (persist_rows.begin (), persist_rows.end (),
284+ [state_idx](const persist_row & row) {
285+ return row.dst == state_idx;
286+ });
287+ if (it == persist_rows.end ()) {
288+ persist_rows.push_back ({ state_idx, (int32_t ) i, pos });
289+ } else if (pos > it->pos ) {
290+ it->src = (int32_t ) i;
291+ it->pos = pos;
292+ }
293+
263294 const int64_t n_visible = (int64_t ) (pos + 1 )/ratio;
264295 plan.n_visible [i] = (int32_t ) n_visible;
265296 plan.n_kv = std::max (plan.n_kv , n_visible);
@@ -280,10 +311,10 @@ static llama_kv_cache_dsv4_context::comp_plan dsv4_build_comp_plan(
280311 const llama_pos prev_start = source_start - ratio;
281312
282313 for (uint32_t j = 0 ; j < ratio; ++j) {
283- plan. state_read_idxs .push_back (state_source_idx (seq_id, prev_start + j));
314+ overlap_prev_reads .push_back (state_source_idx (seq_id, prev_start + j));
284315 }
285316 for (uint32_t j = 0 ; j < ratio; ++j) {
286- plan. state_read_idxs .push_back (state_source_idx (seq_id, source_start + j));
317+ overlap_cur_reads .push_back (state_source_idx (seq_id, source_start + j));
287318 }
288319 } else {
289320 for (uint32_t j = 0 ; j < ratio; ++j) {
@@ -292,14 +323,34 @@ static llama_kv_cache_dsv4_context::comp_plan dsv4_build_comp_plan(
292323 }
293324 }
294325
326+ if (overlap) {
327+ // [ all blocks' prev-window indices | all blocks' cur-window indices ]
328+ plan.state_read_idxs .reserve (overlap_prev_reads.size () + overlap_cur_reads.size ());
329+ plan.state_read_idxs .insert (plan.state_read_idxs .end (),
330+ overlap_prev_reads.begin (), overlap_prev_reads.end ());
331+ plan.state_read_idxs .insert (plan.state_read_idxs .end (),
332+ overlap_cur_reads.begin (), overlap_cur_reads.end ());
333+ }
334+
335+ std::sort (persist_rows.begin (), persist_rows.end (),
336+ [](const persist_row & a, const persist_row & b) {
337+ return a.dst < b.dst ;
338+ });
339+
340+ for (const persist_row & row : persist_rows) {
341+ plan.state_persist_src_idxs .push_back (row.src );
342+ plan.state_persist_dst_idxs .push_back (row.dst );
343+ }
344+
295345 static const bool debug = []() {
296346 const char * env = getenv (" LLAMA_DSV4_COMPRESS_DEBUG" );
297347 return env && atoi (env) > 0 ;
298348 }();
299349
300350 if (debug) {
301- LLAMA_LOG_INFO (" %s: ratio=%u, n_tokens=%u, state_write_end=%s\n " ,
351+ LLAMA_LOG_INFO (" %s: ratio=%u, n_tokens=%u, state_persist_dst=%s, state_write_end=%s\n " ,
302352 __func__, ratio, ubatch.n_tokens ,
353+ dsv4_plan_positions (plan.state_persist_dst_idxs ).c_str (),
303354 dsv4_plan_positions (plan.state_write_end ).c_str ());
304355 }
305356
@@ -668,6 +719,12 @@ llama_kv_cache_dsv4::llama_kv_cache_dsv4(
668719 lid_state = std::make_unique<llama_dsv4_comp_state>(
669720 model, offload, unified, n_seq_max, DSV4_CSA_RATIO, 2 *DSV4_CSA_RATIO,
670721 2 *model.hparams .indexer_head_size , " lid" , filter_csa);
722+
723+ // DSV4 attention reads compressed-K / compressor-state rows that the current
724+ // graph does not necessarily overwrite; uninitialized buffer contents would
725+ // otherwise leak in (instance-specific garbage) and corrupt recall. Zero all
726+ // compressed buffers up front so reads of un-written rows are deterministic.
727+ clear_compressed (true );
671728}
672729
673730llama_memory_context_ptr llama_kv_cache_dsv4::init_batch (
@@ -766,7 +823,7 @@ void llama_kv_cache_dsv4::clear(bool data) {
766823 restored_trim_pos.clear ();
767824
768825 kv_raw->clear (data);
769- clear_compressed (data);
826+ clear_compressed (true ); // DSV4 compressed buffers must never expose stale/uninit rows
770827}
771828
772829bool llama_kv_cache_dsv4::seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
@@ -808,7 +865,7 @@ bool llama_kv_cache_dsv4::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1
808865 restored_trim_pos.clear ();
809866 }
810867
811- clear_compressed (false );
868+ clear_compressed (true );
812869 }
813870
814871 return res;
@@ -818,28 +875,28 @@ void llama_kv_cache_dsv4::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_ds
818875 restored_trim_pos.clear ();
819876
820877 kv_raw->seq_cp (seq_id_src, seq_id_dst, p0, p1);
821- clear_compressed (false );
878+ clear_compressed (true );
822879}
823880
824881void llama_kv_cache_dsv4::seq_keep (llama_seq_id seq_id) {
825882 restored_trim_pos.clear ();
826883
827884 kv_raw->seq_keep (seq_id);
828- clear_compressed (false );
885+ clear_compressed (true );
829886}
830887
831888void llama_kv_cache_dsv4::seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
832889 restored_trim_pos.clear ();
833890
834891 kv_raw->seq_add (seq_id, p0, p1, shift);
835- clear_compressed (false );
892+ clear_compressed (true );
836893}
837894
838895void llama_kv_cache_dsv4::seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
839896 restored_trim_pos.clear ();
840897
841898 kv_raw->seq_div (seq_id, p0, p1, d);
842- clear_compressed (false );
899+ clear_compressed (true );
843900}
844901
845902llama_pos llama_kv_cache_dsv4::seq_pos_min (llama_seq_id seq_id) const {
0 commit comments