Skip to content

Commit 5bdcd99

Browse files
committed
PR Change 2
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
1 parent 1e78997 commit 5bdcd99

5 files changed

Lines changed: 41 additions & 30 deletions

File tree

tuning/data/data_preprocessing_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@ def get_data_collator(
6969
"""
7070

7171
if processor:
72+
if is_padding_free or packing:
73+
raise ValueError(
74+
"Vision model tuning does not support packing or padding_free tuning."
75+
"Please set packing=False and is_padding_free=False."
76+
)
7277
return VisionDataCollator(processor)
7378

7479
if packing:

tuning/data/setup_dataprocessor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -345,8 +345,7 @@ def _process_raw_data_args(
345345
handlers, dataset_text_field = _get_pretokenized_dataset_handlers(
346346
data_args, (is_eval_dataset_present and not is_evaldata_tokenized)
347347
)
348-
# TODO: Better way to handle vision this elif condition
349-
elif data_args.dataset_text_field and data_args.dataset_image_field:
348+
elif processor and data_args.dataset_text_field and data_args.dataset_image_field:
350349

351350
handlers, dataset_text_field = _get_vision_dataset_handlers(
352351
data_args, processor_kwargs

tuning/sft_trainer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,7 @@ def train(
230230
processor = None
231231
try:
232232
# try to load vision model
233-
model_loader = AutoModelForVision2Seq.from_pretrained
234-
model = model_loader(
233+
model = AutoModelForVision2Seq.from_pretrained(
235234
model_args.model_name_or_path,
236235
cache_dir=train_args.cache_dir,
237236
torch_dtype=get_torch_dtype(model_args.torch_dtype),

tuning/utils/collators.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,11 @@
1818

1919
class VisionDataCollator:
2020
"""
21-
A data collator specialized for multi-modal (text + image) inputs.
21+
A data collator specialized for vision model (text + image) inputs.
2222
It uses a processor (e.g., LlavaProcessor or MllamaProcessor) to
2323
combine text and images into model-ready tensors.
2424
25-
For padding-free tuning, configure the processor's arguments
26-
in `processor_kwargs`, for example:
27-
processor_kwargs = {
28-
"padding": False,
29-
"max_length": 1024,
30-
...
31-
}
32-
25+
Padding-free tuning is not supported.
3326
Args:
3427
processor: A processor (like `LlavaProcessor`, `MllamaProcessor`, etc.).
3528
"""

tuning/utils/tokenizer_data_utils.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -107,24 +107,10 @@ def tokenizer_and_embedding_resize(
107107
embedding_size = int(multiple_of * math.ceil(len(tokenizer) / multiple_of))
108108
num_new_tokens = num_new_tokens + embedding_size - len(tokenizer)
109109

110+
# For Mllama models, we need to resize the input and output embeddings
111+
# separately, as the model has a different input and output embeddings.
110112
if isinstance(model, MllamaForConditionalGeneration):
111-
# Get new input embedding size
112-
current_input_embeddings = model.get_input_embeddings()
113-
current_output_embeddings = model.get_output_embeddings()
114-
input_embedding_size = current_input_embeddings.weight.shape[0] + (
115-
embedding_size - current_output_embeddings.weight.shape[0]
116-
)
117-
118-
# Save current input embedding
119-
resized_input_embeddings = model._get_resized_embeddings(
120-
current_input_embeddings,
121-
new_num_tokens=input_embedding_size,
122-
mean_resizing=True,
123-
)
124-
resized_input_embeddings = copy.deepcopy(resized_input_embeddings)
125-
resized_input_embeddings.requires_grad_(
126-
current_input_embeddings.weight.requires_grad
127-
)
113+
resized_input_embeddings = get_resized_input_embeddings(model, embedding_size)
128114

129115
# Resize input and output embeddings
130116
model.resize_token_embeddings(embedding_size)
@@ -153,3 +139,32 @@ def tokenizer_and_embedding_resize(
153139
output_embeddings[-num_new_tokens:] = output_embeddings_avg
154140

155141
return {"num_new_tokens": num_new_tokens, "new_embedding_size": embedding_size}
142+
143+
144+
def get_resized_input_embeddings(model, embedding_size):
145+
"""Get resized input embeddings for Mllama models.
146+
Args:
147+
model: Mllama models.
148+
embedding_size: Size of the new embeddings.
149+
Returns:
150+
resized_input_embeddings: Resized input embeddings.
151+
"""
152+
# Get current input and output embeddings
153+
# and their respective vocab sizes
154+
current_input_embeddings = model.get_input_embeddings()
155+
current_output_embeddings = model.get_output_embeddings()
156+
input_embedding_size = current_input_embeddings.weight.shape[0] + (
157+
embedding_size - current_output_embeddings.weight.shape[0]
158+
)
159+
160+
# Save current input embedding
161+
resized_input_embeddings = model._get_resized_embeddings(
162+
current_input_embeddings,
163+
new_num_tokens=input_embedding_size,
164+
mean_resizing=True,
165+
)
166+
resized_input_embeddings = copy.deepcopy(resized_input_embeddings)
167+
resized_input_embeddings.requires_grad_(
168+
current_input_embeddings.weight.requires_grad
169+
)
170+
return resized_input_embeddings

0 commit comments

Comments
 (0)