diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index 82e78a61f0..d95e87dd85 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -127,6 +127,71 @@ def _get_run_method_params(instance: "Agent") -> set[str]: return {name for name, p in sig.parameters.items() if p.kind != inspect.Parameter.VAR_KEYWORD} +def _select_tools_by_name(configured_tools: ToolsType, names: list[str]) -> list[Tool | Toolset]: + """ + Select configured tools by name for a single run. + + Standalone Tools are kept when their name is requested. A Toolset that exposes a requested name is replaced by a + per-run `spawn()` (an isolated copy) with the requested names registered as its `_selected_tool_names`, so + dynamic toolsets such as SearchableToolset preserve their behavior (search/lazy-loading) over the selected subset + without mutating the shared, configured Toolset. + + :param configured_tools: The tools configured on the Agent. + :param names: The requested tool names. + :returns: The selected standalone Tools and/or spawned, selection-scoped Toolsets. + :raises ValueError: If no tools were configured, or if any requested name is not a valid tool name. + """ + if not configured_tools: + raise ValueError("No tools were configured for the Agent at initialization.") + + requested_names = set(names) + items: list[Tool | Toolset] = ( + [configured_tools] if isinstance(configured_tools, Toolset) else list(configured_tools) + ) + + # Resolve selectable names per item. For Toolsets we use get_selectable_tools() so dynamic toolsets + # (e.g. SearchableToolset) offer their full catalog by name, not just the tools exposed by iteration. + selectable_per_item: list[tuple[Tool | Toolset, set[str]]] = [] + valid_tool_names: set[str] = set() + for item in items: + item_names = {tool.name for tool in item.get_selectable_tools()} if isinstance(item, Toolset) else {item.name} + selectable_per_item.append((item, item_names)) + valid_tool_names |= item_names + + invalid_tool_names = requested_names - valid_tool_names + if invalid_tool_names: + raise ValueError( + f"The following tool names are not valid: {invalid_tool_names}. Valid tool names are: {valid_tool_names}." + ) + + selected: list[Tool | Toolset] = [] + for item, item_names in selectable_per_item: + matched = requested_names & item_names + if not matched: + continue + if isinstance(item, Toolset): + # Apply the selection to a per-run copy so the shared, configured Toolset is never mutated. + spawned = item.spawn() + spawned._selected_tool_names = matched + selected.append(spawned) + else: + selected.append(item) + return selected + + +def _spawn_tools(tools: ToolsType) -> ToolsType: + """ + Return per-run copies of `tools`, replacing each Toolset with an isolated `spawn()` (Tools are passed through). + + This isolates run-scoped Toolset state (e.g. a SearchableToolset's discovered tools and any active name + selection) so that concurrent runs sharing the same configured Toolset — such as parallel sub-agent tool calls + or concurrent requests against one Agent — don't corrupt each other. + """ + if isinstance(tools, Toolset): + return tools.spawn() + return [item.spawn() if isinstance(item, Toolset) else item for item in tools] + + def _validate_prompt_message_blocks(user_prompt: str | None, system_prompt: str | None) -> None: """ Validate explicit Jinja2 message blocks in Agent prompts. @@ -699,35 +764,26 @@ def _select_tools(self, tools: ToolsType | list[str] | None = None) -> ToolsType or if any provided tool name is not valid. :raises TypeError: If tools is not a list of Tool objects, a Toolset, or a list of tool names (strings). """ + # Toolsets are spawned into per-run copies (see _spawn_tools / _select_tools_by_name) so concurrent runs + # sharing the same configured Toolset don't corrupt each other's run-scoped state. if tools is None: - return self.tools + return _spawn_tools(self.tools) if isinstance(tools, list) and all(isinstance(t, str) for t in tools): - if not self.tools: - raise ValueError("No tools were configured for the Agent at initialization.") - available_tools = flatten_tools_or_toolsets(self.tools) - selected_tool_names = cast(list[str], tools) # mypy thinks this could still be list[Tool] or Toolset - valid_tool_names = {tool.name for tool in available_tools} - invalid_tool_names = {name for name in selected_tool_names if name not in valid_tool_names} - if invalid_tool_names: - raise ValueError( - f"The following tool names are not valid: {invalid_tool_names}. " - f"Valid tool names are: {valid_tool_names}." - ) - return [tool for tool in available_tools if tool.name in selected_tool_names] + return _select_tools_by_name(self.tools, cast(list[str], tools)) if isinstance(tools, Toolset): # Per-run tools are not covered by the Agent's own warm_up(), so warm them up here. # warm_up() is expected to be idempotent, so re-warming on every run is cheap. warm_up_tools(tools) - return tools + return _spawn_tools(tools) if isinstance(tools, list): selected = cast(list[Tool | Toolset], tools) # mypy can't narrow the Union type from isinstance check # Per-run tools are not covered by the Agent's own warm_up(), so warm them up here. # warm_up() is expected to be idempotent, so re-warming on every run is cheap. warm_up_tools(selected) - return selected + return _spawn_tools(selected) raise TypeError( "tools must be a list of Tool and/or Toolset objects, a Toolset, or a list of tool names (strings)." diff --git a/haystack/tools/searchable_toolset.py b/haystack/tools/searchable_toolset.py index 03b7c42ed9..c8de77769a 100644 --- a/haystack/tools/searchable_toolset.py +++ b/haystack/tools/searchable_toolset.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +import copy from collections.abc import Iterator from typing import TYPE_CHECKING, Annotated, Any @@ -174,6 +175,19 @@ def warm_up(self) -> None: self._is_warmed_up = True + def get_selectable_tools(self) -> list[Tool]: + """ + Return the full catalog of tools that can be selected by name. + + Iteration only exposes the search tool plus already-discovered tools, but name-based selection can target + any tool in the catalog, so this returns the entire flattened catalog (warming up first if needed). + + :returns: The flattened catalog of tools. + """ + if not self._is_warmed_up: + self.warm_up() + return list(self._catalog) + def clear(self) -> None: """ Clear all discovered tools. @@ -183,6 +197,27 @@ def clear(self) -> None: """ self._discovered_tools.clear() + def spawn(self) -> "SearchableToolset": + """ + Return an isolated copy for a single run. + + The copy shares the read-only catalog and BM25 index but gets fresh discovered tools and name selection, + plus a bootstrap search tool bound to the copy. This way concurrent runs sharing the same configured + SearchableToolset don't share discovered tools or collide on the active selection. + + :returns: A run-scoped copy of this SearchableToolset. + """ + if not self._is_warmed_up: + self.warm_up() + new = copy.copy(self) + new._discovered_tools = {} + new._selected_tool_names = None + # Rebuild the bootstrap tool so its closure is bound to the copy's discovered tools / selection + # rather than the original's. The document store and catalog are read-only and stay shared. + if not self._passthrough: + new._bootstrap_tool = new._create_search_tool() + return new + def _create_search_tool(self) -> Tool: """Create the search_tools bootstrap tool.""" @@ -213,8 +248,15 @@ def search_tools( "names/descriptions (e.g. 'route weather search')." ) + # Scope the search to the selected subset if active so that top_k applies within the selected tools + filters = None + if self._selected_tool_names is not None: + filters = {"field": "meta.tool_name", "operator": "in", "value": list(self._selected_tool_names)} + # at this point, the toolset has been warmed up, so self._document_store is not None - results = self._document_store.bm25_retrieval(query=tool_keywords, top_k=num_results) # type: ignore[union-attr] + results = self._document_store.bm25_retrieval( # type: ignore[union-attr] + query=tool_keywords, top_k=num_results, filters=filters + ) if not results: return "No tools found matching these keywords. Try different keywords." @@ -249,13 +291,18 @@ def search_tools( return bootstrap_tool + def _is_selected(self, name: str) -> bool: + """Whether a catalog tool name is allowed by the active `_selected_tool_names` filter (None means all).""" + return self._selected_tool_names is None or name in self._selected_tool_names + def __iter__(self) -> Iterator[Tool]: """ Iterate over available tools. - In passthrough mode, yields all catalog tools. - Otherwise, yields bootstrap tool + discovered tools. - Automatically calls warm_up() if needed to ensure bootstrap tool is available. + In passthrough mode, yields all catalog tools. Otherwise, yields the bootstrap search tool plus the + already-discovered tools. If `_selected_tool_names` is set, catalog/discovered tools are restricted to that + set, but the bootstrap search tool is always exposed so search keeps working over the selected subset. + Automatically calls warm_up() if needed to ensure the bootstrap tool is available. """ # Unlike base Toolset/MCPToolset, which expose a placeholder tool before warm_up, this toolset materializes # everything (flattened catalog, bootstrap tool, passthrough decision) in warm_up. @@ -264,11 +311,11 @@ def __iter__(self) -> Iterator[Tool]: if not self._is_warmed_up: self.warm_up() if self._passthrough: - yield from self._catalog + yield from (tool for tool in self._catalog if self._is_selected(tool.name)) else: if self._bootstrap_tool is not None: yield self._bootstrap_tool - yield from self._discovered_tools.values() + yield from (tool for tool in self._discovered_tools.values() if self._is_selected(tool.name)) def __len__(self) -> int: """Return the number of currently available tools.""" diff --git a/haystack/tools/toolset.py b/haystack/tools/toolset.py index ff0fe4008c..ceb0ebb2bc 100644 --- a/haystack/tools/toolset.py +++ b/haystack/tools/toolset.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +import copy from collections.abc import Iterator from dataclasses import dataclass, field from typing import Any @@ -47,8 +48,8 @@ def subtract(a: Annotated[int, "first number"], b: Annotated[int, "second number ``` 2. Base class for dynamic tool loading: - By subclassing Toolset, you can create implementations that dynamically load tools - from external sources like OpenAPI URLs, MCP servers, or other resources. + By subclassing Toolset, you can create implementations that dynamically load tools from external sources like + OpenAPI URLs, MCP servers, or other resources. Example: ```python @@ -94,15 +95,14 @@ def from_dict(cls, data): agent = Agent(chat_generator=OpenAIChatGenerator(), tools=calculator_toolset) ``` - Toolset implements the collection interface (__iter__, __contains__, __len__, __getitem__), - making it behave like a list of Tools. This makes it compatible with components that expect - iterable tools, such as Agent or Haystack chat generators. + Toolset implements the collection interface (__iter__, __contains__, __len__, __getitem__), making it behave like + a list of Tools. This makes it compatible with components that expect iterable tools, such as Agent or Haystack + chat generators. When implementing a custom Toolset subclass for dynamic tool loading: - Perform the dynamic loading in the __init__ method - Override to_dict() and from_dict() methods if your tools are defined dynamically - - Serialize endpoint descriptors rather than tool instances if your tools - are loaded from external sources + - Serialize endpoint descriptors rather than tool instances if your tools are loaded from external sources """ # Use field() with default_factory to initialize the list @@ -124,15 +124,56 @@ def __post_init__(self) -> None: # Tracks whether warm_up() has already run so subsequent calls become a no-op. self._is_warmed_up = False + # Optional per-run name filter. When set, iteration only yields tools whose name is in this set. + # None means no filtering. Set on a per-run spawn(), so it never leaks across runs. + self._selected_tool_names: set[str] | None = None + def __iter__(self) -> Iterator[Tool]: """ Return an iterator over the Tools in this Toolset. - This allows the Toolset to be used wherever a list of Tools is expected. + This allows the Toolset to be used wherever a list of Tools is expected. If a name filter is active, + only the tools whose names are in it are yielded. :returns: An iterator yielding Tool instances """ - return iter(self.tools) + for tool in self.tools: + if self._selected_tool_names is None or tool.name in self._selected_tool_names: + yield tool + + def get_selectable_tools(self) -> list[Tool]: + """ + Return the full set of tools that can be selected by name, ignoring any active name filter. + + This differs from iteration, which yields only the tools currently exposed (and respects the name filter). + Override this when a Toolset's iteration does not surface every selectable tool, so name-based selection + can still target the full set. + + Warms up the Toolset first if needed, so lazily loaded tools (those a Toolset fetches in `warm_up()`) + are available for selection. + + :returns: The list of tools available for name-based selection. + """ + if not self._is_warmed_up: + self.warm_up() + return list(self.tools) + + def spawn(self) -> "Toolset": + """ + Return an isolated copy of this Toolset for a single run. + + The copy shares this Toolset's read-only state (its tools and any warmed-up resources) but gets fresh + run-scoped state, so concurrent runs that share the same configured Toolset don't corrupt each other (for + example, one run's name selection leaking into another). Warms up first if needed so the copy shares the + warmed state. Subclasses with additional run-scoped state should override this. + + :returns: A run-scoped copy of this Toolset. + """ + if not self._is_warmed_up: + self.warm_up() + new = copy.copy(self) + new._selected_tool_names = None + return new def __contains__(self, item: str | Tool) -> bool: """ @@ -146,9 +187,9 @@ def __contains__(self, item: str | Tool) -> bool: :returns: True if contained, False otherwise """ if isinstance(item, str): - return any(tool.name == item for tool in self.tools) + return any(tool.name == item for tool in self) if isinstance(item, Tool): - return item in self.tools + return any(tool is item or tool == item for tool in self) return False def warm_up(self) -> None: @@ -281,20 +322,20 @@ def __add__(self, other: "Tool | Toolset | list[Tool]") -> "Toolset": def __len__(self) -> int: """ - Return the number of Tools in this Toolset. + Return the number of Tools in this Toolset (respecting any active name filter). :returns: Number of Tools """ - return len(self.tools) + return sum(1 for _ in self) def __getitem__(self, index: int) -> Tool: """ - Get a Tool by index. + Get a Tool by index (respecting any active name filter). :param index: Index of the Tool to get :returns: The Tool at the specified index """ - return self.tools[index] + return list(self)[index] class _ToolsetWrapper(Toolset): @@ -312,9 +353,19 @@ def __init__(self, toolsets: list[Toolset]) -> None: self._is_warmed_up = False def __iter__(self) -> Iterator[Tool]: - """Iterate over all tools from all toolsets.""" + """Iterate over all tools from all toolsets, honoring any active name filter.""" for toolset in self.toolsets: - yield from toolset + for tool in toolset: + if self._selected_tool_names is None or tool.name in self._selected_tool_names: + yield tool + + def get_selectable_tools(self) -> list[Tool]: + """Return every selectable tool across all wrapped toolsets, ignoring any active filter.""" + return [tool for toolset in self.toolsets for tool in toolset.get_selectable_tools()] + + def spawn(self) -> "_ToolsetWrapper": + """Return an isolated copy with each wrapped toolset spawned.""" + return _ToolsetWrapper([toolset.spawn() for toolset in self.toolsets]) def __contains__(self, item: Any) -> bool: """Check if a tool is in any of the toolsets.""" @@ -371,8 +422,8 @@ def from_dict(cls, data: dict[str, Any]) -> "_ToolsetWrapper": return cls(toolsets=toolsets) def __len__(self) -> int: - """Return total number of tools across all toolsets.""" - return sum(len(toolset) for toolset in self.toolsets) + """Return total number of tools across all toolsets (respecting any active name filter).""" + return sum(1 for _ in self) def __getitem__(self, index: int) -> Tool: """Get a tool by index across all toolsets.""" diff --git a/releasenotes/notes/toolset-name-selection-e94fbf226507ec54.yaml b/releasenotes/notes/toolset-name-selection-e94fbf226507ec54.yaml new file mode 100644 index 0000000000..68af278249 --- /dev/null +++ b/releasenotes/notes/toolset-name-selection-e94fbf226507ec54.yaml @@ -0,0 +1,19 @@ +--- +fixes: + - | + Runtime tool-name selection via ``Agent.run(tools=["tool_a", "tool_b"])`` now resolves correctly when a + ``SearchableToolset`` is configured. Previously the ``SearchableToolset`` was flattened into a single-item + list (just its search tool), which broke its search and lazy-loading behavior. + - | + A ``Toolset`` is no longer mutated in place during an ``Agent`` run. Each run now operates on an isolated + per-run copy of the configured ``Toolset`` (via the new ``Toolset.spawn()`` method). This makes concurrent + runs that share the same ``Toolset`` instance safe: per-run state such as the active tool-name selection, and + a ``SearchableToolset``'s discovered tools, can no longer leak or collide across runs. +enhancements: + - | + ``Toolset`` gained two methods: ``get_selectable_tools()`` returns every tool available for name-based + selection (ignoring any active selection restriction, e.g. the one used by ``SearchableToolset``), and + ``spawn()`` returns an isolated, run-scoped copy of the Toolset. Subclasses with additional run-scoped state + can override ``spawn()``. + - | + ``SearchableToolset`` now exposes its full catalog of tools via ``SearchableToolset.get_selectable_tools()``. diff --git a/test/components/agents/test_agent.py b/test/components/agents/test_agent.py index e45e36d620..e358f93e3f 100644 --- a/test/components/agents/test_agent.py +++ b/test/components/agents/test_agent.py @@ -16,7 +16,7 @@ from openai.types.chat import ChatCompletionChunk, chat_completion_chunk from haystack import Document, Pipeline, component, tracing -from haystack.components.agents.agent import Agent, _accumulate_usage +from haystack.components.agents.agent import Agent, _accumulate_usage, _select_tools_by_name from haystack.components.agents.state import merge_lists, replace_values from haystack.components.agents.tool_calling import _run_tool from haystack.components.builders.chat_prompt_builder import ChatPromptBuilder @@ -1575,17 +1575,56 @@ def test_agent_span_has_parent_when_in_pipeline(self, spying_tracer, weather_too assert agent_span.parent_span == agent_component_span -class TestAgentToolSelection: - def test_tool_selection_by_name(self, weather_tool: Tool, component_tool: Tool): - chat_generator = MockChatGenerator() - agent = Agent( - chat_generator=chat_generator, - tools=[weather_tool, component_tool], - system_prompt="This is a system prompt.", - ) - result = agent._select_tools([weather_tool.name]) +class TestSelectToolsByName: + """Tests for the _select_tools_by_name helper (runtime tool-name selection).""" + + def test_selects_standalone_tools_by_name(self, weather_tool: Tool, component_tool: Tool): + result = _select_tools_by_name([weather_tool, component_tool], [weather_tool.name]) assert result == [weather_tool] + def test_raises_on_invalid_name(self, weather_tool: Tool, component_tool: Tool): + with pytest.raises( + ValueError, match="The following tool names are not valid: {'invalid_tool_name'}. Valid tool names are: ." + ): + _select_tools_by_name([weather_tool, component_tool], ["invalid_tool_name"]) + + def test_raises_when_no_tools_configured(self, weather_tool: Tool): + with pytest.raises(ValueError, match="No tools were configured for the Agent at initialization."): + _select_tools_by_name([], [weather_tool.name]) + + def test_returns_isolated_spawn_with_selection(self, weather_tool: Tool, component_tool: Tool): + """A Toolset exposing a requested name is replaced by an isolated spawn carrying the selection. + + The shared, configured Toolset is not mutated. + """ + toolset = Toolset([weather_tool, component_tool]) + + result = _select_tools_by_name([toolset], [weather_tool.name]) + + assert len(result) == 1 + spawned = result[0] + assert isinstance(spawned, Toolset) + assert spawned is not toolset # an isolated per-run copy + assert spawned._selected_tool_names == {weather_tool.name} + assert [tool.name for tool in spawned] == [weather_tool.name] + # The configured toolset is untouched. + assert toolset._selected_tool_names is None + + def test_mixed_standalone_tools_and_toolsets(self, weather_tool: Tool, component_tool: Tool): + toolset = Toolset([weather_tool]) + + result = _select_tools_by_name([component_tool, toolset], [weather_tool.name, component_tool.name]) + + # The standalone tool is passed through; the toolset is replaced by an isolated spawn with the selection. + assert component_tool in result + spawns = [t for t in result if isinstance(t, Toolset)] + assert len(spawns) == 1 + assert spawns[0] is not toolset + assert spawns[0]._selected_tool_names == {weather_tool.name} + assert toolset._selected_tool_names is None + + +class TestAgentToolSelection: def test_tool_selection_new_tool(self, weather_tool: Tool, component_tool: Tool): chat_generator = MockChatGenerator() agent = Agent(chat_generator=chat_generator, tools=[weather_tool], system_prompt="This is a system prompt.") @@ -1602,24 +1641,6 @@ def test_tool_selection_existing_tools(self, weather_tool: Tool, component_tool: result = agent._select_tools(None) assert result == [weather_tool, component_tool] - def test_tool_selection_invalid_tool_name(self, weather_tool: Tool, component_tool: Tool): - chat_generator = MockChatGenerator() - agent = Agent( - chat_generator=chat_generator, - tools=[weather_tool, component_tool], - system_prompt="This is a system prompt.", - ) - with pytest.raises( - ValueError, match=("The following tool names are not valid: {'invalid_tool_name'}. Valid tool names are: .") - ): - agent._select_tools(["invalid_tool_name"]) - - def test_tool_selection_no_tools_configured(self, weather_tool: Tool, component_tool: Tool): - chat_generator = MockChatGenerator() - agent = Agent(chat_generator=chat_generator, tools=[], system_prompt="This is a system prompt.") - with pytest.raises(ValueError, match="No tools were configured for the Agent at initialization."): - agent._select_tools([weather_tool.name]) - def test_tool_selection_invalid_type(self, weather_tool: Tool, component_tool: Tool): chat_generator = MockChatGenerator() agent = Agent( diff --git a/test/tools/test_searchable_toolset.py b/test/tools/test_searchable_toolset.py index e3b0def40a..1bf0d5155f 100644 --- a/test/tools/test_searchable_toolset.py +++ b/test/tools/test_searchable_toolset.py @@ -9,10 +9,11 @@ import pytest +from haystack import component from haystack.components.agents import Agent from haystack.components.generators.chat import OpenAIChatGenerator -from haystack.dataclasses import ChatMessage -from haystack.tools import SearchableToolset, Tool, Toolset +from haystack.dataclasses import ChatMessage, ToolCall +from haystack.tools import SearchableToolset, Tool, Toolset, flatten_tools_or_toolsets from haystack.tools.from_function import create_tool_from_function @@ -264,6 +265,32 @@ def test_search_tools_respects_k(self, large_catalog): # Should find exactly 1 tool assert "Found and loaded 1 tool(s):" in result + def test_search_with_selection_scopes_retrieval(self): + """With a name selection, search is scoped to the selected tools so a small top_k still finds them.""" + + def fn(x): + return x + + params = {"type": "object", "properties": {"x": {"type": "string"}}, "required": ["x"]} + # Several decoys match "weather" strongly; the target matches it too but less strongly. + catalog = [ + Tool(name=f"decoy_{i}", description="weather weather weather forecast", parameters=params, function=fn) + for i in range(5) + ] + catalog.append(Tool(name="target_weather", description="weather report", parameters=params, function=fn)) + + toolset = SearchableToolset(catalog=catalog, search_threshold=3, top_k=1) + toolset.warm_up() + assert toolset._bootstrap_tool is not None + toolset._selected_tool_names = {"target_weather"} + + result = toolset._bootstrap_tool.invoke(tool_keywords="weather") + + # Even with top_k=1 and several stronger-matching decoys, the scoped search still finds the selected tool + # (the old approach retrieved top_k across the whole catalog and then dropped non-selected results). + assert "target_weather" in result + assert set(toolset._discovered_tools) == {"target_weather"} + def test_search_tools_no_results(self, large_catalog): """Test search_tools with no matching results.""" toolset = SearchableToolset(catalog=large_catalog) @@ -789,6 +816,149 @@ def test_serialization(self, large_catalog): assert "get_weather" in result +class TestSearchableToolsetAgentToolSelection: + """Deterministic Agent tests for runtime tool-name selection and lazy tool_call_counts.""" + + def test_get_selectable_tools_exposes_full_catalog(self, large_catalog): + """get_selectable_tools() exposes the whole catalog, unlike iteration (search tool + discovered only).""" + toolset = SearchableToolset(catalog=large_catalog, search_threshold=3) + toolset.warm_up() + + # Iteration only exposes the bootstrap search tool before anything is discovered. + assert [tool.name for tool in toolset] == ["search_tools"] + # The catalog, however, is fully available for name-based selection. + assert {tool.name for tool in toolset.get_selectable_tools()} == {tool.name for tool in large_catalog} + + def test_runtime_tool_names_select_isolated_spawn_and_preserve_search(self, large_catalog, monkeypatch): + """Selecting catalog tool names returns an isolated spawn carrying the selection and keeping search active.""" + monkeypatch.setenv("OPENAI_API_KEY", "fake-key") + toolset = SearchableToolset(catalog=large_catalog, search_threshold=3) # 8 tools -> search mode + agent = Agent(chat_generator=OpenAIChatGenerator(), tools=toolset) + + selected = agent._select_tools(["get_weather", "add_numbers"]) + + # An isolated spawn is returned with the selection; the configured toolset is not mutated. + assert len(selected) == 1 + spawned = selected[0] + assert isinstance(spawned, SearchableToolset) + assert spawned is not toolset + assert spawned._selected_tool_names == {"get_weather", "add_numbers"} + assert toolset._selected_tool_names is None + # Search is preserved on the spawn (not dismantled): only the bootstrap tool is exposed up front. + assert [tool.name for tool in spawned] == ["search_tools"] + # And search only discovers tools within the selected subset. + assert spawned._bootstrap_tool is not None + spawned._bootstrap_tool.invoke(tool_keywords="weather add stock multiply") + assert set(spawned._discovered_tools) <= {"get_weather", "add_numbers"} + # The configured toolset's discovered tools are untouched. + assert toolset._discovered_tools == {} + + def test_spawns_have_independent_discovered_tools_and_selection(self, large_catalog): + """Two spawns of one SearchableToolset don't share discovered tools or collide on the active selection.""" + toolset = SearchableToolset(catalog=large_catalog, search_threshold=3) + toolset.warm_up() + + spawn_a = toolset.spawn() + spawn_b = toolset.spawn() + + assert spawn_a is not spawn_b + assert spawn_a is not toolset + # Bootstrap tools are rebound per spawn (not shared with the original or each other). + assert spawn_a._bootstrap_tool is not None + assert spawn_a._bootstrap_tool is not spawn_b._bootstrap_tool + + spawn_a._selected_tool_names = {"get_weather"} + spawn_a._bootstrap_tool.invoke(tool_keywords="weather add stock multiply") + + # Discovery on spawn_a does not leak into spawn_b or the configured toolset. + assert set(spawn_a._discovered_tools) <= {"get_weather"} + assert spawn_b._discovered_tools == {} + assert toolset._discovered_tools == {} + # The selection is likewise isolated. + assert spawn_b._selected_tool_names is None + assert toolset._selected_tool_names is None + + def test_runtime_tool_names_passthrough_exposes_selected(self, large_catalog, monkeypatch): + """In passthrough mode, selecting names exposes exactly those catalog tools directly.""" + monkeypatch.setenv("OPENAI_API_KEY", "fake-key") + toolset = SearchableToolset(catalog=large_catalog, search_threshold=20) # 8 < 20 -> passthrough + agent = Agent(chat_generator=OpenAIChatGenerator(), tools=toolset) + + selected = agent._select_tools(["get_weather", "add_numbers"]) + + assert selected == [toolset] + assert {tool.name for tool in flatten_tools_or_toolsets(selected)} == {"get_weather", "add_numbers"} + + def test_agent_run_with_runtime_tool_names(self, large_catalog): + """An Agent with a SearchableToolset runs with specific catalog tools selected by name on an isolated spawn.""" + toolset = SearchableToolset(catalog=large_catalog, search_threshold=20) # passthrough exposes the selection + + @component + class WeatherCallingGenerator: + invoked = False + + @component.output_types(replies=list[ChatMessage]) + def run(self, messages, tools=None, **kwargs): + # In passthrough mode the selected catalog tool is exposed directly. + assert [tool.name for tool in tools] == ["get_weather"] + if self.invoked: + return {"replies": [ChatMessage.from_assistant("done")]} + self.invoked = True + return { + "replies": [ + ChatMessage.from_assistant( + tool_calls=[ToolCall(tool_name="get_weather", arguments={"city": "Berlin"})] + ) + ] + } + + agent = Agent(chat_generator=WeatherCallingGenerator(), tools=toolset, max_agent_steps=5) + result = agent.run(messages=[ChatMessage.from_user("What's the weather in Berlin?")], tools=["get_weather"]) + + assert result["tool_call_counts"]["get_weather"] == 1 + # The Agent runs against an isolated spawn, so the configured toolset's selection never gets set. + assert toolset._selected_tool_names is None + + def test_discovered_tool_call_counts_added_lazily(self, large_catalog): + """tool_call_counts seeds only the search tool up front; a discovered+called tool is added lazily.""" + toolset = SearchableToolset(catalog=large_catalog, search_threshold=3, top_k=5) + + @component + class SearchThenWeatherGenerator: + step = 0 + + @component.output_types(replies=list[ChatMessage]) + def run(self, messages, tools=None, **kwargs): + self.step += 1 + if self.step == 1: + return { + "replies": [ + ChatMessage.from_assistant( + tool_calls=[ToolCall(tool_name="search_tools", arguments={"tool_keywords": "weather"})] + ) + ] + } + if self.step == 2: + return { + "replies": [ + ChatMessage.from_assistant( + tool_calls=[ToolCall(tool_name="get_weather", arguments={"city": "Berlin"})] + ) + ] + } + return {"replies": [ChatMessage.from_assistant("done")]} + + agent = Agent(chat_generator=SearchThenWeatherGenerator(), tools=toolset, max_agent_steps=6) + result = agent.run(messages=[ChatMessage.from_user("What's the weather in Berlin?")]) + + counts = result["tool_call_counts"] + # search_tools is seeded at init; get_weather is only counted after being discovered and called. + assert counts["search_tools"] == 1 + assert counts["get_weather"] == 1 + # The Agent discovers tools on an isolated spawn, so the configured toolset's discovered tools stay empty. + assert toolset._discovered_tools == {} + + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") @pytest.mark.integration class TestSearchableToolsetAgentIntegration: @@ -802,11 +972,13 @@ def test_agent_discovers_and_uses_tools(self, large_catalog): assert len(agent.tools) == 1 result = agent.run(messages=[ChatMessage.from_user("What's the weather in Milan?")]) - assert len(agent.tools) > 1 + # Discovered tools are cleared after the run (reset), so only the search tool remains exposed. + assert len(agent.tools) == 1 assert "messages" in result messages = result["messages"] assert len(messages) > 1 + # Discovery happened during the run: the agent searched and then called the discovered tool. tool_calls = [tool_call for msg in messages if msg.tool_calls for tool_call in msg.tool_calls] assert len(tool_calls) > 1 assert any(tool_call.tool_name == "search_tools" for tool_call in tool_calls) diff --git a/test/tools/test_toolset.py b/test/tools/test_toolset.py index 5a3bcb8e46..139e13d2a5 100644 --- a/test/tools/test_toolset.py +++ b/test/tools/test_toolset.py @@ -431,3 +431,115 @@ def test_plus_returns_new_unwarmed_toolset(self): assert new_tool.warm_up_count == 0 ts2.warm_up() assert new_tool.warm_up_count == 1 + + +class TestToolsetToolSelection: + """Tests for get_selectable_tools(), the name filter, and spawn().""" + + def test_no_filter_yields_all_tools(self, add_tool, multiply_tool): + toolset = Toolset([add_tool, multiply_tool]) + assert toolset._selected_tool_names is None + assert [tool.name for tool in toolset] == ["add", "multiply"] + assert len(toolset) == 2 + + def test_get_selectable_tools_returns_all_tools(self, add_tool, multiply_tool): + toolset = Toolset([add_tool, multiply_tool]) + assert toolset.get_selectable_tools() == [add_tool, multiply_tool] + + def test_get_selectable_tools_ignores_active_filter(self, add_tool, multiply_tool): + toolset = Toolset([add_tool, multiply_tool]) + toolset._selected_tool_names = {"add"} + # Iteration is filtered, but get_selectable_tools still returns the full set. + assert [tool.name for tool in toolset] == ["add"] + assert {tool.name for tool in toolset.get_selectable_tools()} == {"add", "multiply"} + + def test_get_selectable_tools_warms_up_lazy_toolset(self, add_tool, multiply_tool): + """get_selectable_tools() warms up a lazy toolset so its lazily loaded tools are available for selection.""" + + class LazyToolset(Toolset): + def __init__(self): + super().__init__([]) # no tools until warm_up + + def warm_up(self): + if self._is_warmed_up: + return + self.tools = [add_tool, multiply_tool] + self._is_warmed_up = True + + toolset = LazyToolset() + assert toolset._is_warmed_up is False + assert toolset.tools == [] # not loaded yet + + selectable = toolset.get_selectable_tools() + + assert toolset._is_warmed_up is True # get_selectable_tools triggered warm_up + assert [tool.name for tool in selectable] == ["add", "multiply"] + + def test_filter_restricts_iteration(self, add_tool, multiply_tool, subtract_tool): + toolset = Toolset([add_tool, multiply_tool, subtract_tool]) + toolset._selected_tool_names = {"add", "subtract"} + assert [tool.name for tool in toolset] == ["add", "subtract"] + + def test_filter_restricts_len(self, add_tool, multiply_tool, subtract_tool): + toolset = Toolset([add_tool, multiply_tool, subtract_tool]) + toolset._selected_tool_names = {"add"} + assert len(toolset) == 1 + + def test_filter_restricts_getitem(self, add_tool, multiply_tool, subtract_tool): + toolset = Toolset([add_tool, multiply_tool, subtract_tool]) + toolset._selected_tool_names = {"subtract"} + assert toolset[0].name == "subtract" + + def test_filter_restricts_contains(self, add_tool, multiply_tool): + toolset = Toolset([add_tool, multiply_tool]) + toolset._selected_tool_names = {"add"} + assert "add" in toolset + assert "multiply" not in toolset + assert add_tool in toolset + assert multiply_tool not in toolset + + def test_spawn_returns_isolated_copy(self, add_tool, multiply_tool): + toolset = Toolset([add_tool, multiply_tool]) + + spawned = toolset.spawn() + + assert spawned is not toolset + assert spawned._selected_tool_names is None + # The copy shares the same (read-only) tools. + assert list(spawned.tools) == list(toolset.tools) + + def test_spawn_selection_does_not_leak_to_original(self, add_tool, multiply_tool): + """A per-run selection set on a spawn must not affect the configured toolset or other spawns.""" + toolset = Toolset([add_tool, multiply_tool]) + + spawn_a = toolset.spawn() + spawn_b = toolset.spawn() + spawn_a._selected_tool_names = {"add"} + + # Each run sees only its own selection; the configured toolset stays unfiltered. + assert [tool.name for tool in spawn_a] == ["add"] + assert [tool.name for tool in spawn_b] == ["add", "multiply"] + assert [tool.name for tool in toolset] == ["add", "multiply"] + assert toolset._selected_tool_names is None + + def test_spawn_warms_up_lazy_toolset(self, add_tool, multiply_tool): + """spawn() warms up a lazy toolset so the copy shares the warmed state.""" + + class LazyToolset(Toolset): + def __init__(self): + super().__init__([]) + + def warm_up(self): + if self._is_warmed_up: + return + self.tools = [add_tool, multiply_tool] + self._is_warmed_up = True + + toolset = LazyToolset() + assert toolset._is_warmed_up is False + + spawned = toolset.spawn() + + assert toolset._is_warmed_up is True # spawn triggered warm_up + assert spawned._is_warmed_up is True + assert [tool.name for tool in spawned] == ["add", "multiply"] diff --git a/test/tools/test_toolset_wrapper.py b/test/tools/test_toolset_wrapper.py index 4f4188cca9..f212c05ac8 100644 --- a/test/tools/test_toolset_wrapper.py +++ b/test/tools/test_toolset_wrapper.py @@ -192,3 +192,41 @@ def test_from_dict_rejects_non_toolset(self, add_tool): with pytest.raises(TypeError, match="is not a subclass of Toolset"): _ToolsetWrapper.from_dict(data) + + +class TestToolsetWrapperToolSelection: + """Tests for get_selectable_tools(), the name filter, and spawn() on _ToolsetWrapper.""" + + def test_get_selectable_tools_aggregates_all_toolsets(self, add_tool, multiply_tool, subtract_tool): + wrapper = Toolset([add_tool]) + Toolset([multiply_tool, subtract_tool]) + assert {tool.name for tool in wrapper.get_selectable_tools()} == {"add", "multiply", "subtract"} + + def test_get_selectable_tools_ignores_active_filter(self, add_tool, multiply_tool): + wrapper = Toolset([add_tool]) + Toolset([multiply_tool]) + wrapper._selected_tool_names = {"add"} + # Iteration is filtered, but get_selectable_tools still returns the full set. + assert [tool.name for tool in wrapper] == ["add"] + assert {tool.name for tool in wrapper.get_selectable_tools()} == {"add", "multiply"} + + def test_filter_restricts_iteration_and_len(self, add_tool, multiply_tool, subtract_tool): + wrapper = Toolset([add_tool, multiply_tool]) + Toolset([subtract_tool]) + wrapper._selected_tool_names = {"add", "subtract"} + assert [tool.name for tool in wrapper] == ["add", "subtract"] + assert len(wrapper) == 2 + + def test_spawn_isolates_own_and_child_state(self, add_tool, multiply_tool): + ts1 = Toolset([add_tool]) + ts2 = Toolset([multiply_tool]) + wrapper = ts1 + ts2 + + spawned = wrapper.spawn() + + # The spawn and its wrapped toolsets are independent copies. + assert spawned is not wrapper + spawned._selected_tool_names = {"add"} + assert {tool.name for tool in spawned} == {"add"} + # The configured wrapper and its children are untouched. + assert wrapper._selected_tool_names is None + assert ts1._selected_tool_names is None + assert ts2._selected_tool_names is None + assert {tool.name for tool in wrapper} == {"add", "multiply"}