@@ -209,38 +209,6 @@ llama_memory_context_ptr llama_kv_cache_iswa::init_batch(llama_batch_allocr & ba
209209 return std::make_unique<llama_kv_cache_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE );
210210}
211211
212- llama_memory_context_ptr llama_kv_cache_iswa::init_batch_with_sinfos (
213- llama_batch_allocr & balloc,
214- uint32_t n_ubatch,
215- const llama_kv_cache::slot_info_vec_t & sinfos,
216- bool is_inplace_update) {
217- if (sinfos.empty ()) {
218- return std::make_unique<llama_kv_cache_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE );
219- }
220-
221- balloc.split_reset ();
222-
223- std::vector<llama_ubatch> ubatches;
224- const uint32_t n_stream = kv_base->get_n_stream ();
225- while (true ) {
226- auto ubatch = n_stream == 1 ? balloc.split_simple (n_ubatch) : balloc.split_equal (n_ubatch, true );
227- if (ubatch.n_tokens == 0 ) {
228- break ;
229- }
230- ubatches.push_back (std::move (ubatch)); // NOLINT
231- }
232-
233- if (ubatches.size () != sinfos.size ()) {
234- return std::make_unique<llama_kv_cache_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE );
235- }
236-
237- auto sinfos_base = sinfos;
238- auto sinfos_swa = sinfos;
239-
240- return std::make_unique<llama_kv_cache_iswa_context>(
241- this , std::move (sinfos_base), std::move (sinfos_swa), std::move (ubatches), is_inplace_update);
242- }
243-
244212llama_memory_context_ptr llama_kv_cache_iswa::init_full () {
245213 return std::make_unique<llama_kv_cache_iswa_context>(this );
246214}
@@ -279,6 +247,20 @@ llama_kv_cache * llama_kv_cache_iswa::get_swa() const {
279247 return kv_swa.get ();
280248}
281249
250+ void llama_kv_cache_iswa::set_swa_reuse_guard (llama_pos query_pos) {
251+ kv_base->clear_swa_reuse_guard ();
252+ kv_swa->set_swa_reuse_guard (query_pos);
253+ }
254+
255+ void llama_kv_cache_iswa::clear_swa_reuse_guard () {
256+ kv_base->clear_swa_reuse_guard ();
257+ kv_swa->clear_swa_reuse_guard ();
258+ }
259+
260+ bool llama_kv_cache_iswa::consume_swa_reuse_guard_block_prepare () {
261+ return kv_swa->consume_swa_reuse_guard_block_prepare ();
262+ }
263+
282264//
283265// llama_kv_cache_iswa_context
284266//
@@ -313,19 +295,6 @@ llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
313295 status(llama_memory_status_combine(ctx_base->get_status (), ctx_swa->get_status())) {
314296}
315297
316- llama_kv_cache_iswa_context::llama_kv_cache_iswa_context (
317- llama_kv_cache_iswa * kv,
318- slot_info_vec_t sinfos_base,
319- slot_info_vec_t sinfos_swa,
320- std::vector<llama_ubatch> ubatches,
321- bool is_inplace_update) :
322- ubatches(std::move(ubatches)),
323- // note: here we copy the ubatches. not sure if this is ideal
324- ctx_base(new llama_kv_cache_context(kv->get_base (), std::move(sinfos_base), this->ubatches, is_inplace_update)),
325- ctx_swa (new llama_kv_cache_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches, is_inplace_update)),
326- status(llama_memory_status_combine(ctx_base->get_status (), ctx_swa->get_status())) {
327- }
328-
329298llama_kv_cache_iswa_context:: ~llama_kv_cache_iswa_context () = default ;
330299
331300bool llama_kv_cache_iswa_context::next () {
@@ -373,12 +342,3 @@ const llama_kv_cache_context * llama_kv_cache_iswa_context::get_swa() const {
373342
374343 return static_cast <const llama_kv_cache_context *>(ctx_swa.get ());
375344}
376-
377- void llama_kv_cache_iswa_context::set_inplace (bool value) {
378- auto * base = const_cast <llama_kv_cache_context *>(
379- static_cast <const llama_kv_cache_context *>(ctx_base.get ()));
380- auto * swa = const_cast <llama_kv_cache_context *>(
381- static_cast <const llama_kv_cache_context *>(ctx_swa.get ()));
382- if (base) { base->set_inplace (value); }
383- if (swa) { swa ->set_inplace (value); }
384- }
0 commit comments