@@ -1261,6 +1261,32 @@ void llama_model::load_hparams(llama_model_loader & ml) {
12611261 default: type = LLM_TYPE_UNKNOWN;
12621262 }
12631263 } break;
1264+ case LLM_ARCH_GEMMA4:
1265+ {
1266+ hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
1267+ ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer);
1268+
1269+ uint32_t n_kv_shared_layers = 0;
1270+ ml.get_key(LLM_KV_ATTENTION_SHARED_KV_LAYERS, n_kv_shared_layers, false);
1271+
1272+ hparams.n_layer_kv_from_start = hparams.n_layer - (int32_t)n_kv_shared_layers;
1273+ hparams.f_attention_scale = 1.0f; // Gemma4 uses self.scaling = 1.0 (no pre-attn scaling)
1274+
1275+ ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
1276+ ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
1277+ ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
1278+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1279+ ml.get_key(LLM_KV_EMBEDDING_LENGTH_PER_LAYER, hparams.n_embd_per_layer);
1280+ ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa);
1281+ ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa);
1282+ ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
1283+
1284+ switch (hparams.n_layer) {
1285+ case 35: type = LLM_TYPE_E2B; break;
1286+ case 42: type = LLM_TYPE_E4B; break; // to confirm: E4B or E5B?
1287+ default: type = LLM_TYPE_UNKNOWN;
1288+ }
1289+ } break;
12641290 case LLM_ARCH_GEMMA_EMBEDDING:
12651291 {
12661292 hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC;
@@ -4229,6 +4255,101 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
42294255 layer.laurel_post_norm = create_tensor(tn(LLM_TENSOR_LAUREL_POST_NORM, "weight", i), {n_embd}, 0);
42304256 }
42314257 } break;
4258+ case LLM_ARCH_GEMMA4:
4259+ {
4260+ const uint32_t n_embd_per_layer = hparams.n_embd_per_layer;
4261+ const int64_t n_ff_exp = hparams.n_ff_exp;
4262+
4263+ if (n_embd_head_k != n_embd_head_v) {
4264+ throw std::runtime_error("Gemma 4 requires n_embd_head_k == n_embd_head_v");
4265+ }
4266+ if (hparams.n_embd_head_k_swa != hparams.n_embd_head_v_swa) {
4267+ throw std::runtime_error("Gemma 4 requires n_embd_head_k_swa == n_embd_head_v_swa");
4268+ }
4269+
4270+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
4271+ // if output is NULL, init from the input tok embed
4272+ if (output == NULL) {
4273+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
4274+ }
4275+
4276+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
4277+
4278+ if (n_embd_per_layer > 0) {
4279+ tok_embd_per_layer = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_per_layer * n_layer, n_vocab}, 0);
4280+ per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight"), {n_embd, n_embd_per_layer * n_layer}, 0);
4281+ per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight"), {n_embd_per_layer}, 0);
4282+ }
4283+
4284+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
4285+
4286+ int rope_freqs_flag = 0;
4287+
4288+ for (int i = 0; i < n_layer; ++i) {
4289+ auto & layer = layers[i];
4290+ const int64_t n_head = hparams.n_head(i);
4291+ const int64_t n_embd_head = hparams.n_embd_head_k(i);
4292+ const int64_t n_embd_k = hparams.n_embd_k_gqa(i);
4293+ const int64_t n_embd_v = hparams.n_embd_v_gqa(i);
4294+ const int kv_flags = hparams.has_kv(i) ? 0 : TENSOR_NOT_REQUIRED;
4295+
4296+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
4297+
4298+ // note: use_alternative_attention (v_proj is optional, if it's not present, use k_proj)
4299+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head * n_head}, 0);
4300+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k}, kv_flags);
4301+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v}, TENSOR_NOT_REQUIRED);
4302+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head * n_head, n_embd}, 0);
4303+
4304+ layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head}, 0);
4305+ layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head}, kv_flags);
4306+ layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
4307+
4308+ layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), {1u}, TENSOR_NOT_REQUIRED);
4309+
4310+ if (!hparams.is_swa(i)) {
4311+ // full_attention layers use rope_freqs for proportional rope
4312+ layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_embd_head/2}, rope_freqs_flag);
4313+ rope_freqs_flag = TENSOR_DUPLICATED;
4314+ }
4315+
4316+ // handle use_double_wide_mlp
4317+ int64_t n_ff_cur = hparams.n_ff(i);
4318+
4319+ // for expert layers, we use normal FFN as shared expert (same as python code)
4320+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
4321+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff_cur}, 0);
4322+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff_cur}, 0);
4323+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff_cur, n_embd}, 0);
4324+ layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
4325+
4326+ // MoE router
4327+ layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED);
4328+ bool has_expert = layer.ffn_gate_inp != nullptr;
4329+
4330+ // norm
4331+ if (has_expert) {
4332+ layer.ffn_gate_inp_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "scale", i), {n_embd}, 0);
4333+
4334+ layer.ffn_pre_norm_2 = create_tensor(tn(LLM_TENSOR_FFN_PRE_NORM_2, "weight", i), {n_embd}, 0);
4335+ layer.ffn_post_norm_1 = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM_1, "weight", i), {n_embd}, 0);
4336+ layer.ffn_post_norm_2 = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM_2, "weight", i), {n_embd}, 0);
4337+
4338+ // MoE FFN
4339+ layer.ffn_gate_up_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_UP_EXPS, "weight", i), {n_embd, n_ff_exp * 2, n_expert}, 0);
4340+ layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
4341+
4342+ // per-expert scale will be loaded as down_exps_s at the end of the current switch case
4343+ }
4344+
4345+ // per-layer embeddings
4346+ if (n_embd_per_layer > 0) {
4347+ layer.per_layer_inp_gate = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE, "weight", i), {n_embd, n_embd_per_layer}, 0);
4348+ layer.per_layer_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ, "weight", i), {n_embd_per_layer, n_embd}, 0);
4349+ layer.per_layer_post_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0);
4350+ }
4351+ }
4352+ } break;
42324353 case LLM_ARCH_STARCODER2:
42334354 {
42344355 tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -8233,7 +8354,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
82338354 } else {
82348355 llama_memory_i::layer_reuse_cb reuse = nullptr;
82358356
8236- if (arch == LLM_ARCH_GEMMA3N) {
8357+ if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4 ) {
82378358 reuse = [&](int32_t il) {
82388359 if (il >= (int32_t) hparams.n_layer_kv_from_start) {
82398360 return (int32_t) hparams.n_layer_kv_from_start - (hparams.is_swa(il) ? 2 : 1);
@@ -8486,6 +8607,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
84868607 {
84878608 llm = std::make_unique<llm_build_gemma3n_iswa>(*this, params);
84888609 } break;
8610+ case LLM_ARCH_GEMMA4:
8611+ {
8612+ llm = std::make_unique<llm_build_gemma4_iswa>(*this, params);
8613+ } break;
84898614 case LLM_ARCH_GEMMA_EMBEDDING:
84908615 {
84918616 llm = std::make_unique<llm_build_gemma_embedding>(*this, params);
@@ -9006,6 +9131,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
90069131 case LLM_ARCH_GEMMA2:
90079132 case LLM_ARCH_GEMMA3:
90089133 case LLM_ARCH_GEMMA3N:
9134+ case LLM_ARCH_GEMMA4:
90099135 case LLM_ARCH_GEMMA_EMBEDDING:
90109136 case LLM_ARCH_STARCODER2:
90119137 case LLM_ARCH_OPENELM:
0 commit comments