@@ -798,7 +798,8 @@ def __init__(self, *args, **kwargs):
798798 # remap audio hparams
799799 if self .hparams_audio :
800800 self .hparams_audio ["feat_in" ] = self .hparams_audio .get ("input_feat_size" , 128 )
801- self .hparams_audio ["intermediate_size" ] = self .hparams_audio ["hidden_size" ] * 4
801+ if "hidden_size" in self .hparams_audio :
802+ self .hparams_audio ["intermediate_size" ] = self .hparams_audio ["hidden_size" ] * 4
802803 else :
803804 self .has_audio_encoder = False
804805
@@ -872,7 +873,7 @@ def __init__(self, *args, **kwargs):
872873 assert self .hparams_audio is not None
873874 text_embd_dim = self .hparams_vision ["mm_embed_dim" ]
874875 self .hparams_vision ["hidden_size" ] = text_embd_dim
875- self .hparams_audio ["hidden_size" ] = text_embd_dim
876+ self .hparams_audio ["hidden_size" ] = self . hparams_audio [ "audio_embed_dim" ]
876877 # this is a transformer-less vision tower, the params below are redundant but set to avoid error
877878 self .hparams_vision ["intermediate_size" ] = 0
878879 self .hparams_vision ["num_layers" ] = 0
@@ -897,7 +898,10 @@ def modify_tensors(self, data_torch, name, bid):
897898 # ggml im2col outputs in RR..GG..BB.. (CHW) order, but weight expects RGBRGB.. (HWC).
898899 # Permute columns so column i aligns with CHW input position i.
899900 assert self .hparams_vision is not None
900- p = self .hparams_vision ["model_patch_size" ]
901+ if "model_patch_size" in self .hparams_vision :
902+ p = self .hparams_vision ["model_patch_size" ]
903+ else :
904+ p = self .hparams_vision ["patch_size" ] * self .hparams_vision ["pooling_kernel_size" ]
901905 i = torch .arange (p * p * 3 )
902906 ch = i // (p * p )
903907 row = (i % (p * p )) // p
@@ -908,7 +912,10 @@ def modify_tensors(self, data_torch, name, bid):
908912 elif "patch_ln1.weight" in name or "patch_ln1.bias" in name :
909913 # same permutation for patch_ln1 as patch_dense to align with CHW input order
910914 assert self .hparams_vision is not None
911- p = self .hparams_vision ["model_patch_size" ]
915+ if "model_patch_size" in self .hparams_vision :
916+ p = self .hparams_vision ["model_patch_size" ]
917+ else :
918+ p = self .hparams_vision ["patch_size" ] * self .hparams_vision ["pooling_kernel_size" ]
912919 i = torch .arange (p * p * 3 )
913920 ch = i // (p * p )
914921 row = (i % (p * p )) // p
0 commit comments