Skip to content

Commit c800cf2

Browse files
ngxsonwordingone
authored andcommitted
model, mtmd: fix gguf conversion for audio/vision mmproj (ggml-org#21309)
* fix gguf conversion for audio/vision mmproj * fix test
1 parent 5dd0f44 commit c800cf2

26 files changed

Lines changed: 1440 additions & 150 deletions

common/chat-auto-parser-generator.cpp

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,8 @@ common_peg_parser analyze_tools::build_parser(parser_build_context & ctx) const
168168
return build_tool_parser_tag_json(ctx);
169169
case tool_format::TAG_WITH_TAGGED:
170170
return build_tool_parser_tag_tagged(ctx);
171+
case tool_format::TAG_WITH_GEMMA4_DICT:
172+
return build_tool_parser_tag_gemma4_dict(ctx);
171173
default:
172174
LOG_ERR("[ERROR] Template seems to support tool calls, but failed to determine tool format. Tool calling will not work properly. "
173175
"Check for a fixed template for your model in the models/templates directory of your llama.cpp installation or "
@@ -439,4 +441,113 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
439441
p.end();
440442
}
441443

444+
common_peg_parser analyze_tools::build_tool_parser_tag_gemma4_dict(parser_build_context & ctx) const {
445+
auto & p = ctx.p;
446+
const auto & inputs = ctx.inputs;
447+
bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
448+
449+
// The Gemma4 string quote token used in place of JSON "
450+
static const std::string QUOTE = "<|\"|>";
451+
452+
common_peg_parser tool_choice = p.choice();
453+
454+
foreach_function(inputs.tools, [&](const json & tool) {
455+
const auto & func = tool.at("function");
456+
std::string name = func.at("name");
457+
const auto & params = func.at("parameters");
458+
459+
if (!params.contains("properties") || !params.at("properties").is_object()) {
460+
// No arguments - just match the function name with empty braces
461+
auto func_parser = p.atomic(
462+
p.tool_open(p.literal(function.name_prefix) + p.tool_name(p.literal(name)) + p.literal("{")) +
463+
p.tool_args(p.eps()) +
464+
p.tool_close(p.literal("}")));
465+
tool_choice |= p.rule("tool-" + name, func_parser);
466+
return;
467+
}
468+
469+
const auto & properties = params.at("properties");
470+
std::set<std::string> required;
471+
if (params.contains("required") && params.at("required").is_array()) {
472+
params.at("required").get_to(required);
473+
}
474+
475+
// Build per-argument parsers, sorted alphabetically (matching template's dictsort)
476+
struct arg_entry {
477+
std::string param_name;
478+
common_peg_parser parser;
479+
};
480+
std::vector<arg_entry> arg_entries;
481+
482+
for (const auto & [param_name, param_schema] : properties.items()) {
483+
std::string type = "object";
484+
auto type_v = param_schema.contains("type") ? param_schema.at("type") : json::object();
485+
if (type_v.is_string()) type_v.get_to(type);
486+
487+
common_peg_parser value_parser = p.eps();
488+
if (type == "string") {
489+
// String values are delimited by <|"|>...<|"|>
490+
value_parser =
491+
p.literal(QUOTE) +
492+
p.tool_arg_string_value(p.schema(p.until(QUOTE),
493+
"tool-" + name + "-arg-" + param_name + "-schema", param_schema, true)) +
494+
p.literal(QUOTE);
495+
} else {
496+
// Numbers, booleans: raw text up to the next comma or closing brace
497+
value_parser = p.tool_arg_value(p.until_one_of({",", "}"}));
498+
}
499+
500+
auto arg = p.tool_arg(
501+
p.tool_arg_open(p.tool_arg_name(p.literal(param_name)) + p.literal(":")) +
502+
value_parser +
503+
p.tool_arg_close(p.eps()));
504+
505+
arg_entries.push_back({param_name, p.rule("tool-" + name + "-arg-" + param_name, arg)});
506+
}
507+
508+
// Sort alphabetically to match Jinja's dictsort
509+
std::sort(arg_entries.begin(), arg_entries.end(), [](const auto & a, const auto & b) {
510+
return a.param_name < b.param_name;
511+
});
512+
513+
// Build arg sequence: any arg, then zero-or-more comma-separated additional args
514+
common_peg_parser args_seq = p.eps();
515+
if (!arg_entries.empty()) {
516+
common_peg_parser any_arg = p.choice();
517+
for (auto & entry : arg_entries) {
518+
any_arg |= entry.parser;
519+
}
520+
args_seq = p.optional(
521+
any_arg + p.repeat(p.literal(",") + any_arg, 0, (int) arg_entries.size() - 1));
522+
}
523+
524+
// Full parser: call:name{args}
525+
auto func_parser = p.atomic(
526+
p.tool_open(p.literal(function.name_prefix) + p.tool_name(p.literal(name)) + p.literal("{")) +
527+
p.tool_args(args_seq) +
528+
p.tool_close(p.literal("}")));
529+
530+
tool_choice |= p.rule("tool-" + name, func_parser);
531+
});
532+
533+
// Wrap each call in <|tool_call>...</tool_call|>
534+
auto wrapped_call = p.literal(format.per_call_start) + tool_choice + p.literal(format.per_call_end);
535+
536+
common_peg_parser tool_calls = p.eps();
537+
if (inputs.parallel_tool_calls) {
538+
tool_calls = p.trigger_rule("tool-call", wrapped_call + p.zero_or_more(p.space() + wrapped_call));
539+
} else {
540+
tool_calls = p.trigger_rule("tool-call", wrapped_call);
541+
}
542+
543+
if (!force_tools) {
544+
tool_calls = p.optional(tool_calls);
545+
}
546+
547+
auto content_before_tools = p.until(format.per_call_start);
548+
return ctx.reasoning_parser +
549+
(force_tools ? p.eps() : p.optional(p.content(content_before_tools))) +
550+
tool_calls + p.end();
551+
}
552+
442553
} // namespace autoparser

common/chat-auto-parser.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ enum class tool_format {
144144
JSON_NATIVE, // Pure JSON: {"name": "X", "arguments": {...}}
145145
TAG_WITH_JSON, // Tag-based with JSON args: <function=X>{...}</function>
146146
TAG_WITH_TAGGED, // Tag-based with tagged args: <param=key>value</param>
147+
TAG_WITH_GEMMA4_DICT, // Gemma4 custom dict: <|tool_call>call:name{key:<|"|>val<|"|>}<tool_call|>
147148
};
148149

149150
inline std::ostream & operator<<(std::ostream & os, const tool_format & format) {
@@ -156,6 +157,8 @@ inline std::ostream & operator<<(std::ostream & os, const tool_format & format)
156157
return os << "TAG_WITH_JSON";
157158
case tool_format::TAG_WITH_TAGGED:
158159
return os << "TAG_WITH_TAGGED";
160+
case tool_format::TAG_WITH_GEMMA4_DICT:
161+
return os << "TAG_WITH_GEMMA4_DICT";
159162
default:
160163
return os << "UNKNOWN";
161164
}

common/chat-diff-analyzer.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,33 @@ static std::vector<std::function<void(const common_chat_template & tmpl, autopar
9292
LOG_DBG(ANSI_ORANGE "[Patch: Functionary 3.1]\n" ANSI_RESET);
9393
}
9494
},
95+
// Gemma4 - custom dict format: <|tool_call>call:name{key:<|"|>val<|"|>}<tool_call|>
96+
[](const common_chat_template & tmpl, autoparser & analysis) -> void {
97+
if (tmpl.src.find("'<|tool_call>call:'") != std::string::npos) {
98+
analysis.tools.format.mode = tool_format::TAG_WITH_GEMMA4_DICT;
99+
analysis.tools.format.per_call_start = "<|tool_call>";
100+
analysis.tools.format.per_call_end = "<tool_call|>";
101+
analysis.tools.format.section_start = "";
102+
analysis.tools.format.section_end = "";
103+
analysis.tools.function.name_prefix = "call:";
104+
analysis.tools.function.name_suffix = "";
105+
analysis.tools.arguments.start = "{";
106+
analysis.tools.arguments.end = "}";
107+
analysis.tools.arguments.name_suffix = ":";
108+
analysis.tools.arguments.separator = ",";
109+
analysis.reasoning.mode = reasoning_mode::TAG_BASED;
110+
analysis.reasoning.start = "<|channel>thought\n";
111+
analysis.reasoning.end = "<channel|>";
112+
analysis.preserved_tokens.clear();
113+
analysis.preserved_tokens.push_back("<|tool_call>");
114+
analysis.preserved_tokens.push_back("<tool_call|>");
115+
analysis.preserved_tokens.push_back("<|tool_response>");
116+
analysis.preserved_tokens.push_back("<tool_response|>");
117+
analysis.preserved_tokens.push_back("<|\"|>");
118+
analysis.preserved_tokens.push_back("<|turn>");
119+
LOG_DBG(ANSI_ORANGE "[Patch: Gemma4]\n" ANSI_RESET);
120+
}
121+
},
95122
// DeepSeek-R1-Distill-Qwen
96123
[](const common_chat_template & tmpl, autoparser & analysis) -> void {
97124
if (tmpl.src.find(

common/chat.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1857,6 +1857,10 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
18571857
workaround::func_args_not_string(params.messages);
18581858
}
18591859

1860+
if (src.find("'<|tool_call>call:'") != std::string::npos) {
1861+
workaround::convert_tool_responses_gemma4(params.messages);
1862+
}
1863+
18601864
params.add_generation_prompt = false;
18611865
std::string no_gen_prompt = common_chat_template_direct_apply(tmpl, params);
18621866
params.add_generation_prompt = true;

convert_hf_to_gguf.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,7 +1125,7 @@ def set_gguf_parameters(self):
11251125
if (n_experts := self.find_hparam(["num_local_experts", "num_experts"], optional=True)) is not None:
11261126
self.gguf_writer.add_expert_count(n_experts)
11271127
logger.info(f"gguf: expert count = {n_experts}")
1128-
if (n_experts_used := self.find_hparam(["num_experts_per_tok", "num_experts_per_token"], optional=True)) is not None:
1128+
if (n_experts_used := self.find_hparam(["num_experts_per_tok", "num_experts_per_token", "top_k_experts"], optional=True)) is not None:
11291129
self.gguf_writer.add_expert_used_count(n_experts_used)
11301130
logger.info(f"gguf: experts used count = {n_experts_used}")
11311131
if (n_expert_groups := self.hparams.get("n_group")) is not None:
@@ -6686,7 +6686,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
66866686
@ModelBase.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration")
66876687
class Gemma3Model(TextModel):
66886688
model_arch = gguf.MODEL_ARCH.GEMMA3
6689-
norm_shift = 1.0 # Gemma3RMSNorm adds 1.0 to the norm value
6689+
6690+
def norm_shift(self, name: str) -> float:
6691+
return 1.0 if name.endswith("norm.weight") else 0.0 # Gemma3RMSNorm adds 1.0 to the norm value
66906692

66916693
def set_vocab(self):
66926694
if (self.dir_model / "tokenizer.model").is_file():
@@ -6724,17 +6726,22 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
67246726

67256727
# remove OOV (out-of-vocabulary) rows in token_embd
67266728
if "embed_tokens.weight" in name:
6729+
n_vocab_real = -1
67276730
if (self.dir_model / "tokenizer.model").is_file():
67286731
tokens = self._create_vocab_sentencepiece()[0]
6732+
n_vocab_real = len(tokens)
67296733
else:
6730-
tokens = self.get_vocab_base()[0]
6731-
data_torch = data_torch[:len(tokens)]
6734+
with open(self.dir_model / "tokenizer.json", "r", encoding="utf-8") as f:
6735+
tokenizer_json = json.load(f)
6736+
n_vocab_real = len(tokenizer_json["model"]["vocab"]) + len(tokenizer_json["added_tokens"])
6737+
data_torch = data_torch[:n_vocab_real]
67326738

67336739
# ref code in Gemma3RMSNorm
67346740
# output = output * (1.0 + self.weight.float())
67356741
# note: this is not the case on gemma3n
6736-
if name.endswith("norm.weight"):
6737-
data_torch = data_torch + self.norm_shift
6742+
f_shift = self.norm_shift(name)
6743+
if f_shift != 0.0:
6744+
data_torch = data_torch + f_shift
67386745

67396746
yield from super().modify_tensors(data_torch, name, bid)
67406747

@@ -6908,7 +6915,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
69086915
assert data_torch.shape[2] == 1
69096916
data_torch = data_torch.reshape(data_torch.shape[0], data_torch.shape[1])
69106917

6911-
yield from super().modify_tensors(data_torch, name, bid)
6918+
mapped_name = self.map_tensor_name(name, (".weight", ".bias", ".input_max", ".input_min", ".output_max", ".output_min"))
6919+
yield (mapped_name, data_torch)
69126920

69136921

69146922
@ModelBase.register("Gemma3nForConditionalGeneration")
@@ -7033,7 +7041,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
70337041
@ModelBase.register("Gemma3nForCausalLM", "Gemma3nForConditionalGeneration")
70347042
class Gemma3NModel(Gemma3Model):
70357043
model_arch = gguf.MODEL_ARCH.GEMMA3N
7036-
norm_shift = 0.0 # same value with Gemma3p5RMSNorm scale_shift on python code
70377044

70387045
_altup_proj: list[Tensor] = []
70397046
_altup_unembd: list[Tensor] = []
@@ -7052,6 +7059,10 @@ def __init__(self, *args, **kwargs):
70527059
torch.Tensor(), # to be replaced
70537060
]
70547061

7062+
def norm_shift(self, name: str) -> float:
7063+
del name
7064+
return 0.0 # same value with Gemma3p5RMSNorm scale_shift on python code
7065+
70557066
def set_vocab(self):
70567067
# For Gemma3n multimodal models, we need the FULL vocab_size (262400)
70577068
# which includes special tokens from 262144-262399 for vision/audio.
@@ -7197,6 +7208,9 @@ def set_vocab(self):
71977208

71987209
assert len(tokens) == vocab.vocab_size
71997210

7211+
# TODO @ngxson : there are some known (rare) issues with the tokenizer during development
7212+
# but I don't have time to dive into them right now;
7213+
# using a dedicated tokenizer name so that we can fix later without re-converting GGUF
72007214
self.gguf_writer.add_tokenizer_model("gemma4")
72017215
self.gguf_writer.add_token_list(tokens)
72027216
self.gguf_writer.add_token_scores(scores)

examples/eval-callback/eval-callback.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,18 @@ static bool run(llama_context * ctx, const common_params & params) {
1515

1616
const bool add_bos = llama_vocab_get_add_bos(vocab);
1717

18-
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);
18+
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos, true);
1919

2020
if (tokens.empty()) {
2121
LOG_ERR("%s : there are not input tokens to process - (try to provide a prompt with '-p')\n", __func__);
2222
return false;
2323
}
2424

25+
LOG_INF("number of input tokens = %zu\n", tokens.size());
26+
for (size_t i = 0; i < tokens.size(); ++i) {
27+
LOG_INF(" %d\n", tokens[i]);
28+
}
29+
2530
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
2631
LOG_ERR("%s : failed to eval\n", __func__);
2732
return false;

0 commit comments

Comments
 (0)