Skip to content

Commit 9e00db6

Browse files
committed
remove redundant V cache
1 parent 22676c1 commit 9e00db6

9 files changed

Lines changed: 99 additions & 23 deletions

conversion/deepseek.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ def _e8m0_to_float(scale: Tensor) -> Tensor:
516516
return scale.float()
517517

518518
bits = scale.view(torch.uint8).float()
519-
return torch.pow(torch.tensor(2.0, device=bits.device), bits - 127.0)
519+
return torch.exp2(bits - 127.0)
520520

521521
def _collect_source_dtypes(self) -> None:
522522
for name, gen in self.model_tensors.items():

src/llama-graph.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -566,15 +566,19 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
566566
// base tensors may not be allocated if there are no non-SWA attention layers
567567
if (self_k_idxs && self_k_idxs->buffer) {
568568
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
569-
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
569+
if (self_v_idxs) {
570+
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
571+
}
570572

571573
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
572574
}
573575

574576
// swa tensors may not be allocated if there are no SWA attention layers
575577
if (self_k_idxs_swa && self_k_idxs_swa->buffer) {
576578
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
577-
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
579+
if (self_v_idxs_swa) {
580+
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
581+
}
578582

579583
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
580584
}
@@ -2947,8 +2951,6 @@ llm_graph_input_dsv4 * llm_graph_context::build_inp_dsv4() const {
29472951

29482952
{
29492953
inp_raw->self_k_idxs = raw_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
2950-
inp_raw->self_v_idxs = raw_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
2951-
29522954
inp_raw->self_kq_mask = build_attn_inp_kq_mask(ctx0, raw_ctx->get_base(), ubatch, cparams);
29532955
inp_raw->self_kq_mask_cnv = inp_raw->self_kq_mask;
29542956
}
@@ -2957,18 +2959,12 @@ llm_graph_input_dsv4 * llm_graph_context::build_inp_dsv4() const {
29572959
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "DSV4 expects SWA raw cache");
29582960

29592961
inp_raw->self_k_idxs_swa = raw_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
2960-
inp_raw->self_v_idxs_swa = raw_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
2961-
29622962
inp_raw->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, raw_ctx->get_swa(), ubatch, cparams);
29632963
inp_raw->self_kq_mask_swa_cnv = inp_raw->self_kq_mask_swa;
29642964
}
29652965

29662966
inp_raw->self_k_rot = raw_ctx->get_base()->build_input_k_rot(ctx0);
2967-
inp_raw->self_v_rot = raw_ctx->get_base()->build_input_v_rot(ctx0);
2968-
29692967
inp_raw->self_k_rot_swa = raw_ctx->get_swa()->build_input_k_rot(ctx0);
2970-
inp_raw->self_v_rot_swa = raw_ctx->get_swa()->build_input_v_rot(ctx0);
2971-
29722968
auto inp = std::make_unique<llm_graph_input_dsv4>(cparams, std::move(inp_raw), mctx_cur);
29732969

29742970
dsv4_build_comp_inputs(ctx0, inp->inp_csa, mctx_cur->get_csa_plan(), "csa");

src/llama-kv-cache-dsv4.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,7 @@ llama_kv_cache_dsv4::llama_kv_cache_dsv4(
632632
uint32_t n_pad,
633633
const layer_filter_cb & filter,
634634
const layer_reuse_cb & reuse) :
635+
hparams_raw(model.hparams),
635636
hparams_csa(model.hparams),
636637
hparams_hca(model.hparams),
637638
hparams_lid(model.hparams) {
@@ -646,8 +647,10 @@ llama_kv_cache_dsv4::llama_kv_cache_dsv4(
646647

647648
LLAMA_LOG_INFO("%s: creating DSV4 raw KV cache\n", __func__);
648649

650+
dsv4_make_k_only(hparams_raw);
651+
649652
kv_raw = std::make_unique<llama_kv_cache_iswa>(
650-
model, type_k, type_v,
653+
model, hparams_raw, type_k, type_v,
651654
v_trans, offload, swa_full, unified, kv_size, n_seq_max, n_ubatch, n_pad,
652655
filter_raw, reuse);
653656

src/llama-kv-cache-dsv4.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ class llama_kv_cache_dsv4 : public llama_memory_i {
131131
llama_dsv4_comp_state * get_lid_state() const;
132132

133133
private:
134+
llama_hparams hparams_raw;
134135
llama_hparams hparams_csa;
135136
llama_hparams hparams_hca;
136137
llama_hparams hparams_lid;

src/llama-kv-cache-iswa.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,26 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
2424
uint32_t n_ubatch,
2525
uint32_t n_pad,
2626
const layer_filter_cb & filter,
27-
const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) {
27+
const layer_reuse_cb & reuse) :
28+
llama_kv_cache_iswa(model, model.hparams, type_k, type_v, v_trans, offload, swa_full, unified,
29+
kv_size, n_seq_max, n_ubatch, n_pad, filter, reuse) {
30+
}
31+
32+
llama_kv_cache_iswa::llama_kv_cache_iswa(
33+
const llama_model & model,
34+
const llama_hparams & hparams,
35+
ggml_type type_k,
36+
ggml_type type_v,
37+
bool v_trans,
38+
bool offload,
39+
bool swa_full,
40+
bool unified,
41+
uint32_t kv_size,
42+
uint32_t n_seq_max,
43+
uint32_t n_ubatch,
44+
uint32_t n_pad,
45+
const layer_filter_cb & filter,
46+
const layer_reuse_cb & reuse) : hparams(hparams), unified(unified) {
2847

2948
// chain filters
3049
const layer_filter_cb filter_base = [&](int32_t il) {

src/llama-kv-cache-iswa.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,22 @@ class llama_kv_cache_iswa : public llama_memory_i {
2828
const layer_filter_cb & filter,
2929
const layer_reuse_cb & reuse);
3030

31+
llama_kv_cache_iswa(
32+
const llama_model & model,
33+
const llama_hparams & hparams,
34+
ggml_type type_k,
35+
ggml_type type_v,
36+
bool v_trans,
37+
bool offload,
38+
bool swa_full,
39+
bool unified,
40+
uint32_t kv_size,
41+
uint32_t n_seq_max,
42+
uint32_t n_ubatch,
43+
uint32_t n_pad,
44+
const layer_filter_cb & filter,
45+
const layer_reuse_cb & reuse);
46+
3147
~llama_kv_cache_iswa() = default;
3248

3349
//

src/llama-kv-cache.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,12 @@ llama_kv_cache::llama_kv_cache(
177177
n_embd_head_k_all = -1;
178178
}
179179

180-
if (n_embd_head_v_all == 0) {
181-
n_embd_head_v_all = (int32_t) hparams.n_embd_head_v(il);
182-
} else if (n_embd_head_v_all > 0 && n_embd_head_v_all != (int32_t) hparams.n_embd_head_v(il)) {
183-
n_embd_head_v_all = -1;
180+
if (!is_mla) {
181+
if (n_embd_head_v_all == 0) {
182+
n_embd_head_v_all = (int32_t) hparams.n_embd_head_v(il);
183+
} else if (n_embd_head_v_all > 0 && n_embd_head_v_all != (int32_t) hparams.n_embd_head_v(il)) {
184+
n_embd_head_v_all = -1;
185+
}
184186
}
185187

186188
// [TAG_V_CACHE_VARIABLE]

src/models/deepseek-v4.cpp

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,6 @@ ggml_tensor * llama_model_deepseek_v4_flash::graph::build_csa_lid_attention(
651651
const llama_kv_cache_context * mctx_swa = inp_attn->mctx->get_swa();
652652

653653
ggml_build_forward_expand(gf, mctx_swa->cpy_k(ctx0, kv, inp_attn->get_k_idxs_swa(), il));
654-
ggml_build_forward_expand(gf, mctx_swa->cpy_v(ctx0, kv, inp_attn->get_v_idxs_swa(), il));
655654

656655
ggml_tensor * raw_k = mctx_swa->get_k(ctx0, il);
657656
if (raw_k->type != GGML_TYPE_F32) {
@@ -709,7 +708,6 @@ ggml_tensor * llama_model_deepseek_v4_flash::graph::build_hca_attention(
709708
const llama_kv_cache_context * mctx_swa = inp_attn->mctx->get_swa();
710709

711710
ggml_build_forward_expand(gf, mctx_swa->cpy_k(ctx0, kv, inp_attn->get_k_idxs_swa(), il));
712-
ggml_build_forward_expand(gf, mctx_swa->cpy_v(ctx0, kv, inp_attn->get_v_idxs_swa(), il));
713711

714712
ggml_tensor * raw_k = mctx_swa->get_k(ctx0, il);
715713
if (raw_k->type != GGML_TYPE_F32) {
@@ -748,6 +746,42 @@ ggml_tensor * llama_model_deepseek_v4_flash::graph::build_hca_attention(
748746
return out;
749747
}
750748

749+
ggml_tensor * llama_model_deepseek_v4_flash::graph::build_raw_attention(
750+
llm_graph_input_attn_kv_iswa * inp_attn,
751+
ggml_tensor * q,
752+
ggml_tensor * kv,
753+
ggml_tensor * sinks,
754+
float kq_scale,
755+
int il) const {
756+
const bool is_swa = hparams.is_swa(il);
757+
758+
ggml_tensor * k_rot = is_swa ? inp_attn->self_k_rot_swa : inp_attn->self_k_rot;
759+
ggml_tensor * v_rot = is_swa ? inp_attn->self_v_rot_swa : inp_attn->self_v_rot;
760+
GGML_ASSERT(v_rot == nullptr);
761+
762+
if (k_rot) {
763+
q = ggml_mul_mat(ctx0, k_rot, q);
764+
kv = ggml_mul_mat(ctx0, k_rot, kv);
765+
}
766+
767+
ggml_build_forward_expand(gf, q);
768+
ggml_build_forward_expand(gf, kv);
769+
770+
const llama_kv_cache_context * mctx_cur = is_swa ? inp_attn->mctx->get_swa() : inp_attn->mctx->get_base();
771+
const auto & k_idxs = is_swa ? inp_attn->get_k_idxs_swa() : inp_attn->get_k_idxs();
772+
773+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, kv, k_idxs, il));
774+
775+
const auto & kq_mask = is_swa ? inp_attn->get_kq_mask_swa() : inp_attn->get_kq_mask();
776+
777+
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
778+
779+
ggml_tensor * out = build_attn_mha(q, k, k, nullptr, kq_mask, sinks, nullptr, kq_scale, il);
780+
cb(out, "attn_raw", il);
781+
782+
return out;
783+
}
784+
751785
ggml_tensor * llama_model_deepseek_v4_flash::graph::build_attention(
752786
const llama_model & model,
753787
llm_graph_input_dsv4 * inp_dsv4,
@@ -1021,11 +1055,8 @@ ggml_tensor * llama_model_deepseek_v4_flash::graph::build_attention(
10211055
out = build_hca_attention(inp_dsv4, inp_attn, q, kv, layer.attn_sinks,
10221056
1.0f/sqrtf(float(n_embd_head)), il);
10231057
} else {
1024-
out = build_attn(inp_attn,
1025-
nullptr, nullptr, nullptr,
1026-
q, kv, kv, nullptr, layer.attn_sinks, nullptr,
1058+
out = build_raw_attention(inp_attn, q, kv, layer.attn_sinks,
10271059
1.0f/sqrtf(float(n_embd_head)), il);
1028-
cb(out, "attn_raw", il);
10291060
}
10301061

10311062
out = ggml_reshape_3d(ctx0, out, n_embd_head, n_head, nt);

src/models/models.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,6 +1149,14 @@ struct llama_model_deepseek_v4_flash : public llama_model_base {
11491149
float kq_scale,
11501150
int il) const;
11511151

1152+
ggml_tensor * build_raw_attention(
1153+
llm_graph_input_attn_kv_iswa * inp_attn,
1154+
ggml_tensor * q,
1155+
ggml_tensor * kv,
1156+
ggml_tensor * sinks,
1157+
float kq_scale,
1158+
int il) const;
1159+
11521160
ggml_tensor * build_hc_weighted_sum(
11531161
ggml_tensor * x,
11541162
ggml_tensor * weights) const;

0 commit comments

Comments
 (0)