Skip to content

Commit 9ca009a

Browse files
committed
Enhance multimodal capabilities with new projector types and suppress tokens support
- Introduced support for new projector types: Gemma 4 Unified Vision (GEMMA4UV) and Gemma 4 Unified Audio (GEMMA4UA). - Added functionality to suppress specific tokens during processing, improving model performance and flexibility. - Updated tensor mappings and constants to accommodate new projector types and suppress tokens. - Enhanced audio preprocessing for GEMMA4UA to handle raw waveform inputs efficiently. - Revised model building logic for GEMMA4UV to utilize LayerNorm and positional embeddings effectively.
1 parent 0a635dc commit 9ca009a

18 files changed

Lines changed: 269 additions & 6 deletions

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1148,7 +1148,24 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
11481148
case GGML_OP_ROPE:
11491149
return true;
11501150
case GGML_OP_IM2COL:
1151-
return ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
1151+
{
1152+
if (!(ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32))) {
1153+
return false;
1154+
}
1155+
// The Metal im2col kernel launches KH*KW threads per threadgroup
1156+
// (one per kernel element). If the conv kernel is large enough that
1157+
// KH*KW exceeds the Apple GPU threadgroup cap (1024), the kernel
1158+
// would hit a runtime GGML_ASSERT. Decline here so the op falls back
1159+
// to CPU instead of crashing. Affects large-kernel patch convs such
1160+
// as Gemma 4 unified vision (gemma4uv).
1161+
const bool is_2D = ggml_get_op_params_i32(op, 6) == 1;
1162+
const int64_t KW = op->src[0]->ne[0];
1163+
const int64_t KH = is_2D ? op->src[0]->ne[1] : 1;
1164+
if (KH*KW > 1024) {
1165+
return false;
1166+
}
1167+
return true;
1168+
}
11521169
case GGML_OP_CONV_2D:
11531170
return ggml_is_contiguous(op->src[0]) &&
11541171
op->src[1]->type == GGML_TYPE_F32 &&

gguf-py/gguf/constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ class Tokenizer:
268268
ADD_PREFIX = "tokenizer.ggml.add_space_prefix"
269269
REMOVE_EXTRA_WS = "tokenizer.ggml.remove_extra_whitespaces"
270270
PRECOMPILED_CHARSMAP = "tokenizer.ggml.precompiled_charsmap"
271+
SUPPRESS_TOKENS = "tokenizer.ggml.suppress_tokens"
271272
HF_JSON = "tokenizer.huggingface.json"
272273
RWKV = "tokenizer.rwkv.world"
273274
CHAT_TEMPLATE = "tokenizer.chat_template"
@@ -722,6 +723,7 @@ class MODEL_TENSOR(IntEnum):
722723
V_ENC_EMBD_CLS = auto()
723724
V_ENC_EMBD_PATCH = auto()
724725
V_ENC_EMBD_NORM = auto()
726+
V_ENC_EMBD_PATCH_NORM = auto() # allow multiple norms in the same embd, e.g. for gemma4u
725727
V_ENC_EMBD_POS = auto()
726728
V_ENC_INPUT_NORM = auto()
727729
V_ENC_ATTN_QKV = auto()
@@ -1212,6 +1214,7 @@ class MODEL_TENSOR(IntEnum):
12121214
MODEL_TENSOR.V_ENC_EMBD_CLS: "v.class_embd",
12131215
MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.patch_embd",
12141216
MODEL_TENSOR.V_ENC_EMBD_NORM: "v.norm_embd",
1217+
MODEL_TENSOR.V_ENC_EMBD_PATCH_NORM: "v.patch_norm.{bid}",
12151218
MODEL_TENSOR.V_ENC_EMBD_POS: "v.position_embd",
12161219
MODEL_TENSOR.V_ENC_ATTN_QKV: "v.blk.{bid}.attn_qkv",
12171220
MODEL_TENSOR.V_ENC_ATTN_Q: "v.blk.{bid}.attn_q",
@@ -1366,6 +1369,7 @@ class MODEL_TENSOR(IntEnum):
13661369
MODEL_TENSOR.V_ENC_EMBD_CLS,
13671370
MODEL_TENSOR.V_ENC_EMBD_PATCH,
13681371
MODEL_TENSOR.V_ENC_EMBD_NORM,
1372+
MODEL_TENSOR.V_ENC_EMBD_PATCH_NORM,
13691373
MODEL_TENSOR.V_ENC_EMBD_POS,
13701374
MODEL_TENSOR.V_ENC_EMBD_IMGNL,
13711375
MODEL_TENSOR.V_ENC_EMBD_VSEP,
@@ -4149,6 +4153,8 @@ class VisionProjectorType:
41494153
GEMMA3NA = "gemma3na"
41504154
GEMMA4V = "gemma4v"
41514155
GEMMA4A = "gemma4a"
4156+
GEMMA4UV = "gemma4uv" # "unified" variant
4157+
GEMMA4UA = "gemma4ua" # "unified" variant
41524158
PHI4 = "phi4"
41534159
IDEFICS3 = "idefics3"
41544160
PIXTRAL = "pixtral"

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,6 +1122,9 @@ def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
11221122

11231123
self.add_string(Keys.Tokenizer.CHAT_TEMPLATE, value)
11241124

1125+
def add_suppress_tokens(self, tokens: Sequence[int]) -> None:
1126+
self.add_array(Keys.Tokenizer.SUPPRESS_TOKENS, tokens)
1127+
11251128
def add_eot_token_id(self, id: int) -> None:
11261129
self.add_uint32(Keys.Tokenizer.EOT_ID, id)
11271130

gguf-py/gguf/tensor_mapping.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,13 +1409,18 @@ class TensorNameMap:
14091409
"model.vision_tower.patch_embedder.input_proj", # gemma4
14101410
"vision_tower.patch_embed.patchifier.proj", # dots.ocr
14111411
"vision_model.conv1", # Step3-VL
1412+
"model.vision_embedder.patch_dense", # gemma4 unified
14121413
),
14131414

14141415
MODEL_TENSOR.V_ENC_EMBD_NORM: (
14151416
"visual.post_conv_layernorm", # glm4v
14161417
"vision_tower.patch_embed.patchifier.norm", # dots.ocr
14171418
),
14181419

1420+
MODEL_TENSOR.V_ENC_EMBD_PATCH_NORM: (
1421+
"model.vision_embedder.patch_ln{bid}", # gemma4 unified
1422+
),
1423+
14191424
MODEL_TENSOR.V_ENC_EMBD_POS: (
14201425
"vision_tower.vision_model.embeddings.position_embedding",
14211426
"model.vision_tower.embeddings.position_embeddings", # Intern-S1
@@ -1430,6 +1435,7 @@ class TensorNameMap:
14301435
"vision_model.radio_model.model.patch_generator.pos_embed", # Nemotron Nano v2 VL
14311436
"model.vision_tower.patch_embedder.position_embedding_table", # gemma4
14321437
"vision_model.positional_embedding", # Step3-VL
1438+
"model.vision_embedder.pos_embedding", # gemma4 unified
14331439
),
14341440

14351441
MODEL_TENSOR.V_ENC_EMBD_IMGNL: (

src/llama-arch.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
330330
{ LLM_KV_TOKENIZER_FIM_PAD_ID, "tokenizer.ggml.fim_pad_token_id" },
331331
{ LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" },
332332
{ LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" },
333+
{ LLM_KV_TOKENIZER_SUPPRESS_TOKENS, "tokenizer.ggml.suppress_tokens" },
333334

334335
{ LLM_KV_ADAPTER_TYPE, "adapter.type" },
335336
{ LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" },

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ enum llm_kv {
312312
LLM_KV_TOKENIZER_FIM_PAD_ID,
313313
LLM_KV_TOKENIZER_FIM_REP_ID,
314314
LLM_KV_TOKENIZER_FIM_SEP_ID,
315+
LLM_KV_TOKENIZER_SUPPRESS_TOKENS,
315316

316317
LLM_KV_ADAPTER_TYPE,
317318
LLM_KV_ADAPTER_LORA_ALPHA,

src/llama-vocab.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1668,6 +1668,8 @@ struct llama_vocab::impl {
16681668
// set of all tokens that cause "end of generation"
16691669
std::set<llama_token> special_eog_ids;
16701670

1671+
std::vector<llama_token> suppress_tokens;
1672+
16711673
std::unique_ptr<llm_tokenizer> tokenizer;
16721674

16731675
std::vector<char> precompiled_charsmap;
@@ -2344,6 +2346,16 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
23442346
}
23452347
}
23462348

2349+
// suppress tokens
2350+
{
2351+
const int suppress_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SUPPRESS_TOKENS).c_str());
2352+
if (suppress_idx != -1) {
2353+
const int n = gguf_get_arr_n(ctx, suppress_idx);
2354+
const int32_t * data = (const int32_t *) gguf_get_arr_data(ctx, suppress_idx);
2355+
suppress_tokens.assign(data, data + n);
2356+
}
2357+
}
2358+
23472359
// auto-detect special tokens by text
23482360
// TODO: convert scripts should provide these tokens through the KV metadata LLM_KV_TOKENIZER_...
23492361
// for now, we apply this workaround to find the tokens based on their text
@@ -3758,6 +3770,10 @@ bool llama_vocab::get_treat_whitespace_as_suffix() const {
37583770
return pimpl->treat_whitespace_as_suffix;
37593771
}
37603772

3773+
const std::vector<llama_token> & llama_vocab::get_suppress_tokens() const {
3774+
return pimpl->suppress_tokens;
3775+
}
3776+
37613777
int llama_vocab::max_token_len() const {
37623778
return pimpl->max_token_len;
37633779
}

src/llama-vocab.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ struct llama_vocab {
137137
bool get_escape_whitespaces () const;
138138
bool get_treat_whitespace_as_suffix() const;
139139

140+
const std::vector<llama_token> & get_suppress_tokens() const;
141+
140142
int max_token_len() const;
141143

142144
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;

tools/mtmd/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ add_library(mtmd
2020
models/dotsocr.cpp
2121
models/gemma4a.cpp
2222
models/gemma4v.cpp
23+
models/gemma4ua.cpp
24+
models/gemma4uv.cpp
2325
models/glm4v.cpp
2426
models/hunyuanocr.cpp
2527
models/internvl.cpp

tools/mtmd/clip-impl.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
#define TN_PATCH_EMBD_1 "v.patch_embd.weight.1"
7676
#define TN_PATCH_BIAS "v.patch_embd.bias"
7777
#define TN_NORM_EMBD "v.norm_embd.%s"
78+
#define TN_PATCH_NORM "v.patch_norm.%d.%s"
7879
#define TN_ATTN_QKV "%s.blk.%d.attn_qkv.%s"
7980
#define TN_ATTN_K "%s.blk.%d.attn_k.%s"
8081
#define TN_ATTN_Q "%s.blk.%d.attn_q.%s"
@@ -265,6 +266,8 @@ enum projector_type {
265266
PROJECTOR_TYPE_GEMMA3NA,
266267
PROJECTOR_TYPE_GEMMA4V,
267268
PROJECTOR_TYPE_GEMMA4A,
269+
PROJECTOR_TYPE_GEMMA4UV,
270+
PROJECTOR_TYPE_GEMMA4UA,
268271
PROJECTOR_TYPE_PHI4,
269272
PROJECTOR_TYPE_IDEFICS3,
270273
PROJECTOR_TYPE_PIXTRAL,
@@ -311,6 +314,8 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
311314
{ PROJECTOR_TYPE_GEMMA3NA, "gemma3na"},
312315
{ PROJECTOR_TYPE_GEMMA4V, "gemma4v"},
313316
{ PROJECTOR_TYPE_GEMMA4A, "gemma4a"},
317+
{ PROJECTOR_TYPE_GEMMA4UV, "gemma4uv"},
318+
{ PROJECTOR_TYPE_GEMMA4UA, "gemma4ua"},
314319
{ PROJECTOR_TYPE_PHI4, "phi4"},
315320
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
316321
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},

0 commit comments

Comments
 (0)