Skip to content

Commit 42ff42d

Browse files
anakin87sjrl
andauthored
fix: fix Ollama types + add py.typed (#1922)
* fix: fix Ollama types + add py.typed * Update integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> * better typing for embedders --------- Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com>
1 parent d4fe7c2 commit 42ff42d

8 files changed

Lines changed: 35 additions & 36 deletions

File tree

.github/workflows/ollama.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,9 @@ jobs:
8484
- name: Install Hatch
8585
run: pip install --upgrade hatch
8686

87-
# TODO: Once this integration is properly typed, use hatch run test:types
88-
# https://github.com/deepset-ai/haystack-core-integrations/issues/1771
8987
- name: Lint
9088
if: matrix.python-version == '3.9' && runner.os == 'Linux'
91-
run: hatch run fmt-check && hatch run lint:typing
89+
run: hatch run fmt-check && hatch run test:types
9290

9391
- name: Generate docs
9492
if: matrix.python-version == '3.9' && runner.os == 'Linux'

integrations/ollama/pyproject.toml

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -71,18 +71,14 @@ integration = 'pytest -m "integration" {args:tests}'
7171
all = 'pytest {args:tests}'
7272
cov-retry = 'all --cov=haystack_integrations --reruns 3 --reruns-delay 30 -x'
7373

74-
types = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
74+
types = """mypy -p haystack_integrations.components.embedders.ollama \
75+
-p haystack_integrations.components.generators.ollama {args}"""
7576

76-
# TODO: remove lint environment once this integration is properly typed
77-
# test environment should be used instead
78-
# https://github.com/deepset-ai/haystack-core-integrations/issues/1771
79-
[tool.hatch.envs.lint]
80-
installer = "uv"
81-
detached = true
82-
dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"]
83-
84-
[tool.hatch.envs.lint.scripts]
85-
typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
77+
[tool.mypy]
78+
install_types = true
79+
non_interactive = true
80+
check_untyped_defs = true
81+
disallow_incomplete_defs = true
8682

8783
[tool.hatch.metadata]
8884
allow-direct-references = true
@@ -173,7 +169,3 @@ log_cli = true
173169
addopts = ["--import-mode=importlib"]
174170
asyncio_mode = "auto"
175171
asyncio_default_fixture_loop_scope = "class"
176-
177-
[[tool.mypy.overrides]]
178-
module = ["haystack.*", "haystack_integrations.*", "pytest.*", "ollama.*", "pydantic.*"]
179-
ignore_missing_imports = true

integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import asyncio
2-
from typing import Any, Dict, List, Optional
2+
from typing import Any, Dict, List, Optional, Union
33

44
from haystack import Document, component
55
from tqdm import tqdm
@@ -114,7 +114,7 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
114114

115115
def _embed_batch(
116116
self, texts_to_embed: List[str], batch_size: int, generation_kwargs: Optional[Dict[str, Any]] = None
117-
):
117+
) -> List[List[float]]:
118118
"""
119119
Internal method to embed a batch of texts.
120120
"""
@@ -132,7 +132,7 @@ def _embed_batch(
132132

133133
async def _embed_batch_async(
134134
self, texts_to_embed: List[str], batch_size: int, generation_kwargs: Optional[Dict[str, Any]] = None
135-
):
135+
) -> List[List[float]]:
136136
"""
137137
Internal method to embed a batch of texts asynchronously.
138138
"""
@@ -160,7 +160,9 @@ async def _embed_batch_async(
160160
return all_embeddings
161161

162162
@component.output_types(documents=List[Document], meta=Dict[str, Any])
163-
def run(self, documents: List[Document], generation_kwargs: Optional[Dict[str, Any]] = None):
163+
def run(
164+
self, documents: List[Document], generation_kwargs: Optional[Dict[str, Any]] = None
165+
) -> Dict[str, Union[List[Document], Dict[str, Any]]]:
164166
"""
165167
Runs an Ollama Model to compute embeddings of the provided documents.
166168
@@ -193,7 +195,9 @@ def run(self, documents: List[Document], generation_kwargs: Optional[Dict[str, A
193195
return {"documents": documents, "meta": {"model": self.model}}
194196

195197
@component.output_types(documents=List[Document], meta=Dict[str, Any])
196-
async def run_async(self, documents: List[Document], generation_kwargs: Optional[Dict[str, Any]] = None):
198+
async def run_async(
199+
self, documents: List[Document], generation_kwargs: Optional[Dict[str, Any]] = None
200+
) -> Dict[str, Union[List[Document], Dict[str, Any]]]:
197201
"""
198202
Asynchronously run an Ollama Model to compute embeddings of the provided documents.
199203

integrations/ollama/src/haystack_integrations/components/embedders/ollama/text_embedder.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Optional
1+
from typing import Any, Dict, List, Optional, Union
22

33
from haystack import component
44

@@ -49,7 +49,9 @@ def __init__(
4949
self._async_client = AsyncClient(host=self.url, timeout=self.timeout)
5050

5151
@component.output_types(embedding=List[float], meta=Dict[str, Any])
52-
def run(self, text: str, generation_kwargs: Optional[Dict[str, Any]] = None):
52+
def run(
53+
self, text: str, generation_kwargs: Optional[Dict[str, Any]] = None
54+
) -> Dict[str, Union[List[float], Dict[str, Any]]]:
5355
"""
5456
Runs an Ollama Model to compute embeddings of the provided text.
5557
@@ -69,7 +71,9 @@ def run(self, text: str, generation_kwargs: Optional[Dict[str, Any]] = None):
6971
return result
7072

7173
@component.output_types(embedding=List[float], meta=Dict[str, Any])
72-
async def run_async(self, text: str, generation_kwargs: Optional[Dict[str, Any]] = None):
74+
async def run_async(
75+
self, text: str, generation_kwargs: Optional[Dict[str, Any]] = None
76+
) -> Dict[str, Union[List[float], Dict[str, Any]]]:
7377
"""
7478
Asynchronously run an Ollama Model to compute embeddings of the provided text.
7579

integrations/ollama/src/haystack_integrations/components/embedders/py.typed

Whitespace-only changes.

integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def __init__(
220220
- JSON Schema: The response is formatted as a JSON object
221221
that adheres to the specified JSON Schema. (needs Ollama ≥ 0.1.34)
222222
"""
223-
_check_duplicate_tool_names(tools)
223+
_check_duplicate_tool_names(list(tools or []))
224224

225225
self.model = model
226226
self.url = url
@@ -361,10 +361,11 @@ def _handle_streaming_response(
361361
# Compose final reply
362362
text = "".join(c.content for c in chunks)
363363

364-
tool_calls = [
365-
ToolCall(tool_name=name_by_id[tool_call_id], arguments=arg_by_id.get(tool_call_id))
366-
for tool_call_id in id_order
367-
]
364+
tool_calls = []
365+
for tool_call_id in id_order:
366+
arguments = arg_by_id.get(tool_call_id, {})
367+
assert isinstance(arguments, dict) # final arguments are a dictionary # noqa: S101
368+
tool_calls.append(ToolCall(tool_name=name_by_id[tool_call_id], arguments=arguments))
368369

369370
reply = ChatMessage.from_assistant(
370371
text=text,
@@ -381,7 +382,7 @@ def run(
381382
tools: Optional[Union[List[Tool], Toolset]] = None,
382383
*,
383384
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
384-
):
385+
) -> Dict[str, List[ChatMessage]]:
385386
"""
386387
Runs an Ollama Model on a given chat history.
387388
@@ -405,7 +406,7 @@ def run(
405406
"""
406407
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
407408
tools = tools or self.tools
408-
_check_duplicate_tool_names(tools)
409+
_check_duplicate_tool_names(list(tools or []))
409410

410411
# Convert Toolset → list[Tool] for JSON serialization
411412
if isinstance(tools, Toolset):
@@ -421,7 +422,7 @@ def run(
421422
model=self.model,
422423
messages=ollama_messages,
423424
tools=ollama_tools,
424-
stream=is_stream,
425+
stream=is_stream, # type: ignore[call-overload] # Ollama expects Literal[True] or Literal[False], not bool
425426
keep_alive=self.keep_alive,
426427
options=generation_kwargs,
427428
format=self.response_format,

integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def run(
239239
generation_kwargs: Optional[Dict[str, Any]] = None,
240240
*,
241241
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
242-
):
242+
) -> Dict[str, List[Any]]:
243243
"""
244244
Runs an Ollama Model on the given prompt.
245245
@@ -263,7 +263,7 @@ def run(
263263
response = self._client.generate(
264264
model=self.model,
265265
prompt=prompt,
266-
stream=stream,
266+
stream=stream, # type: ignore[call-overload] # Ollama expects Literal[True] or Literal[False], not bool
267267
keep_alive=self.keep_alive,
268268
options=generation_kwargs,
269269
)

integrations/ollama/src/haystack_integrations/components/generators/py.typed

Whitespace-only changes.

0 commit comments

Comments
 (0)