2222)
2323from haystack .utils .callable_serialization import deserialize_callable , serialize_callable
2424from pydantic .json_schema import JsonSchemaValue
25+ from tenacity import retry , retry_if_exception_type , stop_after_attempt , wait_exponential
2526
2627from ollama import AsyncClient , ChatResponse , Client
2728
@@ -216,6 +217,7 @@ def __init__(
216217 url : str = "http://localhost:11434" ,
217218 generation_kwargs : dict [str , Any ] | None = None ,
218219 timeout : int = 120 ,
220+ max_retries : int = 0 ,
219221 keep_alive : float | str | None = None ,
220222 streaming_callback : Callable [[StreamingChunk ], None ] | None = None ,
221223 tools : ToolsType | None = None ,
@@ -233,6 +235,8 @@ def __init__(
233235 [Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
234236 :param timeout:
235237 The number of seconds before throwing a timeout error from the Ollama API.
238+ :param max_retries:
239+ Maximum number of retries to attempt for failed requests.
236240 :param think:
237241 If True, the model will "think" before producing a response.
238242 Only [thinking models](https://ollama.com/search?c=thinking) support this feature.
@@ -268,6 +272,7 @@ def __init__(
268272 self .url = url
269273 self .generation_kwargs = generation_kwargs or {}
270274 self .timeout = timeout
275+ self .max_retries = max_retries
271276 self .keep_alive = keep_alive
272277 self .streaming_callback = streaming_callback
273278 self .tools = tools # Store original tools for serialization
@@ -292,6 +297,7 @@ def to_dict(self) -> dict[str, Any]:
292297 url = self .url ,
293298 generation_kwargs = self .generation_kwargs ,
294299 timeout = self .timeout ,
300+ max_retries = self .max_retries ,
295301 keep_alive = self .keep_alive ,
296302 streaming_callback = callback_name ,
297303 tools = serialize_tools_or_toolset (self .tools ),
@@ -518,16 +524,25 @@ def run(
518524
519525 ollama_messages = [_convert_chatmessage_to_ollama_format (m ) for m in messages ]
520526
521- response = self ._client .chat (
522- model = self .model ,
523- messages = ollama_messages ,
524- tools = ollama_tools ,
525- stream = is_stream , # type: ignore[call-overload] # Ollama expects Literal[True] or Literal[False], not bool
526- keep_alive = self .keep_alive ,
527- options = generation_kwargs ,
528- format = self .response_format ,
529- think = self .think ,
527+ @retry (
528+ reraise = True ,
529+ stop = stop_after_attempt (self .max_retries + 1 ),
530+ retry = retry_if_exception_type (Exception ),
531+ wait = wait_exponential (),
530532 )
533+ def chat_with_retry () -> ChatResponse | Iterator [ChatResponse ]:
534+ return self ._client .chat (
535+ model = self .model ,
536+ messages = ollama_messages ,
537+ tools = ollama_tools ,
538+ stream = is_stream , # type: ignore[call-overload] # Ollama expects Literal[True] or Literal[False], not bool
539+ keep_alive = self .keep_alive ,
540+ options = generation_kwargs ,
541+ format = self .response_format ,
542+ think = self .think ,
543+ )
544+
545+ response = chat_with_retry ()
531546
532547 if isinstance (response , Iterator ):
533548 return self ._handle_streaming_response (response_iter = response , callback = callback )
@@ -579,16 +594,25 @@ async def run_async(
579594
580595 ollama_messages = [_convert_chatmessage_to_ollama_format (m ) for m in messages ]
581596
582- response = await self ._async_client .chat (
583- model = self .model ,
584- messages = ollama_messages ,
585- tools = ollama_tools ,
586- stream = is_stream , # type: ignore[call-overload] # Ollama expects Literal[True] or Literal[False], not bool
587- keep_alive = self .keep_alive ,
588- options = generation_kwargs ,
589- format = self .response_format ,
590- think = self .think ,
597+ @retry (
598+ reraise = True ,
599+ stop = stop_after_attempt (self .max_retries + 1 ),
600+ retry = retry_if_exception_type (Exception ),
601+ wait = wait_exponential (),
591602 )
603+ async def chat_with_retry () -> ChatResponse | AsyncIterator [ChatResponse ]:
604+ return await self ._async_client .chat (
605+ model = self .model ,
606+ messages = ollama_messages ,
607+ tools = ollama_tools ,
608+ stream = is_stream , # type: ignore[call-overload] # Ollama expects Literal[True] or Literal[False], not bool
609+ keep_alive = self .keep_alive ,
610+ options = generation_kwargs ,
611+ format = self .response_format ,
612+ think = self .think ,
613+ )
614+
615+ response = await chat_with_retry ()
592616
593617 if isinstance (response , AsyncIterator ):
594618 # response is an async iterator for streaming
0 commit comments