Skip to content

Commit e802356

Browse files
authored
convert: Fix Gemma 4 Unified conversion (ggml-org#24118)
* Fix Gemma 4 Unified conversion * Set audio hidden size to audio_embed_dim
1 parent 4c51309 commit e802356

1 file changed

Lines changed: 11 additions & 4 deletions

File tree

conversion/gemma.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -798,7 +798,8 @@ def __init__(self, *args, **kwargs):
798798
# remap audio hparams
799799
if self.hparams_audio:
800800
self.hparams_audio["feat_in"] = self.hparams_audio.get("input_feat_size", 128)
801-
self.hparams_audio["intermediate_size"] = self.hparams_audio["hidden_size"] * 4
801+
if "hidden_size" in self.hparams_audio:
802+
self.hparams_audio["intermediate_size"] = self.hparams_audio["hidden_size"] * 4
802803
else:
803804
self.has_audio_encoder = False
804805

@@ -872,7 +873,7 @@ def __init__(self, *args, **kwargs):
872873
assert self.hparams_audio is not None
873874
text_embd_dim = self.hparams_vision["mm_embed_dim"]
874875
self.hparams_vision["hidden_size"] = text_embd_dim
875-
self.hparams_audio["hidden_size"] = text_embd_dim
876+
self.hparams_audio["hidden_size"] = self.hparams_audio["audio_embed_dim"]
876877
# this is a transformer-less vision tower, the params below are redundant but set to avoid error
877878
self.hparams_vision["intermediate_size"] = 0
878879
self.hparams_vision["num_layers"] = 0
@@ -897,7 +898,10 @@ def modify_tensors(self, data_torch, name, bid):
897898
# ggml im2col outputs in RR..GG..BB.. (CHW) order, but weight expects RGBRGB.. (HWC).
898899
# Permute columns so column i aligns with CHW input position i.
899900
assert self.hparams_vision is not None
900-
p = self.hparams_vision["model_patch_size"]
901+
if "model_patch_size" in self.hparams_vision:
902+
p = self.hparams_vision["model_patch_size"]
903+
else:
904+
p = self.hparams_vision["patch_size"] * self.hparams_vision["pooling_kernel_size"]
901905
i = torch.arange(p * p * 3)
902906
ch = i // (p * p)
903907
row = (i % (p * p)) // p
@@ -908,7 +912,10 @@ def modify_tensors(self, data_torch, name, bid):
908912
elif "patch_ln1.weight" in name or "patch_ln1.bias" in name:
909913
# same permutation for patch_ln1 as patch_dense to align with CHW input order
910914
assert self.hparams_vision is not None
911-
p = self.hparams_vision["model_patch_size"]
915+
if "model_patch_size" in self.hparams_vision:
916+
p = self.hparams_vision["model_patch_size"]
917+
else:
918+
p = self.hparams_vision["patch_size"] * self.hparams_vision["pooling_kernel_size"]
912919
i = torch.arange(p * p * 3)
913920
ch = i // (p * p)
914921
row = (i % (p * p)) // p

0 commit comments

Comments
 (0)