|
27 | 27 |
|
28 | 28 | # Third Party |
29 | 29 | from datasets.exceptions import DatasetGenerationError, DatasetNotFoundError |
| 30 | +from PIL import Image |
| 31 | +from transformers import AutoProcessor |
30 | 32 | from transformers.trainer_callback import TrainerCallback |
| 33 | +import numpy as np |
31 | 34 | import pytest |
32 | 35 | import torch |
33 | 36 | import transformers |
|
90 | 93 | DataHandlerType, |
91 | 94 | add_tokenizer_eos_token, |
92 | 95 | ) |
| 96 | +from tuning.utils.collators import VisionDataCollator |
93 | 97 | from tuning.utils.import_utils import is_fms_accelerate_available |
94 | 98 |
|
95 | 99 | MODEL_ARGS = configs.ModelArguments( |
|
123 | 127 | ) |
124 | 128 |
|
125 | 129 | PEFT_LORA_ARGS = peft_config.LoraConfig(r=8, lora_alpha=32, lora_dropout=0.05) |
| 130 | +LLAMA_VISION_MODEL_NAME = "tests/artifacts/tiny-llama-vision-model" |
126 | 131 |
|
127 | 132 |
|
128 | 133 | @pytest.mark.parametrize( |
@@ -1969,3 +1974,51 @@ def test_handler(element, tokenizer, **kwargs): |
1969 | 1974 | }, |
1970 | 1975 | ) |
1971 | 1976 | _validate_training(tempdir) |
| 1977 | + |
| 1978 | + |
| 1979 | +def test_vision_data_collator(): |
| 1980 | + """Test the VisionDataCollator with dummy Image data.""" |
| 1981 | + |
| 1982 | + processor = AutoProcessor.from_pretrained(LLAMA_VISION_MODEL_NAME) |
| 1983 | + collator = VisionDataCollator(processor) |
| 1984 | + processor_kwargs = {} |
| 1985 | + processor_kwargs["return_tensors"] = "pt" |
| 1986 | + processor_kwargs["padding"] = True |
| 1987 | + image_size = (32, 32) |
| 1988 | + |
| 1989 | + def generate_pil_image(size=image_size): |
| 1990 | + """Generate a dummy image array of the specified size and return PIL Image.""" |
| 1991 | + image_array = np.random.randint(0, 256, size=(*size, 3), dtype=np.uint8) |
| 1992 | + return Image.fromarray(image_array) |
| 1993 | + |
| 1994 | + image1 = generate_pil_image() |
| 1995 | + image2 = generate_pil_image() |
| 1996 | + |
| 1997 | + features = [ |
| 1998 | + { |
| 1999 | + "processor_kwargs": processor_kwargs, |
| 2000 | + "fields_name": { |
| 2001 | + "dataset_text_field": "text", |
| 2002 | + "dataset_image_field": "image", |
| 2003 | + }, |
| 2004 | + "text": "Describe the image.", |
| 2005 | + "image": [image1], |
| 2006 | + }, |
| 2007 | + { |
| 2008 | + "processor_kwargs": processor_kwargs, |
| 2009 | + "fields_name": { |
| 2010 | + "dataset_text_field": "text", |
| 2011 | + "dataset_image_field": "image", |
| 2012 | + }, |
| 2013 | + "text": "What is in the image?", |
| 2014 | + "image": [image2], |
| 2015 | + }, |
| 2016 | + ] |
| 2017 | + |
| 2018 | + # Call the collator which returns a batch dictionary containing "input_ids" and "labels" |
| 2019 | + batch = collator(features) |
| 2020 | + |
| 2021 | + assert "input_ids" in batch |
| 2022 | + assert "labels" in batch |
| 2023 | + assert "attention_mask" in batch |
| 2024 | + assert batch["input_ids"].shape == batch["labels"].shape |
0 commit comments