-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathfilter.py
More file actions
128 lines (94 loc) · 3.93 KB
/
filter.py
File metadata and controls
128 lines (94 loc) · 3.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#%%
import json
with open('../benchmark/extractions.json', 'r', encoding='utf-8') as f:
extractions = json.load(f)["institutes"]
with open('../benchmark/descriptions.json', 'r', encoding='utf-8') as f:
descriptions = json.load(f)["institutes"]
with open('./examples.json', 'r', encoding='utf-8') as f:
examples = json.load(f)["institutes"]
print(extractions[0], descriptions[0], examples[0])
with open('../benchmark/Query/institutes/SFW.sql', 'r', encoding='utf-8') as f:
content = f.read()
sql_blocks = content.split('--------------------------------------------------')
first_10_sql = [sql_blocks[i].split('\n\n')[-1] for i in range(len(sql_blocks)-1)]
# %%
#sql
import re
def anlysis_sfw_sql(sql):
select_col = re.search(r'SELECT (.+?)\s+FROM', sql, re.I).group(1).strip()
select_cols = [c.strip().lower() for c in select_col.split(',')]
select_indices = [extractions.index(col.lower()) for col in select_cols if col.lower() in extractions]
where = re.search(r'WHERE (.*);', sql, re.I).group(1)
attr_names = re.findall(
r'([A-Za-z_][A-Za-z0-9_]*)\s*(?:==|=|!=|<>|>=|<=|>|<)',
where
)
attr_indices = [extractions.index(attr.lower()) for attr in attr_names if attr.lower() in extractions]
return select_indices, where, attr_indices
#%%
import pandas as pd
import lotus
from lotus.types import CascadeArgs, ProxyModel
from lotus.models import LM
import os
import time
import numpy as np
os.environ["OPENAI_API_BASE"] = ""
os.environ["OPENAI_API_KEY"] = ""
csv_path = "../benchmark/ground_truth/institutes.csv"
base_dir = "../benchmark/datasets/institutes"
df = pd.read_csv(csv_path)
ids = df["ID"].dropna().astype(str).tolist()
data = {'context':[]}
for id_value in ids:
file_path = os.path.join(base_dir, f"{id_value}.txt")
if os.path.exists(file_path):
with open(file_path, "r", encoding="utf-8") as f:
content = f.read().strip()
data['context'].append(content)
else:
data['context'].append('')
print(len(data['context']))
# %%
for j in range(len(sql_blocks)-1):
lm = LM(model="gpt-4.1-mini")
lotus.settings.configure(lm=lm)
folder = f'./results/institutes/SFW/SQL{j}'
os.makedirs(folder, exist_ok=True)
sql = first_10_sql[j]
select_indices, where, attr_indices = anlysis_sfw_sql(sql)
print(select_indices, where, attr_indices)
df = pd.DataFrame(data)
user_instruction = "{context}."
for i in attr_indices:
user_instruction += descriptions[i]
user_instruction += where
filtered_df = df.sem_filter(
user_instruction, strategy="Cot", return_all=True, return_explanations=False
)
print(filtered_df)
print('#########')
filtered_indices = filtered_df[filtered_df["filter_label"] == True].index
contents = [data['context'][i] for i in filtered_indices]
filter_data = {'context':contents}
df_filter = pd.DataFrame(filter_data)
df_data = {"ID":[ids[i] for i in filtered_indices]}
for i in select_indices:
att = extractions[i]
description = descriptions[i]
example = examples[i]
# print(example)
examples_df = pd.DataFrame(example)
user_instruction = "What" + att + "in {context}?" + description + "If there are multiple values, separate them with '||' and leave empty if not applicable. Please keep each extracted value concise and avoid lengthy content."
df_test = df_filter.sem_map(user_instruction, examples=examples_df)
# df_test = df_filter.sem_map(user_instruction)
df_data[att] = df_test['_map'].tolist()
call = len(ids)+len(select_indices)*len(filtered_indices)
print(f'LLM-call:{call}')
print('-------')
# lm.print_total_usage()
df_final = pd.DataFrame(df_data)
df_final = df_final.map(lambda x: np.nan if isinstance(x, str) and "empty" in x else x)
# print(df_final.iloc[0])
df_final.to_csv(folder+'/results.csv', index=False, encoding='utf-8-sig')
# %%