@@ -19,7 +19,9 @@ namespace {
1919constexpr uint32_t KVAR_N_GROUP = 128 ;
2020constexpr uint32_t KVAR_N_STAGE_GROUPS = 3 ;
2121constexpr uint32_t KVAR_N_STATE_MAGIC = 0x4e52564b ; // "KVRN"
22- constexpr uint32_t KVAR_N_STATE_VERSION = 3 ;
22+ constexpr uint32_t KVAR_N_STATE_VERSION = 4 ;
23+ constexpr uint32_t KVAR_N_STATE_RECORDS_FULL = 0 ;
24+ constexpr uint32_t KVAR_N_STATE_STAGE_ONLY_PARTIAL = 1 ;
2325
2426bool kvarn_backend_supports_native_ops (ggml_backend_dev_t dev) {
2527 if (dev == nullptr || ggml_backend_dev_type (dev) == GGML_BACKEND_DEVICE_TYPE_CPU ) {
@@ -694,6 +696,7 @@ bool llama_kv_cache_kvarn::apply_pending_stream_copies(llama_context * lctx) {
694696
695697void llama_kv_cache_kvarn::state_write (llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
696698 metadata->state_write (io, seq_id, flags);
699+ const bool partial_state = (flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY ) != 0 && seq_id >= 0 ;
697700
698701 std::vector<uint32_t > saved_streams;
699702 if (seq_id == -1 ) {
@@ -718,6 +721,8 @@ void llama_kv_cache_kvarn::state_write(llama_io_write_i & io, llama_seq_id seq_i
718721 for (const uint32_t stream : saved_streams) {
719722 io.write (&stream, sizeof (stream));
720723 }
724+ const uint32_t state_kind = partial_state ? KVAR_N_STATE_STAGE_ONLY_PARTIAL : KVAR_N_STATE_RECORDS_FULL ;
725+ io.write (&state_kind, sizeof (state_kind));
721726
722727 // n_groups_used is single-valued across all saved streams. This is correct
723728 // because when seq_id >= 0, saved_streams has exactly 1 entry (the stream
@@ -743,10 +748,12 @@ void llama_kv_cache_kvarn::state_write(llama_io_write_i & io, llama_seq_id seq_i
743748 for (const uint32_t stream : saved_streams) {
744749 io.write (&stream, sizeof (stream));
745750
746- const size_t k_records_used = n_groups_used * layer.k_records_stream [stream]->nb [2 ];
747- const size_t v_records_used = n_groups_used * layer.v_records_stream [stream]->nb [2 ];
748- write_kvarn_tensor_slice (io, layer.k_records_stream [stream], 0 , k_records_used);
749- write_kvarn_tensor_slice (io, layer.v_records_stream [stream], 0 , v_records_used);
751+ if (state_kind == KVAR_N_STATE_RECORDS_FULL ) {
752+ const size_t k_records_used = n_groups_used * layer.k_records_stream [stream]->nb [2 ];
753+ const size_t v_records_used = n_groups_used * layer.v_records_stream [stream]->nb [2 ];
754+ write_kvarn_tensor_slice (io, layer.k_records_stream [stream], 0 , k_records_used);
755+ write_kvarn_tensor_slice (io, layer.v_records_stream [stream], 0 , v_records_used);
756+ }
750757 write_kvarn_tensor (io, layer.k_stage_stream [stream]);
751758 write_kvarn_tensor (io, layer.v_stage_stream [stream]);
752759 }
@@ -798,6 +805,14 @@ void llama_kv_cache_kvarn::state_read(llama_io_read_i & io, llama_seq_id seq_id,
798805 }
799806 }
800807
808+ uint32_t state_kind = KVAR_N_STATE_RECORDS_FULL ;
809+ if (version >= 4 ) {
810+ io.read (&state_kind, sizeof (state_kind));
811+ if (state_kind != KVAR_N_STATE_RECORDS_FULL && state_kind != KVAR_N_STATE_STAGE_ONLY_PARTIAL ) {
812+ throw std::runtime_error (" invalid KVarN cache state kind" );
813+ }
814+ }
815+
801816 uint32_t n_groups_used = n_groups_per_stream;
802817 if (version >= 3 ) {
803818 io.read (&n_groups_used, sizeof (n_groups_used));
@@ -810,6 +825,11 @@ void llama_kv_cache_kvarn::state_read(llama_io_read_i & io, llama_seq_id seq_id,
810825 if (seq_id != -1 && seq_stream >= n_stream) {
811826 throw std::runtime_error (" invalid KVarN sequence stream" );
812827 }
828+ if (state_kind == KVAR_N_STATE_STAGE_ONLY_PARTIAL ) {
829+ if (seq_id < 0 ) {
830+ throw std::runtime_error (" KVarN stage-only state requires a destination sequence" );
831+ }
832+ }
813833
814834 for (const auto & layer : layers) {
815835 uint32_t il;
@@ -833,11 +853,13 @@ void llama_kv_cache_kvarn::state_read(llama_io_read_i & io, llama_seq_id seq_id,
833853 const size_t k_records_total = n_groups_per_stream * layer.k_records_stream [stream_dst]->nb [2 ];
834854 const size_t v_records_total = n_groups_per_stream * layer.v_records_stream [stream_dst]->nb [2 ];
835855
836- read_kvarn_tensor_slice (io, layer.k_records_stream [stream_dst], 0 , k_records_used);
837- zero_kvarn_tensor_range (layer.k_records_stream [stream_dst], k_records_used, k_records_total - k_records_used);
856+ if (state_kind == KVAR_N_STATE_RECORDS_FULL ) {
857+ read_kvarn_tensor_slice (io, layer.k_records_stream [stream_dst], 0 , k_records_used);
858+ zero_kvarn_tensor_range (layer.k_records_stream [stream_dst], k_records_used, k_records_total - k_records_used);
838859
839- read_kvarn_tensor_slice (io, layer.v_records_stream [stream_dst], 0 , v_records_used);
840- zero_kvarn_tensor_range (layer.v_records_stream [stream_dst], v_records_used, v_records_total - v_records_used);
860+ read_kvarn_tensor_slice (io, layer.v_records_stream [stream_dst], 0 , v_records_used);
861+ zero_kvarn_tensor_range (layer.v_records_stream [stream_dst], v_records_used, v_records_total - v_records_used);
862+ }
841863
842864 read_kvarn_tensor (io, layer.k_stage_stream [stream_dst]);
843865 read_kvarn_tensor (io, layer.v_stage_stream [stream_dst]);
0 commit comments