66import os
77import ast
88from PIL import Image
9- from utils .feedback import show_feedback_form
9+ from utils .feedback import (
10+ show_feedback_form ,
11+ submit_feedback_to_google_sheet ,
12+ get_git_commit_hash ,
13+ )
1014from dotenv import load_dotenv
1115from typing import Callable , Any
1216
@@ -22,6 +26,24 @@ def wrapper(*args: Any, **kwargs: Any) -> tuple[Any, float]:
2226 return wrapper
2327
2428
29+ def translate_chat_history_to_api (chat_history , max_pairs = 4 ):
30+ api_format = []
31+ relevant_history = [
32+ msg for msg in chat_history [1 :] if msg ["role" ] in ["user" , "ai" ]
33+ ]
34+
35+ i = len (relevant_history ) - 1
36+ while i > 0 and len (api_format ) < max_pairs :
37+ ai_msg = relevant_history [i ]
38+ user_msg = relevant_history [i - 1 ]
39+ if ai_msg ["role" ] == "ai" and user_msg ["role" ] == "user" :
40+ api_format .insert (0 , {"User" : user_msg ["content" ], "AI" : ai_msg ["content" ]})
41+ i -= 2
42+ else :
43+ i -= 1
44+ return api_format
45+
46+
2547@measure_response_time
2648def response_generator (user_input : str ) -> tuple [str , str ] | tuple [None , None ]:
2749 """
@@ -34,74 +56,47 @@ def response_generator(user_input: str) -> tuple[str, str] | tuple[None, None]:
3456 - tuple: Contains the AI response and sources.
3557 """
3658 url = f"{ st .session_state .base_url } { st .session_state .selected_endpoint } "
37-
3859 headers = {"accept" : "application/json" , "Content-Type" : "application/json" }
39-
40- payload = {"query" : user_input , "list_sources" : True , "list_context" : True }
41-
60+ chat_history = translate_chat_history_to_api (st .session_state .chat_history )
61+ payload = {
62+ "query" : user_input ,
63+ "list_sources" : True ,
64+ "list_context" : True ,
65+ "chat_history" : chat_history ,
66+ }
4267 try :
4368 response = requests .post (url , headers = headers , json = payload )
4469 response .raise_for_status ()
45-
46- try :
47- data = response .json ()
48- if not isinstance (data , dict ):
49- st .error ("Invalid response format" )
50- return None , None
51- except ValueError :
52- st .error ("Failed to decode JSON response" )
70+ data = response .json ()
71+ if not isinstance (data , dict ):
72+ st .error ("Invalid response format" )
5373 return None , None
54-
5574 sources = data .get ("sources" , "" )
5675 st .session_state .metadata [user_input ] = {
5776 "sources" : sources ,
5877 "context" : data .get ("context" , "" ),
5978 }
60-
6179 return data .get ("response" , "" ), sources
62-
6380 except requests .exceptions .RequestException as e :
6481 st .error (f"Request failed: { e } " )
6582 return None , None
6683
6784
68- def fetch_endpoints () -> tuple [str , list [str ]]:
69- base_url = os .getenv ("CHAT_ENDPOINT" , "http://localhost:8000" )
70- url = f"{ base_url } /chains/listAll"
71- try :
72- response = requests .get (url )
73- response .raise_for_status ()
74- endpoints = response .json ()
75- return base_url , endpoints
76- except requests .exceptions .RequestException as e :
77- st .error (f"Failed to fetch endpoints: { e } " )
78- return base_url , []
79-
80-
8185def main () -> None :
8286 load_dotenv ()
83-
8487 img = Image .open ("assets/or_logo.png" )
8588 st .set_page_config (page_title = "OR Assistant" , page_icon = img )
8689
8790 deployment_time = datetime .datetime .now (pytz .timezone ("UTC" ))
88- st .info (f" Deployment time: { deployment_time .strftime (' %m/%d/%Y %H:%M' )} UTC" )
91+ st .info (f' Deployment time: { deployment_time .strftime (" %m/%d/%Y %H:%M" )} UTC' )
8992
9093 st .title ("OR Assistant" )
9194
92- base_url , endpoints = fetch_endpoints ()
93-
94- selected_endpoint = st .selectbox (
95- "Select preferred endpoint" ,
96- options = endpoints ,
97- index = 0 ,
98- format_func = lambda x : x .split ("/" )[- 1 ].capitalize (),
99- )
95+ base_url = os .getenv ("CHAT_ENDPOINT" , "http://localhost:8000" )
96+ selected_endpoint = "/graphs/agent-retriever"
10097
10198 if "selected_endpoint" not in st .session_state :
10299 st .session_state .selected_endpoint = selected_endpoint
103- else :
104- st .session_state .selected_endpoint = selected_endpoint
105100
106101 if "base_url" not in st .session_state :
107102 st .session_state .base_url = base_url
@@ -115,6 +110,8 @@ def main() -> None:
115110 st .session_state .chat_history = []
116111 if "metadata" not in st .session_state :
117112 st .session_state .metadata = {}
113+ if "sources" not in st .session_state :
114+ st .session_state .sources = {}
118115
119116 if not st .session_state .chat_history :
120117 st .session_state .chat_history .append (
@@ -124,10 +121,42 @@ def main() -> None:
124121 }
125122 )
126123
127- for message in st .session_state .chat_history :
124+ for idx , message in enumerate ( st .session_state .chat_history ) :
128125 with st .chat_message (message ["role" ]):
129126 st .markdown (message ["content" ])
130127
128+ if message ["role" ] == "ai" and idx > 0 :
129+ user_message = st .session_state .chat_history [idx - 1 ]
130+ if user_message ["role" ] == "user" :
131+ user_input = user_message ["content" ]
132+ sources = st .session_state .sources .get (user_input )
133+ with st .expander ("Sources:" ):
134+ try :
135+ if sources :
136+ if isinstance (sources , str ):
137+ cleaned_sources = sources .replace ("{" , "[" ).replace (
138+ "}" , "]"
139+ )
140+ parsed_sources = ast .literal_eval (cleaned_sources )
141+ else :
142+ parsed_sources = sources
143+ if (
144+ isinstance (parsed_sources , (list , set ))
145+ and parsed_sources
146+ ):
147+ sources_list = "\n " .join (
148+ f"- [{ link } ]({ link } )"
149+ for link in parsed_sources
150+ if link .strip ()
151+ )
152+ st .markdown (sources_list )
153+ else :
154+ st .markdown ("No Sources Attached." )
155+ else :
156+ st .markdown ("No Sources Attached." )
157+ except (ValueError , SyntaxError ) as e :
158+ st .markdown (f"Failed to parse sources: { e } " )
159+
131160 user_input = st .chat_input ("Enter your queries ..." )
132161
133162 if user_input :
@@ -146,62 +175,69 @@ def main() -> None:
146175 ):
147176 response , sources = response_tuple
148177 if response is not None :
149- response_buffer = ""
178+ response_buffer = response
150179
151180 with st .chat_message ("ai" ):
152181 message_placeholder = st .empty ()
153-
154- response_buffer = ""
155- for chunk in response .split (" " ):
156- response_buffer += chunk + " "
157- if chunk .endswith ("\n " ):
158- response_buffer += " "
159- message_placeholder .markdown (response_buffer )
160- time .sleep (0.05 )
161-
162182 message_placeholder .markdown (response_buffer )
163183
184+ # Display response time
164185 response_time_text = (
165186 f"Response Time: { response_time / 1000 :.2f} seconds"
166187 )
167- response_time_colored = f":{ 'green' if response_time < 5000 else 'orange' if response_time < 10000 else 'red' } [{ response_time_text } ]"
168- st .markdown (response_time_colored )
188+ if response_time < 5000 :
189+ color = "green"
190+ elif response_time < 10000 :
191+ color = "orange"
192+ else :
193+ color = "red"
194+ st .markdown (f":{ color } [{ response_time_text } ]" )
195+
169196 st .session_state .chat_history .append (
170197 {
171198 "content" : response_buffer ,
172199 "role" : "ai" ,
173200 }
174201 )
175202
176- if sources :
177- with st .expander ("Sources:" ):
178- try :
203+ st .session_state .sources [user_input ] = sources
204+
205+ with st .expander ("Sources:" ):
206+ try :
207+ if sources :
179208 if isinstance (sources , str ):
180209 cleaned_sources = sources .replace ("{" , "[" ).replace (
181210 "}" , "]"
182211 )
183212 parsed_sources = ast .literal_eval (cleaned_sources )
184213 else :
185214 parsed_sources = sources
186- if isinstance (parsed_sources , (list , set )):
215+ if (
216+ isinstance (parsed_sources , (list , set ))
217+ and parsed_sources
218+ ):
187219 sources_list = "\n " .join (
188220 f"- [{ link } ]({ link } )"
189221 for link in parsed_sources
190222 if link .strip ()
191223 )
192224 st .markdown (sources_list )
193225 else :
194- st .markdown ("No valid sources found." )
195- except (ValueError , SyntaxError ) as e :
196- st .markdown (f"Failed to parse sources: { e } " )
197- else :
198- st .error ("Invalid response from the API" )
199-
226+ st .markdown ("No Sources Attached." )
227+ else :
228+ st .markdown ("No Sources Attached." )
229+ except (ValueError , SyntaxError ) as e :
230+ st .markdown (f"Failed to parse sources: { e } " )
231+ else :
232+ st .error ("Invalid response from the API" )
233+
234+ # Reaction buttons and feedback form
200235 question_dict = {
201236 interaction ["content" ]: i
202237 for i , interaction in enumerate (st .session_state .chat_history )
203238 if interaction ["role" ] == "user"
204239 }
240+
205241 if question_dict and os .getenv ("FEEDBACK_SHEET_ID" ):
206242 if "feedback_button" not in st .session_state :
207243 st .session_state .feedback_button = False
@@ -212,10 +248,47 @@ def update_state() -> None:
212248 """
213249 st .session_state .feedback_button = True
214250
215- if (
216- st .button ("Feedback" , on_click = update_state )
217- or st .session_state .feedback_button
218- ):
251+ # Display reaction buttons
252+ col1 , col2 , col3 = st .columns ([1 , 1 , 2 ])
253+ with col1 :
254+ thumbs_up = st .button ("👍" , key = "thumbs_up" )
255+ with col2 :
256+ thumbs_down = st .button ("👎" , key = "thumbs_down" )
257+ with col3 :
258+ feedback_clicked = st .button ("Feedback" , on_click = update_state )
259+
260+ # Handle thumbs up and thumbs down reactions
261+ if thumbs_up or thumbs_down :
262+ try :
263+ selected_question = st .session_state .chat_history [- 2 ][
264+ "content"
265+ ] # Last user question
266+ gen_ans = st .session_state .chat_history [- 1 ][
267+ "content"
268+ ] # Last AI response
269+ sources = st .session_state .metadata .get (selected_question , {}).get (
270+ "sources" , ["N/A" ]
271+ )
272+ context = st .session_state .metadata .get (selected_question , {}).get (
273+ "context" , ["N/A" ]
274+ )
275+ reaction = "upvote" if thumbs_up else "downvote"
276+
277+ submit_feedback_to_google_sheet (
278+ question = selected_question ,
279+ answer = gen_ans ,
280+ sources = sources if isinstance (sources , list ) else [sources ],
281+ context = context if isinstance (context , list ) else [context ],
282+ issue = "" , # Leave issue blank
283+ version = os .getenv ("RAG_VERSION" , get_git_commit_hash ()),
284+ reaction = reaction , # Pass the reaction
285+ )
286+ st .success ("Thank you for your feedback!" )
287+ except Exception as e :
288+ st .error (f"Failed to submit feedback: { e } " )
289+
290+ # Feedback form logic
291+ if feedback_clicked or st .session_state .feedback_button :
219292 try :
220293 show_feedback_form (
221294 question_dict ,
0 commit comments