@@ -17,10 +17,17 @@ async fn index(_req: Request) -> Response {
1717}
1818
1919#[ derive( Serialize , Deserialize ) ]
20- struct Messages {
20+ struct MessagesRequest {
2121 messages : Vec < String > ,
2222}
2323
24+ #[ derive( Serialize , Deserialize , Debug ) ]
25+ #[ serde( tag = "type" ) ]
26+ enum MessagesResponse {
27+ Success { messages : Vec < String > } ,
28+ Cancelled ,
29+ }
30+
2431async fn load_docs ( paths : Vec < PathBuf > ) -> Vec < String > {
2532 let mut doc_futs = paths
2633 . into_iter ( )
@@ -33,53 +40,82 @@ async fn load_docs(paths: Vec<PathBuf>) -> Vec<String> {
3340 docs
3441}
3542
36- type Payload = ( Arc < Vec < String > > , oneshot:: Sender < Vec < String > > ) ;
43+ type Payload = ( Arc < Vec < String > > , oneshot:: Sender < Option < Vec < String > > > ) ;
3744
38- fn chatbot_thread ( ) -> mpsc:: Sender < Payload > {
39- let ( tx, mut rx) = mpsc:: channel :: < Payload > ( 1024 ) ;
45+ fn chatbot_thread ( ) -> ( mpsc:: Sender < Payload > , mpsc:: Sender < ( ) > ) {
46+ let ( req_tx, mut req_rx) = mpsc:: channel :: < Payload > ( 1024 ) ;
47+ let ( cancel_tx, mut cancel_rx) = mpsc:: channel :: < ( ) > ( 1 ) ;
4048 tokio:: spawn ( async move {
4149 let mut chatbot = chatbot:: Chatbot :: new ( vec ! [ ":-)" . into( ) , "^^" . into( ) ] ) ;
42- while let Some ( ( messages, responder) ) = rx . recv ( ) . await {
50+ while let Some ( ( messages, responder) ) = req_rx . recv ( ) . await {
4351 let doc_paths = chatbot. retrieval_documents ( & messages) ;
4452 let docs = load_docs ( doc_paths) . await ;
45- let response = chatbot. query_chat ( & messages, & docs) . await ;
46- responder. send ( response) . unwrap ( ) ;
53+ let chat_fut = chatbot. query_chat ( & messages, & docs) ;
54+ let cancel_fut = cancel_rx. recv ( ) ;
55+ tokio:: select! {
56+ response = chat_fut => {
57+ responder. send( Some ( response) ) . unwrap( ) ;
58+ }
59+ _ = cancel_fut => {
60+ responder. send( None ) . unwrap( ) ;
61+ }
62+ }
4763 }
4864 } ) ;
49- tx
65+ ( req_tx , cancel_tx )
5066}
5167
52- async fn query_chat ( messages : & Arc < Vec < String > > ) -> Vec < String > {
53- static SENDER : LazyLock < mpsc :: Sender < Payload > > = LazyLock :: new ( chatbot_thread) ;
68+ static CHATBOT_THREAD : LazyLock < ( mpsc :: Sender < Payload > , mpsc :: Sender < ( ) > ) > =
69+ LazyLock :: new ( chatbot_thread) ;
5470
71+ async fn query_chat ( messages : & Arc < Vec < String > > ) -> Option < Vec < String > > {
5572 let ( tx, rx) = oneshot:: channel ( ) ;
56- SENDER . send ( ( Arc :: clone ( messages) , tx) ) . await . unwrap ( ) ;
73+ CHATBOT_THREAD
74+ . 0
75+ . send ( ( Arc :: clone ( messages) , tx) )
76+ . await
77+ . unwrap ( ) ;
5778 rx. await . unwrap ( )
5879}
5980
81+ async fn cancel ( _req : Request ) -> Response {
82+ CHATBOT_THREAD . 1 . send ( ( ) ) . await . unwrap ( ) ;
83+ Ok ( Content :: Html ( "success" . into ( ) ) )
84+ }
85+
6086async fn chat ( req : Request ) -> Response {
6187 let Request :: Post ( body) = req else {
6288 return Err ( StatusCode :: METHOD_NOT_ALLOWED ) ;
6389 } ;
64- let Ok ( mut data) = serde_json:: from_str :: < Messages > ( & body) else {
90+ let Ok ( mut data) = serde_json:: from_str :: < MessagesRequest > ( & body) else {
6591 return Err ( StatusCode :: INTERNAL_SERVER_ERROR ) ;
6692 } ;
6793
6894 let messages = Arc :: new ( data. messages ) ;
69- let ( i, mut responses ) = join ! ( chatbot:: gen_random_number( ) , query_chat( & messages) ) ;
95+ let ( i, responses_opt ) = join ! ( chatbot:: gen_random_number( ) , query_chat( & messages) ) ;
7096
71- let response = responses. remove ( i % responses. len ( ) ) ;
72- data. messages = Arc :: into_inner ( messages) . unwrap ( ) ;
73- data. messages . push ( response) ;
97+ let response = match responses_opt {
98+ Some ( mut responses) => {
99+ let response = responses. remove ( i % responses. len ( ) ) ;
100+ data. messages = Arc :: into_inner ( messages) . unwrap ( ) ;
101+ data. messages . push ( response) ;
74102
75- Ok ( Content :: Json ( serde_json:: to_string ( & data) . unwrap ( ) ) )
103+ MessagesResponse :: Success {
104+ messages : data. messages ,
105+ }
106+ }
107+ None => MessagesResponse :: Cancelled ,
108+ } ;
109+
110+ Ok ( Content :: Json ( serde_json:: to_string ( & response) . unwrap ( ) ) )
76111}
77112
78113#[ tokio:: main]
79114async fn main ( ) {
80115 miniserve:: Server :: new ( )
81116 . route ( "/" , index)
82117 . route ( "/chat" , chat)
118+ . route ( "/cancel" , cancel)
83119 . run ( )
84120 . await
85121}
0 commit comments