Skip to content

Commit e882c5c

Browse files
authored
Feature: Add feedback reaction button and add chat history v2 (#112)
* frontend: Add feedback reaction buttons and improve chat history handling * change to safe markdown syntax * Update frontend/utils/feedback.py --------- Signed-off-by: error9098x <provantablack@gmail.com> Signed-off-by: Jack Luar <jluar@precisioninno.com>
1 parent 423d70e commit e882c5c

4 files changed

Lines changed: 167 additions & 89 deletions

File tree

frontend/requirements-test.txt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
streamlit==1.37.0
1+
streamlit==1.40.2
22
requests==2.32.3
33
requests-oauthlib==2.0.0
4-
Pillow==10.3.0
4+
Pillow==11.0.0
55
pytz==2024.1
66
google-auth==2.30.0
77
google-auth-httplib2==0.2.0
@@ -13,4 +13,5 @@ flask==3.0.3
1313
types-pytz==2024.1.0.20240417
1414
types-requests==2.32.0.20240622
1515
pre-commit==3.7.1
16-
ruff==0.5.1
16+
ruff==0.5.1
17+
mypy==1.10.1

frontend/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
streamlit==1.37.0
1+
streamlit==1.40.2
22
requests==2.32.3
33
requests-oauthlib==2.0.0
4-
Pillow==10.3.0
4+
Pillow==11.0.0
55
pytz==2024.1
66
google-auth==2.30.0
77
google-auth-httplib2==0.2.0

frontend/streamlit_app.py

Lines changed: 141 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
import os
77
import ast
88
from PIL import Image
9-
from utils.feedback import show_feedback_form
9+
from utils.feedback import (
10+
show_feedback_form,
11+
submit_feedback_to_google_sheet,
12+
get_git_commit_hash,
13+
)
1014
from dotenv import load_dotenv
1115
from typing import Callable, Any
1216

@@ -22,6 +26,24 @@ def wrapper(*args: Any, **kwargs: Any) -> tuple[Any, float]:
2226
return wrapper
2327

2428

29+
def translate_chat_history_to_api(chat_history, max_pairs=4):
30+
api_format = []
31+
relevant_history = [
32+
msg for msg in chat_history[1:] if msg["role"] in ["user", "ai"]
33+
]
34+
35+
i = len(relevant_history) - 1
36+
while i > 0 and len(api_format) < max_pairs:
37+
ai_msg = relevant_history[i]
38+
user_msg = relevant_history[i - 1]
39+
if ai_msg["role"] == "ai" and user_msg["role"] == "user":
40+
api_format.insert(0, {"User": user_msg["content"], "AI": ai_msg["content"]})
41+
i -= 2
42+
else:
43+
i -= 1
44+
return api_format
45+
46+
2547
@measure_response_time
2648
def response_generator(user_input: str) -> tuple[str, str] | tuple[None, None]:
2749
"""
@@ -34,74 +56,47 @@ def response_generator(user_input: str) -> tuple[str, str] | tuple[None, None]:
3456
- tuple: Contains the AI response and sources.
3557
"""
3658
url = f"{st.session_state.base_url}{st.session_state.selected_endpoint}"
37-
3859
headers = {"accept": "application/json", "Content-Type": "application/json"}
39-
40-
payload = {"query": user_input, "list_sources": True, "list_context": True}
41-
60+
chat_history = translate_chat_history_to_api(st.session_state.chat_history)
61+
payload = {
62+
"query": user_input,
63+
"list_sources": True,
64+
"list_context": True,
65+
"chat_history": chat_history,
66+
}
4267
try:
4368
response = requests.post(url, headers=headers, json=payload)
4469
response.raise_for_status()
45-
46-
try:
47-
data = response.json()
48-
if not isinstance(data, dict):
49-
st.error("Invalid response format")
50-
return None, None
51-
except ValueError:
52-
st.error("Failed to decode JSON response")
70+
data = response.json()
71+
if not isinstance(data, dict):
72+
st.error("Invalid response format")
5373
return None, None
54-
5574
sources = data.get("sources", "")
5675
st.session_state.metadata[user_input] = {
5776
"sources": sources,
5877
"context": data.get("context", ""),
5978
}
60-
6179
return data.get("response", ""), sources
62-
6380
except requests.exceptions.RequestException as e:
6481
st.error(f"Request failed: {e}")
6582
return None, None
6683

6784

68-
def fetch_endpoints() -> tuple[str, list[str]]:
69-
base_url = os.getenv("CHAT_ENDPOINT", "http://localhost:8000")
70-
url = f"{base_url}/chains/listAll"
71-
try:
72-
response = requests.get(url)
73-
response.raise_for_status()
74-
endpoints = response.json()
75-
return base_url, endpoints
76-
except requests.exceptions.RequestException as e:
77-
st.error(f"Failed to fetch endpoints: {e}")
78-
return base_url, []
79-
80-
8185
def main() -> None:
8286
load_dotenv()
83-
8487
img = Image.open("assets/or_logo.png")
8588
st.set_page_config(page_title="OR Assistant", page_icon=img)
8689

8790
deployment_time = datetime.datetime.now(pytz.timezone("UTC"))
88-
st.info(f"Deployment time: {deployment_time.strftime('%m/%d/%Y %H:%M')} UTC")
91+
st.info(f'Deployment time: {deployment_time.strftime("%m/%d/%Y %H:%M")} UTC')
8992

9093
st.title("OR Assistant")
9194

92-
base_url, endpoints = fetch_endpoints()
93-
94-
selected_endpoint = st.selectbox(
95-
"Select preferred endpoint",
96-
options=endpoints,
97-
index=0,
98-
format_func=lambda x: x.split("/")[-1].capitalize(),
99-
)
95+
base_url = os.getenv("CHAT_ENDPOINT", "http://localhost:8000")
96+
selected_endpoint = "/graphs/agent-retriever"
10097

10198
if "selected_endpoint" not in st.session_state:
10299
st.session_state.selected_endpoint = selected_endpoint
103-
else:
104-
st.session_state.selected_endpoint = selected_endpoint
105100

106101
if "base_url" not in st.session_state:
107102
st.session_state.base_url = base_url
@@ -115,6 +110,8 @@ def main() -> None:
115110
st.session_state.chat_history = []
116111
if "metadata" not in st.session_state:
117112
st.session_state.metadata = {}
113+
if "sources" not in st.session_state:
114+
st.session_state.sources = {}
118115

119116
if not st.session_state.chat_history:
120117
st.session_state.chat_history.append(
@@ -124,10 +121,42 @@ def main() -> None:
124121
}
125122
)
126123

127-
for message in st.session_state.chat_history:
124+
for idx, message in enumerate(st.session_state.chat_history):
128125
with st.chat_message(message["role"]):
129126
st.markdown(message["content"])
130127

128+
if message["role"] == "ai" and idx > 0:
129+
user_message = st.session_state.chat_history[idx - 1]
130+
if user_message["role"] == "user":
131+
user_input = user_message["content"]
132+
sources = st.session_state.sources.get(user_input)
133+
with st.expander("Sources:"):
134+
try:
135+
if sources:
136+
if isinstance(sources, str):
137+
cleaned_sources = sources.replace("{", "[").replace(
138+
"}", "]"
139+
)
140+
parsed_sources = ast.literal_eval(cleaned_sources)
141+
else:
142+
parsed_sources = sources
143+
if (
144+
isinstance(parsed_sources, (list, set))
145+
and parsed_sources
146+
):
147+
sources_list = "\n".join(
148+
f"- [{link}]({link})"
149+
for link in parsed_sources
150+
if link.strip()
151+
)
152+
st.markdown(sources_list)
153+
else:
154+
st.markdown("No Sources Attached.")
155+
else:
156+
st.markdown("No Sources Attached.")
157+
except (ValueError, SyntaxError) as e:
158+
st.markdown(f"Failed to parse sources: {e}")
159+
131160
user_input = st.chat_input("Enter your queries ...")
132161

133162
if user_input:
@@ -146,62 +175,69 @@ def main() -> None:
146175
):
147176
response, sources = response_tuple
148177
if response is not None:
149-
response_buffer = ""
178+
response_buffer = response
150179

151180
with st.chat_message("ai"):
152181
message_placeholder = st.empty()
153-
154-
response_buffer = ""
155-
for chunk in response.split(" "):
156-
response_buffer += chunk + " "
157-
if chunk.endswith("\n"):
158-
response_buffer += " "
159-
message_placeholder.markdown(response_buffer)
160-
time.sleep(0.05)
161-
162182
message_placeholder.markdown(response_buffer)
163183

184+
# Display response time
164185
response_time_text = (
165186
f"Response Time: {response_time / 1000:.2f} seconds"
166187
)
167-
response_time_colored = f":{'green' if response_time < 5000 else 'orange' if response_time < 10000 else 'red'}[{response_time_text}]"
168-
st.markdown(response_time_colored)
188+
if response_time < 5000:
189+
color = "green"
190+
elif response_time < 10000:
191+
color = "orange"
192+
else:
193+
color = "red"
194+
st.markdown(f":{color}[{response_time_text}]")
195+
169196
st.session_state.chat_history.append(
170197
{
171198
"content": response_buffer,
172199
"role": "ai",
173200
}
174201
)
175202

176-
if sources:
177-
with st.expander("Sources:"):
178-
try:
203+
st.session_state.sources[user_input] = sources
204+
205+
with st.expander("Sources:"):
206+
try:
207+
if sources:
179208
if isinstance(sources, str):
180209
cleaned_sources = sources.replace("{", "[").replace(
181210
"}", "]"
182211
)
183212
parsed_sources = ast.literal_eval(cleaned_sources)
184213
else:
185214
parsed_sources = sources
186-
if isinstance(parsed_sources, (list, set)):
215+
if (
216+
isinstance(parsed_sources, (list, set))
217+
and parsed_sources
218+
):
187219
sources_list = "\n".join(
188220
f"- [{link}]({link})"
189221
for link in parsed_sources
190222
if link.strip()
191223
)
192224
st.markdown(sources_list)
193225
else:
194-
st.markdown("No valid sources found.")
195-
except (ValueError, SyntaxError) as e:
196-
st.markdown(f"Failed to parse sources: {e}")
197-
else:
198-
st.error("Invalid response from the API")
199-
226+
st.markdown("No Sources Attached.")
227+
else:
228+
st.markdown("No Sources Attached.")
229+
except (ValueError, SyntaxError) as e:
230+
st.markdown(f"Failed to parse sources: {e}")
231+
else:
232+
st.error("Invalid response from the API")
233+
234+
# Reaction buttons and feedback form
200235
question_dict = {
201236
interaction["content"]: i
202237
for i, interaction in enumerate(st.session_state.chat_history)
203238
if interaction["role"] == "user"
204239
}
240+
205241
if question_dict and os.getenv("FEEDBACK_SHEET_ID"):
206242
if "feedback_button" not in st.session_state:
207243
st.session_state.feedback_button = False
@@ -212,10 +248,47 @@ def update_state() -> None:
212248
"""
213249
st.session_state.feedback_button = True
214250

215-
if (
216-
st.button("Feedback", on_click=update_state)
217-
or st.session_state.feedback_button
218-
):
251+
# Display reaction buttons
252+
col1, col2, col3 = st.columns([1, 1, 2])
253+
with col1:
254+
thumbs_up = st.button("👍", key="thumbs_up")
255+
with col2:
256+
thumbs_down = st.button("👎", key="thumbs_down")
257+
with col3:
258+
feedback_clicked = st.button("Feedback", on_click=update_state)
259+
260+
# Handle thumbs up and thumbs down reactions
261+
if thumbs_up or thumbs_down:
262+
try:
263+
selected_question = st.session_state.chat_history[-2][
264+
"content"
265+
] # Last user question
266+
gen_ans = st.session_state.chat_history[-1][
267+
"content"
268+
] # Last AI response
269+
sources = st.session_state.metadata.get(selected_question, {}).get(
270+
"sources", ["N/A"]
271+
)
272+
context = st.session_state.metadata.get(selected_question, {}).get(
273+
"context", ["N/A"]
274+
)
275+
reaction = "upvote" if thumbs_up else "downvote"
276+
277+
submit_feedback_to_google_sheet(
278+
question=selected_question,
279+
answer=gen_ans,
280+
sources=sources if isinstance(sources, list) else [sources],
281+
context=context if isinstance(context, list) else [context],
282+
issue="", # Leave issue blank
283+
version=os.getenv("RAG_VERSION", get_git_commit_hash()),
284+
reaction=reaction, # Pass the reaction
285+
)
286+
st.success("Thank you for your feedback!")
287+
except Exception as e:
288+
st.error(f"Failed to submit feedback: {e}")
289+
290+
# Feedback form logic
291+
if feedback_clicked or st.session_state.feedback_button:
219292
try:
220293
show_feedback_form(
221294
question_dict,

0 commit comments

Comments
 (0)