Skip to content

Commit ae6e38b

Browse files
authored
Update tools param to ToolsType (#2438)
1 parent d636a9e commit ae6e38b

3 files changed

Lines changed: 68 additions & 15 deletions

File tree

integrations/llama_cpp/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ classifiers = [
2626
"Programming Language :: Python :: Implementation :: CPython",
2727
"Programming Language :: Python :: Implementation :: PyPy",
2828
]
29-
dependencies = ["haystack-ai>=2.16.1", "llama-cpp-python>=0.2.87"]
29+
dependencies = ["haystack-ai>=2.19.0", "llama-cpp-python>=0.2.87"]
3030

3131
# On macOS GitHub runners, we use a custom index to download pre-built wheels.
3232
# Installing from source might fail due to missing dependencies (CMake fails with "OpenMP not found")

integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
)
1717
from haystack.dataclasses.streaming_chunk import FinishReason, StreamingChunk, SyncStreamingCallbackT
1818
from haystack.tools import (
19-
Tool,
20-
Toolset,
19+
ToolsType,
2120
_check_duplicate_tool_names,
2221
deserialize_tools_or_toolset_inplace,
22+
flatten_tools_or_toolsets,
2323
serialize_tools_or_toolset,
2424
)
2525
from haystack.utils import deserialize_callable, serialize_callable
@@ -196,7 +196,7 @@ def __init__(
196196
model_kwargs: Optional[Dict[str, Any]] = None,
197197
generation_kwargs: Optional[Dict[str, Any]] = None,
198198
*,
199-
tools: Optional[Union[List[Tool], Toolset]] = None,
199+
tools: Optional[ToolsType] = None,
200200
streaming_callback: Optional[StreamingCallbackT] = None,
201201
chat_handler_name: Optional[str] = None,
202202
model_clip_path: Optional[str] = None,
@@ -215,8 +215,8 @@ def __init__(
215215
For more information on the available kwargs, see
216216
[llama.cpp documentation](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion).
217217
:param tools:
218-
A list of tools or a Toolset for which the model can prepare calls.
219-
This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
218+
A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
219+
Each tool should have a unique name.
220220
:param streaming_callback: A callback function that is called when a new token is received from the stream.
221221
:param chat_handler_name: Name of the chat handler for multimodal models.
222222
Common options include: "Llava16ChatHandler", "MoondreamChatHandler", "Qwen25VLChatHandler".
@@ -235,7 +235,7 @@ def __init__(
235235
model_kwargs.setdefault("n_ctx", n_ctx)
236236
model_kwargs.setdefault("n_batch", n_batch)
237237

238-
_check_duplicate_tool_names(list(tools or []))
238+
_check_duplicate_tool_names(flatten_tools_or_toolsets(tools))
239239

240240
handler: Optional[Llava15ChatHandler] = None
241241
# Validate multimodal requirements
@@ -325,7 +325,7 @@ def run(
325325
messages: List[ChatMessage],
326326
generation_kwargs: Optional[Dict[str, Any]] = None,
327327
*,
328-
tools: Optional[Union[List[Tool], Toolset]] = None,
328+
tools: Optional[ToolsType] = None,
329329
streaming_callback: Optional[StreamingCallbackT] = None,
330330
) -> Dict[str, List[ChatMessage]]:
331331
"""
@@ -337,8 +337,9 @@ def run(
337337
For more information on the available kwargs, see
338338
[llama.cpp documentation](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion).
339339
:param tools:
340-
A list of tools or a Toolset for which the model can prepare calls. If set, it will override the `tools`
341-
parameter set during component initialization.
340+
A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
341+
Each tool should have a unique name. If set, it will override the `tools` parameter set during
342+
component initialization.
342343
:param streaming_callback: A callback function that is called when a new token is received from the stream.
343344
If set, it will override the `streaming_callback` parameter set during component initialization.
344345
:returns: A dictionary with the following keys:
@@ -355,13 +356,12 @@ def run(
355356
formatted_messages = [_convert_message_to_llamacpp_format(msg) for msg in messages]
356357

357358
tools = tools or self.tools
358-
if isinstance(tools, Toolset):
359-
tools = list(tools)
360-
_check_duplicate_tool_names(tools)
359+
flattened_tools = flatten_tools_or_toolsets(tools)
360+
_check_duplicate_tool_names(flattened_tools)
361361

362362
llamacpp_tools: List[ChatCompletionTool] = []
363-
if tools:
364-
for t in tools:
363+
if flattened_tools:
364+
for t in flattened_tools:
365365
llamacpp_tools.append(
366366
{
367367
"type": "function",

integrations/llama_cpp/tests/test_chat_generator.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,59 @@ def test_init_with_toolset(self, temperature_tool):
723723
generator = LlamaCppChatGenerator(model="test_model.gguf", tools=toolset)
724724
assert generator.tools == toolset
725725

726+
def test_init_with_mixed_tools(self, temperature_tool):
727+
"""Test initialization with mixed Tool and Toolset objects."""
728+
729+
def population(city: str):
730+
"""Get population for a given city."""
731+
return f"The population of {city} is 2.2 million"
732+
733+
population_tool = create_tool_from_function(population)
734+
toolset = Toolset([population_tool])
735+
736+
generator = LlamaCppChatGenerator(model="test_model.gguf", tools=[temperature_tool, toolset])
737+
assert generator.tools == [temperature_tool, toolset]
738+
739+
def test_run_with_mixed_tools(self, temperature_tool):
740+
"""Test run method with mixed Tool and Toolset objects."""
741+
742+
def population(city: str):
743+
"""Get population for a given city."""
744+
return f"The population of {city} is 2.2 million"
745+
746+
population_tool = create_tool_from_function(population)
747+
toolset = Toolset([population_tool])
748+
749+
generator = LlamaCppChatGenerator(model="test_model.gguf")
750+
751+
# Mock the model
752+
mock_model = MagicMock()
753+
mock_response = {
754+
"choices": [{"message": {"content": "Generated text"}, "index": 0, "finish_reason": "stop"}],
755+
"id": "test_id",
756+
"model": "test_model",
757+
"created": 1234567890,
758+
"usage": {"prompt_tokens": 10, "completion_tokens": 5},
759+
}
760+
mock_model.create_chat_completion.return_value = mock_response
761+
generator._model = mock_model
762+
763+
generator.run(
764+
messages=[ChatMessage.from_user("What's the weather in Paris and population of Berlin?")],
765+
tools=[temperature_tool, toolset],
766+
)
767+
768+
# Verify the model was called with the correct tools
769+
mock_model.create_chat_completion.assert_called_once()
770+
call_args = mock_model.create_chat_completion.call_args[1]
771+
assert "tools" in call_args
772+
assert len(call_args["tools"]) == 2 # Both tools should be flattened
773+
774+
# Verify tool names
775+
tool_names = {tool["function"]["name"] for tool in call_args["tools"]}
776+
assert "get_current_temperature" in tool_names
777+
assert "population" in tool_names
778+
726779
def test_init_with_multimodal_params(self):
727780
"""Test initialization with multimodal parameters."""
728781
generator = LlamaCppChatGenerator(

0 commit comments

Comments
 (0)