Skip to content

Commit 2ac0112

Browse files
stephencoxclaude
andcommitted
mtmd: add Gemma 4 audio conformer encoder support
Add audio processing for Gemma 4 E2B/E4B via a USM-style Conformer. Architecture: - 12-layer Conformer: FFN → Self-Attention → Causal Conv1D → FFN → Norm - Subsampling Conv Projection: 2x Conv2D(stride=2) with LayerNorm - Full self-attention with sinusoidal RPE and sliding window mask (24) - Logit softcapping at 50.0, ClippableLinear clamping - Output: 1024 → 1536 → RMSNorm → multimodal embedder Mel preprocessing (dedicated mtmd_audio_preprocessor_gemma4a): - HTK mel scale, 128 bins, magnitude STFT, mel_floor=1e-3 - Standard periodic Hann window (320 samples), zero-padded to FFT size - Semicausal left-padding (frame_length/2 samples) - Frame count matched to PyTorch (unfold formula) - No pre-emphasis, no Whisper-style normalization - Mel cosine similarity vs PyTorch: 0.9998 Key fixes: - Tensor loading dedup: prevent get_tensor() from creating duplicate entries in ctx_data. Fixed with std::set guard. - ClippableLinear clamp_info loading moved after per-layer tensors. - Sliding window mask (24 positions) matching PyTorch context_size. - Skip Whisper normalization for Gemma4 mel output. Tested on E2B and E4B with CPU and Vulkan backends. Transcribes: "Glad to see things are going well and business is starting to pick up" (matching ground truth). Ref: ggml-org#21325 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent d0a6dfe commit 2ac0112

11 files changed

Lines changed: 645 additions & 31 deletions

File tree

ggml/src/ggml-cuda/ssm-conv.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,9 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
134134
switch (nc) {
135135
case 3: launch_kernel(std::integral_constant<int, 3>{}); break;
136136
case 4: launch_kernel(std::integral_constant<int, 4>{}); break;
137+
case 5: launch_kernel(std::integral_constant<int, 5>{}); break;
137138
case 9: launch_kernel(std::integral_constant<int, 9>{}); break;
138-
default: GGML_ABORT("Only support kernel sizes 3, 4, 9 right now.");
139+
default: GGML_ABORT("Only support kernel sizes 3, 4, 5, 9 right now.");
139140
}
140141
}
141142

tests/test-llama-archs.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ static gguf_context_ptr get_gguf_ctx(const llm_arch arch, const bool moe) {
8686
uint32_t n_layer = 2;
8787
if (arch == LLM_ARCH_LLAMA4) {
8888
n_layer = 4; // hparams.n_no_rope_layer_step is hard-coded to 4
89+
} else if (arch == LLM_ARCH_GEMMA4) {
90+
n_embd = 128;
91+
n_head = 2;
92+
n_ff = 192;
93+
n_layer = 5; // need at least 5 for swa_pattern (every 5th is full_attention)
8994
} else if (arch == LLM_ARCH_GEMMA3N) {
9095
n_embd = 64;
9196
n_head = 1;
@@ -167,7 +172,15 @@ static gguf_context_ptr get_gguf_ctx(const llm_arch arch, const bool moe) {
167172
ms.add_kv(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, uint32_t(8));
168173
ms.add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW, n_ctx/8);
169174

170-
if (arch == LLM_ARCH_MIMO2 || arch == LLM_ARCH_STEP35) {
175+
if (arch == LLM_ARCH_GEMMA4) {
176+
ms.add_kv(LLM_KV_EMBEDDING_LENGTH_PER_LAYER, n_embd/2);
177+
ms.add_kv(LLM_KV_ATTENTION_SHARED_KV_LAYERS, uint32_t(0));
178+
ms.add_kv(LLM_KV_ATTENTION_KEY_LENGTH_SWA, n_embd_head);
179+
ms.add_kv(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, n_embd_head);
180+
ms.add_kv(LLM_KV_ROPE_FREQ_BASE_SWA, 10000.0f);
181+
// SWA pattern: every 5th layer is full attention (matches E2B layer_types)
182+
ms.add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, uint32_t(5));
183+
} else if (arch == LLM_ARCH_MIMO2 || arch == LLM_ARCH_STEP35) {
171184
std::vector<uint32_t> pattern;
172185
pattern.reserve(n_layer);
173186
for (uint32_t il = 0; il < n_layer; il++) {
@@ -386,7 +399,7 @@ static int save_models(const llm_arch target_arch, const size_t seed, const ggml
386399
continue; // Only half-implemented and to be removed in the future.
387400
}
388401
if (arch == LLM_ARCH_GEMMA4) {
389-
continue; // FIXME @ngxson
402+
continue; // FIXME: ISWA KV cache initialization needs more fixture params
390403
}
391404
if (arch == LLM_ARCH_RWKV6 || arch == LLM_ARCH_RWKV6QWEN2 || arch == LLM_ARCH_RWKV7 || arch == LLM_ARCH_ARWKV7) {
392405
continue; // FIXME
@@ -455,7 +468,7 @@ static int test_backends(const llm_arch target_arch, const size_t seed, const gg
455468
continue; // Only half-implemented and to be removed in the future.
456469
}
457470
if (arch == LLM_ARCH_GEMMA4) {
458-
continue; // FIXME @ngxson
471+
continue; // FIXME: ISWA KV cache initialization needs more fixture params
459472
}
460473
if (arch == LLM_ARCH_WAVTOKENIZER_DEC) {
461474
continue; // FIXME CUDA backend crashes.

tools/mtmd/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ add_library(mtmd
1717
models/models.h
1818
models/cogvlm.cpp
1919
models/conformer.cpp
20+
models/gemma4a.cpp
2021
models/gemma4v.cpp
2122
models/glm4v.cpp
2223
models/hunyuanocr.cpp

tools/mtmd/clip-impl.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,21 @@
181181
#define TN_CONV_PW1 "%s.blk.%d.conv_pw1.%s"
182182
#define TN_CONV_PW2 "%s.blk.%d.conv_pw2.%s"
183183

184+
// gemma4 audio conformer
185+
#define TN_A_MM_INP_PROJ "mm.a.input_projection.%s"
186+
#define TN_A_MM_SOFT_EMB_N "mm.a.soft_emb_norm.%s"
187+
#define TN_A_INP_PROJ "a.input_projection.%s"
188+
#define TN_A_CONV1D "a.conv1d.%d.%s"
189+
#define TN_A_CONV1D_NORM "a.conv1d.%d.norm.%s"
190+
#define TN_A_OUT_PROJ "a.pre_encode.out.%s"
191+
#define TN_A_ATTN_PRE_NORM "%s.blk.%d.attn_pre_norm.%s"
192+
#define TN_A_ATTN_POST_NORM "%s.blk.%d.attn_post_norm.%s"
193+
#define TN_A_ATTN_K_REL "%s.blk.%d.attn_k_rel.%s"
194+
#define TN_A_PER_DIM_SCALE "%s.blk.%d.per_dim_scale.%s"
195+
#define TN_A_PER_DIM_K_SCALE "%s.blk.%d.per_dim_k_scale.%s"
196+
#define TN_A_FFN_POST_NORM "%s.blk.%d.ffn_post_norm.%s"
197+
#define TN_A_FFN_POST_NORM_1 "%s.blk.%d.ffn_post_norm_1.%s"
198+
184199
// mobilenetv5 (gemma3n) definitions
185200
#define TN_MNV5_STEM_CONV "v.conv_stem.conv.weight"
186201
#define TN_MNV5_STEM_BIAS "v.conv_stem.conv.bias"

tools/mtmd/clip-model.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,13 @@ struct clip_layer {
218218
ggml_tensor * conv_pw2_w = nullptr;
219219
ggml_tensor * conv_pw2_b = nullptr;
220220

221+
// gemma4 audio conformer per-layer
222+
ggml_tensor * attn_pre_norm_w = nullptr;
223+
ggml_tensor * attn_k_rel_w = nullptr;
224+
ggml_tensor * per_dim_scale_w = nullptr;
225+
ggml_tensor * per_dim_k_scale_w = nullptr;
226+
ggml_tensor * ff_post_norm_1_w = nullptr;
227+
221228
bool has_deepstack() const {
222229
return deepstack_fc1_w != nullptr;
223230
}
@@ -460,6 +467,15 @@ struct clip_model {
460467
};
461468
std::map<std::string, clamp_info> clamp_info_map;
462469

470+
// gemma4 audio conformer
471+
std::array<ggml_tensor *, 2> sscp_conv_w = {nullptr};
472+
std::array<ggml_tensor *, 2> sscp_conv_b = {nullptr};
473+
std::array<ggml_tensor *, 2> sscp_norm_w = {nullptr};
474+
ggml_tensor * sscp_inp_proj_w = nullptr;
475+
ggml_tensor * sscp_inp_proj_b = nullptr;
476+
ggml_tensor * audio_out_proj_w = nullptr;
477+
ggml_tensor * audio_out_proj_b = nullptr;
478+
463479
bool audio_has_avgpool() const {
464480
return proj_type == PROJECTOR_TYPE_QWEN2A
465481
|| proj_type == PROJECTOR_TYPE_VOXTRAL

tools/mtmd/clip.cpp

Lines changed: 155 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -922,6 +922,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
922922
{
923923
builder = std::make_unique<clip_graph_conformer>(ctx, img);
924924
} break;
925+
case PROJECTOR_TYPE_GEMMA4A:
926+
{
927+
builder = std::make_unique<clip_graph_gemma4a>(ctx, img);
928+
} break;
925929
case PROJECTOR_TYPE_GLM4V:
926930
{
927931
builder = std::make_unique<clip_graph_glm4v>(ctx, img);
@@ -1429,6 +1433,16 @@ struct clip_model_loader {
14291433
hparams.audio_window_len = 400;
14301434
hparams.audio_hop_len = 160;
14311435
} break;
1436+
case PROJECTOR_TYPE_GEMMA4A:
1437+
{
1438+
// Gemma4 feature_extraction_gemma4.py:
1439+
// frame_length_ms=20 -> 320 samples, n_fft=512, hop=10ms -> 160
1440+
hparams.audio_chunk_len = 0; // no fixed-length padding
1441+
hparams.audio_sample_rate = 16000;
1442+
hparams.audio_n_fft = 512;
1443+
hparams.audio_window_len = 320; // 20ms frame (NOT 25ms/400)
1444+
hparams.audio_hop_len = 160;
1445+
} break;
14321446
case PROJECTOR_TYPE_JANUS_PRO:
14331447
{
14341448
hparams.image_pad_color = {127, 127, 127};
@@ -1531,16 +1545,21 @@ struct clip_model_loader {
15311545
}
15321546

15331547
// helper function
1548+
std::unordered_set<std::string> loaded_tensor_names;
15341549
auto get_tensor = [&](const std::string & name, bool required = true) {
1550+
// Each tensor should only be loaded once; duplicates indicate a bug
1551+
if (loaded_tensor_names.count(name)) {
1552+
throw std::runtime_error(string_format("%s: tensor already loaded: %s\n", __func__, name.c_str()));
1553+
}
15351554
ggml_tensor * cur = ggml_get_tensor(ctx_meta.get(), name.c_str());
15361555
if (!cur && required) {
15371556
throw std::runtime_error(string_format("%s: unable to find tensor %s\n", __func__, name.c_str()));
15381557
}
15391558
if (cur) {
15401559
tensors_to_load.push_back(cur);
1541-
// add tensors to context
15421560
ggml_tensor * data_tensor = ggml_dup_tensor(ctx_clip.ctx_data.get(), cur);
15431561
ggml_set_name(data_tensor, cur->name);
1562+
loaded_tensor_names.insert(name);
15441563
cur = data_tensor;
15451564
}
15461565
return cur;
@@ -2113,6 +2132,74 @@ struct clip_model_loader {
21132132
model.mm_fc_w = get_tensor(string_format(TN_MM_PROJECTOR, "weight"));
21142133
model.mm_fc_b = get_tensor(string_format(TN_MM_PROJECTOR, "bias"));
21152134
} break;
2135+
case PROJECTOR_TYPE_GEMMA4A:
2136+
{
2137+
for (int i = 0; i < 2; i++) {
2138+
model.sscp_conv_w[i] = get_tensor(string_format(TN_A_CONV1D, i, "weight"));
2139+
model.sscp_conv_b[i] = get_tensor(string_format(TN_A_CONV1D, i, "bias"), false);
2140+
model.sscp_norm_w[i] = get_tensor(string_format(TN_A_CONV1D_NORM, i, "weight"), false);
2141+
}
2142+
model.sscp_inp_proj_w = get_tensor(string_format(TN_A_INP_PROJ, "weight"));
2143+
model.sscp_inp_proj_b = get_tensor(string_format(TN_A_INP_PROJ, "bias"), false);
2144+
model.audio_out_proj_w = get_tensor(string_format(TN_A_OUT_PROJ, "weight"), false);
2145+
model.audio_out_proj_b = get_tensor(string_format(TN_A_OUT_PROJ, "bias"), false);
2146+
// audio multimodal embedder (mm.a.* namespace, not mm.*)
2147+
model.mm_soft_emb_norm_w = get_tensor(string_format(TN_A_MM_SOFT_EMB_N, "weight"), false);
2148+
model.mm_input_proj_w = get_tensor(string_format(TN_A_MM_INP_PROJ, "weight"), false);
2149+
2150+
// Per-layer tensors NOT loaded by the generic loop above
2151+
for (int il = 0; il < hparams.n_layer; ++il) {
2152+
auto & layer = model.layers[il];
2153+
2154+
// Gemma4 audio conformer-specific tensors
2155+
layer.ff_norm_w = get_tensor(string_format(TN_FFN_NORM, prefix, il, "weight"));
2156+
layer.attn_pre_norm_w = get_tensor(string_format(TN_A_ATTN_PRE_NORM, prefix, il, "weight"), false);
2157+
layer.per_dim_scale_w = get_tensor(string_format(TN_A_PER_DIM_SCALE, prefix, il, "weight"), false);
2158+
layer.per_dim_k_scale_w = get_tensor(string_format(TN_A_PER_DIM_K_SCALE, prefix, il, "weight"), false);
2159+
layer.attn_k_rel_w = get_tensor(string_format(TN_A_ATTN_K_REL, prefix, il, "weight"), false);
2160+
2161+
// Convolution module
2162+
layer.norm_conv_w = get_tensor(string_format(TN_NORM_CONV, prefix, il, "weight"), false);
2163+
layer.norm_conv_b = get_tensor(string_format(TN_NORM_CONV, prefix, il, "bias"), false);
2164+
layer.conv_pw1_w = get_tensor(string_format(TN_CONV_PW1, prefix, il, "weight"));
2165+
layer.conv_pw1_b = get_tensor(string_format(TN_CONV_PW1, prefix, il, "bias"), false);
2166+
layer.conv_dw_w = get_tensor(string_format(TN_CONV_DW, prefix, il, "weight"));
2167+
layer.conv_dw_b = get_tensor(string_format(TN_CONV_DW, prefix, il, "bias"), false);
2168+
layer.conv_norm_w = get_tensor(string_format(TN_CONV_NORM, prefix, il, "weight"), false);
2169+
layer.conv_norm_b = get_tensor(string_format(TN_CONV_NORM, prefix, il, "bias"), false);
2170+
layer.conv_pw2_w = get_tensor(string_format(TN_CONV_PW2, prefix, il, "weight"));
2171+
layer.conv_pw2_b = get_tensor(string_format(TN_CONV_PW2, prefix, il, "bias"), false);
2172+
2173+
// FFN2 (second half-step)
2174+
layer.ff_norm_1_w = get_tensor(string_format(TN_FFN_NORM_1, prefix, il, "weight"));
2175+
layer.ff_up_1_w = get_tensor(string_format(TN_FFN_UP_1, prefix, il, "weight"));
2176+
layer.ff_up_1_b = get_tensor(string_format(TN_FFN_UP_1, prefix, il, "bias"), false);
2177+
layer.ff_down_1_w = get_tensor(string_format(TN_FFN_DOWN_1, prefix, il, "weight"));
2178+
layer.ff_down_1_b = get_tensor(string_format(TN_FFN_DOWN_1, prefix, il, "bias"), false);
2179+
layer.ff_post_norm_1_w = get_tensor(string_format(TN_A_FFN_POST_NORM_1, prefix, il, "weight"), false);
2180+
}
2181+
2182+
// Load clamp info for ClippableLinear AFTER all tensors are loaded
2183+
for (auto * tensor : tensors_to_load) {
2184+
std::string name = tensor->name;
2185+
if (string_ends_with(name, ".weight")) {
2186+
std::string name_inp_max = name;
2187+
std::string name_inp_min = name;
2188+
std::string name_out_max = name;
2189+
std::string name_out_min = name;
2190+
string_replace_all(name_inp_max, ".weight", ".input_max");
2191+
string_replace_all(name_inp_min, ".weight", ".input_min");
2192+
string_replace_all(name_out_max, ".weight", ".output_max");
2193+
string_replace_all(name_out_min, ".weight", ".output_min");
2194+
model.clamp_info_map[name] = {
2195+
get_scalar(name_inp_max, FLT_MAX),
2196+
get_scalar(name_inp_min, -FLT_MAX),
2197+
get_scalar(name_out_max, FLT_MAX),
2198+
get_scalar(name_out_min, -FLT_MAX)
2199+
};
2200+
}
2201+
}
2202+
} break;
21162203
case PROJECTOR_TYPE_LFM2A:
21172204
{
21182205
for (int i : {0, 2, 3, 5, 6}) {
@@ -2173,7 +2260,10 @@ struct clip_model_loader {
21732260
ggml_backend_buffer_set_usage(ctx_clip.buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
21742261
for (auto & t : tensors_to_load) {
21752262
ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name);
2176-
const size_t offset = tensor_offset[t->name];
2263+
GGML_ASSERT(cur && "tensor not found in ctx_data");
2264+
auto it_off = tensor_offset.find(t->name);
2265+
GGML_ASSERT(it_off != tensor_offset.end() && "no offset for tensor");
2266+
const size_t offset = it_off->second;
21772267
fin.seekg(offset, std::ios::beg);
21782268
if (!fin) {
21792269
throw std::runtime_error(string_format("%s: failed to seek for tensor %s\n", __func__, t->name));
@@ -2465,8 +2555,7 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params
24652555

24662556
// TODO: we don't support audio for Gemma 3N, but GGUF contains audio tensors
24672557
// we can remove this check when we implement audio support for Gemma 3N
2468-
skip_audio = ctx_vision->model.proj_type == PROJECTOR_TYPE_GEMMA3NV
2469-
|| ctx_vision->model.proj_type == PROJECTOR_TYPE_GEMMA4V;
2558+
skip_audio = ctx_vision->model.proj_type == PROJECTOR_TYPE_GEMMA3NV;
24702559
}
24712560

24722561
if (loader.has_audio && !skip_audio) {
@@ -2808,6 +2897,16 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
28082897
{
28092898
n_patches = ((((img->nx + 1) / 2) + 1) / 2 + 1) / 2;
28102899
} break;
2900+
case PROJECTOR_TYPE_GEMMA4A:
2901+
{
2902+
// Two Conv2D stride-2: O = floor((I + 2p - k) / s) + 1, p=1, k=3, s=2
2903+
// O = floor((I - 1) / 2) + 1
2904+
int n = img->nx;
2905+
for (int i = 0; i < 2; i++) {
2906+
n = (n - 1) / 2 + 1;
2907+
}
2908+
n_patches = n;
2909+
} break;
28112910
default:
28122911
GGML_ABORT("unsupported projector type");
28132912
}
@@ -3232,6 +3331,56 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
32323331
}
32333332
set_input_i32("pos_w", pos_data);
32343333
} break;
3334+
case PROJECTOR_TYPE_GEMMA4A:
3335+
{
3336+
GGML_ASSERT(imgs.entries.size() == 1);
3337+
const auto & img0 = imgs.entries.front();
3338+
// Compute n_pos matching SSCP output: two stride-2 convs
3339+
int n_pos = img0->nx;
3340+
for (int i = 0; i < 2; i++) { n_pos = (n_pos - 1) / 2 + 1; }
3341+
3342+
// Chunked local attention: blocked causal mask and RPE
3343+
const int chunk_size = 12;
3344+
const int max_past = 12;
3345+
const int context_size = chunk_size + max_past;
3346+
const int num_blocks = (n_pos + chunk_size - 1) / chunk_size;
3347+
3348+
// Blocked causal attention mask: [context_size, chunk_size, num_blocks]
3349+
{
3350+
std::vector<float> mask(context_size * chunk_size * num_blocks, -INFINITY);
3351+
for (int b = 0; b < num_blocks; b++) {
3352+
for (int q = 0; q < chunk_size; q++) {
3353+
int gq = b * chunk_size + q;
3354+
for (int k = 0; k < context_size; k++) {
3355+
int gk = b * chunk_size - max_past + k;
3356+
if (gq < n_pos && gk >= 0 && gk < n_pos && gk <= gq) {
3357+
mask[k + q * context_size + b * context_size * chunk_size] = 0.0f;
3358+
}
3359+
}
3360+
}
3361+
}
3362+
set_input_f32("kq_mask", mask);
3363+
}
3364+
3365+
// Sinusoidal RPE: 13 positions [12, 11, ..., 0]
3366+
{
3367+
const int n_embd = ctx->model.hparams.n_embd;
3368+
const int num_timescales = n_embd / 2;
3369+
const float log_timescale_increment = logf(10000.0f) / std::max(num_timescales - 1, 1);
3370+
const int rpe_len = max_past + 1;
3371+
std::vector<float> pos_emb(n_embd * rpe_len, 0.0f);
3372+
for (int p = 0; p < rpe_len; p++) {
3373+
float position = (float)(max_past - p);
3374+
for (int i = 0; i < num_timescales; i++) {
3375+
float inv_ts = expf(-(float)i * log_timescale_increment);
3376+
float scaled = position * inv_ts;
3377+
pos_emb[p * n_embd + i] = sinf(scaled);
3378+
pos_emb[p * n_embd + i + num_timescales] = cosf(scaled);
3379+
}
3380+
}
3381+
set_input_f32("pos_emb", pos_emb);
3382+
}
3383+
} break;
32353384
case PROJECTOR_TYPE_LFM2A:
32363385
{
32373386
GGML_ASSERT(imgs.entries.size() == 1);
@@ -3391,6 +3540,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
33913540
return ctx->model.mm_fc_w->ne[1];
33923541
case PROJECTOR_TYPE_LFM2A:
33933542
return ctx->model.position_embeddings->ne[0];
3543+
case PROJECTOR_TYPE_GEMMA4A:
3544+
return ctx->model.hparams.projection_dim;
33943545
case PROJECTOR_TYPE_GLM4V:
33953546
return ctx->model.mm_ffn_down_w->ne[1];
33963547
default:

0 commit comments

Comments
 (0)