diff --git a/haystack/components/generators/chat/openai.py b/haystack/components/generators/chat/openai.py index fda06e6ebd..dd3d7cee25 100644 --- a/haystack/components/generators/chat/openai.py +++ b/haystack/components/generators/chat/openai.py @@ -427,12 +427,11 @@ def _handle_stream_response(self, chat_completion: Stream, callback: SyncStreami chunks: List[StreamingChunk] = [] for chunk in chat_completion: # pylint: disable=not-an-iterable assert len(chunk.choices) <= 1, "Streaming responses should have at most one choice." - chunk_deltas = _convert_chat_completion_chunk_to_streaming_chunk( + chunk_delta = _convert_chat_completion_chunk_to_streaming_chunk( chunk=chunk, previous_chunks=chunks, component_info=component_info ) - for chunk_delta in chunk_deltas: - chunks.append(chunk_delta) - callback(chunk_delta) + chunks.append(chunk_delta) + callback(chunk_delta) return [_convert_streaming_chunks_to_chat_message(chunks=chunks)] async def _handle_async_stream_response( @@ -442,12 +441,11 @@ async def _handle_async_stream_response( chunks: List[StreamingChunk] = [] async for chunk in chat_completion: # pylint: disable=not-an-iterable assert len(chunk.choices) <= 1, "Streaming responses should have at most one choice." - chunk_deltas = _convert_chat_completion_chunk_to_streaming_chunk( + chunk_delta = _convert_chat_completion_chunk_to_streaming_chunk( chunk=chunk, previous_chunks=chunks, component_info=component_info ) - for chunk_delta in chunk_deltas: - chunks.append(chunk_delta) - await callback(chunk_delta) + chunks.append(chunk_delta) + await callback(chunk_delta) return [_convert_streaming_chunks_to_chat_message(chunks=chunks)] @@ -509,7 +507,7 @@ def _convert_chat_completion_to_chat_message(completion: ChatCompletion, choice: def _convert_chat_completion_chunk_to_streaming_chunk( chunk: ChatCompletionChunk, previous_chunks: List[StreamingChunk], component_info: Optional[ComponentInfo] = None -) -> List[StreamingChunk]: +) -> StreamingChunk: """ Converts the streaming response chunk from the OpenAI API to a StreamingChunk. @@ -521,61 +519,68 @@ def _convert_chat_completion_chunk_to_streaming_chunk( :returns: A list of StreamingChunk objects representing the content of the chunk from the OpenAI API. """ - # Choices is empty on the very first chunk which provides role information (e.g. "assistant"). - # It is also empty if include_usage is set to True where the usage information is returned. + # On very first chunk so len(previous_chunks) == 0, the Choices field only provides role info (e.g. "assistant") + # Choices is empty if include_usage is set to True where the usage information is returned. if len(chunk.choices) == 0: - return [ - StreamingChunk( - content="", - component_info=component_info, - # Index is None since it's only set to an int when a content block is present - index=None, - meta={ - "model": chunk.model, - "received_at": datetime.now().isoformat(), - "usage": _serialize_usage(chunk.usage), - }, - ) - ] + return StreamingChunk( + content="", + component_info=component_info, + # Index is None since it's only set to an int when a content block is present + index=None, + meta={ + "model": chunk.model, + "received_at": datetime.now().isoformat(), + "usage": _serialize_usage(chunk.usage), + }, + ) choice: ChunkChoice = chunk.choices[0] - content = choice.delta.content or "" # create a list of ToolCallDelta objects from the tool calls if choice.delta.tool_calls: - chunk_messages = [] + tool_calls_deltas = [] for tool_call in choice.delta.tool_calls: function = tool_call.function - chunk_message = StreamingChunk( - content=content, - # We adopt the tool_call.index as the index of the chunk - component_info=component_info, - index=tool_call.index, - tool_call=ToolCallDelta( + tool_calls_deltas.append( + ToolCallDelta( + index=tool_call.index, id=tool_call.id, tool_name=function.name if function else None, arguments=function.arguments if function and function.arguments else None, - ), - start=function.name is not None if function else False, - meta={ - "model": chunk.model, - "index": choice.index, - "tool_calls": choice.delta.tool_calls, - "finish_reason": choice.finish_reason, - "received_at": datetime.now().isoformat(), - "usage": _serialize_usage(chunk.usage), - }, + ) ) - chunk_messages.append(chunk_message) - return chunk_messages + chunk_message = StreamingChunk( + content=choice.delta.content or "", + component_info=component_info, + # We adopt the first tool_calls_deltas.index as the overall index of the chunk. + index=tool_calls_deltas[0].index, + tool_calls=tool_calls_deltas, + start=tool_calls_deltas[0].tool_name is not None, + meta={ + "model": chunk.model, + "index": choice.index, + "tool_calls": choice.delta.tool_calls, + "finish_reason": choice.finish_reason, + "received_at": datetime.now().isoformat(), + "usage": _serialize_usage(chunk.usage), + }, + ) + return chunk_message - chunk_message = StreamingChunk( - content=content, - component_info=component_info, + # On very first chunk the choice field only provides role info (e.g. "assistant") so we set index to None + # We set all chunks missing the content field to index of None. E.g. can happen if chunk only contains finish + # reason. + if choice.delta.content is None or choice.delta.role is not None: + resolved_index = None + else: # We set the index to be 0 since if text content is being streamed then no tool calls are being streamed # NOTE: We may need to revisit this if OpenAI allows planning/thinking content before tool calls like # Anthropic Claude - index=0, + resolved_index = 0 + chunk_message = StreamingChunk( + content=choice.delta.content or "", + component_info=component_info, + index=resolved_index, # The first chunk is always a start message chunk that only contains role information, so if we reach here # and previous_chunks is length 1 then this is the start of text content. start=len(previous_chunks) == 1, @@ -588,7 +593,7 @@ def _convert_chat_completion_chunk_to_streaming_chunk( "usage": _serialize_usage(chunk.usage), }, ) - return [chunk_message] + return chunk_message def _serialize_usage(usage): diff --git a/haystack/components/generators/openai.py b/haystack/components/generators/openai.py index 40e36a9ad3..11903d5868 100644 --- a/haystack/components/generators/openai.py +++ b/haystack/components/generators/openai.py @@ -249,7 +249,7 @@ def run( chunk=chunk, # type: ignore previous_chunks=chunks, component_info=component_info, - )[0] + ) chunks.append(chunk_delta) streaming_callback(chunk_delta) diff --git a/haystack/components/generators/utils.py b/haystack/components/generators/utils.py index e66b13786e..d8c5bd678a 100644 --- a/haystack/components/generators/utils.py +++ b/haystack/components/generators/utils.py @@ -31,17 +31,24 @@ def print_streaming_chunk(chunk: StreamingChunk) -> None: print("\n\n", flush=True, end="") ## Tool Call streaming - if chunk.tool_call: - # If chunk.start is True indicates beginning of a tool call - # Also presence of chunk.tool_call.name indicates the start of a tool call too - if chunk.start: - print("[TOOL CALL]\n", flush=True, end="") - print(f"Tool: {chunk.tool_call.tool_name} ", flush=True, end="") - print("\nArguments: ", flush=True, end="") - - # print the tool arguments - if chunk.tool_call.arguments: - print(chunk.tool_call.arguments, flush=True, end="") + if chunk.tool_calls: + # Typically, if there are multiple tool calls in the chunk this means that the tool calls are fully formed and + # not just a delta. + for tool_call in chunk.tool_calls: + # If chunk.start is True indicates beginning of a tool call + # Also presence of tool_call.tool_name indicates the start of a tool call too + if chunk.start: + # If there is more than one tool call in the chunk, we print two new lines to separate them + # We know there is more than one tool call if the index of the tool call is greater than the index of + # the chunk. + if chunk.index and tool_call.index > chunk.index: + print("\n\n", flush=True, end="") + + print("[TOOL CALL]\nTool: {tool_call.tool_name} \nArguments: ", flush=True, end="") + + # print the tool arguments + if tool_call.arguments: + print(tool_call.arguments, flush=True, end="") ## Tool Call Result streaming # Print tool call results if available (from ToolInvoker) @@ -76,39 +83,41 @@ def _convert_streaming_chunks_to_chat_message(chunks: List[StreamingChunk]) -> C # Process tool calls if present in any chunk tool_call_data: Dict[int, Dict[str, str]] = {} # Track tool calls by index for chunk in chunks: - if chunk.tool_call: + if chunk.tool_calls: # We do this to make sure mypy is happy, but we enforce index is not None in the StreamingChunk dataclass if # tool_call is present assert chunk.index is not None - # We use the index of the chunk to track the tool call across chunks since the ID is not always provided - if chunk.index not in tool_call_data: - tool_call_data[chunk.index] = {"id": "", "name": "", "arguments": ""} + for tool_call in chunk.tool_calls: + # We use the index of the tool_call to track the tool call across chunks since the ID is not always + # provided + if tool_call.index not in tool_call_data: + tool_call_data[chunk.index] = {"id": "", "name": "", "arguments": ""} - # Save the ID if present - if chunk.tool_call.id is not None: - tool_call_data[chunk.index]["id"] = chunk.tool_call.id + # Save the ID if present + if tool_call.id is not None: + tool_call_data[chunk.index]["id"] = tool_call.id - if chunk.tool_call.tool_name is not None: - tool_call_data[chunk.index]["name"] += chunk.tool_call.tool_name - if chunk.tool_call.arguments is not None: - tool_call_data[chunk.index]["arguments"] += chunk.tool_call.arguments + if tool_call.tool_name is not None: + tool_call_data[chunk.index]["name"] += tool_call.tool_name + if tool_call.arguments is not None: + tool_call_data[chunk.index]["arguments"] += tool_call.arguments # Convert accumulated tool call data into ToolCall objects sorted_keys = sorted(tool_call_data.keys()) for key in sorted_keys: - tool_call = tool_call_data[key] + tool_call_dict = tool_call_data[key] try: - arguments = json.loads(tool_call["arguments"]) - tool_calls.append(ToolCall(id=tool_call["id"], tool_name=tool_call["name"], arguments=arguments)) + arguments = json.loads(tool_call_dict["arguments"]) + tool_calls.append(ToolCall(id=tool_call_dict["id"], tool_name=tool_call_dict["name"], arguments=arguments)) except json.JSONDecodeError: logger.warning( "OpenAI returned a malformed JSON string for tool call arguments. This tool call " "will be skipped. To always generate a valid JSON, set `tools_strict` to `True`. " "Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}", - _id=tool_call["id"], - _name=tool_call["name"], - _arguments=tool_call["arguments"], + _id=tool_call_dict["id"], + _name=tool_call_dict["name"], + _arguments=tool_call_dict["arguments"], ) # finish_reason can appear in different places so we look for the last one diff --git a/haystack/dataclasses/streaming_chunk.py b/haystack/dataclasses/streaming_chunk.py index d01a61bc05..41abc6c03c 100644 --- a/haystack/dataclasses/streaming_chunk.py +++ b/haystack/dataclasses/streaming_chunk.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass, field -from typing import Any, Awaitable, Callable, Dict, Literal, Optional, Union, overload +from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Union, overload from haystack.core.component import Component from haystack.dataclasses.chat_message import ToolCallResult @@ -15,11 +15,13 @@ class ToolCallDelta: """ Represents a Tool call prepared by the model, usually contained in an assistant message. + :param index: The index of the Tool call in the list of Tool calls. :param tool_name: The name of the Tool to call. :param arguments: Either the full arguments in JSON format or a delta of the arguments. :param id: The ID of the Tool call. """ + index: int tool_name: Optional[str] = field(default=None) arguments: Optional[str] = field(default=None) id: Optional[str] = field(default=None) # noqa: A003 @@ -71,7 +73,8 @@ class StreamingChunk: :param component_info: A `ComponentInfo` object containing information about the component that generated the chunk, such as the component name and type. :param index: An optional integer index representing which content block this chunk belongs to. - :param tool_call: An optional ToolCallDelta object representing a tool call associated with the message chunk. + :param tool_calls: An optional list of ToolCallDelta object representing a tool call associated with the message + chunk. :param tool_call_result: An optional ToolCallResult object representing the result of a tool call. :param start: A boolean indicating whether this chunk marks the start of a content block. """ @@ -80,21 +83,21 @@ class StreamingChunk: meta: Dict[str, Any] = field(default_factory=dict, hash=False) component_info: Optional[ComponentInfo] = field(default=None) index: Optional[int] = field(default=None) - tool_call: Optional[ToolCallDelta] = field(default=None) + tool_calls: Optional[List[ToolCallDelta]] = field(default=None) tool_call_result: Optional[ToolCallResult] = field(default=None) start: bool = field(default=False) def __post_init__(self): - fields_set = sum(bool(x) for x in (self.content, self.tool_call, self.tool_call_result)) + fields_set = sum(bool(x) for x in (self.content, self.tool_calls, self.tool_call_result)) if fields_set > 1: raise ValueError( "Only one of `content`, `tool_call`, or `tool_call_result` may be set in a StreamingChunk. " - f"Got content: '{self.content}', tool_call: '{self.tool_call}', " + f"Got content: '{self.content}', tool_call: '{self.tool_calls}', " f"tool_call_result: '{self.tool_call_result}'" ) # NOTE: We don't enforce this for self.content otherwise it would be a breaking change - if (self.tool_call or self.tool_call_result) and self.index is None: + if (self.tool_calls or self.tool_call_result) and self.index is None: raise ValueError("If `tool_call`, or `tool_call_result` is set, `index` must also be set.") diff --git a/releasenotes/notes/update-streaming-chunk-more-info-9008e05b21eef349.yaml b/releasenotes/notes/update-streaming-chunk-more-info-9008e05b21eef349.yaml index d4c7b3cd73..a15d78728f 100644 --- a/releasenotes/notes/update-streaming-chunk-more-info-9008e05b21eef349.yaml +++ b/releasenotes/notes/update-streaming-chunk-more-info-9008e05b21eef349.yaml @@ -1,8 +1,8 @@ --- features: - | - Updated StreamingChunk to add the fields `tool_call`, `tool_call_result`, `index`, and `start` to make it easier to format the stream in a streaming callback. - - Added new dataclass ToolCallDelta for the `StreamingChunk.tool_call` field to reflect that the arguments can be a string delta. + Updated StreamingChunk to add the fields `tool_calls`, `tool_call_result`, `index`, and `start` to make it easier to format the stream in a streaming callback. + - Added new dataclass ToolCallDelta for the `StreamingChunk.tool_calls` field to reflect that the arguments can be a string delta. - Updated `print_streaming_chunk` and `_convert_streaming_chunks_to_chat_message` utility methods to use these new fields. This especially improves the formatting when using `print_streaming_chunk` with Agent. - Updated `OpenAIGenerator`, `OpenAIChatGenerator`, `HuggingFaceAPIGenerator`, `HuggingFaceAPIChatGenerator`, `HuggingFaceLocalGenerator` and `HuggingFaceLocalChatGenerator` to follow the new dataclasses. - Updated `ToolInvoker` to follow the StreamingChunk dataclass. diff --git a/test/components/generators/chat/test_openai.py b/test/components/generators/chat/test_openai.py index 610c26bbcd..fd4c227e1e 100644 --- a/test/components/generators/chat/test_openai.py +++ b/test/components/generators/chat/test_openai.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from unittest.mock import patch, MagicMock +from unittest.mock import patch, ANY, MagicMock import pytest @@ -21,7 +21,7 @@ from haystack import component from haystack.components.generators.utils import print_streaming_chunk -from haystack.dataclasses import StreamingChunk +from haystack.dataclasses import StreamingChunk, ToolCallDelta from haystack.utils.auth import Secret from haystack.dataclasses import ChatMessage, ToolCall from haystack.tools import ComponentTool, Tool @@ -598,295 +598,6 @@ def test_invalid_tool_call_json(self, tools, caplog): assert message.meta["finish_reason"] == "tool_calls" assert message.meta["usage"]["completion_tokens"] == 47 - def test_handle_stream_response(self): - openai_chunks = [ - ChatCompletionChunk( - id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", - choices=[chat_completion_chunk.Choice(delta=ChoiceDelta(role="assistant"), index=0)], - created=1747834733, - model="gpt-4o-mini-2024-07-18", - object="chat.completion.chunk", - service_tier="default", - system_fingerprint="fp_54eb4bd693", - ), - ChatCompletionChunk( - id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", - choices=[ - chat_completion_chunk.Choice( - delta=ChoiceDelta( - tool_calls=[ - ChoiceDeltaToolCall( - index=0, - id="call_zcvlnVaTeJWRjLAFfYxX69z4", - function=ChoiceDeltaToolCallFunction(arguments="", name="weather"), - type="function", - ) - ] - ), - index=0, - ) - ], - created=1747834733, - model="gpt-4o-mini-2024-07-18", - object="chat.completion.chunk", - service_tier="default", - system_fingerprint="fp_54eb4bd693", - ), - ChatCompletionChunk( - id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", - choices=[ - chat_completion_chunk.Choice( - delta=ChoiceDelta( - tool_calls=[ - ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='{"ci')) - ] - ), - index=0, - ) - ], - created=1747834733, - model="gpt-4o-mini-2024-07-18", - object="chat.completion.chunk", - service_tier="default", - system_fingerprint="fp_54eb4bd693", - ), - ChatCompletionChunk( - id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", - choices=[ - chat_completion_chunk.Choice( - delta=ChoiceDelta( - tool_calls=[ - ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='ty": ')) - ] - ), - index=0, - ) - ], - created=1747834733, - model="gpt-4o-mini-2024-07-18", - object="chat.completion.chunk", - service_tier="default", - system_fingerprint="fp_54eb4bd693", - ), - ChatCompletionChunk( - id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", - choices=[ - chat_completion_chunk.Choice( - delta=ChoiceDelta( - tool_calls=[ - ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='"Paris')) - ] - ), - index=0, - ) - ], - created=1747834733, - model="gpt-4o-mini-2024-07-18", - object="chat.completion.chunk", - service_tier="default", - system_fingerprint="fp_54eb4bd693", - ), - ChatCompletionChunk( - id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", - choices=[ - chat_completion_chunk.Choice( - delta=ChoiceDelta( - tool_calls=[ - ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='"}')) - ] - ), - index=0, - ) - ], - created=1747834733, - model="gpt-4o-mini-2024-07-18", - object="chat.completion.chunk", - service_tier="default", - system_fingerprint="fp_54eb4bd693", - ), - ChatCompletionChunk( - id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", - choices=[ - chat_completion_chunk.Choice( - delta=ChoiceDelta( - tool_calls=[ - ChoiceDeltaToolCall( - index=1, - id="call_C88m67V16CrETq6jbNXjdZI9", - function=ChoiceDeltaToolCallFunction(arguments="", name="weather"), - type="function", - ) - ] - ), - index=0, - ) - ], - created=1747834733, - model="gpt-4o-mini-2024-07-18", - object="chat.completion.chunk", - service_tier="default", - system_fingerprint="fp_54eb4bd693", - ), - ChatCompletionChunk( - id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", - choices=[ - chat_completion_chunk.Choice( - delta=ChoiceDelta( - tool_calls=[ - ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='{"ci')) - ] - ), - index=0, - ) - ], - created=1747834733, - model="gpt-4o-mini-2024-07-18", - object="chat.completion.chunk", - service_tier="default", - system_fingerprint="fp_54eb4bd693", - ), - ChatCompletionChunk( - id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", - choices=[ - chat_completion_chunk.Choice( - delta=ChoiceDelta( - tool_calls=[ - ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='ty": ')) - ] - ), - index=0, - ) - ], - created=1747834733, - model="gpt-4o-mini-2024-07-18", - object="chat.completion.chunk", - service_tier="default", - system_fingerprint="fp_54eb4bd693", - ), - ChatCompletionChunk( - id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", - choices=[ - chat_completion_chunk.Choice( - delta=ChoiceDelta( - tool_calls=[ - ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='"Berli')) - ] - ), - index=0, - ) - ], - created=1747834733, - model="gpt-4o-mini-2024-07-18", - object="chat.completion.chunk", - service_tier="default", - system_fingerprint="fp_54eb4bd693", - ), - ChatCompletionChunk( - id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", - choices=[ - chat_completion_chunk.Choice( - delta=ChoiceDelta( - tool_calls=[ - ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='n"}')) - ] - ), - index=0, - ) - ], - created=1747834733, - model="gpt-4o-mini-2024-07-18", - object="chat.completion.chunk", - service_tier="default", - system_fingerprint="fp_54eb4bd693", - ), - ChatCompletionChunk( - id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", - choices=[chat_completion_chunk.Choice(delta=ChoiceDelta(), finish_reason="tool_calls", index=0)], - created=1747834733, - model="gpt-4o-mini-2024-07-18", - object="chat.completion.chunk", - service_tier="default", - system_fingerprint="fp_54eb4bd693", - ), - ChatCompletionChunk( - id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", - choices=[], - created=1747834733, - model="gpt-4o-mini-2024-07-18", - object="chat.completion.chunk", - service_tier="default", - system_fingerprint="fp_54eb4bd693", - usage=CompletionUsage( - completion_tokens=42, - prompt_tokens=282, - total_tokens=324, - completion_tokens_details=CompletionTokensDetails( - accepted_prediction_tokens=0, audio_tokens=0, reasoning_tokens=0, rejected_prediction_tokens=0 - ), - prompt_tokens_details=PromptTokensDetails(audio_tokens=0, cached_tokens=0), - ), - ), - ] - component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) - result = component._handle_stream_response(openai_chunks, callback=lambda chunk: None)[0] # type: ignore - - assert not result.texts - assert not result.text - - # Verify both tool calls were found and processed - assert len(result.tool_calls) == 2 - assert result.tool_calls[0].id == "call_zcvlnVaTeJWRjLAFfYxX69z4" - assert result.tool_calls[0].tool_name == "weather" - assert result.tool_calls[0].arguments == {"city": "Paris"} - assert result.tool_calls[1].id == "call_C88m67V16CrETq6jbNXjdZI9" - assert result.tool_calls[1].tool_name == "weather" - assert result.tool_calls[1].arguments == {"city": "Berlin"} - - # Verify meta information - assert result.meta["model"] == "gpt-4o-mini-2024-07-18" - assert result.meta["finish_reason"] == "tool_calls" - assert result.meta["index"] == 0 - assert result.meta["completion_start_time"] is not None - assert result.meta["usage"] == { - "completion_tokens": 42, - "prompt_tokens": 282, - "total_tokens": 324, - "completion_tokens_details": { - "accepted_prediction_tokens": 0, - "audio_tokens": 0, - "reasoning_tokens": 0, - "rejected_prediction_tokens": 0, - }, - "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}, - } - - def test_convert_usage_chunk_to_streaming_chunk(self): - chunk = ChatCompletionChunk( - id="chatcmpl-BC1y4wqIhe17R8sv3lgLcWlB4tXCw", - choices=[], - created=1742207200, - model="gpt-4o-mini-2024-07-18", - object="chat.completion.chunk", - service_tier="default", - system_fingerprint="fp_06737a9306", - usage=CompletionUsage( - completion_tokens=8, - prompt_tokens=13, - total_tokens=21, - completion_tokens_details=CompletionTokensDetails( - accepted_prediction_tokens=0, audio_tokens=0, reasoning_tokens=0, rejected_prediction_tokens=0 - ), - prompt_tokens_details=PromptTokensDetails(audio_tokens=0, cached_tokens=0), - ), - ) - result = _convert_chat_completion_chunk_to_streaming_chunk(chunk=chunk, previous_chunks=[])[0] - assert result.content == "" - assert result.start is False - assert result.tool_call is None - assert result.tool_call_result is None - assert result.meta["model"] == "gpt-4o-mini-2024-07-18" - assert result.meta["received_at"] is not None - @pytest.mark.skipif( not os.environ.get("OPENAI_API_KEY", None), reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", @@ -1032,3 +743,497 @@ def test_live_run_with_toolset(self, tools): assert tool_call.tool_name == "weather" assert tool_call.arguments == {"city": "Paris"} assert message.meta["finish_reason"] == "tool_calls" + + +@pytest.fixture +def chat_completion_chunks(): + return [ + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[chat_completion_chunk.Choice(delta=ChoiceDelta(role="assistant"), index=0)], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + ), + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[ + chat_completion_chunk.Choice( + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + id="call_zcvlnVaTeJWRjLAFfYxX69z4", + function=ChoiceDeltaToolCallFunction(arguments="", name="weather"), + type="function", + ) + ] + ), + index=0, + ) + ], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + ), + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[ + chat_completion_chunk.Choice( + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='{"ci')) + ] + ), + index=0, + ) + ], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + ), + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[ + chat_completion_chunk.Choice( + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='ty": ')) + ] + ), + index=0, + ) + ], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + ), + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[ + chat_completion_chunk.Choice( + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='"Paris')) + ] + ), + index=0, + ) + ], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + ), + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[ + chat_completion_chunk.Choice( + delta=ChoiceDelta( + tool_calls=[ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='"}'))] + ), + index=0, + ) + ], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + ), + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[ + chat_completion_chunk.Choice( + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall( + index=1, + id="call_C88m67V16CrETq6jbNXjdZI9", + function=ChoiceDeltaToolCallFunction(arguments="", name="weather"), + type="function", + ) + ] + ), + index=0, + ) + ], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + ), + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[ + chat_completion_chunk.Choice( + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='{"ci')) + ] + ), + index=0, + ) + ], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + ), + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[ + chat_completion_chunk.Choice( + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='ty": ')) + ] + ), + index=0, + ) + ], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + ), + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[ + chat_completion_chunk.Choice( + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='"Berli')) + ] + ), + index=0, + ) + ], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + ), + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[ + chat_completion_chunk.Choice( + delta=ChoiceDelta( + tool_calls=[ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='n"}'))] + ), + index=0, + ) + ], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + ), + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[chat_completion_chunk.Choice(delta=ChoiceDelta(), finish_reason="tool_calls", index=0)], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + ), + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + usage=CompletionUsage( + completion_tokens=42, + prompt_tokens=282, + total_tokens=324, + completion_tokens_details=CompletionTokensDetails( + accepted_prediction_tokens=0, audio_tokens=0, reasoning_tokens=0, rejected_prediction_tokens=0 + ), + prompt_tokens_details=PromptTokensDetails(audio_tokens=0, cached_tokens=0), + ), + ), + ] + + +@pytest.fixture +def streaming_chunks(): + return [ + StreamingChunk( + content="", + meta={ + "model": "gpt-4o-mini-2024-07-18", + "index": 0, + "tool_calls": None, + "finish_reason": None, + "received_at": ANY, + "usage": None, + }, + ), + StreamingChunk( + content="", + meta={ + "model": "gpt-4o-mini-2024-07-18", + "index": 0, + "tool_calls": [ + ChoiceDeltaToolCall( + index=0, + id="call_zcvlnVaTeJWRjLAFfYxX69z4", + function=ChoiceDeltaToolCallFunction(arguments="", name="weather"), + type="function", + ) + ], + "finish_reason": None, + "received_at": ANY, + "usage": None, + }, + index=0, + tool_calls=[ToolCallDelta(tool_name="weather", id="call_zcvlnVaTeJWRjLAFfYxX69z4", index=0)], + start=True, + ), + StreamingChunk( + content="", + meta={ + "model": "gpt-4o-mini-2024-07-18", + "index": 0, + "tool_calls": [ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='{"ci'))], + "finish_reason": None, + "received_at": ANY, + "usage": None, + }, + index=0, + tool_calls=[ToolCallDelta(arguments='{"ci', index=0)], + ), + StreamingChunk( + content="", + meta={ + "model": "gpt-4o-mini-2024-07-18", + "index": 0, + "tool_calls": [ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='ty": '))], + "finish_reason": None, + "received_at": ANY, + "usage": None, + }, + index=0, + tool_calls=[ToolCallDelta(arguments='ty": ', index=0)], + ), + StreamingChunk( + content="", + meta={ + "model": "gpt-4o-mini-2024-07-18", + "index": 0, + "tool_calls": [ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='"Paris'))], + "finish_reason": None, + "received_at": ANY, + "usage": None, + }, + index=0, + tool_calls=[ToolCallDelta(arguments='"Paris', index=0)], + ), + StreamingChunk( + content="", + meta={ + "model": "gpt-4o-mini-2024-07-18", + "index": 0, + "tool_calls": [ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='"}'))], + "finish_reason": None, + "received_at": ANY, + "usage": None, + }, + index=0, + tool_calls=[ToolCallDelta(arguments='"}', index=0)], + ), + StreamingChunk( + content="", + meta={ + "model": "gpt-4o-mini-2024-07-18", + "index": 0, + "tool_calls": [ + ChoiceDeltaToolCall( + index=1, + id="call_C88m67V16CrETq6jbNXjdZI9", + function=ChoiceDeltaToolCallFunction(arguments="", name="weather"), + type="function", + ) + ], + "finish_reason": None, + "received_at": ANY, + "usage": None, + }, + index=1, + tool_calls=[ToolCallDelta(tool_name="weather", id="call_C88m67V16CrETq6jbNXjdZI9", index=1)], + start=True, + ), + StreamingChunk( + content="", + meta={ + "model": "gpt-4o-mini-2024-07-18", + "index": 0, + "tool_calls": [ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='{"ci'))], + "finish_reason": None, + "received_at": ANY, + "usage": None, + }, + index=1, + tool_calls=[ToolCallDelta(arguments='{"ci', index=1)], + ), + StreamingChunk( + content="", + meta={ + "model": "gpt-4o-mini-2024-07-18", + "index": 0, + "tool_calls": [ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='ty": '))], + "finish_reason": None, + "received_at": ANY, + "usage": None, + }, + index=1, + tool_calls=[ToolCallDelta(arguments='ty": ', index=1)], + ), + StreamingChunk( + content="", + meta={ + "model": "gpt-4o-mini-2024-07-18", + "index": 0, + "tool_calls": [ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='"Berli'))], + "finish_reason": None, + "received_at": ANY, + "usage": None, + }, + index=1, + tool_calls=[ToolCallDelta(arguments='"Berli', index=1)], + ), + StreamingChunk( + content="", + meta={ + "model": "gpt-4o-mini-2024-07-18", + "index": 0, + "tool_calls": [ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='n"}'))], + "finish_reason": None, + "received_at": ANY, + "usage": None, + }, + index=1, + tool_calls=[ToolCallDelta(arguments='n"}', index=1)], + ), + StreamingChunk( + content="", + meta={ + "model": "gpt-4o-mini-2024-07-18", + "index": 0, + "tool_calls": None, + "finish_reason": "tool_calls", + "received_at": ANY, + "usage": None, + }, + ), + StreamingChunk( + content="", + meta={ + "model": "gpt-4o-mini-2024-07-18", + "received_at": ANY, + "usage": { + "completion_tokens": 42, + "prompt_tokens": 282, + "total_tokens": 324, + "completion_tokens_details": { + "accepted_prediction_tokens": 0, + "audio_tokens": 0, + "reasoning_tokens": 0, + "rejected_prediction_tokens": 0, + }, + "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}, + }, + }, + ), + ] + + +class TestChatCompletionChunkConversion: + def test_convert_chat_completion_chunk_to_streaming_chunk(self, chat_completion_chunks, streaming_chunks): + previous_chunks = [] + for openai_chunk, haystack_chunk in zip(chat_completion_chunks, streaming_chunks): + stream_chunk = _convert_chat_completion_chunk_to_streaming_chunk( + chunk=openai_chunk, previous_chunks=previous_chunks + ) + assert stream_chunk == haystack_chunk + previous_chunks.append(openai_chunk) + + def test_handle_stream_response(self, chat_completion_chunks): + openai_chunks = chat_completion_chunks + comp = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) + result = comp._handle_stream_response(openai_chunks, callback=lambda chunk: None)[0] # type: ignore + + assert not result.texts + assert not result.text + + # Verify both tool calls were found and processed + assert len(result.tool_calls) == 2 + assert result.tool_calls[0].id == "call_zcvlnVaTeJWRjLAFfYxX69z4" + assert result.tool_calls[0].tool_name == "weather" + assert result.tool_calls[0].arguments == {"city": "Paris"} + assert result.tool_calls[1].id == "call_C88m67V16CrETq6jbNXjdZI9" + assert result.tool_calls[1].tool_name == "weather" + assert result.tool_calls[1].arguments == {"city": "Berlin"} + + # Verify meta information + assert result.meta["model"] == "gpt-4o-mini-2024-07-18" + assert result.meta["finish_reason"] == "tool_calls" + assert result.meta["index"] == 0 + assert result.meta["completion_start_time"] is not None + assert result.meta["usage"] == { + "completion_tokens": 42, + "prompt_tokens": 282, + "total_tokens": 324, + "completion_tokens_details": { + "accepted_prediction_tokens": 0, + "audio_tokens": 0, + "reasoning_tokens": 0, + "rejected_prediction_tokens": 0, + }, + "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}, + } + + def test_convert_usage_chunk_to_streaming_chunk(self): + usage_chunk = ChatCompletionChunk( + id="chatcmpl-BC1y4wqIhe17R8sv3lgLcWlB4tXCw", + choices=[], + created=1742207200, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_06737a9306", + usage=CompletionUsage( + completion_tokens=8, + prompt_tokens=13, + total_tokens=21, + completion_tokens_details=CompletionTokensDetails( + accepted_prediction_tokens=0, audio_tokens=0, reasoning_tokens=0, rejected_prediction_tokens=0 + ), + prompt_tokens_details=PromptTokensDetails(audio_tokens=0, cached_tokens=0), + ), + ) + result = _convert_chat_completion_chunk_to_streaming_chunk(chunk=usage_chunk, previous_chunks=[]) + assert result.content == "" + assert result.start is False + assert result.tool_calls is None + assert result.tool_call_result is None + assert result.meta["model"] == "gpt-4o-mini-2024-07-18" + assert result.meta["received_at"] is not None diff --git a/test/components/generators/test_utils.py b/test/components/generators/test_utils.py index cc6d6edfdf..f4307afd50 100644 --- a/test/components/generators/test_utils.py +++ b/test/components/generators/test_utils.py @@ -42,7 +42,9 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): component_info=ComponentInfo(name="test", type="test"), index=0, start=True, - tool_call=ToolCallDelta(id="call_ZOj5l67zhZOx6jqjg7ATQwb6", tool_name="rag_pipeline_tool", arguments=""), + tool_calls=[ + ToolCallDelta(id="call_ZOj5l67zhZOx6jqjg7ATQwb6", tool_name="rag_pipeline_tool", arguments="", index=0) + ], ), StreamingChunk( content="", @@ -59,7 +61,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): }, component_info=ComponentInfo(name="test", type="test"), index=0, - tool_call=ToolCallDelta(arguments='{"qu'), + tool_calls=[ToolCallDelta(arguments='{"qu', index=0)], ), StreamingChunk( content="", @@ -76,7 +78,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): }, component_info=ComponentInfo(name="test", type="test"), index=0, - tool_call=ToolCallDelta(arguments='ery":'), + tool_calls=[ToolCallDelta(arguments='ery":', index=0)], ), StreamingChunk( content="", @@ -93,7 +95,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): }, component_info=ComponentInfo(name="test", type="test"), index=0, - tool_call=ToolCallDelta(arguments=' "Wher'), + tool_calls=[ToolCallDelta(arguments=' "Wher', index=0)], ), StreamingChunk( content="", @@ -110,7 +112,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): }, component_info=ComponentInfo(name="test", type="test"), index=0, - tool_call=ToolCallDelta(arguments="e do"), + tool_calls=[ToolCallDelta(arguments="e do", index=0)], ), StreamingChunk( content="", @@ -127,7 +129,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): }, component_info=ComponentInfo(name="test", type="test"), index=0, - tool_call=ToolCallDelta(arguments="es Ma"), + tool_calls=[ToolCallDelta(arguments="es Ma", index=0)], ), StreamingChunk( content="", @@ -144,7 +146,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): }, component_info=ComponentInfo(name="test", type="test"), index=0, - tool_call=ToolCallDelta(arguments="rk liv"), + tool_calls=[ToolCallDelta(arguments="rk liv", index=0)], ), StreamingChunk( content="", @@ -161,7 +163,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): }, component_info=ComponentInfo(name="test", type="test"), index=0, - tool_call=ToolCallDelta(arguments='e?"}'), + tool_calls=[ToolCallDelta(arguments='e?"}', index=0)], ), StreamingChunk( content="", @@ -182,7 +184,9 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): component_info=ComponentInfo(name="test", type="test"), index=1, start=True, - tool_call=ToolCallDelta(id="call_STxsYY69wVOvxWqopAt3uWTB", tool_name="get_weather", arguments=""), + tool_calls=[ + ToolCallDelta(id="call_STxsYY69wVOvxWqopAt3uWTB", tool_name="get_weather", arguments="", index=1) + ], ), StreamingChunk( content="", @@ -199,7 +203,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): }, component_info=ComponentInfo(name="test", type="test"), index=1, - tool_call=ToolCallDelta(arguments='{"ci'), + tool_calls=[ToolCallDelta(arguments='{"ci', index=1)], ), StreamingChunk( content="", @@ -216,7 +220,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): }, component_info=ComponentInfo(name="test", type="test"), index=1, - tool_call=ToolCallDelta(arguments='ty": '), + tool_calls=[ToolCallDelta(arguments='ty": ', index=1)], ), StreamingChunk( content="", @@ -233,7 +237,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): }, component_info=ComponentInfo(name="test", type="test"), index=1, - tool_call=ToolCallDelta(arguments='"Berli'), + tool_calls=[ToolCallDelta(arguments='"Berli', index=1)], ), StreamingChunk( content="", @@ -250,7 +254,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): }, component_info=ComponentInfo(name="test", type="test"), index=1, - tool_call=ToolCallDelta(arguments='n"}'), + tool_calls=[ToolCallDelta(arguments='n"}', index=1)], ), StreamingChunk( content="", diff --git a/test/dataclasses/test_streaming_chunk.py b/test/dataclasses/test_streaming_chunk.py index 695d155483..aa7c13424f 100644 --- a/test/dataclasses/test_streaming_chunk.py +++ b/test/dataclasses/test_streaming_chunk.py @@ -58,7 +58,7 @@ def test_create_chunk_with_content_and_tool_call(): StreamingChunk( content="Test content", meta={"key": "value"}, - tool_call=ToolCallDelta(id="123", tool_name="test_tool", arguments='{"arg1": "value1"}'), + tool_calls=[ToolCallDelta(id="123", tool_name="test_tool", arguments='{"arg1": "value1"}', index=0)], ) @@ -92,12 +92,13 @@ def test_component_info_from_component_with_name_from_pipeline(): def test_tool_call_delta(): - tool_call = ToolCallDelta(id="123", tool_name="test_tool", arguments='{"arg1": "value1"}') + tool_call = ToolCallDelta(id="123", tool_name="test_tool", arguments='{"arg1": "value1"}', index=0) assert tool_call.id == "123" assert tool_call.tool_name == "test_tool" assert tool_call.arguments == '{"arg1": "value1"}' + assert tool_call.index == 0 def test_tool_call_delta_with_missing_fields(): with pytest.raises(ValueError): - _ = ToolCallDelta(id="123") + _ = ToolCallDelta(id="123", index=0)