Skip to content

Commit fe32d53

Browse files
authored
Add files via upload
1 parent f71a756 commit fe32d53

9 files changed

Lines changed: 68195 additions & 0 deletions

eval/dynamic_update.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
from utils import gpt_summarize, generate_id, get_timestamp, gpt_update_profile, gpt_generate_multi_summary
2+
3+
class DynamicUpdate:
4+
def __init__(self, short_term_memory, mid_term_memory, long_term_memory, topic_similarity_threshold=0.8, client=None):
5+
self.short_term_memory = short_term_memory
6+
self.mid_term_memory = mid_term_memory
7+
self.long_term_memory = long_term_memory
8+
self.topic_similarity_threshold = topic_similarity_threshold
9+
self.client = client
10+
self.last_evicted_page = None
11+
12+
def _is_conversation_continuing(self, previous_page, current_page):
13+
if not previous_page:
14+
return False
15+
16+
prompt = """Determine if these two conversation pages are continuous (true continuation without topic shift).
17+
Return ONLY "true" or "false".
18+
19+
Previous Page:
20+
User: {prev_user}
21+
Assistant: {prev_agent}
22+
23+
Current Page:
24+
User: {curr_user}
25+
Assistant: {curr_agent}
26+
27+
Continuous?""".format(
28+
prev_user=previous_page.get("user_input", ""),
29+
prev_agent=previous_page.get("agent_response", ""),
30+
curr_user=current_page.get("user_input", ""),
31+
curr_agent=current_page.get("agent_response", "")
32+
)
33+
34+
messages = [
35+
{"role": "system", "content": "You are a conversation continuity detector. Return ONLY 'true' or 'false'."},
36+
{"role": "user", "content": prompt}
37+
]
38+
39+
response = self.client.chat_completion(
40+
model="gpt-4o-mini",
41+
messages=messages,
42+
temperature=0.0,
43+
max_tokens=10
44+
)
45+
46+
return response.strip().lower() == "true"
47+
48+
def _generate_meta_info(self, last_page_meta, current_page):
49+
"""
50+
基于上一页的meta-info和当前页内容生成新的meta-info
51+
:param last_page_meta: 上一页的meta-info内容
52+
:param current_page: 当前页的对话内容
53+
:return: 更新后的meta-info
54+
"""
55+
current_conversation = f"User: {current_page.get('user_input', '')}\nAssistant: {current_page.get('agent_response', '')}"
56+
57+
prompt = """Update the conversation meta-summary by incorporating the new dialogue while maintaining continuity.
58+
59+
Guidelines:
60+
1. Start from the previous meta-summary (if exists)
61+
2. Add/update information based on the new dialogue
62+
3. Keep it concise (1-2 sentences max)
63+
4. Maintain context coherence
64+
65+
Previous Meta-summary: {last_meta}
66+
New Dialogue:
67+
{new_dialogue}
68+
69+
Updated Meta-summary:""".format(
70+
last_meta=last_page_meta if last_page_meta else "None",
71+
new_dialogue=current_conversation
72+
)
73+
74+
messages = [
75+
{"role": "system", "content": """You are a conversation meta-summary updater. Your task is to:
76+
1. Preserve relevant context from previous meta-summary
77+
2. Integrate new information from current dialogue
78+
3. Output ONLY the updated summary (no explanations)"""},
79+
{"role": "user", "content": prompt}
80+
]
81+
82+
return self.client.chat_completion(
83+
model="gpt-4o-mini",
84+
messages=messages,
85+
temperature=0.3,
86+
max_tokens=100
87+
).strip()
88+
89+
def _update_connected_pages(self, page_id, new_meta_info):
90+
connected_pages = []
91+
current_page = self.mid_term_memory.get_page_by_id(page_id)
92+
93+
if not current_page:
94+
return
95+
96+
prev_page_id = current_page.get("pre_page")
97+
while prev_page_id:
98+
prev_page = self.mid_term_memory.get_page_by_id(prev_page_id)
99+
if prev_page:
100+
connected_pages.insert(0, prev_page)
101+
prev_page_id = prev_page.get("pre_page")
102+
else:
103+
break
104+
105+
next_page_id = current_page.get("next_page")
106+
while next_page_id:
107+
next_page = self.mid_term_memory.get_page_by_id(next_page_id)
108+
if next_page:
109+
connected_pages.append(next_page)
110+
next_page_id = next_page.get("next_page")
111+
else:
112+
break
113+
114+
for page in connected_pages:
115+
page["meta_info"] = new_meta_info
116+
self.mid_term_memory.update_page_connections(page.get("pre_page"), page.get("next_page"))
117+
118+
def update_short_term(self, message):
119+
self.short_term_memory.add_qa_pair(message)
120+
121+
def bulk_evict_and_update_mid_term(self):
122+
evicted = []
123+
# 1. 从短期记忆移除内容(保持不变)
124+
while self.short_term_memory.is_full():
125+
msg = self.short_term_memory.pop_oldest()
126+
if msg and msg.get("user_input") and msg.get("agent_response"):
127+
evicted.append(msg)
128+
129+
if not evicted:
130+
return
131+
132+
# 2. 先创建基础页面结构并进行连续性处理
133+
pages = []
134+
for qa in evicted:
135+
page = {
136+
"page_id": generate_id("page"),
137+
"user_input": qa.get("user_input", ""),
138+
"agent_response": qa.get("agent_response", ""),
139+
"timestamp": qa.get("timestamp"),
140+
"preloaded": False,
141+
"analyzed": False,
142+
"pre_page": None,
143+
"next_page": None,
144+
"meta_info": None
145+
}
146+
147+
# 连续性判断
148+
is_continuous = self._is_conversation_continuing(self.last_evicted_page, page)
149+
if is_continuous and self.last_evicted_page:
150+
page["pre_page"] = self.last_evicted_page["page_id"]
151+
self.last_evicted_page["next_page"] = page["page_id"]
152+
153+
# 更新元信息
154+
last_meta = self.last_evicted_page.get("meta_info")
155+
new_meta_info = self._generate_meta_info(last_meta, page)
156+
page["meta_info"] = new_meta_info
157+
self._update_connected_pages(page["pre_page"], new_meta_info)
158+
else:
159+
page["meta_info"] = self._generate_meta_info(None, page)
160+
161+
pages.append(page)
162+
self.last_evicted_page = page
163+
164+
# 3. 将所有用户输入拼接用于主题分析
165+
input_text = "\n".join([f"User: {page.get('user_input','')}\n" for page in pages])
166+
print("动态更新:调用 GPT 生成多子主题摘要...")
167+
multi_summary = gpt_generate_multi_summary(input_text, self.client)
168+
169+
# 4. 按主题分组插入中期记忆
170+
for summary_dict in multi_summary.get("summaries", []):
171+
sub_summary = summary_dict.get("content", "")
172+
sub_key_words = summary_dict.get("keywords", [])
173+
174+
print(f"动态更新:处理子主题【{summary_dict.get('theme','')}】,插入中期记忆...")
175+
self.mid_term_memory.insert_pages_into_session(
176+
sub_summary,
177+
sub_key_words,
178+
pages, # 传入已经处理好的完整pages
179+
self.topic_similarity_threshold
180+
)
181+
182+
def update_long_term(self, user_id, new_profile_data, knowledge_text):
183+
print("动态更新:更新长期记忆中的用户画像和私有数据...")
184+
self.long_term_memory.update_user_profile(user_id, new_profile_data)
185+
self.long_term_memory.add_knowledge(knowledge_text)

eval/evalution_loco.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import json
2+
import re
3+
from typing import List, Dict
4+
from collections import defaultdict
5+
import statistics
6+
7+
def simple_tokenize(text: str) -> List[str]:
8+
"""Simple tokenization function."""
9+
if not text:
10+
return []
11+
12+
# Convert to string if not already
13+
text = str(text).lower()
14+
# Remove punctuation and split by whitespace using regex (正确的方法)
15+
tokens = re.findall(r'\b\w+\b', text)
16+
return tokens
17+
18+
def calculate_f1(prediction: str, reference: str) -> float:
19+
"""Calculate F1 score for prediction against reference."""
20+
# Tokenize both prediction and reference
21+
pred_tokens = set(simple_tokenize(prediction))
22+
ref_tokens = set(simple_tokenize(reference))
23+
24+
# Calculate intersection
25+
common_tokens = pred_tokens & ref_tokens
26+
27+
# Calculate precision and recall
28+
precision = len(common_tokens) / len(pred_tokens) if len(pred_tokens) > 0 else 0
29+
recall = len(common_tokens) / len(ref_tokens) if len(ref_tokens) > 0 else 0
30+
31+
# Calculate F1 score
32+
if precision + recall > 0:
33+
f1 = 2 * (precision * recall) / (precision + recall)
34+
else:
35+
f1 = 0
36+
return f1
37+
38+
def load_data(file_path: str) -> List[Dict]:
39+
"""Load data from a JSON file."""
40+
with open(file_path, 'r', encoding='utf-8') as file:
41+
data = json.load(file)
42+
return data
43+
44+
def main(file_path: str):
45+
"""Main function to calculate average F1 scores per category."""
46+
# Load data from file
47+
data = load_data(file_path)
48+
49+
# Initialize category dictionary
50+
category_f1 = defaultdict(list)
51+
52+
# Calculate F1 scores for each sample
53+
for sample in data:
54+
category = sample['category']
55+
system_answer = sample['system_answer']
56+
original_answer = sample['original_answer']
57+
58+
# Calculate F1 score
59+
f1 = calculate_f1(system_answer, original_answer)
60+
61+
# Append F1 score to the corresponding category
62+
category_f1[category].append(f1)
63+
64+
# Calculate and print average F1 scores for each category
65+
for category, f1_scores in category_f1.items():
66+
avg_f1 = statistics.mean(f1_scores)
67+
print(f"Category {category}: Average F1 Score = {avg_f1:.4f}")
68+
69+
if __name__ == "__main__":
70+
file_path = "all_loco_results.json" # 使用main_loco_parse.py生成的文件
71+
main(file_path)

0 commit comments

Comments
 (0)