Skip to content

Commit af9aac2

Browse files
sjrlAmnah199
andauthored
chore!: Update finish reason in output of HuggingFaceAPIChatGenerator to match between stream and non-stream modes (#9686)
* Update finish reason * Fix unit test * Add reno * Update releasenotes/notes/update-finish-reason-hf-api-chat-gen-c700042a079733e8.yaml Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com> * Update async as well * Fix unit test --------- Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com>
1 parent 03d9f0f commit af9aac2

3 files changed

Lines changed: 32 additions & 6 deletions

File tree

haystack/components/generators/chat/hugging_face_api.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
ChatCompletionInputStreamOptions,
4141
ChatCompletionInputTool,
4242
ChatCompletionOutput,
43+
ChatCompletionOutputComplete,
4344
ChatCompletionOutputToolCall,
4445
ChatCompletionStreamOutput,
4546
ChatCompletionStreamOutputChoice,
@@ -112,7 +113,9 @@ def _convert_tools_to_hfapi_tools(
112113
return hf_tools
113114

114115

115-
def _map_hf_finish_reason_to_haystack(choice: "ChatCompletionStreamOutputChoice") -> Optional[FinishReason]:
116+
def _map_hf_finish_reason_to_haystack(
117+
choice: Union["ChatCompletionStreamOutputChoice", "ChatCompletionOutputComplete"],
118+
) -> Optional[FinishReason]:
116119
"""
117120
Map HuggingFace finish reasons to Haystack FinishReason literals.
118121
@@ -133,7 +136,10 @@ def _map_hf_finish_reason_to_haystack(choice: "ChatCompletionStreamOutputChoice"
133136
return None
134137

135138
# Check if this choice contains tool call information
136-
has_tool_calls = choice.delta.tool_calls is not None or choice.delta.tool_call_id is not None
139+
if isinstance(choice, ChatCompletionStreamOutputChoice):
140+
has_tool_calls = choice.delta.tool_calls is not None or choice.delta.tool_call_id is not None
141+
else:
142+
has_tool_calls = choice.message.tool_calls is not None or choice.message.tool_call_id is not None
137143

138144
# If we detect tool calls, override the finish reason
139145
if has_tool_calls:
@@ -565,9 +571,10 @@ def _run_non_streaming(
565571

566572
tool_calls = _convert_hfapi_tool_calls(choice.message.tool_calls)
567573

574+
mapped_finish_reason = _map_hf_finish_reason_to_haystack(choice) if choice.finish_reason else None
568575
meta: dict[str, Any] = {
569576
"model": self._client.model,
570-
"finish_reason": choice.finish_reason,
577+
"finish_reason": mapped_finish_reason,
571578
"index": choice.index,
572579
}
573580

@@ -629,9 +636,10 @@ async def _run_non_streaming_async(
629636

630637
tool_calls = _convert_hfapi_tool_calls(choice.message.tool_calls)
631638

639+
mapped_finish_reason = _map_hf_finish_reason_to_haystack(choice) if choice.finish_reason else None
632640
meta: dict[str, Any] = {
633641
"model": self._async_client.model,
634-
"finish_reason": choice.finish_reason,
642+
"finish_reason": mapped_finish_reason,
635643
"index": choice.index,
636644
}
637645

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
---
2+
upgrade:
3+
- |
4+
The `finish_reason` field behavior in `HuggingFaceAPIChatGenerator` has been
5+
updated. Previously, the new `finish_reason` mapping (introduced in Haystack 2.15.0 release) was only applied when streaming was enabled. When streaming was disabled,
6+
the old `finish_reason` was still returned. This change ensures the updated
7+
`finish_reason` values are consistently returned regardless of streaming mode.
8+
9+
**How to know if you're affected:**
10+
If you rely on `finish_reason` in responses from `HuggingFaceAPIChatGenerator`
11+
with streaming disabled, you may see different values after this upgrade.
12+
13+
**What to do:**
14+
Review the updated mapping:
15+
- `length` → `length`
16+
- `eos_token` → `stop`
17+
- `stop_sequence` → `stop`
18+
- If tool calls are present → `tool_calls`

test/components/generators/chat/test_hugging_face_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,7 @@ def test_run_with_tools(self, mock_check_valid_model, tools):
587587
assert response["replies"][0].tool_calls[0].arguments == {"city": "Paris"}
588588
assert response["replies"][0].tool_calls[0].id == "0"
589589
assert response["replies"][0].meta == {
590-
"finish_reason": "stop",
590+
"finish_reason": "tool_calls",
591591
"index": 0,
592592
"model": "meta-llama/Llama-3.1-70B-Instruct",
593593
"usage": {"completion_tokens": 30, "prompt_tokens": 426},
@@ -1040,7 +1040,7 @@ async def test_run_async_with_tools(self, tools, mock_check_valid_model):
10401040
assert response["replies"][0].tool_calls[0].arguments == {"city": "Paris"}
10411041
assert response["replies"][0].tool_calls[0].id == "0"
10421042
assert response["replies"][0].meta == {
1043-
"finish_reason": "stop",
1043+
"finish_reason": "tool_calls",
10441044
"index": 0,
10451045
"model": "meta-llama/Llama-3.1-70B-Instruct",
10461046
"usage": {"completion_tokens": 30, "prompt_tokens": 426},

0 commit comments

Comments
 (0)