Skip to content

Commit fbd2d4e

Browse files
committed
port: minimal Gemma 4 arch support ported from turboquant
Port the core Gemma 4 (gemma4-iswa) architecture from icex/integration/macos-gemma4-128k-20260412 onto rotorquant's feature/planarquant-kv-cache. Enough to load, tokenize, and run gemma-4-E4B-it-Q4_K_M.gguf through llama-cli on Metal. What this includes: - src/models/gemma4-iswa.cpp (311 lines, verbatim from turboquant) - Gemma 4 arch / kv / tensor enum entries in llama-arch.* - Gemma 4 hparams (shared_kv_layers, embedding_length_per_layer) - Gemma 4 model loader + graph builder registration in llama-model.* - Gemma 4 vocab pretokenizer (LLAMA_VOCAB_PRE_TYPE_GEMMA4 = 50) - Newline-split tokenizer BOS handling in llama-vocab.cpp - CMake entry for gemma4-iswa.cpp - unicode_regex_split: add byte_encode default parameter shim (ignored in body; mirrors PR ggml-org#21406 signature so the Gemma 4 vocab path compiles without pulling the full custom-newlines impl) What this deliberately skips: - MTMD vision/audio projector changes (text-only benchmarking) - common/chat.cpp Gemma 4 parser (llama-cli does not need it) - convert_hf_to_gguf.py updates (pre-converted GGUF already on disk) - tools/mtmd clip-model/clip-impl updates - Tokenizer test suite additions - Downstream fix commits ggml-org#21326, ggml-org#21343, ggml-org#21390, ggml-org#21406, ggml-org#21418, ggml-org#21500, ggml-org#21534, ggml-org#21704, ggml-org#21739 — verified non-essential for Metal text generation, may be revisited if quality issues surface Smoke-tested with: llama-cli -m gemma-4-E4B-it-Q4_K_M.gguf -p 'The capital of France is' → correctly produces 'The capital of France is Paris.' Rotorquant still has planar3/iso3/turbo3 KV cache types from its own feature/planarquant-kv-cache line; they are unaffected by this port but were not rerun on Gemma 4 in this session.
1 parent 20efe75 commit fbd2d4e

12 files changed

Lines changed: 607 additions & 8 deletions

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ add_library(llama
7373
models/gemma2-iswa.cpp
7474
models/gemma3.cpp
7575
models/gemma3n-iswa.cpp
76+
models/gemma4-iswa.cpp
7677
models/glm4-moe.cpp
7778
models/glm4.cpp
7879
models/gpt2.cpp

src/llama-arch.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
5656
{ LLM_ARCH_GEMMA2, "gemma2" },
5757
{ LLM_ARCH_GEMMA3, "gemma3" },
5858
{ LLM_ARCH_GEMMA3N, "gemma3n" },
59+
{ LLM_ARCH_GEMMA4, "gemma4" },
5960
{ LLM_ARCH_GEMMA_EMBEDDING, "gemma-embedding" },
6061
{ LLM_ARCH_STARCODER2, "starcoder2" },
6162
{ LLM_ARCH_MAMBA, "mamba" },
@@ -165,6 +166,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
165166
{ LLM_KV_CONTEXT_LENGTH, "%s.context_length" },
166167
{ LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" },
167168
{ LLM_KV_EMBEDDING_LENGTH_OUT, "%s.embedding_length_out" },
169+
{ LLM_KV_EMBEDDING_LENGTH_PER_LAYER, "%s.embedding_length_per_layer_input" },
168170
{ LLM_KV_FEATURES_LENGTH, "%s.features_length" },
169171
{ LLM_KV_BLOCK_COUNT, "%s.block_count" },
170172
{ LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" },
@@ -238,6 +240,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
238240
{ LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, "%s.attention.indexer.head_count" },
239241
{ LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, "%s.attention.indexer.key_length" },
240242
{ LLM_KV_ATTENTION_INDEXER_TOP_K, "%s.attention.indexer.top_k" },
243+
{ LLM_KV_ATTENTION_SHARED_KV_LAYERS, "%s.attention.shared_kv_layers" },
241244

242245
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
243246
{ LLM_KV_ROPE_DIMENSION_COUNT_SWA, "%s.rope.dimension_count_swa" },
@@ -364,6 +367,9 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
364367
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
365368
{ LLM_TENSOR_ATTN_GATE, "blk.%d.attn_gate" },
366369
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
370+
{ LLM_TENSOR_FFN_POST_NORM_1, "blk.%d.post_ffw_norm_1" },
371+
{ LLM_TENSOR_FFN_POST_NORM_2, "blk.%d.post_ffw_norm_2" },
372+
{ LLM_TENSOR_FFN_PRE_NORM_2, "blk.%d.pre_ffw_norm_2" },
367373
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
368374
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
369375
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
@@ -373,6 +379,7 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
373379
{ LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
374380
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
375381
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
382+
{ LLM_TENSOR_LAYER_OUT_SCALE, "blk.%d.layer_output_scale" },
376383
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
377384
{ LLM_TENSOR_POS_EMBD, "position_embd" },
378385
{ LLM_TENSOR_FFN_ACT, "blk.%d.ffn.act" },
@@ -557,6 +564,8 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
557564
LLM_TENSOR_OUTPUT_NORM,
558565
LLM_TENSOR_OUTPUT,
559566
LLM_TENSOR_ROPE_FREQS,
567+
LLM_TENSOR_ROPE_FACTORS_LONG,
568+
LLM_TENSOR_ROPE_FACTORS_SHORT,
560569
LLM_TENSOR_ATTN_NORM,
561570
LLM_TENSOR_ATTN_Q,
562571
LLM_TENSOR_ATTN_K,
@@ -1340,6 +1349,38 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
13401349
LLM_TENSOR_LAUREL_R,
13411350
LLM_TENSOR_LAUREL_POST_NORM,
13421351
};
1352+
case LLM_ARCH_GEMMA4:
1353+
return {
1354+
LLM_TENSOR_ROPE_FREQS,
1355+
LLM_TENSOR_TOKEN_EMBD,
1356+
LLM_TENSOR_OUTPUT_NORM,
1357+
LLM_TENSOR_ATTN_NORM,
1358+
LLM_TENSOR_ATTN_Q,
1359+
LLM_TENSOR_ATTN_Q_NORM,
1360+
LLM_TENSOR_ATTN_K,
1361+
LLM_TENSOR_ATTN_K_NORM,
1362+
LLM_TENSOR_ATTN_V,
1363+
LLM_TENSOR_ATTN_OUT,
1364+
LLM_TENSOR_ATTN_POST_NORM,
1365+
LLM_TENSOR_FFN_NORM,
1366+
LLM_TENSOR_FFN_GATE,
1367+
LLM_TENSOR_FFN_DOWN,
1368+
LLM_TENSOR_FFN_UP,
1369+
LLM_TENSOR_FFN_GATE_UP_EXPS,
1370+
LLM_TENSOR_FFN_DOWN_EXPS,
1371+
LLM_TENSOR_FFN_GATE_INP,
1372+
LLM_TENSOR_FFN_POST_NORM,
1373+
LLM_TENSOR_FFN_POST_NORM_1,
1374+
LLM_TENSOR_FFN_POST_NORM_2,
1375+
LLM_TENSOR_FFN_PRE_NORM_2,
1376+
LLM_TENSOR_LAYER_OUT_SCALE,
1377+
LLM_TENSOR_PER_LAYER_TOKEN_EMBD,
1378+
LLM_TENSOR_PER_LAYER_MODEL_PROJ,
1379+
LLM_TENSOR_PER_LAYER_PROJ_NORM,
1380+
LLM_TENSOR_PER_LAYER_INP_GATE,
1381+
LLM_TENSOR_PER_LAYER_PROJ,
1382+
LLM_TENSOR_PER_LAYER_POST_NORM,
1383+
};
13431384
case LLM_ARCH_GEMMA_EMBEDDING:
13441385
return {
13451386
LLM_TENSOR_TOKEN_EMBD,
@@ -2652,11 +2693,15 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
26522693
{LLM_TENSOR_ATTN_OUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
26532694
{LLM_TENSOR_ATTN_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
26542695
{LLM_TENSOR_FFN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
2696+
{LLM_TENSOR_FFN_PRE_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
2697+
{LLM_TENSOR_FFN_POST_NORM_1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
2698+
{LLM_TENSOR_FFN_POST_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
26552699
{LLM_TENSOR_FFN_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
26562700
{LLM_TENSOR_FFN_NORM_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
26572701
{LLM_TENSOR_ATTN_Q_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
26582702
{LLM_TENSOR_ATTN_K_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
26592703
{LLM_TENSOR_LAYER_OUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
2704+
{LLM_TENSOR_LAYER_OUT_SCALE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
26602705
{LLM_TENSOR_ATTN_Q_A_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
26612706
{LLM_TENSOR_ATTN_KV_A_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
26622707
{LLM_TENSOR_ATTN_SUB_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},

src/llama-arch.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ enum llm_arch {
6060
LLM_ARCH_GEMMA2,
6161
LLM_ARCH_GEMMA3,
6262
LLM_ARCH_GEMMA3N,
63+
LLM_ARCH_GEMMA4,
6364
LLM_ARCH_GEMMA_EMBEDDING,
6465
LLM_ARCH_STARCODER2,
6566
LLM_ARCH_MAMBA,
@@ -169,6 +170,7 @@ enum llm_kv {
169170
LLM_KV_CONTEXT_LENGTH,
170171
LLM_KV_EMBEDDING_LENGTH,
171172
LLM_KV_EMBEDDING_LENGTH_OUT,
173+
LLM_KV_EMBEDDING_LENGTH_PER_LAYER,
172174
LLM_KV_FEATURES_LENGTH,
173175
LLM_KV_BLOCK_COUNT,
174176
LLM_KV_LEADING_DENSE_BLOCK_COUNT,
@@ -242,6 +244,7 @@ enum llm_kv {
242244
LLM_KV_ATTENTION_INDEXER_HEAD_COUNT,
243245
LLM_KV_ATTENTION_INDEXER_KEY_LENGTH,
244246
LLM_KV_ATTENTION_INDEXER_TOP_K,
247+
LLM_KV_ATTENTION_SHARED_KV_LAYERS,
245248

246249
LLM_KV_ROPE_DIMENSION_COUNT,
247250
LLM_KV_ROPE_DIMENSION_COUNT_SWA,
@@ -369,6 +372,9 @@ enum llm_tensor {
369372
LLM_TENSOR_FFN_GATE_INP_SHEXP,
370373
LLM_TENSOR_FFN_NORM,
371374
LLM_TENSOR_FFN_POST_NORM,
375+
LLM_TENSOR_FFN_POST_NORM_1,
376+
LLM_TENSOR_FFN_POST_NORM_2,
377+
LLM_TENSOR_FFN_PRE_NORM_2,
372378
LLM_TENSOR_FFN_GATE,
373379
LLM_TENSOR_FFN_DOWN,
374380
LLM_TENSOR_FFN_UP,
@@ -393,6 +399,7 @@ enum llm_tensor {
393399
LLM_TENSOR_ATTN_Q_NORM,
394400
LLM_TENSOR_ATTN_K_NORM,
395401
LLM_TENSOR_LAYER_OUT_NORM,
402+
LLM_TENSOR_LAYER_OUT_SCALE,
396403
LLM_TENSOR_POST_ATTN_NORM,
397404
LLM_TENSOR_POST_MLP_NORM,
398405
LLM_TENSOR_PER_LAYER_TOKEN_EMBD, // gemma3n

src/llama-hparams.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,9 @@ struct llama_hparams {
209209
// qwen3vl deepstack
210210
uint32_t n_deepstack_layers = 0;
211211

212+
// gemma4 per-layer embedding
213+
uint32_t n_embd_per_layer = 0;
214+
212215
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
213216
// ref: https://github.com/ggml-org/llama.cpp/pull/8141
214217
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;

src/llama-model.cpp

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1261,6 +1261,32 @@ void llama_model::load_hparams(llama_model_loader & ml) {
12611261
default: type = LLM_TYPE_UNKNOWN;
12621262
}
12631263
} break;
1264+
case LLM_ARCH_GEMMA4:
1265+
{
1266+
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
1267+
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer);
1268+
1269+
uint32_t n_kv_shared_layers = 0;
1270+
ml.get_key(LLM_KV_ATTENTION_SHARED_KV_LAYERS, n_kv_shared_layers, false);
1271+
1272+
hparams.n_layer_kv_from_start = hparams.n_layer - (int32_t)n_kv_shared_layers;
1273+
hparams.f_attention_scale = 1.0f; // Gemma4 uses self.scaling = 1.0 (no pre-attn scaling)
1274+
1275+
ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
1276+
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
1277+
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
1278+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1279+
ml.get_key(LLM_KV_EMBEDDING_LENGTH_PER_LAYER, hparams.n_embd_per_layer);
1280+
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa);
1281+
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa);
1282+
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
1283+
1284+
switch (hparams.n_layer) {
1285+
case 35: type = LLM_TYPE_E2B; break;
1286+
case 42: type = LLM_TYPE_E4B; break; // to confirm: E4B or E5B?
1287+
default: type = LLM_TYPE_UNKNOWN;
1288+
}
1289+
} break;
12641290
case LLM_ARCH_GEMMA_EMBEDDING:
12651291
{
12661292
hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC;
@@ -4229,6 +4255,101 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
42294255
layer.laurel_post_norm = create_tensor(tn(LLM_TENSOR_LAUREL_POST_NORM, "weight", i), {n_embd}, 0);
42304256
}
42314257
} break;
4258+
case LLM_ARCH_GEMMA4:
4259+
{
4260+
const uint32_t n_embd_per_layer = hparams.n_embd_per_layer;
4261+
const int64_t n_ff_exp = hparams.n_ff_exp;
4262+
4263+
if (n_embd_head_k != n_embd_head_v) {
4264+
throw std::runtime_error("Gemma 4 requires n_embd_head_k == n_embd_head_v");
4265+
}
4266+
if (hparams.n_embd_head_k_swa != hparams.n_embd_head_v_swa) {
4267+
throw std::runtime_error("Gemma 4 requires n_embd_head_k_swa == n_embd_head_v_swa");
4268+
}
4269+
4270+
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
4271+
// if output is NULL, init from the input tok embed
4272+
if (output == NULL) {
4273+
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
4274+
}
4275+
4276+
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
4277+
4278+
if (n_embd_per_layer > 0) {
4279+
tok_embd_per_layer = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_per_layer * n_layer, n_vocab}, 0);
4280+
per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight"), {n_embd, n_embd_per_layer * n_layer}, 0);
4281+
per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight"), {n_embd_per_layer}, 0);
4282+
}
4283+
4284+
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
4285+
4286+
int rope_freqs_flag = 0;
4287+
4288+
for (int i = 0; i < n_layer; ++i) {
4289+
auto & layer = layers[i];
4290+
const int64_t n_head = hparams.n_head(i);
4291+
const int64_t n_embd_head = hparams.n_embd_head_k(i);
4292+
const int64_t n_embd_k = hparams.n_embd_k_gqa(i);
4293+
const int64_t n_embd_v = hparams.n_embd_v_gqa(i);
4294+
const int kv_flags = hparams.has_kv(i) ? 0 : TENSOR_NOT_REQUIRED;
4295+
4296+
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
4297+
4298+
// note: use_alternative_attention (v_proj is optional, if it's not present, use k_proj)
4299+
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head * n_head}, 0);
4300+
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k}, kv_flags);
4301+
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v}, TENSOR_NOT_REQUIRED);
4302+
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head * n_head, n_embd}, 0);
4303+
4304+
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head}, 0);
4305+
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head}, kv_flags);
4306+
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
4307+
4308+
layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), {1u}, TENSOR_NOT_REQUIRED);
4309+
4310+
if (!hparams.is_swa(i)) {
4311+
// full_attention layers use rope_freqs for proportional rope
4312+
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_embd_head/2}, rope_freqs_flag);
4313+
rope_freqs_flag = TENSOR_DUPLICATED;
4314+
}
4315+
4316+
// handle use_double_wide_mlp
4317+
int64_t n_ff_cur = hparams.n_ff(i);
4318+
4319+
// for expert layers, we use normal FFN as shared expert (same as python code)
4320+
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
4321+
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff_cur}, 0);
4322+
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff_cur}, 0);
4323+
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff_cur, n_embd}, 0);
4324+
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
4325+
4326+
// MoE router
4327+
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED);
4328+
bool has_expert = layer.ffn_gate_inp != nullptr;
4329+
4330+
// norm
4331+
if (has_expert) {
4332+
layer.ffn_gate_inp_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "scale", i), {n_embd}, 0);
4333+
4334+
layer.ffn_pre_norm_2 = create_tensor(tn(LLM_TENSOR_FFN_PRE_NORM_2, "weight", i), {n_embd}, 0);
4335+
layer.ffn_post_norm_1 = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM_1, "weight", i), {n_embd}, 0);
4336+
layer.ffn_post_norm_2 = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM_2, "weight", i), {n_embd}, 0);
4337+
4338+
// MoE FFN
4339+
layer.ffn_gate_up_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_UP_EXPS, "weight", i), {n_embd, n_ff_exp * 2, n_expert}, 0);
4340+
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
4341+
4342+
// per-expert scale will be loaded as down_exps_s at the end of the current switch case
4343+
}
4344+
4345+
// per-layer embeddings
4346+
if (n_embd_per_layer > 0) {
4347+
layer.per_layer_inp_gate = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE, "weight", i), {n_embd, n_embd_per_layer}, 0);
4348+
layer.per_layer_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ, "weight", i), {n_embd_per_layer, n_embd}, 0);
4349+
layer.per_layer_post_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0);
4350+
}
4351+
}
4352+
} break;
42324353
case LLM_ARCH_STARCODER2:
42334354
{
42344355
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -8233,7 +8354,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
82338354
} else {
82348355
llama_memory_i::layer_reuse_cb reuse = nullptr;
82358356

8236-
if (arch == LLM_ARCH_GEMMA3N) {
8357+
if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) {
82378358
reuse = [&](int32_t il) {
82388359
if (il >= (int32_t) hparams.n_layer_kv_from_start) {
82398360
return (int32_t) hparams.n_layer_kv_from_start - (hparams.is_swa(il) ? 2 : 1);
@@ -8486,6 +8607,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
84868607
{
84878608
llm = std::make_unique<llm_build_gemma3n_iswa>(*this, params);
84888609
} break;
8610+
case LLM_ARCH_GEMMA4:
8611+
{
8612+
llm = std::make_unique<llm_build_gemma4_iswa>(*this, params);
8613+
} break;
84898614
case LLM_ARCH_GEMMA_EMBEDDING:
84908615
{
84918616
llm = std::make_unique<llm_build_gemma_embedding>(*this, params);
@@ -9006,6 +9131,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
90069131
case LLM_ARCH_GEMMA2:
90079132
case LLM_ARCH_GEMMA3:
90089133
case LLM_ARCH_GEMMA3N:
9134+
case LLM_ARCH_GEMMA4:
90099135
case LLM_ARCH_GEMMA_EMBEDDING:
90109136
case LLM_ARCH_STARCODER2:
90119137
case LLM_ARCH_OPENELM:

src/llama-model.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,9 @@ struct llama_layer {
270270
struct ggml_tensor * ffn_norm = nullptr;
271271
struct ggml_tensor * ffn_norm_b = nullptr;
272272
struct ggml_tensor * ffn_post_norm = nullptr;
273+
struct ggml_tensor * ffn_post_norm_1 = nullptr; // gemma4
274+
struct ggml_tensor * ffn_post_norm_2 = nullptr; // gemma4
275+
struct ggml_tensor * ffn_pre_norm_2 = nullptr; // gemma4
273276
struct ggml_tensor * layer_out_norm = nullptr;
274277
struct ggml_tensor * layer_out_norm_b = nullptr;
275278
struct ggml_tensor * ffn_norm_exps = nullptr;
@@ -285,6 +288,7 @@ struct llama_layer {
285288

286289
// ff MoE
287290
struct ggml_tensor * ffn_gate_inp = nullptr;
291+
struct ggml_tensor * ffn_gate_inp_s = nullptr; // gemma4
288292
struct ggml_tensor * ffn_gate_exps = nullptr;
289293
struct ggml_tensor * ffn_down_exps = nullptr;
290294
struct ggml_tensor * ffn_up_exps = nullptr;
@@ -483,6 +487,9 @@ struct llama_layer {
483487
struct ggml_tensor * indexer_attn_k = nullptr;
484488
struct ggml_tensor * indexer_attn_q_b = nullptr; // note: for lora a/b, not bias
485489

490+
// gemma4 layer output scale
491+
struct ggml_tensor * out_scale = nullptr;
492+
486493
struct llama_layer_posnet posnet;
487494

488495
struct llama_layer_convnext convnext;

0 commit comments

Comments
 (0)