Skip to content

Commit ec37138

Browse files
authored
refactor: Update to StreamingChunk, better index setting and change tool_call to tool_calls (#9525)
* Fixes to setting StreamingChunk.index properly and refactoring tests for conversion * Make _convert_chat_completion_chunk_to_streaming_chunk a member of OpenAIChatGenerator so we can overwrite it in integrations that inherit from it * Fixes * Modify streaming chunk to accept a list of tool call deltas. * Fix tests * Fix mypy and update original reno * Undo change * Update conversion to return a single streaming chunk * update to print streaming chunk * Fix types * PR comments
1 parent f911459 commit ec37138

8 files changed

Lines changed: 620 additions & 393 deletions

File tree

haystack/components/generators/chat/openai.py

Lines changed: 54 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -427,12 +427,11 @@ def _handle_stream_response(self, chat_completion: Stream, callback: SyncStreami
427427
chunks: List[StreamingChunk] = []
428428
for chunk in chat_completion: # pylint: disable=not-an-iterable
429429
assert len(chunk.choices) <= 1, "Streaming responses should have at most one choice."
430-
chunk_deltas = _convert_chat_completion_chunk_to_streaming_chunk(
430+
chunk_delta = _convert_chat_completion_chunk_to_streaming_chunk(
431431
chunk=chunk, previous_chunks=chunks, component_info=component_info
432432
)
433-
for chunk_delta in chunk_deltas:
434-
chunks.append(chunk_delta)
435-
callback(chunk_delta)
433+
chunks.append(chunk_delta)
434+
callback(chunk_delta)
436435
return [_convert_streaming_chunks_to_chat_message(chunks=chunks)]
437436

438437
async def _handle_async_stream_response(
@@ -442,12 +441,11 @@ async def _handle_async_stream_response(
442441
chunks: List[StreamingChunk] = []
443442
async for chunk in chat_completion: # pylint: disable=not-an-iterable
444443
assert len(chunk.choices) <= 1, "Streaming responses should have at most one choice."
445-
chunk_deltas = _convert_chat_completion_chunk_to_streaming_chunk(
444+
chunk_delta = _convert_chat_completion_chunk_to_streaming_chunk(
446445
chunk=chunk, previous_chunks=chunks, component_info=component_info
447446
)
448-
for chunk_delta in chunk_deltas:
449-
chunks.append(chunk_delta)
450-
await callback(chunk_delta)
447+
chunks.append(chunk_delta)
448+
await callback(chunk_delta)
451449
return [_convert_streaming_chunks_to_chat_message(chunks=chunks)]
452450

453451

@@ -509,7 +507,7 @@ def _convert_chat_completion_to_chat_message(completion: ChatCompletion, choice:
509507

510508
def _convert_chat_completion_chunk_to_streaming_chunk(
511509
chunk: ChatCompletionChunk, previous_chunks: List[StreamingChunk], component_info: Optional[ComponentInfo] = None
512-
) -> List[StreamingChunk]:
510+
) -> StreamingChunk:
513511
"""
514512
Converts the streaming response chunk from the OpenAI API to a StreamingChunk.
515513
@@ -521,61 +519,68 @@ def _convert_chat_completion_chunk_to_streaming_chunk(
521519
:returns:
522520
A list of StreamingChunk objects representing the content of the chunk from the OpenAI API.
523521
"""
524-
# Choices is empty on the very first chunk which provides role information (e.g. "assistant").
525-
# It is also empty if include_usage is set to True where the usage information is returned.
522+
# On very first chunk so len(previous_chunks) == 0, the Choices field only provides role info (e.g. "assistant")
523+
# Choices is empty if include_usage is set to True where the usage information is returned.
526524
if len(chunk.choices) == 0:
527-
return [
528-
StreamingChunk(
529-
content="",
530-
component_info=component_info,
531-
# Index is None since it's only set to an int when a content block is present
532-
index=None,
533-
meta={
534-
"model": chunk.model,
535-
"received_at": datetime.now().isoformat(),
536-
"usage": _serialize_usage(chunk.usage),
537-
},
538-
)
539-
]
525+
return StreamingChunk(
526+
content="",
527+
component_info=component_info,
528+
# Index is None since it's only set to an int when a content block is present
529+
index=None,
530+
meta={
531+
"model": chunk.model,
532+
"received_at": datetime.now().isoformat(),
533+
"usage": _serialize_usage(chunk.usage),
534+
},
535+
)
540536

541537
choice: ChunkChoice = chunk.choices[0]
542-
content = choice.delta.content or ""
543538

544539
# create a list of ToolCallDelta objects from the tool calls
545540
if choice.delta.tool_calls:
546-
chunk_messages = []
541+
tool_calls_deltas = []
547542
for tool_call in choice.delta.tool_calls:
548543
function = tool_call.function
549-
chunk_message = StreamingChunk(
550-
content=content,
551-
# We adopt the tool_call.index as the index of the chunk
552-
component_info=component_info,
553-
index=tool_call.index,
554-
tool_call=ToolCallDelta(
544+
tool_calls_deltas.append(
545+
ToolCallDelta(
546+
index=tool_call.index,
555547
id=tool_call.id,
556548
tool_name=function.name if function else None,
557549
arguments=function.arguments if function and function.arguments else None,
558-
),
559-
start=function.name is not None if function else False,
560-
meta={
561-
"model": chunk.model,
562-
"index": choice.index,
563-
"tool_calls": choice.delta.tool_calls,
564-
"finish_reason": choice.finish_reason,
565-
"received_at": datetime.now().isoformat(),
566-
"usage": _serialize_usage(chunk.usage),
567-
},
550+
)
568551
)
569-
chunk_messages.append(chunk_message)
570-
return chunk_messages
552+
chunk_message = StreamingChunk(
553+
content=choice.delta.content or "",
554+
component_info=component_info,
555+
# We adopt the first tool_calls_deltas.index as the overall index of the chunk.
556+
index=tool_calls_deltas[0].index,
557+
tool_calls=tool_calls_deltas,
558+
start=tool_calls_deltas[0].tool_name is not None,
559+
meta={
560+
"model": chunk.model,
561+
"index": choice.index,
562+
"tool_calls": choice.delta.tool_calls,
563+
"finish_reason": choice.finish_reason,
564+
"received_at": datetime.now().isoformat(),
565+
"usage": _serialize_usage(chunk.usage),
566+
},
567+
)
568+
return chunk_message
571569

572-
chunk_message = StreamingChunk(
573-
content=content,
574-
component_info=component_info,
570+
# On very first chunk the choice field only provides role info (e.g. "assistant") so we set index to None
571+
# We set all chunks missing the content field to index of None. E.g. can happen if chunk only contains finish
572+
# reason.
573+
if choice.delta.content is None or choice.delta.role is not None:
574+
resolved_index = None
575+
else:
575576
# We set the index to be 0 since if text content is being streamed then no tool calls are being streamed
576577
# NOTE: We may need to revisit this if OpenAI allows planning/thinking content before tool calls like
577578
# Anthropic Claude
578-
index=0,
579+
resolved_index = 0
580+
chunk_message = StreamingChunk(
581+
content=choice.delta.content or "",
582+
component_info=component_info,
583+
index=resolved_index,
579584
# The first chunk is always a start message chunk that only contains role information, so if we reach here
580585
# and previous_chunks is length 1 then this is the start of text content.
581586
start=len(previous_chunks) == 1,
@@ -588,7 +593,7 @@ def _convert_chat_completion_chunk_to_streaming_chunk(
588593
"usage": _serialize_usage(chunk.usage),
589594
},
590595
)
591-
return [chunk_message]
596+
return chunk_message
592597

593598

594599
def _serialize_usage(usage):

haystack/components/generators/openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def run(
249249
chunk=chunk, # type: ignore
250250
previous_chunks=chunks,
251251
component_info=component_info,
252-
)[0]
252+
)
253253
chunks.append(chunk_delta)
254254
streaming_callback(chunk_delta)
255255

haystack/components/generators/utils.py

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,24 @@ def print_streaming_chunk(chunk: StreamingChunk) -> None:
3131
print("\n\n", flush=True, end="")
3232

3333
## Tool Call streaming
34-
if chunk.tool_call:
35-
# If chunk.start is True indicates beginning of a tool call
36-
# Also presence of chunk.tool_call.name indicates the start of a tool call too
37-
if chunk.start:
38-
print("[TOOL CALL]\n", flush=True, end="")
39-
print(f"Tool: {chunk.tool_call.tool_name} ", flush=True, end="")
40-
print("\nArguments: ", flush=True, end="")
41-
42-
# print the tool arguments
43-
if chunk.tool_call.arguments:
44-
print(chunk.tool_call.arguments, flush=True, end="")
34+
if chunk.tool_calls:
35+
# Typically, if there are multiple tool calls in the chunk this means that the tool calls are fully formed and
36+
# not just a delta.
37+
for tool_call in chunk.tool_calls:
38+
# If chunk.start is True indicates beginning of a tool call
39+
# Also presence of tool_call.tool_name indicates the start of a tool call too
40+
if chunk.start:
41+
# If there is more than one tool call in the chunk, we print two new lines to separate them
42+
# We know there is more than one tool call if the index of the tool call is greater than the index of
43+
# the chunk.
44+
if chunk.index and tool_call.index > chunk.index:
45+
print("\n\n", flush=True, end="")
46+
47+
print("[TOOL CALL]\nTool: {tool_call.tool_name} \nArguments: ", flush=True, end="")
48+
49+
# print the tool arguments
50+
if tool_call.arguments:
51+
print(tool_call.arguments, flush=True, end="")
4552

4653
## Tool Call Result streaming
4754
# Print tool call results if available (from ToolInvoker)
@@ -76,39 +83,41 @@ def _convert_streaming_chunks_to_chat_message(chunks: List[StreamingChunk]) -> C
7683
# Process tool calls if present in any chunk
7784
tool_call_data: Dict[int, Dict[str, str]] = {} # Track tool calls by index
7885
for chunk in chunks:
79-
if chunk.tool_call:
86+
if chunk.tool_calls:
8087
# We do this to make sure mypy is happy, but we enforce index is not None in the StreamingChunk dataclass if
8188
# tool_call is present
8289
assert chunk.index is not None
8390

84-
# We use the index of the chunk to track the tool call across chunks since the ID is not always provided
85-
if chunk.index not in tool_call_data:
86-
tool_call_data[chunk.index] = {"id": "", "name": "", "arguments": ""}
91+
for tool_call in chunk.tool_calls:
92+
# We use the index of the tool_call to track the tool call across chunks since the ID is not always
93+
# provided
94+
if tool_call.index not in tool_call_data:
95+
tool_call_data[chunk.index] = {"id": "", "name": "", "arguments": ""}
8796

88-
# Save the ID if present
89-
if chunk.tool_call.id is not None:
90-
tool_call_data[chunk.index]["id"] = chunk.tool_call.id
97+
# Save the ID if present
98+
if tool_call.id is not None:
99+
tool_call_data[chunk.index]["id"] = tool_call.id
91100

92-
if chunk.tool_call.tool_name is not None:
93-
tool_call_data[chunk.index]["name"] += chunk.tool_call.tool_name
94-
if chunk.tool_call.arguments is not None:
95-
tool_call_data[chunk.index]["arguments"] += chunk.tool_call.arguments
101+
if tool_call.tool_name is not None:
102+
tool_call_data[chunk.index]["name"] += tool_call.tool_name
103+
if tool_call.arguments is not None:
104+
tool_call_data[chunk.index]["arguments"] += tool_call.arguments
96105

97106
# Convert accumulated tool call data into ToolCall objects
98107
sorted_keys = sorted(tool_call_data.keys())
99108
for key in sorted_keys:
100-
tool_call = tool_call_data[key]
109+
tool_call_dict = tool_call_data[key]
101110
try:
102-
arguments = json.loads(tool_call["arguments"])
103-
tool_calls.append(ToolCall(id=tool_call["id"], tool_name=tool_call["name"], arguments=arguments))
111+
arguments = json.loads(tool_call_dict["arguments"])
112+
tool_calls.append(ToolCall(id=tool_call_dict["id"], tool_name=tool_call_dict["name"], arguments=arguments))
104113
except json.JSONDecodeError:
105114
logger.warning(
106115
"OpenAI returned a malformed JSON string for tool call arguments. This tool call "
107116
"will be skipped. To always generate a valid JSON, set `tools_strict` to `True`. "
108117
"Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}",
109-
_id=tool_call["id"],
110-
_name=tool_call["name"],
111-
_arguments=tool_call["arguments"],
118+
_id=tool_call_dict["id"],
119+
_name=tool_call_dict["name"],
120+
_arguments=tool_call_dict["arguments"],
112121
)
113122

114123
# finish_reason can appear in different places so we look for the last one

haystack/dataclasses/streaming_chunk.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
from dataclasses import dataclass, field
6-
from typing import Any, Awaitable, Callable, Dict, Literal, Optional, Union, overload
6+
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Union, overload
77

88
from haystack.core.component import Component
99
from haystack.dataclasses.chat_message import ToolCallResult
@@ -15,11 +15,13 @@ class ToolCallDelta:
1515
"""
1616
Represents a Tool call prepared by the model, usually contained in an assistant message.
1717
18+
:param index: The index of the Tool call in the list of Tool calls.
1819
:param tool_name: The name of the Tool to call.
1920
:param arguments: Either the full arguments in JSON format or a delta of the arguments.
2021
:param id: The ID of the Tool call.
2122
"""
2223

24+
index: int
2325
tool_name: Optional[str] = field(default=None)
2426
arguments: Optional[str] = field(default=None)
2527
id: Optional[str] = field(default=None) # noqa: A003
@@ -71,7 +73,8 @@ class StreamingChunk:
7173
:param component_info: A `ComponentInfo` object containing information about the component that generated the chunk,
7274
such as the component name and type.
7375
:param index: An optional integer index representing which content block this chunk belongs to.
74-
:param tool_call: An optional ToolCallDelta object representing a tool call associated with the message chunk.
76+
:param tool_calls: An optional list of ToolCallDelta object representing a tool call associated with the message
77+
chunk.
7578
:param tool_call_result: An optional ToolCallResult object representing the result of a tool call.
7679
:param start: A boolean indicating whether this chunk marks the start of a content block.
7780
"""
@@ -80,21 +83,21 @@ class StreamingChunk:
8083
meta: Dict[str, Any] = field(default_factory=dict, hash=False)
8184
component_info: Optional[ComponentInfo] = field(default=None)
8285
index: Optional[int] = field(default=None)
83-
tool_call: Optional[ToolCallDelta] = field(default=None)
86+
tool_calls: Optional[List[ToolCallDelta]] = field(default=None)
8487
tool_call_result: Optional[ToolCallResult] = field(default=None)
8588
start: bool = field(default=False)
8689

8790
def __post_init__(self):
88-
fields_set = sum(bool(x) for x in (self.content, self.tool_call, self.tool_call_result))
91+
fields_set = sum(bool(x) for x in (self.content, self.tool_calls, self.tool_call_result))
8992
if fields_set > 1:
9093
raise ValueError(
9194
"Only one of `content`, `tool_call`, or `tool_call_result` may be set in a StreamingChunk. "
92-
f"Got content: '{self.content}', tool_call: '{self.tool_call}', "
95+
f"Got content: '{self.content}', tool_call: '{self.tool_calls}', "
9396
f"tool_call_result: '{self.tool_call_result}'"
9497
)
9598

9699
# NOTE: We don't enforce this for self.content otherwise it would be a breaking change
97-
if (self.tool_call or self.tool_call_result) and self.index is None:
100+
if (self.tool_calls or self.tool_call_result) and self.index is None:
98101
raise ValueError("If `tool_call`, or `tool_call_result` is set, `index` must also be set.")
99102

100103

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
---
22
features:
33
- |
4-
Updated StreamingChunk to add the fields `tool_call`, `tool_call_result`, `index`, and `start` to make it easier to format the stream in a streaming callback.
5-
- Added new dataclass ToolCallDelta for the `StreamingChunk.tool_call` field to reflect that the arguments can be a string delta.
4+
Updated StreamingChunk to add the fields `tool_calls`, `tool_call_result`, `index`, and `start` to make it easier to format the stream in a streaming callback.
5+
- Added new dataclass ToolCallDelta for the `StreamingChunk.tool_calls` field to reflect that the arguments can be a string delta.
66
- Updated `print_streaming_chunk` and `_convert_streaming_chunks_to_chat_message` utility methods to use these new fields. This especially improves the formatting when using `print_streaming_chunk` with Agent.
77
- Updated `OpenAIGenerator`, `OpenAIChatGenerator`, `HuggingFaceAPIGenerator`, `HuggingFaceAPIChatGenerator`, `HuggingFaceLocalGenerator` and `HuggingFaceLocalChatGenerator` to follow the new dataclasses.
88
- Updated `ToolInvoker` to follow the StreamingChunk dataclass.

0 commit comments

Comments
 (0)