@@ -210,7 +210,6 @@ static llama_kv_cache_dsv4_context::comp_plan dsv4_build_comp_plan(
210210 const llama_ubatch & ubatch,
211211 uint32_t ratio,
212212 bool overlap,
213- bool stateful,
214213 uint32_t state_size,
215214 uint32_t kv_size,
216215 uint32_t n_stream) {
@@ -256,12 +255,10 @@ static llama_kv_cache_dsv4_context::comp_plan dsv4_build_comp_plan(
256255
257256 const llama_seq_id seq_id = ubatch.seq_id [i][0 ];
258257
259- if (stateful) {
260- const int64_t stream_off = n_stream > 1 ? (int64_t ) seq_id*state_size : 0 ;
258+ const int64_t stream_off = n_stream > 1 ? (int64_t ) seq_id*state_size : 0 ;
261259
262- plan.state_idxs .push_back ((int32_t ) (stream_off + pos%state_size));
263- plan.state_pos .push_back ((int32_t ) (pos%ratio));
264- }
260+ plan.state_idxs .push_back ((int32_t ) (stream_off + pos%state_size));
261+ plan.state_pos .push_back ((int32_t ) (pos%ratio));
265262
266263 const int64_t n_visible = (int64_t ) (pos + 1 )/ratio;
267264 plan.n_visible [i] = (int32_t ) n_visible;
@@ -273,36 +270,26 @@ static llama_kv_cache_dsv4_context::comp_plan dsv4_build_comp_plan(
273270
274271 const llama_pos source_start = pos + 1 - ratio;
275272
276- if (stateful) {
277- const int64_t cache_off = n_stream > 1 ? (int64_t ) seq_id*kv_size : 0 ;
273+ const int64_t cache_off = n_stream > 1 ? (int64_t ) seq_id*kv_size : 0 ;
278274
279- plan.state_write_idxs .push_back (cache_off + pos/ratio);
280- plan.state_write_pos .push_back ((int32_t ) source_start);
281- plan.state_write_end .push_back ((int32_t ) pos);
275+ plan.state_write_idxs .push_back (cache_off + pos/ratio);
276+ plan.state_write_pos .push_back ((int32_t ) source_start);
277+ plan.state_write_end .push_back ((int32_t ) pos);
282278
283- if (overlap) {
284- const llama_pos prev_start = source_start - ratio;
279+ if (overlap) {
280+ const llama_pos prev_start = source_start - ratio;
285281
286- for (uint32_t j = 0 ; j < ratio; ++j) {
287- plan.state_read_idxs .push_back (state_source_idx (seq_id, prev_start + j));
288- }
289- for (uint32_t j = 0 ; j < ratio; ++j) {
290- plan.state_read_idxs .push_back (state_source_idx (seq_id, source_start + j));
291- }
292- } else {
293- for (uint32_t j = 0 ; j < ratio; ++j) {
294- plan.state_read_idxs .push_back (state_source_idx (seq_id, source_start + j));
295- }
282+ for (uint32_t j = 0 ; j < ratio; ++j) {
283+ plan.state_read_idxs .push_back (state_source_idx (seq_id, prev_start + j));
284+ }
285+ for (uint32_t j = 0 ; j < ratio; ++j) {
286+ plan.state_read_idxs .push_back (state_source_idx (seq_id, source_start + j));
287+ }
288+ } else {
289+ for (uint32_t j = 0 ; j < ratio; ++j) {
290+ plan.state_read_idxs .push_back (state_source_idx (seq_id, source_start + j));
296291 }
297-
298- continue ;
299292 }
300-
301- const int64_t stream_off = n_stream > 1 ? (int64_t ) seq_id*kv_size : 0 ;
302-
303- plan.write_idxs .push_back (stream_off + pos/ratio);
304- plan.write_pos .push_back ((int32_t ) (pos + 1 - ratio));
305- plan.write_end .push_back ((int32_t ) pos);
306293 }
307294
308295 static const bool debug = []() {
@@ -311,11 +298,9 @@ static llama_kv_cache_dsv4_context::comp_plan dsv4_build_comp_plan(
311298 }();
312299
313300 if (debug) {
314- LLAMA_LOG_INFO (" %s: ratio=%u, n_tokens=%u, write_end=%s, state_write_end=%s, pending_end =%s\n " ,
301+ LLAMA_LOG_INFO (" %s: ratio=%u, n_tokens=%u, state_write_end=%s\n " ,
315302 __func__, ratio, ubatch.n_tokens ,
316- dsv4_plan_positions (plan.write_end ).c_str (),
317- dsv4_plan_positions (plan.state_write_end ).c_str (),
318- dsv4_plan_positions (plan.pending_end ).c_str ());
303+ dsv4_plan_positions (plan.state_write_end ).c_str ());
319304 }
320305
321306 return plan;
@@ -325,15 +310,14 @@ static std::vector<llama_kv_cache_dsv4_context::comp_plan> dsv4_build_comp_plans
325310 const std::vector<llama_ubatch> & ubatches,
326311 uint32_t ratio,
327312 bool overlap,
328- bool stateful,
329313 uint32_t state_size,
330314 uint32_t kv_size,
331315 uint32_t n_stream) {
332316 std::vector<llama_kv_cache_dsv4_context::comp_plan> plans;
333317 plans.reserve (ubatches.size ());
334318
335319 for (const llama_ubatch & ubatch : ubatches) {
336- plans.push_back (dsv4_build_comp_plan (ubatch, ratio, overlap, stateful, state_size, kv_size, n_stream));
320+ plans.push_back (dsv4_build_comp_plan (ubatch, ratio, overlap, state_size, kv_size, n_stream));
337321 }
338322
339323 return plans;
@@ -1023,9 +1007,9 @@ llama_kv_cache_dsv4_context::llama_kv_cache_dsv4_context(
10231007 slot_info_vec_t sinfos_raw_swa,
10241008 std::vector<llama_ubatch> ubatches) :
10251009 ubatches(std::move(ubatches)),
1026- plans_csa(dsv4_build_comp_plans(this ->ubatches, DSV4_CSA_RATIO, true , true ,
1010+ plans_csa(dsv4_build_comp_plans(this ->ubatches, DSV4_CSA_RATIO, true ,
10271011 kv->get_csa_state ()->get_state_size(), kv->get_csa()->get_size(), kv->get_csa_state()->get_n_stream())),
1028- plans_hca(dsv4_build_comp_plans(this ->ubatches, DSV4_HCA_RATIO, false , true ,
1012+ plans_hca(dsv4_build_comp_plans(this ->ubatches, DSV4_HCA_RATIO, false ,
10291013 kv->get_hca_state ()->get_state_size(), kv->get_hca()->get_size(), kv->get_hca_state()->get_n_stream())),
10301014 plans_lid(plans_csa),
10311015 ctx_raw(new llama_kv_cache_iswa_context(kv->get_raw (), std::move(sinfos_raw_base), std::move(sinfos_raw_swa), this->ubatches)),
0 commit comments