Skip to content

Commit 14bb7ae

Browse files
committed
llama: fix quantized kv-cache for dsv4
1 parent 8c146a8 commit 14bb7ae

6 files changed

Lines changed: 94 additions & 66 deletions

File tree

src/llama-graph.cpp

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -63,26 +63,6 @@ static bool can_reuse_kq_mask(
6363

6464
// impl
6565

66-
static ggml_tensor * ggml_mul_mat_aux(
67-
ggml_context * ctx,
68-
ggml_tensor * cur,
69-
ggml_tensor * rot) {
70-
const auto n = rot->ne[0];
71-
72-
ggml_tensor * res;
73-
74-
if (!ggml_is_contiguous(cur)) {
75-
res = ggml_cont_2d (ctx, cur, n, ggml_nelements(cur)/n);
76-
} else {
77-
res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n);
78-
}
79-
res = ggml_mul_mat (ctx, rot, res);
80-
ggml_mul_mat_set_hint(res, GGML_HINT_SRC0_IS_HADAMARD);
81-
res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]);
82-
83-
return res;
84-
}
85-
8666
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
8767
if (ubatch->token) {
8868
const int64_t n_tokens = ubatch->n_tokens;
@@ -881,6 +861,14 @@ void llm_graph_input_dsv4::set_input(const llama_ubatch * ubatch) {
881861
dsv4_set_comp_inputs(inp_hca, plan_hca, "hca", debug > 0, ubatch->n_tokens, n_stream);
882862
dsv4_set_comp_inputs(inp_lid, plan_lid, "lid", debug > 0, ubatch->n_tokens, n_stream);
883863

864+
if (inp_csa.k_rot && inp_csa.k_rot->buffer) {
865+
mctx->get_csa()->set_input_k_rot(inp_csa.k_rot);
866+
}
867+
868+
if (inp_hca.k_rot && inp_hca.k_rot->buffer) {
869+
mctx->get_hca()->set_input_k_rot(inp_hca.k_rot);
870+
}
871+
884872
if (inp_lid.k_rot && inp_lid.k_rot->buffer) {
885873
mctx->get_lid()->set_input_k_rot(inp_lid.k_rot);
886874
}
@@ -2633,12 +2621,12 @@ ggml_tensor * llm_graph_context::build_attn(
26332621
GGML_ASSERT(v_mla == nullptr);
26342622

26352623
if (inp->self_k_rot) {
2636-
q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot);
2637-
k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot);
2624+
q_cur = llama_mul_mat_hadamard(ctx0, q_cur, inp->self_k_rot);
2625+
k_cur = llama_mul_mat_hadamard(ctx0, k_cur, inp->self_k_rot);
26382626
}
26392627

26402628
if (inp->self_v_rot) {
2641-
v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot);
2629+
v_cur = llama_mul_mat_hadamard(ctx0, v_cur, inp->self_v_rot);
26422630
}
26432631

26442632
// these nodes are added to the graph together so that they are not reordered
@@ -2669,7 +2657,7 @@ ggml_tensor * llm_graph_context::build_attn(
26692657
cb(cur, "kqv_out", il);
26702658

26712659
if (inp->self_v_rot) {
2672-
cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot);
2660+
cur = llama_mul_mat_hadamard(ctx0, cur, inp->self_v_rot);
26732661
}
26742662

26752663
if (wo) {
@@ -2874,14 +2862,14 @@ ggml_tensor * llm_graph_context::build_attn(
28742862
auto * v_rot = is_swa ? inp->self_v_rot_swa : inp->self_v_rot;
28752863

28762864
if (k_rot) {
2877-
q_cur = ggml_mul_mat_aux(ctx0, q_cur, k_rot);
2865+
q_cur = llama_mul_mat_hadamard(ctx0, q_cur, k_rot);
28782866
if (k_cur) {
2879-
k_cur = ggml_mul_mat_aux(ctx0, k_cur, k_rot);
2867+
k_cur = llama_mul_mat_hadamard(ctx0, k_cur, k_rot);
28802868
}
28812869
}
28822870
if (v_rot) {
28832871
if (v_cur) {
2884-
v_cur = ggml_mul_mat_aux(ctx0, v_cur, v_rot);
2872+
v_cur = llama_mul_mat_hadamard(ctx0, v_cur, v_rot);
28852873
}
28862874
}
28872875

@@ -2924,7 +2912,7 @@ ggml_tensor * llm_graph_context::build_attn(
29242912
cb(cur, "kqv_out", il);
29252913

29262914
if (v_rot) {
2927-
cur = ggml_mul_mat_aux(ctx0, cur, v_rot);
2915+
cur = llama_mul_mat_hadamard(ctx0, cur, v_rot);
29282916
}
29292917

29302918
if (wo) {
@@ -3084,6 +3072,8 @@ llm_graph_input_dsv4 * llm_graph_context::build_inp_dsv4() const {
30843072
dsv4_build_comp_inputs(ctx0, inp->inp_csa, mctx_cur->get_csa_plan(ubatch), "csa", n_stream);
30853073
dsv4_build_comp_inputs(ctx0, inp->inp_hca, mctx_cur->get_hca_plan(ubatch), "hca", n_stream);
30863074
dsv4_build_comp_inputs(ctx0, inp->inp_lid, mctx_cur->get_lid_plan(ubatch), "lid", n_stream);
3075+
inp->inp_csa.k_rot = mctx_cur->get_csa()->build_input_k_rot(ctx0);
3076+
inp->inp_hca.k_rot = mctx_cur->get_hca()->build_input_k_rot(ctx0);
30873077
inp->inp_lid.k_rot = mctx_cur->get_lid()->build_input_k_rot(ctx0);
30883078

30893079
return (llm_graph_input_dsv4 *) res->add_input(std::move(inp));

src/llama-impl.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,26 @@ static inline dst_t llama_cast(src_t v) {
5454
}
5555
}
5656

57+
static inline ggml_tensor * llama_mul_mat_hadamard(
58+
ggml_context * ctx,
59+
ggml_tensor * cur,
60+
ggml_tensor * rot) {
61+
const auto n = rot->ne[0];
62+
63+
ggml_tensor * res;
64+
65+
if (!ggml_is_contiguous(cur)) {
66+
res = ggml_cont_2d(ctx, cur, n, ggml_nelements(cur)/n);
67+
} else {
68+
res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n);
69+
}
70+
res = ggml_mul_mat(ctx, rot, res);
71+
ggml_mul_mat_set_hint(res, GGML_HINT_SRC0_IS_HADAMARD);
72+
res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]);
73+
74+
return res;
75+
}
76+
5777
struct time_meas {
5878
time_meas(int64_t & t_acc, bool disable = false);
5979
~time_meas();

src/llama-kv-cache-dsv4.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
static constexpr uint32_t DSV4_CSA_RATIO = 4;
1919
static constexpr uint32_t DSV4_HCA_RATIO = 128;
20+
// [TAG_DSV4_CACHE_PAD]
21+
// matches MATRIX_ROW_PADDING used by backends for quantized row padding
22+
static constexpr uint32_t DSV4_CACHE_PAD = 512;
2023

2124
static constexpr uint32_t DSV4_STATE_MAGIC = 0x34565344; // DSV4
2225
static constexpr uint32_t DSV4_STATE_VERSION = 1;
@@ -519,7 +522,7 @@ static llama_kv_cache_dsv4_context::comp_plan dsv4_build_comp_plan(
519522
overlap_cur_reads.begin(), overlap_cur_reads.end());
520523
}
521524

522-
plan.n_kv = GGML_PAD(plan.n_kv, 256u);
525+
plan.n_kv = GGML_PAD(plan.n_kv, DSV4_CACHE_PAD);
523526

524527
std::sort(persist_rows.begin(), persist_rows.end(),
525528
[](const persist_row & a, const persist_row & b) {
@@ -950,13 +953,16 @@ llama_kv_cache_dsv4::llama_kv_cache_dsv4(
950953
// Keep DSV4 KV/state streams per sequence even when public KV mode is unified.
951954
const bool unified_raw = false;
952955

956+
const uint32_t kv_size_raw = GGML_PAD(kv_size, DSV4_CACHE_PAD);
957+
const uint32_t n_pad_raw = std::max(n_pad, DSV4_CACHE_PAD);
958+
953959
LLAMA_LOG_INFO("%s: creating DSV4 raw KV cache\n", __func__);
954960

955961
dsv4_make_k_only(hparams_raw);
956962

957963
kv_raw = std::make_unique<llama_kv_cache_iswa>(
958964
model, hparams_raw, type_k, type_v,
959-
v_trans, offload, swa_full, unified_raw, kv_size, n_seq_max, n_ubatch, n_pad,
965+
v_trans, offload, swa_full, unified_raw, kv_size_raw, n_seq_max, n_ubatch, n_pad_raw,
960966
nullptr, filter_raw, reuse, nullptr);
961967

962968
dsv4_make_k_only(hparams_csa);
@@ -989,27 +995,27 @@ llama_kv_cache_dsv4::llama_kv_cache_dsv4(
989995
const bool unified_compressed = false;
990996

991997
LLAMA_LOG_INFO("%s: creating DSV4 CSA compressed KV cache, size = %u cells\n",
992-
__func__, dsv4_comp_size(kv_size, DSV4_CSA_RATIO));
998+
__func__, dsv4_comp_size(kv_size_raw, DSV4_CSA_RATIO));
993999

9941000
kv_csa = std::make_unique<llama_kv_cache>(
9951001
model, hparams_csa, type_k, type_v,
996-
v_trans, offload, unified_compressed, GGML_PAD(dsv4_comp_size(kv_size, DSV4_CSA_RATIO), 256u), n_seq_max, n_pad,
1002+
v_trans, offload, unified_compressed, GGML_PAD(dsv4_comp_size(kv_size_raw, DSV4_CSA_RATIO), DSV4_CACHE_PAD), n_seq_max, n_pad,
9971003
0, LLAMA_SWA_TYPE_NONE, nullptr, filter_csa, nullptr, nullptr);
9981004

9991005
LLAMA_LOG_INFO("%s: creating DSV4 HCA compressed KV cache, size = %u cells\n",
1000-
__func__, dsv4_comp_size(kv_size, DSV4_HCA_RATIO));
1006+
__func__, dsv4_comp_size(kv_size_raw, DSV4_HCA_RATIO));
10011007

10021008
kv_hca = std::make_unique<llama_kv_cache>(
10031009
model, hparams_hca, type_k, type_v,
1004-
v_trans, offload, unified_compressed, GGML_PAD(dsv4_comp_size(kv_size, DSV4_HCA_RATIO), 256u), n_seq_max, n_pad,
1010+
v_trans, offload, unified_compressed, GGML_PAD(dsv4_comp_size(kv_size_raw, DSV4_HCA_RATIO), DSV4_CACHE_PAD), n_seq_max, n_pad,
10051011
0, LLAMA_SWA_TYPE_NONE, nullptr, filter_hca, nullptr, nullptr);
10061012

10071013
LLAMA_LOG_INFO("%s: creating DSV4 lightning-indexer KV cache, size = %u cells\n",
1008-
__func__, dsv4_comp_size(kv_size, DSV4_CSA_RATIO));
1014+
__func__, dsv4_comp_size(kv_size_raw, DSV4_CSA_RATIO));
10091015

10101016
kv_lid = std::make_unique<llama_kv_cache>(
10111017
model, hparams_lid, type_k, type_v,
1012-
v_trans, offload, unified_compressed, GGML_PAD(dsv4_comp_size(kv_size, DSV4_CSA_RATIO), 256u), n_seq_max, n_pad,
1018+
v_trans, offload, unified_compressed, GGML_PAD(dsv4_comp_size(kv_size_raw, DSV4_CSA_RATIO), DSV4_CACHE_PAD), n_seq_max, n_pad,
10131019
0, LLAMA_SWA_TYPE_NONE, nullptr, filter_csa, nullptr, nullptr);
10141020

10151021
LLAMA_LOG_INFO("%s: creating DSV4 CSA compressor state\n", __func__);

src/llama-kv-cache-iswa.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,10 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
6868

6969
const uint32_t size_base = kv_size;
7070

71-
// note: the SWA cache is always padded to 256 for performance
71+
// note: the SWA cache is always padded to at least 256 for performance
7272
// https://github.com/ggml-org/llama.cpp/issues/17037
73-
uint32_t size_swa = GGML_PAD(std::min(size_base, hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch), 256);
73+
const uint32_t n_pad_swa = std::max(n_pad, 256u);
74+
uint32_t size_swa = GGML_PAD(std::min(size_base, hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch), n_pad_swa);
7475

7576
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
7677
if (swa_full) {

src/llama-kv-cache.cpp

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -57,22 +57,6 @@ static void ggml_gen_hadamard(ggml_tensor * tensor) {
5757
}
5858
}
5959

60-
static ggml_tensor * ggml_mul_mat_aux(
61-
ggml_context * ctx,
62-
ggml_tensor * cur,
63-
ggml_tensor * rot) {
64-
const auto n = rot->ne[0];
65-
66-
ggml_tensor * res;
67-
68-
res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n);
69-
res = ggml_mul_mat (ctx, rot, res);
70-
ggml_mul_mat_set_hint(res, GGML_HINT_SRC0_IS_HADAMARD);
71-
res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]);
72-
73-
return res;
74-
}
75-
7660
//
7761
// llama_kv_cache
7862
//
@@ -1875,14 +1859,14 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
18751859
tmp = ggml_cast(ctx, cur, GGML_TYPE_F32);
18761860

18771861
// rotate back
1878-
tmp = ggml_mul_mat_aux(ctx, tmp, rot);
1862+
tmp = llama_mul_mat_hadamard(ctx, tmp, rot);
18791863

18801864
tmp = ggml_rope_ext(ctx, tmp,
18811865
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
18821866
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
18831867

18841868
// rotate fwd
1885-
tmp = ggml_mul_mat_aux(ctx, tmp, rot);
1869+
tmp = llama_mul_mat_hadamard(ctx, tmp, rot);
18861870

18871871
tmp = ggml_cpy(ctx, tmp, cur);
18881872
} else {

src/models/deepseek4.cpp

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ ggml_tensor * llama_model_deepseek4::graph::build_lid_top_k(
557557
cb(indexer_q_pe, "lid_q_pe", il);
558558

559559
indexer_q = ggml_concat(ctx0, indexer_q_nope, indexer_q_pe, 0);
560-
indexer_q = ggml_mul_mat(ctx0, inp_lid.k_rot, indexer_q);
560+
indexer_q = llama_mul_mat_hadamard(ctx0, indexer_q, inp_lid.k_rot);
561561
cb(indexer_q, "lid_q_rot", il);
562562

563563
ggml_tensor * indexer_weights = build_lora_mm(layer.indexer_proj, cur);
@@ -652,10 +652,15 @@ ggml_tensor * llama_model_deepseek4::graph::build_csa_lid_attention(
652652
int il) const {
653653
const auto & inp_csa = inp_dsv4->get_csa();
654654
GGML_ASSERT(inp_csa.kq_mask);
655-
GGML_ASSERT(inp_attn->self_k_rot == nullptr);
656655

657656
ggml_tensor * top_k = build_lid_top_k(model, inp_dsv4, qr, cur, inp_pos, il);
658657

658+
ggml_tensor * k_rot = inp_attn->self_k_rot;
659+
if (k_rot) {
660+
q = llama_mul_mat_hadamard(ctx0, q, k_rot);
661+
kv = llama_mul_mat_hadamard(ctx0, kv, k_rot);
662+
}
663+
659664
ggml_build_forward_expand(gf, q);
660665
ggml_build_forward_expand(gf, kv);
661666

@@ -696,6 +701,9 @@ ggml_tensor * llama_model_deepseek4::graph::build_csa_lid_attention(
696701

697702
ggml_tensor * kq_b = dsv4_build_kq_zero_bias(ctx0, cparams, kq_mask, q->ne[1]);
698703
ggml_tensor * out = build_attn_mha(q, k_all, k_all, kq_b, kq_mask, sinks, nullptr, kq_scale, il);
704+
if (k_rot) {
705+
out = llama_mul_mat_hadamard(ctx0, out, k_rot);
706+
}
699707
cb(out, "attn_csa_lid", il);
700708

701709
return out;
@@ -711,7 +719,12 @@ ggml_tensor * llama_model_deepseek4::graph::build_hca_attention(
711719
int il) const {
712720
const auto & inp_hca = inp_dsv4->get_hca();
713721
GGML_ASSERT(inp_hca.kq_mask);
714-
GGML_ASSERT(inp_attn->self_k_rot == nullptr);
722+
723+
ggml_tensor * k_rot = inp_attn->self_k_rot;
724+
if (k_rot) {
725+
q = llama_mul_mat_hadamard(ctx0, q, k_rot);
726+
kv = llama_mul_mat_hadamard(ctx0, kv, k_rot);
727+
}
715728

716729
ggml_build_forward_expand(gf, q);
717730
ggml_build_forward_expand(gf, kv);
@@ -753,6 +766,9 @@ ggml_tensor * llama_model_deepseek4::graph::build_hca_attention(
753766

754767
ggml_tensor * kq_b = dsv4_build_kq_zero_bias(ctx0, cparams, kq_mask, q->ne[1]);
755768
ggml_tensor * out = build_attn_mha(q, k_all, k_all, kq_b, kq_mask, sinks, nullptr, kq_scale, il);
769+
if (k_rot) {
770+
out = llama_mul_mat_hadamard(ctx0, out, k_rot);
771+
}
756772
cb(out, "attn_hca", il);
757773

758774
return out;
@@ -770,8 +786,8 @@ ggml_tensor * llama_model_deepseek4::graph::build_raw_attention(
770786
ggml_tensor * k_rot = inp_attn->self_k_rot;
771787

772788
if (k_rot) {
773-
q = ggml_mul_mat(ctx0, k_rot, q);
774-
kv = ggml_mul_mat(ctx0, k_rot, kv);
789+
q = llama_mul_mat_hadamard(ctx0, q, k_rot);
790+
kv = llama_mul_mat_hadamard(ctx0, kv, k_rot);
775791
}
776792

777793
ggml_build_forward_expand(gf, q);
@@ -788,6 +804,9 @@ ggml_tensor * llama_model_deepseek4::graph::build_raw_attention(
788804

789805
ggml_tensor * kq_b = dsv4_build_kq_zero_bias(ctx0, cparams, kq_mask, q->ne[1]);
790806
ggml_tensor * out = build_attn_mha(q, k, k, kq_b, kq_mask, sinks, nullptr, kq_scale, il);
807+
if (k_rot) {
808+
out = llama_mul_mat_hadamard(ctx0, out, k_rot);
809+
}
791810
cb(out, "attn_raw", il);
792811

793812
return out;
@@ -917,6 +936,11 @@ ggml_tensor * llama_model_deepseek4::graph::build_attention(
917936
"csa_state_compress",
918937
il);
919938

939+
if (inp_dsv4->get_csa().k_rot) {
940+
kv_comp_csa_state = llama_mul_mat_hadamard(ctx0, kv_comp_csa_state, inp_dsv4->get_csa().k_rot);
941+
cb(kv_comp_csa_state, "csa_state_compress_rot", il);
942+
}
943+
920944
ggml_build_forward_expand(gf, inp_dsv4->mctx->get_csa()->cpy_k(ctx0,
921945
kv_comp_csa_state, inp_dsv4->get_csa().state_write_idxs, il));
922946

@@ -965,7 +989,7 @@ ggml_tensor * llama_model_deepseek4::graph::build_attention(
965989
il);
966990

967991
if (inp_dsv4->get_lid().k_rot) {
968-
kv_comp_lid_state = ggml_mul_mat(ctx0, inp_dsv4->get_lid().k_rot, kv_comp_lid_state);
992+
kv_comp_lid_state = llama_mul_mat_hadamard(ctx0, kv_comp_lid_state, inp_dsv4->get_lid().k_rot);
969993
cb(kv_comp_lid_state, "lid_state_compress_rot", il);
970994
}
971995

@@ -1007,6 +1031,11 @@ ggml_tensor * llama_model_deepseek4::graph::build_attention(
10071031
"hca_state_compress",
10081032
il);
10091033

1034+
if (inp_dsv4->get_hca().k_rot) {
1035+
kv_comp_hca = llama_mul_mat_hadamard(ctx0, kv_comp_hca, inp_dsv4->get_hca().k_rot);
1036+
cb(kv_comp_hca, "hca_state_compress_rot", il);
1037+
}
1038+
10101039
ggml_build_forward_expand(gf, inp_dsv4->mctx->get_hca()->cpy_k(ctx0,
10111040
kv_comp_hca, inp_dsv4->get_hca().state_write_idxs, il));
10121041
hca_state_dep = kv_comp_hca;
@@ -1035,13 +1064,11 @@ ggml_tensor * llama_model_deepseek4::graph::build_attention(
10351064
if (ratio == DSV4_CSA_RATIO &&
10361065
inp_dsv4->get_csa().kq_mask &&
10371066
inp_dsv4->get_lid().kq_mask &&
1038-
inp_dsv4->get_lid().k_rot &&
1039-
inp_attn->self_k_rot == nullptr) {
1067+
inp_dsv4->get_lid().k_rot) {
10401068
out = build_csa_lid_attention(model, inp_dsv4, inp_attn, q, kv, qr, cur, inp_pos, layer.attn_sinks,
10411069
1.0f/sqrtf(float(n_embd_head)), il);
10421070
} else if (ratio == DSV4_HCA_RATIO &&
1043-
inp_dsv4->get_hca().kq_mask &&
1044-
inp_attn->self_k_rot == nullptr) {
1071+
inp_dsv4->get_hca().kq_mask) {
10451072
out = build_hca_attention(inp_dsv4, inp_attn, q, kv, layer.attn_sinks,
10461073
1.0f/sqrtf(float(n_embd_head)), il);
10471074
} else {

0 commit comments

Comments
 (0)