Skip to content

Commit 47f1483

Browse files
blazickjpDFScience+1-at-298147827574
andauthored
Refactoring System Prompt Management and Test Updates (#18)
* Initial refactor * fixing bugs after changes * System prompt fix * more refactor and bug fixing * prompt management * readme * refactor * System prompt fix * Fix tests * readme update --------- Co-authored-by: DFScience+1-at-298147827574 <jblazick@amazon.com>
1 parent 8e6db6a commit 47f1483

14 files changed

Lines changed: 513 additions & 501 deletions

File tree

backend/agent/agent.py

Lines changed: 32 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
GPT_MODEL = "gpt-4" # or any other chat model you want to use
1919
# GPT_MODEL = "anthropic" # or any other chat model you want to use
2020
MAX_TOKENS = 2000 # or any other number of tokens you want to use
21-
TEMPERATURE = 0.2 # or any other temperature you want to use
21+
TEMPERATURE = 0.75 # or any other temperature you want to use
2222

2323

2424
class FunctionCall(BaseModel):
@@ -75,7 +75,6 @@ def __init__(
7575
self.function_map = {
7676
func.__name__: func for func in callables if func is not None
7777
}
78-
self.files_in_prompt: List[str] = []
7978

8079
def query(self, input: str, command: Optional[str] = None) -> List[str]:
8180
"""
@@ -93,7 +92,8 @@ def query(self, input: str, command: Optional[str] = None) -> List[str]:
9392
message_history = [
9493
Message(**i).to_dict() for i in self.memory_manager.get_messages()
9594
]
96-
function_to_call = FunctionCall()
95+
96+
self.function_to_call = FunctionCall()
9797

9898
keyword_args = {
9999
"model": self.GPT_MODEL,
@@ -121,70 +121,41 @@ def query(self, input: str, command: Optional[str] = None) -> List[str]:
121121
# )
122122
# self.set_files_in_prompt(include_line_numbers=True)
123123
keyword_args["model"] = "gpt-4"
124-
124+
print(f"Calling model: {self.GPT_MODEL}")
125125
for i, chunk in enumerate(self.call_model_streaming(**keyword_args)):
126126
delta = chunk["choices"][0].get("delta", {})
127127
if "function_call" in delta:
128-
if "name" in delta.function_call:
129-
function_to_call.name = delta.function_call["name"]
130-
if "arguments" in delta.function_call:
131-
if function_to_call.name == "Changes" and i == 0:
132-
yield "\n```json\n" + delta.function_call["arguments"]
133-
else:
134-
function_to_call.arguments += delta.function_call["arguments"]
135-
yield delta.function_call["arguments"]
136-
if chunk["choices"][0]["finish_reason"] == "stop" and function_to_call.name:
137-
if function_to_call.name == "Changes":
138-
yield "```\n\n"
139-
print(
140-
f"\n\nFunc Call: {function_to_call.name}\n\n{function_to_call.arguments}"
141-
)
142-
args = self.process_json(function_to_call.arguments)
143-
function_response = self.function_map[function_to_call.name](**args)
144-
print(f"Func Response: {json.dumps(function_response.to_dict())}")
145-
if function_to_call.name == "Changes":
146-
diff = function_response.execute()
147-
# Show the diff back to the user
148-
yield diff
128+
yield from self.process_function_call(delta, i)
129+
if self.should_stop_and_has_function(chunk):
130+
yield from self.execute_function()
149131
else:
150132
yield delta.get("content")
151133

152-
def set_files_in_prompt(self, include_line_numbers: Optional[bool] = None) -> None:
153-
"""
154-
Sets the files in the prompt.
155-
156-
Args:
157-
files (List[File]): A list of files to be set in the prompt.
158-
include_line_numbers (Optional[bool]): Whether to include line numbers in the prompt.
159-
"""
160-
file_contents = self.codebase.get_file_contents()
161-
content = ""
162-
for k, v in file_contents.items():
163-
print(k in self.files_in_prompt)
164-
if k in self.files_in_prompt and include_line_numbers:
165-
v = self._add_line_numbers_to_content(v)
166-
content += f"{k}:\n{v}\n\n"
167-
elif k in self.files_in_prompt:
168-
content += f"{k}:\n{v}\n\n"
169-
170-
self.memory_manager.system_file_contents = content
171-
self.memory_manager.set_system()
172-
return
173-
174-
def _add_line_numbers_to_content(self, content: str) -> str:
175-
"""
176-
Adds line numbers to the given content.
177-
178-
Args:
179-
content (str): The content to add line numbers to.
180-
181-
Returns:
182-
str: The content with line numbers added.
183-
"""
184-
lines = content.split("\n")
185-
for i in range(len(lines)):
186-
lines[i] = f"{i+1} {lines[i]}"
187-
return "\n".join(lines)
134+
def process_function_call(self, delta, i):
135+
function_call = delta["function_call"]
136+
if "name" in function_call:
137+
self.function_to_call.name = function_call["name"]
138+
if "arguments" in function_call:
139+
if self.function_to_call.name == "Changes" and i == 0:
140+
yield "\n```json\n" + function_call["arguments"]
141+
else:
142+
self.function_to_call.arguments += function_call["arguments"]
143+
yield function_call["arguments"]
144+
145+
def should_stop_and_has_function(self, delta):
146+
return (
147+
delta["choices"][0]["finish_reason"] == "stop"
148+
and self.function_to_call.name # noqa 503
149+
)
150+
151+
def execute_function(self):
152+
if self.function_to_call.name == "Changes":
153+
yield "```\n\n"
154+
args = self.process_json(self.function_to_call.arguments)
155+
function_response = self.function_map[self.function_to_call.name](**args)
156+
if self.function_to_call.name == "Changes":
157+
diff = function_response.execute(self.codebase.directory)
158+
yield diff
188159

189160
def process_json(self, args: str) -> str:
190161
"""
@@ -217,25 +188,6 @@ def process_json(self, args: str) -> str:
217188

218189
return json.loads(response_str)
219190

220-
def generate_llama_prompt(self) -> str:
221-
"""
222-
Generates a prompt for the Code Llama model.
223-
224-
Args:
225-
input (str): The input text to be processed by the GPT-3 model.
226-
227-
Returns:
228-
str: The generated prompt.
229-
"""
230-
prompt = f"### System Prompt\n{self.memory_manager.system}\n\n"
231-
for message in self.memory_manager.get_messages():
232-
if message["role"].lower() == "user":
233-
prompt += f"### User Message\n{message['content']}\n\n"
234-
if message["role"].lower() == "assistant":
235-
prompt += f"### Assistant\n{message['content']}\n\n"
236-
237-
return prompt + "### Assistant"
238-
239191
def generate_anthropic_prompt(self) -> str:
240192
"""
241193
Generates a prompt for the Gaive model.
@@ -265,56 +217,7 @@ def call_model_streaming(self, **kwargs):
265217
if self.GPT_MODEL == "gpt-4" or self.GPT_MODEL == "gpt-3.5-turbo":
266218
for chunk in openai.ChatCompletion.create(**kwargs):
267219
yield chunk
268-
if self.GPT_MODEL == "code-llama":
269-
try:
270-
sm_client = boto3.client("sagemaker-runtime")
271-
endpoint = os.getenv("CODELLAMA_ENDPOINT")
272-
if not endpoint:
273-
raise ValueError("CODELLAMA_ENDPOINT environment variable not set")
274-
resp = sm_client.invoke_endpoint_with_response_stream(
275-
EndpointName=endpoint,
276-
Body=json.dumps(
277-
{
278-
"inputs": self.generate_llama_prompt(),
279-
"parameters": {
280-
"max_new_tokens": kwargs["max_tokens"],
281-
},
282-
}
283-
),
284-
ContentType="application/json",
285-
)
286-
except Exception as e:
287-
print(f"Error calling Code Llama: {e}")
288-
yield {
289-
"choices": [
290-
{
291-
"finish_reason": "stop",
292-
"delta": {"content": "Error: " + str(e)},
293-
}
294-
]
295-
}
296-
297-
while True:
298-
try:
299-
chunk = next(iter((resp["Body"])))
300-
bytes_to_send = chunk["PayloadPart"]["Bytes"]
301-
decoded_str = bytes_to_send.decode("utf-8")
302-
cleaned_str = decoded_str.replace(
303-
'{"generated_text": "', ""
304-
).replace('"}', "")
305-
cleaned_str = cleaned_str.encode().decode("unicode_escape")
306-
307-
yield {
308-
"choices": [
309-
{"finish_reason": "stop", "delta": {"content": cleaned_str}}
310-
]
311-
}
312-
except StopIteration:
313-
break
314220

315-
except UnboundLocalError:
316-
print("UnboundLocalError")
317-
break
318221
if self.GPT_MODEL == "anthropic":
319222
print("Calling anthropic")
320223
try:

backend/agent/agent_functions/changes.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
import os
22
import difflib as dl
33

4-
from dotenv import load_dotenv
54
from typing import List, Optional, Tuple
65
from pydantic import Field
76
from openai_function_call import OpenAISchema
87

9-
load_dotenv()
10-
DIRECTORY = os.getenv("PROJECT_DIRECTORY")
11-
128

139
class Change(OpenAISchema):
1410
"""
@@ -76,7 +72,7 @@ def apply_changes(self, changes: List[Change], content: str) -> str:
7672
"""
7773
for change in changes:
7874
new_content = self.replace_part_with_missing_leading_whitespace(
79-
whole_lines=content.split("\n"),
75+
whole_lines=content.splitlines(),
8076
part_lines=change.original.splitlines(),
8177
replace_lines=change.updated.splitlines(),
8278
)
@@ -99,25 +95,21 @@ def count_spaces(self, line: str) -> int:
9995
"""
10096
Counts the leading spaces in a line of code.
10197
"""
102-
spaces = 0
103-
for char in line:
104-
if char == " ":
105-
spaces += 1
106-
else:
107-
break
98+
spaces = len(line) - len(line.lstrip())
10899
return spaces
109100

110-
def execute(self) -> str:
101+
def execute(self, directory) -> str:
111102
"""
112103
Executes the changes on the file and returns a diff.
113104
"""
105+
DIRECTORY = directory
114106
relative_path = self.file_name.lstrip("/")
115107
file_path = os.path.join(DIRECTORY, relative_path)
116108
print(f"Directory: {DIRECTORY}")
117109
print(f"File Path: {file_path}")
118110
try:
119111
with open(file_path, "r") as f:
120-
current_contents = f.read()
112+
current_contents = f.readlines()
121113
# current_contents_with_line_numbers = "\n".join(
122114
# [
123115
# f"{i+1} {line}"
@@ -199,7 +191,7 @@ def replace_part_with_missing_leading_whitespace(
199191
del whole_lines[start:end]
200192
# add new lines and adjust indentation
201193
for i, line in enumerate(replace_lines):
202-
whole_lines.insert(start + i, spaces * " " + line)
194+
whole_lines.insert(start + i, " " * spaces + line)
203195

204196
return "\n".join(whole_lines)
205197

backend/app_setup.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,46 @@
55
from agent.agent import CodingAgent
66
from memory.memory_manager import MemoryManager
77
from database.my_codebase import MyCodebase
8-
from dotenv import load_dotenv
98
from fastapi import FastAPI
109
from fastapi.middleware.cors import CORSMiddleware
1110
from typing import Any, Callable
1211
from pydantic import BaseModel
1312
from agent.agent_functions.changes import Changes
1413
from agent.agent_functions.shell_commands import CommandPlan
1514

16-
IDENTITY = "You are an AI Pair Programmer and a world class python developer. Your role is to assist the Human in developing, debugging, and optimizing their project. Feel free to ask for more details if something isn't clear."
17-
IGNORE_DIRS=['node_modules','.next','.venv','__pycache__','.git']
18-
FILE_EXTENSIONS=['.js','.py','.md']
15+
IDENTITY = """
16+
# MISSION
17+
Act as Professor Synapse🧙🏾‍♂️, a conductor of expert agents. Your job is to support me in accomplishing my goals by finding alignment with me, then calling upon an expert agent perfectly suited to the task by initializing:
18+
19+
**Synapse_CoR** = "[emoji]: I am an expert in [role&domain]. I know [context]. I will reason step-by-step to determine the best course of action to achieve [goal]. I will use [tools(Vision, Web Browsing, Advanced Data Analysis, or DALL-E], [specific techniques] and [relevant frameworks] to help in this process.
20+
21+
Let's accomplish your goal by following these steps:
22+
23+
[3 reasoned steps]
24+
25+
My task ends when [completion].
26+
27+
[first step, question]"
28+
29+
# INSTRUCTIONS
30+
1. 🧙🏾‍♂️ Step back and gather context, relevant information and clarify my goals by asking questions
31+
2. Once confirmed, init Synapse_CoR
32+
3. After init, each output will ALWAYS follow the below format:
33+
-🧙🏾‍♂️: [align on my goal] and end with, "This is very important to me".
34+
-[emoji]: provide an [actionable response or deliverable] and end with an [open ended question], and omit [reasoned steps] and [completion]
35+
4.  Together 🧙🏾‍♂️ and [emoji] support me until goal is complete
36+
37+
# COMMANDS
38+
/start=🧙🏾‍♂️,introduce and begin with step one
39+
/save=🧙🏾‍♂️, #restate goal, #summarize progress, #reason next step
40+
41+
# RULES
42+
-use emojis liberally to express yourself
43+
-Start every output with 🧙🏾‍♂️: or [emoji]: to indicate who is speaking.
44+
-Keep responses actionable and practical for the user
45+
"""
46+
IGNORE_DIRS = ["node_modules", ".next", ".venv", "__pycache__", ".git"]
47+
FILE_EXTENSIONS = [".js", ".py", ".md"]
1948

2049
def create_database_connection() -> connection:
2150
try:
@@ -55,7 +84,13 @@ def setup_memory_manager(**kwargs) -> MemoryManager:
5584

5685

5786
def setup_codebase() -> MyCodebase:
58-
my_codebase = MyCodebase(directory=DIRECTORY, db_connection=DB_CONNECTION, file_extensions=FILE_EXTENSIONS, ignore_dirs=IGNORE_DIRS)
87+
my_codebase = MyCodebase(
88+
directory=DIRECTORY,
89+
db_connection=DB_CONNECTION,
90+
file_extensions=FILE_EXTENSIONS,
91+
ignore_dirs=IGNORE_DIRS,
92+
)
93+
5994
my_codebase.ignore_dirs = IGNORE_DIRS
6095
return my_codebase
6196

backend/database/my_codebase.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import os
22
import datetime
3-
from dotenv import load_dotenv
43
import tiktoken
5-
from typing import Dict
64

75

86
ENCODER = tiktoken.encoding_for_model("gpt-3.5-turbo")
@@ -12,7 +10,13 @@
1210
class MyCodebase:
1311
UPDATE_FULL = False
1412

15-
def __init__(self, directory: str = ".", db_connection=None, ignore_dirs=None, file_extensions=None):
13+
def __init__(
14+
self,
15+
directory: str = ".",
16+
db_connection=None,
17+
ignore_dirs=None,
18+
file_extensions=None,
19+
):
1620
self.directory = directory
1721
self.conn = db_connection
1822
self.cur = self.conn.cursor()
@@ -113,15 +117,6 @@ def create_tables(self) -> None:
113117
except Exception as e:
114118
print(f"Failed to create tables: {e}")
115119

116-
def get_file_contents(self) -> Dict[str, str]:
117-
self.cur.execute("SELECT file_path, text FROM files")
118-
results = self.cur.fetchall()
119-
out = {}
120-
for file_name, text in results:
121-
out.update({os.path.relpath(file_name, self.directory): text})
122-
print(f"\n\nGet File Contents: {out.keys()}")
123-
return out
124-
125120
def tree(self) -> str:
126121
tree = {}
127122
start_from = os.path.basename(self.directory)
@@ -189,6 +184,7 @@ def _is_valid_file(self, file_name):
189184
return (
190185
not file_name.startswith(".")
191186
and not file_name.startswith("_") # noqa 503
187+
and not file_name.endswith(".jsonl") # noqa 503
192188
and any( # noqa 503
193189
file_name.endswith(ext) for ext in self.file_extensions # noqa 503
194190
)

0 commit comments

Comments
 (0)