Skip to content

Commit f5ba49f

Browse files
ggerganovOsamaMazhar
authored andcommitted
kv-cache : support attention rotation for heterogeneous iSWA (ggml-org#21513)
* kv-cache : support attention rotation for heterogeneous iSWA * cont : remove assert
1 parent 89879a5 commit f5ba49f

4 files changed

Lines changed: 58 additions & 17 deletions

File tree

src/llama-graph.cpp

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,14 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
511511
if (self_v_rot) {
512512
mctx->get_base()->set_input_v_rot(self_v_rot);
513513
}
514+
515+
if (self_k_rot_swa) {
516+
mctx->get_swa()->set_input_k_rot(self_k_rot_swa);
517+
}
518+
519+
if (self_v_rot_swa) {
520+
mctx->get_swa()->set_input_v_rot(self_v_rot_swa);
521+
}
514522
}
515523

516524
bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
@@ -681,6 +689,14 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
681689
attn_ctx->get_base()->set_input_v_rot(inp_attn->self_v_rot);
682690
}
683691

692+
if (inp_attn->self_k_rot_swa) {
693+
attn_ctx->get_swa()->set_input_k_rot(inp_attn->self_k_rot_swa);
694+
}
695+
696+
if (inp_attn->self_v_rot_swa) {
697+
attn_ctx->get_swa()->set_input_v_rot(inp_attn->self_v_rot_swa);
698+
}
699+
684700
const int64_t n_rs = mctx->get_recr()->get_n_rs();
685701

686702
if (inp_rs->s_copy) {
@@ -2242,15 +2258,20 @@ ggml_tensor * llm_graph_context::build_attn(
22422258
ggml_tensor * v_mla,
22432259
float kq_scale,
22442260
int il) const {
2245-
if (inp->self_k_rot) {
2246-
q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot);
2261+
const bool is_swa = hparams.is_swa(il);
2262+
2263+
auto * k_rot = is_swa ? inp->self_k_rot_swa : inp->self_k_rot;
2264+
auto * v_rot = is_swa ? inp->self_v_rot_swa : inp->self_v_rot;
2265+
2266+
if (k_rot) {
2267+
q_cur = ggml_mul_mat_aux(ctx0, q_cur, k_rot);
22472268
if (k_cur) {
2248-
k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot);
2269+
k_cur = ggml_mul_mat_aux(ctx0, k_cur, k_rot);
22492270
}
22502271
}
2251-
if (inp->self_v_rot) {
2272+
if (v_rot) {
22522273
if (v_cur) {
2253-
v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot);
2274+
v_cur = ggml_mul_mat_aux(ctx0, v_cur, v_rot);
22542275
}
22552276
}
22562277

@@ -2268,8 +2289,6 @@ ggml_tensor * llm_graph_context::build_attn(
22682289

22692290
const auto * mctx_iswa = inp->mctx;
22702291

2271-
const bool is_swa = hparams.is_swa(il);
2272-
22732292
const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
22742293

22752294
// optionally store to KV cache
@@ -2294,8 +2313,8 @@ ggml_tensor * llm_graph_context::build_attn(
22942313
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
22952314
cb(cur, "kqv_out", il);
22962315

2297-
if (inp->self_v_rot) {
2298-
cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot);
2316+
if (v_rot) {
2317+
cur = ggml_mul_mat_aux(ctx0, cur, v_rot);
22992318
}
23002319

23012320
if (wo) {
@@ -2397,6 +2416,9 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
23972416
inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0);
23982417
inp->self_v_rot = mctx_cur->get_base()->build_input_v_rot(ctx0);
23992418

2419+
inp->self_k_rot_swa = mctx_cur->get_swa()->build_input_k_rot(ctx0);
2420+
inp->self_v_rot_swa = mctx_cur->get_swa()->build_input_v_rot(ctx0);
2421+
24002422
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
24012423
}
24022424

src/llama-graph.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ class llm_graph_input_attn_kv : public llm_graph_input_i {
308308
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
309309
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
310310

311-
// note: assumes v_rot^ == I
311+
// note: assumes v_rot^2 == I
312312
ggml_tensor * self_k_rot = nullptr;
313313
ggml_tensor * self_v_rot = nullptr;
314314

@@ -388,10 +388,12 @@ class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
388388
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
389389
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
390390

391-
// note: using same rotation matrices for both base and swa cache
392391
ggml_tensor * self_k_rot = nullptr;
393392
ggml_tensor * self_v_rot = nullptr;
394393

394+
ggml_tensor * self_k_rot_swa = nullptr;
395+
ggml_tensor * self_v_rot_swa = nullptr;
396+
395397
const llama_hparams hparams;
396398
const llama_cparams cparams;
397399

src/llama-kv-cache.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,18 @@ llama_kv_cache::llama_kv_cache(
169169
continue;
170170
}
171171

172+
if (n_embd_head_k_all == 0) {
173+
n_embd_head_k_all = (int32_t) hparams.n_embd_head_k(il);
174+
} else if (n_embd_head_k_all > 0 && n_embd_head_k_all != (int32_t) hparams.n_embd_head_k(il)) {
175+
n_embd_head_k_all = -1;
176+
}
177+
178+
if (n_embd_head_v_all == 0) {
179+
n_embd_head_v_all = (int32_t) hparams.n_embd_head_v(il);
180+
} else if (n_embd_head_v_all > 0 && n_embd_head_v_all != (int32_t) hparams.n_embd_head_v(il)) {
181+
n_embd_head_v_all = -1;
182+
}
183+
172184
// [TAG_V_CACHE_VARIABLE]
173185
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
174186
const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max();
@@ -276,23 +288,23 @@ llama_kv_cache::llama_kv_cache(
276288

277289
attn_rot_k =
278290
!attn_rot_disable &&
291+
n_embd_head_k_all > 0 &&
279292
ggml_is_quantized(type_k) &&
280-
!hparams.is_n_embd_k_gqa_variable() &&
281293
hparams.n_embd_head_k() % 64 == 0;
282294

283295
attn_rot_v =
284296
!attn_rot_disable &&
297+
n_embd_head_v_all > 0 &&
285298
ggml_is_quantized(type_v) &&
286-
!hparams.is_n_embd_v_gqa_variable() &&
287299
hparams.n_embd_head_v() % 64 == 0;
288300

289-
LLAMA_LOG_INFO("%s: attn_rot_k = %d\n", __func__, attn_rot_k);
290-
LLAMA_LOG_INFO("%s: attn_rot_v = %d\n", __func__, attn_rot_v);
301+
LLAMA_LOG_INFO("%s: attn_rot_k = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_k, n_embd_head_k_all);
302+
LLAMA_LOG_INFO("%s: attn_rot_v = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_v, n_embd_head_v_all);
291303

292304
// pre-compute the haramard matrices and keep them in host memory
293305
// TODO: in the future, we can make copies in the backend buffers to avoid host -> device transfers
294306
if (attn_rot_k || attn_rot_v) {
295-
for (int64_t n = 64; n <= std::max(hparams.n_embd_head_k(), hparams.n_embd_head_v()); n *= 2) {
307+
for (int64_t n = 64; n <= std::max(n_embd_head_k_all, n_embd_head_v_all); n *= 2) {
296308
attn_rot_hadamard[n] = std::vector<float>(n*n);
297309

298310
ggml_init_params params = {
@@ -1308,7 +1320,7 @@ ggml_tensor * llama_kv_cache::build_input_k_rot(ggml_context * ctx) const {
13081320
// ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4141323088
13091321
do {
13101322
nrot *= 2;
1311-
} while (hparams.n_embd_head_k() % nrot == 0);
1323+
} while (n_embd_head_k_all % nrot == 0);
13121324
nrot /= 2;
13131325

13141326
res = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot);

src/llama-kv-cache.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,11 @@ class llama_kv_cache : public llama_memory_i {
239239
bool attn_rot_k = false;
240240
bool attn_rot_v = false;
241241

242+
// if all layers participating in the cache have constant head size, the value is stored here
243+
// otherwise the value is -1
244+
int32_t n_embd_head_k_all = 0;
245+
int32_t n_embd_head_v_all = 0;
246+
242247
// pre-computed hadamard martrices
243248
std::unordered_map<int64_t, std::vector<float>> attn_rot_hadamard;
244249

0 commit comments

Comments
 (0)