Skip to content

Commit db1fcf9

Browse files
committed
Added test_vision_data_collator
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
1 parent 5bdcd99 commit db1fcf9

1 file changed

Lines changed: 53 additions & 0 deletions

File tree

tests/test_sft_trainer.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@
2727

2828
# Third Party
2929
from datasets.exceptions import DatasetGenerationError, DatasetNotFoundError
30+
from PIL import Image
31+
from transformers import AutoProcessor
3032
from transformers.trainer_callback import TrainerCallback
33+
import numpy as np
3134
import pytest
3235
import torch
3336
import transformers
@@ -90,6 +93,7 @@
9093
DataHandlerType,
9194
add_tokenizer_eos_token,
9295
)
96+
from tuning.utils.collators import VisionDataCollator
9397
from tuning.utils.import_utils import is_fms_accelerate_available
9498

9599
MODEL_ARGS = configs.ModelArguments(
@@ -123,6 +127,7 @@
123127
)
124128

125129
PEFT_LORA_ARGS = peft_config.LoraConfig(r=8, lora_alpha=32, lora_dropout=0.05)
130+
LLAMA_VISION_MODEL_NAME = "tests/artifacts/tiny-llama-vision-model"
126131

127132

128133
@pytest.mark.parametrize(
@@ -1969,3 +1974,51 @@ def test_handler(element, tokenizer, **kwargs):
19691974
},
19701975
)
19711976
_validate_training(tempdir)
1977+
1978+
1979+
def test_vision_data_collator():
1980+
"""Test the VisionDataCollator with dummy Image data."""
1981+
1982+
processor = AutoProcessor.from_pretrained(LLAMA_VISION_MODEL_NAME)
1983+
collator = VisionDataCollator(processor)
1984+
processor_kwargs = {}
1985+
processor_kwargs["return_tensors"] = "pt"
1986+
processor_kwargs["padding"] = True
1987+
image_size = (32, 32)
1988+
1989+
def generate_pil_image(size=image_size):
1990+
"""Generate a dummy image array of the specified size and return PIL Image."""
1991+
image_array = np.random.randint(0, 256, size=(*size, 3), dtype=np.uint8)
1992+
return Image.fromarray(image_array)
1993+
1994+
image1 = generate_pil_image()
1995+
image2 = generate_pil_image()
1996+
1997+
features = [
1998+
{
1999+
"processor_kwargs": processor_kwargs,
2000+
"fields_name": {
2001+
"dataset_text_field": "text",
2002+
"dataset_image_field": "image",
2003+
},
2004+
"text": "Describe the image.",
2005+
"image": [image1],
2006+
},
2007+
{
2008+
"processor_kwargs": processor_kwargs,
2009+
"fields_name": {
2010+
"dataset_text_field": "text",
2011+
"dataset_image_field": "image",
2012+
},
2013+
"text": "What is in the image?",
2014+
"image": [image2],
2015+
},
2016+
]
2017+
2018+
# Call the collator which returns a batch dictionary containing "input_ids" and "labels"
2019+
batch = collator(features)
2020+
2021+
assert "input_ids" in batch
2022+
assert "labels" in batch
2023+
assert "attention_mask" in batch
2024+
assert batch["input_ids"].shape == batch["labels"].shape

0 commit comments

Comments
 (0)