-
Notifications
You must be signed in to change notification settings - Fork 2.8k
feat: Update streaming chunk #9424
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0446fe5
f2ddbff
6cb7a31
005ef69
e29b6f2
5914d5b
d141b47
ac51918
012c0bb
6048328
010c037
93758fd
f43477d
a907d9e
ced8fd8
22314b8
bc306d3
b625395
8cbefeb
7cac572
4bfbe58
3f8f661
51c8440
658b47b
27ca068
3a0558f
e5d2d74
c3d303c
59339f7
7147e03
06cb3ec
c2a2eb0
665c710
9c8ae7b
f75bd61
29a9b8b
26adf02
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,7 @@ | |
| StreamingChunk, | ||
| SyncStreamingCallbackT, | ||
| ToolCall, | ||
| ToolCallDelta, | ||
| select_streaming_callback, | ||
| ) | ||
| from haystack.tools import ( | ||
|
|
@@ -422,9 +423,12 @@ def _handle_stream_response(self, chat_completion: Stream, callback: SyncStreami | |
| chunks: List[StreamingChunk] = [] | ||
| for chunk in chat_completion: # pylint: disable=not-an-iterable | ||
| assert len(chunk.choices) <= 1, "Streaming responses should have at most one choice." | ||
| chunk_delta = _convert_chat_completion_chunk_to_streaming_chunk(chunk=chunk, component_info=component_info) | ||
| chunks.append(chunk_delta) | ||
| callback(chunk_delta) | ||
| chunk_deltas = _convert_chat_completion_chunk_to_streaming_chunk( | ||
| chunk=chunk, previous_chunks=chunks, component_info=component_info | ||
| ) | ||
| for chunk_delta in chunk_deltas: | ||
| chunks.append(chunk_delta) | ||
| callback(chunk_delta) | ||
| return [_convert_streaming_chunks_to_chat_message(chunks=chunks)] | ||
|
|
||
| async def _handle_async_stream_response( | ||
|
|
@@ -434,9 +438,12 @@ async def _handle_async_stream_response( | |
| chunks: List[StreamingChunk] = [] | ||
| async for chunk in chat_completion: # pylint: disable=not-an-iterable | ||
| assert len(chunk.choices) <= 1, "Streaming responses should have at most one choice." | ||
| chunk_delta = _convert_chat_completion_chunk_to_streaming_chunk(chunk=chunk, component_info=component_info) | ||
| chunks.append(chunk_delta) | ||
| await callback(chunk_delta) | ||
| chunk_deltas = _convert_chat_completion_chunk_to_streaming_chunk( | ||
| chunk=chunk, previous_chunks=chunks, component_info=component_info | ||
| ) | ||
| for chunk_delta in chunk_deltas: | ||
| chunks.append(chunk_delta) | ||
| await callback(chunk_delta) | ||
| return [_convert_streaming_chunks_to_chat_message(chunks=chunks)] | ||
|
|
||
|
|
||
|
|
@@ -497,34 +504,77 @@ def _convert_chat_completion_to_chat_message(completion: ChatCompletion, choice: | |
|
|
||
|
|
||
| def _convert_chat_completion_chunk_to_streaming_chunk( | ||
| chunk: ChatCompletionChunk, component_info: Optional[ComponentInfo] = None | ||
| ) -> StreamingChunk: | ||
| chunk: ChatCompletionChunk, previous_chunks: List[StreamingChunk], component_info: Optional[ComponentInfo] = None | ||
| ) -> List[StreamingChunk]: | ||
| """ | ||
| Converts the streaming response chunk from the OpenAI API to a StreamingChunk. | ||
|
|
||
| :param chunk: The chunk returned by the OpenAI API. | ||
| :param previous_chunks: A list of previously received StreamingChunks. | ||
| :param component_info: An optional `ComponentInfo` object containing information about the component that | ||
| generated the chunk, such as the component name and type. | ||
|
|
||
| :returns: | ||
| The StreamingChunk. | ||
| A list of StreamingChunk objects representing the content of the chunk from the OpenAI API. | ||
| """ | ||
| # Choices is empty on the very first chunk which provides role information (e.g. "assistant"). | ||
| # It is also empty if include_usage is set to True where the usage information is returned. | ||
| if len(chunk.choices) == 0: | ||
| return StreamingChunk( | ||
| content="", | ||
| meta={ | ||
| "model": chunk.model, | ||
| "received_at": datetime.now().isoformat(), | ||
| "usage": _serialize_usage(chunk.usage), | ||
| }, | ||
| component_info=component_info, | ||
| ) | ||
| return [ | ||
| StreamingChunk( | ||
| content="", | ||
| component_info=component_info, | ||
| # Index is None since it's only set to an int when a content block is present | ||
| index=None, | ||
| meta={ | ||
| "model": chunk.model, | ||
| "received_at": datetime.now().isoformat(), | ||
| "usage": _serialize_usage(chunk.usage), | ||
| }, | ||
| ) | ||
| ] | ||
|
|
||
| choice: ChunkChoice = chunk.choices[0] | ||
| content = choice.delta.content or "" | ||
|
|
||
| # create a list of ToolCallDelta objects from the tool calls | ||
| if choice.delta.tool_calls: | ||
| chunk_messages = [] | ||
| for tool_call in choice.delta.tool_calls: | ||
| function = tool_call.function | ||
| chunk_message = StreamingChunk( | ||
| content=content, | ||
| # We adopt the tool_call.index as the index of the chunk | ||
| component_info=component_info, | ||
| index=tool_call.index, | ||
| tool_call=ToolCallDelta( | ||
| id=tool_call.id, | ||
| tool_name=function.name if function else None, | ||
| arguments=function.arguments if function and function.arguments else None, | ||
| ), | ||
|
Comment on lines
+547
to
+554
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This choice was made
since I believe it makes intuitive sense and is the format used by other providers like Bedrock, Cohere and Anthropic. |
||
| start=function.name is not None if function else False, | ||
| meta={ | ||
| "model": chunk.model, | ||
| "index": choice.index, | ||
| "tool_calls": choice.delta.tool_calls, | ||
| "finish_reason": choice.finish_reason, | ||
| "received_at": datetime.now().isoformat(), | ||
| "usage": _serialize_usage(chunk.usage), | ||
| }, | ||
|
Comment on lines
+556
to
+563
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We still keep everything in meta like we did before to prevent any breaking changes |
||
| ) | ||
| chunk_messages.append(chunk_message) | ||
| return chunk_messages | ||
|
|
||
| chunk_message = StreamingChunk( | ||
| content=content, | ||
| component_info=component_info, | ||
| # We set the index to be 0 since if text content is being streamed then no tool calls are being streamed | ||
| # NOTE: We may need to revisit this if OpenAI allows planning/thinking content before tool calls like | ||
| # Anthropic Claude | ||
| index=0, | ||
| # The first chunk is always a start message chunk that only contains role information, so if we reach here | ||
| # and previous_chunks is length 1 then this is the start of text content. | ||
| start=len(previous_chunks) == 1, | ||
| meta={ | ||
| "model": chunk.model, | ||
| "index": choice.index, | ||
|
|
@@ -533,9 +583,8 @@ def _convert_chat_completion_chunk_to_streaming_chunk( | |
| "received_at": datetime.now().isoformat(), | ||
| "usage": _serialize_usage(chunk.usage), | ||
| }, | ||
| component_info=component_info, | ||
| ) | ||
| return chunk_message | ||
| return [chunk_message] | ||
|
|
||
|
|
||
| def _serialize_usage(usage): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,9 +3,7 @@ | |
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import json | ||
| from typing import Any, Dict, List | ||
|
|
||
| from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall | ||
| from typing import Dict, List | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We've updated both utility functions in this file to use the new StreamingChunk fields instead of relying on specific values in the metadata. |
||
|
|
||
| from haystack import logging | ||
| from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall | ||
|
|
@@ -28,33 +26,38 @@ def print_streaming_chunk(chunk: StreamingChunk) -> None: | |
| :param chunk: A chunk of streaming data containing content and optional metadata, such as tool calls and | ||
| tool results. | ||
| """ | ||
| # Print tool call metadata if available (from ChatGenerator) | ||
| if tool_calls := chunk.meta.get("tool_calls"): | ||
| for tool_call in tool_calls: | ||
| # Convert to dict if tool_call is a ChoiceDeltaToolCall | ||
| tool_call_dict: Dict[str, Any] = ( | ||
| tool_call.to_dict() if isinstance(tool_call, ChoiceDeltaToolCall) else tool_call | ||
| ) | ||
| if chunk.start and chunk.index and chunk.index > 0: | ||
| # If this is the start of a new content block but not the first content block, print two new lines | ||
| print("\n\n", flush=True, end="") | ||
|
|
||
| if function := tool_call_dict.get("function"): | ||
| if name := function.get("name"): | ||
| print("\n\n[TOOL CALL]\n", flush=True, end="") | ||
| print(f"Tool: {name} ", flush=True, end="") | ||
| print("\nArguments: ", flush=True, end="") | ||
| ## Tool Call streaming | ||
| if chunk.tool_call: | ||
| # If chunk.start is True indicates beginning of a tool call | ||
| # Also presence of chunk.tool_call.name indicates the start of a tool call too | ||
| if chunk.start: | ||
| print("[TOOL CALL]\n", flush=True, end="") | ||
| print(f"Tool: {chunk.tool_call.tool_name} ", flush=True, end="") | ||
| print("\nArguments: ", flush=True, end="") | ||
|
|
||
| if arguments := function.get("arguments"): | ||
| print(arguments, flush=True, end="") | ||
| # print the tool arguments | ||
| if chunk.tool_call.arguments: | ||
| print(chunk.tool_call.arguments, flush=True, end="") | ||
|
|
||
| ## Tool Call Result streaming | ||
| # Print tool call results if available (from ToolInvoker) | ||
| if tool_result := chunk.meta.get("tool_result"): | ||
| print(f"\n\n[TOOL RESULT]\n{tool_result}", flush=True, end="") | ||
| if chunk.tool_call_result: | ||
| # Tool Call Result is fully formed so delta accumulation is not needed | ||
| print(f"[TOOL RESULT]\n{chunk.tool_call_result.result}", flush=True, end="") | ||
|
|
||
| ## Normal content streaming | ||
| # Print the main content of the chunk (from ChatGenerator) | ||
| if content := chunk.content: | ||
| print(content, flush=True, end="") | ||
| if chunk.content: | ||
| if chunk.start: | ||
| print("[ASSISTANT]\n", flush=True, end="") | ||
| print(chunk.content, flush=True, end="") | ||
|
|
||
| # End of LLM assistant message so we add two new lines | ||
| # This ensures spacing between multiple LLM messages (e.g. Agent) | ||
| # This ensures spacing between multiple LLM messages (e.g. Agent) or multiple Tool Call Results | ||
| if chunk.meta.get("finish_reason") is not None: | ||
| print("\n\n", flush=True, end="") | ||
|
|
||
|
|
@@ -71,38 +74,41 @@ def _convert_streaming_chunks_to_chat_message(chunks: List[StreamingChunk]) -> C | |
| tool_calls = [] | ||
|
|
||
| # Process tool calls if present in any chunk | ||
| tool_call_data: Dict[str, Dict[str, str]] = {} # Track tool calls by index | ||
| for chunk_payload in chunks: | ||
| tool_calls_meta = chunk_payload.meta.get("tool_calls") | ||
| if tool_calls_meta is not None: | ||
| for delta in tool_calls_meta: | ||
| # We use the index of the tool call to track it across chunks since the ID is not always provided | ||
| if delta.index not in tool_call_data: | ||
| tool_call_data[delta.index] = {"id": "", "name": "", "arguments": ""} | ||
|
|
||
| # Save the ID if present | ||
| if delta.id is not None: | ||
| tool_call_data[delta.index]["id"] = delta.id | ||
|
|
||
| if delta.function is not None: | ||
| if delta.function.name is not None: | ||
| tool_call_data[delta.index]["name"] += delta.function.name | ||
| if delta.function.arguments is not None: | ||
| tool_call_data[delta.index]["arguments"] += delta.function.arguments | ||
| tool_call_data: Dict[int, Dict[str, str]] = {} # Track tool calls by index | ||
| for chunk in chunks: | ||
| if chunk.tool_call: | ||
| # We do this to make sure mypy is happy, but we enforce index is not None in the StreamingChunk dataclass if | ||
| # tool_call is present | ||
| assert chunk.index is not None | ||
|
|
||
| # We use the index of the chunk to track the tool call across chunks since the ID is not always provided | ||
| if chunk.index not in tool_call_data: | ||
| tool_call_data[chunk.index] = {"id": "", "name": "", "arguments": ""} | ||
|
|
||
| # Save the ID if present | ||
| if chunk.tool_call.id is not None: | ||
| tool_call_data[chunk.index]["id"] = chunk.tool_call.id | ||
|
|
||
| if chunk.tool_call.tool_name is not None: | ||
| tool_call_data[chunk.index]["name"] += chunk.tool_call.tool_name | ||
| if chunk.tool_call.arguments is not None: | ||
| tool_call_data[chunk.index]["arguments"] += chunk.tool_call.arguments | ||
|
|
||
| # Convert accumulated tool call data into ToolCall objects | ||
| for call_data in tool_call_data.values(): | ||
| sorted_keys = sorted(tool_call_data.keys()) | ||
| for key in sorted_keys: | ||
| tool_call = tool_call_data[key] | ||
| try: | ||
| arguments = json.loads(call_data["arguments"]) | ||
| tool_calls.append(ToolCall(id=call_data["id"], tool_name=call_data["name"], arguments=arguments)) | ||
| arguments = json.loads(tool_call["arguments"]) | ||
| tool_calls.append(ToolCall(id=tool_call["id"], tool_name=tool_call["name"], arguments=arguments)) | ||
| except json.JSONDecodeError: | ||
| logger.warning( | ||
| "OpenAI returned a malformed JSON string for tool call arguments. This tool call " | ||
| "will be skipped. To always generate a valid JSON, set `tools_strict` to `True`. " | ||
| "Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}", | ||
| _id=call_data["id"], | ||
| _name=call_data["name"], | ||
| _arguments=call_data["arguments"], | ||
| _id=tool_call["id"], | ||
| _name=tool_call["name"], | ||
| _arguments=tool_call["arguments"], | ||
| ) | ||
|
|
||
| # finish_reason can appear in different places so we look for the last one | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -504,6 +504,9 @@ def run( | |
| streaming_callback( | ||
| StreamingChunk( | ||
| content="", | ||
| index=len(tool_messages) - 1, | ||
| tool_call_result=tool_messages[-1].tool_call_results[0], | ||
| start=True, | ||
|
Comment on lines
+507
to
+509
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In ToolInvoker the whole ToolCallResult is always sent so start is always True and index is just the current message index. |
||
| meta={"tool_result": tool_messages[-1].tool_call_results[0].result, "tool_call": tool_call}, | ||
| ) | ||
| ) | ||
|
|
@@ -609,6 +612,9 @@ async def run_async( | |
| await streaming_callback( | ||
| StreamingChunk( | ||
| content="", | ||
| index=len(tool_messages) - 1, | ||
| tool_call_result=tool_messages[-1].tool_call_results[0], | ||
| start=True, | ||
| meta={"tool_result": tool_messages[-1].tool_call_results[0].result, "tool_call": tool_call}, | ||
| ) | ||
| ) # type: ignore[misc] # we have checked that streaming_callback is not None and async | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We return multiple chunks now because our
StreamingChunkhas been set up to only be able to contain one type of content block at a time. So if there are somehow two tool calls returned at once from OpenAI we would convert these into two separate StreamingChunks.As a note, I'm not sure how to trigger this behavior by OpenAI, but this technically needs to be done to follow their SDK spec.