diff --git a/haystack/components/generators/chat/hugging_face_api.py b/haystack/components/generators/chat/hugging_face_api.py index c12a5192bd..6526fdad90 100644 --- a/haystack/components/generators/chat/hugging_face_api.py +++ b/haystack/components/generators/chat/hugging_face_api.py @@ -40,6 +40,7 @@ ChatCompletionInputStreamOptions, ChatCompletionInputTool, ChatCompletionOutput, + ChatCompletionOutputComplete, ChatCompletionOutputToolCall, ChatCompletionStreamOutput, ChatCompletionStreamOutputChoice, @@ -112,7 +113,9 @@ def _convert_tools_to_hfapi_tools( return hf_tools -def _map_hf_finish_reason_to_haystack(choice: "ChatCompletionStreamOutputChoice") -> Optional[FinishReason]: +def _map_hf_finish_reason_to_haystack( + choice: Union["ChatCompletionStreamOutputChoice", "ChatCompletionOutputComplete"], +) -> Optional[FinishReason]: """ Map HuggingFace finish reasons to Haystack FinishReason literals. @@ -133,7 +136,10 @@ def _map_hf_finish_reason_to_haystack(choice: "ChatCompletionStreamOutputChoice" return None # Check if this choice contains tool call information - has_tool_calls = choice.delta.tool_calls is not None or choice.delta.tool_call_id is not None + if isinstance(choice, ChatCompletionStreamOutputChoice): + has_tool_calls = choice.delta.tool_calls is not None or choice.delta.tool_call_id is not None + else: + has_tool_calls = choice.message.tool_calls is not None or choice.message.tool_call_id is not None # If we detect tool calls, override the finish reason if has_tool_calls: @@ -565,9 +571,10 @@ def _run_non_streaming( tool_calls = _convert_hfapi_tool_calls(choice.message.tool_calls) + mapped_finish_reason = _map_hf_finish_reason_to_haystack(choice) if choice.finish_reason else None meta: dict[str, Any] = { "model": self._client.model, - "finish_reason": choice.finish_reason, + "finish_reason": mapped_finish_reason, "index": choice.index, } @@ -629,9 +636,10 @@ async def _run_non_streaming_async( tool_calls = _convert_hfapi_tool_calls(choice.message.tool_calls) + mapped_finish_reason = _map_hf_finish_reason_to_haystack(choice) if choice.finish_reason else None meta: dict[str, Any] = { "model": self._async_client.model, - "finish_reason": choice.finish_reason, + "finish_reason": mapped_finish_reason, "index": choice.index, } diff --git a/releasenotes/notes/update-finish-reason-hf-api-chat-gen-c700042a079733e8.yaml b/releasenotes/notes/update-finish-reason-hf-api-chat-gen-c700042a079733e8.yaml new file mode 100644 index 0000000000..54c0438ddc --- /dev/null +++ b/releasenotes/notes/update-finish-reason-hf-api-chat-gen-c700042a079733e8.yaml @@ -0,0 +1,18 @@ +--- +upgrade: + - | + The `finish_reason` field behavior in `HuggingFaceAPIChatGenerator` has been + updated. Previously, the new `finish_reason` mapping (introduced in Haystack 2.15.0 release) was only applied when streaming was enabled. When streaming was disabled, + the old `finish_reason` was still returned. This change ensures the updated + `finish_reason` values are consistently returned regardless of streaming mode. + + **How to know if you're affected:** + If you rely on `finish_reason` in responses from `HuggingFaceAPIChatGenerator` + with streaming disabled, you may see different values after this upgrade. + + **What to do:** + Review the updated mapping: + - `length` → `length` + - `eos_token` → `stop` + - `stop_sequence` → `stop` + - If tool calls are present → `tool_calls` diff --git a/test/components/generators/chat/test_hugging_face_api.py b/test/components/generators/chat/test_hugging_face_api.py index 7f6e844aaa..f61d9bf01c 100644 --- a/test/components/generators/chat/test_hugging_face_api.py +++ b/test/components/generators/chat/test_hugging_face_api.py @@ -587,7 +587,7 @@ def test_run_with_tools(self, mock_check_valid_model, tools): assert response["replies"][0].tool_calls[0].arguments == {"city": "Paris"} assert response["replies"][0].tool_calls[0].id == "0" assert response["replies"][0].meta == { - "finish_reason": "stop", + "finish_reason": "tool_calls", "index": 0, "model": "meta-llama/Llama-3.1-70B-Instruct", "usage": {"completion_tokens": 30, "prompt_tokens": 426}, @@ -1040,7 +1040,7 @@ async def test_run_async_with_tools(self, tools, mock_check_valid_model): assert response["replies"][0].tool_calls[0].arguments == {"city": "Paris"} assert response["replies"][0].tool_calls[0].id == "0" assert response["replies"][0].meta == { - "finish_reason": "stop", + "finish_reason": "tool_calls", "index": 0, "model": "meta-llama/Llama-3.1-70B-Instruct", "usage": {"completion_tokens": 30, "prompt_tokens": 426},