Skip to content

Commit ade29fc

Browse files
anakin87sjrl
authored andcommitted
fix: HuggingFaceAPIChatGenerator - make tool conversion compatible with huggingface_hub>=0.31.0 (#9354)
* fix: HuggingFaceAPIChatGenerator - make tool conversion compatible with huggingface_hub>=0.31.0 * relnote
1 parent cff3435 commit ade29fc

3 files changed

Lines changed: 83 additions & 28 deletions

File tree

haystack/components/generators/chat/hugging_face_api.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,26 @@ def _convert_hfapi_tool_calls(hfapi_tool_calls: Optional[List["ChatCompletionOut
8080
return tool_calls
8181

8282

83+
def _convert_tools_to_hfapi_tools(
84+
tools: Optional[Union[List[Tool], Toolset]],
85+
) -> Optional[List["ChatCompletionInputTool"]]:
86+
if not tools:
87+
return None
88+
89+
# huggingface_hub<0.31.0 uses "arguments", huggingface_hub>=0.31.0 uses "parameters"
90+
parameters_name = "arguments" if hasattr(ChatCompletionInputFunctionDefinition, "arguments") else "parameters"
91+
92+
hf_tools = []
93+
for tool in tools:
94+
hf_tools_args = {"name": tool.name, "description": tool.description, parameters_name: tool.parameters}
95+
96+
hf_tools.append(
97+
ChatCompletionInputTool(function=ChatCompletionInputFunctionDefinition(**hf_tools_args), type="function")
98+
)
99+
100+
return hf_tools
101+
102+
83103
@component
84104
class HuggingFaceAPIChatGenerator:
85105
"""
@@ -313,19 +333,11 @@ def run(
313333
if streaming_callback:
314334
return self._run_streaming(formatted_messages, generation_kwargs, streaming_callback)
315335

316-
hf_tools = None
317-
if tools:
318-
if isinstance(tools, Toolset):
319-
tools = list(tools)
320-
hf_tools = [
321-
ChatCompletionInputTool(
322-
function=ChatCompletionInputFunctionDefinition(
323-
name=tool.name, description=tool.description, arguments=tool.parameters
324-
),
325-
type="function",
326-
)
327-
for tool in tools
328-
]
336+
if tools and isinstance(tools, Toolset):
337+
tools = list(tools)
338+
339+
hf_tools = _convert_tools_to_hfapi_tools(tools)
340+
329341
return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools)
330342

331343
@component.output_types(replies=List[ChatMessage])
@@ -373,19 +385,11 @@ async def run_async(
373385
if streaming_callback:
374386
return await self._run_streaming_async(formatted_messages, generation_kwargs, streaming_callback)
375387

376-
hf_tools = None
377-
if tools:
378-
if isinstance(tools, Toolset):
379-
tools = list(tools)
380-
hf_tools = [
381-
ChatCompletionInputTool(
382-
function=ChatCompletionInputFunctionDefinition(
383-
name=tool.name, description=tool.description, arguments=tool.parameters
384-
),
385-
type="function",
386-
)
387-
for tool in tools
388-
]
388+
if tools and isinstance(tools, Toolset):
389+
tools = list(tools)
390+
391+
hf_tools = _convert_tools_to_hfapi_tools(tools)
392+
389393
return await self._run_non_streaming_async(formatted_messages, generation_kwargs, hf_tools)
390394

391395
def _run_streaming(
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
fixes:
3+
- |
4+
Make internal tool conversion in the HuggingFaceAPICompatibleChatGenerator compatible with huggingface_hub>=0.31.0.
5+
In the huggingface_hub library, `arguments` attribute of `ChatCompletionInputFunctionDefinition` has been renamed to
6+
`parameters`.
7+
Our implementation is compatible with both the legacy version and the new one.

test/components/generators/chat/test_hugging_face_api.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from haystack.dataclasses import StreamingChunk
1111
from haystack.utils.auth import Secret
1212
from haystack.utils.hf import HFGenerationAPIType
13+
1314
from huggingface_hub import (
1415
ChatCompletionOutput,
1516
ChatCompletionOutputComplete,
@@ -21,9 +22,14 @@
2122
ChatCompletionStreamOutputChoice,
2223
ChatCompletionStreamOutputDelta,
2324
)
24-
from huggingface_hub.utils import RepositoryNotFoundError
25+
from huggingface_hub.errors import RepositoryNotFoundError
26+
27+
from haystack.components.generators.chat.hugging_face_api import (
28+
HuggingFaceAPIChatGenerator,
29+
_convert_hfapi_tool_calls,
30+
_convert_tools_to_hfapi_tools,
31+
)
2532

26-
from haystack.components.generators.chat.hugging_face_api import HuggingFaceAPIChatGenerator, _convert_hfapi_tool_calls
2733
from haystack.tools import Tool
2834
from haystack.dataclasses import ChatMessage, ToolCall
2935
from haystack.tools.toolset import Toolset
@@ -976,3 +982,41 @@ def test_to_dict_with_toolset(self, mock_check_valid_model, tools):
976982
},
977983
}
978984
assert data["init_parameters"]["tools"] == expected_tools_data
985+
986+
def test_convert_tools_to_hfapi_tools(self):
987+
assert _convert_tools_to_hfapi_tools(None) is None
988+
assert _convert_tools_to_hfapi_tools([]) is None
989+
990+
tool = Tool(
991+
name="weather",
992+
description="useful to determine the weather in a given location",
993+
parameters={"city": {"type": "string"}},
994+
function=get_weather,
995+
)
996+
hf_tools = _convert_tools_to_hfapi_tools([tool])
997+
assert len(hf_tools) == 1
998+
assert hf_tools[0].type == "function"
999+
assert hf_tools[0].function.name == "weather"
1000+
assert hf_tools[0].function.description == "useful to determine the weather in a given location"
1001+
assert hf_tools[0].function.parameters == {"city": {"type": "string"}}
1002+
1003+
def test_convert_tools_to_hfapi_tools_legacy(self):
1004+
# this satisfies the check hasattr(ChatCompletionInputFunctionDefinition, "arguments")
1005+
mock_class = MagicMock()
1006+
1007+
with patch(
1008+
"haystack.components.generators.chat.hugging_face_api.ChatCompletionInputFunctionDefinition", mock_class
1009+
):
1010+
tool = Tool(
1011+
name="weather",
1012+
description="useful to determine the weather in a given location",
1013+
parameters={"city": {"type": "string"}},
1014+
function=get_weather,
1015+
)
1016+
_convert_tools_to_hfapi_tools([tool])
1017+
1018+
mock_class.assert_called_once_with(
1019+
name="weather",
1020+
arguments={"city": {"type": "string"}},
1021+
description="useful to determine the weather in a given location",
1022+
)

0 commit comments

Comments
 (0)