77import concurrent .futures
88import asyncio
99from loguru import logger
10- from llms_wrapper .llms import any2message
10+ from copy import deepcopy
11+ from llms_wrapper .llms import any2message , LLM
1112from llms_wrapper .log import configure_logging
1213
1314# Implementation notes:
@@ -46,7 +47,14 @@ class ChatbotError(Exception):
4647 pass
4748
4849class SerialChatbot :
49- def __init__ (self , llm , config = None , initial_message = None , message_template = None ):
50+ def __init__ (
51+ self ,
52+ llm ,
53+ config = None ,
54+ initial_message = None ,
55+ message_template = None ,
56+ max_messages : int = 9999999 ,
57+ ):
5058 """
5159 Initialize the SerialChatbot with an LLM, configuration, and optional initial message and template.
5260
@@ -65,6 +73,9 @@ def __init__(self, llm, config=None, initial_message=None, message_template=None
6573 config: The full configuration object or None.
6674 initial_message: The LLM Message to send initially to the LLM, if None, send nothing.
6775 message_template: The prompt template to use if None, just send messages as role user.
76+ max_messages: Optional maximum number of messages to keep in the chat history. If None, no limit. If
77+ the number of messages exceeds this limit, the method compact_messages() is called to replace
78+ the messages with a compacted version.
6879 """
6980 self .llm = llm
7081 self .config = config
@@ -75,6 +86,7 @@ def __init__(self, llm, config=None, initial_message=None, message_template=None
7586 else :
7687 self .initial_message = None
7788 self .message_template = message_template
89+ self .max_messages = max_messages
7890
7991 def reply (
8092 self ,
@@ -110,6 +122,40 @@ def reply(
110122 self .llm_messages .append ({"role" : "assistant" , "content" : answer })
111123 return dict (answer = answer , error = None , is_ok = True , message = message , metadata = metadata , response = ret )
112124
125+ def set_llm (self , llm : LLM ):
126+ """
127+ Set the LLM to use for generating responses.
128+ """
129+ self .llm = llm
130+
131+ def clear_history (self ):
132+ """
133+ Clear the chat history. This removes all messages from the chat history and re-initiralizes with
134+ the initial message, if any.
135+ """
136+ if self .initial_message is not None :
137+ self .llm_messages = deepcopy (self .initial_message )
138+ else :
139+ self .llm_messages = []
140+
141+ def append_messages (self , messages : list [dict ]):
142+ """Append messages to the chat history and shorten the history if necessary"""
143+ self .llm_messages .extend (messages )
144+ # shorten the messages if necessary
145+ if len (self .llm_messages ) > self .max_messages :
146+ self .compact_messages ()
147+
148+ def compact_messages (self ):
149+ """
150+ Default strategy for compacting messages in the chat history. This will keep max_messages
151+ last messages if there was no initial message, or max_messages - 1 if there was an initial message.
152+ """
153+ cur_messages = self .llm_messages
154+ if self .initial_message is not None :
155+ self .llm_messages = [self .initial_message ] + cur_messages [- (self .max_messages - 1 ):]
156+ else :
157+ self .llm_messages = cur_messages [- self .max_messages :]
158+
113159
114160class FlexibleChatbot :
115161 """
@@ -917,6 +963,56 @@ def run_sync_example(schbot: SerialChatbot):
917963 else :
918964 logger .warning ("Chatbot thread was not alive when stop was called." )
919965
966+ # ===================================================================================
967+ # Example sync chatbot implementation for interacting with LLMs in a chat.
968+ # ===================================================================================
969+
970+ class SimpleSerialChatbot (SerialChatbot ):
971+ """
972+ This implementation limits the reply function to just string messages.
973+ """
974+ def __init__ (
975+ self , * args , ** kwargs ):
976+ super ().__init__ (* args , ** kwargs )
977+ # list of tuples containing user requests and responses
978+ self .history : list = []
979+
980+ def clear_history (self ):
981+ """Clear the chat history"""
982+ super ().clear_history ()
983+ self .history = []
984+
985+ def reply (self , request : str ) -> dict :
986+ new_messages = any2message (request )
987+ logger .debug (f"SimpleSerialChatbot: reply called with request: { request } , new_messages: { new_messages } " )
988+ self .append_messages (new_messages )
989+ logger .debug (f"SimpleSerialChatbot: reply: llm_messages before query: { self .llm_messages } " )
990+ response = self .llm .query (self .llm_messages , return_cost = True )
991+ logger .debug (f"SimpleSerialChatbot: reply: response is { response } " )
992+ if response .get ("error" ):
993+ error = response ["error" ]
994+ self .history .append ((request , error ))
995+ self .append_messages ([dict (role = "assistant" , content = f"Error: { error } " )])
996+ return dict (
997+ error = response ["error" ],
998+ answer = None , is_ok = False ,
999+ cost = response .get ("cost" , 0 ),
1000+ n_prompt_tokens = response .get ("n_prompt_tokens" , 0 ),
1001+ n_completion_tokens = response .get ("n_completion_tokens" , 0 ),
1002+ response = response )
1003+ else :
1004+ self .append_messages ([dict (role = "assistant" , content = response ["answer" ])])
1005+ return dict (
1006+ error = None ,
1007+ answer = response ["answer" ],
1008+ cost = response .get ("cost" , 0 ),
1009+ n_prompt_tokens = response .get ("n_prompt_tokens" , 0 ),
1010+ n_completion_tokens = response .get ("n_completion_tokens" , 0 ),
1011+ is_ok = True ,
1012+ response = response ,
1013+ )
1014+
1015+
9201016
9211017if __name__ == "__main__" :
9221018
0 commit comments