Skip to content

Commit b61886b

Browse files
sjrljulian-risch
andauthored
feat: Update streaming chunk (#9424)
* Start expanding StreamingChunk * First pass at expanding Streaming Chunk * Working version! * Some tweaks and also make ToolInvoker stream a chunk with a finish reason * Properly update test * Change to tool_name, remove kw_only since its python 3.10 only and update HuggingFaceAPIChatGenerator to start following new StreamingChunk * Add reno * Some cleanup * Fix unit tests * Fix mypy and integration test * Fix pylint * Start refactoring huggingface local api * Refactor openai generator and chat generator to reuse util methods * Did some reorg * Reusue utility method in HuggingFaceAPI * Get rid of unneeded default values in tests * Update conversion of streaming chunks to chat message to not rely on openai dataclass anymore * Fix tests and loosen check in StreamingChunk post_init * Fixes * Fix license header * Add start and index to HFAPIGenerator * Fix mypy * Clean up * Update haystack/components/generators/utils.py Co-authored-by: Julian Risch <julian.risch@deepset.ai> * Update haystack/components/generators/utils.py Co-authored-by: Julian Risch <julian.risch@deepset.ai> * Change StreamingChunk.start to only a bool * PR comments * Fix unit test * PR comment * Fix test --------- Co-authored-by: Julian Risch <julian.risch@deepset.ai>
1 parent f85ce19 commit b61886b

15 files changed

Lines changed: 620 additions & 143 deletions

File tree

haystack/components/generators/chat/hugging_face_api.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,9 @@ def _convert_tools_to_hfapi_tools(
103103

104104

105105
def _convert_chat_completion_stream_output_to_streaming_chunk(
106-
chunk: "ChatCompletionStreamOutput", component_info: Optional[ComponentInfo] = None
106+
chunk: "ChatCompletionStreamOutput",
107+
previous_chunks: List[StreamingChunk],
108+
component_info: Optional[ComponentInfo] = None,
107109
) -> StreamingChunk:
108110
"""
109111
Converts the Hugging Face API ChatCompletionStreamOutput to a StreamingChunk.
@@ -127,6 +129,10 @@ def _convert_chat_completion_stream_output_to_streaming_chunk(
127129
content=choice.delta.content or "",
128130
meta={"model": chunk.model, "received_at": datetime.now().isoformat(), "finish_reason": choice.finish_reason},
129131
component_info=component_info,
132+
# Index must always be 0 since we don't allow tool calls in streaming mode.
133+
index=0 if choice.finish_reason is None else None,
134+
# start is True at the very beginning since first chunk contains role information + first part of the answer.
135+
start=len(previous_chunks) == 0,
130136
)
131137
return stream_chunk
132138

@@ -441,10 +447,10 @@ def _run_streaming(
441447
)
442448

443449
component_info = ComponentInfo.from_component(self)
444-
streaming_chunks = []
450+
streaming_chunks: List[StreamingChunk] = []
445451
for chunk in api_output:
446452
streaming_chunk = _convert_chat_completion_stream_output_to_streaming_chunk(
447-
chunk=chunk, component_info=component_info
453+
chunk=chunk, previous_chunks=streaming_chunks, component_info=component_info
448454
)
449455
streaming_chunks.append(streaming_chunk)
450456
streaming_callback(streaming_chunk)
@@ -505,10 +511,10 @@ async def _run_streaming_async(
505511
)
506512

507513
component_info = ComponentInfo.from_component(self)
508-
streaming_chunks = []
514+
streaming_chunks: List[StreamingChunk] = []
509515
async for chunk in api_output:
510516
stream_chunk = _convert_chat_completion_stream_output_to_streaming_chunk(
511-
chunk=chunk, component_info=component_info
517+
chunk=chunk, previous_chunks=streaming_chunks, component_info=component_info
512518
)
513519
streaming_chunks.append(stream_chunk)
514520
await streaming_callback(stream_chunk) # type: ignore

haystack/components/generators/chat/openai.py

Lines changed: 69 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
StreamingChunk,
2323
SyncStreamingCallbackT,
2424
ToolCall,
25+
ToolCallDelta,
2526
select_streaming_callback,
2627
)
2728
from haystack.tools import (
@@ -422,9 +423,12 @@ def _handle_stream_response(self, chat_completion: Stream, callback: SyncStreami
422423
chunks: List[StreamingChunk] = []
423424
for chunk in chat_completion: # pylint: disable=not-an-iterable
424425
assert len(chunk.choices) <= 1, "Streaming responses should have at most one choice."
425-
chunk_delta = _convert_chat_completion_chunk_to_streaming_chunk(chunk=chunk, component_info=component_info)
426-
chunks.append(chunk_delta)
427-
callback(chunk_delta)
426+
chunk_deltas = _convert_chat_completion_chunk_to_streaming_chunk(
427+
chunk=chunk, previous_chunks=chunks, component_info=component_info
428+
)
429+
for chunk_delta in chunk_deltas:
430+
chunks.append(chunk_delta)
431+
callback(chunk_delta)
428432
return [_convert_streaming_chunks_to_chat_message(chunks=chunks)]
429433

430434
async def _handle_async_stream_response(
@@ -434,9 +438,12 @@ async def _handle_async_stream_response(
434438
chunks: List[StreamingChunk] = []
435439
async for chunk in chat_completion: # pylint: disable=not-an-iterable
436440
assert len(chunk.choices) <= 1, "Streaming responses should have at most one choice."
437-
chunk_delta = _convert_chat_completion_chunk_to_streaming_chunk(chunk=chunk, component_info=component_info)
438-
chunks.append(chunk_delta)
439-
await callback(chunk_delta)
441+
chunk_deltas = _convert_chat_completion_chunk_to_streaming_chunk(
442+
chunk=chunk, previous_chunks=chunks, component_info=component_info
443+
)
444+
for chunk_delta in chunk_deltas:
445+
chunks.append(chunk_delta)
446+
await callback(chunk_delta)
440447
return [_convert_streaming_chunks_to_chat_message(chunks=chunks)]
441448

442449

@@ -497,34 +504,77 @@ def _convert_chat_completion_to_chat_message(completion: ChatCompletion, choice:
497504

498505

499506
def _convert_chat_completion_chunk_to_streaming_chunk(
500-
chunk: ChatCompletionChunk, component_info: Optional[ComponentInfo] = None
501-
) -> StreamingChunk:
507+
chunk: ChatCompletionChunk, previous_chunks: List[StreamingChunk], component_info: Optional[ComponentInfo] = None
508+
) -> List[StreamingChunk]:
502509
"""
503510
Converts the streaming response chunk from the OpenAI API to a StreamingChunk.
504511
505512
:param chunk: The chunk returned by the OpenAI API.
513+
:param previous_chunks: A list of previously received StreamingChunks.
514+
:param component_info: An optional `ComponentInfo` object containing information about the component that
515+
generated the chunk, such as the component name and type.
506516
507517
:returns:
508-
The StreamingChunk.
518+
A list of StreamingChunk objects representing the content of the chunk from the OpenAI API.
509519
"""
510520
# Choices is empty on the very first chunk which provides role information (e.g. "assistant").
511521
# It is also empty if include_usage is set to True where the usage information is returned.
512522
if len(chunk.choices) == 0:
513-
return StreamingChunk(
514-
content="",
515-
meta={
516-
"model": chunk.model,
517-
"received_at": datetime.now().isoformat(),
518-
"usage": _serialize_usage(chunk.usage),
519-
},
520-
component_info=component_info,
521-
)
523+
return [
524+
StreamingChunk(
525+
content="",
526+
component_info=component_info,
527+
# Index is None since it's only set to an int when a content block is present
528+
index=None,
529+
meta={
530+
"model": chunk.model,
531+
"received_at": datetime.now().isoformat(),
532+
"usage": _serialize_usage(chunk.usage),
533+
},
534+
)
535+
]
522536

523537
choice: ChunkChoice = chunk.choices[0]
524538
content = choice.delta.content or ""
525539

540+
# create a list of ToolCallDelta objects from the tool calls
541+
if choice.delta.tool_calls:
542+
chunk_messages = []
543+
for tool_call in choice.delta.tool_calls:
544+
function = tool_call.function
545+
chunk_message = StreamingChunk(
546+
content=content,
547+
# We adopt the tool_call.index as the index of the chunk
548+
component_info=component_info,
549+
index=tool_call.index,
550+
tool_call=ToolCallDelta(
551+
id=tool_call.id,
552+
tool_name=function.name if function else None,
553+
arguments=function.arguments if function and function.arguments else None,
554+
),
555+
start=function.name is not None if function else False,
556+
meta={
557+
"model": chunk.model,
558+
"index": choice.index,
559+
"tool_calls": choice.delta.tool_calls,
560+
"finish_reason": choice.finish_reason,
561+
"received_at": datetime.now().isoformat(),
562+
"usage": _serialize_usage(chunk.usage),
563+
},
564+
)
565+
chunk_messages.append(chunk_message)
566+
return chunk_messages
567+
526568
chunk_message = StreamingChunk(
527569
content=content,
570+
component_info=component_info,
571+
# We set the index to be 0 since if text content is being streamed then no tool calls are being streamed
572+
# NOTE: We may need to revisit this if OpenAI allows planning/thinking content before tool calls like
573+
# Anthropic Claude
574+
index=0,
575+
# The first chunk is always a start message chunk that only contains role information, so if we reach here
576+
# and previous_chunks is length 1 then this is the start of text content.
577+
start=len(previous_chunks) == 1,
528578
meta={
529579
"model": chunk.model,
530580
"index": choice.index,
@@ -533,9 +583,8 @@ def _convert_chat_completion_chunk_to_streaming_chunk(
533583
"received_at": datetime.now().isoformat(),
534584
"usage": _serialize_usage(chunk.usage),
535585
},
536-
component_info=component_info,
537586
)
538-
return chunk_message
587+
return [chunk_message]
539588

540589

541590
def _serialize_usage(usage):

haystack/components/generators/hugging_face_api.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,9 @@ def _stream_and_build_response(
235235
if first_chunk_time is None:
236236
first_chunk_time = datetime.now().isoformat()
237237

238-
stream_chunk = StreamingChunk(content=token.text, meta=chunk_metadata, component_info=component_info)
238+
stream_chunk = StreamingChunk(
239+
content=token.text, meta=chunk_metadata, component_info=component_info, index=0, start=len(chunks) == 0
240+
)
239241
chunks.append(stream_chunk)
240242
streaming_callback(stream_chunk)
241243

haystack/components/generators/openai.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,9 @@ def run(
247247
for chunk in completion:
248248
chunk_delta: StreamingChunk = _convert_chat_completion_chunk_to_streaming_chunk(
249249
chunk=chunk, # type: ignore
250+
previous_chunks=chunks,
250251
component_info=component_info,
251-
)
252+
)[0]
252253
chunks.append(chunk_delta)
253254
streaming_callback(chunk_delta)
254255

haystack/components/generators/utils.py

Lines changed: 52 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import json
6-
from typing import Any, Dict, List
7-
8-
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
6+
from typing import Dict, List
97

108
from haystack import logging
119
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
@@ -28,33 +26,38 @@ def print_streaming_chunk(chunk: StreamingChunk) -> None:
2826
:param chunk: A chunk of streaming data containing content and optional metadata, such as tool calls and
2927
tool results.
3028
"""
31-
# Print tool call metadata if available (from ChatGenerator)
32-
if tool_calls := chunk.meta.get("tool_calls"):
33-
for tool_call in tool_calls:
34-
# Convert to dict if tool_call is a ChoiceDeltaToolCall
35-
tool_call_dict: Dict[str, Any] = (
36-
tool_call.to_dict() if isinstance(tool_call, ChoiceDeltaToolCall) else tool_call
37-
)
29+
if chunk.start and chunk.index and chunk.index > 0:
30+
# If this is the start of a new content block but not the first content block, print two new lines
31+
print("\n\n", flush=True, end="")
3832

39-
if function := tool_call_dict.get("function"):
40-
if name := function.get("name"):
41-
print("\n\n[TOOL CALL]\n", flush=True, end="")
42-
print(f"Tool: {name} ", flush=True, end="")
43-
print("\nArguments: ", flush=True, end="")
33+
## Tool Call streaming
34+
if chunk.tool_call:
35+
# If chunk.start is True indicates beginning of a tool call
36+
# Also presence of chunk.tool_call.name indicates the start of a tool call too
37+
if chunk.start:
38+
print("[TOOL CALL]\n", flush=True, end="")
39+
print(f"Tool: {chunk.tool_call.tool_name} ", flush=True, end="")
40+
print("\nArguments: ", flush=True, end="")
4441

45-
if arguments := function.get("arguments"):
46-
print(arguments, flush=True, end="")
42+
# print the tool arguments
43+
if chunk.tool_call.arguments:
44+
print(chunk.tool_call.arguments, flush=True, end="")
4745

46+
## Tool Call Result streaming
4847
# Print tool call results if available (from ToolInvoker)
49-
if tool_result := chunk.meta.get("tool_result"):
50-
print(f"\n\n[TOOL RESULT]\n{tool_result}", flush=True, end="")
48+
if chunk.tool_call_result:
49+
# Tool Call Result is fully formed so delta accumulation is not needed
50+
print(f"[TOOL RESULT]\n{chunk.tool_call_result.result}", flush=True, end="")
5151

52+
## Normal content streaming
5253
# Print the main content of the chunk (from ChatGenerator)
53-
if content := chunk.content:
54-
print(content, flush=True, end="")
54+
if chunk.content:
55+
if chunk.start:
56+
print("[ASSISTANT]\n", flush=True, end="")
57+
print(chunk.content, flush=True, end="")
5558

5659
# End of LLM assistant message so we add two new lines
57-
# This ensures spacing between multiple LLM messages (e.g. Agent)
60+
# This ensures spacing between multiple LLM messages (e.g. Agent) or multiple Tool Call Results
5861
if chunk.meta.get("finish_reason") is not None:
5962
print("\n\n", flush=True, end="")
6063

@@ -71,38 +74,41 @@ def _convert_streaming_chunks_to_chat_message(chunks: List[StreamingChunk]) -> C
7174
tool_calls = []
7275

7376
# Process tool calls if present in any chunk
74-
tool_call_data: Dict[str, Dict[str, str]] = {} # Track tool calls by index
75-
for chunk_payload in chunks:
76-
tool_calls_meta = chunk_payload.meta.get("tool_calls")
77-
if tool_calls_meta is not None:
78-
for delta in tool_calls_meta:
79-
# We use the index of the tool call to track it across chunks since the ID is not always provided
80-
if delta.index not in tool_call_data:
81-
tool_call_data[delta.index] = {"id": "", "name": "", "arguments": ""}
82-
83-
# Save the ID if present
84-
if delta.id is not None:
85-
tool_call_data[delta.index]["id"] = delta.id
86-
87-
if delta.function is not None:
88-
if delta.function.name is not None:
89-
tool_call_data[delta.index]["name"] += delta.function.name
90-
if delta.function.arguments is not None:
91-
tool_call_data[delta.index]["arguments"] += delta.function.arguments
77+
tool_call_data: Dict[int, Dict[str, str]] = {} # Track tool calls by index
78+
for chunk in chunks:
79+
if chunk.tool_call:
80+
# We do this to make sure mypy is happy, but we enforce index is not None in the StreamingChunk dataclass if
81+
# tool_call is present
82+
assert chunk.index is not None
83+
84+
# We use the index of the chunk to track the tool call across chunks since the ID is not always provided
85+
if chunk.index not in tool_call_data:
86+
tool_call_data[chunk.index] = {"id": "", "name": "", "arguments": ""}
87+
88+
# Save the ID if present
89+
if chunk.tool_call.id is not None:
90+
tool_call_data[chunk.index]["id"] = chunk.tool_call.id
91+
92+
if chunk.tool_call.tool_name is not None:
93+
tool_call_data[chunk.index]["name"] += chunk.tool_call.tool_name
94+
if chunk.tool_call.arguments is not None:
95+
tool_call_data[chunk.index]["arguments"] += chunk.tool_call.arguments
9296

9397
# Convert accumulated tool call data into ToolCall objects
94-
for call_data in tool_call_data.values():
98+
sorted_keys = sorted(tool_call_data.keys())
99+
for key in sorted_keys:
100+
tool_call = tool_call_data[key]
95101
try:
96-
arguments = json.loads(call_data["arguments"])
97-
tool_calls.append(ToolCall(id=call_data["id"], tool_name=call_data["name"], arguments=arguments))
102+
arguments = json.loads(tool_call["arguments"])
103+
tool_calls.append(ToolCall(id=tool_call["id"], tool_name=tool_call["name"], arguments=arguments))
98104
except json.JSONDecodeError:
99105
logger.warning(
100106
"OpenAI returned a malformed JSON string for tool call arguments. This tool call "
101107
"will be skipped. To always generate a valid JSON, set `tools_strict` to `True`. "
102108
"Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}",
103-
_id=call_data["id"],
104-
_name=call_data["name"],
105-
_arguments=call_data["arguments"],
109+
_id=tool_call["id"],
110+
_name=tool_call["name"],
111+
_arguments=tool_call["arguments"],
106112
)
107113

108114
# finish_reason can appear in different places so we look for the last one

haystack/components/tools/tool_invoker.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,9 @@ def run(
504504
streaming_callback(
505505
StreamingChunk(
506506
content="",
507+
index=len(tool_messages) - 1,
508+
tool_call_result=tool_messages[-1].tool_call_results[0],
509+
start=True,
507510
meta={"tool_result": tool_messages[-1].tool_call_results[0].result, "tool_call": tool_call},
508511
)
509512
)
@@ -609,6 +612,9 @@ async def run_async(
609612
await streaming_callback(
610613
StreamingChunk(
611614
content="",
615+
index=len(tool_messages) - 1,
616+
tool_call_result=tool_messages[-1].tool_call_results[0],
617+
start=True,
612618
meta={"tool_result": tool_messages[-1].tool_call_results[0].result, "tool_call": tool_call},
613619
)
614620
) # type: ignore[misc] # we have checked that streaming_callback is not None and async

haystack/dataclasses/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
"sparse_embedding": ["SparseEmbedding"],
1616
"state": ["State"],
1717
"streaming_chunk": [
18-
"StreamingChunk",
1918
"AsyncStreamingCallbackT",
19+
"ComponentInfo",
2020
"StreamingCallbackT",
21+
"StreamingChunk",
2122
"SyncStreamingCallbackT",
23+
"ToolCallDelta",
2224
"select_streaming_callback",
23-
"ComponentInfo",
2425
],
2526
}
2627

@@ -37,6 +38,7 @@
3738
StreamingCallbackT,
3839
StreamingChunk,
3940
SyncStreamingCallbackT,
41+
ToolCallDelta,
4042
select_streaming_callback,
4143
)
4244
else:

0 commit comments

Comments
 (0)