|
10 | 10 | from haystack.dataclasses import StreamingChunk |
11 | 11 | from haystack.utils.auth import Secret |
12 | 12 | from haystack.utils.hf import HFGenerationAPIType |
| 13 | + |
13 | 14 | from huggingface_hub import ( |
14 | 15 | ChatCompletionOutput, |
15 | 16 | ChatCompletionOutputComplete, |
|
21 | 22 | ChatCompletionStreamOutputChoice, |
22 | 23 | ChatCompletionStreamOutputDelta, |
23 | 24 | ) |
24 | | -from huggingface_hub.utils import RepositoryNotFoundError |
| 25 | +from huggingface_hub.errors import RepositoryNotFoundError |
| 26 | + |
| 27 | +from haystack.components.generators.chat.hugging_face_api import ( |
| 28 | + HuggingFaceAPIChatGenerator, |
| 29 | + _convert_hfapi_tool_calls, |
| 30 | + _convert_tools_to_hfapi_tools, |
| 31 | +) |
25 | 32 |
|
26 | | -from haystack.components.generators.chat.hugging_face_api import HuggingFaceAPIChatGenerator, _convert_hfapi_tool_calls |
27 | 33 | from haystack.tools import Tool |
28 | 34 | from haystack.dataclasses import ChatMessage, ToolCall |
29 | 35 | from haystack.tools.toolset import Toolset |
@@ -976,3 +982,41 @@ def test_to_dict_with_toolset(self, mock_check_valid_model, tools): |
976 | 982 | }, |
977 | 983 | } |
978 | 984 | assert data["init_parameters"]["tools"] == expected_tools_data |
| 985 | + |
| 986 | + def test_convert_tools_to_hfapi_tools(self): |
| 987 | + assert _convert_tools_to_hfapi_tools(None) is None |
| 988 | + assert _convert_tools_to_hfapi_tools([]) is None |
| 989 | + |
| 990 | + tool = Tool( |
| 991 | + name="weather", |
| 992 | + description="useful to determine the weather in a given location", |
| 993 | + parameters={"city": {"type": "string"}}, |
| 994 | + function=get_weather, |
| 995 | + ) |
| 996 | + hf_tools = _convert_tools_to_hfapi_tools([tool]) |
| 997 | + assert len(hf_tools) == 1 |
| 998 | + assert hf_tools[0].type == "function" |
| 999 | + assert hf_tools[0].function.name == "weather" |
| 1000 | + assert hf_tools[0].function.description == "useful to determine the weather in a given location" |
| 1001 | + assert hf_tools[0].function.parameters == {"city": {"type": "string"}} |
| 1002 | + |
| 1003 | + def test_convert_tools_to_hfapi_tools_legacy(self): |
| 1004 | + # this satisfies the check hasattr(ChatCompletionInputFunctionDefinition, "arguments") |
| 1005 | + mock_class = MagicMock() |
| 1006 | + |
| 1007 | + with patch( |
| 1008 | + "haystack.components.generators.chat.hugging_face_api.ChatCompletionInputFunctionDefinition", mock_class |
| 1009 | + ): |
| 1010 | + tool = Tool( |
| 1011 | + name="weather", |
| 1012 | + description="useful to determine the weather in a given location", |
| 1013 | + parameters={"city": {"type": "string"}}, |
| 1014 | + function=get_weather, |
| 1015 | + ) |
| 1016 | + _convert_tools_to_hfapi_tools([tool]) |
| 1017 | + |
| 1018 | + mock_class.assert_called_once_with( |
| 1019 | + name="weather", |
| 1020 | + arguments={"city": {"type": "string"}}, |
| 1021 | + description="useful to determine the weather in a given location", |
| 1022 | + ) |
0 commit comments