Skip to content

Commit 5e8b436

Browse files
committed
Refactor: Update environment variable handling and improve conversation memory management
1 parent 25bcdd7 commit 5e8b436

1 file changed

Lines changed: 44 additions & 73 deletions

File tree

LocalBot.py

Lines changed: 44 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import os, time, random, asyncio, aiohttp, json, lyricsgenius, discord, wolframalpha
1+
import os, random, asyncio, aiohttp, json, lyricsgenius, discord, wolframalpha
22
from discord.ext import commands, tasks
33
import yt_dlp as youtube_dl
44
from typing import Optional
@@ -12,7 +12,7 @@
1212
GENIUS_TOKEN = os.getenv("GENIUS_TOKEN")
1313
WOLF = os.getenv("WOLF")
1414
WEATHER_API_KEY = os.getenv("WEATHER_API_KEY")
15-
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
15+
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
1616

1717
intents = discord.Intents.default()
1818
intents.message_content = True
@@ -91,8 +91,7 @@
9191
Note: Process user messages in format "username: message" but respond to message content only.
9292
"""
9393

94-
# Initialize Gemini client
95-
gemini_client = genai.Client(api_key=GOOGLE_API_KEY) if GOOGLE_API_KEY else None
94+
gemini_client = genai.Client(api_key=GEMINI_API_KEY) if GEMINI_API_KEY else None
9695

9796
ytdl_format_options = {
9897
"format": "bestaudio/best",
@@ -172,61 +171,63 @@ async def generate_chat_completion(
172171
global conversation_memory
173172
context_key = server_id if server_id else f"DM-{channel_id}-{user_id}"
174173

175-
# Initialize memory if it doesn't exist for this context
176174
if context_key not in conversation_memory:
177175
conversation_memory[context_key] = ConversationBufferWindowMemory(
178-
k=10, memory_key="chat_history", return_messages=True
176+
k=42, memory_key="chat_history", return_messages=True
179177
)
180178

181179
memory = conversation_memory[context_key]
182180

183-
# Check if Gemini client is available
184181
if not gemini_client:
185182
await send_response(
186183
ctx,
187-
"Gemini client is not configured. Please check your GOOGLE_API_KEY.",
184+
"Gemini client is not configured. Please check your GEMINI_API_KEY or GOOGLE_API_KEY.",
188185
)
189186
return None
190187

191-
# Build conversation history for Gemini
192188
conversation_history = []
193-
194-
# Add system instruction
195189
system_content = system_prompt
196-
197-
# Get chat history from memory
198190
chat_history = memory.chat_memory.messages
199191

200-
# Convert LangChain messages to Gemini format
201192
for message in chat_history:
202193
if hasattr(message, "content"):
203194
if hasattr(message, "type") and message.type == "human":
204195
conversation_history.append(f"User: {message.content}")
205196
elif hasattr(message, "type") and message.type == "ai":
206197
conversation_history.append(f"Assistant: {message.content}")
207198

208-
# Add current prompt
209199
conversation_history.append(prompt)
210200

211-
# Combine system prompt with conversation
212201
full_prompt = f"{system_content}\n\nConversation:\n" + "\n".join(
213202
conversation_history
214203
)
215204

216-
# Generate response using Gemini
205+
contents = [
206+
types.Content(
207+
role="user",
208+
parts=[
209+
types.Part.from_text(text=full_prompt),
210+
],
211+
),
212+
]
213+
214+
generate_content_config = types.GenerateContentConfig(
215+
max_output_tokens=1024,
216+
temperature=0.7,
217+
thinking_config=types.ThinkingConfig(
218+
thinking_budget=0,
219+
),
220+
response_mime_type="text/plain",
221+
)
222+
217223
response = gemini_client.models.generate_content(
218224
model="gemini-2.5-flash",
219-
contents=full_prompt,
220-
config=types.GenerateContentConfig(
221-
temperature=0.7,
222-
max_output_tokens=1024,
223-
),
225+
contents=contents,
226+
config=generate_content_config,
224227
)
225228

226-
# Extract response text
227229
response_text = response.text if hasattr(response, "text") else str(response)
228230

229-
# Update memory with the new interaction
230231
memory.chat_memory.add_user_message(prompt)
231232
memory.chat_memory.add_ai_message(response_text)
232233

@@ -246,12 +247,10 @@ async def handle_tool_call(
246247
) -> str:
247248
"""Handle tool calls embedded in the response."""
248249
try:
249-
# Extract tool call
250250
start = tool_call_text.index("<tool_calls>") + len("<tool_calls>")
251251
end = tool_call_text.index("</tool_calls>")
252252
tool_call_json = tool_call_text[start:end].strip()
253253

254-
# Parse tool call
255254
try:
256255
tool_call = json.loads(tool_call_json)
257256
except json.JSONDecodeError:
@@ -263,7 +262,6 @@ async def handle_tool_call(
263262
if not tool_name:
264263
raise ValueError("Tool name not specified")
265264

266-
# Define available tools with type hints
267265
tool_actions = {
268266
"cat": lambda: cat(ctx, from_tool_call=send_directly),
269267
"dog": lambda: dog(ctx, from_tool_call=send_directly),
@@ -299,13 +297,11 @@ async def handle_tool_call(
299297
"whats_new": lambda: whats_new(ctx, from_tool_call=send_directly),
300298
}
301299

302-
# Execute tool
303300
if tool_name not in tool_actions:
304301
raise ValueError(f"Unknown tool: {tool_name}")
305302

306303
result = await tool_actions[tool_name]()
307304

308-
# Update memory with tool result
309305
if result and memory:
310306
memory.chat_memory.messages[-1].content += f"\nTool result: {result}"
311307

@@ -326,30 +322,24 @@ async def handle_tool_call(
326322

327323

328324
statuses = [
329-
# Interactive Features
330325
"Ask me anything! 💭",
331326
"Weather forecasts 🌤️",
332327
"Playing music 🎵",
333-
# Games
334328
"Number guessing 🎲",
335329
"Rolling dice 🎯",
336330
"Flipping coins 🪙",
337-
# Helper Features
338331
"Managing messages 📝",
339332
"Fetching lyrics 🎤",
340333
"Calculating math 🔢",
341334
"Sharing knowledge 📚",
342-
# Fun Statuses
343335
"Running on local power 🔋",
344336
"Processing requests ⚡",
345337
"Thinking in binary 🤖",
346338
"Learning new tricks 🎓",
347-
# Friendly Messages
348339
"Here to help! 👋",
349340
"Chat with me 💬",
350341
"Ready for commands ⌨️",
351342
"Local assistant 🤝",
352-
# System Status
353343
"Online and active ✨",
354344
"Fast responses ⚡",
355345
"24/7 Service 🕒",
@@ -501,20 +491,16 @@ async def chat(ctx, *, message):
501491
username = ctx.author.display_name
502492

503493
message = username + ": " + message
504-
505-
# Get the memory context key
506494
context_key = server_id if server_id else f"DM-{channel_id}-{user_id}"
507495

508-
# Make sure memory is initialized
509496
global conversation_memory
510497
if context_key not in conversation_memory:
511498
conversation_memory[context_key] = ConversationBufferWindowMemory(
512-
k=10, memory_key="chat_history", return_messages=True
499+
k=42, memory_key="chat_history", return_messages=True
513500
)
514501

515502
memory = conversation_memory[context_key]
516503

517-
# Initial response generation
518504
response = await generate_chat_completion(
519505
ctx=ctx,
520506
prompt=message,
@@ -523,20 +509,18 @@ async def chat(ctx, *, message):
523509
user_id=user_id,
524510
)
525511

526-
# Process all tool calls in sequence
527512
tool_was_used = False
528-
if response: # Check if response is not None
513+
if response:
529514
while "<tool_calls>" in response and "</tool_calls>" in response:
530515
tool_was_used = True
531-
# Extract the tool call
516+
532517
tool_start = response.find("<tool_calls>")
533518
tool_end = response.find("</tool_calls>") + len("</tool_calls>")
534519
tool_call_text = response[tool_start:tool_end]
535520

536-
# Save text before the tool call - this may contain context that needs to be sent
537521
pre_tool_text = response[:tool_start].strip()
538522
if pre_tool_text:
539-
# Only send preceding text for the first tool in a sequence
523+
540524
first_tool_call = not any(
541525
"<tool_calls>" in msg.content
542526
for msg in memory.chat_memory.messages
@@ -545,24 +529,22 @@ async def chat(ctx, *, message):
545529
if first_tool_call:
546530
await ctx.reply(pre_tool_text)
547531

548-
# Execute the tool and get results
549532
tool_result = await handle_tool_call(
550533
ctx,
551534
tool_call_text,
552535
memory,
553-
send_directly=True, # Prevent tools from sending their own message
536+
send_directly=True,
554537
)
555538

556-
print(f"Tool result: {tool_result}") # Debug print
539+
print(f"Tool result: {tool_result}")
557540

558-
# If there's more text after this tool call, check if it contains another tool call
559541
remaining_text = response[tool_end:].strip()
560542

561543
if "<tool_calls>" in remaining_text:
562-
# There's another tool call, continue processing
544+
563545
response = remaining_text
564546
else:
565-
# No more tool calls, generate final response with all tool results
547+
566548
followup_prompt = (
567549
f"You used one or more tools to answer the user's question. "
568550
f"The last tool result was: {tool_result}. "
@@ -579,11 +561,9 @@ async def chat(ctx, *, message):
579561
is_tool_followup=True,
580562
)
581563

582-
# Send the complete response
583564
await send_complete_response(ctx, followup_response)
584-
break # Exit the loop as we've processed all tools and sent a response
565+
break
585566

586-
# If no tool calls were found, send the response directly
587567
if not tool_was_used:
588568
await send_complete_response(ctx, response)
589569

@@ -684,7 +664,7 @@ async def play_song(ctx, info, filename):
684664

685665
@bot.slash_command(description="Play a song or playlist from YouTube.")
686666
async def play(ctx, *, query):
687-
# Check if user is in a voice channel
667+
688668
if not ctx.author.voice:
689669
await ctx.response.send_message(
690670
"You need to be in a voice channel to play music.", ephemeral=True
@@ -693,13 +673,12 @@ async def play(ctx, *, query):
693673

694674
state = await get_server_state(ctx.guild.id)
695675

696-
# Join the user's voice channel if not already connected
697676
if not ctx.voice_client:
698677
channel = ctx.author.voice.channel
699678
await channel.connect()
700679
await ctx.response.defer()
701680
else:
702-
# If already in a voice channel but it's different from the user's, move to user's channel
681+
703682
if ctx.voice_client.channel != ctx.author.voice.channel:
704683
await ctx.voice_client.disconnect()
705684
channel = ctx.author.voice.channel
@@ -800,7 +779,7 @@ async def pin(ctx):
800779
async def weather(ctx, city: str, from_tool_call: bool = False) -> str:
801780
"""Get current weather for a city using OpenWeatherMap API."""
802781
try:
803-
# Use direct city query instead of geocoding
782+
804783
weather_url = (
805784
f"https://api.openweathermap.org/data/2.5/weather?"
806785
f"q={city}&units=metric&appid={WEATHER_API_KEY}"
@@ -818,18 +797,15 @@ async def weather(ctx, city: str, from_tool_call: bool = False) -> str:
818797

819798
data = await response.json()
820799

821-
# Extract weather data
822800
temp = data["main"]["temp"]
823801
feels_like = data["main"]["feels_like"]
824802
humidity = data["main"]["humidity"]
825803
wind_speed = data["wind"]["speed"]
826804
weather_desc = data["weather"][0]["description"]
827805
pressure = data["main"]["pressure"]
828806

829-
# Get country code and combine with city name
830807
location_name = f"{data['name']}, {data['sys']['country']}"
831808

832-
# Select weather emoji based on weather condition code
833809
weather_id = data["weather"][0]["id"]
834810
weather_emoji = "🌈" # default
835811
if weather_id < 300:
@@ -877,7 +853,7 @@ async def getweather(ctx, *, city: str):
877853
async def music_play(ctx, query: str, from_tool_call: bool = False) -> str:
878854
"""Play music in a voice channel via the chat interface."""
879855
try:
880-
# Check if user is in a voice channel
856+
881857
if not ctx.author.voice:
882858
message = "You need to be in a voice channel to play music."
883859
if not from_tool_call:
@@ -886,7 +862,6 @@ async def music_play(ctx, query: str, from_tool_call: bool = False) -> str:
886862

887863
state = await get_server_state(ctx.guild.id)
888864

889-
# Join the user's voice channel if not already connected
890865
if not ctx.voice_client:
891866
channel = ctx.author.voice.channel
892867
try:
@@ -897,7 +872,7 @@ async def music_play(ctx, query: str, from_tool_call: bool = False) -> str:
897872
await ctx.reply(error_msg)
898873
return error_msg
899874
else:
900-
# If already in a different voice channel, move to user's channel
875+
901876
if ctx.voice_client.channel != ctx.author.voice.channel:
902877
await ctx.voice_client.disconnect()
903878
channel = ctx.author.voice.channel
@@ -916,18 +891,17 @@ async def music_play(ctx, query: str, from_tool_call: bool = False) -> str:
916891
else:
917892
url = f"ytsearch:{query}"
918893

919-
# Status message
920894
status_message = f"Searching for '{query}'..."
921895
if not from_tool_call:
922896
await ctx.reply(status_message)
923897

924898
try:
925-
# Extract info and download
899+
926900
info = ydl.extract_info(url, download=True)
927-
result = "No songs found." # Default value
901+
result = "No songs found."
928902

929903
if info and "entries" in info:
930-
# It's a playlist
904+
931905
songs_added = 0
932906
for entry in info["entries"]:
933907
filename = ydl.prepare_filename(entry)
@@ -937,19 +911,17 @@ async def music_play(ctx, query: str, from_tool_call: bool = False) -> str:
937911
songs_added += 1
938912
result = f"Added {songs_added} songs to the queue from playlist."
939913
elif info:
940-
# It's a single song
914+
941915
filename = ydl.prepare_filename(info)
942916
state["playlist_queue"].append({"info": info, "filename": filename})
943917
result = f"Added '{info['title']}' to the queue."
944918

945-
# Start playing if nothing is currently playing
946919
if not state["current_song"] and state["playlist_queue"]:
947920
next_song = state["playlist_queue"].pop(0)
948921
title = next_song["info"]["title"]
949-
# Direct message handling instead of using followup
922+
950923
await ctx.send(f"Now playing: {title}")
951924

952-
# Set up playback
953925
ctx.voice_client.play(
954926
discord.FFmpegPCMAudio(
955927
next_song["filename"],
@@ -1012,7 +984,6 @@ async def whats_new(ctx, from_tool_call=False):
1012984
with open(whatsnew_path, "r") as file:
1013985
content = file.read()
1014986

1015-
# If the content is too long, summarize or truncate it
1016987
if len(content) > 1900:
1017988
content = content[:1900] + "\n\n... (truncated)"
1018989

0 commit comments

Comments
 (0)