Skip to content

Commit 6b7855d

Browse files
committed
add llm_graph_input_dsv4
1 parent 07c3713 commit 6b7855d

9 files changed

Lines changed: 2323 additions & 12 deletions

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ add_library(llama
2525
llama-kv-cache.cpp
2626
llama-kv-cache-iswa.cpp
2727
llama-kv-cache-dsa.cpp
28+
llama-kv-cache-dsv4.cpp
2829
llama-memory.cpp
2930
llama-memory-hybrid.cpp
3031
llama-memory-hybrid-iswa.cpp

src/llama-graph.cpp

Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "llama-kv-cache.h"
99
#include "llama-kv-cache-iswa.h"
1010
#include "llama-kv-cache-dsa.h"
11+
#include "llama-kv-cache-dsv4.h"
1112
#include "llama-memory-hybrid.h"
1213
#include "llama-memory-hybrid-iswa.h"
1314
#include "llama-memory-recurrent.h"
@@ -17,6 +18,7 @@
1718
#include <cstring>
1819
#include <numeric>
1920
#include <sstream>
21+
#include <string>
2022
#include <unordered_set>
2123

2224
// dedup helpers
@@ -620,6 +622,223 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
620622
return res;
621623
}
622624

625+
static void dsv4_set_i64(ggml_tensor * dst, const std::vector<int64_t> & src) {
626+
if (!dst || !dst->buffer) {
627+
return;
628+
}
629+
630+
GGML_ASSERT(dst->ne[0] == (int64_t) src.size());
631+
ggml_backend_tensor_set(dst, src.data(), 0, src.size()*ggml_element_size(dst));
632+
}
633+
634+
static void dsv4_set_i32(ggml_tensor * dst, const std::vector<int32_t> & src) {
635+
if (!dst || !dst->buffer) {
636+
return;
637+
}
638+
639+
GGML_ASSERT(dst->ne[0] == (int64_t) src.size());
640+
ggml_backend_tensor_set(dst, src.data(), 0, src.size()*ggml_element_size(dst));
641+
}
642+
643+
static void dsv4_set_kq_mask(
644+
ggml_tensor * dst,
645+
const llama_kv_cache_dsv4_context::comp_plan & plan,
646+
uint32_t n_tokens) {
647+
if (!dst || !dst->buffer) {
648+
return;
649+
}
650+
651+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
652+
GGML_ASSERT(dst->ne[0] == plan.n_kv);
653+
GGML_ASSERT(dst->ne[1] == (int64_t) n_tokens);
654+
GGML_ASSERT(dst->ne[2] == 1);
655+
GGML_ASSERT(dst->ne[3] == 1);
656+
GGML_ASSERT((int64_t) plan.n_visible.size() == dst->ne[1]);
657+
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
658+
659+
float * data = (float *) dst->data;
660+
661+
for (int64_t i = 0; i < dst->ne[1]; ++i) {
662+
const int32_t n_visible = plan.n_visible[i];
663+
664+
for (int64_t j = 0; j < dst->ne[0]; ++j) {
665+
data[i*dst->ne[0] + j] = j < n_visible ? 0.0f : -INFINITY;
666+
}
667+
}
668+
}
669+
670+
static std::string dsv4_plan_positions(const std::vector<int32_t> & values) {
671+
std::ostringstream ss;
672+
ss << "[";
673+
for (size_t i = 0; i < values.size(); ++i) {
674+
if (i > 0) {
675+
ss << ", ";
676+
}
677+
ss << values[i];
678+
}
679+
ss << "]";
680+
return ss.str();
681+
}
682+
683+
static bool dsv4_compress_debug() {
684+
static const bool debug = []() {
685+
const char * env = getenv("LLAMA_DSV4_COMPRESS_DEBUG");
686+
return env && atoi(env) > 0;
687+
}();
688+
689+
return debug;
690+
}
691+
692+
static void dsv4_set_comp_inputs(
693+
const llm_graph_input_dsv4::comp_input & inp,
694+
const llama_kv_cache_dsv4_context::comp_plan & plan,
695+
const char * name,
696+
bool debug,
697+
uint32_t n_tokens) {
698+
dsv4_set_i64(inp.write_idxs, plan.write_idxs);
699+
dsv4_set_i32(inp.write_pos, plan.write_pos);
700+
dsv4_set_i32(inp.write_end, plan.write_end);
701+
dsv4_set_i32(inp.pending_end, plan.pending_end);
702+
dsv4_set_i32(inp.state_idxs, plan.state_idxs);
703+
dsv4_set_i32(inp.state_pos, plan.state_pos);
704+
dsv4_set_i32(inp.state_read_idxs, plan.state_read_idxs);
705+
dsv4_set_i64(inp.state_write_idxs, plan.state_write_idxs);
706+
dsv4_set_i32(inp.state_write_pos, plan.state_write_pos);
707+
dsv4_set_i32(inp.state_write_end, plan.state_write_end);
708+
dsv4_set_i32(inp.n_visible, plan.n_visible);
709+
dsv4_set_kq_mask(inp.kq_mask, plan, n_tokens);
710+
711+
if (debug || dsv4_compress_debug()) {
712+
LLAMA_LOG_INFO("%s: %s ratio=%u, n_tokens=%u, write_end=%s, state_write_end=%s, pending_end=%s\n",
713+
__func__, name, plan.ratio, n_tokens,
714+
dsv4_plan_positions(plan.write_end).c_str(),
715+
dsv4_plan_positions(plan.state_write_end).c_str(),
716+
dsv4_plan_positions(plan.pending_end).c_str());
717+
}
718+
}
719+
720+
static bool dsv4_can_reuse_tensor_1d(ggml_tensor * t, int64_t ne0) {
721+
return (t == nullptr && ne0 == 0) || (t != nullptr && t->ne[0] == ne0);
722+
}
723+
724+
static bool dsv4_can_reuse_kq_mask(
725+
ggml_tensor * t,
726+
const llama_kv_cache_dsv4_context::comp_plan & plan,
727+
uint32_t n_tokens) {
728+
if (plan.n_kv == 0) {
729+
return t == nullptr;
730+
}
731+
732+
return t != nullptr &&
733+
t->ne[0] == plan.n_kv &&
734+
t->ne[1] == (int64_t) n_tokens &&
735+
t->ne[2] == 1 &&
736+
t->ne[3] == 1;
737+
}
738+
739+
static bool dsv4_can_reuse_comp_input(
740+
const llm_graph_input_dsv4::comp_input & inp,
741+
const llama_kv_cache_dsv4_context::comp_plan & plan,
742+
uint32_t n_tokens) {
743+
const int64_t n_write = plan.write_idxs.size();
744+
745+
bool res = true;
746+
res &= dsv4_can_reuse_tensor_1d(inp.write_idxs, n_write);
747+
res &= dsv4_can_reuse_tensor_1d(inp.write_pos, n_write);
748+
res &= dsv4_can_reuse_tensor_1d(inp.write_end, n_write);
749+
res &= dsv4_can_reuse_tensor_1d(inp.pending_end, plan.pending_end.size());
750+
res &= dsv4_can_reuse_tensor_1d(inp.state_idxs, plan.state_idxs.size());
751+
res &= dsv4_can_reuse_tensor_1d(inp.state_pos, plan.state_pos.size());
752+
res &= dsv4_can_reuse_tensor_1d(inp.state_read_idxs, plan.state_read_idxs.size());
753+
res &= dsv4_can_reuse_tensor_1d(inp.state_write_idxs, plan.state_write_idxs.size());
754+
res &= dsv4_can_reuse_tensor_1d(inp.state_write_pos, plan.state_write_pos.size());
755+
res &= dsv4_can_reuse_tensor_1d(inp.state_write_end, plan.state_write_end.size());
756+
res &= dsv4_can_reuse_tensor_1d(inp.n_visible, plan.n_visible.size());
757+
res &= dsv4_can_reuse_kq_mask(inp.kq_mask, plan, n_tokens);
758+
759+
return res;
760+
}
761+
762+
static ggml_tensor * dsv4_build_input_1d(
763+
ggml_context * ctx,
764+
ggml_type type,
765+
int64_t ne0,
766+
const std::string & name) {
767+
if (ne0 == 0) {
768+
return nullptr;
769+
}
770+
771+
ggml_tensor * res = ggml_new_tensor_1d(ctx, type, ne0);
772+
ggml_set_input(res);
773+
ggml_set_name(res, name.c_str());
774+
775+
return res;
776+
}
777+
778+
static void dsv4_build_comp_inputs(
779+
ggml_context * ctx,
780+
llm_graph_input_dsv4::comp_input & inp,
781+
const llama_kv_cache_dsv4_context::comp_plan & plan,
782+
const char * name) {
783+
const int64_t n_write = plan.write_idxs.size();
784+
785+
inp.write_idxs = dsv4_build_input_1d(ctx, GGML_TYPE_I64, n_write, std::string("dsv4_") + name + "_write_idxs");
786+
inp.write_pos = dsv4_build_input_1d(ctx, GGML_TYPE_I32, n_write, std::string("dsv4_") + name + "_write_pos");
787+
inp.write_end = dsv4_build_input_1d(ctx, GGML_TYPE_I32, n_write, std::string("dsv4_") + name + "_write_end");
788+
inp.pending_end = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.pending_end.size(), std::string("dsv4_") + name + "_pending_end");
789+
inp.state_idxs = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_idxs.size(), std::string("dsv4_") + name + "_state_idxs");
790+
inp.state_pos = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_pos.size(), std::string("dsv4_") + name + "_state_pos");
791+
inp.state_read_idxs = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_read_idxs.size(), std::string("dsv4_") + name + "_state_read_idxs");
792+
inp.state_write_idxs = dsv4_build_input_1d(ctx, GGML_TYPE_I64, plan.state_write_idxs.size(), std::string("dsv4_") + name + "_state_write_idxs");
793+
inp.state_write_pos = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_write_pos.size(), std::string("dsv4_") + name + "_state_write_pos");
794+
inp.state_write_end = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_write_end.size(), std::string("dsv4_") + name + "_state_write_end");
795+
inp.n_visible = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.n_visible.size(), std::string("dsv4_") + name + "_n_visible");
796+
797+
if (plan.n_kv > 0) {
798+
inp.kq_mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, plan.n_kv, plan.n_visible.size(), 1, 1);
799+
ggml_set_input(inp.kq_mask);
800+
ggml_set_name(inp.kq_mask, (std::string("dsv4_") + name + "_kq_mask").c_str());
801+
}
802+
}
803+
804+
void llm_graph_input_dsv4::set_input(const llama_ubatch * ubatch) {
805+
inp_raw->mctx = mctx->get_raw();
806+
inp_raw->set_input(ubatch);
807+
808+
dsv4_set_comp_inputs(inp_csa, mctx->get_csa_plan(), "csa", debug > 0, ubatch->n_tokens);
809+
dsv4_set_comp_inputs(inp_hca, mctx->get_hca_plan(), "hca", debug > 0, ubatch->n_tokens);
810+
dsv4_set_comp_inputs(inp_lid, mctx->get_lid_plan(), "lid", debug > 0, ubatch->n_tokens);
811+
812+
if (inp_lid.k_rot && inp_lid.k_rot->buffer) {
813+
mctx->get_lid()->set_input_k_rot(inp_lid.k_rot);
814+
}
815+
}
816+
817+
bool llm_graph_input_dsv4::can_reuse(const llm_graph_params & params) {
818+
const auto * mctx = static_cast<const llama_kv_cache_dsv4_context *>(params.mctx);
819+
820+
this->mctx = mctx;
821+
inp_raw->mctx = mctx->get_raw();
822+
823+
bool res = true;
824+
825+
if (inp_raw->self_k_idxs && inp_raw->self_k_idxs->buffer) {
826+
res &= inp_raw->self_k_idxs->ne[0] == params.ubatch.n_tokens;
827+
res &= can_reuse_kq_mask(inp_raw->self_kq_mask, mctx->get_raw()->get_base(), params.ubatch, params.cparams);
828+
}
829+
830+
if (inp_raw->self_k_idxs_swa && inp_raw->self_k_idxs_swa->buffer) {
831+
res &= inp_raw->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
832+
res &= can_reuse_kq_mask(inp_raw->self_kq_mask_swa, mctx->get_raw()->get_swa(), params.ubatch, params.cparams);
833+
}
834+
835+
res &= dsv4_can_reuse_comp_input(inp_csa, mctx->get_csa_plan(), params.ubatch.n_tokens);
836+
res &= dsv4_can_reuse_comp_input(inp_hca, mctx->get_hca_plan(), params.ubatch.n_tokens);
837+
res &= dsv4_can_reuse_comp_input(inp_lid, mctx->get_lid_plan(), params.ubatch.n_tokens);
838+
839+
return res;
840+
}
841+
623842
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
624843
GGML_ASSERT(cross_kq_mask);
625844

@@ -2731,6 +2950,46 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
27312950
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
27322951
}
27332952

2953+
llm_graph_input_dsv4 * llm_graph_context::build_inp_dsv4() const {
2954+
const auto * mctx_cur = static_cast<const llama_kv_cache_dsv4_context *>(mctx);
2955+
const auto * raw_ctx = mctx_cur->get_raw();
2956+
2957+
auto inp_raw = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, raw_ctx);
2958+
2959+
{
2960+
inp_raw->self_k_idxs = raw_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
2961+
inp_raw->self_v_idxs = raw_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
2962+
2963+
inp_raw->self_kq_mask = build_attn_inp_kq_mask(ctx0, raw_ctx->get_base(), ubatch, cparams);
2964+
inp_raw->self_kq_mask_cnv = inp_raw->self_kq_mask;
2965+
}
2966+
2967+
{
2968+
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "DSV4 expects SWA raw cache");
2969+
2970+
inp_raw->self_k_idxs_swa = raw_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
2971+
inp_raw->self_v_idxs_swa = raw_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
2972+
2973+
inp_raw->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, raw_ctx->get_swa(), ubatch, cparams);
2974+
inp_raw->self_kq_mask_swa_cnv = inp_raw->self_kq_mask_swa;
2975+
}
2976+
2977+
inp_raw->self_k_rot = raw_ctx->get_base()->build_input_k_rot(ctx0);
2978+
inp_raw->self_v_rot = raw_ctx->get_base()->build_input_v_rot(ctx0);
2979+
2980+
inp_raw->self_k_rot_swa = raw_ctx->get_swa()->build_input_k_rot(ctx0);
2981+
inp_raw->self_v_rot_swa = raw_ctx->get_swa()->build_input_v_rot(ctx0);
2982+
2983+
auto inp = std::make_unique<llm_graph_input_dsv4>(cparams, std::move(inp_raw), mctx_cur);
2984+
2985+
dsv4_build_comp_inputs(ctx0, inp->inp_csa, mctx_cur->get_csa_plan(), "csa");
2986+
dsv4_build_comp_inputs(ctx0, inp->inp_hca, mctx_cur->get_hca_plan(), "hca");
2987+
dsv4_build_comp_inputs(ctx0, inp->inp_lid, mctx_cur->get_lid_plan(), "lid");
2988+
inp->inp_lid.k_rot = mctx_cur->get_lid()->build_input_k_rot(ctx0);
2989+
2990+
return (llm_graph_input_dsv4 *) res->add_input(std::move(inp));
2991+
}
2992+
27342993
ggml_tensor * llm_graph_context::build_rs(
27352994
ggml_tensor * s,
27362995
ggml_tensor * state_copy_main,

src/llama-graph.h

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ struct llama_memory_context_i;
2323

2424
class llama_kv_cache_context;
2525
class llama_kv_cache_dsa_context;
26+
class llama_kv_cache_dsv4_context;
2627
class llama_kv_cache_iswa_context;
2728
class llama_memory_recurrent_context;
2829
class llama_memory_hybrid_context;
@@ -458,6 +459,57 @@ class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
458459
const llama_kv_cache_iswa_context * mctx;
459460
};
460461

462+
class llm_graph_input_dsv4 : public llm_graph_input_i {
463+
public:
464+
struct comp_input {
465+
ggml_tensor * write_idxs = nullptr; // I64 [n_write]
466+
ggml_tensor * write_pos = nullptr; // I32 [n_write]
467+
ggml_tensor * write_end = nullptr; // I32 [n_write]
468+
ggml_tensor * pending_end = nullptr; // I32 [n_pending]
469+
470+
ggml_tensor * state_idxs = nullptr; // I32 [n_state]
471+
ggml_tensor * state_pos = nullptr; // I32 [n_state]
472+
ggml_tensor * state_read_idxs = nullptr; // I32 [ratio*n_state_write]
473+
ggml_tensor * state_write_idxs = nullptr; // I64 [n_state_write]
474+
ggml_tensor * state_write_pos = nullptr; // I32 [n_state_write]
475+
ggml_tensor * state_write_end = nullptr; // I32 [n_state_write]
476+
477+
ggml_tensor * n_visible = nullptr; // I32 [n_batch]
478+
ggml_tensor * kq_mask = nullptr; // F32 [n_kv, n_batch]
479+
480+
ggml_tensor * k_rot = nullptr;
481+
};
482+
483+
llm_graph_input_dsv4(
484+
const llama_cparams & cparams,
485+
std::unique_ptr<llm_graph_input_attn_kv_iswa> inp_raw,
486+
const llama_kv_cache_dsv4_context * mctx) :
487+
inp_raw(std::move(inp_raw)),
488+
cparams(cparams),
489+
mctx(mctx) {
490+
}
491+
~llm_graph_input_dsv4() = default;
492+
493+
void set_input(const llama_ubatch * ubatch) override;
494+
495+
bool can_reuse(const llm_graph_params & params) override;
496+
497+
llm_graph_input_attn_kv_iswa * get_raw() const { return inp_raw.get(); }
498+
const comp_input & get_csa() const { return inp_csa; }
499+
const comp_input & get_hca() const { return inp_hca; }
500+
const comp_input & get_lid() const { return inp_lid; }
501+
502+
std::unique_ptr<llm_graph_input_attn_kv_iswa> inp_raw;
503+
504+
comp_input inp_csa;
505+
comp_input inp_hca;
506+
comp_input inp_lid;
507+
508+
const llama_cparams cparams;
509+
510+
const llama_kv_cache_dsv4_context * mctx;
511+
};
512+
461513
class llm_graph_input_attn_cross : public llm_graph_input_i {
462514
public:
463515
llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
@@ -1033,6 +1085,8 @@ struct llm_graph_context {
10331085

10341086
llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const;
10351087

1088+
llm_graph_input_dsv4 * build_inp_dsv4() const;
1089+
10361090
// note: if k_cur or v_cur are not provided, they will not be stored in the memory
10371091
ggml_tensor * build_attn(
10381092
llm_graph_input_attn_kv_iswa * inp,

0 commit comments

Comments
 (0)