|
8 | 8 | #include "llama-kv-cache.h" |
9 | 9 | #include "llama-kv-cache-iswa.h" |
10 | 10 | #include "llama-kv-cache-dsa.h" |
| 11 | +#include "llama-kv-cache-dsv4.h" |
11 | 12 | #include "llama-memory-hybrid.h" |
12 | 13 | #include "llama-memory-hybrid-iswa.h" |
13 | 14 | #include "llama-memory-recurrent.h" |
|
17 | 18 | #include <cstring> |
18 | 19 | #include <numeric> |
19 | 20 | #include <sstream> |
| 21 | +#include <string> |
20 | 22 | #include <unordered_set> |
21 | 23 |
|
22 | 24 | // dedup helpers |
@@ -620,6 +622,223 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { |
620 | 622 | return res; |
621 | 623 | } |
622 | 624 |
|
| 625 | +static void dsv4_set_i64(ggml_tensor * dst, const std::vector<int64_t> & src) { |
| 626 | + if (!dst || !dst->buffer) { |
| 627 | + return; |
| 628 | + } |
| 629 | + |
| 630 | + GGML_ASSERT(dst->ne[0] == (int64_t) src.size()); |
| 631 | + ggml_backend_tensor_set(dst, src.data(), 0, src.size()*ggml_element_size(dst)); |
| 632 | +} |
| 633 | + |
| 634 | +static void dsv4_set_i32(ggml_tensor * dst, const std::vector<int32_t> & src) { |
| 635 | + if (!dst || !dst->buffer) { |
| 636 | + return; |
| 637 | + } |
| 638 | + |
| 639 | + GGML_ASSERT(dst->ne[0] == (int64_t) src.size()); |
| 640 | + ggml_backend_tensor_set(dst, src.data(), 0, src.size()*ggml_element_size(dst)); |
| 641 | +} |
| 642 | + |
| 643 | +static void dsv4_set_kq_mask( |
| 644 | + ggml_tensor * dst, |
| 645 | + const llama_kv_cache_dsv4_context::comp_plan & plan, |
| 646 | + uint32_t n_tokens) { |
| 647 | + if (!dst || !dst->buffer) { |
| 648 | + return; |
| 649 | + } |
| 650 | + |
| 651 | + GGML_ASSERT(dst->type == GGML_TYPE_F32); |
| 652 | + GGML_ASSERT(dst->ne[0] == plan.n_kv); |
| 653 | + GGML_ASSERT(dst->ne[1] == (int64_t) n_tokens); |
| 654 | + GGML_ASSERT(dst->ne[2] == 1); |
| 655 | + GGML_ASSERT(dst->ne[3] == 1); |
| 656 | + GGML_ASSERT((int64_t) plan.n_visible.size() == dst->ne[1]); |
| 657 | + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); |
| 658 | + |
| 659 | + float * data = (float *) dst->data; |
| 660 | + |
| 661 | + for (int64_t i = 0; i < dst->ne[1]; ++i) { |
| 662 | + const int32_t n_visible = plan.n_visible[i]; |
| 663 | + |
| 664 | + for (int64_t j = 0; j < dst->ne[0]; ++j) { |
| 665 | + data[i*dst->ne[0] + j] = j < n_visible ? 0.0f : -INFINITY; |
| 666 | + } |
| 667 | + } |
| 668 | +} |
| 669 | + |
| 670 | +static std::string dsv4_plan_positions(const std::vector<int32_t> & values) { |
| 671 | + std::ostringstream ss; |
| 672 | + ss << "["; |
| 673 | + for (size_t i = 0; i < values.size(); ++i) { |
| 674 | + if (i > 0) { |
| 675 | + ss << ", "; |
| 676 | + } |
| 677 | + ss << values[i]; |
| 678 | + } |
| 679 | + ss << "]"; |
| 680 | + return ss.str(); |
| 681 | +} |
| 682 | + |
| 683 | +static bool dsv4_compress_debug() { |
| 684 | + static const bool debug = []() { |
| 685 | + const char * env = getenv("LLAMA_DSV4_COMPRESS_DEBUG"); |
| 686 | + return env && atoi(env) > 0; |
| 687 | + }(); |
| 688 | + |
| 689 | + return debug; |
| 690 | +} |
| 691 | + |
| 692 | +static void dsv4_set_comp_inputs( |
| 693 | + const llm_graph_input_dsv4::comp_input & inp, |
| 694 | + const llama_kv_cache_dsv4_context::comp_plan & plan, |
| 695 | + const char * name, |
| 696 | + bool debug, |
| 697 | + uint32_t n_tokens) { |
| 698 | + dsv4_set_i64(inp.write_idxs, plan.write_idxs); |
| 699 | + dsv4_set_i32(inp.write_pos, plan.write_pos); |
| 700 | + dsv4_set_i32(inp.write_end, plan.write_end); |
| 701 | + dsv4_set_i32(inp.pending_end, plan.pending_end); |
| 702 | + dsv4_set_i32(inp.state_idxs, plan.state_idxs); |
| 703 | + dsv4_set_i32(inp.state_pos, plan.state_pos); |
| 704 | + dsv4_set_i32(inp.state_read_idxs, plan.state_read_idxs); |
| 705 | + dsv4_set_i64(inp.state_write_idxs, plan.state_write_idxs); |
| 706 | + dsv4_set_i32(inp.state_write_pos, plan.state_write_pos); |
| 707 | + dsv4_set_i32(inp.state_write_end, plan.state_write_end); |
| 708 | + dsv4_set_i32(inp.n_visible, plan.n_visible); |
| 709 | + dsv4_set_kq_mask(inp.kq_mask, plan, n_tokens); |
| 710 | + |
| 711 | + if (debug || dsv4_compress_debug()) { |
| 712 | + LLAMA_LOG_INFO("%s: %s ratio=%u, n_tokens=%u, write_end=%s, state_write_end=%s, pending_end=%s\n", |
| 713 | + __func__, name, plan.ratio, n_tokens, |
| 714 | + dsv4_plan_positions(plan.write_end).c_str(), |
| 715 | + dsv4_plan_positions(plan.state_write_end).c_str(), |
| 716 | + dsv4_plan_positions(plan.pending_end).c_str()); |
| 717 | + } |
| 718 | +} |
| 719 | + |
| 720 | +static bool dsv4_can_reuse_tensor_1d(ggml_tensor * t, int64_t ne0) { |
| 721 | + return (t == nullptr && ne0 == 0) || (t != nullptr && t->ne[0] == ne0); |
| 722 | +} |
| 723 | + |
| 724 | +static bool dsv4_can_reuse_kq_mask( |
| 725 | + ggml_tensor * t, |
| 726 | + const llama_kv_cache_dsv4_context::comp_plan & plan, |
| 727 | + uint32_t n_tokens) { |
| 728 | + if (plan.n_kv == 0) { |
| 729 | + return t == nullptr; |
| 730 | + } |
| 731 | + |
| 732 | + return t != nullptr && |
| 733 | + t->ne[0] == plan.n_kv && |
| 734 | + t->ne[1] == (int64_t) n_tokens && |
| 735 | + t->ne[2] == 1 && |
| 736 | + t->ne[3] == 1; |
| 737 | +} |
| 738 | + |
| 739 | +static bool dsv4_can_reuse_comp_input( |
| 740 | + const llm_graph_input_dsv4::comp_input & inp, |
| 741 | + const llama_kv_cache_dsv4_context::comp_plan & plan, |
| 742 | + uint32_t n_tokens) { |
| 743 | + const int64_t n_write = plan.write_idxs.size(); |
| 744 | + |
| 745 | + bool res = true; |
| 746 | + res &= dsv4_can_reuse_tensor_1d(inp.write_idxs, n_write); |
| 747 | + res &= dsv4_can_reuse_tensor_1d(inp.write_pos, n_write); |
| 748 | + res &= dsv4_can_reuse_tensor_1d(inp.write_end, n_write); |
| 749 | + res &= dsv4_can_reuse_tensor_1d(inp.pending_end, plan.pending_end.size()); |
| 750 | + res &= dsv4_can_reuse_tensor_1d(inp.state_idxs, plan.state_idxs.size()); |
| 751 | + res &= dsv4_can_reuse_tensor_1d(inp.state_pos, plan.state_pos.size()); |
| 752 | + res &= dsv4_can_reuse_tensor_1d(inp.state_read_idxs, plan.state_read_idxs.size()); |
| 753 | + res &= dsv4_can_reuse_tensor_1d(inp.state_write_idxs, plan.state_write_idxs.size()); |
| 754 | + res &= dsv4_can_reuse_tensor_1d(inp.state_write_pos, plan.state_write_pos.size()); |
| 755 | + res &= dsv4_can_reuse_tensor_1d(inp.state_write_end, plan.state_write_end.size()); |
| 756 | + res &= dsv4_can_reuse_tensor_1d(inp.n_visible, plan.n_visible.size()); |
| 757 | + res &= dsv4_can_reuse_kq_mask(inp.kq_mask, plan, n_tokens); |
| 758 | + |
| 759 | + return res; |
| 760 | +} |
| 761 | + |
| 762 | +static ggml_tensor * dsv4_build_input_1d( |
| 763 | + ggml_context * ctx, |
| 764 | + ggml_type type, |
| 765 | + int64_t ne0, |
| 766 | + const std::string & name) { |
| 767 | + if (ne0 == 0) { |
| 768 | + return nullptr; |
| 769 | + } |
| 770 | + |
| 771 | + ggml_tensor * res = ggml_new_tensor_1d(ctx, type, ne0); |
| 772 | + ggml_set_input(res); |
| 773 | + ggml_set_name(res, name.c_str()); |
| 774 | + |
| 775 | + return res; |
| 776 | +} |
| 777 | + |
| 778 | +static void dsv4_build_comp_inputs( |
| 779 | + ggml_context * ctx, |
| 780 | + llm_graph_input_dsv4::comp_input & inp, |
| 781 | + const llama_kv_cache_dsv4_context::comp_plan & plan, |
| 782 | + const char * name) { |
| 783 | + const int64_t n_write = plan.write_idxs.size(); |
| 784 | + |
| 785 | + inp.write_idxs = dsv4_build_input_1d(ctx, GGML_TYPE_I64, n_write, std::string("dsv4_") + name + "_write_idxs"); |
| 786 | + inp.write_pos = dsv4_build_input_1d(ctx, GGML_TYPE_I32, n_write, std::string("dsv4_") + name + "_write_pos"); |
| 787 | + inp.write_end = dsv4_build_input_1d(ctx, GGML_TYPE_I32, n_write, std::string("dsv4_") + name + "_write_end"); |
| 788 | + inp.pending_end = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.pending_end.size(), std::string("dsv4_") + name + "_pending_end"); |
| 789 | + inp.state_idxs = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_idxs.size(), std::string("dsv4_") + name + "_state_idxs"); |
| 790 | + inp.state_pos = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_pos.size(), std::string("dsv4_") + name + "_state_pos"); |
| 791 | + inp.state_read_idxs = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_read_idxs.size(), std::string("dsv4_") + name + "_state_read_idxs"); |
| 792 | + inp.state_write_idxs = dsv4_build_input_1d(ctx, GGML_TYPE_I64, plan.state_write_idxs.size(), std::string("dsv4_") + name + "_state_write_idxs"); |
| 793 | + inp.state_write_pos = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_write_pos.size(), std::string("dsv4_") + name + "_state_write_pos"); |
| 794 | + inp.state_write_end = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_write_end.size(), std::string("dsv4_") + name + "_state_write_end"); |
| 795 | + inp.n_visible = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.n_visible.size(), std::string("dsv4_") + name + "_n_visible"); |
| 796 | + |
| 797 | + if (plan.n_kv > 0) { |
| 798 | + inp.kq_mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, plan.n_kv, plan.n_visible.size(), 1, 1); |
| 799 | + ggml_set_input(inp.kq_mask); |
| 800 | + ggml_set_name(inp.kq_mask, (std::string("dsv4_") + name + "_kq_mask").c_str()); |
| 801 | + } |
| 802 | +} |
| 803 | + |
| 804 | +void llm_graph_input_dsv4::set_input(const llama_ubatch * ubatch) { |
| 805 | + inp_raw->mctx = mctx->get_raw(); |
| 806 | + inp_raw->set_input(ubatch); |
| 807 | + |
| 808 | + dsv4_set_comp_inputs(inp_csa, mctx->get_csa_plan(), "csa", debug > 0, ubatch->n_tokens); |
| 809 | + dsv4_set_comp_inputs(inp_hca, mctx->get_hca_plan(), "hca", debug > 0, ubatch->n_tokens); |
| 810 | + dsv4_set_comp_inputs(inp_lid, mctx->get_lid_plan(), "lid", debug > 0, ubatch->n_tokens); |
| 811 | + |
| 812 | + if (inp_lid.k_rot && inp_lid.k_rot->buffer) { |
| 813 | + mctx->get_lid()->set_input_k_rot(inp_lid.k_rot); |
| 814 | + } |
| 815 | +} |
| 816 | + |
| 817 | +bool llm_graph_input_dsv4::can_reuse(const llm_graph_params & params) { |
| 818 | + const auto * mctx = static_cast<const llama_kv_cache_dsv4_context *>(params.mctx); |
| 819 | + |
| 820 | + this->mctx = mctx; |
| 821 | + inp_raw->mctx = mctx->get_raw(); |
| 822 | + |
| 823 | + bool res = true; |
| 824 | + |
| 825 | + if (inp_raw->self_k_idxs && inp_raw->self_k_idxs->buffer) { |
| 826 | + res &= inp_raw->self_k_idxs->ne[0] == params.ubatch.n_tokens; |
| 827 | + res &= can_reuse_kq_mask(inp_raw->self_kq_mask, mctx->get_raw()->get_base(), params.ubatch, params.cparams); |
| 828 | + } |
| 829 | + |
| 830 | + if (inp_raw->self_k_idxs_swa && inp_raw->self_k_idxs_swa->buffer) { |
| 831 | + res &= inp_raw->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; |
| 832 | + res &= can_reuse_kq_mask(inp_raw->self_kq_mask_swa, mctx->get_raw()->get_swa(), params.ubatch, params.cparams); |
| 833 | + } |
| 834 | + |
| 835 | + res &= dsv4_can_reuse_comp_input(inp_csa, mctx->get_csa_plan(), params.ubatch.n_tokens); |
| 836 | + res &= dsv4_can_reuse_comp_input(inp_hca, mctx->get_hca_plan(), params.ubatch.n_tokens); |
| 837 | + res &= dsv4_can_reuse_comp_input(inp_lid, mctx->get_lid_plan(), params.ubatch.n_tokens); |
| 838 | + |
| 839 | + return res; |
| 840 | +} |
| 841 | + |
623 | 842 | void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { |
624 | 843 | GGML_ASSERT(cross_kq_mask); |
625 | 844 |
|
@@ -2731,6 +2950,46 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const |
2731 | 2950 | return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp)); |
2732 | 2951 | } |
2733 | 2952 |
|
| 2953 | +llm_graph_input_dsv4 * llm_graph_context::build_inp_dsv4() const { |
| 2954 | + const auto * mctx_cur = static_cast<const llama_kv_cache_dsv4_context *>(mctx); |
| 2955 | + const auto * raw_ctx = mctx_cur->get_raw(); |
| 2956 | + |
| 2957 | + auto inp_raw = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, raw_ctx); |
| 2958 | + |
| 2959 | + { |
| 2960 | + inp_raw->self_k_idxs = raw_ctx->get_base()->build_input_k_idxs(ctx0, ubatch); |
| 2961 | + inp_raw->self_v_idxs = raw_ctx->get_base()->build_input_v_idxs(ctx0, ubatch); |
| 2962 | + |
| 2963 | + inp_raw->self_kq_mask = build_attn_inp_kq_mask(ctx0, raw_ctx->get_base(), ubatch, cparams); |
| 2964 | + inp_raw->self_kq_mask_cnv = inp_raw->self_kq_mask; |
| 2965 | + } |
| 2966 | + |
| 2967 | + { |
| 2968 | + GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "DSV4 expects SWA raw cache"); |
| 2969 | + |
| 2970 | + inp_raw->self_k_idxs_swa = raw_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch); |
| 2971 | + inp_raw->self_v_idxs_swa = raw_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch); |
| 2972 | + |
| 2973 | + inp_raw->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, raw_ctx->get_swa(), ubatch, cparams); |
| 2974 | + inp_raw->self_kq_mask_swa_cnv = inp_raw->self_kq_mask_swa; |
| 2975 | + } |
| 2976 | + |
| 2977 | + inp_raw->self_k_rot = raw_ctx->get_base()->build_input_k_rot(ctx0); |
| 2978 | + inp_raw->self_v_rot = raw_ctx->get_base()->build_input_v_rot(ctx0); |
| 2979 | + |
| 2980 | + inp_raw->self_k_rot_swa = raw_ctx->get_swa()->build_input_k_rot(ctx0); |
| 2981 | + inp_raw->self_v_rot_swa = raw_ctx->get_swa()->build_input_v_rot(ctx0); |
| 2982 | + |
| 2983 | + auto inp = std::make_unique<llm_graph_input_dsv4>(cparams, std::move(inp_raw), mctx_cur); |
| 2984 | + |
| 2985 | + dsv4_build_comp_inputs(ctx0, inp->inp_csa, mctx_cur->get_csa_plan(), "csa"); |
| 2986 | + dsv4_build_comp_inputs(ctx0, inp->inp_hca, mctx_cur->get_hca_plan(), "hca"); |
| 2987 | + dsv4_build_comp_inputs(ctx0, inp->inp_lid, mctx_cur->get_lid_plan(), "lid"); |
| 2988 | + inp->inp_lid.k_rot = mctx_cur->get_lid()->build_input_k_rot(ctx0); |
| 2989 | + |
| 2990 | + return (llm_graph_input_dsv4 *) res->add_input(std::move(inp)); |
| 2991 | +} |
| 2992 | + |
2734 | 2993 | ggml_tensor * llm_graph_context::build_rs( |
2735 | 2994 | ggml_tensor * s, |
2736 | 2995 | ggml_tensor * state_copy_main, |
|
0 commit comments