Skip to content

Commit 91094e1

Browse files
vblagojesjrl
andauthored
feat: Add finish_reason field to StreamingChunk (#9536)
* Initial commit * Update deprecation version * Improve comment * Minor simplification * Add reno note * Remove deprecation warning * Remove fallback in haystack/components/generators/utils.py * FinishReason alphabetical import * Add tool_call_results finish reason, adapt codebase * Define finish_reason to be Optional[FinishReason] * Add StreamingChunk finish_reason in HF generators * Update reno note * Repair merge issue * Update tests for finish_reason * Resolve mypy issues * Lint issue * Enhance HF finish_reason translation * Remove irrlevant test * PR comments --------- Co-authored-by: Sebastian Husch Lee <sjrl423@gmail.com>
1 parent 1d1c13a commit 91094e1

13 files changed

Lines changed: 192 additions & 10 deletions

File tree

haystack/components/generators/chat/hugging_face_api.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
ToolCall,
1919
select_streaming_callback,
2020
)
21+
from haystack.dataclasses.streaming_chunk import FinishReason
2122
from haystack.lazy_imports import LazyImport
2223
from haystack.tools import (
2324
Tool,
@@ -41,6 +42,7 @@
4142
ChatCompletionOutput,
4243
ChatCompletionOutputToolCall,
4344
ChatCompletionStreamOutput,
45+
ChatCompletionStreamOutputChoice,
4446
InferenceClient,
4547
)
4648

@@ -110,6 +112,43 @@ def _convert_tools_to_hfapi_tools(
110112
return hf_tools
111113

112114

115+
def _map_hf_finish_reason_to_haystack(choice: "ChatCompletionStreamOutputChoice") -> Optional[FinishReason]:
116+
"""
117+
Map HuggingFace finish reasons to Haystack FinishReason literals.
118+
119+
Uses the full choice object to detect tool calls and provide accurate mapping.
120+
121+
HuggingFace finish reasons (can be found here https://huggingface.github.io/text-generation-inference/ under
122+
FinishReason):
123+
- "length": number of generated tokens == `max_new_tokens`
124+
- "eos_token": the model generated its end of sequence token
125+
- "stop_sequence": the model generated a text included in `stop_sequences`
126+
127+
Additionally detects tool calls from delta.tool_calls or delta.tool_call_id.
128+
129+
:param choice: The HuggingFace ChatCompletionStreamOutputChoice object.
130+
:returns: The corresponding Haystack FinishReason or None.
131+
"""
132+
if choice.finish_reason is None:
133+
return None
134+
135+
# Check if this choice contains tool call information
136+
has_tool_calls = choice.delta.tool_calls is not None or choice.delta.tool_call_id is not None
137+
138+
# If we detect tool calls, override the finish reason
139+
if has_tool_calls:
140+
return "tool_calls"
141+
142+
# Map HuggingFace finish reasons to Haystack standard ones
143+
mapping: Dict[str, FinishReason] = {
144+
"length": "length", # Direct match
145+
"eos_token": "stop", # EOS token means natural stop
146+
"stop_sequence": "stop", # Stop sequence means natural stop
147+
}
148+
149+
return mapping.get(choice.finish_reason, "stop") # Default to "stop" for unknown reasons
150+
151+
113152
def _convert_chat_completion_stream_output_to_streaming_chunk(
114153
chunk: "ChatCompletionStreamOutput",
115154
previous_chunks: List[StreamingChunk],
@@ -133,6 +172,7 @@ def _convert_chat_completion_stream_output_to_streaming_chunk(
133172
# the argument is probably allowed for compatibility with OpenAI
134173
# see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
135174
choice = chunk.choices[0]
175+
mapped_finish_reason = _map_hf_finish_reason_to_haystack(choice) if choice.finish_reason else None
136176
stream_chunk = StreamingChunk(
137177
content=choice.delta.content or "",
138178
meta={"model": chunk.model, "received_at": datetime.now().isoformat(), "finish_reason": choice.finish_reason},
@@ -141,6 +181,7 @@ def _convert_chat_completion_stream_output_to_streaming_chunk(
141181
index=0 if choice.finish_reason is None else None,
142182
# start is True at the very beginning since first chunk contains role information + first part of the answer.
143183
start=len(previous_chunks) == 0,
184+
finish_reason=mapped_finish_reason,
144185
)
145186
return stream_chunk
146187

haystack/components/generators/chat/openai.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
AsyncStreamingCallbackT,
1919
ChatMessage,
2020
ComponentInfo,
21+
FinishReason,
2122
StreamingCallbackT,
2223
StreamingChunk,
2324
SyncStreamingCallbackT,
@@ -517,8 +518,15 @@ def _convert_chat_completion_chunk_to_streaming_chunk(
517518
generated the chunk, such as the component name and type.
518519
519520
:returns:
520-
A list of StreamingChunk objects representing the content of the chunk from the OpenAI API.
521+
A StreamingChunk object representing the content of the chunk from the OpenAI API.
521522
"""
523+
finish_reason_mapping: Dict[str, FinishReason] = {
524+
"stop": "stop",
525+
"length": "length",
526+
"content_filter": "content_filter",
527+
"tool_calls": "tool_calls",
528+
"function_call": "tool_calls",
529+
}
522530
# On very first chunk so len(previous_chunks) == 0, the Choices field only provides role info (e.g. "assistant")
523531
# Choices is empty if include_usage is set to True where the usage information is returned.
524532
if len(chunk.choices) == 0:
@@ -527,6 +535,7 @@ def _convert_chat_completion_chunk_to_streaming_chunk(
527535
component_info=component_info,
528536
# Index is None since it's only set to an int when a content block is present
529537
index=None,
538+
finish_reason=None,
530539
meta={
531540
"model": chunk.model,
532541
"received_at": datetime.now().isoformat(),
@@ -556,6 +565,7 @@ def _convert_chat_completion_chunk_to_streaming_chunk(
556565
index=tool_calls_deltas[0].index,
557566
tool_calls=tool_calls_deltas,
558567
start=tool_calls_deltas[0].tool_name is not None,
568+
finish_reason=finish_reason_mapping.get(choice.finish_reason) if choice.finish_reason else None,
559569
meta={
560570
"model": chunk.model,
561571
"index": choice.index,
@@ -584,6 +594,7 @@ def _convert_chat_completion_chunk_to_streaming_chunk(
584594
# The first chunk is always a start message chunk that only contains role information, so if we reach here
585595
# and previous_chunks is length 1 then this is the start of text content.
586596
start=len(previous_chunks) == 1,
597+
finish_reason=finish_reason_mapping.get(choice.finish_reason) if choice.finish_reason else None,
587598
meta={
588599
"model": chunk.model,
589600
"index": choice.index,

haystack/components/generators/hugging_face_api.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from haystack import component, default_from_dict, default_to_dict
1010
from haystack.dataclasses import (
1111
ComponentInfo,
12+
FinishReason,
1213
StreamingCallbackT,
1314
StreamingChunk,
1415
SyncStreamingCallbackT,
@@ -241,8 +242,21 @@ def _stream_and_build_response(
241242
if first_chunk_time is None:
242243
first_chunk_time = datetime.now().isoformat()
243244

245+
mapping: Dict[str, FinishReason] = {
246+
"length": "length", # Direct match
247+
"eos_token": "stop", # EOS token means natural stop
248+
"stop_sequence": "stop", # Stop sequence means natural stop
249+
}
250+
mapped_finish_reason = (
251+
mapping.get(chunk_metadata["finish_reason"], "stop") if chunk_metadata.get("finish_reason") else None
252+
)
244253
stream_chunk = StreamingChunk(
245-
content=token.text, meta=chunk_metadata, component_info=component_info, index=0, start=len(chunks) == 0
254+
content=token.text,
255+
meta=chunk_metadata,
256+
component_info=component_info,
257+
index=0,
258+
start=len(chunks) == 0,
259+
finish_reason=mapped_finish_reason,
246260
)
247261
chunks.append(stream_chunk)
248262
streaming_callback(stream_chunk)

haystack/components/generators/utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def print_streaming_chunk(chunk: StreamingChunk) -> None:
6565

6666
# End of LLM assistant message so we add two new lines
6767
# This ensures spacing between multiple LLM messages (e.g. Agent) or multiple Tool Call Results
68-
if chunk.meta.get("finish_reason") is not None:
68+
if chunk.finish_reason is not None:
6969
print("\n\n", flush=True, end="")
7070

7171

@@ -121,9 +121,7 @@ def _convert_streaming_chunks_to_chat_message(chunks: List[StreamingChunk]) -> C
121121
)
122122

123123
# finish_reason can appear in different places so we look for the last one
124-
finish_reasons = [
125-
chunk.meta.get("finish_reason") for chunk in chunks if chunk.meta.get("finish_reason") is not None
126-
]
124+
finish_reasons = [chunk.finish_reason for chunk in chunks if chunk.finish_reason]
127125
finish_reason = finish_reasons[-1] if finish_reasons else None
128126

129127
meta = {

haystack/components/tools/tool_invoker.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,11 @@ def run(
553553

554554
# We stream one more chunk that contains a finish_reason if tool_messages were generated
555555
if len(tool_messages) > 0 and streaming_callback is not None:
556-
streaming_callback(StreamingChunk(content="", meta={"finish_reason": "tool_call_results"}))
556+
streaming_callback(
557+
StreamingChunk(
558+
content="", finish_reason="tool_call_results", meta={"finish_reason": "tool_call_results"}
559+
)
560+
)
557561

558562
return {"tool_messages": tool_messages, "state": state}
559563

@@ -685,7 +689,11 @@ async def run_async(
685689

686690
# We stream one more chunk that contains a finish_reason if tool_messages were generated
687691
if len(tool_messages) > 0 and streaming_callback is not None:
688-
await streaming_callback(StreamingChunk(content="", meta={"finish_reason": "tool_call_results"}))
692+
await streaming_callback(
693+
StreamingChunk(
694+
content="", finish_reason="tool_call_results", meta={"finish_reason": "tool_call_results"}
695+
)
696+
)
689697

690698
return {"tool_messages": tool_messages, "state": state}
691699

haystack/dataclasses/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"streaming_chunk": [
1818
"AsyncStreamingCallbackT",
1919
"ComponentInfo",
20+
"FinishReason",
2021
"StreamingCallbackT",
2122
"StreamingChunk",
2223
"SyncStreamingCallbackT",
@@ -40,6 +41,7 @@
4041
from .state import State as State
4142
from .streaming_chunk import AsyncStreamingCallbackT as AsyncStreamingCallbackT
4243
from .streaming_chunk import ComponentInfo as ComponentInfo
44+
from .streaming_chunk import FinishReason as FinishReason
4345
from .streaming_chunk import StreamingCallbackT as StreamingCallbackT
4446
from .streaming_chunk import StreamingChunk as StreamingChunk
4547
from .streaming_chunk import SyncStreamingCallbackT as SyncStreamingCallbackT

haystack/dataclasses/streaming_chunk.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
from haystack.dataclasses.chat_message import ToolCallResult
1010
from haystack.utils.asynchronous import is_callable_async_compatible
1111

12+
# Type alias for standard finish_reason values following OpenAI's convention
13+
# plus Haystack-specific value ("tool_call_results")
14+
FinishReason = Literal["stop", "length", "tool_calls", "content_filter", "tool_call_results"]
15+
1216

1317
@dataclass
1418
class ToolCallDelta:
@@ -77,6 +81,9 @@ class StreamingChunk:
7781
chunk.
7882
:param tool_call_result: An optional ToolCallResult object representing the result of a tool call.
7983
:param start: A boolean indicating whether this chunk marks the start of a content block.
84+
:param finish_reason: An optional value indicating the reason the generation finished.
85+
Standard values follow OpenAI's convention: "stop", "length", "tool_calls", "content_filter",
86+
plus Haystack-specific value "tool_call_results".
8087
"""
8188

8289
content: str
@@ -86,6 +93,7 @@ class StreamingChunk:
8693
tool_calls: Optional[List[ToolCallDelta]] = field(default=None)
8794
tool_call_result: Optional[ToolCallResult] = field(default=None)
8895
start: bool = field(default=False)
96+
finish_reason: Optional[FinishReason] = field(default=None)
8997

9098
def __post_init__(self):
9199
fields_set = sum(bool(x) for x in (self.content, self.tool_calls, self.tool_call_result))
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
features:
3+
- |
4+
Added dedicated `finish_reason` field to `StreamingChunk` class to improve type safety and enable sophisticated streaming UI logic. The field uses a `FinishReason` type alias with standard values: "stop", "length", "tool_calls", "content_filter", plus Haystack-specific value "tool_call_results" (used by ToolInvoker to indicate tool execution completion).
5+
- |
6+
Updated `ToolInvoker` component to use the new `finish_reason` field when streaming tool results. The component now sets `finish_reason="tool_call_results"` in the final streaming chunk to indicate that tool execution has completed, while maintaining backward compatibility by also setting the value in `meta["finish_reason"]`.

test/components/generators/chat/test_hugging_face_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,7 @@ def test_convert_hfapi_tool_calls_invalid_type_arguments(self):
711711
"model": "microsoft/Phi-3.5-mini-instruct",
712712
"finish_reason": "stop",
713713
},
714+
finish_reason="stop",
714715
),
715716
[0],
716717
),

test/components/generators/chat/test_openai.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1143,6 +1143,7 @@ def streaming_chunks():
11431143
"received_at": ANY,
11441144
"usage": None,
11451145
},
1146+
finish_reason="tool_calls",
11461147
),
11471148
StreamingChunk(
11481149
content="",
@@ -1174,7 +1175,7 @@ def test_convert_chat_completion_chunk_to_streaming_chunk(self, chat_completion_
11741175
chunk=openai_chunk, previous_chunks=previous_chunks
11751176
)
11761177
assert stream_chunk == haystack_chunk
1177-
previous_chunks.append(openai_chunk)
1178+
previous_chunks.append(stream_chunk)
11781179

11791180
def test_handle_stream_response(self, chat_completion_chunks):
11801181
openai_chunks = chat_completion_chunks

0 commit comments

Comments
 (0)