@@ -511,6 +511,14 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
511511 if (self_v_rot) {
512512 mctx->get_base ()->set_input_v_rot (self_v_rot);
513513 }
514+
515+ if (self_k_rot_swa) {
516+ mctx->get_swa ()->set_input_k_rot (self_k_rot_swa);
517+ }
518+
519+ if (self_v_rot_swa) {
520+ mctx->get_swa ()->set_input_v_rot (self_v_rot_swa);
521+ }
514522}
515523
516524bool llm_graph_input_attn_kv_iswa::can_reuse (const llm_graph_params & params) {
@@ -681,6 +689,14 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
681689 attn_ctx->get_base ()->set_input_v_rot (inp_attn->self_v_rot );
682690 }
683691
692+ if (inp_attn->self_k_rot_swa ) {
693+ attn_ctx->get_swa ()->set_input_k_rot (inp_attn->self_k_rot_swa );
694+ }
695+
696+ if (inp_attn->self_v_rot_swa ) {
697+ attn_ctx->get_swa ()->set_input_v_rot (inp_attn->self_v_rot_swa );
698+ }
699+
684700 const int64_t n_rs = mctx->get_recr ()->get_n_rs ();
685701
686702 if (inp_rs->s_copy ) {
@@ -2233,15 +2249,20 @@ ggml_tensor * llm_graph_context::build_attn(
22332249 ggml_tensor * v_mla,
22342250 float kq_scale,
22352251 int il) const {
2236- if (inp->self_k_rot ) {
2237- q_cur = ggml_mul_mat_aux (ctx0, q_cur, inp->self_k_rot );
2252+ const bool is_swa = hparams.is_swa (il);
2253+
2254+ auto * k_rot = is_swa ? inp->self_k_rot_swa : inp->self_k_rot ;
2255+ auto * v_rot = is_swa ? inp->self_v_rot_swa : inp->self_v_rot ;
2256+
2257+ if (k_rot) {
2258+ q_cur = ggml_mul_mat_aux (ctx0, q_cur, k_rot);
22382259 if (k_cur) {
2239- k_cur = ggml_mul_mat_aux (ctx0, k_cur, inp-> self_k_rot );
2260+ k_cur = ggml_mul_mat_aux (ctx0, k_cur, k_rot );
22402261 }
22412262 }
2242- if (inp-> self_v_rot ) {
2263+ if (v_rot ) {
22432264 if (v_cur) {
2244- v_cur = ggml_mul_mat_aux (ctx0, v_cur, inp-> self_v_rot );
2265+ v_cur = ggml_mul_mat_aux (ctx0, v_cur, v_rot );
22452266 }
22462267 }
22472268
@@ -2259,8 +2280,6 @@ ggml_tensor * llm_graph_context::build_attn(
22592280
22602281 const auto * mctx_iswa = inp->mctx ;
22612282
2262- const bool is_swa = hparams.is_swa (il);
2263-
22642283 const auto * mctx_cur = is_swa ? mctx_iswa->get_swa () : mctx_iswa->get_base ();
22652284
22662285 // optionally store to KV cache
@@ -2285,8 +2304,8 @@ ggml_tensor * llm_graph_context::build_attn(
22852304 ggml_tensor * cur = build_attn_mha (q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
22862305 cb (cur, " kqv_out" , il);
22872306
2288- if (inp-> self_v_rot ) {
2289- cur = ggml_mul_mat_aux (ctx0, cur, inp-> self_v_rot );
2307+ if (v_rot ) {
2308+ cur = ggml_mul_mat_aux (ctx0, cur, v_rot );
22902309 }
22912310
22922311 if (wo) {
@@ -2388,6 +2407,9 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
23882407 inp->self_k_rot = mctx_cur->get_base ()->build_input_k_rot (ctx0);
23892408 inp->self_v_rot = mctx_cur->get_base ()->build_input_v_rot (ctx0);
23902409
2410+ inp->self_k_rot_swa = mctx_cur->get_swa ()->build_input_k_rot (ctx0);
2411+ inp->self_v_rot_swa = mctx_cur->get_swa ()->build_input_v_rot (ctx0);
2412+
23912413 return (llm_graph_input_attn_kv_iswa *) res->add_input (std::move (inp));
23922414}
23932415
0 commit comments