@@ -107,24 +107,10 @@ def tokenizer_and_embedding_resize(
107107 embedding_size = int (multiple_of * math .ceil (len (tokenizer ) / multiple_of ))
108108 num_new_tokens = num_new_tokens + embedding_size - len (tokenizer )
109109
110+ # For Mllama models, we need to resize the input and output embeddings
111+ # separately, as the model has a different input and output embeddings.
110112 if isinstance (model , MllamaForConditionalGeneration ):
111- # Get new input embedding size
112- current_input_embeddings = model .get_input_embeddings ()
113- current_output_embeddings = model .get_output_embeddings ()
114- input_embedding_size = current_input_embeddings .weight .shape [0 ] + (
115- embedding_size - current_output_embeddings .weight .shape [0 ]
116- )
117-
118- # Save current input embedding
119- resized_input_embeddings = model ._get_resized_embeddings (
120- current_input_embeddings ,
121- new_num_tokens = input_embedding_size ,
122- mean_resizing = True ,
123- )
124- resized_input_embeddings = copy .deepcopy (resized_input_embeddings )
125- resized_input_embeddings .requires_grad_ (
126- current_input_embeddings .weight .requires_grad
127- )
113+ resized_input_embeddings = get_resized_input_embeddings (model , embedding_size )
128114
129115 # Resize input and output embeddings
130116 model .resize_token_embeddings (embedding_size )
@@ -153,3 +139,32 @@ def tokenizer_and_embedding_resize(
153139 output_embeddings [- num_new_tokens :] = output_embeddings_avg
154140
155141 return {"num_new_tokens" : num_new_tokens , "new_embedding_size" : embedding_size }
142+
143+
144+ def get_resized_input_embeddings (model , embedding_size ):
145+ """Get resized input embeddings for Mllama models.
146+ Args:
147+ model: Mllama models.
148+ embedding_size: Size of the new embeddings.
149+ Returns:
150+ resized_input_embeddings: Resized input embeddings.
151+ """
152+ # Get current input and output embeddings
153+ # and their respective vocab sizes
154+ current_input_embeddings = model .get_input_embeddings ()
155+ current_output_embeddings = model .get_output_embeddings ()
156+ input_embedding_size = current_input_embeddings .weight .shape [0 ] + (
157+ embedding_size - current_output_embeddings .weight .shape [0 ]
158+ )
159+
160+ # Save current input embedding
161+ resized_input_embeddings = model ._get_resized_embeddings (
162+ current_input_embeddings ,
163+ new_num_tokens = input_embedding_size ,
164+ mean_resizing = True ,
165+ )
166+ resized_input_embeddings = copy .deepcopy (resized_input_embeddings )
167+ resized_input_embeddings .requires_grad_ (
168+ current_input_embeddings .weight .requires_grad
169+ )
170+ return resized_input_embeddings
0 commit comments