Skip to content

Commit 62627fa

Browse files
feat(ollama): add max_retries to chat generator (#2899)
* add max_retries to chat generator * fix: refine ollama retry behavior and tests * fix: tighten ollama retry typing and document retry policy
1 parent 7a9a4f5 commit 62627fa

3 files changed

Lines changed: 189 additions & 21 deletions

File tree

integrations/ollama/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ classifiers = [
2626
"Programming Language :: Python :: Implementation :: CPython",
2727
"Programming Language :: Python :: Implementation :: PyPy",
2828
]
29-
dependencies = ["haystack-ai>=2.22.0", "ollama>=0.5.0", "pydantic"]
29+
dependencies = ["haystack-ai>=2.22.0", "ollama>=0.5.0", "pydantic", "tenacity>=8.2.3"]
3030

3131
[project.urls]
3232
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/ollama#readme"

integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py

Lines changed: 90 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,44 @@
2222
)
2323
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
2424
from pydantic.json_schema import JsonSchemaValue
25+
from tenacity import RetryCallState, retry, retry_if_exception, wait_exponential
2526

26-
from ollama import AsyncClient, ChatResponse, Client
27+
from ollama import AsyncClient, ChatResponse, Client, ResponseError
2728

2829
FINISH_REASON_MAPPING: dict[str, FinishReason] = {
2930
"stop": "stop",
3031
"tool_calls": "tool_calls",
3132
# we skip load and unload reasons
3233
}
3334

35+
HTTP_STATUS_TOO_MANY_REQUESTS = 429
36+
HTTP_STATUS_SERVER_ERROR_MIN = 500
37+
HTTP_STATUS_SERVER_ERROR_MAX_EXCLUSIVE = 600
38+
39+
40+
def _stop_after_instance_max_retries(retry_state: RetryCallState) -> bool:
41+
"""
42+
Stop retries after `self.max_retries + 1` attempts.
43+
"""
44+
instance = retry_state.args[0]
45+
return retry_state.attempt_number >= instance.max_retries + 1
46+
47+
48+
def _is_retryable_exception(exc: BaseException) -> bool:
49+
"""
50+
Return True for transient failures that should be retried.
51+
52+
Retries are attempted for:
53+
- HTTP 429 responses
54+
- HTTP 5xx responses
55+
- transport-level connection/timeout errors
56+
"""
57+
if isinstance(exc, ResponseError):
58+
return exc.status_code == HTTP_STATUS_TOO_MANY_REQUESTS or (
59+
HTTP_STATUS_SERVER_ERROR_MIN <= exc.status_code < HTTP_STATUS_SERVER_ERROR_MAX_EXCLUSIVE
60+
)
61+
return isinstance(exc, (ConnectionError, TimeoutError))
62+
3463

3564
def _convert_chatmessage_to_ollama_format(message: ChatMessage) -> dict[str, Any]:
3665
"""
@@ -216,6 +245,7 @@ def __init__(
216245
url: str = "http://localhost:11434",
217246
generation_kwargs: dict[str, Any] | None = None,
218247
timeout: int = 120,
248+
max_retries: int = 0,
219249
keep_alive: float | str | None = None,
220250
streaming_callback: Callable[[StreamingChunk], None] | None = None,
221251
tools: ToolsType | None = None,
@@ -233,6 +263,9 @@ def __init__(
233263
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
234264
:param timeout:
235265
The number of seconds before throwing a timeout error from the Ollama API.
266+
:param max_retries:
267+
Maximum number of retries to attempt for failed requests (HTTP 429, 5xx, connection/timeout errors).
268+
Uses exponential backoff between attempts. Set to 0 (default) to disable retries.
236269
:param think:
237270
If True, the model will "think" before producing a response.
238271
Only [thinking models](https://ollama.com/search?c=thinking) support this feature.
@@ -268,6 +301,7 @@ def __init__(
268301
self.url = url
269302
self.generation_kwargs = generation_kwargs or {}
270303
self.timeout = timeout
304+
self.max_retries = max_retries
271305
self.keep_alive = keep_alive
272306
self.streaming_callback = streaming_callback
273307
self.tools = tools # Store original tools for serialization
@@ -292,6 +326,7 @@ def to_dict(self) -> dict[str, Any]:
292326
url=self.url,
293327
generation_kwargs=self.generation_kwargs,
294328
timeout=self.timeout,
329+
max_retries=self.max_retries,
295330
keep_alive=self.keep_alive,
296331
streaming_callback=callback_name,
297332
tools=serialize_tools_or_toolset(self.tools),
@@ -469,6 +504,56 @@ async def _handle_streaming_response_async(
469504

470505
return {"replies": [reply]}
471506

507+
@retry(
508+
reraise=True,
509+
stop=_stop_after_instance_max_retries,
510+
retry=retry_if_exception(_is_retryable_exception),
511+
wait=wait_exponential(),
512+
)
513+
def _chat(
514+
self,
515+
*,
516+
messages: list[dict[str, Any]],
517+
tools: list[dict[str, Any]] | None,
518+
is_stream: bool,
519+
generation_kwargs: dict[str, Any],
520+
) -> ChatResponse | Iterator[ChatResponse]:
521+
return self._client.chat(
522+
model=self.model,
523+
messages=messages,
524+
tools=tools,
525+
stream=is_stream, # type: ignore[call-overload] # Ollama expects Literal[True] or Literal[False], not bool
526+
keep_alive=self.keep_alive,
527+
options=generation_kwargs,
528+
format=self.response_format,
529+
think=self.think,
530+
)
531+
532+
@retry(
533+
reraise=True,
534+
stop=_stop_after_instance_max_retries,
535+
retry=retry_if_exception(_is_retryable_exception),
536+
wait=wait_exponential(),
537+
)
538+
async def _chat_async(
539+
self,
540+
*,
541+
messages: list[dict[str, Any]],
542+
tools: list[dict[str, Any]] | None,
543+
is_stream: bool,
544+
generation_kwargs: dict[str, Any],
545+
) -> ChatResponse | AsyncIterator[ChatResponse]:
546+
return await self._async_client.chat(
547+
model=self.model,
548+
messages=messages,
549+
tools=tools,
550+
stream=is_stream, # type: ignore[call-overload] # Ollama expects Literal[True] or Literal[False], not bool
551+
keep_alive=self.keep_alive,
552+
options=generation_kwargs,
553+
format=self.response_format,
554+
think=self.think,
555+
)
556+
472557
@component.output_types(replies=list[ChatMessage])
473558
def run(
474559
self,
@@ -518,15 +603,8 @@ def run(
518603

519604
ollama_messages = [_convert_chatmessage_to_ollama_format(m) for m in messages]
520605

521-
response = self._client.chat(
522-
model=self.model,
523-
messages=ollama_messages,
524-
tools=ollama_tools,
525-
stream=is_stream, # type: ignore[call-overload] # Ollama expects Literal[True] or Literal[False], not bool
526-
keep_alive=self.keep_alive,
527-
options=generation_kwargs,
528-
format=self.response_format,
529-
think=self.think,
606+
response = self._chat(
607+
messages=ollama_messages, tools=ollama_tools, is_stream=is_stream, generation_kwargs=generation_kwargs
530608
)
531609

532610
if isinstance(response, Iterator):
@@ -579,15 +657,8 @@ async def run_async(
579657

580658
ollama_messages = [_convert_chatmessage_to_ollama_format(m) for m in messages]
581659

582-
response = await self._async_client.chat(
583-
model=self.model,
584-
messages=ollama_messages,
585-
tools=ollama_tools,
586-
stream=is_stream, # type: ignore[call-overload] # Ollama expects Literal[True] or Literal[False], not bool
587-
keep_alive=self.keep_alive,
588-
options=generation_kwargs,
589-
format=self.response_format,
590-
think=self.think,
660+
response = await self._chat_async(
661+
messages=ollama_messages, tools=ollama_tools, is_stream=is_stream, generation_kwargs=generation_kwargs
591662
)
592663

593664
if isinstance(response, AsyncIterator):

integrations/ollama/tests/test_chat_generator.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
from typing import Annotated
3-
from unittest.mock import Mock, patch
3+
from unittest.mock import AsyncMock, Mock, patch
44

55
import pytest
66
from haystack.components.generators.utils import print_streaming_chunk
@@ -518,6 +518,7 @@ def test_init_default(self):
518518
assert component.url == "http://localhost:11434"
519519
assert component.generation_kwargs == {}
520520
assert component.timeout == 120
521+
assert component.max_retries == 0
521522
assert component.streaming_callback is None
522523
assert component.tools is None
523524
assert component.keep_alive is None
@@ -529,6 +530,7 @@ def test_init(self, tools):
529530
url="http://my-custom-endpoint:11434",
530531
generation_kwargs={"temperature": 0.5},
531532
timeout=5,
533+
max_retries=2,
532534
keep_alive="10m",
533535
streaming_callback=print_streaming_chunk,
534536
tools=tools,
@@ -539,6 +541,7 @@ def test_init(self, tools):
539541
assert component.url == "http://my-custom-endpoint:11434"
540542
assert component.generation_kwargs == {"temperature": 0.5}
541543
assert component.timeout == 5
544+
assert component.max_retries == 2
542545
assert component.keep_alive == "10m"
543546
assert component.streaming_callback is print_streaming_chunk
544547
assert component.tools == tools
@@ -603,6 +606,7 @@ def test_to_dict(self):
603606
"type": "haystack_integrations.components.generators.ollama.chat.chat_generator.OllamaChatGenerator",
604607
"init_parameters": {
605608
"timeout": 120,
609+
"max_retries": 0,
606610
"model": "llama2",
607611
"url": "custom_url",
608612
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
@@ -650,6 +654,7 @@ def test_from_dict(self):
650654
"type": "haystack_integrations.components.generators.ollama.chat.chat_generator.OllamaChatGenerator",
651655
"init_parameters": {
652656
"timeout": 120,
657+
"max_retries": 0,
653658
"model": "llama2",
654659
"url": "custom_url",
655660
"keep_alive": "5m",
@@ -689,6 +694,7 @@ def test_from_dict(self):
689694
"some_test_param": "test-params",
690695
}
691696
assert component.timeout == 120
697+
assert component.max_retries == 0
692698
assert component.tools == [tool]
693699
assert component.response_format == {
694700
"type": "object",
@@ -790,6 +796,97 @@ def test_run(self, mock_client):
790796
assert result["replies"][0].text == "Fine. How can I help you today?"
791797
assert result["replies"][0].role == "assistant"
792798

799+
@patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client")
800+
def test_run_retries_after_failure(self, mock_client):
801+
generator = OllamaChatGenerator(max_retries=1)
802+
803+
mock_response = ChatResponse(
804+
model="qwen3:0.6b",
805+
created_at="2023-12-12T14:13:43.416799Z",
806+
message={"role": "assistant", "content": "Recovered after retry"},
807+
done=True,
808+
prompt_eval_count=1,
809+
eval_count=2,
810+
)
811+
812+
mock_client_instance = mock_client.return_value
813+
mock_client_instance.chat.side_effect = [ResponseError("temporary failure", status_code=500), mock_response]
814+
815+
result = generator.run(messages=[ChatMessage.from_user("Hello!")])
816+
817+
assert mock_client_instance.chat.call_count == 2
818+
assert result["replies"][0].text == "Recovered after retry"
819+
820+
@patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client")
821+
def test_run_raises_after_retry_exhausted(self, mock_client):
822+
generator = OllamaChatGenerator(max_retries=1)
823+
mock_client_instance = mock_client.return_value
824+
mock_client_instance.chat.side_effect = ResponseError("persistent failure", status_code=503)
825+
826+
with pytest.raises(ResponseError, match="persistent failure"):
827+
generator.run(messages=[ChatMessage.from_user("Hello!")])
828+
829+
assert mock_client_instance.chat.call_count == 2
830+
831+
@patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client")
832+
def test_run_does_not_retry_non_retryable_error(self, mock_client):
833+
generator = OllamaChatGenerator(max_retries=2)
834+
mock_client_instance = mock_client.return_value
835+
mock_client_instance.chat.side_effect = ResponseError("bad request", status_code=400)
836+
837+
with pytest.raises(ResponseError, match="bad request"):
838+
generator.run(messages=[ChatMessage.from_user("Hello!")])
839+
840+
assert mock_client_instance.chat.call_count == 1
841+
842+
@pytest.mark.asyncio
843+
@patch("haystack_integrations.components.generators.ollama.chat.chat_generator.AsyncClient")
844+
async def test_run_async_does_not_retry_non_retryable_error(self, mock_async_client):
845+
generator = OllamaChatGenerator(max_retries=2)
846+
mock_async_client_instance = mock_async_client.return_value
847+
mock_async_client_instance.chat = AsyncMock(side_effect=ResponseError("bad request", status_code=400))
848+
849+
with pytest.raises(ResponseError, match="bad request"):
850+
await generator.run_async(messages=[ChatMessage.from_user("Hello!")])
851+
852+
assert mock_async_client_instance.chat.call_count == 1
853+
854+
@pytest.mark.asyncio
855+
@patch("haystack_integrations.components.generators.ollama.chat.chat_generator.AsyncClient")
856+
async def test_run_async_retries_after_failure(self, mock_async_client):
857+
generator = OllamaChatGenerator(max_retries=1)
858+
859+
mock_response = ChatResponse(
860+
model="qwen3:0.6b",
861+
created_at="2023-12-12T14:13:43.416799Z",
862+
message={"role": "assistant", "content": "Recovered after retry"},
863+
done=True,
864+
prompt_eval_count=1,
865+
eval_count=2,
866+
)
867+
868+
mock_async_client_instance = mock_async_client.return_value
869+
mock_async_client_instance.chat = AsyncMock(
870+
side_effect=[ResponseError("temporary failure", status_code=500), mock_response]
871+
)
872+
873+
result = await generator.run_async(messages=[ChatMessage.from_user("Hello!")])
874+
875+
assert mock_async_client_instance.chat.call_count == 2
876+
assert result["replies"][0].text == "Recovered after retry"
877+
878+
@pytest.mark.asyncio
879+
@patch("haystack_integrations.components.generators.ollama.chat.chat_generator.AsyncClient")
880+
async def test_run_async_raises_after_retry_exhausted(self, mock_async_client):
881+
generator = OllamaChatGenerator(max_retries=1)
882+
mock_async_client_instance = mock_async_client.return_value
883+
mock_async_client_instance.chat = AsyncMock(side_effect=ResponseError("persistent failure", status_code=503))
884+
885+
with pytest.raises(ResponseError, match="persistent failure"):
886+
await generator.run_async(messages=[ChatMessage.from_user("Hello!")])
887+
888+
assert mock_async_client_instance.chat.call_count == 2
889+
793890
@patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client")
794891
def test_run_streaming(self, mock_client):
795892
collected_chunks = []

0 commit comments

Comments
 (0)