|
1 | 1 | #include "llama-graph.h" |
2 | 2 |
|
3 | 3 | #include "llama-impl.h" |
| 4 | +#include "llama-model.h" |
4 | 5 | #include "llama-batch.h" |
5 | 6 | #include "llama-cparams.h" |
6 | 7 |
|
@@ -992,6 +993,67 @@ ggml_tensor * llm_graph_context::build_norm( |
992 | 993 | return cur; |
993 | 994 | } |
994 | 995 |
|
| 996 | + |
| 997 | +llm_graph_qkv llm_graph_context::build_qkv( |
| 998 | + const llama_layer & layer, |
| 999 | + ggml_tensor * cur, |
| 1000 | + int64_t n_embd_head, |
| 1001 | + int64_t n_head, |
| 1002 | + int64_t n_head_kv, |
| 1003 | + int il) const { |
| 1004 | + const int64_t n_embd_q = n_embd_head * n_head; |
| 1005 | + const int64_t n_embd_kv = n_embd_head * n_head_kv; |
| 1006 | + |
| 1007 | + ggml_tensor * Qcur, * Kcur, * Vcur; |
| 1008 | + |
| 1009 | + if (layer.wqkv) { |
| 1010 | + // fused QKV path |
| 1011 | + ggml_tensor * qkv = build_lora_mm(layer.wqkv, cur); |
| 1012 | + cb(qkv, "wqkv", il); |
| 1013 | + if (layer.bqkv) { |
| 1014 | + qkv = ggml_add(ctx0, qkv, layer.bqkv); |
| 1015 | + cb(qkv, "bqkv", il); |
| 1016 | + } |
| 1017 | + Qcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head, n_tokens, |
| 1018 | + ggml_element_size(qkv) * n_embd_head, qkv->nb[1], 0); |
| 1019 | + Kcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head_kv, n_tokens, |
| 1020 | + ggml_element_size(qkv) * n_embd_head, qkv->nb[1], |
| 1021 | + ggml_element_size(qkv) * n_embd_q); |
| 1022 | + Vcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head_kv, n_tokens, |
| 1023 | + ggml_element_size(qkv) * n_embd_head, qkv->nb[1], |
| 1024 | + ggml_element_size(qkv) * (n_embd_q + n_embd_kv)); |
| 1025 | + } else { |
| 1026 | + // separate Q/K/V path |
| 1027 | + Qcur = build_lora_mm(layer.wq, cur, layer.wq_s); |
| 1028 | + cb(Qcur, "Qcur", il); |
| 1029 | + if (layer.bq) { |
| 1030 | + Qcur = ggml_add(ctx0, Qcur, layer.bq); |
| 1031 | + cb(Qcur, "Qcur", il); |
| 1032 | + } |
| 1033 | + Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); |
| 1034 | + cb(Kcur, "Kcur", il); |
| 1035 | + if (layer.bk) { |
| 1036 | + Kcur = ggml_add(ctx0, Kcur, layer.bk); |
| 1037 | + cb(Kcur, "Kcur", il); |
| 1038 | + } |
| 1039 | + Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); |
| 1040 | + cb(Vcur, "Vcur", il); |
| 1041 | + if (layer.bv) { |
| 1042 | + Vcur = ggml_add(ctx0, Vcur, layer.bv); |
| 1043 | + cb(Vcur, "Vcur", il); |
| 1044 | + } |
| 1045 | + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); |
| 1046 | + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); |
| 1047 | + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); |
| 1048 | + } |
| 1049 | + |
| 1050 | + cb(Qcur, "Qcur", il); |
| 1051 | + cb(Kcur, "Kcur", il); |
| 1052 | + cb(Vcur, "Vcur", il); |
| 1053 | + |
| 1054 | + return { Qcur, Kcur, Vcur }; |
| 1055 | +} |
| 1056 | + |
995 | 1057 | ggml_tensor * llm_graph_context::build_ffn( |
996 | 1058 | ggml_tensor * cur, |
997 | 1059 | ggml_tensor * up, |
|
0 commit comments