Skip to content

Commit dd97604

Browse files
committed
move assistant to separate file
1 parent c0da00a commit dd97604

2 files changed

Lines changed: 208 additions & 207 deletions

File tree

src/models/gemma4-assistant.cpp

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
#include "models.h"
2+
3+
void llama_model_gemma4_assistant::load_arch_hparams(llama_model_loader & ml) {
4+
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
5+
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer);
6+
7+
uint32_t n_kv_shared_layers = 0;
8+
ml.get_key(LLM_KV_ATTENTION_SHARED_KV_LAYERS, n_kv_shared_layers, false);
9+
10+
hparams.n_layer_kv_from_start = hparams.n_layer - (int32_t) n_kv_shared_layers;
11+
hparams.f_attention_scale = 1.0f;
12+
13+
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
14+
ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
15+
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
16+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
17+
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa);
18+
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa);
19+
20+
if (hparams.n_layer == 4) {
21+
type = LLM_TYPE_31B;
22+
}
23+
}
24+
25+
void llama_model_gemma4_assistant::load_arch_tensors(llama_model_loader &) {
26+
LLAMA_LOAD_LOCALS;
27+
28+
if (n_embd_head_k != n_embd_head_v) {
29+
throw std::runtime_error("Gemma 4 assistant requires n_embd_head_k == n_embd_head_v");
30+
}
31+
if (hparams.n_embd_head_k_swa != hparams.n_embd_head_v_swa) {
32+
throw std::runtime_error("Gemma 4 assistant requires n_embd_head_k_swa == n_embd_head_v_swa");
33+
}
34+
if (hparams.n_embd_out() == n_embd) {
35+
throw std::runtime_error("Gemma 4 assistant requires embedding_length_out to carry the target hidden size");
36+
}
37+
38+
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
39+
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED);
40+
41+
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
42+
43+
const int64_t n_embd_backbone = hparams.n_embd_out();
44+
nextn_pre_proj = create_tensor(tn(LLM_TENSOR_NEXTN_PRE_PROJ, "weight"), { 2*n_embd_backbone, n_embd }, 0);
45+
nextn_post_proj = create_tensor(tn(LLM_TENSOR_NEXTN_POST_PROJ, "weight"), { n_embd, n_embd_backbone }, 0);
46+
47+
int rope_freqs_flag = 0;
48+
49+
for (int i = 0; i < n_layer; ++i) {
50+
auto & layer = layers[i];
51+
52+
const int64_t n_head = hparams.n_head(i);
53+
const int64_t n_embd_head = hparams.n_embd_head_k(i);
54+
const int64_t n_ff = hparams.n_ff(i);
55+
56+
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
57+
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head*n_head }, 0);
58+
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head*n_head, n_embd }, 0);
59+
60+
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head }, 0);
61+
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
62+
63+
layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), { 1u }, 0);
64+
65+
if (!hparams.is_swa(i)) {
66+
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_embd_head/2 }, rope_freqs_flag);
67+
rope_freqs_flag = TENSOR_DUPLICATED;
68+
}
69+
70+
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
71+
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0);
72+
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);
73+
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
74+
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), { n_embd }, 0);
75+
}
76+
}
77+
78+
std::unique_ptr<llm_graph_context> llama_model_gemma4_assistant::build_arch_graph(const llm_graph_params & params) const {
79+
return std::make_unique<graph>(*this, params);
80+
}
81+
82+
llama_model_gemma4_assistant::graph::graph(const llama_model & model, const llm_graph_params & params) :
83+
llm_graph_context(params) {
84+
GGML_ASSERT(src_mctx && "Gemma 4 assistant graph requires an MTP source (llama_set_mtp_source)");
85+
GGML_ASSERT(src_model && "Gemma 4 assistant graph requires a source model");
86+
GGML_ASSERT(src_model->tok_embd && "source model missing tok_embd");
87+
88+
const auto & src_hparams = src_model->hparams;
89+
90+
// By convention the MTP draft reads from the trunk's final SWA and full layers.
91+
const int32_t src_layer_full = (int32_t) src_hparams.n_layer - 1;
92+
const int32_t src_layer_swa = (int32_t) src_hparams.n_layer - 2;
93+
GGML_ASSERT(!src_hparams.is_swa(src_layer_full) && "trunk's last layer must be full attention");
94+
GGML_ASSERT( src_hparams.is_swa(src_layer_swa) && "trunk's penultimate layer must be SWA");
95+
96+
const int64_t n_embd_backbone = hparams.n_embd_out();
97+
98+
ggml_tensor * inp_tokens;
99+
ggml_tensor * inp_h;
100+
{
101+
auto inp = std::make_unique<llm_graph_input_embd>(n_embd_backbone);
102+
103+
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
104+
cb(inp->tokens, "inp_tokens", -1);
105+
ggml_set_input(inp->tokens);
106+
inp_tokens = inp->tokens;
107+
res->t_inp_tokens = inp->tokens;
108+
109+
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_backbone, ubatch.n_tokens);
110+
cb(inp->embd, "inp_h", -1);
111+
ggml_set_input(inp->embd);
112+
inp_h = inp->embd;
113+
res->t_inp_embd = inp->embd;
114+
115+
res->add_input(std::move(inp));
116+
}
117+
118+
ggml_tensor * x = ggml_get_rows(ctx0, src_model->tok_embd, inp_tokens);
119+
x = ggml_scale(ctx0, x, sqrtf((float) n_embd_backbone));
120+
cb(x, "inp_embd_target", -1);
121+
122+
ggml_tensor * xh = ggml_concat(ctx0, x, inp_h, 0);
123+
cb(xh, "inp_xh", -1);
124+
125+
ggml_tensor * cur = ggml_mul_mat(ctx0, model.nextn_pre_proj, xh);
126+
cb(cur, "pre_proj", -1);
127+
128+
auto * inp_attn = build_attn_inp_src_kv_iswa();
129+
ggml_tensor * inp_pos = build_inp_pos();
130+
ggml_tensor * inp_out_ids = build_inp_out_ids();
131+
132+
ggml_tensor * inpL = cur;
133+
134+
for (int il = 0; il < n_layer; ++il) {
135+
const bool is_swa = hparams.is_swa(il);
136+
const int32_t il_src = is_swa ? src_layer_swa : src_layer_full;
137+
138+
const int64_t n_embd_head = hparams.n_embd_head_k(il);
139+
const int64_t n_head = hparams.n_head(il);
140+
141+
const float freq_base_l = model.get_rope_freq_base(cparams, il);
142+
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
143+
const int n_rot_l = hparams.n_rot(il);
144+
145+
ggml_tensor * cur_norm = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
146+
cb(cur_norm, "attn_norm", il);
147+
148+
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur_norm);
149+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
150+
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
151+
cb(Qcur, "Qcur_normed", il);
152+
153+
ggml_tensor * freq_factors = is_swa ? nullptr : model.layers[il].rope_freqs;
154+
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, freq_factors, n_rot_l, rope_type, n_ctx_orig,
155+
freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow);
156+
cb(Qcur, "Qcur_pos", il);
157+
158+
cur = build_attn(inp_attn, model.layers[il].wo, nullptr, nullptr,
159+
Qcur, nullptr, nullptr, nullptr, hparams.f_attention_scale, il, il_src);
160+
161+
if (il == n_layer - 1 && inp_out_ids) {
162+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
163+
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
164+
}
165+
166+
cur = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il);
167+
cb(cur, "attn_post_norm", il);
168+
169+
ggml_tensor * attn_out = ggml_add(ctx0, cur, inpL);
170+
cb(attn_out, "attn_out", il);
171+
172+
cur = build_norm(attn_out, model.layers[il].ffn_norm, nullptr, LLM_NORM_RMS, il);
173+
cb(cur, "ffn_norm", il);
174+
175+
cur = build_ffn(cur,
176+
model.layers[il].ffn_up, nullptr, nullptr,
177+
model.layers[il].ffn_gate, nullptr, nullptr,
178+
model.layers[il].ffn_down, nullptr, nullptr,
179+
nullptr,
180+
LLM_FFN_GELU, LLM_FFN_PAR, il);
181+
cb(cur, "ffn_out", il);
182+
183+
cur = build_norm(cur, model.layers[il].ffn_post_norm, nullptr, LLM_NORM_RMS, -1);
184+
cb(cur, "ffn_post_norm", il);
185+
186+
cur = ggml_add(ctx0, cur, attn_out);
187+
188+
cur = ggml_mul(ctx0, cur, model.layers[il].out_scale);
189+
cb(cur, "out_scaled", il);
190+
191+
inpL = cur;
192+
}
193+
cur = inpL;
194+
195+
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
196+
cb(cur, "result_norm", -1);
197+
198+
ggml_tensor * logits = build_lora_mm(model.output, cur);
199+
cb(logits, "result_output", -1);
200+
res->t_logits = logits;
201+
202+
ggml_tensor * h_next = ggml_mul_mat(ctx0, model.nextn_post_proj, cur);
203+
cb(h_next, "h_nextn", -1);
204+
res->t_h_nextn = h_next;
205+
206+
ggml_build_forward_expand(gf, logits);
207+
ggml_build_forward_expand(gf, h_next);
208+
}

0 commit comments

Comments
 (0)