Skip to content

Commit ae3c423

Browse files
committed
add max_retries to chat generator
1 parent 24935e4 commit ae3c423

3 files changed

Lines changed: 81 additions & 19 deletions

File tree

integrations/ollama/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ classifiers = [
2626
"Programming Language :: Python :: Implementation :: CPython",
2727
"Programming Language :: Python :: Implementation :: PyPy",
2828
]
29-
dependencies = ["haystack-ai>=2.22.0", "ollama>=0.5.0", "pydantic"]
29+
dependencies = ["haystack-ai>=2.22.0", "ollama>=0.5.0", "pydantic", "tenacity"]
3030

3131
[project.urls]
3232
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/ollama#readme"

integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
)
2323
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
2424
from pydantic.json_schema import JsonSchemaValue
25+
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
2526

2627
from ollama import AsyncClient, ChatResponse, Client
2728

@@ -216,6 +217,7 @@ def __init__(
216217
url: str = "http://localhost:11434",
217218
generation_kwargs: dict[str, Any] | None = None,
218219
timeout: int = 120,
220+
max_retries: int = 0,
219221
keep_alive: float | str | None = None,
220222
streaming_callback: Callable[[StreamingChunk], None] | None = None,
221223
tools: ToolsType | None = None,
@@ -233,6 +235,8 @@ def __init__(
233235
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
234236
:param timeout:
235237
The number of seconds before throwing a timeout error from the Ollama API.
238+
:param max_retries:
239+
Maximum number of retries to attempt for failed requests.
236240
:param think:
237241
If True, the model will "think" before producing a response.
238242
Only [thinking models](https://ollama.com/search?c=thinking) support this feature.
@@ -268,6 +272,7 @@ def __init__(
268272
self.url = url
269273
self.generation_kwargs = generation_kwargs or {}
270274
self.timeout = timeout
275+
self.max_retries = max_retries
271276
self.keep_alive = keep_alive
272277
self.streaming_callback = streaming_callback
273278
self.tools = tools # Store original tools for serialization
@@ -292,6 +297,7 @@ def to_dict(self) -> dict[str, Any]:
292297
url=self.url,
293298
generation_kwargs=self.generation_kwargs,
294299
timeout=self.timeout,
300+
max_retries=self.max_retries,
295301
keep_alive=self.keep_alive,
296302
streaming_callback=callback_name,
297303
tools=serialize_tools_or_toolset(self.tools),
@@ -518,16 +524,25 @@ def run(
518524

519525
ollama_messages = [_convert_chatmessage_to_ollama_format(m) for m in messages]
520526

521-
response = self._client.chat(
522-
model=self.model,
523-
messages=ollama_messages,
524-
tools=ollama_tools,
525-
stream=is_stream, # type: ignore[call-overload] # Ollama expects Literal[True] or Literal[False], not bool
526-
keep_alive=self.keep_alive,
527-
options=generation_kwargs,
528-
format=self.response_format,
529-
think=self.think,
527+
@retry(
528+
reraise=True,
529+
stop=stop_after_attempt(self.max_retries + 1),
530+
retry=retry_if_exception_type(Exception),
531+
wait=wait_exponential(),
530532
)
533+
def chat_with_retry() -> ChatResponse | Iterator[ChatResponse]:
534+
return self._client.chat(
535+
model=self.model,
536+
messages=ollama_messages,
537+
tools=ollama_tools,
538+
stream=is_stream, # type: ignore[call-overload] # Ollama expects Literal[True] or Literal[False], not bool
539+
keep_alive=self.keep_alive,
540+
options=generation_kwargs,
541+
format=self.response_format,
542+
think=self.think,
543+
)
544+
545+
response = chat_with_retry()
531546

532547
if isinstance(response, Iterator):
533548
return self._handle_streaming_response(response_iter=response, callback=callback)
@@ -579,16 +594,25 @@ async def run_async(
579594

580595
ollama_messages = [_convert_chatmessage_to_ollama_format(m) for m in messages]
581596

582-
response = await self._async_client.chat(
583-
model=self.model,
584-
messages=ollama_messages,
585-
tools=ollama_tools,
586-
stream=is_stream, # type: ignore[call-overload] # Ollama expects Literal[True] or Literal[False], not bool
587-
keep_alive=self.keep_alive,
588-
options=generation_kwargs,
589-
format=self.response_format,
590-
think=self.think,
597+
@retry(
598+
reraise=True,
599+
stop=stop_after_attempt(self.max_retries + 1),
600+
retry=retry_if_exception_type(Exception),
601+
wait=wait_exponential(),
591602
)
603+
async def chat_with_retry() -> ChatResponse | AsyncIterator[ChatResponse]:
604+
return await self._async_client.chat(
605+
model=self.model,
606+
messages=ollama_messages,
607+
tools=ollama_tools,
608+
stream=is_stream, # type: ignore[call-overload] # Ollama expects Literal[True] or Literal[False], not bool
609+
keep_alive=self.keep_alive,
610+
options=generation_kwargs,
611+
format=self.response_format,
612+
think=self.think,
613+
)
614+
615+
response = await chat_with_retry()
592616

593617
if isinstance(response, AsyncIterator):
594618
# response is an async iterator for streaming

integrations/ollama/tests/test_chat_generator.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,7 @@ def test_init_default(self):
518518
assert component.url == "http://localhost:11434"
519519
assert component.generation_kwargs == {}
520520
assert component.timeout == 120
521+
assert component.max_retries == 0
521522
assert component.streaming_callback is None
522523
assert component.tools is None
523524
assert component.keep_alive is None
@@ -529,6 +530,7 @@ def test_init(self, tools):
529530
url="http://my-custom-endpoint:11434",
530531
generation_kwargs={"temperature": 0.5},
531532
timeout=5,
533+
max_retries=2,
532534
keep_alive="10m",
533535
streaming_callback=print_streaming_chunk,
534536
tools=tools,
@@ -539,6 +541,7 @@ def test_init(self, tools):
539541
assert component.url == "http://my-custom-endpoint:11434"
540542
assert component.generation_kwargs == {"temperature": 0.5}
541543
assert component.timeout == 5
544+
assert component.max_retries == 2
542545
assert component.keep_alive == "10m"
543546
assert component.streaming_callback is print_streaming_chunk
544547
assert component.tools == tools
@@ -603,6 +606,7 @@ def test_to_dict(self):
603606
"type": "haystack_integrations.components.generators.ollama.chat.chat_generator.OllamaChatGenerator",
604607
"init_parameters": {
605608
"timeout": 120,
609+
"max_retries": 0,
606610
"model": "llama2",
607611
"url": "custom_url",
608612
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
@@ -650,6 +654,7 @@ def test_from_dict(self):
650654
"type": "haystack_integrations.components.generators.ollama.chat.chat_generator.OllamaChatGenerator",
651655
"init_parameters": {
652656
"timeout": 120,
657+
"max_retries": 0,
653658
"model": "llama2",
654659
"url": "custom_url",
655660
"keep_alive": "5m",
@@ -689,6 +694,7 @@ def test_from_dict(self):
689694
"some_test_param": "test-params",
690695
}
691696
assert component.timeout == 120
697+
assert component.max_retries == 0
692698
assert component.tools == [tool]
693699
assert component.response_format == {
694700
"type": "object",
@@ -790,6 +796,38 @@ def test_run(self, mock_client):
790796
assert result["replies"][0].text == "Fine. How can I help you today?"
791797
assert result["replies"][0].role == "assistant"
792798

799+
@patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client")
800+
def test_run_retries_after_failure(self, mock_client):
801+
generator = OllamaChatGenerator(max_retries=1)
802+
803+
mock_response = ChatResponse(
804+
model="qwen3:0.6b",
805+
created_at="2023-12-12T14:13:43.416799Z",
806+
message={"role": "assistant", "content": "Recovered after retry"},
807+
done=True,
808+
prompt_eval_count=1,
809+
eval_count=2,
810+
)
811+
812+
mock_client_instance = mock_client.return_value
813+
mock_client_instance.chat.side_effect = [RuntimeError("temporary failure"), mock_response]
814+
815+
result = generator.run(messages=[ChatMessage.from_user("Hello!")])
816+
817+
assert mock_client_instance.chat.call_count == 2
818+
assert result["replies"][0].text == "Recovered after retry"
819+
820+
@patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client")
821+
def test_run_raises_after_retry_exhausted(self, mock_client):
822+
generator = OllamaChatGenerator(max_retries=1)
823+
mock_client_instance = mock_client.return_value
824+
mock_client_instance.chat.side_effect = RuntimeError("persistent failure")
825+
826+
with pytest.raises(RuntimeError, match="persistent failure"):
827+
generator.run(messages=[ChatMessage.from_user("Hello!")])
828+
829+
assert mock_client_instance.chat.call_count == 2
830+
793831
@patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client")
794832
def test_run_streaming(self, mock_client):
795833
collected_chunks = []

0 commit comments

Comments
 (0)