@@ -894,31 +894,29 @@ ggml_tensor * llama_model_deepseek4::graph::build_attention(
894894 csa_state_score = ggml_add (ctx0, csa_state_score, csa_ape_rows);
895895 cb (csa_state_score, " csa_state_score_ape" , il);
896896
897- ggml_tensor * csa_state_dep = nullptr ;
898- if (inp_dsv4->get_csa ().state_write_idxs ) {
899- ggml_tensor * csa_source_kv = ggml_concat (ctx0,
900- inp_dsv4->mctx ->get_csa_state ()->get_kv (ctx0, il), csa_state_kv, 1 );
901- ggml_tensor * csa_source_score = ggml_concat (ctx0,
902- inp_dsv4->mctx ->get_csa_state ()->get_score (ctx0, il), csa_state_score, 1 );
903-
904- ggml_tensor * kv_comp_csa_state = build_overlap_compressed_kv_from_state (
905- csa_source_kv,
906- csa_source_score,
907- inp_dsv4->get_csa ().state_read_idxs ,
908- inp_dsv4->get_csa ().state_write_pos ,
909- layer.attn_comp_norm ,
910- DSV4_CSA_RATIO ,
911- n_embd_head,
912- " csa_state_compress" ,
913- il);
914-
915- ggml_build_forward_expand (gf, inp_dsv4->mctx ->get_csa ()->cpy_k (ctx0,
916- kv_comp_csa_state, inp_dsv4->get_csa ().state_write_idxs , il));
917- csa_state_dep = kv_comp_csa_state;
918- }
897+ GGML_ASSERT (inp_dsv4->get_csa ().state_write_idxs );
898+
899+ ggml_tensor * csa_source_kv = ggml_concat (ctx0,
900+ inp_dsv4->mctx ->get_csa_state ()->get_kv (ctx0, il), csa_state_kv, 1 );
901+ ggml_tensor * csa_source_score = ggml_concat (ctx0,
902+ inp_dsv4->mctx ->get_csa_state ()->get_score (ctx0, il), csa_state_score, 1 );
903+
904+ ggml_tensor * kv_comp_csa_state = build_overlap_compressed_kv_from_state (
905+ csa_source_kv,
906+ csa_source_score,
907+ inp_dsv4->get_csa ().state_read_idxs ,
908+ inp_dsv4->get_csa ().state_write_pos ,
909+ layer.attn_comp_norm ,
910+ DSV4_CSA_RATIO ,
911+ n_embd_head,
912+ " csa_state_compress" ,
913+ il);
919914
920- csa_state_kv = dsv4_with_zero_dep (ctx0, csa_state_kv, csa_state_dep);
921- csa_state_score = dsv4_with_zero_dep (ctx0, csa_state_score, csa_state_dep);
915+ ggml_build_forward_expand (gf, inp_dsv4->mctx ->get_csa ()->cpy_k (ctx0,
916+ kv_comp_csa_state, inp_dsv4->get_csa ().state_write_idxs , il));
917+
918+ csa_state_kv = dsv4_with_zero_dep (ctx0, csa_state_kv, kv_comp_csa_state);
919+ csa_state_score = dsv4_with_zero_dep (ctx0, csa_state_score, kv_comp_csa_state);
922920
923921 ggml_tensor * csa_persist_kv = ggml_get_rows (ctx0, csa_state_kv, inp_dsv4->get_csa ().state_persist_src_idxs );
924922 ggml_tensor * csa_persist_score = ggml_get_rows (ctx0, csa_state_score, inp_dsv4->get_csa ().state_persist_src_idxs );
@@ -946,36 +944,34 @@ ggml_tensor * llama_model_deepseek4::graph::build_attention(
946944 lid_state_score = ggml_add (ctx0, lid_state_score, lid_ape_rows);
947945 cb (lid_state_score, " lid_state_score_ape" , il);
948946
949- ggml_tensor * lid_state_dep = nullptr ;
950- if (inp_dsv4->get_lid ().state_write_idxs ) {
951- ggml_tensor * lid_source_kv = ggml_concat (ctx0,
952- inp_dsv4->mctx ->get_lid_state ()->get_kv (ctx0, il), lid_state_kv, 1 );
953- ggml_tensor * lid_source_score = ggml_concat (ctx0,
954- inp_dsv4->mctx ->get_lid_state ()->get_score (ctx0, il), lid_state_score, 1 );
955-
956- ggml_tensor * kv_comp_lid_state = build_overlap_compressed_kv_from_state (
957- lid_source_kv,
958- lid_source_score,
959- inp_dsv4->get_lid ().state_read_idxs ,
960- inp_dsv4->get_lid ().state_write_pos ,
961- layer.indexer_comp_norm ,
962- DSV4_CSA_RATIO ,
963- hparams.indexer_head_size ,
964- " lid_state_compress" ,
965- il);
966-
967- if (inp_dsv4->get_lid ().k_rot ) {
968- kv_comp_lid_state = ggml_mul_mat (ctx0, inp_dsv4->get_lid ().k_rot , kv_comp_lid_state);
969- cb (kv_comp_lid_state, " lid_state_compress_rot" , il);
970- }
947+ GGML_ASSERT (inp_dsv4->get_lid ().state_write_idxs );
948+
949+ ggml_tensor * lid_source_kv = ggml_concat (ctx0,
950+ inp_dsv4->mctx ->get_lid_state ()->get_kv (ctx0, il), lid_state_kv, 1 );
951+ ggml_tensor * lid_source_score = ggml_concat (ctx0,
952+ inp_dsv4->mctx ->get_lid_state ()->get_score (ctx0, il), lid_state_score, 1 );
953+
954+ ggml_tensor * kv_comp_lid_state = build_overlap_compressed_kv_from_state (
955+ lid_source_kv,
956+ lid_source_score,
957+ inp_dsv4->get_lid ().state_read_idxs ,
958+ inp_dsv4->get_lid ().state_write_pos ,
959+ layer.indexer_comp_norm ,
960+ DSV4_CSA_RATIO ,
961+ hparams.indexer_head_size ,
962+ " lid_state_compress" ,
963+ il);
971964
972- ggml_build_forward_expand (gf, inp_dsv4->mctx -> get_lid ()-> cpy_k (ctx0,
973- kv_comp_lid_state , inp_dsv4->get_lid ().state_write_idxs , il) );
974- lid_state_dep = kv_comp_lid_state ;
965+ if ( inp_dsv4->get_lid (). k_rot ) {
966+ kv_comp_lid_state = ggml_mul_mat (ctx0 , inp_dsv4->get_lid ().k_rot , kv_comp_lid_state );
967+ cb (kv_comp_lid_state, " lid_state_compress_rot " , il) ;
975968 }
976969
977- lid_state_kv = dsv4_with_zero_dep (ctx0, lid_state_kv, lid_state_dep);
978- lid_state_score = dsv4_with_zero_dep (ctx0, lid_state_score, lid_state_dep);
970+ ggml_build_forward_expand (gf, inp_dsv4->mctx ->get_lid ()->cpy_k (ctx0,
971+ kv_comp_lid_state, inp_dsv4->get_lid ().state_write_idxs , il));
972+
973+ lid_state_kv = dsv4_with_zero_dep (ctx0, lid_state_kv, kv_comp_lid_state);
974+ lid_state_score = dsv4_with_zero_dep (ctx0, lid_state_score, kv_comp_lid_state);
979975
980976 ggml_tensor * lid_persist_kv = ggml_get_rows (ctx0, lid_state_kv, inp_dsv4->get_lid ().state_persist_src_idxs );
981977 ggml_tensor * lid_persist_score = ggml_get_rows (ctx0, lid_state_score, inp_dsv4->get_lid ().state_persist_src_idxs );
0 commit comments