-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathlong_term_memory.py
More file actions
103 lines (76 loc) · 2.86 KB
/
Copy pathlong_term_memory.py
File metadata and controls
103 lines (76 loc) · 2.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import chromadb
from chromadb.utils import embedding_functions
embedding_function = embedding_functions.DefaultEmbeddingFunction()
client = chromadb.PersistentClient(path="_memories/long_term_memory")
long_term_memory = client.get_or_create_collection(name="long_term_memory", embedding_function=embedding_function, metadata={"hnsw:space": "cosine"})
def get_ltm(): # Only get documents and ids from long term memory
res = long_term_memory.get(
include=["documents", "metadatas"]
)
return res
def add_content_to_ltm(content_id, content, author, iteration, virality_score, sentiment_score):
ids: list = []
metadatas: list = []
documents: list = []
ids.append(content_id)
metadatas.append({"Author": author, "Iteration": iteration, "Virality Score": virality_score, "Sentiment Score": sentiment_score})
documents.append(content)
try:
long_term_memory.add(
ids=ids,
metadatas=metadatas,
documents=documents
)
except Exception as e:
print("Add data to db failed: ", e)
def is_content_in_ltm(content_id):
res = long_term_memory.get(
ids=content_id,
include=["documents"],
)
return bool(res["documents"])
def modify_ltm_virality_score(content_id, new_virality_score):
res = long_term_memory.get(
ids=content_id,
include=["metadatas"],
)
virality_score = res["metadatas"][0]["Virality Score"]
long_term_memory.update(
ids=content_id,
metadatas=[{"Virality Score": virality_score + new_virality_score}]
)
def modify_ltm_sentiment_score(content_id, new_sentiment_score):
res = long_term_memory.get(
ids=content_id,
include=["metadatas"],
)
current_sentiment_score = res["metadatas"][0]["Sentiment Score"]
long_term_memory.update(
ids=content_id,
metadatas=[{"Sentiment Score": current_sentiment_score + new_sentiment_score}]
)
def get_source_agent_from_ltm(content_id):
res = long_term_memory.get(
ids=content_id,
include=["documents", "metadatas"],
)
return res["metadatas"][0]["Author"]
def get_feedbacks_from_ltm(agent):
res = long_term_memory.get(
where={"Author": str(agent.name.lower())},
include=["documents", "metadatas"],
)
num_documents = len(res['documents'])
output_strings = []
for i in range(num_documents):
document = res['documents'][i]
virality_score = res['metadatas'][i]['Virality Score']
content_score = res['metadatas'][i]['Sentiment Score']
output_string = f"{document} - Virality Score: {virality_score} - Content Score: {content_score}"
output_strings.append(output_string)
final_output = "\n".join(output_strings)
return final_output
def clear_ltm():
long_term_memory.delete(
where={"Iteration": {"$gte": 0}}
)