Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 31 additions & 9 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,14 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
if (self_v_rot) {
mctx->get_base()->set_input_v_rot(self_v_rot);
}

if (self_k_rot_swa) {
mctx->get_swa()->set_input_k_rot(self_k_rot_swa);
}

if (self_v_rot_swa) {
mctx->get_swa()->set_input_v_rot(self_v_rot_swa);
}
}

bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
Expand Down Expand Up @@ -681,6 +689,14 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
attn_ctx->get_base()->set_input_v_rot(inp_attn->self_v_rot);
}

if (inp_attn->self_k_rot_swa) {
attn_ctx->get_swa()->set_input_k_rot(inp_attn->self_k_rot_swa);
}

if (inp_attn->self_v_rot_swa) {
attn_ctx->get_swa()->set_input_v_rot(inp_attn->self_v_rot_swa);
}

const int64_t n_rs = mctx->get_recr()->get_n_rs();

if (inp_rs->s_copy) {
Expand Down Expand Up @@ -2233,15 +2249,20 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * v_mla,
float kq_scale,
int il) const {
if (inp->self_k_rot) {
q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot);
const bool is_swa = hparams.is_swa(il);

auto * k_rot = is_swa ? inp->self_k_rot_swa : inp->self_k_rot;
auto * v_rot = is_swa ? inp->self_v_rot_swa : inp->self_v_rot;

if (k_rot) {
q_cur = ggml_mul_mat_aux(ctx0, q_cur, k_rot);
if (k_cur) {
k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot);
k_cur = ggml_mul_mat_aux(ctx0, k_cur, k_rot);
}
}
if (inp->self_v_rot) {
if (v_rot) {
if (v_cur) {
v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot);
v_cur = ggml_mul_mat_aux(ctx0, v_cur, v_rot);
}
}

Expand All @@ -2259,8 +2280,6 @@ ggml_tensor * llm_graph_context::build_attn(

const auto * mctx_iswa = inp->mctx;

const bool is_swa = hparams.is_swa(il);

const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();

// optionally store to KV cache
Expand All @@ -2285,8 +2304,8 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
cb(cur, "kqv_out", il);

if (inp->self_v_rot) {
cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot);
if (v_rot) {
cur = ggml_mul_mat_aux(ctx0, cur, v_rot);
}

if (wo) {
Expand Down Expand Up @@ -2388,6 +2407,9 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0);
inp->self_v_rot = mctx_cur->get_base()->build_input_v_rot(ctx0);

inp->self_k_rot_swa = mctx_cur->get_swa()->build_input_k_rot(ctx0);
inp->self_v_rot_swa = mctx_cur->get_swa()->build_input_v_rot(ctx0);

return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
}

Expand Down
6 changes: 4 additions & 2 deletions src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ class llm_graph_input_attn_kv : public llm_graph_input_i {
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]

// note: assumes v_rot^ == I
// note: assumes v_rot^2 == I
ggml_tensor * self_k_rot = nullptr;
ggml_tensor * self_v_rot = nullptr;

Expand Down Expand Up @@ -388,10 +388,12 @@ class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]

// note: using same rotation matrices for both base and swa cache
ggml_tensor * self_k_rot = nullptr;
ggml_tensor * self_v_rot = nullptr;

ggml_tensor * self_k_rot_swa = nullptr;
ggml_tensor * self_v_rot_swa = nullptr;

const llama_hparams hparams;
const llama_cparams cparams;

Expand Down
24 changes: 18 additions & 6 deletions src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,18 @@ llama_kv_cache::llama_kv_cache(
continue;
}

if (n_embd_head_k_all == 0) {
n_embd_head_k_all = (int32_t) hparams.n_embd_head_k(il);
} else if (n_embd_head_k_all > 0 && n_embd_head_k_all != (int32_t) hparams.n_embd_head_k(il)) {
n_embd_head_k_all = -1;
}

if (n_embd_head_v_all == 0) {
n_embd_head_v_all = (int32_t) hparams.n_embd_head_v(il);
} else if (n_embd_head_v_all > 0 && n_embd_head_v_all != (int32_t) hparams.n_embd_head_v(il)) {
n_embd_head_v_all = -1;
}

// [TAG_V_CACHE_VARIABLE]
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max();
Expand Down Expand Up @@ -276,23 +288,23 @@ llama_kv_cache::llama_kv_cache(

attn_rot_k =
!attn_rot_disable &&
n_embd_head_k_all > 0 &&
ggml_is_quantized(type_k) &&
!hparams.is_n_embd_k_gqa_variable() &&
hparams.n_embd_head_k() % 64 == 0;

attn_rot_v =
!attn_rot_disable &&
n_embd_head_v_all > 0 &&
ggml_is_quantized(type_v) &&
!hparams.is_n_embd_v_gqa_variable() &&
hparams.n_embd_head_v() % 64 == 0;

LLAMA_LOG_INFO("%s: attn_rot_k = %d\n", __func__, attn_rot_k);
LLAMA_LOG_INFO("%s: attn_rot_v = %d\n", __func__, attn_rot_v);
LLAMA_LOG_INFO("%s: attn_rot_k = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_k, n_embd_head_k_all);
LLAMA_LOG_INFO("%s: attn_rot_v = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_v, n_embd_head_v_all);

// pre-compute the haramard matrices and keep them in host memory
// TODO: in the future, we can make copies in the backend buffers to avoid host -> device transfers
if (attn_rot_k || attn_rot_v) {
for (int64_t n = 64; n <= std::max(hparams.n_embd_head_k(), hparams.n_embd_head_v()); n *= 2) {
for (int64_t n = 64; n <= std::max(n_embd_head_k_all, n_embd_head_v_all); n *= 2) {
attn_rot_hadamard[n] = std::vector<float>(n*n);

ggml_init_params params = {
Expand Down Expand Up @@ -1308,7 +1320,7 @@ ggml_tensor * llama_kv_cache::build_input_k_rot(ggml_context * ctx) const {
// ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4141323088
do {
nrot *= 2;
} while (hparams.n_embd_head_k() % nrot == 0);
} while (n_embd_head_k_all % nrot == 0);
nrot /= 2;

res = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot);
Expand Down
5 changes: 5 additions & 0 deletions src/llama-kv-cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,11 @@ class llama_kv_cache : public llama_memory_i {
bool attn_rot_k = false;
bool attn_rot_v = false;

// if all layers participating in the cache have constant head size, the value is stored here
// otherwise the value is -1
int32_t n_embd_head_k_all = 0;
int32_t n_embd_head_v_all = 0;

// pre-computed hadamard martrices
std::unordered_map<int64_t, std::vector<float>> attn_rot_hadamard;

Expand Down
Loading