Skip to content

Commit eba79ad

Browse files
authored
feat: Ollama - accept str as ChatGenerator input; deprecate generator; rm generator example (#3388)
1 parent 56aee49 commit eba79ad

6 files changed

Lines changed: 74 additions & 60 deletions

File tree

integrations/ollama/examples/generator_example.py

Lines changed: 0 additions & 55 deletions
This file was deleted.

integrations/ollama/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ classifiers = [
2727
"Programming Language :: Python :: Implementation :: CPython",
2828
"Programming Language :: Python :: Implementation :: PyPy",
2929
]
30-
dependencies = ["haystack-ai>=2.22.0", "ollama>=0.5.4", "pydantic>=2.12.0", "tenacity>=8.2.3"]
30+
dependencies = ["haystack-ai>=2.30.0", "ollama>=0.5.4", "pydantic>=2.12.0", "tenacity>=8.2.3"]
3131

3232
[project.urls]
3333
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/ollama#readme"

integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any, Literal
55

66
from haystack import component, default_from_dict, default_to_dict
7+
from haystack.components.generators.utils import _normalize_messages
78
from haystack.dataclasses import (
89
AsyncStreamingCallbackT,
910
ChatMessage,
@@ -577,7 +578,7 @@ async def _chat_async(
577578
@component.output_types(replies=list[ChatMessage])
578579
def run(
579580
self,
580-
messages: list[ChatMessage],
581+
messages: list[ChatMessage] | str,
581582
generation_kwargs: dict[str, Any] | None = None,
582583
tools: ToolsType | None = None,
583584
*,
@@ -587,7 +588,8 @@ def run(
587588
Runs an Ollama Model on a given chat history.
588589
589590
:param messages:
590-
A list of ChatMessage instances representing the input messages.
591+
A list of ChatMessage instances representing the input messages. If a string is provided, it is converted
592+
to a list containing a ChatMessage with user role.
591593
:param generation_kwargs:
592594
Per-call overrides for Ollama inference options.
593595
These are merged on top of the instance-level `generation_kwargs`.
@@ -603,6 +605,7 @@ def run(
603605
:returns: A dictionary with the following keys:
604606
- `replies`: A list of ChatMessages containing the model's response
605607
"""
608+
messages = _normalize_messages(messages)
606609

607610
# Validate and select the streaming callback
608611
callback = select_streaming_callback(
@@ -636,7 +639,7 @@ def run(
636639
@component.output_types(replies=list[ChatMessage])
637640
async def run_async(
638641
self,
639-
messages: list[ChatMessage],
642+
messages: list[ChatMessage] | str,
640643
generation_kwargs: dict[str, Any] | None = None,
641644
tools: ToolsType | None = None,
642645
*,
@@ -646,7 +649,8 @@ async def run_async(
646649
Async version of run. Runs an Ollama Model on a given chat history.
647650
648651
:param messages:
649-
A list of ChatMessage instances representing the input messages.
652+
A list of ChatMessage instances representing the input messages. If a string is provided, it is converted
653+
to a list containing a ChatMessage with user role.
650654
:param generation_kwargs:
651655
Per-call overrides for Ollama inference options.
652656
These are merged on top of the instance-level `generation_kwargs`.
@@ -659,6 +663,8 @@ async def run_async(
659663
:returns: A dictionary with the following keys:
660664
- `replies`: A list of ChatMessages containing the model's response
661665
"""
666+
messages = _normalize_messages(messages)
667+
662668
# Validate and select the streaming callback
663669
callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True)
664670

integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from collections.abc import Callable
23
from typing import Any
34

@@ -139,6 +140,13 @@ def __init__(
139140
- any negative number which will keep the model loaded in memory (e.g. -1 or "-1m")
140141
- '0' which will unload the model immediately after generating a response.
141142
"""
143+
warnings.warn(
144+
"The `OllamaGenerator` component is deprecated and will be removed in a future version. "
145+
"Use `OllamaChatGenerator` instead, which now also supports string inputs.",
146+
FutureWarning,
147+
stacklevel=2,
148+
)
149+
142150
self.timeout = timeout
143151
self.raw = raw
144152
self.template = template

integrations/ollama/tests/test_chat_generator.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,6 +1007,57 @@ def test_run(self, mock_client):
10071007
assert result["replies"][0].text == "Fine. How can I help you today?"
10081008
assert result["replies"][0].role == "assistant"
10091009

1010+
@patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client")
1011+
def test_run_with_string_input(self, mock_client):
1012+
generator = OllamaChatGenerator()
1013+
1014+
mock_response = ChatResponse(
1015+
model="qwen3:0.6b",
1016+
created_at="2023-12-12T14:13:43.416799Z",
1017+
message={"role": "assistant", "content": "Paris"},
1018+
done=True,
1019+
prompt_eval_count=1,
1020+
eval_count=1,
1021+
)
1022+
1023+
mock_client_instance = mock_client.return_value
1024+
mock_client_instance.chat.return_value = mock_response
1025+
1026+
result = generator.run("What's the capital of France?")
1027+
1028+
_, kwargs = mock_client_instance.chat.call_args
1029+
assert kwargs["messages"] == [{"role": "user", "content": "What's the capital of France?"}]
1030+
1031+
assert isinstance(result["replies"], list)
1032+
assert len(result["replies"]) == 1
1033+
assert isinstance(result["replies"][0], ChatMessage)
1034+
1035+
@pytest.mark.asyncio
1036+
@patch("haystack_integrations.components.generators.ollama.chat.chat_generator.AsyncClient")
1037+
async def test_run_async_with_string_input(self, mock_async_client):
1038+
generator = OllamaChatGenerator()
1039+
1040+
mock_response = ChatResponse(
1041+
model="qwen3:0.6b",
1042+
created_at="2023-12-12T14:13:43.416799Z",
1043+
message={"role": "assistant", "content": "Paris"},
1044+
done=True,
1045+
prompt_eval_count=1,
1046+
eval_count=1,
1047+
)
1048+
1049+
mock_async_client_instance = mock_async_client.return_value
1050+
mock_async_client_instance.chat = AsyncMock(return_value=mock_response)
1051+
1052+
result = await generator.run_async("What's the capital of France?")
1053+
1054+
_, kwargs = mock_async_client_instance.chat.call_args
1055+
assert kwargs["messages"] == [{"role": "user", "content": "What's the capital of France?"}]
1056+
1057+
assert isinstance(result["replies"], list)
1058+
assert len(result["replies"]) == 1
1059+
assert isinstance(result["replies"][0], ChatMessage)
1060+
10101061
@patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client")
10111062
def test_run_retries_after_failure(self, mock_client):
10121063
generator = OllamaChatGenerator(max_retries=1)

uv.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,7 @@
22
# attacks via compromised dependencies. uv resolves this relative to the current clock at
33
# install/lock time, so no manual date updates are needed.
44
exclude-newer = "24 hours"
5+
6+
# haystack-ai is a first-party dependency
7+
[exclude-newer-package]
8+
haystack-ai = false

0 commit comments

Comments
 (0)