Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions haystack/components/generators/chat/hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ToolCall,
select_streaming_callback,
)
from haystack.dataclasses.streaming_chunk import FinishReason
from haystack.lazy_imports import LazyImport
from haystack.tools import (
Tool,
Expand All @@ -41,6 +42,7 @@
ChatCompletionOutput,
ChatCompletionOutputToolCall,
ChatCompletionStreamOutput,
ChatCompletionStreamOutputChoice,
InferenceClient,
)

Expand Down Expand Up @@ -110,6 +112,43 @@ def _convert_tools_to_hfapi_tools(
return hf_tools


def _map_hf_finish_reason_to_haystack(choice: "ChatCompletionStreamOutputChoice") -> Optional[FinishReason]:
"""
Map HuggingFace finish reasons to Haystack FinishReason literals.

Uses the full choice object to detect tool calls and provide accurate mapping.

HuggingFace finish reasons (can be found here https://huggingface.github.io/text-generation-inference/ under
FinishReason):
- "length": number of generated tokens == `max_new_tokens`
- "eos_token": the model generated its end of sequence token
- "stop_sequence": the model generated a text included in `stop_sequences`

Additionally detects tool calls from delta.tool_calls or delta.tool_call_id.

:param choice: The HuggingFace ChatCompletionStreamOutputChoice object.
:returns: The corresponding Haystack FinishReason or None.
"""
if choice.finish_reason is None:
return None

# Check if this choice contains tool call information
has_tool_calls = choice.delta.tool_calls is not None or choice.delta.tool_call_id is not None

# If we detect tool calls, override the finish reason
if has_tool_calls:
return "tool_calls"

# Map HuggingFace finish reasons to Haystack standard ones
mapping: Dict[str, FinishReason] = {
"length": "length", # Direct match
"eos_token": "stop", # EOS token means natural stop
"stop_sequence": "stop", # Stop sequence means natural stop
}

return mapping.get(choice.finish_reason, "stop") # Default to "stop" for unknown reasons


def _convert_chat_completion_stream_output_to_streaming_chunk(
chunk: "ChatCompletionStreamOutput",
previous_chunks: List[StreamingChunk],
Expand All @@ -133,6 +172,7 @@ def _convert_chat_completion_stream_output_to_streaming_chunk(
# the argument is probably allowed for compatibility with OpenAI
# see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
choice = chunk.choices[0]
mapped_finish_reason = _map_hf_finish_reason_to_haystack(choice) if choice.finish_reason else None
stream_chunk = StreamingChunk(
content=choice.delta.content or "",
meta={"model": chunk.model, "received_at": datetime.now().isoformat(), "finish_reason": choice.finish_reason},
Expand All @@ -141,6 +181,7 @@ def _convert_chat_completion_stream_output_to_streaming_chunk(
index=0 if choice.finish_reason is None else None,
# start is True at the very beginning since first chunk contains role information + first part of the answer.
start=len(previous_chunks) == 0,
finish_reason=mapped_finish_reason,
)
return stream_chunk

Expand Down
13 changes: 12 additions & 1 deletion haystack/components/generators/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
AsyncStreamingCallbackT,
ChatMessage,
ComponentInfo,
FinishReason,
StreamingCallbackT,
StreamingChunk,
SyncStreamingCallbackT,
Expand Down Expand Up @@ -517,8 +518,15 @@ def _convert_chat_completion_chunk_to_streaming_chunk(
generated the chunk, such as the component name and type.

:returns:
A list of StreamingChunk objects representing the content of the chunk from the OpenAI API.
A StreamingChunk object representing the content of the chunk from the OpenAI API.
"""
finish_reason_mapping: Dict[str, FinishReason] = {
"stop": "stop",
"length": "length",
"content_filter": "content_filter",
"tool_calls": "tool_calls",
"function_call": "tool_calls",
}
# On very first chunk so len(previous_chunks) == 0, the Choices field only provides role info (e.g. "assistant")
# Choices is empty if include_usage is set to True where the usage information is returned.
if len(chunk.choices) == 0:
Expand All @@ -527,6 +535,7 @@ def _convert_chat_completion_chunk_to_streaming_chunk(
component_info=component_info,
# Index is None since it's only set to an int when a content block is present
index=None,
finish_reason=None,
meta={
"model": chunk.model,
"received_at": datetime.now().isoformat(),
Expand Down Expand Up @@ -556,6 +565,7 @@ def _convert_chat_completion_chunk_to_streaming_chunk(
index=tool_calls_deltas[0].index,
tool_calls=tool_calls_deltas,
start=tool_calls_deltas[0].tool_name is not None,
finish_reason=finish_reason_mapping.get(choice.finish_reason) if choice.finish_reason else None,
meta={
"model": chunk.model,
"index": choice.index,
Expand Down Expand Up @@ -584,6 +594,7 @@ def _convert_chat_completion_chunk_to_streaming_chunk(
# The first chunk is always a start message chunk that only contains role information, so if we reach here
# and previous_chunks is length 1 then this is the start of text content.
start=len(previous_chunks) == 1,
finish_reason=finish_reason_mapping.get(choice.finish_reason) if choice.finish_reason else None,
meta={
"model": chunk.model,
"index": choice.index,
Expand Down
16 changes: 15 additions & 1 deletion haystack/components/generators/hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import (
ComponentInfo,
FinishReason,
StreamingCallbackT,
StreamingChunk,
SyncStreamingCallbackT,
Expand Down Expand Up @@ -241,8 +242,21 @@ def _stream_and_build_response(
if first_chunk_time is None:
first_chunk_time = datetime.now().isoformat()

mapping: Dict[str, FinishReason] = {
"length": "length", # Direct match
"eos_token": "stop", # EOS token means natural stop
"stop_sequence": "stop", # Stop sequence means natural stop
}
mapped_finish_reason = (
mapping.get(chunk_metadata["finish_reason"], "stop") if chunk_metadata.get("finish_reason") else None
)
stream_chunk = StreamingChunk(
content=token.text, meta=chunk_metadata, component_info=component_info, index=0, start=len(chunks) == 0
content=token.text,
meta=chunk_metadata,
component_info=component_info,
index=0,
start=len(chunks) == 0,
finish_reason=mapped_finish_reason,
)
chunks.append(stream_chunk)
streaming_callback(stream_chunk)
Expand Down
6 changes: 2 additions & 4 deletions haystack/components/generators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def print_streaming_chunk(chunk: StreamingChunk) -> None:

# End of LLM assistant message so we add two new lines
# This ensures spacing between multiple LLM messages (e.g. Agent) or multiple Tool Call Results
if chunk.meta.get("finish_reason") is not None:
if chunk.finish_reason is not None:
print("\n\n", flush=True, end="")


Expand Down Expand Up @@ -121,9 +121,7 @@ def _convert_streaming_chunks_to_chat_message(chunks: List[StreamingChunk]) -> C
)

# finish_reason can appear in different places so we look for the last one
finish_reasons = [
chunk.meta.get("finish_reason") for chunk in chunks if chunk.meta.get("finish_reason") is not None
]
finish_reasons = [chunk.finish_reason for chunk in chunks if chunk.finish_reason]
finish_reason = finish_reasons[-1] if finish_reasons else None

meta = {
Expand Down
12 changes: 10 additions & 2 deletions haystack/components/tools/tool_invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,11 @@ def run(

# We stream one more chunk that contains a finish_reason if tool_messages were generated
if len(tool_messages) > 0 and streaming_callback is not None:
streaming_callback(StreamingChunk(content="", meta={"finish_reason": "tool_call_results"}))
streaming_callback(
StreamingChunk(
content="", finish_reason="tool_call_results", meta={"finish_reason": "tool_call_results"}
)
)

return {"tool_messages": tool_messages, "state": state}

Expand Down Expand Up @@ -685,7 +689,11 @@ async def run_async(

# We stream one more chunk that contains a finish_reason if tool_messages were generated
if len(tool_messages) > 0 and streaming_callback is not None:
await streaming_callback(StreamingChunk(content="", meta={"finish_reason": "tool_call_results"}))
await streaming_callback(
StreamingChunk(
content="", finish_reason="tool_call_results", meta={"finish_reason": "tool_call_results"}
)
)

return {"tool_messages": tool_messages, "state": state}

Expand Down
2 changes: 2 additions & 0 deletions haystack/dataclasses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"streaming_chunk": [
"AsyncStreamingCallbackT",
"ComponentInfo",
"FinishReason",
"StreamingCallbackT",
"StreamingChunk",
"SyncStreamingCallbackT",
Expand All @@ -40,6 +41,7 @@
from .state import State as State
from .streaming_chunk import AsyncStreamingCallbackT as AsyncStreamingCallbackT
from .streaming_chunk import ComponentInfo as ComponentInfo
from .streaming_chunk import FinishReason as FinishReason
from .streaming_chunk import StreamingCallbackT as StreamingCallbackT
from .streaming_chunk import StreamingChunk as StreamingChunk
from .streaming_chunk import SyncStreamingCallbackT as SyncStreamingCallbackT
Expand Down
8 changes: 8 additions & 0 deletions haystack/dataclasses/streaming_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
from haystack.dataclasses.chat_message import ToolCallResult
from haystack.utils.asynchronous import is_callable_async_compatible

# Type alias for standard finish_reason values following OpenAI's convention
# plus Haystack-specific value ("tool_call_results")
FinishReason = Literal["stop", "length", "tool_calls", "content_filter", "tool_call_results"]


@dataclass
class ToolCallDelta:
Expand Down Expand Up @@ -77,6 +81,9 @@ class StreamingChunk:
chunk.
:param tool_call_result: An optional ToolCallResult object representing the result of a tool call.
:param start: A boolean indicating whether this chunk marks the start of a content block.
:param finish_reason: An optional value indicating the reason the generation finished.
Standard values follow OpenAI's convention: "stop", "length", "tool_calls", "content_filter",
plus Haystack-specific value "tool_call_results".
"""

content: str
Expand All @@ -86,6 +93,7 @@ class StreamingChunk:
tool_calls: Optional[List[ToolCallDelta]] = field(default=None)
tool_call_result: Optional[ToolCallResult] = field(default=None)
start: bool = field(default=False)
finish_reason: Optional[FinishReason] = field(default=None)

def __post_init__(self):
fields_set = sum(bool(x) for x in (self.content, self.tool_calls, self.tool_call_result))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
features:
- |
Added dedicated `finish_reason` field to `StreamingChunk` class to improve type safety and enable sophisticated streaming UI logic. The field uses a `FinishReason` type alias with standard values: "stop", "length", "tool_calls", "content_filter", plus Haystack-specific value "tool_call_results" (used by ToolInvoker to indicate tool execution completion).
- |
Updated `ToolInvoker` component to use the new `finish_reason` field when streaming tool results. The component now sets `finish_reason="tool_call_results"` in the final streaming chunk to indicate that tool execution has completed, while maintaining backward compatibility by also setting the value in `meta["finish_reason"]`.
1 change: 1 addition & 0 deletions test/components/generators/chat/test_hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,7 @@ def test_convert_hfapi_tool_calls_invalid_type_arguments(self):
"model": "microsoft/Phi-3.5-mini-instruct",
"finish_reason": "stop",
},
finish_reason="stop",
),
[0],
),
Expand Down
3 changes: 2 additions & 1 deletion test/components/generators/chat/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,7 @@ def streaming_chunks():
"received_at": ANY,
"usage": None,
},
finish_reason="tool_calls",
),
StreamingChunk(
content="",
Expand Down Expand Up @@ -1174,7 +1175,7 @@ def test_convert_chat_completion_chunk_to_streaming_chunk(self, chat_completion_
chunk=openai_chunk, previous_chunks=previous_chunks
)
assert stream_chunk == haystack_chunk
previous_chunks.append(openai_chunk)
previous_chunks.append(stream_chunk)

def test_handle_stream_response(self, chat_completion_chunks):
openai_chunks = chat_completion_chunks
Expand Down
1 change: 1 addition & 0 deletions test/components/generators/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk():
"received_at": "2025-02-19T16:02:55.948772",
},
component_info=ComponentInfo(name="test", type="test"),
finish_reason="tool_calls",
),
StreamingChunk(
content="",
Expand Down
52 changes: 52 additions & 0 deletions test/components/tools/test_tool_invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,28 @@ def streaming_callback(chunk: StreamingChunk) -> None:
assert tool_call_result.origin == tool_call
assert not tool_call_result.error

def test_run_with_streaming_callback_finish_reason(self, invoker):
streaming_chunks = []

def streaming_callback(chunk: StreamingChunk) -> None:
streaming_chunks.append(chunk)

tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})
message = ChatMessage.from_assistant(tool_calls=[tool_call])

result = invoker.run(messages=[message], streaming_callback=streaming_callback)
assert "tool_messages" in result
assert len(result["tool_messages"]) == 1

# Check that we received streaming chunks
assert len(streaming_chunks) >= 2 # At least one for tool result and one for finish reason

# The last chunk should have finish_reason set to "tool_call_results"
final_chunk = streaming_chunks[-1]
assert final_chunk.finish_reason == "tool_call_results"
assert final_chunk.meta["finish_reason"] == "tool_call_results"
assert final_chunk.content == ""

@pytest.mark.asyncio
async def test_run_async_with_streaming_callback(self, thread_executor, weather_tool):
streaming_callback_called = False
Expand Down Expand Up @@ -245,6 +267,36 @@ async def streaming_callback(chunk: StreamingChunk) -> None:
# check we called the streaming callback
assert streaming_callback_called

@pytest.mark.asyncio
async def test_run_async_with_streaming_callback_finish_reason(self, thread_executor, weather_tool):
streaming_chunks = []

async def streaming_callback(chunk: StreamingChunk) -> None:
streaming_chunks.append(chunk)

tool_invoker = ToolInvoker(
tools=[weather_tool],
raise_on_failure=True,
convert_result_to_json_string=False,
async_executor=thread_executor,
)

tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})
message = ChatMessage.from_assistant(tool_calls=[tool_call])

result = await tool_invoker.run_async(messages=[message], streaming_callback=streaming_callback)
assert "tool_messages" in result
assert len(result["tool_messages"]) == 1

# Check that we received streaming chunks
assert len(streaming_chunks) >= 2 # At least one for tool result and one for finish reason

# The last chunk should have finish_reason set to "tool_call_results"
final_chunk = streaming_chunks[-1]
assert final_chunk.finish_reason == "tool_call_results"
assert final_chunk.meta["finish_reason"] == "tool_call_results"
assert final_chunk.content == ""

def test_run_with_toolset(self, tool_set):
tool_invoker = ToolInvoker(tools=tool_set, raise_on_failure=True, convert_result_to_json_string=False)
tool_call = ToolCall(tool_name="addition_tool", arguments={"num1": 5, "num2": 3})
Expand Down
Loading
Loading