@@ -90,6 +90,65 @@ def retry(
9090 raise ParseError (f"Could not parse a valid value after { n_retry } retries." )
9191
9292
93+ def retry_multiple (
94+ chat : "ChatModel" ,
95+ messages : "Discussion" ,
96+ n_retry : int ,
97+ parser : callable ,
98+ log : bool = True ,
99+ num_samples : int = 1 ,
100+ ):
101+ """Retry querying the chat models with the response from the parser until it
102+ returns a valid value.
103+
104+ If the answer is not valid, it will retry and append to the chat the retry
105+ message. It will stop after `n_retry`.
106+
107+ Note, each retry has to resend the whole prompt to the API. This can be slow
108+ and expensive.
109+
110+ Args:
111+ chat (ChatModel): a ChatModel object taking a list of messages and
112+ returning a list of answers, all in OpenAI format.
113+ messages (list): the list of messages so far. This list will be modified with
114+ the new messages and the retry messages.
115+ n_retry (int): the maximum number of sequential retries.
116+ parser (callable): a function taking a message and retruning a parsed value,
117+ or raising a ParseError
118+ log (bool): whether to log the retry messages.
119+ num_samples (int): the number of samples to generate from the model.
120+
121+ Returns:
122+ list[dict]: the parsed value, with a string at key "action".
123+
124+ Raises:
125+ ParseError: if the parser could not parse the response after n_retry retries.
126+ """
127+ tries = 0
128+ while tries < n_retry :
129+ answer_list = chat (messages , num_samples = num_samples )
130+ # TODO: could we change this to not use inplace modifications ?
131+ messages .append (answer )
132+ parsed_answers = []
133+ errors = []
134+ for answer in answer_list :
135+ try :
136+ parsed_answers .append (parser (answer ["content" ]))
137+ except ParseError as parsing_error :
138+ errors .append (str (parsing_error ))
139+ # if we have a valid answer, return it
140+ if parsed_answers :
141+ return parsed_answers , tries
142+ else :
143+ tries += 1
144+ if log :
145+ msg = f"Query failed. Retrying { tries } /{ n_retry } .\n [LLM]:\n { answer ['content' ]} \n [User]:\n { str (errors )} "
146+ logging .info (msg )
147+ messages .append (dict (role = "user" , content = str (errors )))
148+
149+ raise ParseError (f"Could not parse a valid value after { n_retry } retries." )
150+
151+
93152def truncate_tokens (text , max_tokens = 8000 , start = 0 , model_name = "gpt-4" ):
94153 """Use tiktoken to truncate a text to a maximum number of tokens."""
95154 enc = tiktoken .encoding_for_model (model_name )
0 commit comments