Skip to content

Commit 67e0db0

Browse files
anakin87OGuggenbuehl
authored andcommitted
feat: add ImageContent dataclass to include images in ChatMessage + OpenAI support (deepset-ai#9626)
1 parent 682b3d9 commit 67e0db0

12 files changed

Lines changed: 574 additions & 47 deletions

File tree

docs/pydoc/config/data_classess_api.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ loaders:
22
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
33
search_path: [../../../haystack/dataclasses]
44
modules:
5-
["answer", "byte_stream", "chat_message", "document", "sparse_embedding", "streaming_chunk"]
5+
["answer", "byte_stream", "chat_message", "document", "image_content", "sparse_embedding", "streaming_chunk"]
66
ignore_when_discovered: ["__init__"]
77
processors:
88
- type: filter

docs/pydoc/config_docusaurus/data_classess_api.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ loaders:
22
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
33
search_path: [../../../haystack/dataclasses]
44
modules:
5-
["answer", "byte_stream", "chat_message", "document", "sparse_embedding", "streaming_chunk"]
5+
["answer", "byte_stream", "chat_message", "document", "image_content", "sparse_embedding", "streaming_chunk"]
66
ignore_when_discovered: ["__init__"]
77
processors:
88
- type: filter

haystack/dataclasses/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"answer": ["Answer", "ExtractedAnswer", "GeneratedAnswer"],
1212
"byte_stream": ["ByteStream"],
1313
"chat_message": ["ChatMessage", "ChatRole", "TextContent", "ToolCall", "ToolCallResult"],
14+
"image_content": ["ImageContent"],
1415
"document": ["Document"],
1516
"sparse_embedding": ["SparseEmbedding"],
1617
"state": ["State"],
@@ -37,6 +38,7 @@
3738
from .chat_message import ToolCall as ToolCall
3839
from .chat_message import ToolCallResult as ToolCallResult
3940
from .document import Document as Document
41+
from .image_content import ImageContent as ImageContent
4042
from .sparse_embedding import SparseEmbedding as SparseEmbedding
4143
from .streaming_chunk import AsyncStreamingCallbackT as AsyncStreamingCallbackT
4244
from .streaming_chunk import ComponentInfo as ComponentInfo

haystack/dataclasses/chat_message.py

Lines changed: 127 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Any, Dict, List, Optional, Sequence, Union
99

1010
from haystack import logging
11+
from haystack.dataclasses.image_content import ImageContent
1112

1213
logger = logging.getLogger(__name__)
1314

@@ -86,7 +87,33 @@ class TextContent:
8687
text: str
8788

8889

89-
ChatMessageContentT = Union[TextContent, ToolCall, ToolCallResult]
90+
ChatMessageContentT = Union[TextContent, ToolCall, ToolCallResult, ImageContent]
91+
92+
93+
def _deserialize_content_part(part: Dict[str, Any]) -> ChatMessageContentT:
94+
"""
95+
Deserialize a single content part of a serialized ChatMessage.
96+
97+
:param part:
98+
A dictionary representing a single content part of a serialized ChatMessage.
99+
:returns:
100+
A ChatMessageContentT object.
101+
:raises ValueError:
102+
If the part is not a valid ChatMessageContentT object.
103+
"""
104+
if "text" in part:
105+
return TextContent(text=part["text"])
106+
if "tool_call" in part:
107+
return ToolCall(**part["tool_call"])
108+
if "tool_call_result" in part:
109+
result = part["tool_call_result"]["result"]
110+
origin = ToolCall(**part["tool_call_result"]["origin"])
111+
error = part["tool_call_result"]["error"]
112+
tcr = ToolCallResult(result=result, origin=origin, error=error)
113+
return tcr
114+
if "image" in part:
115+
return ImageContent(**part["image"])
116+
raise ValueError(f"Unsupported content part in the serialized ChatMessage: `{part}`")
90117

91118

92119
def _deserialize_content(serialized_content: List[Dict[str, Any]]) -> List[ChatMessageContentT]:
@@ -99,30 +126,29 @@ def _deserialize_content(serialized_content: List[Dict[str, Any]]) -> List[ChatM
99126
:returns:
100127
Deserialized `content` field as a list of `ChatMessageContentT` objects.
101128
"""
102-
content: List[ChatMessageContentT] = []
103-
104-
for part in serialized_content:
105-
if "text" in part:
106-
content.append(TextContent(text=part["text"]))
107-
elif "tool_call" in part:
108-
content.append(ToolCall(**part["tool_call"]))
109-
elif "tool_call_result" in part:
110-
result = part["tool_call_result"]["result"]
111-
origin = ToolCall(**part["tool_call_result"]["origin"])
112-
error = part["tool_call_result"]["error"]
113-
tcr = ToolCallResult(result=result, origin=origin, error=error)
114-
content.append(tcr)
115-
else:
116-
raise ValueError(
117-
f"Unsupported content part in the serialized ChatMessage: {part}. "
118-
"The `content` field of the serialized ChatMessage must be a list of dictionaries, where each "
119-
"dictionary contains one of these keys: 'text', 'tool_call', or 'tool_call_result'. "
120-
f"Valid formats: [{{'text': 'Hello'}}, "
121-
f"{{'tool_call': {{'tool_name': 'search', 'arguments': {{}}, 'id': 'call_123'}}}}, "
122-
f"{{'tool_call_result': {{'result': 'data', 'origin': {{...}}, 'error': false}}}}]"
123-
)
129+
return [_deserialize_content_part(part) for part in serialized_content]
124130

125-
return content
131+
132+
def _serialize_content_part(part: ChatMessageContentT) -> Dict[str, Any]:
133+
"""
134+
Serialize a single content part of a ChatMessage.
135+
136+
:param part:
137+
A ChatMessageContentT object.
138+
:returns:
139+
A dictionary representing the content part.
140+
:raises TypeError:
141+
If the part is not a valid ChatMessageContentT object.
142+
"""
143+
if isinstance(part, TextContent):
144+
return {"text": part.text}
145+
elif isinstance(part, ToolCall):
146+
return {"tool_call": asdict(part)}
147+
elif isinstance(part, ToolCallResult):
148+
return {"tool_call_result": asdict(part)}
149+
elif isinstance(part, ImageContent):
150+
return {"image": asdict(part)}
151+
raise TypeError(f"Unsupported type in ChatMessage content: `{type(part).__name__}` for `{part}`.")
126152

127153

128154
@dataclass
@@ -252,6 +278,22 @@ def tool_call_result(self) -> Optional[ToolCallResult]:
252278
return tool_call_results[0]
253279
return None
254280

281+
@property
282+
def images(self) -> List[ImageContent]:
283+
"""
284+
Returns the list of all images contained in the message.
285+
"""
286+
return [content for content in self._content if isinstance(content, ImageContent)]
287+
288+
@property
289+
def image(self) -> Optional[ImageContent]:
290+
"""
291+
Returns the first image contained in the message.
292+
"""
293+
if images := self.images:
294+
return images[0]
295+
return None
296+
255297
def is_from(self, role: Union[ChatRole, str]) -> bool:
256298
"""
257299
Check if the message is from a specific role.
@@ -264,16 +306,44 @@ def is_from(self, role: Union[ChatRole, str]) -> bool:
264306
return self._role == role
265307

266308
@classmethod
267-
def from_user(cls, text: str, meta: Optional[Dict[str, Any]] = None, name: Optional[str] = None) -> "ChatMessage":
309+
def from_user(
310+
cls,
311+
text: Optional[str] = None,
312+
meta: Optional[Dict[str, Any]] = None,
313+
name: Optional[str] = None,
314+
*,
315+
content_parts: Optional[Sequence[Union[TextContent, str, ImageContent]]] = None,
316+
) -> "ChatMessage":
268317
"""
269318
Create a message from the user.
270319
271-
:param text: The text content of the message.
320+
:param text: The text content of the message. Specify this or content_parts.
272321
:param meta: Additional metadata associated with the message.
273322
:param name: An optional name for the participant. This field is only supported by OpenAI.
323+
:param content_parts: A list of content parts to include in the message. Specify this or text.
274324
:returns: A new ChatMessage instance.
275325
"""
276-
return cls(_role=ChatRole.USER, _content=[TextContent(text=text)], _meta=meta or {}, _name=name)
326+
if not text and not content_parts:
327+
raise ValueError("Either text or content_parts must be provided.")
328+
if text and content_parts:
329+
raise ValueError("Only one of text or content_parts can be provided.")
330+
331+
content: Sequence[Union[TextContent, ImageContent]] = []
332+
333+
if text is not None:
334+
content = [TextContent(text=text)]
335+
elif content_parts is not None:
336+
content = [TextContent(el) if isinstance(el, str) else el for el in content_parts]
337+
if not any(isinstance(el, TextContent) for el in content):
338+
raise ValueError("The user message must contain at least one textual part.")
339+
340+
unsupported_parts = [el for el in content if not isinstance(el, (ImageContent, TextContent))]
341+
if unsupported_parts:
342+
raise ValueError(
343+
f"The user message must contain only text or image parts. Unsupported parts: {unsupported_parts}"
344+
)
345+
346+
return cls(_role=ChatRole.USER, _content=content, _meta=meta or {}, _name=name)
277347

278348
@classmethod
279349
def from_system(cls, text: str, meta: Optional[Dict[str, Any]] = None, name: Optional[str] = None) -> "ChatMessage":
@@ -343,18 +413,8 @@ def to_dict(self) -> Dict[str, Any]:
343413
serialized["role"] = self._role.value
344414
serialized["meta"] = self._meta
345415
serialized["name"] = self._name
346-
content: List[Dict[str, Any]] = []
347-
for part in self._content:
348-
if isinstance(part, TextContent):
349-
content.append({"text": part.text})
350-
elif isinstance(part, ToolCall):
351-
content.append({"tool_call": asdict(part)})
352-
elif isinstance(part, ToolCallResult):
353-
content.append({"tool_call_result": asdict(part)})
354-
else:
355-
raise TypeError(f"Unsupported type in ChatMessage content: `{type(part).__name__}` for `{part}`.")
356416

357-
serialized["content"] = content
417+
serialized["content"] = [_serialize_content_part(part) for part in self._content]
358418
return serialized
359419

360420
@classmethod
@@ -426,30 +486,56 @@ def to_openai_dict_format(self, require_tool_call_ids: bool = True) -> Dict[str,
426486
"A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`."
427487
)
428488
if len(text_contents) + len(tool_call_results) > 1:
429-
raise ValueError("A `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`.")
489+
raise ValueError(
490+
"For OpenAI compatibility, a `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`."
491+
)
430492

431493
openai_msg: Dict[str, Any] = {"role": self._role.value}
432494

433495
# Add name field if present
434496
if self._name is not None:
435497
openai_msg["name"] = self._name
436498

499+
# user message
500+
if openai_msg["role"] == "user":
501+
if len(self._content) == 1:
502+
openai_msg["content"] = self.text
503+
return openai_msg
504+
505+
# if the user message contains a list of text and images, OpenAI expects a list of dictionaries
506+
content = []
507+
for part in self._content:
508+
if isinstance(part, TextContent):
509+
content.append({"type": "text", "text": part.text})
510+
elif isinstance(part, ImageContent):
511+
image_item: Dict[str, Any] = {
512+
"type": "image_url",
513+
# If no MIME type is provided, default to JPEG.
514+
# OpenAI API appears to tolerate MIME type mismatches.
515+
"image_url": {"url": f"data:{part.mime_type or 'image/jpeg'};base64,{part.base64_image}"},
516+
}
517+
if part.detail:
518+
image_item["image_url"]["detail"] = part.detail
519+
content.append(image_item)
520+
openai_msg["content"] = content
521+
return openai_msg
522+
523+
# tool message
437524
if tool_call_results:
438525
result = tool_call_results[0]
439526
openai_msg["content"] = result.result
440527
if result.origin.id is not None:
441528
openai_msg["tool_call_id"] = result.origin.id
442529
elif require_tool_call_ids:
443530
raise ValueError("`ToolCall` must have a non-null `id` attribute to be used with OpenAI.")
444-
445531
# OpenAI does not provide a way to communicate errors in tool invocations, so we ignore the error field
446532
return openai_msg
447533

534+
# system and assistant messages
448535
if text_contents:
449536
openai_msg["content"] = text_contents[0]
450537
if tool_calls:
451538
openai_tool_calls = []
452-
453539
for tc in tool_calls:
454540
openai_tool_call = {
455541
"type": "function",
@@ -461,7 +547,6 @@ def to_openai_dict_format(self, require_tool_call_ids: bool = True) -> Dict[str,
461547
elif require_tool_call_ids:
462548
raise ValueError("`ToolCall` must have a non-null `id` attribute to be used with OpenAI.")
463549
openai_tool_calls.append(openai_tool_call)
464-
465550
openai_msg["tool_calls"] = openai_tool_calls
466551
return openai_msg
467552

0 commit comments

Comments
 (0)