|
75 | 75 | DataPreProcessorConfig, |
76 | 76 | DataSetConfig, |
77 | 77 | ) |
| 78 | +from tuning.data.data_handlers import apply_tokenizer_chat_template |
78 | 79 | from tuning.data.data_preprocessing_utils import get_data_collator |
79 | 80 | from tuning.data.data_processors import DataPreProcessor, get_datapreprocessor |
80 | 81 | from tuning.data.setup_dataprocessor import ( |
@@ -2156,6 +2157,191 @@ def test_multimodal_processor_injection_in_data_pipeline(model_name): |
2156 | 2157 | assert "image" in result["train"].column_names |
2157 | 2158 |
|
2158 | 2159 |
|
| 2160 | +@pytest.mark.parametrize( |
| 2161 | + "model_name", |
| 2162 | + [TINY_LLAMA_VISION_MODEL_NAME, TINY_GRANITE_VISION_MODEL_NAME], |
| 2163 | +) |
| 2164 | +def test_vision_chat_template_with_image_tokens(model_name): |
| 2165 | + """Test that apply_tokenizer_chat_template correctly inserts image tokens for vision models. |
| 2166 | +
|
| 2167 | + This is a regression test for the bug where vision datasets with OpenAI format |
| 2168 | + would fail with 'Image features and image tokens do not match: tokens: 0, features 18432' |
| 2169 | + because image tokens were not being inserted into the formatted text. |
| 2170 | +
|
| 2171 | + Tests both: |
| 2172 | + 1. Explicit conversation_column_name='messages' |
| 2173 | + 2. Auto-detection when conversation_column_name=None |
| 2174 | + """ |
| 2175 | + processor = AutoProcessor.from_pretrained(model_name) |
| 2176 | + |
| 2177 | + # Create a test element with OpenAI conversational format + image |
| 2178 | + # This mimics the user's dataset structure |
| 2179 | + pil_image = Image.fromarray(np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)) |
| 2180 | + |
| 2181 | + element = { |
| 2182 | + "messages": [ |
| 2183 | + { |
| 2184 | + "role": "user", |
| 2185 | + "content": [ |
| 2186 | + {"type": "text", "text": "Describe this image."}, |
| 2187 | + {"type": "image"}, |
| 2188 | + ], |
| 2189 | + }, |
| 2190 | + ], |
| 2191 | + "image": [pil_image], |
| 2192 | + } |
| 2193 | + |
| 2194 | + # Test 1: With explicit conversation_column_name |
| 2195 | + result = apply_tokenizer_chat_template( |
| 2196 | + element=element, |
| 2197 | + formatted_text_column_name="text", |
| 2198 | + conversation_column_name="messages", |
| 2199 | + tokenizer=processor.tokenizer, |
| 2200 | + processor=processor, |
| 2201 | + ) |
| 2202 | + |
| 2203 | + # Verify image tokens are present in formatted text |
| 2204 | + assert "text" in result |
| 2205 | + formatted_text = result["text"] |
| 2206 | + assert processor.image_token in formatted_text, ( |
| 2207 | + f"Image token '{processor.image_token}' not found in formatted text. " |
| 2208 | + f"This would cause 'Image features and image tokens do not match' error. " |
| 2209 | + f"Text: {formatted_text}" |
| 2210 | + ) |
| 2211 | + |
| 2212 | + # Test 2: With auto-detection (conversation_column_name=None) |
| 2213 | + result_auto = apply_tokenizer_chat_template( |
| 2214 | + element=element, |
| 2215 | + formatted_text_column_name="text", |
| 2216 | + conversation_column_name=None, # Trigger auto-detection |
| 2217 | + tokenizer=processor.tokenizer, |
| 2218 | + processor=processor, |
| 2219 | + ) |
| 2220 | + |
| 2221 | + assert "text" in result_auto |
| 2222 | + formatted_text_auto = result_auto["text"] |
| 2223 | + assert processor.image_token in formatted_text_auto, ( |
| 2224 | + f"Auto-detection failed: Image token '{processor.image_token}' not found. " |
| 2225 | + f"Text: {formatted_text_auto}" |
| 2226 | + ) |
| 2227 | + |
| 2228 | + # Test 3: Verify collator can process the formatted text correctly |
| 2229 | + collator = VisionDataCollator(processor) |
| 2230 | + batch_element = { |
| 2231 | + "text": result["text"], |
| 2232 | + "image": element["image"], |
| 2233 | + "fields_name": {"dataset_text_field": "text", "dataset_image_field": "image"}, |
| 2234 | + "processor_kwargs": {"return_tensors": "pt", "padding": True}, |
| 2235 | + } |
| 2236 | + batch = collator([batch_element]) |
| 2237 | + |
| 2238 | + # Verify image tokens are present in tokenized input_ids |
| 2239 | + assert "input_ids" in batch |
| 2240 | + assert "pixel_values" in batch |
| 2241 | + assert "labels" in batch |
| 2242 | + |
| 2243 | + image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token) |
| 2244 | + num_image_tokens = (batch["input_ids"] == image_token_id).sum().item() |
| 2245 | + assert num_image_tokens > 0, ( |
| 2246 | + f"No image tokens found in tokenized input_ids. " |
| 2247 | + f"This is the exact bug: tokens=0 would cause the error. " |
| 2248 | + f"Image token ID: {image_token_id}, input_ids: {batch['input_ids']}" |
| 2249 | + ) |
| 2250 | + |
| 2251 | + # Verify pixel values have correct shape |
| 2252 | + assert batch["pixel_values"].shape[0] > 0, "No pixel values in batch" |
| 2253 | + |
| 2254 | + |
| 2255 | +@pytest.mark.parametrize( |
| 2256 | + "model_name", |
| 2257 | + [TINY_LLAMA_VISION_MODEL_NAME, TINY_GRANITE_VISION_MODEL_NAME], |
| 2258 | +) |
| 2259 | +def test_vision_chat_template_error_without_conversation_column(model_name): |
| 2260 | + """Test that apply_tokenizer_chat_template raises helpful error when |
| 2261 | + auto-detection fails and image tokens are missing. |
| 2262 | +
|
| 2263 | + This ensures users get actionable error messages when their dataset |
| 2264 | + doesn't have a standard conversation column name. |
| 2265 | + """ |
| 2266 | + processor = AutoProcessor.from_pretrained(model_name) |
| 2267 | + |
| 2268 | + # Create element with non-standard column name that won't be auto-detected |
| 2269 | + pil_image = Image.fromarray(np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)) |
| 2270 | + |
| 2271 | + element = { |
| 2272 | + "custom_conversation_field": [ |
| 2273 | + { |
| 2274 | + "role": "user", |
| 2275 | + "content": "Just plain text, no image placeholder", |
| 2276 | + }, |
| 2277 | + ], |
| 2278 | + "image": [pil_image], |
| 2279 | + } |
| 2280 | + |
| 2281 | + # This should raise an error because: |
| 2282 | + # 1. Auto-detection won't find 'custom_conversation_field' |
| 2283 | + # 2. Either: Template rendering fails (KeyError) OR |
| 2284 | + # formatted text lacks image tokens (ValueError) |
| 2285 | + # Both are acceptable - the point is it fails with a clear error |
| 2286 | + with pytest.raises((ValueError, KeyError)) as exc_info: |
| 2287 | + apply_tokenizer_chat_template( |
| 2288 | + element=element, |
| 2289 | + formatted_text_column_name="text", |
| 2290 | + conversation_column_name=None, # Let auto-detection fail |
| 2291 | + tokenizer=processor.tokenizer, |
| 2292 | + processor=processor, |
| 2293 | + ) |
| 2294 | + |
| 2295 | + error_msg = str(exc_info.value) |
| 2296 | + # Verify the error message mentions the problem |
| 2297 | + # (either template issue or missing image tokens) |
| 2298 | + assert error_msg # Non-empty error message is good enough |
| 2299 | + |
| 2300 | + |
| 2301 | +@pytest.mark.parametrize( |
| 2302 | + "model_name", |
| 2303 | + [TINY_LLAMA_VISION_MODEL_NAME, TINY_GRANITE_VISION_MODEL_NAME], |
| 2304 | +) |
| 2305 | +def test_vision_chat_template_multiple_images(model_name): |
| 2306 | + """Test that multiple images in a conversation are handled correctly.""" |
| 2307 | + processor = AutoProcessor.from_pretrained(model_name) |
| 2308 | + |
| 2309 | + # Create element with multiple images |
| 2310 | + pil_image1 = Image.fromarray(np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)) |
| 2311 | + pil_image2 = Image.fromarray(np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)) |
| 2312 | + |
| 2313 | + element = { |
| 2314 | + "messages": [ |
| 2315 | + { |
| 2316 | + "role": "user", |
| 2317 | + "content": [ |
| 2318 | + {"type": "text", "text": "Compare these images:"}, |
| 2319 | + {"type": "image"}, |
| 2320 | + {"type": "text", "text": "and"}, |
| 2321 | + {"type": "image"}, |
| 2322 | + ], |
| 2323 | + }, |
| 2324 | + ], |
| 2325 | + "image": [pil_image1, pil_image2], |
| 2326 | + } |
| 2327 | + |
| 2328 | + result = apply_tokenizer_chat_template( |
| 2329 | + element=element, |
| 2330 | + formatted_text_column_name="text", |
| 2331 | + conversation_column_name="messages", |
| 2332 | + tokenizer=processor.tokenizer, |
| 2333 | + processor=processor, |
| 2334 | + ) |
| 2335 | + |
| 2336 | + formatted_text = result["text"] |
| 2337 | + # Count image tokens in formatted text |
| 2338 | + num_image_tokens_in_text = formatted_text.count(processor.image_token) |
| 2339 | + assert num_image_tokens_in_text >= 1, ( |
| 2340 | + f"Expected at least 1 image token, found {num_image_tokens_in_text}. " |
| 2341 | + f"Text: {formatted_text}" |
| 2342 | + ) |
| 2343 | + |
| 2344 | + |
2159 | 2345 | def test_multimodal_processor_missing_raises_error(): |
2160 | 2346 | """Test that prepare_multimodal_data_processor raises ValueError when |
2161 | 2347 | processor is None (i.e., DataPreProcessor was created without a processor). |
|
0 commit comments