2222)
2323from haystack .utils .callable_serialization import deserialize_callable , serialize_callable
2424from pydantic .json_schema import JsonSchemaValue
25+ from tenacity import RetryCallState , retry , retry_if_exception , wait_exponential
2526
26- from ollama import AsyncClient , ChatResponse , Client
27+ from ollama import AsyncClient , ChatResponse , Client , ResponseError
2728
2829FINISH_REASON_MAPPING : dict [str , FinishReason ] = {
2930 "stop" : "stop" ,
3031 "tool_calls" : "tool_calls" ,
3132 # we skip load and unload reasons
3233}
3334
35+ HTTP_STATUS_TOO_MANY_REQUESTS = 429
36+ HTTP_STATUS_SERVER_ERROR_MIN = 500
37+ HTTP_STATUS_SERVER_ERROR_MAX_EXCLUSIVE = 600
38+
39+
40+ def _stop_after_instance_max_retries (retry_state : RetryCallState ) -> bool :
41+ """
42+ Stop retries after `self.max_retries + 1` attempts.
43+ """
44+ instance = retry_state .args [0 ]
45+ return retry_state .attempt_number >= instance .max_retries + 1
46+
47+
48+ def _is_retryable_exception (exc : BaseException ) -> bool :
49+ """
50+ Return True for transient failures that should be retried.
51+
52+ Retries are attempted for:
53+ - HTTP 429 responses
54+ - HTTP 5xx responses
55+ - transport-level connection/timeout errors
56+ """
57+ if isinstance (exc , ResponseError ):
58+ return exc .status_code == HTTP_STATUS_TOO_MANY_REQUESTS or (
59+ HTTP_STATUS_SERVER_ERROR_MIN <= exc .status_code < HTTP_STATUS_SERVER_ERROR_MAX_EXCLUSIVE
60+ )
61+ return isinstance (exc , (ConnectionError , TimeoutError ))
62+
3463
3564def _convert_chatmessage_to_ollama_format (message : ChatMessage ) -> dict [str , Any ]:
3665 """
@@ -216,6 +245,7 @@ def __init__(
216245 url : str = "http://localhost:11434" ,
217246 generation_kwargs : dict [str , Any ] | None = None ,
218247 timeout : int = 120 ,
248+ max_retries : int = 0 ,
219249 keep_alive : float | str | None = None ,
220250 streaming_callback : Callable [[StreamingChunk ], None ] | None = None ,
221251 tools : ToolsType | None = None ,
@@ -233,6 +263,9 @@ def __init__(
233263 [Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
234264 :param timeout:
235265 The number of seconds before throwing a timeout error from the Ollama API.
266+ :param max_retries:
267+ Maximum number of retries to attempt for failed requests (HTTP 429, 5xx, connection/timeout errors).
268+ Uses exponential backoff between attempts. Set to 0 (default) to disable retries.
236269 :param think:
237270 If True, the model will "think" before producing a response.
238271 Only [thinking models](https://ollama.com/search?c=thinking) support this feature.
@@ -268,6 +301,7 @@ def __init__(
268301 self .url = url
269302 self .generation_kwargs = generation_kwargs or {}
270303 self .timeout = timeout
304+ self .max_retries = max_retries
271305 self .keep_alive = keep_alive
272306 self .streaming_callback = streaming_callback
273307 self .tools = tools # Store original tools for serialization
@@ -292,6 +326,7 @@ def to_dict(self) -> dict[str, Any]:
292326 url = self .url ,
293327 generation_kwargs = self .generation_kwargs ,
294328 timeout = self .timeout ,
329+ max_retries = self .max_retries ,
295330 keep_alive = self .keep_alive ,
296331 streaming_callback = callback_name ,
297332 tools = serialize_tools_or_toolset (self .tools ),
@@ -469,6 +504,56 @@ async def _handle_streaming_response_async(
469504
470505 return {"replies" : [reply ]}
471506
507+ @retry (
508+ reraise = True ,
509+ stop = _stop_after_instance_max_retries ,
510+ retry = retry_if_exception (_is_retryable_exception ),
511+ wait = wait_exponential (),
512+ )
513+ def _chat (
514+ self ,
515+ * ,
516+ messages : list [dict [str , Any ]],
517+ tools : list [dict [str , Any ]] | None ,
518+ is_stream : bool ,
519+ generation_kwargs : dict [str , Any ],
520+ ) -> ChatResponse | Iterator [ChatResponse ]:
521+ return self ._client .chat (
522+ model = self .model ,
523+ messages = messages ,
524+ tools = tools ,
525+ stream = is_stream , # type: ignore[call-overload] # Ollama expects Literal[True] or Literal[False], not bool
526+ keep_alive = self .keep_alive ,
527+ options = generation_kwargs ,
528+ format = self .response_format ,
529+ think = self .think ,
530+ )
531+
532+ @retry (
533+ reraise = True ,
534+ stop = _stop_after_instance_max_retries ,
535+ retry = retry_if_exception (_is_retryable_exception ),
536+ wait = wait_exponential (),
537+ )
538+ async def _chat_async (
539+ self ,
540+ * ,
541+ messages : list [dict [str , Any ]],
542+ tools : list [dict [str , Any ]] | None ,
543+ is_stream : bool ,
544+ generation_kwargs : dict [str , Any ],
545+ ) -> ChatResponse | AsyncIterator [ChatResponse ]:
546+ return await self ._async_client .chat (
547+ model = self .model ,
548+ messages = messages ,
549+ tools = tools ,
550+ stream = is_stream , # type: ignore[call-overload] # Ollama expects Literal[True] or Literal[False], not bool
551+ keep_alive = self .keep_alive ,
552+ options = generation_kwargs ,
553+ format = self .response_format ,
554+ think = self .think ,
555+ )
556+
472557 @component .output_types (replies = list [ChatMessage ])
473558 def run (
474559 self ,
@@ -518,15 +603,8 @@ def run(
518603
519604 ollama_messages = [_convert_chatmessage_to_ollama_format (m ) for m in messages ]
520605
521- response = self ._client .chat (
522- model = self .model ,
523- messages = ollama_messages ,
524- tools = ollama_tools ,
525- stream = is_stream , # type: ignore[call-overload] # Ollama expects Literal[True] or Literal[False], not bool
526- keep_alive = self .keep_alive ,
527- options = generation_kwargs ,
528- format = self .response_format ,
529- think = self .think ,
606+ response = self ._chat (
607+ messages = ollama_messages , tools = ollama_tools , is_stream = is_stream , generation_kwargs = generation_kwargs
530608 )
531609
532610 if isinstance (response , Iterator ):
@@ -579,15 +657,8 @@ async def run_async(
579657
580658 ollama_messages = [_convert_chatmessage_to_ollama_format (m ) for m in messages ]
581659
582- response = await self ._async_client .chat (
583- model = self .model ,
584- messages = ollama_messages ,
585- tools = ollama_tools ,
586- stream = is_stream , # type: ignore[call-overload] # Ollama expects Literal[True] or Literal[False], not bool
587- keep_alive = self .keep_alive ,
588- options = generation_kwargs ,
589- format = self .response_format ,
590- think = self .think ,
660+ response = await self ._chat_async (
661+ messages = ollama_messages , tools = ollama_tools , is_stream = is_stream , generation_kwargs = generation_kwargs
591662 )
592663
593664 if isinstance (response , AsyncIterator ):
0 commit comments