Skip to content

Commit cf24e84

Browse files
ggerganoviamwavecut
authored andcommitted
kv-cache : support attention rotation for heterogeneous iSWA (ggml-org#21513)
* kv-cache : support attention rotation for heterogeneous iSWA * cont : remove assert
1 parent 115311f commit cf24e84

File tree

4 files changed

+58
-17
lines changed

4 files changed

+58
-17
lines changed

src/llama-graph.cpp

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,14 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
511511
if (self_v_rot) {
512512
mctx->get_base()->set_input_v_rot(self_v_rot);
513513
}
514+
515+
if (self_k_rot_swa) {
516+
mctx->get_swa()->set_input_k_rot(self_k_rot_swa);
517+
}
518+
519+
if (self_v_rot_swa) {
520+
mctx->get_swa()->set_input_v_rot(self_v_rot_swa);
521+
}
514522
}
515523

516524
bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
@@ -681,6 +689,14 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
681689
attn_ctx->get_base()->set_input_v_rot(inp_attn->self_v_rot);
682690
}
683691

692+
if (inp_attn->self_k_rot_swa) {
693+
attn_ctx->get_swa()->set_input_k_rot(inp_attn->self_k_rot_swa);
694+
}
695+
696+
if (inp_attn->self_v_rot_swa) {
697+
attn_ctx->get_swa()->set_input_v_rot(inp_attn->self_v_rot_swa);
698+
}
699+
684700
const int64_t n_rs = mctx->get_recr()->get_n_rs();
685701

686702
if (inp_rs->s_copy) {
@@ -2328,15 +2344,20 @@ ggml_tensor * llm_graph_context::build_attn(
23282344
ggml_tensor * v_mla,
23292345
float kq_scale,
23302346
int il) const {
2331-
if (inp->self_k_rot) {
2332-
q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot);
2347+
const bool is_swa = hparams.is_swa(il);
2348+
2349+
auto * k_rot = is_swa ? inp->self_k_rot_swa : inp->self_k_rot;
2350+
auto * v_rot = is_swa ? inp->self_v_rot_swa : inp->self_v_rot;
2351+
2352+
if (k_rot) {
2353+
q_cur = ggml_mul_mat_aux(ctx0, q_cur, k_rot);
23332354
if (k_cur) {
2334-
k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot);
2355+
k_cur = ggml_mul_mat_aux(ctx0, k_cur, k_rot);
23352356
}
23362357
}
2337-
if (inp->self_v_rot) {
2358+
if (v_rot) {
23382359
if (v_cur) {
2339-
v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot);
2360+
v_cur = ggml_mul_mat_aux(ctx0, v_cur, v_rot);
23402361
}
23412362
}
23422363

@@ -2354,8 +2375,6 @@ ggml_tensor * llm_graph_context::build_attn(
23542375

23552376
const auto * mctx_iswa = inp->mctx;
23562377

2357-
const bool is_swa = hparams.is_swa(il);
2358-
23592378
const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
23602379

23612380
// optionally store to KV cache
@@ -2406,8 +2425,8 @@ ggml_tensor * llm_graph_context::build_attn(
24062425
}
24072426
}
24082427

2409-
if (inp->self_v_rot) {
2410-
cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot);
2428+
if (v_rot) {
2429+
cur = ggml_mul_mat_aux(ctx0, cur, v_rot);
24112430
}
24122431

24132432
if (wo) {
@@ -2509,6 +2528,9 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
25092528
inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0);
25102529
inp->self_v_rot = mctx_cur->get_base()->build_input_v_rot(ctx0);
25112530

2531+
inp->self_k_rot_swa = mctx_cur->get_swa()->build_input_k_rot(ctx0);
2532+
inp->self_v_rot_swa = mctx_cur->get_swa()->build_input_v_rot(ctx0);
2533+
25122534
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
25132535
}
25142536

src/llama-graph.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ class llm_graph_input_attn_kv : public llm_graph_input_i {
308308
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
309309
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
310310

311-
// note: assumes v_rot^ == I
311+
// note: assumes v_rot^2 == I
312312
ggml_tensor * self_k_rot = nullptr;
313313
ggml_tensor * self_v_rot = nullptr;
314314

@@ -388,10 +388,12 @@ class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
388388
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
389389
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
390390

391-
// note: using same rotation matrices for both base and swa cache
392391
ggml_tensor * self_k_rot = nullptr;
393392
ggml_tensor * self_v_rot = nullptr;
394393

394+
ggml_tensor * self_k_rot_swa = nullptr;
395+
ggml_tensor * self_v_rot_swa = nullptr;
396+
395397
const llama_hparams hparams;
396398
const llama_cparams cparams;
397399

src/llama-kv-cache.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,18 @@ llama_kv_cache::llama_kv_cache(
194194
continue;
195195
}
196196

197+
if (n_embd_head_k_all == 0) {
198+
n_embd_head_k_all = (int32_t) hparams.n_embd_head_k(il);
199+
} else if (n_embd_head_k_all > 0 && n_embd_head_k_all != (int32_t) hparams.n_embd_head_k(il)) {
200+
n_embd_head_k_all = -1;
201+
}
202+
203+
if (n_embd_head_v_all == 0) {
204+
n_embd_head_v_all = (int32_t) hparams.n_embd_head_v(il);
205+
} else if (n_embd_head_v_all > 0 && n_embd_head_v_all != (int32_t) hparams.n_embd_head_v(il)) {
206+
n_embd_head_v_all = -1;
207+
}
208+
197209
// [TAG_V_CACHE_VARIABLE]
198210
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
199211
const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max();
@@ -437,23 +449,23 @@ llama_kv_cache::llama_kv_cache(
437449

438450
attn_rot_k =
439451
!attn_rot_disable &&
452+
n_embd_head_k_all > 0 &&
440453
ggml_is_quantized(type_k) &&
441-
!hparams.is_n_embd_k_gqa_variable() &&
442454
hparams.n_embd_head_k() % 64 == 0;
443455

444456
attn_rot_v =
445457
!attn_rot_disable &&
458+
n_embd_head_v_all > 0 &&
446459
ggml_is_quantized(type_v) &&
447-
!hparams.is_n_embd_v_gqa_variable() &&
448460
hparams.n_embd_head_v() % 64 == 0;
449461

450-
LLAMA_LOG_INFO("%s: attn_rot_k = %d\n", __func__, attn_rot_k);
451-
LLAMA_LOG_INFO("%s: attn_rot_v = %d\n", __func__, attn_rot_v);
462+
LLAMA_LOG_INFO("%s: attn_rot_k = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_k, n_embd_head_k_all);
463+
LLAMA_LOG_INFO("%s: attn_rot_v = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_v, n_embd_head_v_all);
452464

453465
// pre-compute the haramard matrices and keep them in host memory
454466
// TODO: in the future, we can make copies in the backend buffers to avoid host -> device transfers
455467
if (attn_rot_k || attn_rot_v) {
456-
for (int64_t n = 64; n <= std::max(hparams.n_embd_head_k(), hparams.n_embd_head_v()); n *= 2) {
468+
for (int64_t n = 64; n <= std::max(n_embd_head_k_all, n_embd_head_v_all); n *= 2) {
457469
attn_rot_hadamard[n] = std::vector<float>(n*n);
458470

459471
ggml_init_params params = {
@@ -1535,7 +1547,7 @@ ggml_tensor * llama_kv_cache::build_input_k_rot(ggml_context * ctx) const {
15351547
// ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4141323088
15361548
do {
15371549
nrot *= 2;
1538-
} while (hparams.n_embd_head_k() % nrot == 0);
1550+
} while (n_embd_head_k_all % nrot == 0);
15391551
nrot /= 2;
15401552

15411553
res = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot);

src/llama-kv-cache.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,11 @@ class llama_kv_cache : public llama_memory_i {
248248
bool attn_rot_k = false;
249249
bool attn_rot_v = false;
250250

251+
// if all layers participating in the cache have constant head size, the value is stored here
252+
// otherwise the value is -1
253+
int32_t n_embd_head_k_all = 0;
254+
int32_t n_embd_head_v_all = 0;
255+
251256
// pre-computed hadamard martrices
252257
std::unordered_map<int64_t, std::vector<float>> attn_rot_hadamard;
253258

0 commit comments

Comments
 (0)