Skip to content

Commit cff3435

Browse files
YassinNouh21anakin87
authored andcommitted
fix: make HuggingFaceAPIChatGenerator convert Tool Call arguments from string (#9303)
* fix: sort imports in hugging_face_api.py * fix: import logging in hugging_face_api.py * fix: refactor HuggingFace API tool call handling for improved argument conversion * Update haystack/components/generators/chat/hugging_face_api.py Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com> * refinements + tests + relnote * simplify --------- Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>
1 parent e29e882 commit cff3435

3 files changed

Lines changed: 126 additions & 19 deletions

File tree

haystack/components/generators/chat/hugging_face_api.py

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
import json
56
from datetime import datetime
67
from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Union
78

8-
from haystack import component, default_from_dict, default_to_dict
9+
from haystack import component, default_from_dict, default_to_dict, logging
910
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback
1011
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
1112
from haystack.lazy_imports import LazyImport
@@ -20,17 +21,65 @@
2021
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model, convert_message_to_hf_format
2122
from haystack.utils.url_validation import is_valid_http_url
2223

24+
logger = logging.getLogger(__name__)
25+
2326
with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import:
2427
from huggingface_hub import (
2528
AsyncInferenceClient,
2629
ChatCompletionInputFunctionDefinition,
2730
ChatCompletionInputTool,
2831
ChatCompletionOutput,
32+
ChatCompletionOutputToolCall,
2933
ChatCompletionStreamOutput,
3034
InferenceClient,
3135
)
3236

3337

38+
def _convert_hfapi_tool_calls(hfapi_tool_calls: Optional[List["ChatCompletionOutputToolCall"]]) -> List[ToolCall]:
39+
"""
40+
Convert HuggingFace API tool calls to a list of Haystack ToolCall.
41+
42+
:param hfapi_tool_calls: The HuggingFace API tool calls to convert.
43+
:returns: A list of ToolCall objects.
44+
45+
"""
46+
if not hfapi_tool_calls:
47+
return []
48+
49+
tool_calls = []
50+
51+
for hfapi_tc in hfapi_tool_calls:
52+
hf_arguments = hfapi_tc.function.arguments
53+
54+
arguments = None
55+
if isinstance(hf_arguments, dict):
56+
arguments = hf_arguments
57+
elif isinstance(hf_arguments, str):
58+
try:
59+
arguments = json.loads(hf_arguments)
60+
except json.JSONDecodeError:
61+
logger.warning(
62+
"HuggingFace API returned a malformed JSON string for tool call arguments. This tool call "
63+
"will be skipped. Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}",
64+
_id=hfapi_tc.id,
65+
_name=hfapi_tc.function.name,
66+
_arguments=hf_arguments,
67+
)
68+
else:
69+
logger.warning(
70+
"HuggingFace API returned tool call arguments of type {_type}. Valid types are dict and str. This tool "
71+
"call will be skipped. Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}",
72+
_id=hfapi_tc.id,
73+
_name=hfapi_tc.function.name,
74+
_arguments=hf_arguments,
75+
)
76+
77+
if arguments:
78+
tool_calls.append(ToolCall(tool_name=hfapi_tc.function.name, arguments=arguments, id=hfapi_tc.id))
79+
80+
return tool_calls
81+
82+
3483
@component
3584
class HuggingFaceAPIChatGenerator:
3685
"""
@@ -403,14 +452,8 @@ def _run_non_streaming(
403452
choice = api_chat_output.choices[0]
404453

405454
text = choice.message.content
406-
tool_calls = []
407455

408-
if hfapi_tool_calls := choice.message.tool_calls:
409-
for hfapi_tc in hfapi_tool_calls:
410-
tool_call = ToolCall(
411-
tool_name=hfapi_tc.function.name, arguments=hfapi_tc.function.arguments, id=hfapi_tc.id
412-
)
413-
tool_calls.append(tool_call)
456+
tool_calls = _convert_hfapi_tool_calls(choice.message.tool_calls)
414457

415458
meta: Dict[str, Any] = {
416459
"model": self._client.model,
@@ -486,14 +529,8 @@ async def _run_non_streaming_async(
486529
choice = api_chat_output.choices[0]
487530

488531
text = choice.message.content
489-
tool_calls = []
490532

491-
if hfapi_tool_calls := choice.message.tool_calls:
492-
for hfapi_tc in hfapi_tool_calls:
493-
tool_call = ToolCall(
494-
tool_name=hfapi_tc.function.name, arguments=hfapi_tc.function.arguments, id=hfapi_tc.id
495-
)
496-
tool_calls.append(tool_call)
533+
tool_calls = _convert_hfapi_tool_calls(choice.message.tool_calls)
497534

498535
meta: Dict[str, Any] = {
499536
"model": self._async_client.model,
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
fixes:
3+
- |
4+
The `HuggingFaceAPIChatGenerator` now checks the type of the `arguments` variable in the tool calls returned by the
5+
Hugging Face API. If `arguments` is a JSON string, it is parsed into a dictionary.
6+
Previously, the `arguments` type was not checked, which sometimes led to failures later in the tool workflow.

test/components/generators/chat/test_hugging_face_api.py

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
)
2424
from huggingface_hub.utils import RepositoryNotFoundError
2525

26-
from haystack.components.generators.chat.hugging_face_api import HuggingFaceAPIChatGenerator
26+
from haystack.components.generators.chat.hugging_face_api import HuggingFaceAPIChatGenerator, _convert_hfapi_tool_calls
2727
from haystack.tools import Tool
2828
from haystack.dataclasses import ChatMessage, ToolCall
2929
from haystack.tools.toolset import Toolset
@@ -573,6 +573,73 @@ def test_run_with_tools(self, mock_check_valid_model, tools):
573573
"usage": {"completion_tokens": 30, "prompt_tokens": 426},
574574
}
575575

576+
def test_convert_hfapi_tool_calls_empty(self):
577+
hfapi_tool_calls = None
578+
tool_calls = _convert_hfapi_tool_calls(hfapi_tool_calls)
579+
assert len(tool_calls) == 0
580+
581+
hfapi_tool_calls = []
582+
tool_calls = _convert_hfapi_tool_calls(hfapi_tool_calls)
583+
assert len(tool_calls) == 0
584+
585+
def test_convert_hfapi_tool_calls_dict_arguments(self):
586+
hfapi_tool_calls = [
587+
ChatCompletionOutputToolCall(
588+
function=ChatCompletionOutputFunctionDefinition(
589+
arguments={"city": "Paris"}, name="weather", description=None
590+
),
591+
id="0",
592+
type="function",
593+
)
594+
]
595+
tool_calls = _convert_hfapi_tool_calls(hfapi_tool_calls)
596+
assert len(tool_calls) == 1
597+
assert tool_calls[0].tool_name == "weather"
598+
assert tool_calls[0].arguments == {"city": "Paris"}
599+
assert tool_calls[0].id == "0"
600+
601+
def test_convert_hfapi_tool_calls_str_arguments(self):
602+
hfapi_tool_calls = [
603+
ChatCompletionOutputToolCall(
604+
function=ChatCompletionOutputFunctionDefinition(
605+
arguments='{"city": "Paris"}', name="weather", description=None
606+
),
607+
id="0",
608+
type="function",
609+
)
610+
]
611+
tool_calls = _convert_hfapi_tool_calls(hfapi_tool_calls)
612+
assert len(tool_calls) == 1
613+
assert tool_calls[0].tool_name == "weather"
614+
assert tool_calls[0].arguments == {"city": "Paris"}
615+
assert tool_calls[0].id == "0"
616+
617+
def test_convert_hfapi_tool_calls_invalid_str_arguments(self):
618+
hfapi_tool_calls = [
619+
ChatCompletionOutputToolCall(
620+
function=ChatCompletionOutputFunctionDefinition(
621+
arguments="not a valid JSON string", name="weather", description=None
622+
),
623+
id="0",
624+
type="function",
625+
)
626+
]
627+
tool_calls = _convert_hfapi_tool_calls(hfapi_tool_calls)
628+
assert len(tool_calls) == 0
629+
630+
def test_convert_hfapi_tool_calls_invalid_type_arguments(self):
631+
hfapi_tool_calls = [
632+
ChatCompletionOutputToolCall(
633+
function=ChatCompletionOutputFunctionDefinition(
634+
arguments=["this", "is", "a", "list"], name="weather", description=None
635+
),
636+
id="0",
637+
type="function",
638+
)
639+
]
640+
tool_calls = _convert_hfapi_tool_calls(hfapi_tool_calls)
641+
assert len(tool_calls) == 0
642+
576643
@pytest.mark.integration
577644
@pytest.mark.skipif(
578645
not os.environ.get("HF_API_TOKEN", None),
@@ -639,9 +706,6 @@ def test_live_run_serverless_streaming(self):
639706
not os.environ.get("HF_API_TOKEN", None),
640707
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
641708
)
642-
@pytest.mark.xfail(
643-
reason="The Hugging Face API can be unstable and this test may fail intermittently", strict=False
644-
)
645709
def test_live_run_with_tools(self, tools):
646710
"""
647711
We test the round trip: generate tool call, pass tool message, generate response.

0 commit comments

Comments
 (0)