Skip to content

Commit 415c6d0

Browse files
committed
enable graph reuse
1 parent 5a2bbea commit 415c6d0

4 files changed

Lines changed: 105 additions & 56 deletions

File tree

scripts/gen-chat-inline-templates.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,17 +84,22 @@ def main() -> None:
8484
arch, rel = match.group(1), match.group(2)
8585
# read the template verbatim (no newline translation) so the embedded
8686
# string is a byte-for-byte copy of the source .jinja file
87-
content = (repo_root / rel).read_text(encoding="utf-8", newline="")
87+
# Path.read_text() only grew a newline param in python 3.13
88+
with open(repo_root / rel, encoding="utf-8", newline="") as f:
89+
content = f.read()
8890
entries.append((arch, rel, content))
8991

9092
text = render(entries)
9193

9294
output = Path(args.output)
9395
# write only when the content changes to avoid spurious rebuilds
94-
if output.exists() and output.read_text(encoding="utf-8", newline="") == text:
95-
return
96+
if output.exists():
97+
with open(output, encoding="utf-8", newline="") as f:
98+
if f.read() == text:
99+
return
96100
output.parent.mkdir(parents=True, exist_ok=True)
97-
output.write_text(text, encoding="utf-8", newline="")
101+
with open(output, "w", encoding="utf-8", newline="") as f:
102+
f.write(text)
98103

99104

100105
if __name__ == "__main__":

src/llama-kv-cache-dsv4.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
static constexpr uint32_t DSV4_CSA_RATIO = 4;
1919
static constexpr uint32_t DSV4_HCA_RATIO = 128;
20+
static constexpr uint32_t DSV4_CSA_GRAPH_RAW_BUCKET = DSV4_HCA_RATIO;
2021

2122
static constexpr uint32_t DSV4_STATE_MAGIC = 0x34565344; // DSV4
2223
static constexpr uint32_t DSV4_STATE_VERSION = 1;
@@ -226,6 +227,7 @@ static llama_kv_cache_dsv4_context::comp_plan dsv4_build_comp_plan(
226227
};
227228

228229
std::vector<persist_row> persist_rows;
230+
llama_pos max_pos = -1;
229231

230232
// For the overlap compressor, build_overlap_compressed_kv_from_state() consumes
231233
// state_read_idxs as two contiguous halves: the first ratio*n_blocks entries are
@@ -272,6 +274,7 @@ static llama_kv_cache_dsv4_context::comp_plan dsv4_build_comp_plan(
272274
}
273275

274276
const llama_seq_id seq_id = ubatch.seq_id[i][0];
277+
max_pos = std::max(max_pos, pos);
275278

276279
const int64_t stream_off = n_stream > 1 ? (int64_t) seq_id*state_size : 0;
277280

@@ -323,6 +326,36 @@ static llama_kv_cache_dsv4_context::comp_plan dsv4_build_comp_plan(
323326
}
324327
}
325328

329+
if (ratio == DSV4_CSA_RATIO && plan.state_write_idxs.empty() && !plan.state_idxs.empty()) {
330+
assert(kv_size > 0);
331+
332+
uint32_t i = 0;
333+
while (i < ubatch.n_tokens && ubatch.pos[i] < 0) {
334+
++i;
335+
}
336+
assert(i < ubatch.n_tokens);
337+
338+
const llama_pos pos = ubatch.pos[i];
339+
const llama_seq_id seq_id = ubatch.seq_id[i][0];
340+
const int64_t cache_off = n_stream > 1 && seq_id >= 0 ? (int64_t) seq_id*kv_size : 0;
341+
const int32_t source_idx = state_source_idx(seq_id, pos);
342+
343+
plan.state_write_idxs.push_back(cache_off + kv_size - 1);
344+
plan.state_write_pos .push_back(0);
345+
plan.state_write_end .push_back(-1);
346+
347+
if (overlap) {
348+
for (uint32_t j = 0; j < ratio; ++j) {
349+
overlap_prev_reads.push_back(source_idx);
350+
overlap_cur_reads .push_back(source_idx);
351+
}
352+
} else {
353+
for (uint32_t j = 0; j < ratio; ++j) {
354+
plan.state_read_idxs.push_back(source_idx);
355+
}
356+
}
357+
}
358+
326359
if (overlap) {
327360
// [ all blocks' prev-window indices | all blocks' cur-window indices ]
328361
plan.state_read_idxs.reserve(overlap_prev_reads.size() + overlap_cur_reads.size());
@@ -332,6 +365,19 @@ static llama_kv_cache_dsv4_context::comp_plan dsv4_build_comp_plan(
332365
overlap_cur_reads.begin(), overlap_cur_reads.end());
333366
}
334367

368+
if (ratio == DSV4_CSA_RATIO && max_pos >= 0) {
369+
const int64_t raw_bucket = DSV4_CSA_GRAPH_RAW_BUCKET;
370+
const int64_t pos_p1 = max_pos + 1;
371+
int64_t n_raw_buckets = (pos_p1 + raw_bucket - 1)/raw_bucket;
372+
if (pos_p1 % raw_bucket == 0) {
373+
++n_raw_buckets;
374+
}
375+
376+
const int64_t bucketed_tokens = n_raw_buckets * raw_bucket;
377+
const int64_t bucketed_n_kv = (bucketed_tokens + ratio - 1)/ratio;
378+
plan.n_kv = std::min<int64_t>(kv_size, std::max<int64_t>(plan.n_kv, bucketed_n_kv));
379+
}
380+
335381
std::sort(persist_rows.begin(), persist_rows.end(),
336382
[](const persist_row & a, const persist_row & b) {
337383
return a.dst < b.dst;

src/llama-kv-cache-dsv4.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ class llama_kv_cache_dsv4_context : public llama_memory_context_i {
175175
std::vector<int32_t> state_read_idxs;
176176

177177
// Final compressed-cache row ids written by state-backed commits.
178+
// A non-boundary CSA/LID decode step can target a masked scratch row.
178179
std::vector<int64_t> state_write_idxs;
179180

180181
// RoPE positions for state-backed commits.
@@ -186,7 +187,8 @@ class llama_kv_cache_dsv4_context : public llama_memory_context_i {
186187
// Number of completed compressed rows visible for each query token.
187188
std::vector<int32_t> n_visible;
188189

189-
// Maximum compressed rows visible to this ubatch.
190+
// Graph-width for compressed rows. This can be larger than n_visible
191+
// so masked padding rows do not force a new graph at every CSA block.
190192
int64_t n_kv = 0;
191193
};
192194

src/models/deepseek-v4.cpp

Lines changed: 47 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -894,31 +894,29 @@ ggml_tensor * llama_model_deepseek4::graph::build_attention(
894894
csa_state_score = ggml_add(ctx0, csa_state_score, csa_ape_rows);
895895
cb(csa_state_score, "csa_state_score_ape", il);
896896

897-
ggml_tensor * csa_state_dep = nullptr;
898-
if (inp_dsv4->get_csa().state_write_idxs) {
899-
ggml_tensor * csa_source_kv = ggml_concat(ctx0,
900-
inp_dsv4->mctx->get_csa_state()->get_kv(ctx0, il), csa_state_kv, 1);
901-
ggml_tensor * csa_source_score = ggml_concat(ctx0,
902-
inp_dsv4->mctx->get_csa_state()->get_score(ctx0, il), csa_state_score, 1);
903-
904-
ggml_tensor * kv_comp_csa_state = build_overlap_compressed_kv_from_state(
905-
csa_source_kv,
906-
csa_source_score,
907-
inp_dsv4->get_csa().state_read_idxs,
908-
inp_dsv4->get_csa().state_write_pos,
909-
layer.attn_comp_norm,
910-
DSV4_CSA_RATIO,
911-
n_embd_head,
912-
"csa_state_compress",
913-
il);
914-
915-
ggml_build_forward_expand(gf, inp_dsv4->mctx->get_csa()->cpy_k(ctx0,
916-
kv_comp_csa_state, inp_dsv4->get_csa().state_write_idxs, il));
917-
csa_state_dep = kv_comp_csa_state;
918-
}
897+
GGML_ASSERT(inp_dsv4->get_csa().state_write_idxs);
898+
899+
ggml_tensor * csa_source_kv = ggml_concat(ctx0,
900+
inp_dsv4->mctx->get_csa_state()->get_kv(ctx0, il), csa_state_kv, 1);
901+
ggml_tensor * csa_source_score = ggml_concat(ctx0,
902+
inp_dsv4->mctx->get_csa_state()->get_score(ctx0, il), csa_state_score, 1);
903+
904+
ggml_tensor * kv_comp_csa_state = build_overlap_compressed_kv_from_state(
905+
csa_source_kv,
906+
csa_source_score,
907+
inp_dsv4->get_csa().state_read_idxs,
908+
inp_dsv4->get_csa().state_write_pos,
909+
layer.attn_comp_norm,
910+
DSV4_CSA_RATIO,
911+
n_embd_head,
912+
"csa_state_compress",
913+
il);
919914

920-
csa_state_kv = dsv4_with_zero_dep(ctx0, csa_state_kv, csa_state_dep);
921-
csa_state_score = dsv4_with_zero_dep(ctx0, csa_state_score, csa_state_dep);
915+
ggml_build_forward_expand(gf, inp_dsv4->mctx->get_csa()->cpy_k(ctx0,
916+
kv_comp_csa_state, inp_dsv4->get_csa().state_write_idxs, il));
917+
918+
csa_state_kv = dsv4_with_zero_dep(ctx0, csa_state_kv, kv_comp_csa_state);
919+
csa_state_score = dsv4_with_zero_dep(ctx0, csa_state_score, kv_comp_csa_state);
922920

923921
ggml_tensor * csa_persist_kv = ggml_get_rows(ctx0, csa_state_kv, inp_dsv4->get_csa().state_persist_src_idxs);
924922
ggml_tensor * csa_persist_score = ggml_get_rows(ctx0, csa_state_score, inp_dsv4->get_csa().state_persist_src_idxs);
@@ -946,36 +944,34 @@ ggml_tensor * llama_model_deepseek4::graph::build_attention(
946944
lid_state_score = ggml_add(ctx0, lid_state_score, lid_ape_rows);
947945
cb(lid_state_score, "lid_state_score_ape", il);
948946

949-
ggml_tensor * lid_state_dep = nullptr;
950-
if (inp_dsv4->get_lid().state_write_idxs) {
951-
ggml_tensor * lid_source_kv = ggml_concat(ctx0,
952-
inp_dsv4->mctx->get_lid_state()->get_kv(ctx0, il), lid_state_kv, 1);
953-
ggml_tensor * lid_source_score = ggml_concat(ctx0,
954-
inp_dsv4->mctx->get_lid_state()->get_score(ctx0, il), lid_state_score, 1);
955-
956-
ggml_tensor * kv_comp_lid_state = build_overlap_compressed_kv_from_state(
957-
lid_source_kv,
958-
lid_source_score,
959-
inp_dsv4->get_lid().state_read_idxs,
960-
inp_dsv4->get_lid().state_write_pos,
961-
layer.indexer_comp_norm,
962-
DSV4_CSA_RATIO,
963-
hparams.indexer_head_size,
964-
"lid_state_compress",
965-
il);
966-
967-
if (inp_dsv4->get_lid().k_rot) {
968-
kv_comp_lid_state = ggml_mul_mat(ctx0, inp_dsv4->get_lid().k_rot, kv_comp_lid_state);
969-
cb(kv_comp_lid_state, "lid_state_compress_rot", il);
970-
}
947+
GGML_ASSERT(inp_dsv4->get_lid().state_write_idxs);
948+
949+
ggml_tensor * lid_source_kv = ggml_concat(ctx0,
950+
inp_dsv4->mctx->get_lid_state()->get_kv(ctx0, il), lid_state_kv, 1);
951+
ggml_tensor * lid_source_score = ggml_concat(ctx0,
952+
inp_dsv4->mctx->get_lid_state()->get_score(ctx0, il), lid_state_score, 1);
953+
954+
ggml_tensor * kv_comp_lid_state = build_overlap_compressed_kv_from_state(
955+
lid_source_kv,
956+
lid_source_score,
957+
inp_dsv4->get_lid().state_read_idxs,
958+
inp_dsv4->get_lid().state_write_pos,
959+
layer.indexer_comp_norm,
960+
DSV4_CSA_RATIO,
961+
hparams.indexer_head_size,
962+
"lid_state_compress",
963+
il);
971964

972-
ggml_build_forward_expand(gf, inp_dsv4->mctx->get_lid()->cpy_k(ctx0,
973-
kv_comp_lid_state, inp_dsv4->get_lid().state_write_idxs, il));
974-
lid_state_dep = kv_comp_lid_state;
965+
if (inp_dsv4->get_lid().k_rot) {
966+
kv_comp_lid_state = ggml_mul_mat(ctx0, inp_dsv4->get_lid().k_rot, kv_comp_lid_state);
967+
cb(kv_comp_lid_state, "lid_state_compress_rot", il);
975968
}
976969

977-
lid_state_kv = dsv4_with_zero_dep(ctx0, lid_state_kv, lid_state_dep);
978-
lid_state_score = dsv4_with_zero_dep(ctx0, lid_state_score, lid_state_dep);
970+
ggml_build_forward_expand(gf, inp_dsv4->mctx->get_lid()->cpy_k(ctx0,
971+
kv_comp_lid_state, inp_dsv4->get_lid().state_write_idxs, il));
972+
973+
lid_state_kv = dsv4_with_zero_dep(ctx0, lid_state_kv, kv_comp_lid_state);
974+
lid_state_score = dsv4_with_zero_dep(ctx0, lid_state_score, kv_comp_lid_state);
979975

980976
ggml_tensor * lid_persist_kv = ggml_get_rows(ctx0, lid_state_kv, inp_dsv4->get_lid().state_persist_src_idxs);
981977
ggml_tensor * lid_persist_score = ggml_get_rows(ctx0, lid_state_score, inp_dsv4->get_lid().state_persist_src_idxs);

0 commit comments

Comments
 (0)