Skip to content

Commit c419fd5

Browse files
committed
Add Gemma4Unified models and logits bias handling for suppress tokens
- Introduced Gemma4UnifiedForConditionalGeneration and Gemma4UnifiedVisionAudioModel classes to enhance multimodal capabilities. - Implemented functionality to retrieve and apply suppress tokens from generation configuration, improving model output control. - Updated tensor modification logic to accommodate new model architectures and ensure proper handling of positional embeddings. - Enhanced logits bias handling in the C++ implementation to mirror suppress tokens functionality, addressing known output issues.
1 parent 9ca009a commit c419fd5

2 files changed

Lines changed: 122 additions & 2 deletions

File tree

convert_hf_to_gguf.py

Lines changed: 87 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7692,7 +7692,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
76927692
yield from super().modify_tensors(data_torch, name, bid)
76937693

76947694

7695-
@ModelBase.register("Gemma4ForConditionalGeneration")
7695+
@ModelBase.register("Gemma4ForConditionalGeneration", "Gemma4ForCausalLM")
76967696
class Gemma4Model(Gemma3Model):
76977697
model_arch = gguf.MODEL_ARCH.GEMMA4
76987698

@@ -7817,6 +7817,26 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
78177817
yield from super().modify_tensors(data_torch, name, bid)
78187818

78197819

7820+
@ModelBase.register("Gemma4UnifiedForConditionalGeneration")
7821+
class Gemma4UnifiedModel(Gemma4Model):
7822+
model_arch = gguf.MODEL_ARCH.GEMMA4
7823+
7824+
def _get_suppress_tokens(self) -> Sequence[int] | None:
7825+
gen_cfg_path = self.dir_model / "generation_config.json"
7826+
if gen_cfg_path.is_file():
7827+
with open(gen_cfg_path, encoding="utf-8") as f:
7828+
gen_cfg = json.load(f)
7829+
return gen_cfg.get("suppress_tokens")
7830+
return None
7831+
7832+
def set_gguf_parameters(self):
7833+
super().set_gguf_parameters()
7834+
7835+
suppress_tokens = self._get_suppress_tokens()
7836+
if suppress_tokens is not None:
7837+
self.gguf_writer.add_suppress_tokens(suppress_tokens)
7838+
7839+
78207840
@ModelBase.register("Gemma4AssistantForCausalLM")
78217841
class Gemma4AssistantModel(Gemma4Model):
78227842
model_arch = gguf.MODEL_ARCH.GEMMA4_ASSISTANT
@@ -7891,7 +7911,8 @@ def __init__(self, *args, **kwargs):
78917911
# remap audio hparams
78927912
if self.hparams_audio:
78937913
self.hparams_audio["feat_in"] = self.hparams_audio.get("input_feat_size", 128)
7894-
self.hparams_audio["intermediate_size"] = self.hparams_audio["hidden_size"] * 4
7914+
if "hidden_size" in self.hparams_audio:
7915+
self.hparams_audio["intermediate_size"] = self.hparams_audio["hidden_size"] * 4
78957916
else:
78967917
self.has_audio_encoder = False
78977918

@@ -7956,6 +7977,70 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
79567977
yield (mapped_name, data_torch)
79577978

79587979

7980+
@ModelBase.register("Gemma4UnifiedForConditionalGeneration")
7981+
class Gemma4UnifiedVisionAudioModel(Gemma4VisionAudioModel):
7982+
has_audio_encoder = True
7983+
has_vision_encoder = True
7984+
7985+
def __init__(self, *args, **kwargs):
7986+
super().__init__(*args, **kwargs)
7987+
assert self.hparams_vision is not None
7988+
assert self.hparams_audio is not None
7989+
text_embd_dim = self.hparams_vision["mm_embed_dim"]
7990+
self.hparams_vision["hidden_size"] = text_embd_dim
7991+
self.hparams_audio["hidden_size"] = self.hparams_audio["audio_embed_dim"]
7992+
# this is a transformer-less vision tower, the params below are redundant but set to avoid error
7993+
self.hparams_vision["intermediate_size"] = 0
7994+
self.hparams_vision["num_layers"] = 0
7995+
self.hparams_vision["num_attention_heads"] = 0
7996+
self.hparams_audio["intermediate_size"] = 0
7997+
self.hparams_audio["num_layers"] = 0
7998+
self.hparams_audio["num_attention_heads"] = 0
7999+
8000+
def set_gguf_parameters(self):
8001+
super().set_gguf_parameters()
8002+
self.gguf_writer.add_clip_vision_projector_type(gguf.VisionProjectorType.GEMMA4UV)
8003+
self.gguf_writer.add_clip_audio_projector_type(gguf.VisionProjectorType.GEMMA4UA)
8004+
8005+
def modify_tensors(self, data_torch, name, bid):
8006+
if name.endswith("pos_embedding"):
8007+
name += ".weight"
8008+
data_torch = data_torch.permute(1, 0, 2)
8009+
elif ".pos_norm." in name:
8010+
# rename to patch_ln3 to reuse the tensor name scheme
8011+
name = name.replace(".pos_norm.", ".patch_ln3.")
8012+
elif "patch_dense.weight" in name:
8013+
# ggml im2col outputs in RR..GG..BB.. (CHW) order, but weight expects RGBRGB.. (HWC).
8014+
# Permute columns so column i aligns with CHW input position i.
8015+
assert self.hparams_vision is not None
8016+
if "model_patch_size" in self.hparams_vision:
8017+
p = self.hparams_vision["model_patch_size"]
8018+
else:
8019+
p = self.hparams_vision["patch_size"] * self.hparams_vision["pooling_kernel_size"]
8020+
i = torch.arange(p * p * 3)
8021+
ch = i // (p * p)
8022+
row = (i % (p * p)) // p
8023+
col = i % p
8024+
# perm[i] = HWC column index for CHW position i
8025+
perm = row * p * 3 + col * 3 + ch
8026+
data_torch = data_torch[:, perm]
8027+
elif "patch_ln1.weight" in name or "patch_ln1.bias" in name:
8028+
# same permutation for patch_ln1 as patch_dense to align with CHW input order
8029+
assert self.hparams_vision is not None
8030+
if "model_patch_size" in self.hparams_vision:
8031+
p = self.hparams_vision["model_patch_size"]
8032+
else:
8033+
p = self.hparams_vision["patch_size"] * self.hparams_vision["pooling_kernel_size"]
8034+
i = torch.arange(p * p * 3)
8035+
ch = i // (p * p)
8036+
row = (i % (p * p)) // p
8037+
col = i % p
8038+
# perm[i] = HWC index for CHW position i
8039+
perm = row * p * 3 + col * 3 + ch
8040+
data_torch = data_torch[perm]
8041+
return super().modify_tensors(data_torch, name, bid)
8042+
8043+
79598044
@ModelBase.register("Starcoder2ForCausalLM")
79608045
class StarCoder2Model(TextModel):
79618046
model_arch = gguf.MODEL_ARCH.STARCODER2

src/models/gemma4-iswa.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,31 @@ static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, in
77
idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
88
}
99

10+
// TODO @ngxson : maybe improve this in the future
11+
class llm_graph_input_logits_bias : public llm_graph_input_i {
12+
public:
13+
llm_graph_input_logits_bias(const llama_vocab & vocab) {
14+
arr.resize(vocab.n_tokens(), 0.0f);
15+
for (llama_token id : vocab.get_suppress_tokens()) {
16+
if (0 <= id && id < (int32_t)vocab.n_tokens()) {
17+
arr[id] = -INFINITY;
18+
}
19+
}
20+
}
21+
virtual ~llm_graph_input_logits_bias() = default;
22+
23+
void set_input(const llama_ubatch *) override {
24+
const int64_t n_vocab = arr.size();
25+
ggml_backend_tensor_set(logits_bias, arr.data(), 0, n_vocab*ggml_element_size(logits_bias));
26+
}
27+
28+
// bool can_reuse(const llm_graph_params & params) override;
29+
30+
ggml_tensor * logits_bias = nullptr; // F32 [n_vocab]
31+
32+
std::vector<float> arr;
33+
};
34+
1035
llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const llm_graph_params & params) :
1136
llm_graph_context(params),
1237
model(model),
@@ -253,6 +278,16 @@ llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const ll
253278
cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
254279
}
255280

281+
// apply logits bias if needed (e.g. for gemma4_unified patch)
282+
// this is to mirror the suppress_tokens patch on transformers, to avoid model from outputing <image|> and <audio|> tokens (which is a known issue related to the checkpoint)
283+
// TODO: maybe handle this inside the sampling system in the future
284+
if (!model.vocab.get_suppress_tokens().empty()) {
285+
auto inp_bias = std::make_unique<llm_graph_input_logits_bias>(model.vocab);
286+
inp_bias->logits_bias = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, inp_bias->arr.size());
287+
cur = ggml_add(ctx0, cur, inp_bias->logits_bias);
288+
res->add_input(std::move(inp_bias));
289+
}
290+
256291
cb(cur, "result_output", -1);
257292
res->t_logits = cur;
258293

0 commit comments

Comments
 (0)