Skip to content

Commit 4caa0a4

Browse files
committed
Fix KVarN workspace store and compact checkpoints
Read split workspace flush groups per token so the CUDA KVarN store path no longer pulls stale stage tails at ubatch boundaries. Serialize KVarN partial checkpoints as stage-only overlays while preserving full record serialization for portable/full cache state, matching the compact partial-state shape used by normal quant caches.
1 parent 1a115e2 commit 4caa0a4

4 files changed

Lines changed: 50 additions & 14 deletions

File tree

ggml/src/ggml-cuda/kvarn.cu

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,15 +1140,14 @@ static __global__ void kvarn_store_workspace_flush_kernel(
11401140

11411141
const int flush_start = flush_group * KVAR_N_DIM;
11421142
const int stage_base = stream * KVAR_N_DIM * KVAR_N_STAGE_GROUPS;
1143-
const bool from_workspace = flush_start >= start_local && flush_start + KVAR_N_DIM <= end_local;
11441143
float * tile = shared;
11451144
for (int i = threadIdx.x; i < KVAR_N_TILE_VALUES; i += blockDim.x) {
11461145
const int row = i / KVAR_N_DIM;
11471146
const int col = i % KVAR_N_DIM;
11481147
const int token = value ? row : col;
11491148
const int dim = value ? col : row;
1150-
if (from_workspace) {
1151-
const int local_pos = flush_start + token;
1149+
const int local_pos = flush_start + token;
1150+
if (local_pos >= start_local && local_pos < end_local) {
11521151
const int src_token = token_base + local_pos - start_local;
11531152
tile[i] = __half2float(workspace[((int64_t) src_token * n_heads + head) * KVAR_N_DIM + dim]);
11541153
} else {

src/llama-kv-cache-kvarn.cpp

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ namespace {
1919
constexpr uint32_t KVAR_N_GROUP = 128;
2020
constexpr uint32_t KVAR_N_STAGE_GROUPS = 3;
2121
constexpr uint32_t KVAR_N_STATE_MAGIC = 0x4e52564b; // "KVRN"
22-
constexpr uint32_t KVAR_N_STATE_VERSION = 3;
22+
constexpr uint32_t KVAR_N_STATE_VERSION = 4;
23+
constexpr uint32_t KVAR_N_STATE_RECORDS_FULL = 0;
24+
constexpr uint32_t KVAR_N_STATE_STAGE_ONLY_PARTIAL = 1;
2325

2426
bool kvarn_backend_supports_native_ops(ggml_backend_dev_t dev) {
2527
if (dev == nullptr || ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU) {
@@ -694,6 +696,7 @@ bool llama_kv_cache_kvarn::apply_pending_stream_copies(llama_context * lctx) {
694696

695697
void llama_kv_cache_kvarn::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
696698
metadata->state_write(io, seq_id, flags);
699+
const bool partial_state = (flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) != 0 && seq_id >= 0;
697700

698701
std::vector<uint32_t> saved_streams;
699702
if (seq_id == -1) {
@@ -718,6 +721,8 @@ void llama_kv_cache_kvarn::state_write(llama_io_write_i & io, llama_seq_id seq_i
718721
for (const uint32_t stream : saved_streams) {
719722
io.write(&stream, sizeof(stream));
720723
}
724+
const uint32_t state_kind = partial_state ? KVAR_N_STATE_STAGE_ONLY_PARTIAL : KVAR_N_STATE_RECORDS_FULL;
725+
io.write(&state_kind, sizeof(state_kind));
721726

722727
// n_groups_used is single-valued across all saved streams. This is correct
723728
// because when seq_id >= 0, saved_streams has exactly 1 entry (the stream
@@ -743,10 +748,12 @@ void llama_kv_cache_kvarn::state_write(llama_io_write_i & io, llama_seq_id seq_i
743748
for (const uint32_t stream : saved_streams) {
744749
io.write(&stream, sizeof(stream));
745750

746-
const size_t k_records_used = n_groups_used * layer.k_records_stream[stream]->nb[2];
747-
const size_t v_records_used = n_groups_used * layer.v_records_stream[stream]->nb[2];
748-
write_kvarn_tensor_slice(io, layer.k_records_stream[stream], 0, k_records_used);
749-
write_kvarn_tensor_slice(io, layer.v_records_stream[stream], 0, v_records_used);
751+
if (state_kind == KVAR_N_STATE_RECORDS_FULL) {
752+
const size_t k_records_used = n_groups_used * layer.k_records_stream[stream]->nb[2];
753+
const size_t v_records_used = n_groups_used * layer.v_records_stream[stream]->nb[2];
754+
write_kvarn_tensor_slice(io, layer.k_records_stream[stream], 0, k_records_used);
755+
write_kvarn_tensor_slice(io, layer.v_records_stream[stream], 0, v_records_used);
756+
}
750757
write_kvarn_tensor(io, layer.k_stage_stream[stream]);
751758
write_kvarn_tensor(io, layer.v_stage_stream[stream]);
752759
}
@@ -798,6 +805,14 @@ void llama_kv_cache_kvarn::state_read(llama_io_read_i & io, llama_seq_id seq_id,
798805
}
799806
}
800807

808+
uint32_t state_kind = KVAR_N_STATE_RECORDS_FULL;
809+
if (version >= 4) {
810+
io.read(&state_kind, sizeof(state_kind));
811+
if (state_kind != KVAR_N_STATE_RECORDS_FULL && state_kind != KVAR_N_STATE_STAGE_ONLY_PARTIAL) {
812+
throw std::runtime_error("invalid KVarN cache state kind");
813+
}
814+
}
815+
801816
uint32_t n_groups_used = n_groups_per_stream;
802817
if (version >= 3) {
803818
io.read(&n_groups_used, sizeof(n_groups_used));
@@ -810,6 +825,11 @@ void llama_kv_cache_kvarn::state_read(llama_io_read_i & io, llama_seq_id seq_id,
810825
if (seq_id != -1 && seq_stream >= n_stream) {
811826
throw std::runtime_error("invalid KVarN sequence stream");
812827
}
828+
if (state_kind == KVAR_N_STATE_STAGE_ONLY_PARTIAL) {
829+
if (seq_id < 0) {
830+
throw std::runtime_error("KVarN stage-only state requires a destination sequence");
831+
}
832+
}
813833

814834
for (const auto & layer : layers) {
815835
uint32_t il;
@@ -833,11 +853,13 @@ void llama_kv_cache_kvarn::state_read(llama_io_read_i & io, llama_seq_id seq_id,
833853
const size_t k_records_total = n_groups_per_stream * layer.k_records_stream[stream_dst]->nb[2];
834854
const size_t v_records_total = n_groups_per_stream * layer.v_records_stream[stream_dst]->nb[2];
835855

836-
read_kvarn_tensor_slice(io, layer.k_records_stream[stream_dst], 0, k_records_used);
837-
zero_kvarn_tensor_range(layer.k_records_stream[stream_dst], k_records_used, k_records_total - k_records_used);
856+
if (state_kind == KVAR_N_STATE_RECORDS_FULL) {
857+
read_kvarn_tensor_slice(io, layer.k_records_stream[stream_dst], 0, k_records_used);
858+
zero_kvarn_tensor_range(layer.k_records_stream[stream_dst], k_records_used, k_records_total - k_records_used);
838859

839-
read_kvarn_tensor_slice(io, layer.v_records_stream[stream_dst], 0, v_records_used);
840-
zero_kvarn_tensor_range(layer.v_records_stream[stream_dst], v_records_used, v_records_total - v_records_used);
860+
read_kvarn_tensor_slice(io, layer.v_records_stream[stream_dst], 0, v_records_used);
861+
zero_kvarn_tensor_range(layer.v_records_stream[stream_dst], v_records_used, v_records_total - v_records_used);
862+
}
841863

842864
read_kvarn_tensor(io, layer.k_stage_stream[stream_dst]);
843865
read_kvarn_tensor(io, layer.v_stage_stream[stream_dst]);

tests/test-dflash-plumbing.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2620,17 +2620,22 @@ int main(int argc, char ** argv) {
26202620
ok &= expect(kv_cache_kvarn_cpp.find("GGML_ABORT(\"KVarN does not support position shifts\")") != std::string::npos &&
26212621
kv_cache_kvarn_cpp.find("GGML_ABORT(\"KVarN does not support position division\")") != std::string::npos,
26222622
"KVarN seq_add/seq_div must fail fast instead of logging and continuing");
2623-
ok &= expect(kv_cache_kvarn_cpp.find("constexpr uint32_t KVAR_N_STATE_VERSION = 3") != std::string::npos &&
2623+
ok &= expect(kv_cache_kvarn_cpp.find("constexpr uint32_t KVAR_N_STATE_VERSION = 4") != std::string::npos &&
2624+
kv_cache_kvarn_cpp.find("KVAR_N_STATE_RECORDS_FULL") != std::string::npos &&
2625+
kv_cache_kvarn_cpp.find("KVAR_N_STATE_STAGE_ONLY_PARTIAL") != std::string::npos &&
2626+
kv_cache_kvarn_cpp.find("const uint32_t state_kind = partial_state ? KVAR_N_STATE_STAGE_ONLY_PARTIAL : KVAR_N_STATE_RECORDS_FULL;") != std::string::npos &&
26242627
kv_cache_kvarn_cpp.find("saved_streams") != std::string::npos &&
26252628
kv_cache_kvarn_cpp.find("metadata->get_stream_for_seq(seq_id)") != std::string::npos &&
26262629
kv_cache_kvarn_cpp.find("layer.k_records_stream[stream]") != std::string::npos &&
26272630
kv_cache_kvarn_cpp.find("layer.v_stage_stream[stream]") != std::string::npos &&
26282631
kv_cache_kvarn_cpp.find("n_groups_used") != std::string::npos &&
26292632
kv_cache_kvarn_cpp.find("write_kvarn_tensor_slice") != std::string::npos &&
26302633
kv_cache_kvarn_cpp.find("read_kvarn_tensor_slice") != std::string::npos &&
2634+
kv_cache_kvarn_cpp.find("if (state_kind == KVAR_N_STATE_RECORDS_FULL)") != std::string::npos &&
2635+
kv_cache_kvarn_cpp.find("if (state_kind == KVAR_N_STATE_STAGE_ONLY_PARTIAL)") != std::string::npos &&
26312636
kv_cache_kvarn_cpp.find("ggml_backend_tensor_memset") != std::string::npos &&
26322637
kv_cache_kvarn_cpp.find("const uint64_t size64 = size") != std::string::npos,
2633-
"KVarN sequence state must serialize stream-scoped tensors with group-range compression");
2638+
"KVarN sequence state must serialize full records compactly and partial checkpoints as stage-only overlays");
26342639
ok &= expect(kv_cache_kvarn_cpp.find("pending_stream_copies") != std::string::npos &&
26352640
kv_cache_kvarn_cpp.find("llama_synchronize(lctx)") != std::string::npos &&
26362641
kv_cache_kvarn_cpp.find("copy_kvarn_stream") != std::string::npos,

tests/test-kvarn.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,16 @@ static void test_store_legacy_parity_gpu() {
661661
}
662662
}
663663

664+
for (int bits : { 2, 3, 4, 5, 6, 8 }) {
665+
for (bool value : { false, true }) {
666+
const std::vector<uint8_t> modern = test_store_records(
667+
backend, bits, value, false, 1, 2, 512, 200, false, true);
668+
const std::vector<uint8_t> legacy = test_store_records(
669+
backend, bits, value, true, 1, 2, 512, 200, false, true);
670+
require(modern == legacy, "KVarN CUDA split workspace store records differ from legacy path");
671+
}
672+
}
673+
664674
for (int bits : { 2, 3, 4, 5, 6, 8 }) {
665675
for (bool value : { false, true }) {
666676
const std::vector<uint8_t> modern = test_store_records(

0 commit comments

Comments
 (0)