Skip to content

Commit 09ee688

Browse files
committed
Cancellation solution
1 parent 0d6cfb9 commit 09ee688

1 file changed

Lines changed: 53 additions & 17 deletions

File tree

crates/server/src/main.rs

Lines changed: 53 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,17 @@ async fn index(_req: Request) -> Response {
1717
}
1818

1919
#[derive(Serialize, Deserialize)]
20-
struct Messages {
20+
struct MessagesRequest {
2121
messages: Vec<String>,
2222
}
2323

24+
#[derive(Serialize, Deserialize, Debug)]
25+
#[serde(tag = "type")]
26+
enum MessagesResponse {
27+
Success { messages: Vec<String> },
28+
Cancelled,
29+
}
30+
2431
async fn load_docs(paths: Vec<PathBuf>) -> Vec<String> {
2532
let mut doc_futs = paths
2633
.into_iter()
@@ -33,53 +40,82 @@ async fn load_docs(paths: Vec<PathBuf>) -> Vec<String> {
3340
docs
3441
}
3542

36-
type Payload = (Arc<Vec<String>>, oneshot::Sender<Vec<String>>);
43+
type Payload = (Arc<Vec<String>>, oneshot::Sender<Option<Vec<String>>>);
3744

38-
fn chatbot_thread() -> mpsc::Sender<Payload> {
39-
let (tx, mut rx) = mpsc::channel::<Payload>(1024);
45+
fn chatbot_thread() -> (mpsc::Sender<Payload>, mpsc::Sender<()>) {
46+
let (req_tx, mut req_rx) = mpsc::channel::<Payload>(1024);
47+
let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1);
4048
tokio::spawn(async move {
4149
let mut chatbot = chatbot::Chatbot::new(vec![":-)".into(), "^^".into()]);
42-
while let Some((messages, responder)) = rx.recv().await {
50+
while let Some((messages, responder)) = req_rx.recv().await {
4351
let doc_paths = chatbot.retrieval_documents(&messages);
4452
let docs = load_docs(doc_paths).await;
45-
let response = chatbot.query_chat(&messages, &docs).await;
46-
responder.send(response).unwrap();
53+
let chat_fut = chatbot.query_chat(&messages, &docs);
54+
let cancel_fut = cancel_rx.recv();
55+
tokio::select! {
56+
response = chat_fut => {
57+
responder.send(Some(response)).unwrap();
58+
}
59+
_ = cancel_fut => {
60+
responder.send(None).unwrap();
61+
}
62+
}
4763
}
4864
});
49-
tx
65+
(req_tx, cancel_tx)
5066
}
5167

52-
async fn query_chat(messages: &Arc<Vec<String>>) -> Vec<String> {
53-
static SENDER: LazyLock<mpsc::Sender<Payload>> = LazyLock::new(chatbot_thread);
68+
static CHATBOT_THREAD: LazyLock<(mpsc::Sender<Payload>, mpsc::Sender<()>)> =
69+
LazyLock::new(chatbot_thread);
5470

71+
async fn query_chat(messages: &Arc<Vec<String>>) -> Option<Vec<String>> {
5572
let (tx, rx) = oneshot::channel();
56-
SENDER.send((Arc::clone(messages), tx)).await.unwrap();
73+
CHATBOT_THREAD
74+
.0
75+
.send((Arc::clone(messages), tx))
76+
.await
77+
.unwrap();
5778
rx.await.unwrap()
5879
}
5980

81+
async fn cancel(_req: Request) -> Response {
82+
CHATBOT_THREAD.1.send(()).await.unwrap();
83+
Ok(Content::Html("success".into()))
84+
}
85+
6086
async fn chat(req: Request) -> Response {
6187
let Request::Post(body) = req else {
6288
return Err(StatusCode::METHOD_NOT_ALLOWED);
6389
};
64-
let Ok(mut data) = serde_json::from_str::<Messages>(&body) else {
90+
let Ok(mut data) = serde_json::from_str::<MessagesRequest>(&body) else {
6591
return Err(StatusCode::INTERNAL_SERVER_ERROR);
6692
};
6793

6894
let messages = Arc::new(data.messages);
69-
let (i, mut responses) = join!(chatbot::gen_random_number(), query_chat(&messages));
95+
let (i, responses_opt) = join!(chatbot::gen_random_number(), query_chat(&messages));
7096

71-
let response = responses.remove(i % responses.len());
72-
data.messages = Arc::into_inner(messages).unwrap();
73-
data.messages.push(response);
97+
let response = match responses_opt {
98+
Some(mut responses) => {
99+
let response = responses.remove(i % responses.len());
100+
data.messages = Arc::into_inner(messages).unwrap();
101+
data.messages.push(response);
74102

75-
Ok(Content::Json(serde_json::to_string(&data).unwrap()))
103+
MessagesResponse::Success {
104+
messages: data.messages,
105+
}
106+
}
107+
None => MessagesResponse::Cancelled,
108+
};
109+
110+
Ok(Content::Json(serde_json::to_string(&response).unwrap()))
76111
}
77112

78113
#[tokio::main]
79114
async fn main() {
80115
miniserve::Server::new()
81116
.route("/", index)
82117
.route("/chat", chat)
118+
.route("/cancel", cancel)
83119
.run()
84120
.await
85121
}

0 commit comments

Comments
 (0)