Skip to content

Commit 4b1d1ae

Browse files
committed
add Q rot when cache is quantized
1 parent 3c4e521 commit 4b1d1ae

3 files changed

Lines changed: 46 additions & 5 deletions

File tree

src/llama-graph.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,19 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
562562
void llm_graph_input_attn_src_kv_iswa::set_input(const llama_ubatch * ubatch) {
563563
src_mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
564564
src_mctx->get_swa() ->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
565+
566+
if (self_k_rot) {
567+
src_mctx->get_base()->set_input_k_rot(self_k_rot);
568+
}
569+
if (self_v_rot) {
570+
src_mctx->get_base()->set_input_v_rot(self_v_rot);
571+
}
572+
if (self_k_rot_swa) {
573+
src_mctx->get_swa()->set_input_k_rot(self_k_rot_swa);
574+
}
575+
if (self_v_rot_swa) {
576+
src_mctx->get_swa()->set_input_v_rot(self_v_rot_swa);
577+
}
565578
}
566579

567580
bool llm_graph_input_attn_src_kv_iswa::can_reuse(const llm_graph_params & params) {
@@ -2485,6 +2498,11 @@ llm_graph_input_attn_src_kv_iswa * llm_graph_context::build_attn_inp_src_kv_iswa
24852498
inp->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, src_iswa->get_swa(), ubatch, cparams);
24862499
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
24872500

2501+
inp->self_k_rot = src_iswa->get_base()->build_input_k_rot(ctx0);
2502+
inp->self_v_rot = src_iswa->get_base()->build_input_v_rot(ctx0);
2503+
inp->self_k_rot_swa = src_iswa->get_swa()->build_input_k_rot(ctx0);
2504+
inp->self_v_rot_swa = src_iswa->get_swa()->build_input_v_rot(ctx0);
2505+
24882506
return (llm_graph_input_attn_src_kv_iswa *) res->add_input(std::move(inp));
24892507
}
24902508

@@ -2507,6 +2525,13 @@ ggml_tensor * llm_graph_context::build_attn(
25072525

25082526
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
25092527

2528+
auto * k_rot = is_swa ? inp->self_k_rot_swa : inp->self_k_rot;
2529+
auto * v_rot = is_swa ? inp->self_v_rot_swa : inp->self_v_rot;
2530+
2531+
if (k_rot) {
2532+
q_cur = ggml_mul_mat_aux(ctx0, q_cur, k_rot);
2533+
}
2534+
25102535
ggml_build_forward_expand(gf, q_cur);
25112536

25122537
ggml_tensor * q = q_cur;
@@ -2539,6 +2564,10 @@ ggml_tensor * llm_graph_context::build_attn(
25392564
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il_assist);
25402565
cb(cur, "kqv_out", il_assist);
25412566

2567+
if (v_rot) {
2568+
cur = ggml_mul_mat_aux(ctx0, cur, v_rot);
2569+
}
2570+
25422571
if (wo) {
25432572
cur = build_lora_mm(wo, cur, wo_s);
25442573
}

src/llama-graph.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,11 @@ class llm_graph_input_attn_src_kv_iswa : public llm_graph_input_i {
427427
ggml_tensor * self_kq_mask_swa = nullptr;
428428
ggml_tensor * self_kq_mask_swa_cnv = nullptr;
429429

430+
ggml_tensor * self_k_rot = nullptr;
431+
ggml_tensor * self_v_rot = nullptr;
432+
ggml_tensor * self_k_rot_swa = nullptr;
433+
ggml_tensor * self_v_rot_swa = nullptr;
434+
430435
const llama_hparams hparams;
431436
const llama_cparams cparams;
432437

tools/server/server-context.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -816,13 +816,23 @@ struct server_context_impl {
816816

817817
SRV_INF("loading draft model '%s'\n", params_spec.mparams.path.c_str());
818818

819+
const bool spec_mtp = std::find(params_base.speculative.types.begin(),
820+
params_base.speculative.types.end(),
821+
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end();
822+
819823
auto params_dft = params_base;
820824

821825
params_dft.devices = params_spec.devices;
822826
params_dft.model = params_spec.mparams;
823827
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
824-
params_dft.cache_type_k = params_spec.cache_type_k;
825-
params_dft.cache_type_v = params_spec.cache_type_v;
828+
// TODO: find a better way to expose that the cache is shared
829+
if (spec_mtp) {
830+
params_dft.cache_type_k = params_base.cache_type_k;
831+
params_dft.cache_type_v = params_base.cache_type_v;
832+
} else {
833+
params_dft.cache_type_k = params_spec.cache_type_k;
834+
params_dft.cache_type_v = params_spec.cache_type_v;
835+
}
826836

827837
if (params_spec.cpuparams.n_threads > 0) {
828838
params_dft.cpuparams.n_threads = params_spec.cpuparams.n_threads;
@@ -841,9 +851,6 @@ struct server_context_impl {
841851

842852
auto cparams = common_context_params_to_llama(params_dft);
843853

844-
const bool spec_mtp = std::find(params_base.speculative.types.begin(),
845-
params_base.speculative.types.end(),
846-
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end();
847854
if (spec_mtp) {
848855
cparams.ctx_type = LLAMA_CONTEXT_TYPE_MTP;
849856
}

0 commit comments

Comments
 (0)