Skip to content

Commit 7b40644

Browse files
stephencoxclaude
andcommitted
gemma4: fix audio encoder and LM precision issues
Audio encoder fixes: - Fix swapped conv norm weight mapping in tensor_mapping.py (A_ENC_CONV_NORM and A_ENC_NORM_CONV had their gemma4 entries inverted, causing the conv pre-norm and internal norm weights to be swapped in GGUF. This produced 0.67 encoder cosine vs PyTorch; now 0.9999) - Fix causal mask off-by-one: add (gq - gk) < max_past to match PyTorch's dist < left_window_size (was attending to 13 past tokens instead of 12) - Use -1e9 instead of -INFINITY for masked positions to match PyTorch's attention_invalid_logits_value and avoid NaN in padded attention weights LM fixes: - Disable attention logit softcapping for Gemma4 (unlike Gemma2, Gemma4's text model does not use attn softcapping; was incorrectly hardcoded) - Use BF16-rounded embedding scale constants to match PyTorch's native BF16 training precision (ref: PR ggml-org#21451). Fixes long-context coherence on CPU/Vulkan backends. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 30a1937 commit 7b40644

4 files changed

Lines changed: 14 additions & 10 deletions

File tree

gguf-py/gguf/tensor_mapping.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2056,22 +2056,22 @@ class TensorNameMap:
20562056

20572057
MODEL_TENSOR.A_ENC_CONV_NORM: (
20582058
"conformer.layers.{bid}.conv.batch_norm", # lfm2
2059-
"conformer.layers.{bid}.lconv1d.pre_layer_norm", # gemma3n
2059+
"conformer.layers.{bid}.lconv1d.conv_norm", # gemma4
20602060
),
20612061

20622062
MODEL_TENSOR.A_ENC_CONV_PW1: (
20632063
"conformer.layers.{bid}.conv.pointwise_conv1", # lfm2
2064-
"conformer.layers.{bid}.lconv1d.linear_start", # gemma3n
2064+
"conformer.layers.{bid}.lconv1d.linear_start", # gemma4
20652065
),
20662066

20672067
MODEL_TENSOR.A_ENC_CONV_PW2: (
20682068
"conformer.layers.{bid}.conv.pointwise_conv2", # lfm2
2069-
"conformer.layers.{bid}.lconv1d.linear_end", # gemma3n
2069+
"conformer.layers.{bid}.lconv1d.linear_end", # gemma4
20702070
),
20712071

20722072
MODEL_TENSOR.A_ENC_NORM_CONV: (
20732073
"conformer.layers.{bid}.norm_conv", # lfm2
2074-
"conformer.layers.{bid}.lconv1d.conv_norm", # gemma3n
2074+
"conformer.layers.{bid}.lconv1d.pre_layer_norm", # gemma4
20752075
),
20762076

20772077
MODEL_TENSOR.A_PER_DIM_K_SCALE: (

src/llama-model.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1186,14 +1186,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
11861186
uint32_t swa_period = 2;
11871187
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false);
11881188
hparams.set_swa_pattern(swa_period);
1189-
hparams.attn_soft_cap = true;
11901189
hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train;
11911190
hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train;
11921191

11931192
ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
11941193
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
11951194
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1196-
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
1195+
// Gemma4 does NOT use attention logit softcapping (unlike Gemma2)
1196+
hparams.f_attn_logit_softcapping = 0.0f;
1197+
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
1198+
hparams.attn_soft_cap = (hparams.f_attn_logit_softcapping > 0.0f);
11971199
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
11981200

11991201
switch (hparams.n_layer) {

src/models/gemma4-iswa.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const ll
1717
inpL = build_inp_embd(model.tok_embd);
1818

1919
// important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
20-
inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f);
20+
// use BF16-rounded scale to match PyTorch's native BF16 training precision (ref: PR #21451)
21+
inpL = ggml_scale(ctx0, inpL, ubatch.token ? ggml_bf16_to_fp32(ggml_fp32_to_bf16(sqrtf(n_embd))) : 1.0f);
2122
cb(inpL, "inp_scaled", -1);
2223

2324
// inp_pos - contains the positions
@@ -149,8 +150,9 @@ llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const ll
149150
cb(cur_moe, "ffn_norm_2", il);
150151

151152
// custom MoE logits calculation (router operates on attn_out, not cur)
153+
// use BF16-rounded scale to match PyTorch's native BF16 training precision (ref: PR #21451)
152154
ggml_tensor * tmp = ggml_rms_norm(ctx0, attn_out, hparams.f_norm_rms_eps);
153-
tmp = ggml_scale(ctx0, tmp, 1.0f / sqrtf((float) n_embd));
155+
tmp = ggml_scale(ctx0, tmp, 1.0f / ggml_bf16_to_fp32(ggml_fp32_to_bf16(sqrtf((float) n_embd))));
154156
tmp = ggml_mul(ctx0, tmp, model.layers[il].ffn_gate_inp_s);
155157
ggml_tensor * logits = build_lora_mm(model.layers[il].ffn_gate_inp, tmp); // [n_expert, n_tokens]
156158
cb(logits, "ffn_moe_logits", il);

tools/mtmd/clip.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3392,13 +3392,13 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
33923392

33933393
// Blocked causal attention mask: [context_size, chunk_size, num_blocks]
33943394
{
3395-
std::vector<float> mask(context_size * chunk_size * num_blocks, -INFINITY);
3395+
std::vector<float> mask(context_size * chunk_size * num_blocks, -1e9f);
33963396
for (int b = 0; b < num_blocks; b++) {
33973397
for (int q = 0; q < chunk_size; q++) {
33983398
int gq = b * chunk_size + q;
33993399
for (int k = 0; k < context_size; k++) {
34003400
int gk = b * chunk_size - max_past + k;
3401-
if (gq < n_pos && gk >= 0 && gk < n_pos && gk <= gq) {
3401+
if (gq < n_pos && gk >= 0 && gk < n_pos && gk <= gq && (gq - gk) < max_past) {
34023402
mask[k + q * context_size + b * context_size * chunk_size] = 0.0f;
34033403
}
34043404
}

0 commit comments

Comments
 (0)