Skip to content

Commit d3c30a0

Browse files
fix: vision image token insertion (#671)
* fix: vision model image token insertion and transformers v5 compatibility Signed-off-by: yashasvi <yashasvi@ibm.com> * fix: resolve test failures Signed-off-by: yashasvi <yashasvi@ibm.com> * style: remove trailing comma in pytest.mark.skipif Signed-off-by: yashasvi <yashasvi@ibm.com> * fix: formatting and lint issues Signed-off-by: yashasvi <yashasvi@ibm.com> * fix: resolve conflicts Signed-off-by: yashasvi <yashasvi@ibm.com> --------- Signed-off-by: yashasvi <yashasvi@ibm.com> Signed-off-by: Yashasvi Chaurasia <46622381+YashasviChaurasia@users.noreply.github.com>
1 parent bcc6fd0 commit d3c30a0

3 files changed

Lines changed: 312 additions & 8 deletions

File tree

tests/data/test_data_preprocessing.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
DataPreProcessorConfig,
7676
DataSetConfig,
7777
)
78+
from tuning.data.data_handlers import apply_tokenizer_chat_template
7879
from tuning.data.data_preprocessing_utils import get_data_collator
7980
from tuning.data.data_processors import DataPreProcessor, get_datapreprocessor
8081
from tuning.data.setup_dataprocessor import (
@@ -2156,6 +2157,191 @@ def test_multimodal_processor_injection_in_data_pipeline(model_name):
21562157
assert "image" in result["train"].column_names
21572158

21582159

2160+
@pytest.mark.parametrize(
2161+
"model_name",
2162+
[TINY_LLAMA_VISION_MODEL_NAME, TINY_GRANITE_VISION_MODEL_NAME],
2163+
)
2164+
def test_vision_chat_template_with_image_tokens(model_name):
2165+
"""Test that apply_tokenizer_chat_template correctly inserts image tokens for vision models.
2166+
2167+
This is a regression test for the bug where vision datasets with OpenAI format
2168+
would fail with 'Image features and image tokens do not match: tokens: 0, features 18432'
2169+
because image tokens were not being inserted into the formatted text.
2170+
2171+
Tests both:
2172+
1. Explicit conversation_column_name='messages'
2173+
2. Auto-detection when conversation_column_name=None
2174+
"""
2175+
processor = AutoProcessor.from_pretrained(model_name)
2176+
2177+
# Create a test element with OpenAI conversational format + image
2178+
# This mimics the user's dataset structure
2179+
pil_image = Image.fromarray(np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8))
2180+
2181+
element = {
2182+
"messages": [
2183+
{
2184+
"role": "user",
2185+
"content": [
2186+
{"type": "text", "text": "Describe this image."},
2187+
{"type": "image"},
2188+
],
2189+
},
2190+
],
2191+
"image": [pil_image],
2192+
}
2193+
2194+
# Test 1: With explicit conversation_column_name
2195+
result = apply_tokenizer_chat_template(
2196+
element=element,
2197+
formatted_text_column_name="text",
2198+
conversation_column_name="messages",
2199+
tokenizer=processor.tokenizer,
2200+
processor=processor,
2201+
)
2202+
2203+
# Verify image tokens are present in formatted text
2204+
assert "text" in result
2205+
formatted_text = result["text"]
2206+
assert processor.image_token in formatted_text, (
2207+
f"Image token '{processor.image_token}' not found in formatted text. "
2208+
f"This would cause 'Image features and image tokens do not match' error. "
2209+
f"Text: {formatted_text}"
2210+
)
2211+
2212+
# Test 2: With auto-detection (conversation_column_name=None)
2213+
result_auto = apply_tokenizer_chat_template(
2214+
element=element,
2215+
formatted_text_column_name="text",
2216+
conversation_column_name=None, # Trigger auto-detection
2217+
tokenizer=processor.tokenizer,
2218+
processor=processor,
2219+
)
2220+
2221+
assert "text" in result_auto
2222+
formatted_text_auto = result_auto["text"]
2223+
assert processor.image_token in formatted_text_auto, (
2224+
f"Auto-detection failed: Image token '{processor.image_token}' not found. "
2225+
f"Text: {formatted_text_auto}"
2226+
)
2227+
2228+
# Test 3: Verify collator can process the formatted text correctly
2229+
collator = VisionDataCollator(processor)
2230+
batch_element = {
2231+
"text": result["text"],
2232+
"image": element["image"],
2233+
"fields_name": {"dataset_text_field": "text", "dataset_image_field": "image"},
2234+
"processor_kwargs": {"return_tensors": "pt", "padding": True},
2235+
}
2236+
batch = collator([batch_element])
2237+
2238+
# Verify image tokens are present in tokenized input_ids
2239+
assert "input_ids" in batch
2240+
assert "pixel_values" in batch
2241+
assert "labels" in batch
2242+
2243+
image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
2244+
num_image_tokens = (batch["input_ids"] == image_token_id).sum().item()
2245+
assert num_image_tokens > 0, (
2246+
f"No image tokens found in tokenized input_ids. "
2247+
f"This is the exact bug: tokens=0 would cause the error. "
2248+
f"Image token ID: {image_token_id}, input_ids: {batch['input_ids']}"
2249+
)
2250+
2251+
# Verify pixel values have correct shape
2252+
assert batch["pixel_values"].shape[0] > 0, "No pixel values in batch"
2253+
2254+
2255+
@pytest.mark.parametrize(
2256+
"model_name",
2257+
[TINY_LLAMA_VISION_MODEL_NAME, TINY_GRANITE_VISION_MODEL_NAME],
2258+
)
2259+
def test_vision_chat_template_error_without_conversation_column(model_name):
2260+
"""Test that apply_tokenizer_chat_template raises helpful error when
2261+
auto-detection fails and image tokens are missing.
2262+
2263+
This ensures users get actionable error messages when their dataset
2264+
doesn't have a standard conversation column name.
2265+
"""
2266+
processor = AutoProcessor.from_pretrained(model_name)
2267+
2268+
# Create element with non-standard column name that won't be auto-detected
2269+
pil_image = Image.fromarray(np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8))
2270+
2271+
element = {
2272+
"custom_conversation_field": [
2273+
{
2274+
"role": "user",
2275+
"content": "Just plain text, no image placeholder",
2276+
},
2277+
],
2278+
"image": [pil_image],
2279+
}
2280+
2281+
# This should raise an error because:
2282+
# 1. Auto-detection won't find 'custom_conversation_field'
2283+
# 2. Either: Template rendering fails (KeyError) OR
2284+
# formatted text lacks image tokens (ValueError)
2285+
# Both are acceptable - the point is it fails with a clear error
2286+
with pytest.raises((ValueError, KeyError)) as exc_info:
2287+
apply_tokenizer_chat_template(
2288+
element=element,
2289+
formatted_text_column_name="text",
2290+
conversation_column_name=None, # Let auto-detection fail
2291+
tokenizer=processor.tokenizer,
2292+
processor=processor,
2293+
)
2294+
2295+
error_msg = str(exc_info.value)
2296+
# Verify the error message mentions the problem
2297+
# (either template issue or missing image tokens)
2298+
assert error_msg # Non-empty error message is good enough
2299+
2300+
2301+
@pytest.mark.parametrize(
2302+
"model_name",
2303+
[TINY_LLAMA_VISION_MODEL_NAME, TINY_GRANITE_VISION_MODEL_NAME],
2304+
)
2305+
def test_vision_chat_template_multiple_images(model_name):
2306+
"""Test that multiple images in a conversation are handled correctly."""
2307+
processor = AutoProcessor.from_pretrained(model_name)
2308+
2309+
# Create element with multiple images
2310+
pil_image1 = Image.fromarray(np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8))
2311+
pil_image2 = Image.fromarray(np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8))
2312+
2313+
element = {
2314+
"messages": [
2315+
{
2316+
"role": "user",
2317+
"content": [
2318+
{"type": "text", "text": "Compare these images:"},
2319+
{"type": "image"},
2320+
{"type": "text", "text": "and"},
2321+
{"type": "image"},
2322+
],
2323+
},
2324+
],
2325+
"image": [pil_image1, pil_image2],
2326+
}
2327+
2328+
result = apply_tokenizer_chat_template(
2329+
element=element,
2330+
formatted_text_column_name="text",
2331+
conversation_column_name="messages",
2332+
tokenizer=processor.tokenizer,
2333+
processor=processor,
2334+
)
2335+
2336+
formatted_text = result["text"]
2337+
# Count image tokens in formatted text
2338+
num_image_tokens_in_text = formatted_text.count(processor.image_token)
2339+
assert num_image_tokens_in_text >= 1, (
2340+
f"Expected at least 1 image token, found {num_image_tokens_in_text}. "
2341+
f"Text: {formatted_text}"
2342+
)
2343+
2344+
21592345
def test_multimodal_processor_missing_raises_error():
21602346
"""Test that prepare_multimodal_data_processor raises ValueError when
21612347
processor is None (i.e., DataPreProcessor was created without a processor).

tests/test_sft_trainer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1918,6 +1918,10 @@ def test_run_moe_ft_with_save_model_dir(dataset_path):
19181918
assert os.path.exists(os.path.join(save_model_dir))
19191919

19201920

1921+
@pytest.mark.skipif(
1922+
torch.backends.mps.is_available() and not torch.cuda.is_available(),
1923+
reason="MoE models have histogram incompatibility with MPS backend",
1924+
)
19211925
@pytest.mark.parametrize(
19221926
"datafiles, dataconfigfile",
19231927
[
@@ -2192,6 +2196,7 @@ def test_empty_data():
21922196
data_args = copy.deepcopy(DATA_ARGS)
21932197
data_args.training_data_path = EMPTY_DATA
21942198

2199+
# StopIteration is raised by datasets library in transformers v5 when processing empty files
21952200
with pytest.raises((DatasetGenerationError, ValueError, StopIteration)):
21962201
sft_trainer.train(
21972202
copy.deepcopy(MODEL_ARGS),

0 commit comments

Comments
 (0)