|
| 1 | +#include "models.h" |
| 2 | + |
| 3 | +void llama_model_gemma4_assistant::load_arch_hparams(llama_model_loader & ml) { |
| 4 | + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; |
| 5 | + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); |
| 6 | + |
| 7 | + uint32_t n_kv_shared_layers = 0; |
| 8 | + ml.get_key(LLM_KV_ATTENTION_SHARED_KV_LAYERS, n_kv_shared_layers, false); |
| 9 | + |
| 10 | + hparams.n_layer_kv_from_start = hparams.n_layer - (int32_t) n_kv_shared_layers; |
| 11 | + hparams.f_attention_scale = 1.0f; |
| 12 | + |
| 13 | + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); |
| 14 | + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); |
| 15 | + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); |
| 16 | + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); |
| 17 | + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa); |
| 18 | + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa); |
| 19 | + |
| 20 | + if (hparams.n_layer == 4) { |
| 21 | + type = LLM_TYPE_31B; |
| 22 | + } |
| 23 | +} |
| 24 | + |
| 25 | +void llama_model_gemma4_assistant::load_arch_tensors(llama_model_loader &) { |
| 26 | + LLAMA_LOAD_LOCALS; |
| 27 | + |
| 28 | + if (n_embd_head_k != n_embd_head_v) { |
| 29 | + throw std::runtime_error("Gemma 4 assistant requires n_embd_head_k == n_embd_head_v"); |
| 30 | + } |
| 31 | + if (hparams.n_embd_head_k_swa != hparams.n_embd_head_v_swa) { |
| 32 | + throw std::runtime_error("Gemma 4 assistant requires n_embd_head_k_swa == n_embd_head_v_swa"); |
| 33 | + } |
| 34 | + if (hparams.n_embd_out() == n_embd) { |
| 35 | + throw std::runtime_error("Gemma 4 assistant requires embedding_length_out to carry the target hidden size"); |
| 36 | + } |
| 37 | + |
| 38 | + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); |
| 39 | + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); |
| 40 | + |
| 41 | + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); |
| 42 | + |
| 43 | + const int64_t n_embd_backbone = hparams.n_embd_out(); |
| 44 | + nextn_pre_proj = create_tensor(tn(LLM_TENSOR_NEXTN_PRE_PROJ, "weight"), { 2*n_embd_backbone, n_embd }, 0); |
| 45 | + nextn_post_proj = create_tensor(tn(LLM_TENSOR_NEXTN_POST_PROJ, "weight"), { n_embd, n_embd_backbone }, 0); |
| 46 | + |
| 47 | + int rope_freqs_flag = 0; |
| 48 | + |
| 49 | + for (int i = 0; i < n_layer; ++i) { |
| 50 | + auto & layer = layers[i]; |
| 51 | + |
| 52 | + const int64_t n_head = hparams.n_head(i); |
| 53 | + const int64_t n_embd_head = hparams.n_embd_head_k(i); |
| 54 | + const int64_t n_ff = hparams.n_ff(i); |
| 55 | + |
| 56 | + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); |
| 57 | + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head*n_head }, 0); |
| 58 | + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head*n_head, n_embd }, 0); |
| 59 | + |
| 60 | + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head }, 0); |
| 61 | + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); |
| 62 | + |
| 63 | + layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), { 1u }, 0); |
| 64 | + |
| 65 | + if (!hparams.is_swa(i)) { |
| 66 | + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_embd_head/2 }, rope_freqs_flag); |
| 67 | + rope_freqs_flag = TENSOR_DUPLICATED; |
| 68 | + } |
| 69 | + |
| 70 | + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); |
| 71 | + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); |
| 72 | + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); |
| 73 | + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); |
| 74 | + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), { n_embd }, 0); |
| 75 | + } |
| 76 | +} |
| 77 | + |
| 78 | +std::unique_ptr<llm_graph_context> llama_model_gemma4_assistant::build_arch_graph(const llm_graph_params & params) const { |
| 79 | + return std::make_unique<graph>(*this, params); |
| 80 | +} |
| 81 | + |
| 82 | +llama_model_gemma4_assistant::graph::graph(const llama_model & model, const llm_graph_params & params) : |
| 83 | + llm_graph_context(params) { |
| 84 | + GGML_ASSERT(src_mctx && "Gemma 4 assistant graph requires an MTP source (llama_set_mtp_source)"); |
| 85 | + GGML_ASSERT(src_model && "Gemma 4 assistant graph requires a source model"); |
| 86 | + GGML_ASSERT(src_model->tok_embd && "source model missing tok_embd"); |
| 87 | + |
| 88 | + const auto & src_hparams = src_model->hparams; |
| 89 | + |
| 90 | + // By convention the MTP draft reads from the trunk's final SWA and full layers. |
| 91 | + const int32_t src_layer_full = (int32_t) src_hparams.n_layer - 1; |
| 92 | + const int32_t src_layer_swa = (int32_t) src_hparams.n_layer - 2; |
| 93 | + GGML_ASSERT(!src_hparams.is_swa(src_layer_full) && "trunk's last layer must be full attention"); |
| 94 | + GGML_ASSERT( src_hparams.is_swa(src_layer_swa) && "trunk's penultimate layer must be SWA"); |
| 95 | + |
| 96 | + const int64_t n_embd_backbone = hparams.n_embd_out(); |
| 97 | + |
| 98 | + ggml_tensor * inp_tokens; |
| 99 | + ggml_tensor * inp_h; |
| 100 | + { |
| 101 | + auto inp = std::make_unique<llm_graph_input_embd>(n_embd_backbone); |
| 102 | + |
| 103 | + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); |
| 104 | + cb(inp->tokens, "inp_tokens", -1); |
| 105 | + ggml_set_input(inp->tokens); |
| 106 | + inp_tokens = inp->tokens; |
| 107 | + res->t_inp_tokens = inp->tokens; |
| 108 | + |
| 109 | + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_backbone, ubatch.n_tokens); |
| 110 | + cb(inp->embd, "inp_h", -1); |
| 111 | + ggml_set_input(inp->embd); |
| 112 | + inp_h = inp->embd; |
| 113 | + res->t_inp_embd = inp->embd; |
| 114 | + |
| 115 | + res->add_input(std::move(inp)); |
| 116 | + } |
| 117 | + |
| 118 | + ggml_tensor * x = ggml_get_rows(ctx0, src_model->tok_embd, inp_tokens); |
| 119 | + x = ggml_scale(ctx0, x, sqrtf((float) n_embd_backbone)); |
| 120 | + cb(x, "inp_embd_target", -1); |
| 121 | + |
| 122 | + ggml_tensor * xh = ggml_concat(ctx0, x, inp_h, 0); |
| 123 | + cb(xh, "inp_xh", -1); |
| 124 | + |
| 125 | + ggml_tensor * cur = ggml_mul_mat(ctx0, model.nextn_pre_proj, xh); |
| 126 | + cb(cur, "pre_proj", -1); |
| 127 | + |
| 128 | + auto * inp_attn = build_attn_inp_src_kv_iswa(); |
| 129 | + ggml_tensor * inp_pos = build_inp_pos(); |
| 130 | + ggml_tensor * inp_out_ids = build_inp_out_ids(); |
| 131 | + |
| 132 | + ggml_tensor * inpL = cur; |
| 133 | + |
| 134 | + for (int il = 0; il < n_layer; ++il) { |
| 135 | + const bool is_swa = hparams.is_swa(il); |
| 136 | + const int32_t il_src = is_swa ? src_layer_swa : src_layer_full; |
| 137 | + |
| 138 | + const int64_t n_embd_head = hparams.n_embd_head_k(il); |
| 139 | + const int64_t n_head = hparams.n_head(il); |
| 140 | + |
| 141 | + const float freq_base_l = model.get_rope_freq_base(cparams, il); |
| 142 | + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); |
| 143 | + const int n_rot_l = hparams.n_rot(il); |
| 144 | + |
| 145 | + ggml_tensor * cur_norm = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); |
| 146 | + cb(cur_norm, "attn_norm", il); |
| 147 | + |
| 148 | + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur_norm); |
| 149 | + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); |
| 150 | + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); |
| 151 | + cb(Qcur, "Qcur_normed", il); |
| 152 | + |
| 153 | + ggml_tensor * freq_factors = is_swa ? nullptr : model.layers[il].rope_freqs; |
| 154 | + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, freq_factors, n_rot_l, rope_type, n_ctx_orig, |
| 155 | + freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow); |
| 156 | + cb(Qcur, "Qcur_pos", il); |
| 157 | + |
| 158 | + cur = build_attn(inp_attn, model.layers[il].wo, nullptr, nullptr, |
| 159 | + Qcur, nullptr, nullptr, nullptr, hparams.f_attention_scale, il, il_src); |
| 160 | + |
| 161 | + if (il == n_layer - 1 && inp_out_ids) { |
| 162 | + cur = ggml_get_rows(ctx0, cur, inp_out_ids); |
| 163 | + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); |
| 164 | + } |
| 165 | + |
| 166 | + cur = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il); |
| 167 | + cb(cur, "attn_post_norm", il); |
| 168 | + |
| 169 | + ggml_tensor * attn_out = ggml_add(ctx0, cur, inpL); |
| 170 | + cb(attn_out, "attn_out", il); |
| 171 | + |
| 172 | + cur = build_norm(attn_out, model.layers[il].ffn_norm, nullptr, LLM_NORM_RMS, il); |
| 173 | + cb(cur, "ffn_norm", il); |
| 174 | + |
| 175 | + cur = build_ffn(cur, |
| 176 | + model.layers[il].ffn_up, nullptr, nullptr, |
| 177 | + model.layers[il].ffn_gate, nullptr, nullptr, |
| 178 | + model.layers[il].ffn_down, nullptr, nullptr, |
| 179 | + nullptr, |
| 180 | + LLM_FFN_GELU, LLM_FFN_PAR, il); |
| 181 | + cb(cur, "ffn_out", il); |
| 182 | + |
| 183 | + cur = build_norm(cur, model.layers[il].ffn_post_norm, nullptr, LLM_NORM_RMS, -1); |
| 184 | + cb(cur, "ffn_post_norm", il); |
| 185 | + |
| 186 | + cur = ggml_add(ctx0, cur, attn_out); |
| 187 | + |
| 188 | + cur = ggml_mul(ctx0, cur, model.layers[il].out_scale); |
| 189 | + cb(cur, "out_scaled", il); |
| 190 | + |
| 191 | + inpL = cur; |
| 192 | + } |
| 193 | + cur = inpL; |
| 194 | + |
| 195 | + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); |
| 196 | + cb(cur, "result_norm", -1); |
| 197 | + |
| 198 | + ggml_tensor * logits = build_lora_mm(model.output, cur); |
| 199 | + cb(logits, "result_output", -1); |
| 200 | + res->t_logits = logits; |
| 201 | + |
| 202 | + ggml_tensor * h_next = ggml_mul_mat(ctx0, model.nextn_post_proj, cur); |
| 203 | + cb(h_next, "result_h_pre_norm", -1); |
| 204 | + res->t_h_pre_norm = h_next; |
| 205 | + |
| 206 | + ggml_build_forward_expand(gf, logits); |
| 207 | + ggml_build_forward_expand(gf, h_next); |
| 208 | +} |
0 commit comments