From acde67e3f7f77db2cf0f125b4fef0d63d643a0e8 Mon Sep 17 00:00:00 2001 From: escesare Date: Wed, 19 Apr 2023 17:37:45 -0400 Subject: [PATCH] Only retry transient OpenAI errors --- promptify/models/nlp/text2text/base_model.py | 4 ++- .../models/nlp/text2text/openai_complete.py | 31 ++++++++++++++++++- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/promptify/models/nlp/text2text/base_model.py b/promptify/models/nlp/text2text/base_model.py index 76557c2..a4d32a6 100644 --- a/promptify/models/nlp/text2text/base_model.py +++ b/promptify/models/nlp/text2text/base_model.py @@ -1,5 +1,6 @@ from abc import ABCMeta, abstractmethod -from typing import List, Optional, Union, Dict +from typing import Dict, List, Optional, Union + import tenacity @@ -384,6 +385,7 @@ def _retry_decorator(self): multiplier=0.3, exp_base=3, max=self.api_wait ), stop=tenacity.stop_after_attempt(self.api_retry), + reraise=True, ) def execute_with_retry(self, *args, **kwargs): diff --git a/promptify/models/nlp/text2text/openai_complete.py b/promptify/models/nlp/text2text/openai_complete.py index 79344c3..80288d2 100644 --- a/promptify/models/nlp/text2text/openai_complete.py +++ b/promptify/models/nlp/text2text/openai_complete.py @@ -1,8 +1,11 @@ from typing import Dict, List, Optional, Tuple, Union + import openai +import tenacity import tiktoken -from promptify.parser.parser import Parser + from promptify.models.nlp.text2text.base_model import Model +from promptify.parser.parser import Parser class OpenAI(Model): @@ -308,6 +311,32 @@ def model_output(self, response: Dict, max_completion_length: int) -> Dict: data["parsed"] = self.parser.fit(data["text"], max_completion_length) return data + def _retry_decorator(self): + """ + Decorator function for retrying API requests if they fail. + + Returns + ------- + tenacity.Retrying + A decorator function for retrying API requests. + + Notes + ----- + This method is a decorator function for retrying API requests using tenacity. + """ + + return tenacity.retry( + wait=tenacity.wait_random_exponential( + multiplier=0.3, exp_base=3, max=self.api_wait + ), + stop=tenacity.stop_after_attempt(self.api_retry), + retry=tenacity.retry_if_exception_type( + (openai.error.APIError, openai.error.TryAgain, openai.error.Timeout, + openai.error.APIConnectionError, openai.error.RateLimitError, + openai.error.ServiceUnavailableError, )), + reraise=True, + ) + def get_parameters( self, ) -> Dict[str, Union[str, int, float, List[str], Dict[str, int]]]: