@@ -562,6 +562,19 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
562562void llm_graph_input_attn_src_kv_iswa::set_input (const llama_ubatch * ubatch) {
563563 src_mctx->get_base ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
564564 src_mctx->get_swa () ->set_input_kq_mask (self_kq_mask_swa, ubatch, cparams.causal_attn );
565+
566+ if (self_k_rot) {
567+ src_mctx->get_base ()->set_input_k_rot (self_k_rot);
568+ }
569+ if (self_v_rot) {
570+ src_mctx->get_base ()->set_input_v_rot (self_v_rot);
571+ }
572+ if (self_k_rot_swa) {
573+ src_mctx->get_swa ()->set_input_k_rot (self_k_rot_swa);
574+ }
575+ if (self_v_rot_swa) {
576+ src_mctx->get_swa ()->set_input_v_rot (self_v_rot_swa);
577+ }
565578}
566579
567580bool llm_graph_input_attn_src_kv_iswa::can_reuse (const llm_graph_params & params) {
@@ -2485,6 +2498,11 @@ llm_graph_input_attn_src_kv_iswa * llm_graph_context::build_attn_inp_src_kv_iswa
24852498 inp->self_kq_mask_swa = build_attn_inp_kq_mask (ctx0, src_iswa->get_swa (), ubatch, cparams);
24862499 inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask_swa , GGML_TYPE_F16) : inp->self_kq_mask_swa ;
24872500
2501+ inp->self_k_rot = src_iswa->get_base ()->build_input_k_rot (ctx0);
2502+ inp->self_v_rot = src_iswa->get_base ()->build_input_v_rot (ctx0);
2503+ inp->self_k_rot_swa = src_iswa->get_swa ()->build_input_k_rot (ctx0);
2504+ inp->self_v_rot_swa = src_iswa->get_swa ()->build_input_v_rot (ctx0);
2505+
24882506 return (llm_graph_input_attn_src_kv_iswa *) res->add_input (std::move (inp));
24892507}
24902508
@@ -2507,6 +2525,13 @@ ggml_tensor * llm_graph_context::build_attn(
25072525
25082526 const auto & kq_mask = is_swa ? inp->get_kq_mask_swa () : inp->get_kq_mask ();
25092527
2528+ auto * k_rot = is_swa ? inp->self_k_rot_swa : inp->self_k_rot ;
2529+ auto * v_rot = is_swa ? inp->self_v_rot_swa : inp->self_v_rot ;
2530+
2531+ if (k_rot) {
2532+ q_cur = ggml_mul_mat_aux (ctx0, q_cur, k_rot);
2533+ }
2534+
25102535 ggml_build_forward_expand (gf, q_cur);
25112536
25122537 ggml_tensor * q = q_cur;
@@ -2539,6 +2564,10 @@ ggml_tensor * llm_graph_context::build_attn(
25392564 ggml_tensor * cur = build_attn_mha (q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il_assist);
25402565 cb (cur, " kqv_out" , il_assist);
25412566
2567+ if (v_rot) {
2568+ cur = ggml_mul_mat_aux (ctx0, cur, v_rot);
2569+ }
2570+
25422571 if (wo) {
25432572 cur = build_lora_mm (wo, cur, wo_s);
25442573 }
0 commit comments