Skip to content

Commit fd0955f

Browse files
committed
support multi-seq
1 parent 402a60f commit fd0955f

6 files changed

Lines changed: 968 additions & 184 deletions

File tree

src/llama-graph.cpp

Lines changed: 87 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,48 @@ static void dsv4_set_kq_mask(
683683
}
684684
}
685685

686+
static ggml_tensor * dsv4_build_raw_kq_mask(
687+
ggml_context * ctx,
688+
const llama_kv_cache_dsv4_raw_context * mctx,
689+
const llama_ubatch & ubatch,
690+
const llama_cparams & cparams,
691+
int64_t n_stream) {
692+
const auto n_kv = mctx->get_n_kv();
693+
const auto n_tokens = ubatch.n_tokens;
694+
695+
GGML_ASSERT(n_stream > 0);
696+
GGML_ASSERT(n_tokens%n_stream == 0);
697+
698+
const bool use_fattn = cparams.flash_attn && (!cparams.kv_unified || n_stream == 1);
699+
const auto type = use_fattn ? GGML_TYPE_F16 : GGML_TYPE_F32;
700+
701+
ggml_tensor * res = ggml_new_tensor_4d(ctx, type, n_kv, n_tokens/n_stream, 1, n_stream);
702+
ggml_set_input(res);
703+
ggml_set_name(res, "attn_inp_kq_mask");
704+
705+
return res;
706+
}
707+
708+
static bool dsv4_can_reuse_raw_kq_mask(
709+
ggml_tensor * kq_mask,
710+
const llama_kv_cache_dsv4_raw_context * mctx,
711+
const llama_ubatch & ubatch,
712+
int64_t n_stream) {
713+
const auto n_kv = mctx->get_n_kv();
714+
const auto n_tokens = ubatch.n_tokens;
715+
716+
GGML_ASSERT(n_stream > 0);
717+
718+
bool res = true;
719+
720+
res &= (kq_mask->ne[0] == n_kv);
721+
res &= (kq_mask->ne[1] == n_tokens/n_stream);
722+
res &= (kq_mask->ne[2] == 1);
723+
res &= (kq_mask->ne[3] == n_stream);
724+
725+
return res;
726+
}
727+
686728
static std::string dsv4_plan_positions(const std::vector<int32_t> & values) {
687729
std::ostringstream ss;
688730
ss << "[";
@@ -808,15 +850,32 @@ static void dsv4_build_comp_inputs(
808850
}
809851
}
810852

853+
void llm_graph_input_dsv4_raw::set_input(const llama_ubatch * ubatch) {
854+
if (self_k_idxs && self_k_idxs->buffer) {
855+
mctx->set_input_k_idxs(self_k_idxs);
856+
}
857+
858+
if (self_kq_mask && self_kq_mask->buffer) {
859+
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
860+
}
861+
862+
if (self_k_rot) {
863+
mctx->set_input_k_rot(self_k_rot);
864+
}
865+
}
866+
811867
void llm_graph_input_dsv4::set_input(const llama_ubatch * ubatch) {
868+
const auto & plan_csa = mctx->get_csa_plan(*ubatch);
869+
const auto & plan_hca = mctx->get_hca_plan(*ubatch);
870+
const auto & plan_lid = mctx->get_lid_plan(*ubatch);
871+
const int64_t n_stream = plan_csa.n_stream;
872+
812873
inp_raw->mctx = mctx->get_raw();
813874
inp_raw->set_input(ubatch);
814875

815-
const int64_t n_stream = cparams.kv_unified ? 1 : ubatch->n_seqs_unq;
816-
817-
dsv4_set_comp_inputs(inp_csa, mctx->get_csa_plan(*ubatch), "csa", debug > 0, ubatch->n_tokens, n_stream);
818-
dsv4_set_comp_inputs(inp_hca, mctx->get_hca_plan(*ubatch), "hca", debug > 0, ubatch->n_tokens, n_stream);
819-
dsv4_set_comp_inputs(inp_lid, mctx->get_lid_plan(*ubatch), "lid", debug > 0, ubatch->n_tokens, n_stream);
876+
dsv4_set_comp_inputs(inp_csa, plan_csa, "csa", debug > 0, ubatch->n_tokens, n_stream);
877+
dsv4_set_comp_inputs(inp_hca, plan_hca, "hca", debug > 0, ubatch->n_tokens, n_stream);
878+
dsv4_set_comp_inputs(inp_lid, plan_lid, "lid", debug > 0, ubatch->n_tokens, n_stream);
820879

821880
if (inp_lid.k_rot && inp_lid.k_rot->buffer) {
822881
mctx->get_lid()->set_input_k_rot(inp_lid.k_rot);
@@ -831,15 +890,24 @@ bool llm_graph_input_dsv4::can_reuse(const llm_graph_params & params) {
831890

832891
bool res = true;
833892

834-
llm_graph_params raw_params = params;
835-
raw_params.mctx = mctx->get_raw();
836-
res &= inp_raw->can_reuse(raw_params);
893+
const auto & plan_csa = mctx->get_csa_plan(params.ubatch);
894+
const auto & plan_hca = mctx->get_hca_plan(params.ubatch);
895+
const auto & plan_lid = mctx->get_lid_plan(params.ubatch);
896+
const int64_t n_stream = plan_csa.n_stream;
897+
898+
const auto * raw_ctx = mctx->get_raw();
899+
inp_raw->mctx = raw_ctx;
837900

838-
const int64_t n_stream = params.cparams.kv_unified ? 1 : params.ubatch.n_seqs_unq;
901+
if (inp_raw->self_k_idxs && inp_raw->self_k_idxs->buffer) {
902+
res &= inp_raw->self_k_idxs->ne[0] == raw_ctx->get_n_write();
903+
}
904+
if (inp_raw->self_kq_mask && inp_raw->self_kq_mask->buffer) {
905+
res &= dsv4_can_reuse_raw_kq_mask(inp_raw->self_kq_mask, raw_ctx, params.ubatch, n_stream);
906+
}
839907

840-
res &= dsv4_can_reuse_comp_input(inp_csa, mctx->get_csa_plan(params.ubatch), params.ubatch.n_tokens, n_stream);
841-
res &= dsv4_can_reuse_comp_input(inp_hca, mctx->get_hca_plan(params.ubatch), params.ubatch.n_tokens, n_stream);
842-
res &= dsv4_can_reuse_comp_input(inp_lid, mctx->get_lid_plan(params.ubatch), params.ubatch.n_tokens, n_stream);
908+
res &= dsv4_can_reuse_comp_input(inp_csa, plan_csa, params.ubatch.n_tokens, n_stream);
909+
res &= dsv4_can_reuse_comp_input(inp_hca, plan_hca, params.ubatch.n_tokens, n_stream);
910+
res &= dsv4_can_reuse_comp_input(inp_lid, plan_lid, params.ubatch.n_tokens, n_stream);
843911

844912
return res;
845913
}
@@ -2995,28 +3063,19 @@ llm_graph_input_dsv4 * llm_graph_context::build_inp_dsv4() const {
29953063
const auto * mctx_cur = static_cast<const llama_kv_cache_dsv4_context *>(mctx);
29963064
const auto * raw_ctx = mctx_cur->get_raw();
29973065

2998-
auto inp_raw = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, raw_ctx);
3066+
auto inp_raw = std::make_unique<llm_graph_input_dsv4_raw>(cparams, raw_ctx);
29993067

3000-
{
3001-
inp_raw->self_k_idxs = raw_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
3002-
inp_raw->self_kq_mask = build_attn_inp_kq_mask(ctx0, raw_ctx->get_base(), ubatch, cparams);
3003-
inp_raw->self_kq_mask_cnv = inp_raw->self_kq_mask;
3004-
}
3068+
const int64_t n_stream = mctx_cur->get_csa_plan(ubatch).n_stream;
30053069

3006-
{
3007-
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "DSV4 expects SWA raw cache");
3070+
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "DSV4 expects SWA raw cache");
30083071

3009-
inp_raw->self_k_idxs_swa = raw_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
3010-
inp_raw->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, raw_ctx->get_swa(), ubatch, cparams);
3011-
inp_raw->self_kq_mask_swa_cnv = inp_raw->self_kq_mask_swa;
3012-
}
3072+
inp_raw->self_k_idxs = raw_ctx->build_input_k_idxs(ctx0, ubatch);
3073+
inp_raw->self_kq_mask = dsv4_build_raw_kq_mask(ctx0, raw_ctx, ubatch, cparams, n_stream);
3074+
inp_raw->self_kq_mask_cnv = inp_raw->self_kq_mask;
30133075

3014-
inp_raw->self_k_rot = raw_ctx->get_base()->build_input_k_rot(ctx0);
3015-
inp_raw->self_k_rot_swa = raw_ctx->get_swa()->build_input_k_rot(ctx0);
3076+
inp_raw->self_k_rot = raw_ctx->build_input_k_rot(ctx0);
30163077
auto inp = std::make_unique<llm_graph_input_dsv4>(cparams, std::move(inp_raw), mctx_cur);
30173078

3018-
const int64_t n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
3019-
30203079
dsv4_build_comp_inputs(ctx0, inp->inp_csa, mctx_cur->get_csa_plan(ubatch), "csa", n_stream);
30213080
dsv4_build_comp_inputs(ctx0, inp->inp_hca, mctx_cur->get_hca_plan(ubatch), "hca", n_stream);
30223081
dsv4_build_comp_inputs(ctx0, inp->inp_lid, mctx_cur->get_lid_plan(ubatch), "lid", n_stream);

src/llama-graph.h

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ struct llama_memory_context_i;
2323

2424
class llama_kv_cache_context;
2525
class llama_kv_cache_dsa_context;
26+
class llama_kv_cache_dsv4_raw_context;
2627
class llama_kv_cache_dsv4_context;
2728
class llama_kv_cache_iswa_context;
2829
class llama_memory_recurrent_context;
@@ -460,6 +461,34 @@ class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
460461
const llama_kv_cache_iswa_context * mctx;
461462
};
462463

464+
// DSV4 raw graph inputs are SWA-only, but their mask may be stream-shaped
465+
// so raw K can be concatenated with DSV4 compressed K in one attention op.
466+
class llm_graph_input_dsv4_raw {
467+
public:
468+
llm_graph_input_dsv4_raw(
469+
const llama_cparams & cparams,
470+
const llama_kv_cache_dsv4_raw_context * mctx) :
471+
cparams(cparams),
472+
mctx(mctx) {
473+
}
474+
475+
void set_input(const llama_ubatch * ubatch);
476+
477+
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
478+
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
479+
480+
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
481+
482+
ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream]
483+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
484+
485+
ggml_tensor * self_k_rot = nullptr;
486+
487+
const llama_cparams cparams;
488+
489+
const llama_kv_cache_dsv4_raw_context * mctx;
490+
};
491+
463492
class llm_graph_input_dsv4 : public llm_graph_input_i {
464493
public:
465494
struct comp_input {
@@ -477,7 +506,7 @@ class llm_graph_input_dsv4 : public llm_graph_input_i {
477506

478507
llm_graph_input_dsv4(
479508
const llama_cparams & cparams,
480-
std::unique_ptr<llm_graph_input_attn_kv_iswa> inp_raw,
509+
std::unique_ptr<llm_graph_input_dsv4_raw> inp_raw,
481510
const llama_kv_cache_dsv4_context * mctx) :
482511
inp_raw(std::move(inp_raw)),
483512
cparams(cparams),
@@ -489,12 +518,12 @@ class llm_graph_input_dsv4 : public llm_graph_input_i {
489518

490519
bool can_reuse(const llm_graph_params & params) override;
491520

492-
llm_graph_input_attn_kv_iswa * get_raw() const { return inp_raw.get(); }
521+
llm_graph_input_dsv4_raw * get_raw() const { return inp_raw.get(); }
493522
const comp_input & get_csa() const { return inp_csa; }
494523
const comp_input & get_hca() const { return inp_hca; }
495524
const comp_input & get_lid() const { return inp_lid; }
496525

497-
std::unique_ptr<llm_graph_input_attn_kv_iswa> inp_raw;
526+
std::unique_ptr<llm_graph_input_dsv4_raw> inp_raw;
498527

499528
comp_input inp_csa;
500529
comp_input inp_hca;

0 commit comments

Comments
 (0)