@@ -683,6 +683,48 @@ static void dsv4_set_kq_mask(
683683 }
684684}
685685
686+ static ggml_tensor * dsv4_build_raw_kq_mask (
687+ ggml_context * ctx,
688+ const llama_kv_cache_dsv4_raw_context * mctx,
689+ const llama_ubatch & ubatch,
690+ const llama_cparams & cparams,
691+ int64_t n_stream) {
692+ const auto n_kv = mctx->get_n_kv ();
693+ const auto n_tokens = ubatch.n_tokens ;
694+
695+ GGML_ASSERT (n_stream > 0 );
696+ GGML_ASSERT (n_tokens%n_stream == 0 );
697+
698+ const bool use_fattn = cparams.flash_attn && (!cparams.kv_unified || n_stream == 1 );
699+ const auto type = use_fattn ? GGML_TYPE_F16 : GGML_TYPE_F32 ;
700+
701+ ggml_tensor * res = ggml_new_tensor_4d (ctx, type, n_kv, n_tokens/n_stream, 1 , n_stream);
702+ ggml_set_input (res);
703+ ggml_set_name (res, " attn_inp_kq_mask" );
704+
705+ return res;
706+ }
707+
708+ static bool dsv4_can_reuse_raw_kq_mask (
709+ ggml_tensor * kq_mask,
710+ const llama_kv_cache_dsv4_raw_context * mctx,
711+ const llama_ubatch & ubatch,
712+ int64_t n_stream) {
713+ const auto n_kv = mctx->get_n_kv ();
714+ const auto n_tokens = ubatch.n_tokens ;
715+
716+ GGML_ASSERT (n_stream > 0 );
717+
718+ bool res = true ;
719+
720+ res &= (kq_mask->ne [0 ] == n_kv);
721+ res &= (kq_mask->ne [1 ] == n_tokens/n_stream);
722+ res &= (kq_mask->ne [2 ] == 1 );
723+ res &= (kq_mask->ne [3 ] == n_stream);
724+
725+ return res;
726+ }
727+
686728static std::string dsv4_plan_positions (const std::vector<int32_t > & values) {
687729 std::ostringstream ss;
688730 ss << " [" ;
@@ -808,15 +850,32 @@ static void dsv4_build_comp_inputs(
808850 }
809851}
810852
853+ void llm_graph_input_dsv4_raw::set_input (const llama_ubatch * ubatch) {
854+ if (self_k_idxs && self_k_idxs->buffer ) {
855+ mctx->set_input_k_idxs (self_k_idxs);
856+ }
857+
858+ if (self_kq_mask && self_kq_mask->buffer ) {
859+ mctx->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
860+ }
861+
862+ if (self_k_rot) {
863+ mctx->set_input_k_rot (self_k_rot);
864+ }
865+ }
866+
811867void llm_graph_input_dsv4::set_input (const llama_ubatch * ubatch) {
868+ const auto & plan_csa = mctx->get_csa_plan (*ubatch);
869+ const auto & plan_hca = mctx->get_hca_plan (*ubatch);
870+ const auto & plan_lid = mctx->get_lid_plan (*ubatch);
871+ const int64_t n_stream = plan_csa.n_stream ;
872+
812873 inp_raw->mctx = mctx->get_raw ();
813874 inp_raw->set_input (ubatch);
814875
815- const int64_t n_stream = cparams.kv_unified ? 1 : ubatch->n_seqs_unq ;
816-
817- dsv4_set_comp_inputs (inp_csa, mctx->get_csa_plan (*ubatch), " csa" , debug > 0 , ubatch->n_tokens , n_stream);
818- dsv4_set_comp_inputs (inp_hca, mctx->get_hca_plan (*ubatch), " hca" , debug > 0 , ubatch->n_tokens , n_stream);
819- dsv4_set_comp_inputs (inp_lid, mctx->get_lid_plan (*ubatch), " lid" , debug > 0 , ubatch->n_tokens , n_stream);
876+ dsv4_set_comp_inputs (inp_csa, plan_csa, " csa" , debug > 0 , ubatch->n_tokens , n_stream);
877+ dsv4_set_comp_inputs (inp_hca, plan_hca, " hca" , debug > 0 , ubatch->n_tokens , n_stream);
878+ dsv4_set_comp_inputs (inp_lid, plan_lid, " lid" , debug > 0 , ubatch->n_tokens , n_stream);
820879
821880 if (inp_lid.k_rot && inp_lid.k_rot ->buffer ) {
822881 mctx->get_lid ()->set_input_k_rot (inp_lid.k_rot );
@@ -831,15 +890,24 @@ bool llm_graph_input_dsv4::can_reuse(const llm_graph_params & params) {
831890
832891 bool res = true ;
833892
834- llm_graph_params raw_params = params;
835- raw_params.mctx = mctx->get_raw ();
836- res &= inp_raw->can_reuse (raw_params);
893+ const auto & plan_csa = mctx->get_csa_plan (params.ubatch );
894+ const auto & plan_hca = mctx->get_hca_plan (params.ubatch );
895+ const auto & plan_lid = mctx->get_lid_plan (params.ubatch );
896+ const int64_t n_stream = plan_csa.n_stream ;
897+
898+ const auto * raw_ctx = mctx->get_raw ();
899+ inp_raw->mctx = raw_ctx;
837900
838- const int64_t n_stream = params.cparams .kv_unified ? 1 : params.ubatch .n_seqs_unq ;
901+ if (inp_raw->self_k_idxs && inp_raw->self_k_idxs ->buffer ) {
902+ res &= inp_raw->self_k_idxs ->ne [0 ] == raw_ctx->get_n_write ();
903+ }
904+ if (inp_raw->self_kq_mask && inp_raw->self_kq_mask ->buffer ) {
905+ res &= dsv4_can_reuse_raw_kq_mask (inp_raw->self_kq_mask , raw_ctx, params.ubatch , n_stream);
906+ }
839907
840- res &= dsv4_can_reuse_comp_input (inp_csa, mctx-> get_csa_plan (params. ubatch ) , params.ubatch .n_tokens , n_stream);
841- res &= dsv4_can_reuse_comp_input (inp_hca, mctx-> get_hca_plan (params. ubatch ) , params.ubatch .n_tokens , n_stream);
842- res &= dsv4_can_reuse_comp_input (inp_lid, mctx-> get_lid_plan (params. ubatch ) , params.ubatch .n_tokens , n_stream);
908+ res &= dsv4_can_reuse_comp_input (inp_csa, plan_csa , params.ubatch .n_tokens , n_stream);
909+ res &= dsv4_can_reuse_comp_input (inp_hca, plan_hca , params.ubatch .n_tokens , n_stream);
910+ res &= dsv4_can_reuse_comp_input (inp_lid, plan_lid , params.ubatch .n_tokens , n_stream);
843911
844912 return res;
845913}
@@ -2995,28 +3063,19 @@ llm_graph_input_dsv4 * llm_graph_context::build_inp_dsv4() const {
29953063 const auto * mctx_cur = static_cast <const llama_kv_cache_dsv4_context *>(mctx);
29963064 const auto * raw_ctx = mctx_cur->get_raw ();
29973065
2998- auto inp_raw = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, raw_ctx);
3066+ auto inp_raw = std::make_unique<llm_graph_input_dsv4_raw>( cparams, raw_ctx);
29993067
3000- {
3001- inp_raw->self_k_idxs = raw_ctx->get_base ()->build_input_k_idxs (ctx0, ubatch);
3002- inp_raw->self_kq_mask = build_attn_inp_kq_mask (ctx0, raw_ctx->get_base (), ubatch, cparams);
3003- inp_raw->self_kq_mask_cnv = inp_raw->self_kq_mask ;
3004- }
3068+ const int64_t n_stream = mctx_cur->get_csa_plan (ubatch).n_stream ;
30053069
3006- {
3007- GGML_ASSERT (hparams.swa_type != LLAMA_SWA_TYPE_NONE && " DSV4 expects SWA raw cache" );
3070+ GGML_ASSERT (hparams.swa_type != LLAMA_SWA_TYPE_NONE && " DSV4 expects SWA raw cache" );
30083071
3009- inp_raw->self_k_idxs_swa = raw_ctx->get_swa ()->build_input_k_idxs (ctx0, ubatch);
3010- inp_raw->self_kq_mask_swa = build_attn_inp_kq_mask (ctx0, raw_ctx->get_swa (), ubatch, cparams);
3011- inp_raw->self_kq_mask_swa_cnv = inp_raw->self_kq_mask_swa ;
3012- }
3072+ inp_raw->self_k_idxs = raw_ctx->build_input_k_idxs (ctx0, ubatch);
3073+ inp_raw->self_kq_mask = dsv4_build_raw_kq_mask (ctx0, raw_ctx, ubatch, cparams, n_stream);
3074+ inp_raw->self_kq_mask_cnv = inp_raw->self_kq_mask ;
30133075
3014- inp_raw->self_k_rot = raw_ctx->get_base ()->build_input_k_rot (ctx0);
3015- inp_raw->self_k_rot_swa = raw_ctx->get_swa ()->build_input_k_rot (ctx0);
3076+ inp_raw->self_k_rot = raw_ctx->build_input_k_rot (ctx0);
30163077 auto inp = std::make_unique<llm_graph_input_dsv4>(cparams, std::move (inp_raw), mctx_cur);
30173078
3018- const int64_t n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq ;
3019-
30203079 dsv4_build_comp_inputs (ctx0, inp->inp_csa , mctx_cur->get_csa_plan (ubatch), " csa" , n_stream);
30213080 dsv4_build_comp_inputs (ctx0, inp->inp_hca , mctx_cur->get_hca_plan (ubatch), " hca" , n_stream);
30223081 dsv4_build_comp_inputs (ctx0, inp->inp_lid , mctx_cur->get_lid_plan (ubatch), " lid" , n_stream);
0 commit comments