@@ -511,6 +511,14 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
511511 if (self_v_rot) {
512512 mctx->get_base ()->set_input_v_rot (self_v_rot);
513513 }
514+
515+ if (self_k_rot_swa) {
516+ mctx->get_swa ()->set_input_k_rot (self_k_rot_swa);
517+ }
518+
519+ if (self_v_rot_swa) {
520+ mctx->get_swa ()->set_input_v_rot (self_v_rot_swa);
521+ }
514522}
515523
516524bool llm_graph_input_attn_kv_iswa::can_reuse (const llm_graph_params & params) {
@@ -681,6 +689,14 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
681689 attn_ctx->get_base ()->set_input_v_rot (inp_attn->self_v_rot );
682690 }
683691
692+ if (inp_attn->self_k_rot_swa ) {
693+ attn_ctx->get_swa ()->set_input_k_rot (inp_attn->self_k_rot_swa );
694+ }
695+
696+ if (inp_attn->self_v_rot_swa ) {
697+ attn_ctx->get_swa ()->set_input_v_rot (inp_attn->self_v_rot_swa );
698+ }
699+
684700 const int64_t n_rs = mctx->get_recr ()->get_n_rs ();
685701
686702 if (inp_rs->s_copy ) {
@@ -2242,15 +2258,20 @@ ggml_tensor * llm_graph_context::build_attn(
22422258 ggml_tensor * v_mla,
22432259 float kq_scale,
22442260 int il) const {
2245- if (inp->self_k_rot ) {
2246- q_cur = ggml_mul_mat_aux (ctx0, q_cur, inp->self_k_rot );
2261+ const bool is_swa = hparams.is_swa (il);
2262+
2263+ auto * k_rot = is_swa ? inp->self_k_rot_swa : inp->self_k_rot ;
2264+ auto * v_rot = is_swa ? inp->self_v_rot_swa : inp->self_v_rot ;
2265+
2266+ if (k_rot) {
2267+ q_cur = ggml_mul_mat_aux (ctx0, q_cur, k_rot);
22472268 if (k_cur) {
2248- k_cur = ggml_mul_mat_aux (ctx0, k_cur, inp-> self_k_rot );
2269+ k_cur = ggml_mul_mat_aux (ctx0, k_cur, k_rot );
22492270 }
22502271 }
2251- if (inp-> self_v_rot ) {
2272+ if (v_rot ) {
22522273 if (v_cur) {
2253- v_cur = ggml_mul_mat_aux (ctx0, v_cur, inp-> self_v_rot );
2274+ v_cur = ggml_mul_mat_aux (ctx0, v_cur, v_rot );
22542275 }
22552276 }
22562277
@@ -2268,8 +2289,6 @@ ggml_tensor * llm_graph_context::build_attn(
22682289
22692290 const auto * mctx_iswa = inp->mctx ;
22702291
2271- const bool is_swa = hparams.is_swa (il);
2272-
22732292 const auto * mctx_cur = is_swa ? mctx_iswa->get_swa () : mctx_iswa->get_base ();
22742293
22752294 // optionally store to KV cache
@@ -2294,8 +2313,8 @@ ggml_tensor * llm_graph_context::build_attn(
22942313 ggml_tensor * cur = build_attn_mha (q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
22952314 cb (cur, " kqv_out" , il);
22962315
2297- if (inp-> self_v_rot ) {
2298- cur = ggml_mul_mat_aux (ctx0, cur, inp-> self_v_rot );
2316+ if (v_rot ) {
2317+ cur = ggml_mul_mat_aux (ctx0, cur, v_rot );
22992318 }
23002319
23012320 if (wo) {
@@ -2397,6 +2416,9 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
23972416 inp->self_k_rot = mctx_cur->get_base ()->build_input_k_rot (ctx0);
23982417 inp->self_v_rot = mctx_cur->get_base ()->build_input_v_rot (ctx0);
23992418
2419+ inp->self_k_rot_swa = mctx_cur->get_swa ()->build_input_k_rot (ctx0);
2420+ inp->self_v_rot_swa = mctx_cur->get_swa ()->build_input_v_rot (ctx0);
2421+
24002422 return (llm_graph_input_attn_kv_iswa *) res->add_input (std::move (inp));
24012423}
24022424
0 commit comments