Skip to content

Commit 3d86a38

Browse files
authored
feat: support multiple images per column in image context (#257)
* allow image context column to have multiple images * pack multimodal context at the front before user text messages * Fix edge case with numpy array
1 parent f6a2c57 commit 3d86a38

5 files changed

Lines changed: 196 additions & 29 deletions

File tree

packages/data-designer-config/src/data_designer/config/models.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from __future__ import annotations
55

6+
import json
67
import logging
78
from abc import ABC, abstractmethod
89
from enum import Enum
@@ -65,7 +66,7 @@ class ModalityContext(ABC, BaseModel):
6566
data_type: ModalityDataType
6667

6768
@abstractmethod
68-
def get_context(self, record: dict) -> dict[str, Any]: ...
69+
def get_contexts(self, record: dict) -> list[dict[str, Any]]: ...
6970

7071

7172
class ImageContext(ModalityContext):
@@ -81,25 +82,53 @@ class ImageContext(ModalityContext):
8182
modality: Modality = Modality.IMAGE
8283
image_format: ImageFormat | None = None
8384

84-
def get_context(self, record: dict) -> dict[str, Any]:
85-
"""Get the context for the image modality.
85+
def get_contexts(self, record: dict) -> list[dict[str, Any]]:
86+
"""Get the contexts for the image modality.
8687
8788
Args:
88-
record: The record containing the image data.
89+
record: The record containing the image data. The data can be:
90+
- A JSON serialized list of strings
91+
- A list of strings
92+
- A single string
8993
9094
Returns:
91-
The context for the image modality.
95+
A list of image contexts.
9296
"""
93-
context = dict(type="image_url")
94-
context_value = record[self.column_name]
95-
if self.data_type == ModalityDataType.URL:
96-
context["image_url"] = context_value
97+
raw_value = record[self.column_name]
98+
99+
# Normalize to list of strings
100+
if isinstance(raw_value, str):
101+
# Try to parse as JSON first
102+
try:
103+
parsed_value = json.loads(raw_value)
104+
if isinstance(parsed_value, list):
105+
context_values = parsed_value
106+
else:
107+
context_values = [raw_value]
108+
except (json.JSONDecodeError, TypeError):
109+
context_values = [raw_value]
110+
elif isinstance(raw_value, list):
111+
context_values = raw_value
112+
elif hasattr(raw_value, "__iter__") and not isinstance(raw_value, (str, bytes, dict)):
113+
# Handle array-like objects (numpy arrays, pandas Series, etc.)
114+
context_values = list(raw_value)
97115
else:
98-
context["image_url"] = {
99-
"url": f"data:image/{self.image_format.value};base64,{context_value}",
100-
"format": self.image_format.value,
101-
}
102-
return context
116+
context_values = [raw_value]
117+
118+
# Build context list
119+
contexts = []
120+
for context_value in context_values:
121+
context = dict(type="image_url")
122+
if self.data_type == ModalityDataType.URL:
123+
context["image_url"] = context_value
124+
else:
125+
context["image_url"] = {
126+
"url": f"data:image/{self.image_format.value};base64,{context_value}",
127+
"format": self.image_format.value,
128+
}
129+
contexts.append(context)
130+
131+
return contexts
103132

104133
@model_validator(mode="after")
105134
def _validate_image_format(self) -> Self:

packages/data-designer-config/tests/config/test_models.py

Lines changed: 147 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import json
55
import tempfile
66
from collections import Counter
7+
from typing import TYPE_CHECKING
78

89
import pytest
910
import yaml
@@ -24,22 +25,159 @@
2425
UniformDistributionParams,
2526
load_model_configs,
2627
)
28+
from data_designer.lazy_heavy_imports import np
2729

30+
if TYPE_CHECKING:
31+
import numpy as np
2832

29-
def test_image_context_get_context():
33+
34+
def test_image_context_get_contexts_single_string():
35+
"""Test get_contexts with a single string value."""
3036
image_context = ImageContext(
3137
column_name="image_base64", data_type=ModalityDataType.BASE64, image_format=ImageFormat.PNG
3238
)
33-
assert image_context.get_context({"image_base64": "somebase64encodedimagestring"}) == {
34-
"type": "image_url",
35-
"image_url": {"url": "data:image/png;base64,somebase64encodedimagestring", "format": "png"},
36-
}
39+
assert image_context.get_contexts({"image_base64": "somebase64encodedimagestring"}) == [
40+
{
41+
"type": "image_url",
42+
"image_url": {"url": "data:image/png;base64,somebase64encodedimagestring", "format": "png"},
43+
}
44+
]
3745

3846
image_context = ImageContext(column_name="image_url", data_type=ModalityDataType.URL)
39-
assert image_context.get_context({"image_url": "https://example.com/examle_image.png"}) == {
40-
"type": "image_url",
41-
"image_url": "https://example.com/examle_image.png",
42-
}
47+
assert image_context.get_contexts({"image_url": "https://example.com/examle_image.png"}) == [
48+
{
49+
"type": "image_url",
50+
"image_url": "https://example.com/examle_image.png",
51+
}
52+
]
53+
54+
55+
def test_image_context_get_contexts_list_of_strings():
56+
"""Test get_contexts with a list of strings."""
57+
image_context = ImageContext(
58+
column_name="image_base64", data_type=ModalityDataType.BASE64, image_format=ImageFormat.PNG
59+
)
60+
assert image_context.get_contexts({"image_base64": ["image1base64", "image2base64", "image3base64"]}) == [
61+
{
62+
"type": "image_url",
63+
"image_url": {"url": "data:image/png;base64,image1base64", "format": "png"},
64+
},
65+
{
66+
"type": "image_url",
67+
"image_url": {"url": "data:image/png;base64,image2base64", "format": "png"},
68+
},
69+
{
70+
"type": "image_url",
71+
"image_url": {"url": "data:image/png;base64,image3base64", "format": "png"},
72+
},
73+
]
74+
75+
image_context = ImageContext(column_name="image_url", data_type=ModalityDataType.URL)
76+
assert image_context.get_contexts(
77+
{"image_url": ["https://example.com/image1.png", "https://example.com/image2.png"]}
78+
) == [
79+
{
80+
"type": "image_url",
81+
"image_url": "https://example.com/image1.png",
82+
},
83+
{
84+
"type": "image_url",
85+
"image_url": "https://example.com/image2.png",
86+
},
87+
]
88+
89+
90+
def test_image_context_get_contexts_numpy_array():
91+
"""Test get_contexts with numpy arrays (happens after parquet serialization)."""
92+
image_context = ImageContext(
93+
column_name="image_base64", data_type=ModalityDataType.BASE64, image_format=ImageFormat.PNG
94+
)
95+
numpy_array = np.array(["image1base64", "image2base64"])
96+
assert image_context.get_contexts({"image_base64": numpy_array}) == [
97+
{
98+
"type": "image_url",
99+
"image_url": {"url": "data:image/png;base64,image1base64", "format": "png"},
100+
},
101+
{
102+
"type": "image_url",
103+
"image_url": {"url": "data:image/png;base64,image2base64", "format": "png"},
104+
},
105+
]
106+
107+
image_context = ImageContext(column_name="image_url", data_type=ModalityDataType.URL)
108+
numpy_array = np.array(["https://example.com/image1.png", "https://example.com/image2.png"])
109+
assert image_context.get_contexts({"image_url": numpy_array}) == [
110+
{
111+
"type": "image_url",
112+
"image_url": "https://example.com/image1.png",
113+
},
114+
{
115+
"type": "image_url",
116+
"image_url": "https://example.com/image2.png",
117+
},
118+
]
119+
120+
121+
def test_image_context_get_contexts_json_serialized_list():
122+
"""Test get_contexts with a JSON serialized list of strings."""
123+
image_context = ImageContext(
124+
column_name="image_base64", data_type=ModalityDataType.BASE64, image_format=ImageFormat.PNG
125+
)
126+
json_str = json.dumps(["image1base64", "image2base64"])
127+
assert image_context.get_contexts({"image_base64": json_str}) == [
128+
{
129+
"type": "image_url",
130+
"image_url": {"url": "data:image/png;base64,image1base64", "format": "png"},
131+
},
132+
{
133+
"type": "image_url",
134+
"image_url": {"url": "data:image/png;base64,image2base64", "format": "png"},
135+
},
136+
]
137+
138+
image_context = ImageContext(column_name="image_url", data_type=ModalityDataType.URL)
139+
json_str = json.dumps(["https://example.com/image1.png", "https://example.com/image2.png"])
140+
assert image_context.get_contexts({"image_url": json_str}) == [
141+
{
142+
"type": "image_url",
143+
"image_url": "https://example.com/image1.png",
144+
},
145+
{
146+
"type": "image_url",
147+
"image_url": "https://example.com/image2.png",
148+
},
149+
]
150+
151+
152+
def test_image_context_get_contexts_json_string_not_list():
153+
"""Test get_contexts with a JSON string that isn't a list (should treat as single string)."""
154+
image_context = ImageContext(column_name="image_url", data_type=ModalityDataType.URL)
155+
json_str = json.dumps({"nested": "object"})
156+
# Should treat the entire JSON string as a single image URL
157+
assert image_context.get_contexts({"image_url": json_str}) == [
158+
{
159+
"type": "image_url",
160+
"image_url": json_str,
161+
}
162+
]
163+
164+
165+
def test_image_context_get_contexts_invalid_json():
166+
"""Test get_contexts with invalid JSON string (should treat as single string)."""
167+
image_context = ImageContext(column_name="image_url", data_type=ModalityDataType.URL)
168+
invalid_json = "not a valid json string"
169+
assert image_context.get_contexts({"image_url": invalid_json}) == [
170+
{
171+
"type": "image_url",
172+
"image_url": invalid_json,
173+
}
174+
]
175+
176+
177+
def test_image_context_get_contexts_empty_list():
178+
"""Test get_contexts with an empty list."""
179+
image_context = ImageContext(column_name="image_url", data_type=ModalityDataType.URL)
180+
assert image_context.get_contexts({"image_url": []}) == []
43181

44182

45183
def test_image_context_validate_image_format():

packages/data-designer-engine/src/data_designer/engine/column_generators/generators/llm_completion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ def generate(self, data: dict) -> dict:
6262

6363
multi_modal_context = None
6464
if self.config.multi_modal_context is not None and len(self.config.multi_modal_context) > 0:
65-
multi_modal_context = [
66-
context.get_context(deserialized_record) for context in self.config.multi_modal_context
67-
]
65+
multi_modal_context = []
66+
for context in self.config.multi_modal_context:
67+
multi_modal_context.extend(context.get_contexts(deserialized_record))
6868

6969
response, reasoning_trace = self.model.generate(
7070
prompt=self.prompt_renderer.render(

packages/data-designer-engine/src/data_designer/engine/models/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ def prompt_to_messages(
2121
user_content = user_prompt
2222
if multi_modal_context and len(multi_modal_context) > 0:
2323
user_content = []
24-
user_content.append({"type": "text", "text": user_prompt})
2524
for context in multi_modal_context:
2625
user_content.append(context)
26+
user_content.append({"type": "text", "text": user_prompt})
2727
return (
2828
[
2929
str_to_message(content=system_prompt, role="system"),

packages/data-designer-engine/tests/engine/models/test_model_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ def test_prompt_to_messages():
2626
{"content": "hello", "role": "user"},
2727
]
2828
assert prompt_to_messages(user_prompt="hello", multi_modal_context=[mult_modal_context]) == [
29-
{"content": [{"type": "text", "text": "hello"}, mult_modal_context], "role": "user"}
29+
{"content": [mult_modal_context, {"type": "text", "text": "hello"}], "role": "user"}
3030
]
3131
assert prompt_to_messages(
3232
user_prompt="hello", system_prompt=stub_system_prompt, multi_modal_context=[mult_modal_context]
3333
) == [
3434
{"content": stub_system_prompt, "role": "system"},
35-
{"content": [{"type": "text", "text": "hello"}, mult_modal_context], "role": "user"},
35+
{"content": [mult_modal_context, {"type": "text", "text": "hello"}], "role": "user"},
3636
]

0 commit comments

Comments
 (0)