Skip to content

Commit a0cd909

Browse files
committed
chore: Add type checking to test/tools
1 parent a4354a3 commit a0cd909

11 files changed

Lines changed: 98 additions & 59 deletions

haystack/tools/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020
# - list[Tool]: Most common pattern - list of Tool objects
2121
# - list[Toolset]: Less common pattern - list of Toolset objects
2222
# - list[Union[Tool, Toolset]]: Mixing Tools and Toolsets in one list
23+
# - list[ComponentTool]: List of ComponentTool objects
24+
# - list[PipelineTool]: List of PipelineTool objects
2325
# - Toolset: Single Toolset (not in a list)
24-
ToolsType = list[Tool] | list[Toolset] | list[Tool | Toolset] | Toolset
26+
ToolsType = list[Tool] | list[Toolset] | list[Tool | Toolset] | list[ComponentTool] | list[PipelineTool] | Toolset
2527

2628
__all__ = [
2729
"_check_duplicate_tool_names",

haystack/tools/from_function.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import inspect
66
from collections.abc import Callable
7-
from typing import Any
7+
from typing import Any, overload
88

99
from pydantic import create_model
1010

@@ -188,6 +188,30 @@ def get_weather(
188188
)
189189

190190

191+
@overload
192+
def tool(
193+
function: Callable,
194+
*,
195+
name: str | None = None,
196+
description: str | None = None,
197+
inputs_from_state: dict[str, str] | None = None,
198+
outputs_to_state: dict[str, dict[str, Any]] | None = None,
199+
outputs_to_string: dict[str, Any] | None = None,
200+
) -> Tool: ...
201+
202+
203+
@overload
204+
def tool(
205+
function: None = None,
206+
*,
207+
name: str | None = None,
208+
description: str | None = None,
209+
inputs_from_state: dict[str, str] | None = None,
210+
outputs_to_state: dict[str, dict[str, Any]] | None = None,
211+
outputs_to_string: dict[str, Any] | None = None,
212+
) -> Callable[[Callable], Tool]: ...
213+
214+
191215
def tool(
192216
function: Callable | None = None,
193217
*,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ integration-only-slow = 'pytest --maxfail=5 -m "integration and slow" {args:test
163163
all = 'pytest {args:test}'
164164

165165
# TODO We want to eventually type the whole test folder
166-
types = "mypy --install-types --non-interactive --cache-dir=.mypy_cache/ {args:haystack test/core/ test/marshal/ test/testing/ test/tracing/ test/human_in_the_loop test/evaluation test/document_stores test/dataclasses}"
166+
types = "mypy --install-types --non-interactive --cache-dir=.mypy_cache/ {args:haystack test/core/ test/marshal/ test/testing/ test/tracing/ test/tools/ test/human_in_the_loop test/evaluation test/document_stores test/dataclasses}"
167167

168168
[tool.hatch.envs.e2e]
169169
template = "test"

test/tools/test_component_tool.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ def run(self, messages: list[ChatMessage]) -> dict[str, str]:
4646
class SimpleComponent:
4747
"""A simple component that generates text."""
4848

49+
def warm_up(self):
50+
"""
51+
Prepare the component for use.
52+
"""
53+
4954
@component.output_types(reply=str)
5055
def run(self, text: str) -> dict[str, str]:
5156
"""
@@ -143,7 +148,7 @@ def run(self, documents: list[Document], top_k: int = 5) -> dict[str, str]:
143148
:param top_k: The number of top documents to concatenate
144149
:returns: Dictionary containing the concatenated document contents
145150
"""
146-
return {"concatenated": "\n".join(doc.content for doc in documents[:top_k])}
151+
return {"concatenated": "\n".join(doc.content for doc in documents[:top_k] if doc.content)}
147152

148153

149154
@component
@@ -215,7 +220,7 @@ def test_from_component_with_inputs_from_state_different_names(self):
215220
def test_from_component_with_invalid_inputs_from_state_nested_dict(self):
216221
"""Test that ComponentTool rejects nested dict format for inputs_from_state"""
217222
with pytest.raises(TypeError, match="must be str, not dict"):
218-
ComponentTool(component=SimpleComponent(), inputs_from_state={"documents": {"source": "documents"}})
223+
ComponentTool(component=SimpleComponent(), inputs_from_state={"documents": {"source": "documents"}}) # type: ignore[dict-item]
219224

220225
def test_from_component_with_outputs_to_state(self):
221226
tool = ComponentTool(component=SimpleComponent(), outputs_to_state={"replies": {"source": "reply"}})
@@ -369,13 +374,13 @@ def test_from_component_with_dynamic_input_types(self):
369374

370375
def test_from_component_with_invalid_component(self):
371376
class NotAComponent:
372-
def foo(self, text: str):
377+
def foo(self, text: str) -> dict[str, str]:
373378
return {"reply": f"Hello, {text}!"}
374379

375380
not_a_component = NotAComponent()
376381

377382
with pytest.raises(TypeError):
378-
ComponentTool(component=not_a_component, name="invalid_tool", description="This should fail")
383+
ComponentTool(component=not_a_component, name="invalid_tool", description="This should fail") # type: ignore[arg-type]
379384

380385
def test_component_invoker_with_chat_message_input(self):
381386
tool = ComponentTool(
@@ -392,7 +397,7 @@ class AnnotatedComponent:
392397
"""An annotated component with descriptive parameter docstrings."""
393398

394399
@component.output_types(result=str)
395-
def run(self, text: str, number: int = 42):
400+
def run(self, text: str, number: int = 42) -> dict[str, str]:
396401
"""
397402
Process inputs and return result.
398403
@@ -447,7 +452,7 @@ class ComponentA:
447452
"""Component A with descriptive docstrings."""
448453

449454
@component.output_types(output_a=str)
450-
def run(self, query: str):
455+
def run(self, query: str) -> dict[str, str]:
451456
"""
452457
Process query in component A.
453458
@@ -460,7 +465,7 @@ class ComponentB:
460465
"""Component B with descriptive docstrings."""
461466

462467
@component.output_types(output_b=str)
463-
def run(self, text: str):
468+
def run(self, text: str) -> dict[str, str]:
464469
"""
465470
Process text in component B.
466471
@@ -503,20 +508,20 @@ def run(self, text: str):
503508

504509
def test_warm_up_is_idempotent(self):
505510
"""Test that calling warm_up multiple times only warms up the component once."""
506-
from unittest.mock import MagicMock
511+
from unittest.mock import MagicMock, patch
507512

508513
component = SimpleComponent()
509-
component.warm_up = MagicMock()
510514

511515
tool = ComponentTool(component=component)
512516

513-
# Call warm_up multiple times
514-
tool.warm_up()
515-
tool.warm_up()
516-
tool.warm_up()
517+
with patch.object(component, "warm_up", MagicMock()) as mock_warm_up:
518+
# Call warm_up multiple times
519+
tool.warm_up()
520+
tool.warm_up()
521+
tool.warm_up()
517522

518-
# Component's warm_up should only be called once
519-
component.warm_up.assert_called_once()
523+
# Component's warm_up should only be called once
524+
mock_warm_up.assert_called_once()
520525

521526
def test_from_component_with_callable_params_skipped(self, monkeypatch):
522527
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")

test/tools/test_from_function.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,15 @@ def function_with_annotations(
8787

8888

8989
def test_from_function_missing_type_hint():
90-
def function_missing_type_hint(city) -> str:
90+
def function_missing_type_hint(city) -> str: # type: ignore[no-untyped-def]
9191
return f"Weather report for {city}: 20°C, sunny"
9292

9393
with pytest.raises(ValueError):
9494
create_tool_from_function(function=function_missing_type_hint)
9595

9696

9797
def test_from_function_schema_generation_error():
98-
def function_with_invalid_type_hint(city: "invalid") -> str: # noqa: F821
98+
def function_with_invalid_type_hint(city: "invalid") -> str: # type: ignore[name-defined] # noqa: F821
9999
return f"Weather report for {city}: 20°C, sunny"
100100

101101
with pytest.raises(SchemaGenerationError):

test/tools/test_pipeline_tool.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_init_invalid_pipeline(self):
9898
with pytest.raises(
9999
TypeError, match="The 'pipeline' parameter must be an instance of Pipeline or AsyncPipeline."
100100
):
101-
PipelineTool(pipeline="invalid_pipeline", name="test_tool", description="A test tool")
101+
PipelineTool(pipeline="invalid_pipeline", name="test_tool", description="A test tool") # type: ignore[arg-type]
102102

103103
def test_to_dict(self, sample_pipeline, sample_pipeline_dict):
104104
tool = PipelineTool(
@@ -381,7 +381,7 @@ def test_pipeline_tool_with_invalid_inputs_from_state_nested_dict(self, sample_p
381381
output_mapping={"ranker.documents": "documents"},
382382
name="test_tool",
383383
description="A test tool",
384-
inputs_from_state={"user_query": {"source": "query"}},
384+
inputs_from_state={"user_query": {"source": "query"}}, # type: ignore[dict-item]
385385
)
386386

387387
def test_pipeline_tool_with_valid_outputs_to_state(self, sample_pipeline):

test/tools/test_searchable_toolset.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55

66
import os
7+
from collections.abc import Callable
8+
from typing import Any
79

810
import pytest
911

@@ -82,30 +84,28 @@ def small_catalog(weather_tool, add_tool, multiply_tool):
8284
@pytest.fixture
8385
def large_catalog():
8486
"""Larger catalog that requires discovery (>= 8 tools)."""
85-
return [
86-
create_tool_from_function(fn)
87-
for fn in [
88-
get_weather,
89-
add_numbers,
90-
multiply_numbers,
91-
get_stock_price,
92-
search_database,
93-
send_email,
94-
calculate_tax,
95-
convert_currency,
96-
]
87+
functions: list[Callable[..., Any]] = [
88+
get_weather,
89+
add_numbers,
90+
multiply_numbers,
91+
get_stock_price,
92+
search_database,
93+
send_email,
94+
calculate_tax,
95+
convert_currency,
9796
]
97+
return [create_tool_from_function(fn) for fn in functions]
9898

9999

100100
class TestSearchableToolset:
101101
def test_init_with_invalid_catalog(self):
102102
with pytest.raises(TypeError):
103-
SearchableToolset(catalog=123)
103+
SearchableToolset(catalog=123) # type: ignore[arg-type]
104104
with pytest.raises(TypeError):
105-
SearchableToolset(catalog=[123])
105+
SearchableToolset(catalog=[123]) # type: ignore[arg-type]
106106
with pytest.raises(TypeError):
107107
SearchableToolset(
108-
catalog=Tool(
108+
catalog=Tool( # type: ignore[arg-type]
109109
name="test",
110110
description="test",
111111
parameters={"type": "object", "properties": {}},
@@ -132,6 +132,7 @@ def test_not_implemented_methods(self):
132132
def test_clear(self, large_catalog):
133133
toolset = SearchableToolset(catalog=large_catalog)
134134
toolset.warm_up()
135+
assert toolset._bootstrap_tool is not None
135136
toolset._bootstrap_tool.invoke(tool_keywords="weather temperature city")
136137
assert len(toolset._discovered_tools) > 0
137138
toolset.clear()
@@ -187,7 +188,7 @@ def test_passthrough_contains_by_tool_invalid_type(self, small_catalog):
187188
toolset.warm_up()
188189

189190
with pytest.raises(TypeError):
190-
123 in toolset # noqa: B015
191+
123 in toolset # type: ignore[operator] # noqa: B015
191192

192193
def test_custom_search_threshold(self, large_catalog):
193194
"""Test that custom search_threshold changes passthrough behavior."""
@@ -318,6 +319,7 @@ def test_contains_bootstrap_tool(self, large_catalog):
318319
toolset.warm_up()
319320

320321
assert "search_tools" in toolset
322+
assert toolset._bootstrap_tool is not None
321323
assert toolset._bootstrap_tool in toolset
322324

323325
def test_contains_discovered_tool(self, large_catalog):
@@ -680,7 +682,7 @@ def warm_up(self) -> None:
680682
for i in range(5)
681683
]
682684

683-
toolset = SearchableToolset(catalog=[LazyToolset()] + eager_tools)
685+
toolset = SearchableToolset(catalog=[LazyToolset()] + eager_tools) # type: ignore[arg-type]
684686
toolset.warm_up()
685687

686688
# Should have 5 lazy + 5 eager = 10 tools

test/tools/test_serde_utils.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
from typing import Any
6+
57
import pytest
68

79
from haystack.tools import Tool, Toolset, deserialize_tools_or_toolset_inplace, serialize_tools_or_toolset
@@ -91,7 +93,7 @@ def test_deserialize_tools_inplace(self):
9193
name="weather", description="Get weather report", parameters=parameters, function=get_weather_report
9294
)
9395

94-
data = {"tools": [tool.to_dict()]}
96+
data: dict[str, Any] = {"tools": [tool.to_dict()]}
9597
deserialize_tools_or_toolset_inplace(data)
9698
assert data["tools"] == [tool]
9799

@@ -104,7 +106,7 @@ def test_deserialize_tools_inplace(self):
104106
assert data == {"no_tools": 123}
105107

106108
def test_deserialize_tools_inplace_failures(self):
107-
data = {"key": "value"}
109+
data: dict[str, Any] = {"key": "value"}
108110
deserialize_tools_or_toolset_inplace(data)
109111
assert data == {"key": "value"}
110112

@@ -186,7 +188,8 @@ def test_deserialize_list_of_toolsets_inplace(self):
186188

187189
assert isinstance(data["tools"], list)
188190
assert len(data["tools"]) == 2
189-
assert all(isinstance(ts, Toolset) for ts in data["tools"])
191+
assert isinstance(data["tools"][0], Toolset)
192+
assert isinstance(data["tools"][1], Toolset)
190193
assert data["tools"][0][0].name == "weather"
191194
assert data["tools"][1][0].name == "calculator"
192195

@@ -201,7 +204,8 @@ def test_serialize_mixed_list_tools_and_toolsets(self):
201204

202205
toolset = Toolset([tool2])
203206

204-
data = serialize_tools_or_toolset([tool1, toolset])
207+
tools: list[Tool | Toolset] = [tool1, toolset]
208+
data = serialize_tools_or_toolset(tools)
205209

206210
assert isinstance(data, list)
207211
assert len(data) == 2
@@ -230,7 +234,8 @@ def test_serialize_mixed_list_multiple_tools_and_toolsets(self):
230234

231235
toolset = Toolset([tool4, tool5])
232236

233-
data = serialize_tools_or_toolset([tool1, tool2, toolset, tool3])
237+
tools: list[Tool | Toolset] = [tool1, tool2, toolset, tool3]
238+
data = serialize_tools_or_toolset(tools)
234239

235240
assert isinstance(data, list)
236241
assert len(data) == 4

test/tools/test_tool.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def test_init_invalid_output_structure_config_not_dict(self):
7777
description="irrelevant",
7878
parameters={"type": "object", "properties": {"city": {"type": "string"}}},
7979
function=get_weather_report,
80-
outputs_to_state={"documents": ["some_value"]},
80+
outputs_to_state={"documents": ["some_value"]}, # type: ignore[dict-item]
8181
)
8282

8383
@pytest.mark.parametrize(
@@ -258,7 +258,7 @@ def test_inputs_from_state_validation_with_non_string_value(self):
258258
description="Get weather report",
259259
parameters=parameters,
260260
function=get_weather_report,
261-
inputs_from_state={"state_key": {"source": "city"}},
261+
inputs_from_state={"state_key": {"source": "city"}}, # type: ignore[dict-item]
262262
)
263263

264264
def test_inputs_from_state_validation_with_valid_parameter(self):

0 commit comments

Comments
 (0)