Skip to content

Commit b5fbb9f

Browse files
authored
Merge pull request #73 from reactome/prompts-inputs-updates
Prompting improvements for retrieval
2 parents 3ba1f95 + 74e19db commit b5fbb9f

7 files changed

Lines changed: 69 additions & 35 deletions

File tree

.gitignore

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,13 @@ cython_debug/
159159
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
160160
#.idea/
161161

162+
.files/
163+
.ruff_cache/
164+
reactome_github/
165+
reactomegit/
166+
get-pip.py
167+
.DS_Store
168+
162169
.chainlit/translations/*
163170
!.chainlit/translations/en-US.json
164171
csv_files/

bin/chat-fastapi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ async def landing_page():
279279
<div class="button-container">
280280
<a class="button" href="$CHAINLIT_URL/chat/guest/" target="_blank">Guest Access</a>
281281
<a class="button" href="$CHAINLIT_URL/chat/personal/" target="_blank">Log In</a>
282-
<a class="button feedback-button" href="https://docs.google.com/forms/d/e/1FAIpQLSeWajgdJGV2gETj2bo-_jqU54Ryy6d7acJkvMo-KkflYUmfTg/viewform" target="_blank">Feedback</a>
282+
<a class="button feedback-button" href="https://forms.gle/Rvzb8EA73yZs7wd38" target="_blank">Feedback</a>
283283
</div>
284284
285285
<p class="left-justified">Choose <strong>Guest Access</strong> to try the chatbot out. <strong>Log In</strong> will give an increased query allowance and securely stores your chat history so you can save and continue conversations.</p>

src/conversational_chain/graph.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ class AdditionalContent(TypedDict):
3737

3838

3939
class ChatState(TypedDict):
40-
input: str # User input text
41-
query: str # LLM-generated query from user input
40+
user_input: str # User input text
41+
rephrased_input: str # LLM-generated query from user input
4242
chat_history: Annotated[list[BaseMessage], add_messages]
4343
context: list[Document]
4444
answer: str # primary LLM response that is streamed to the user
@@ -103,21 +103,22 @@ async def preprocess(
103103
self, state: ChatState, config: RunnableConfig
104104
) -> dict[str, str]:
105105
query: str = await self.rephrase_chain.ainvoke(state, config)
106-
return {"query": query}
106+
return {"rephrased_input": query}
107107

108108
async def call_model(
109109
self, state: ChatState, config: RunnableConfig
110110
) -> dict[str, Any]:
111111
result: dict[str, Any] = await self.rag_chain.ainvoke(
112112
{
113-
"input": state["query"],
113+
"input": state["rephrased_input"],
114+
"user_input": state["user_input"],
114115
"chat_history": state["chat_history"],
115116
},
116117
config,
117118
)
118119
return {
119120
"chat_history": [
120-
HumanMessage(state["input"]),
121+
HumanMessage(state["user_input"]),
121122
AIMessage(result["answer"]),
122123
],
123124
"context": result["context"],
@@ -130,7 +131,7 @@ async def postprocess(
130131
search_results: list[WebSearchResult] = []
131132
if config["configurable"]["enable_postprocess"]:
132133
result: dict[str, Any] = await self.search_workflow.ainvoke(
133-
{"question": state["query"], "generation": state["answer"]},
134+
{"question": state["rephrased_input"], "generation": state["answer"]},
134135
config=RunnableConfig(callbacks=config["callbacks"]),
135136
)
136137
search_results = result["search_results"]
@@ -149,7 +150,7 @@ async def ainvoke(
149150
if self.graph is None:
150151
self.graph = await self.initialize()
151152
result: dict[str, Any] = await self.graph.ainvoke(
152-
{"input": user_input},
153+
{"user_input": user_input},
153154
config=RunnableConfig(
154155
callbacks=callbacks,
155156
configurable={

src/evaluation/evaluator.py

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from langchain_community.retrievers import BM25Retriever
1212
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
1313
from ragas import evaluate
14-
from ragas.metrics import answer_relevancy, context_utilization, faithfulness
14+
from ragas.metrics import (answer_relevancy, context_recall,
15+
context_utilization, faithfulness)
1516

1617
from conversational_chain.chain import create_rag_chain
1718
from reactome.metadata_info import descriptions_info, field_info
@@ -34,6 +35,12 @@ def parse_arguments():
3435
default="gpt-4o-mini",
3536
help="Language model to use for evaluation",
3637
)
38+
parser.add_argument(
39+
"--rag_type",
40+
choices=["basic", "advanced"],
41+
required=True,
42+
help="Type of RAG system to use for evaluation",
43+
)
3744
return parser.parse_args()
3845

3946

@@ -50,8 +57,8 @@ def load_dataset(testset_path):
5057
raise ValueError(f"Error reading the Excel file: {e}")
5158

5259

53-
def initialize_rag_chain(embeddings_directory, model_name):
54-
"""Initialize the RAG chain system."""
60+
def initialize_rag_chain_with_memory(embeddings_directory, model_name, rag_type):
61+
"""Initialize the RAGChainWithMemory system."""
5562
llm = ChatOpenAI(temperature=0.0, verbose=True, model=model_name)
5663
retriever_list = []
5764

@@ -60,7 +67,7 @@ def initialize_rag_chain(embeddings_directory, model_name):
6067
)
6168
data = loader.load()
6269
bm25_retriever = BM25Retriever.from_documents(data)
63-
bm25_retriever.k = 15
70+
bm25_retriever.k = 7
6471

6572
# Set up vectorstore SelfQuery retriever
6673
embedding = OpenAIEmbeddings(model="text-embedding-3-large")
@@ -69,17 +76,22 @@ def initialize_rag_chain(embeddings_directory, model_name):
6976
embedding_function=embedding,
7077
)
7178

79+
vectordb_retriever = vectordb.as_retriever(search_kwargs={"k": 7})
80+
7281
selfq_retriever = SelfQueryRetriever.from_llm(
7382
llm=llm,
7483
vectorstore=vectordb,
7584
document_contents=descriptions_info["summations"],
7685
metadata_field_info=field_info["summations"],
77-
search_kwargs={"k": 15},
86+
search_kwargs={"k": 7},
7887
)
7988
rrf_retriever = EnsembleRetriever(
8089
retrievers=[bm25_retriever, selfq_retriever], weights=[0.2, 0.8]
8190
)
82-
retriever_list.append(rrf_retriever)
91+
if rag_type == "basic":
92+
retriever_list.append(vectordb_retriever)
93+
elif rag_type == "advanced":
94+
retriever_list.append(rrf_retriever)
8395

8496
reactome_retriever = MergerRetriever(retrievers=retriever_list)
8597

@@ -91,11 +103,15 @@ def initialize_rag_chain(embeddings_directory, model_name):
91103

92104

93105
def process_testset(
94-
testset_path, qa_system, embeddings_directory, response_dir, eval_dir, model_name
106+
testset_path,
107+
qa_system,
108+
embeddings_directory,
109+
response_dir,
110+
eval_dir,
111+
model_name,
112+
rag_type,
95113
):
96114
"""Process a single testset file."""
97-
args = parse_arguments()
98-
99115
testset = load_dataset(testset_path)
100116
questions = [item["question"] for item in testset]
101117
ground_truths = [item["ground_truth"] for item in testset]
@@ -108,6 +124,11 @@ def process_testset(
108124
answers.append(response["answer"])
109125
contexts.append([context.page_content for context in response["context"]])
110126

127+
rag_response_dir = os.path.join(response_dir, rag_type)
128+
rag_eval_dir = os.path.join(eval_dir, rag_type)
129+
os.makedirs(rag_response_dir, exist_ok=True)
130+
os.makedirs(rag_eval_dir, exist_ok=True)
131+
111132
# Save responses to an Excel file
112133
data = {
113134
"question": questions,
@@ -117,8 +138,8 @@ def process_testset(
117138
}
118139
df_ans = pd.DataFrame(data)
119140
response_filename = os.path.join(
120-
response_dir,
121-
f"{os.path.splitext(os.path.basename(testset_path))[0]}_{args.model}_responses.xlsx",
141+
rag_response_dir,
142+
f"{os.path.splitext(os.path.basename(testset_path))[0]}_{model_name}_responses_{rag_type}.xlsx",
122143
)
123144
df_ans.to_excel(response_filename, index=False)
124145
print(f"Responses saved to {response_filename}")
@@ -128,13 +149,13 @@ def process_testset(
128149
result = evaluate(
129150
llm=ChatOpenAI(temperature=0.0, verbose=True, model="gpt-4o"),
130151
dataset=dataset,
131-
metrics=[answer_relevancy, context_utilization, faithfulness],
152+
metrics=[answer_relevancy, context_utilization, faithfulness, context_recall],
132153
)
133154

134155
# Save evaluation results to an Excel file
135156
evaluation_filename = os.path.join(
136-
eval_dir,
137-
f"{os.path.splitext(os.path.basename(testset_path))[0]}_{args.model}_evaluation.xlsx",
157+
rag_eval_dir,
158+
f"{os.path.splitext(os.path.basename(testset_path))[0]}_{model_name}_evaluation_{rag_type}.xlsx",
138159
)
139160
df_eval = result.to_pandas()
140161
df_eval.to_excel(evaluation_filename, index=False)
@@ -144,14 +165,17 @@ def process_testset(
144165
def main():
145166
args = parse_arguments()
146167
model_name = args.model
168+
rag_type = args.rag_type
147169
response_dir = os.path.join(args.testset_dir, "response")
148170
eval_dir = os.path.join(args.testset_dir, "evals")
149171
os.makedirs(response_dir, exist_ok=True)
150172
os.makedirs(eval_dir, exist_ok=True)
151173

152174
# Initialize RAG Chain
153175
embeddings_directory = "/Users/hmohammadi/Desktop/react_to_me_github/reactome_chatbot/embeddings/openai/text-embedding-3-large/reactome/Release90/summations"
154-
qa_system = initialize_rag_chain(embeddings_directory, model_name)
176+
qa_system = initialize_rag_chain_with_memory(
177+
embeddings_directory, model_name, rag_type
178+
)
155179

156180
# Iterate over all .xlsx files in the directory
157181
for filename in os.listdir(args.testset_dir):
@@ -166,6 +190,7 @@ def main():
166190
response_dir,
167191
eval_dir,
168192
model_name,
193+
rag_type,
169194
)
170195

171196

src/retreival_chain.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def create_retrieval_chain(
108108
loader = CSVLoader(file_path=reactome_csvs_dir / csv_file_name)
109109
data = loader.load()
110110
bm25_retriever = BM25Retriever.from_documents(data)
111-
bm25_retriever.k = 15
111+
bm25_retriever.k = 10
112112

113113
# set up vectorstore SelfQuery retriever
114114
embedding = embedding_callable()
@@ -123,7 +123,7 @@ def create_retrieval_chain(
123123
vectorstore=vectordb,
124124
document_contents=descriptions_info[subdirectory],
125125
metadata_field_info=field_info[subdirectory],
126-
search_kwargs={"k": 15},
126+
search_kwargs={"k": 10},
127127
)
128128
rrf_retriever = EnsembleRetriever(
129129
retrievers=[bm25_retriever, selfq_retriever], weights=[0.2, 0.8]

src/system_prompt/reactome_prompt.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,22 @@
33
# Contextualize question prompt
44
contextualize_q_system_prompt = """
55
You are an expert in question formulation with deep expertise in molecular biology and experience as a Reactome curator. Your task is to analyze the conversation history and the user’s latest query to fully understand their intent and what they seek to learn.
6-
Reformulate the user’s question into a standalone version that retains its full meaning without requiring prior context. The reformulated question should be:**
7-
Clear, concise, and precise
8-
Optimized for both vector search (semantic meaning) and case-sensitive keyword search
9-
Faithful to the user’s intent and scientific accuracy
10-
If the user’s question is already self-contained and well-formed, return it as is.
6+
If the user's question is not in English, reformulate the question and translate it to English, ensuring the meaning and intent are preserved.
7+
Reformulate the user’s question into a standalone version that retains its full meaning without requiring prior context. The reformulated question should be:
8+
- Clear, concise, and precise
9+
- Optimized for both vector search (semantic meaning) and case-sensitive keyword search
10+
- Faithful to the user’s intent and scientific accuracy
11+
12+
the returned question should always be in English.
13+
If the user’s question is already in English, self-contained and well-formed, return it as is.
1114
Do NOT answer the question or provide explanations.
1215
"""
1316

1417
contextualize_q_prompt = ChatPromptTemplate.from_messages(
1518
[
1619
("system", contextualize_q_system_prompt),
1720
MessagesPlaceholder(variable_name="chat_history"),
18-
("human", "{input}"),
21+
("human", "{user_input}"),
1922
]
2023
)
2124

@@ -46,6 +49,6 @@
4649
[
4750
("system", qa_system_prompt),
4851
MessagesPlaceholder(variable_name="chat_history"),
49-
("user", "Context:\n{context}\n\nQuestion: {input}"),
52+
("user", "Context:\n{context}\n\nQuestion: {user_input}"),
5053
]
5154
)

src/util/chainlit_helpers.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,7 @@ async def message_rate_limited(config: Config | None) -> bool:
8383
quota_message = "Our servers are currently overloaded.\n"
8484
login_uri: str | None = os.getenv("CHAINLIT_URI_LOGIN", "")
8585
if login_uri:
86-
quota_message += (
87-
f"Please [log in]({login_uri}) to continue chatting and enjoy features like saved chat history and fewer limits."
88-
)
86+
quota_message += f"Please [log in]({login_uri}) to continue chatting and enjoy features like saved chat history and fewer limits."
8987
else:
9088
quota_message += "Please try again later."
9189
await send_messages([quota_message])

0 commit comments

Comments
 (0)