Skip to content

Commit a5a98d6

Browse files
committed
add simple web chatbot
1 parent 072192a commit a5a98d6

7 files changed

Lines changed: 782 additions & 5 deletions

File tree

llms_wrapper/chatbot.py

Lines changed: 98 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import concurrent.futures
88
import asyncio
99
from loguru import logger
10-
from llms_wrapper.llms import any2message
10+
from copy import deepcopy
11+
from llms_wrapper.llms import any2message, LLM
1112
from llms_wrapper.log import configure_logging
1213

1314
# Implementation notes:
@@ -46,7 +47,14 @@ class ChatbotError(Exception):
4647
pass
4748

4849
class SerialChatbot:
49-
def __init__(self, llm, config=None, initial_message=None, message_template=None):
50+
def __init__(
51+
self,
52+
llm,
53+
config=None,
54+
initial_message=None,
55+
message_template=None,
56+
max_messages: int = 9999999,
57+
):
5058
"""
5159
Initialize the SerialChatbot with an LLM, configuration, and optional initial message and template.
5260
@@ -65,6 +73,9 @@ def __init__(self, llm, config=None, initial_message=None, message_template=None
6573
config: The full configuration object or None.
6674
initial_message: The LLM Message to send initially to the LLM, if None, send nothing.
6775
message_template: The prompt template to use if None, just send messages as role user.
76+
max_messages: Optional maximum number of messages to keep in the chat history. If None, no limit. If
77+
the number of messages exceeds this limit, the method compact_messages() is called to replace
78+
the messages with a compacted version.
6879
"""
6980
self.llm = llm
7081
self.config = config
@@ -75,6 +86,7 @@ def __init__(self, llm, config=None, initial_message=None, message_template=None
7586
else:
7687
self.initial_message = None
7788
self.message_template = message_template
89+
self.max_messages = max_messages
7890

7991
def reply(
8092
self,
@@ -110,6 +122,40 @@ def reply(
110122
self.llm_messages.append({"role": "assistant", "content": answer})
111123
return dict(answer=answer, error=None, is_ok=True, message=message, metadata=metadata, response=ret)
112124

125+
def set_llm(self, llm: LLM):
126+
"""
127+
Set the LLM to use for generating responses.
128+
"""
129+
self.llm = llm
130+
131+
def clear_history(self):
132+
"""
133+
Clear the chat history. This removes all messages from the chat history and re-initiralizes with
134+
the initial message, if any.
135+
"""
136+
if self.initial_message is not None:
137+
self.llm_messages = deepcopy(self.initial_message)
138+
else:
139+
self.llm_messages = []
140+
141+
def append_messages(self, messages: list[dict]):
142+
"""Append messages to the chat history and shorten the history if necessary"""
143+
self.llm_messages.extend(messages)
144+
# shorten the messages if necessary
145+
if len(self.llm_messages) > self.max_messages:
146+
self.compact_messages()
147+
148+
def compact_messages(self):
149+
"""
150+
Default strategy for compacting messages in the chat history. This will keep max_messages
151+
last messages if there was no initial message, or max_messages - 1 if there was an initial message.
152+
"""
153+
cur_messages = self.llm_messages
154+
if self.initial_message is not None:
155+
self.llm_messages = [self.initial_message] + cur_messages[-(self.max_messages - 1):]
156+
else:
157+
self.llm_messages = cur_messages[-self.max_messages:]
158+
113159

114160
class FlexibleChatbot:
115161
"""
@@ -917,6 +963,56 @@ def run_sync_example(schbot: SerialChatbot):
917963
else:
918964
logger.warning("Chatbot thread was not alive when stop was called.")
919965

966+
# ===================================================================================
967+
# Example sync chatbot implementation for interacting with LLMs in a chat.
968+
# ===================================================================================
969+
970+
class SimpleSerialChatbot(SerialChatbot):
971+
"""
972+
This implementation limits the reply function to just string messages.
973+
"""
974+
def __init__(
975+
self, *args, **kwargs):
976+
super().__init__(*args, **kwargs)
977+
# list of tuples containing user requests and responses
978+
self.history: list = []
979+
980+
def clear_history(self):
981+
"""Clear the chat history"""
982+
super().clear_history()
983+
self.history = []
984+
985+
def reply(self, request: str) -> dict:
986+
new_messages = any2message(request)
987+
logger.debug(f"SimpleSerialChatbot: reply called with request: {request}, new_messages: {new_messages}")
988+
self.append_messages(new_messages)
989+
logger.debug(f"SimpleSerialChatbot: reply: llm_messages before query: {self.llm_messages}")
990+
response = self.llm.query(self.llm_messages, return_cost=True)
991+
logger.debug(f"SimpleSerialChatbot: reply: response is {response}")
992+
if response.get("error"):
993+
error = response["error"]
994+
self.history.append((request, error))
995+
self.append_messages([dict(role="assistant", content=f"Error: {error}")])
996+
return dict(
997+
error=response["error"],
998+
answer=None, is_ok=False,
999+
cost=response.get("cost", 0),
1000+
n_prompt_tokens=response.get("n_prompt_tokens", 0),
1001+
n_completion_tokens=response.get("n_completion_tokens", 0),
1002+
response=response)
1003+
else:
1004+
self.append_messages([dict(role="assistant", content=response["answer"])])
1005+
return dict(
1006+
error=None,
1007+
answer=response["answer"],
1008+
cost=response.get("cost", 0),
1009+
n_prompt_tokens=response.get("n_prompt_tokens", 0),
1010+
n_completion_tokens=response.get("n_completion_tokens", 0),
1011+
is_ok=True,
1012+
response=response,
1013+
)
1014+
1015+
9201016

9211017
if __name__ == "__main__":
9221018

0 commit comments

Comments
 (0)