Skip to content

Commit 2852c7c

Browse files
stephencoxclaude
andcommitted
gemma4: fix audio encoder and LM precision issues
Audio encoder fixes: - Fix swapped conv norm weight mapping in tensor_mapping.py (A_ENC_CONV_NORM and A_ENC_NORM_CONV had their gemma4 entries inverted, causing the conv pre-norm and internal norm weights to be swapped in GGUF. This produced 0.67 encoder cosine vs PyTorch; now 0.9999) - Fix causal mask off-by-one: add (gq - gk) < max_past to match PyTorch's dist < left_window_size (was attending to 13 past tokens instead of 12) - Use -1e9 instead of -INFINITY for masked positions to match PyTorch's attention_invalid_logits_value and avoid NaN in padded attention weights LM fixes: - Disable attention logit softcapping for Gemma4 (unlike Gemma2, Gemma4's text model does not use attn softcapping; was incorrectly hardcoded) - Use BF16-rounded embedding scale constants to match PyTorch's native BF16 training precision (ref: PR ggml-org#21451). Fixes long-context coherence on CPU/Vulkan backends. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 2ac0112 commit 2852c7c

4 files changed

Lines changed: 15 additions & 11 deletions

File tree

gguf-py/gguf/tensor_mapping.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2045,22 +2045,22 @@ class TensorNameMap:
20452045

20462046
MODEL_TENSOR.A_ENC_CONV_NORM: (
20472047
"conformer.layers.{bid}.conv.batch_norm", # lfm2
2048-
"conformer.layers.{bid}.lconv1d.pre_layer_norm", # gemma3n
2048+
"conformer.layers.{bid}.lconv1d.conv_norm", # gemma4
20492049
),
20502050

20512051
MODEL_TENSOR.A_ENC_CONV_PW1: (
20522052
"conformer.layers.{bid}.conv.pointwise_conv1", # lfm2
2053-
"conformer.layers.{bid}.lconv1d.linear_start", # gemma3n
2053+
"conformer.layers.{bid}.lconv1d.linear_start", # gemma4
20542054
),
20552055

20562056
MODEL_TENSOR.A_ENC_CONV_PW2: (
20572057
"conformer.layers.{bid}.conv.pointwise_conv2", # lfm2
2058-
"conformer.layers.{bid}.lconv1d.linear_end", # gemma3n
2058+
"conformer.layers.{bid}.lconv1d.linear_end", # gemma4
20592059
),
20602060

20612061
MODEL_TENSOR.A_ENC_NORM_CONV: (
20622062
"conformer.layers.{bid}.norm_conv", # lfm2
2063-
"conformer.layers.{bid}.lconv1d.conv_norm", # gemma3n
2063+
"conformer.layers.{bid}.lconv1d.pre_layer_norm", # gemma4
20642064
),
20652065

20662066
MODEL_TENSOR.A_PER_DIM_K_SCALE: (

src/llama-model.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1186,14 +1186,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
11861186
uint32_t swa_period = 2;
11871187
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false);
11881188
hparams.set_swa_pattern(swa_period);
1189-
hparams.attn_soft_cap = true;
11901189
hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train;
11911190
hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train;
11921191

11931192
ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
11941193
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
11951194
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1196-
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
1195+
// Gemma4 does NOT use attention logit softcapping (unlike Gemma2)
1196+
hparams.f_attn_logit_softcapping = 0.0f;
1197+
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
1198+
hparams.attn_soft_cap = (hparams.f_attn_logit_softcapping > 0.0f);
11971199
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
11981200

11991201
switch (hparams.n_layer) {

src/models/gemma4-iswa.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const ll
1010
inpL = build_inp_embd(model.tok_embd);
1111

1212
// important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
13-
inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f);
13+
// use BF16-rounded scale to match PyTorch's native BF16 training precision (ref: PR #21451)
14+
inpL = ggml_scale(ctx0, inpL, ubatch.token ? ggml_bf16_to_fp32(ggml_fp32_to_bf16(sqrtf(n_embd))) : 1.0f);
1415
cb(inpL, "inp_scaled", -1);
1516

1617
// inp_pos - contains the positions
@@ -139,8 +140,9 @@ llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const ll
139140
cb(cur_moe, "ffn_norm_2", il);
140141

141142
// custom MoE logits calculation (router operates on attn_out, not cur)
143+
// use BF16-rounded scale to match PyTorch's native BF16 training precision (ref: PR #21451)
142144
ggml_tensor * tmp = ggml_rms_norm(ctx0, attn_out, hparams.f_norm_rms_eps);
143-
tmp = ggml_scale(ctx0, tmp, 1.0f / sqrtf((float) n_embd));
145+
tmp = ggml_scale(ctx0, tmp, 1.0f / ggml_bf16_to_fp32(ggml_fp32_to_bf16(sqrtf((float) n_embd))));
144146
tmp = ggml_mul(ctx0, tmp, model.layers[il].ffn_gate_inp_s);
145147
ggml_tensor * logits = build_lora_mm(model.layers[il].ffn_gate_inp, tmp); // [n_expert, n_tokens]
146148
cb(logits, "ffn_moe_logits", il);
@@ -266,7 +268,7 @@ ggml_tensor * llm_build_gemma4_iswa::get_per_layer_inputs() {
266268
res->t_inp_tokens = inp->tokens;
267269
inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens);
268270
inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_per_layer, n_layer, n_tokens);
269-
inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_per_layer));
271+
inp_per_layer = ggml_scale(ctx0, inp_per_layer, ggml_bf16_to_fp32(ggml_fp32_to_bf16(sqrtf((float) n_embd_per_layer))));
270272
cb(inp_per_layer, "inp_per_layer_selected", -1);
271273
res->add_input(std::move(inp));
272274
} else {

tools/mtmd/clip.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3347,13 +3347,13 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
33473347

33483348
// Blocked causal attention mask: [context_size, chunk_size, num_blocks]
33493349
{
3350-
std::vector<float> mask(context_size * chunk_size * num_blocks, -INFINITY);
3350+
std::vector<float> mask(context_size * chunk_size * num_blocks, -1e9f);
33513351
for (int b = 0; b < num_blocks; b++) {
33523352
for (int q = 0; q < chunk_size; q++) {
33533353
int gq = b * chunk_size + q;
33543354
for (int k = 0; k < context_size; k++) {
33553355
int gk = b * chunk_size - max_past + k;
3356-
if (gq < n_pos && gk >= 0 && gk < n_pos && gk <= gq) {
3356+
if (gq < n_pos && gk >= 0 && gk < n_pos && gk <= gq && (gq - gk) < max_past) {
33573357
mask[k + q * context_size + b * context_size * chunk_size] = 0.0f;
33583358
}
33593359
}

0 commit comments

Comments
 (0)