Skip to content

Commit 1f60173

Browse files
ggerganovArberSephirotheca
authored andcommitted
kv-cache : support attention rotation for heterogeneous iSWA (ggml-org#21513)
* kv-cache : support attention rotation for heterogeneous iSWA * cont : remove assert
1 parent 69e22cb commit 1f60173

4 files changed

Lines changed: 58 additions & 17 deletions

File tree

src/llama-graph.cpp

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,14 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
511511
if (self_v_rot) {
512512
mctx->get_base()->set_input_v_rot(self_v_rot);
513513
}
514+
515+
if (self_k_rot_swa) {
516+
mctx->get_swa()->set_input_k_rot(self_k_rot_swa);
517+
}
518+
519+
if (self_v_rot_swa) {
520+
mctx->get_swa()->set_input_v_rot(self_v_rot_swa);
521+
}
514522
}
515523

516524
bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
@@ -681,6 +689,14 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
681689
attn_ctx->get_base()->set_input_v_rot(inp_attn->self_v_rot);
682690
}
683691

692+
if (inp_attn->self_k_rot_swa) {
693+
attn_ctx->get_swa()->set_input_k_rot(inp_attn->self_k_rot_swa);
694+
}
695+
696+
if (inp_attn->self_v_rot_swa) {
697+
attn_ctx->get_swa()->set_input_v_rot(inp_attn->self_v_rot_swa);
698+
}
699+
684700
const int64_t n_rs = mctx->get_recr()->get_n_rs();
685701

686702
if (inp_rs->s_copy) {
@@ -2233,15 +2249,20 @@ ggml_tensor * llm_graph_context::build_attn(
22332249
ggml_tensor * v_mla,
22342250
float kq_scale,
22352251
int il) const {
2236-
if (inp->self_k_rot) {
2237-
q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot);
2252+
const bool is_swa = hparams.is_swa(il);
2253+
2254+
auto * k_rot = is_swa ? inp->self_k_rot_swa : inp->self_k_rot;
2255+
auto * v_rot = is_swa ? inp->self_v_rot_swa : inp->self_v_rot;
2256+
2257+
if (k_rot) {
2258+
q_cur = ggml_mul_mat_aux(ctx0, q_cur, k_rot);
22382259
if (k_cur) {
2239-
k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot);
2260+
k_cur = ggml_mul_mat_aux(ctx0, k_cur, k_rot);
22402261
}
22412262
}
2242-
if (inp->self_v_rot) {
2263+
if (v_rot) {
22432264
if (v_cur) {
2244-
v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot);
2265+
v_cur = ggml_mul_mat_aux(ctx0, v_cur, v_rot);
22452266
}
22462267
}
22472268

@@ -2259,8 +2280,6 @@ ggml_tensor * llm_graph_context::build_attn(
22592280

22602281
const auto * mctx_iswa = inp->mctx;
22612282

2262-
const bool is_swa = hparams.is_swa(il);
2263-
22642283
const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
22652284

22662285
// optionally store to KV cache
@@ -2285,8 +2304,8 @@ ggml_tensor * llm_graph_context::build_attn(
22852304
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
22862305
cb(cur, "kqv_out", il);
22872306

2288-
if (inp->self_v_rot) {
2289-
cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot);
2307+
if (v_rot) {
2308+
cur = ggml_mul_mat_aux(ctx0, cur, v_rot);
22902309
}
22912310

22922311
if (wo) {
@@ -2388,6 +2407,9 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
23882407
inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0);
23892408
inp->self_v_rot = mctx_cur->get_base()->build_input_v_rot(ctx0);
23902409

2410+
inp->self_k_rot_swa = mctx_cur->get_swa()->build_input_k_rot(ctx0);
2411+
inp->self_v_rot_swa = mctx_cur->get_swa()->build_input_v_rot(ctx0);
2412+
23912413
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
23922414
}
23932415

src/llama-graph.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ class llm_graph_input_attn_kv : public llm_graph_input_i {
308308
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
309309
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
310310

311-
// note: assumes v_rot^ == I
311+
// note: assumes v_rot^2 == I
312312
ggml_tensor * self_k_rot = nullptr;
313313
ggml_tensor * self_v_rot = nullptr;
314314

@@ -388,10 +388,12 @@ class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
388388
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
389389
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
390390

391-
// note: using same rotation matrices for both base and swa cache
392391
ggml_tensor * self_k_rot = nullptr;
393392
ggml_tensor * self_v_rot = nullptr;
394393

394+
ggml_tensor * self_k_rot_swa = nullptr;
395+
ggml_tensor * self_v_rot_swa = nullptr;
396+
395397
const llama_hparams hparams;
396398
const llama_cparams cparams;
397399

src/llama-kv-cache.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,18 @@ llama_kv_cache::llama_kv_cache(
169169
continue;
170170
}
171171

172+
if (n_embd_head_k_all == 0) {
173+
n_embd_head_k_all = (int32_t) hparams.n_embd_head_k(il);
174+
} else if (n_embd_head_k_all > 0 && n_embd_head_k_all != (int32_t) hparams.n_embd_head_k(il)) {
175+
n_embd_head_k_all = -1;
176+
}
177+
178+
if (n_embd_head_v_all == 0) {
179+
n_embd_head_v_all = (int32_t) hparams.n_embd_head_v(il);
180+
} else if (n_embd_head_v_all > 0 && n_embd_head_v_all != (int32_t) hparams.n_embd_head_v(il)) {
181+
n_embd_head_v_all = -1;
182+
}
183+
172184
// [TAG_V_CACHE_VARIABLE]
173185
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
174186
const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max();
@@ -276,23 +288,23 @@ llama_kv_cache::llama_kv_cache(
276288

277289
attn_rot_k =
278290
!attn_rot_disable &&
291+
n_embd_head_k_all > 0 &&
279292
ggml_is_quantized(type_k) &&
280-
!hparams.is_n_embd_k_gqa_variable() &&
281293
hparams.n_embd_head_k() % 64 == 0;
282294

283295
attn_rot_v =
284296
!attn_rot_disable &&
297+
n_embd_head_v_all > 0 &&
285298
ggml_is_quantized(type_v) &&
286-
!hparams.is_n_embd_v_gqa_variable() &&
287299
hparams.n_embd_head_v() % 64 == 0;
288300

289-
LLAMA_LOG_INFO("%s: attn_rot_k = %d\n", __func__, attn_rot_k);
290-
LLAMA_LOG_INFO("%s: attn_rot_v = %d\n", __func__, attn_rot_v);
301+
LLAMA_LOG_INFO("%s: attn_rot_k = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_k, n_embd_head_k_all);
302+
LLAMA_LOG_INFO("%s: attn_rot_v = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_v, n_embd_head_v_all);
291303

292304
// pre-compute the haramard matrices and keep them in host memory
293305
// TODO: in the future, we can make copies in the backend buffers to avoid host -> device transfers
294306
if (attn_rot_k || attn_rot_v) {
295-
for (int64_t n = 64; n <= std::max(hparams.n_embd_head_k(), hparams.n_embd_head_v()); n *= 2) {
307+
for (int64_t n = 64; n <= std::max(n_embd_head_k_all, n_embd_head_v_all); n *= 2) {
296308
attn_rot_hadamard[n] = std::vector<float>(n*n);
297309

298310
ggml_init_params params = {
@@ -1308,7 +1320,7 @@ ggml_tensor * llama_kv_cache::build_input_k_rot(ggml_context * ctx) const {
13081320
// ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4141323088
13091321
do {
13101322
nrot *= 2;
1311-
} while (hparams.n_embd_head_k() % nrot == 0);
1323+
} while (n_embd_head_k_all % nrot == 0);
13121324
nrot /= 2;
13131325

13141326
res = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot);

src/llama-kv-cache.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,11 @@ class llama_kv_cache : public llama_memory_i {
239239
bool attn_rot_k = false;
240240
bool attn_rot_v = false;
241241

242+
// if all layers participating in the cache have constant head size, the value is stored here
243+
// otherwise the value is -1
244+
int32_t n_embd_head_k_all = 0;
245+
int32_t n_embd_head_v_all = 0;
246+
242247
// pre-computed hadamard martrices
243248
std::unordered_map<int64_t, std::vector<float>> attn_rot_hadamard;
244249

0 commit comments

Comments
 (0)