Skip to content

Commit 35e6936

Browse files
anakin87sjrl
andauthored
feat: add ReasoningContent to ChatMessage (#9696)
* feat: add ReasoningContent to ChatMessage * more tests * release note * Update haystack/dataclasses/chat_message.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> --------- Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com>
1 parent 683c935 commit 35e6936

7 files changed

Lines changed: 237 additions & 13 deletions

File tree

haystack/dataclasses/chat_message.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,13 +145,48 @@ def from_dict(cls, data: dict[str, Any]) -> "TextContent":
145145
return TextContent(**data)
146146

147147

148-
ChatMessageContentT = Union[TextContent, ToolCall, ToolCallResult, ImageContent]
148+
@dataclass
149+
class ReasoningContent:
150+
"""
151+
Represents the optional reasoning content prepared by the model, usually contained in an assistant message.
152+
153+
:param reasoning_text: The reasoning text produced by the model.
154+
:param extra: Dictionary of extra information about the reasoning content. Use to store provider-specific
155+
information. To avoid serialization issues, values should be JSON serializable.
156+
"""
157+
158+
reasoning_text: str
159+
extra: dict[str, Any] = field(default_factory=dict)
160+
161+
def to_dict(self) -> dict[str, Any]:
162+
"""
163+
Convert ReasoningContent into a dictionary.
164+
165+
:returns: A dictionary with keys 'reasoning_text', and 'extra'.
166+
"""
167+
return asdict(self)
168+
169+
@classmethod
170+
def from_dict(cls, data: dict[str, Any]) -> "ReasoningContent":
171+
"""
172+
Creates a new ReasoningContent object from a dictionary.
173+
174+
:param data:
175+
The dictionary to build the ReasoningContent object.
176+
:returns:
177+
The created object.
178+
"""
179+
return ReasoningContent(**data)
180+
181+
182+
ChatMessageContentT = Union[TextContent, ToolCall, ToolCallResult, ImageContent, ReasoningContent]
149183

150184
_CONTENT_PART_CLASSES_TO_SERIALIZATION_KEYS: dict[type[ChatMessageContentT], str] = {
151185
TextContent: "text",
152186
ToolCall: "tool_call",
153187
ToolCallResult: "tool_call_result",
154188
ImageContent: "image",
189+
ReasoningContent: "reasoning",
155190
}
156191

157192

@@ -200,7 +235,7 @@ def _serialize_content_part(part: ChatMessageContentT) -> dict[str, Any]:
200235

201236

202237
@dataclass
203-
class ChatMessage:
238+
class ChatMessage: # pylint: disable=too-many-public-methods # it's OK since we expose several properties
204239
"""
205240
Represents a message in a LLM chat conversation.
206241
@@ -334,6 +369,22 @@ def image(self) -> Optional[ImageContent]:
334369
return images[0]
335370
return None
336371

372+
@property
373+
def reasonings(self) -> list[ReasoningContent]:
374+
"""
375+
Returns the list of all reasoning contents contained in the message.
376+
"""
377+
return [content for content in self._content if isinstance(content, ReasoningContent)]
378+
379+
@property
380+
def reasoning(self) -> Optional[ReasoningContent]:
381+
"""
382+
Returns the first reasoning content contained in the message.
383+
"""
384+
if reasonings := self.reasonings:
385+
return reasonings[0]
386+
return None
387+
337388
def is_from(self, role: Union[ChatRole, str]) -> bool:
338389
"""
339390
Check if the message is from a specific role.
@@ -406,17 +457,27 @@ def from_assistant(
406457
meta: Optional[dict[str, Any]] = None,
407458
name: Optional[str] = None,
408459
tool_calls: Optional[list[ToolCall]] = None,
460+
*,
461+
reasoning: Optional[Union[str, ReasoningContent]] = None,
409462
) -> "ChatMessage":
410463
"""
411464
Create a message from the assistant.
412465
413466
:param text: The text content of the message.
414467
:param meta: Additional metadata associated with the message.
415-
:param tool_calls: The Tool calls to include in the message.
416468
:param name: An optional name for the participant. This field is only supported by OpenAI.
469+
:param tool_calls: The Tool calls to include in the message.
470+
:param reasoning: The reasoning content to include in the message.
417471
:returns: A new ChatMessage instance.
418472
"""
419473
content: list[ChatMessageContentT] = []
474+
if reasoning:
475+
if isinstance(reasoning, str):
476+
content.append(ReasoningContent(reasoning_text=reasoning))
477+
elif isinstance(reasoning, ReasoningContent):
478+
content.append(reasoning)
479+
else:
480+
raise TypeError(f"reasoning must be a string or a ReasoningContent object, got {type(reasoning)}")
420481
if text is not None:
421482
content.append(TextContent(text=text))
422483
if tool_calls:
@@ -576,6 +637,7 @@ def to_openai_dict_format(self, require_tool_call_ids: bool = True) -> dict[str,
576637
return openai_msg
577638

578639
# system and assistant messages
640+
# OpenAI Chat Completions API does not support reasoning content, so we ignore it
579641
if text_contents:
580642
openai_msg["content"] = text_contents[0]
581643
if tool_calls:

haystack/utils/jinja2_chat_extension.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
ChatMessageContentT,
1515
ChatRole,
1616
ImageContent,
17+
ReasoningContent,
1718
TextContent,
1819
ToolCall,
1920
ToolCallResult,
@@ -236,14 +237,19 @@ def _validate_build_chat_message(
236237
if role == "assistant":
237238
texts = [part.text for part in parts if isinstance(part, TextContent)]
238239
tool_calls = [part for part in parts if isinstance(part, ToolCall)]
240+
reasoning = [part for part in parts if isinstance(part, ReasoningContent)]
239241
if len(texts) > 1:
240242
raise ValueError("Assistant message must contain one text part at most.")
241243
if len(texts) == 0 and len(tool_calls) == 0:
242244
raise ValueError("Assistant message must contain at least one text or tool call part.")
243-
if len(parts) > len(texts) + len(tool_calls):
244-
raise ValueError("Assistant message must contain only text or tool call parts.")
245+
if len(parts) > len(texts) + len(tool_calls) + len(reasoning):
246+
raise ValueError("Assistant message must contain only text, tool call or reasoning parts.")
245247
return ChatMessage.from_assistant(
246-
meta=meta, name=name, text=texts[0] if texts else None, tool_calls=tool_calls or None
248+
meta=meta,
249+
name=name,
250+
text=texts[0] if texts else None,
251+
tool_calls=tool_calls or None,
252+
reasoning=reasoning[0] if reasoning else None,
247253
)
248254

249255
if role == "tool":
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
---
2+
features:
3+
- |
4+
Added `ReasoningContent` as a new content part to the `ChatMessage` dataclass. This allows storing model
5+
reasoning text and additional metadata in assistant messages.
6+
Assistant messages can now include reasoning content using the `reasoning` parameter in
7+
`ChatMessage.from_assistant()`.
8+
We will progressively update the implementations for Chat Generators with LLMs that support reasoning to use this
9+
new content part.

test/components/builders/test_chat_prompt_builder.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from haystack import component
1313
from haystack.components.builders.chat_prompt_builder import ChatPromptBuilder
1414
from haystack.core.pipeline.pipeline import Pipeline
15-
from haystack.dataclasses.chat_message import ChatMessage, ImageContent
15+
from haystack.dataclasses.chat_message import ChatMessage, ImageContent, ReasoningContent
1616
from haystack.dataclasses.document import Document
1717

1818

@@ -891,6 +891,26 @@ def test_run_multiple_images(self, base64_image_string):
891891
)
892892
]
893893

894+
def test_run_reasoning(self):
895+
template = """
896+
{% message role="user" %}
897+
Hello! I am {{user_name}}. How much is 2 + 2?
898+
{% endmessage %}
899+
900+
{% message role="assistant" %}
901+
{{ reasoning | templatize_part }}
902+
The answer is 4.
903+
{% endmessage %}
904+
"""
905+
builder = ChatPromptBuilder(template=template)
906+
reasoning = ReasoningContent(reasoning_text="Let me think about it...", extra={"key": "value"})
907+
result = builder.run(user_name="John", reasoning=reasoning)
908+
909+
assert result["prompt"] == [
910+
ChatMessage.from_user(text="Hello! I am John. How much is 2 + 2?"),
911+
ChatMessage.from_assistant(reasoning=reasoning, text="The answer is 4."),
912+
]
913+
894914
def test_to_dict(self):
895915
template = """
896916
{% message role="user" %}

test/dataclasses/test_chat_message.py

Lines changed: 88 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,14 @@
66

77
import pytest
88

9-
from haystack.dataclasses.chat_message import ChatMessage, ChatRole, TextContent, ToolCall, ToolCallResult
9+
from haystack.dataclasses.chat_message import (
10+
ChatMessage,
11+
ChatRole,
12+
ReasoningContent,
13+
TextContent,
14+
ToolCall,
15+
ToolCallResult,
16+
)
1017
from haystack.dataclasses.image_content import ImageContent
1118

1219

@@ -80,9 +87,28 @@ def test_text_content_from_dict(self):
8087
tc = TextContent.from_dict({"text": "Hello"})
8188
assert tc.text == "Hello"
8289

90+
def test_reasoning_content_init(self):
91+
rc = ReasoningContent(reasoning_text="Let me think about it...")
92+
93+
assert rc.reasoning_text == "Let me think about it..."
94+
assert rc.extra == {}
95+
96+
rc = ReasoningContent(reasoning_text="Let me think about it...", extra={"key": "value"})
97+
assert rc.reasoning_text == "Let me think about it..."
98+
assert rc.extra == {"key": "value"}
99+
100+
def test_reasoning_content_to_dict(self):
101+
rc = ReasoningContent(reasoning_text="Let me think about it...", extra={"key": "value"})
102+
assert rc.to_dict() == {"reasoning_text": "Let me think about it...", "extra": {"key": "value"}}
103+
104+
def test_reasoning_content_from_dict(self):
105+
rc = ReasoningContent.from_dict({"reasoning_text": "Let me think about it...", "extra": {"key": "value"}})
106+
assert rc.reasoning_text == "Let me think about it..."
107+
assert rc.extra == {"key": "value"}
108+
83109

84110
class TestChatMessage:
85-
def test_from_assistant_with_valid_content(self):
111+
def test_from_assistant_with_text(self):
86112
text = "Hello, how can I assist you?"
87113
message = ChatMessage.from_assistant(text)
88114

@@ -99,6 +125,8 @@ def test_from_assistant_with_valid_content(self):
99125
assert not message.tool_call_result
100126
assert not message.images
101127
assert not message.image
128+
assert not message.reasonings
129+
assert not message.reasoning
102130

103131
def test_from_assistant_with_tool_calls(self):
104132
tool_calls = [
@@ -120,6 +148,53 @@ def test_from_assistant_with_tool_calls(self):
120148
assert not message.tool_call_result
121149
assert not message.images
122150
assert not message.image
151+
assert not message.reasoning
152+
assert not message.reasonings
153+
154+
def test_from_assistant_with_reasoning_object(self):
155+
reasoning = ReasoningContent(reasoning_text="Let me think about it...", extra={"key": "value"})
156+
text = "After thinking about it, I can say that the answer is 42."
157+
message = ChatMessage.from_assistant(text=text, reasoning=reasoning)
158+
159+
assert message.role == ChatRole.ASSISTANT
160+
assert message._content == [reasoning, TextContent(text=text)]
161+
162+
assert message.texts == [text]
163+
assert message.text == text
164+
assert message.reasoning == reasoning
165+
assert message.reasonings == [reasoning]
166+
167+
assert not message.tool_calls
168+
assert not message.tool_call
169+
assert not message.tool_call_results
170+
assert not message.tool_call_result
171+
assert not message.images
172+
assert not message.image
173+
174+
def test_from_assistant_with_reasoning_string(self):
175+
reasoning = "Let me think about it..."
176+
text = "After thinking about it, I can say that the answer is 42."
177+
message = ChatMessage.from_assistant(text=text, reasoning=reasoning)
178+
179+
expected_reasoning_content = ReasoningContent(reasoning_text=reasoning)
180+
assert message.role == ChatRole.ASSISTANT
181+
assert message._content == [expected_reasoning_content, TextContent(text=text)]
182+
183+
assert message.texts == [text]
184+
assert message.text == text
185+
assert message.reasoning == expected_reasoning_content
186+
assert message.reasonings == [expected_reasoning_content]
187+
188+
assert not message.tool_calls
189+
assert not message.tool_call
190+
assert not message.tool_call_results
191+
assert not message.tool_call_result
192+
assert not message.images
193+
assert not message.image
194+
195+
def test_from_assistant_with_invalid_reasoning(self):
196+
with pytest.raises(TypeError):
197+
ChatMessage.from_assistant(text="text", reasoning=123)
123198

124199
def test_from_user_with_valid_content(self):
125200
text = "I have a question."
@@ -138,6 +213,8 @@ def test_from_user_with_valid_content(self):
138213
assert not message.tool_call_result
139214
assert not message.images
140215
assert not message.image
216+
assert not message.reasonings
217+
assert not message.reasoning
141218

142219
def test_from_user_with_name(self):
143220
text = "I have a question."
@@ -207,6 +284,8 @@ def test_from_system_with_valid_content(self):
207284
assert not message.tool_call_result
208285
assert not message.images
209286
assert not message.image
287+
assert not message.reasonings
288+
assert not message.reasoning
210289

211290
def test_from_tool_with_valid_content(self):
212291
tool_result = "Tool result"
@@ -227,6 +306,8 @@ def test_from_tool_with_valid_content(self):
227306
assert not message.text
228307
assert not message.images
229308
assert not message.image
309+
assert not message.reasonings
310+
assert not message.reasoning
230311

231312
def test_multiple_text_segments(self):
232313
texts = [TextContent(text="Hello"), TextContent(text="World")]
@@ -266,10 +347,13 @@ def test_serde(self, base64_image_string):
266347
meta={"key": "value"},
267348
validation=True,
268349
)
350+
reasoning_content = ReasoningContent(reasoning_text="Let me think about it...", extra={"key": "value"})
269351
meta = {"some": "info"}
270352

271353
message = ChatMessage(
272-
_role=role, _content=[text_content, tool_call, tool_call_result, image_content], _meta=meta
354+
_role=role,
355+
_content=[text_content, tool_call, tool_call_result, image_content, reasoning_content],
356+
_meta=meta,
273357
)
274358

275359
serialized_message = message.to_dict()
@@ -293,6 +377,7 @@ def test_serde(self, base64_image_string):
293377
"validation": True,
294378
}
295379
},
380+
{"reasoning": {"reasoning_text": "Let me think about it...", "extra": {"key": "value"}}},
296381
],
297382
"role": "assistant",
298383
"name": None,

0 commit comments

Comments
 (0)