Skip to content

Commit 8ef93da

Browse files
authored
server: add retry support (#1032)
1 parent f627cf2 commit 8ef93da

2 files changed

Lines changed: 185 additions & 92 deletions

File tree

crates/squawk_server/src/dispatch.rs

Lines changed: 113 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,16 @@ use std::panic::UnwindSafe;
44

55
use anyhow::Result;
66
use log::{error, info};
7-
use lsp_server::{RequestId, Response};
7+
use lsp_server::{Request, Response};
88
use lsp_types::{notification::Notification as LspNotification, request::Request as LspRequest};
9+
use salsa::Cancelled;
10+
use serde::{Serialize, de::DeserializeOwned};
911
use squawk_thread::ThreadIntent;
1012

11-
use crate::global_state::{GlobalState, Snapshot};
13+
use crate::{
14+
global_state::{GlobalState, Snapshot, TaskResult},
15+
panic::PanicError,
16+
};
1217

1318
pub(crate) struct RequestDispatcher<'a> {
1419
req: Option<lsp_server::Request>,
@@ -23,15 +28,14 @@ impl<'a> RequestDispatcher<'a> {
2328
}
2429
}
2530

26-
fn parse<R>(&mut self) -> Option<(lsp_server::RequestId, R::Params)>
31+
fn parse<R>(&mut self) -> Option<(Request, R::Params)>
2732
where
2833
R: LspRequest,
2934
{
3035
let req = self.req.take_if(|req| req.method.as_str() == R::METHOD)?;
3136
let id = req.id.clone();
32-
33-
match req.extract(R::METHOD) {
34-
Ok((id, params)) => Some((id, params)),
37+
match from_json(R::METHOD, &req.params) {
38+
Ok(params) => Some((req, params)),
3539
Err(err) => {
3640
let response = Response::new_err(
3741
id,
@@ -44,6 +48,26 @@ impl<'a> RequestDispatcher<'a> {
4448
}
4549
}
4650

51+
pub(crate) fn on_sync<R>(
52+
mut self,
53+
handler: fn(&Snapshot, R::Params) -> Result<R::Result>,
54+
) -> Self
55+
where
56+
R: LspRequest,
57+
R::Params: Send + 'static + UnwindSafe,
58+
{
59+
let Some((request, params)) = self.parse::<R>() else {
60+
return self;
61+
};
62+
63+
let snapshot = self.global_state.snapshot();
64+
let result = crate::panic::catch_unwind(|| handler(&snapshot, params));
65+
if let Ok(response) = thread_result_to_response::<R>(request.id.clone(), result) {
66+
self.global_state.respond(response);
67+
}
68+
self
69+
}
70+
4771
pub(crate) fn on_sync_mut<R>(
4872
mut self,
4973
handler: fn(&mut GlobalState, R::Params) -> Result<R::Result>,
@@ -52,34 +76,40 @@ impl<'a> RequestDispatcher<'a> {
5276
R: LspRequest,
5377
R::Params: Send + 'static,
5478
{
55-
if let Some((id, params)) = self.parse::<R>() {
56-
let result = handler(self.global_state, params);
57-
let response = result_to_response::<R>(id, result);
79+
let Some((request, params)) = self.parse::<R>() else {
80+
return self;
81+
};
82+
83+
let result = handler(self.global_state, params);
84+
if let Ok(response) = result_to_response::<R>(request.id.clone(), result) {
5885
self.global_state.respond(response);
5986
}
6087
self
6188
}
6289

63-
pub(crate) fn on<R>(self, handler: fn(&Snapshot, R::Params) -> Result<R::Result>) -> Self
90+
pub(crate) fn on<const ALLOW_RETRYING: bool, R>(
91+
self,
92+
handler: fn(&Snapshot, R::Params) -> Result<R::Result>,
93+
) -> Self
6494
where
6595
R: LspRequest,
6696
R::Params: Send + 'static + UnwindSafe,
6797
{
68-
self.on_with_thread_intent::<R>(ThreadIntent::Worker, handler)
98+
self.on_with_thread_intent::<ALLOW_RETRYING, R>(ThreadIntent::Worker, handler)
6999
}
70100

71-
pub(crate) fn on_latency_sensitive<R>(
101+
pub(crate) fn on_latency_sensitive<const ALLOW_RETRYING: bool, R>(
72102
self,
73103
handler: fn(&Snapshot, R::Params) -> Result<R::Result>,
74104
) -> Self
75105
where
76106
R: LspRequest,
77107
R::Params: Send + 'static + UnwindSafe,
78108
{
79-
self.on_with_thread_intent::<R>(ThreadIntent::LatencySensitive, handler)
109+
self.on_with_thread_intent::<ALLOW_RETRYING, R>(ThreadIntent::LatencySensitive, handler)
80110
}
81111

82-
fn on_with_thread_intent<R>(
112+
fn on_with_thread_intent<const ALLOW_RETRYING: bool, R>(
83113
mut self,
84114
intent: ThreadIntent,
85115
handler: fn(&Snapshot, R::Params) -> Result<R::Result>,
@@ -88,15 +118,20 @@ impl<'a> RequestDispatcher<'a> {
88118
R: LspRequest,
89119
R::Params: Send + 'static + UnwindSafe,
90120
{
91-
if let Some((id, params)) = self.parse::<R>() {
121+
if let Some((request, params)) = self.parse::<R>() {
92122
let snapshot = self.global_state.snapshot();
93123

94124
self.global_state.task_pool.handle.spawn(intent, move || {
95-
crate::panic::catch_unwind(|| {
96-
let result = handler(&snapshot, params);
97-
result_to_response::<R>(id.clone(), result)
98-
})
99-
.unwrap_or_else(|error| panic_response(id, &error))
125+
let result = crate::panic::catch_unwind(|| handler(&snapshot, params));
126+
match thread_result_to_response::<R>(request.id.clone(), result) {
127+
Ok(response) => TaskResult::Response(response),
128+
Err(_cancelled) if ALLOW_RETRYING => TaskResult::Retry(request),
129+
Err(_cancelled) => TaskResult::Response(Response::new_err(
130+
request.id,
131+
lsp_server::ErrorCode::ContentModified as i32,
132+
"content modified".to_owned(),
133+
)),
134+
}
100135
});
101136
}
102137

@@ -110,45 +145,75 @@ impl<'a> RequestDispatcher<'a> {
110145
}
111146
}
112147

113-
fn panic_response(id: RequestId, error: &crate::panic::PanicError) -> Response {
114-
// Check if the request was canceled due to some modifications to the salsa database.
115-
if error.payload.downcast_ref::<salsa::Cancelled>().is_some() {
116-
// TODO: trigger retries when we have that setup, we'll reenque the task
117-
log::debug!(
118-
"request id={} was cancelled by salsa, sending content modified",
119-
id
120-
);
121-
Response::new_err(
122-
id,
123-
lsp_server::ErrorCode::ContentModified as i32,
124-
"content modified".to_string(),
125-
)
126-
} else {
127-
Response::new_err(
128-
id,
129-
lsp_server::ErrorCode::InternalError as i32,
130-
"request handler error".to_string(),
131-
)
148+
fn thread_result_to_response<R>(
149+
id: lsp_server::RequestId,
150+
result: Result<anyhow::Result<R::Result>, PanicError>,
151+
) -> Result<lsp_server::Response, PanicError>
152+
where
153+
R: lsp_types::request::Request,
154+
R::Params: DeserializeOwned,
155+
R::Result: Serialize,
156+
{
157+
match result {
158+
Ok(handler_result) => match handler_result {
159+
Ok(result) => Ok(Response::new_ok(id, result)),
160+
Err(error) => Ok(Response::new_err(
161+
id,
162+
lsp_server::ErrorCode::InternalError as i32,
163+
error.to_string(),
164+
)),
165+
},
166+
Err(panic) => {
167+
// Check if the request was canceled due to some modifications to the salsa database.
168+
if panic.payload.downcast_ref::<salsa::Cancelled>().is_some() {
169+
log::debug!(
170+
"request id={} was cancelled by salsa, sending content modified",
171+
id
172+
);
173+
Err(panic)
174+
} else {
175+
let error = panic.to_string();
176+
// we don't retry non-salsa cancellation panics
177+
Ok(Response::new_err(
178+
id,
179+
lsp_server::ErrorCode::InternalError as i32,
180+
format!("request handler error: {error}"),
181+
))
182+
}
183+
}
132184
}
133185
}
134186

135-
fn result_to_response<R>(id: RequestId, result: Result<R::Result>) -> Response
187+
fn result_to_response<R>(
188+
id: lsp_server::RequestId,
189+
result: anyhow::Result<R::Result>,
190+
) -> std::result::Result<lsp_server::Response, Cancelled>
136191
where
137-
R: LspRequest,
192+
R: lsp_types::request::Request,
138193
{
139194
match result {
140-
Ok(result) => Response::new_ok(id, result),
141-
Err(err) => {
142-
error!("Request handler failed: {err}");
143-
Response::new_err(
195+
Ok(resp) => Ok(lsp_server::Response::new_ok(id, &resp)),
196+
Err(e) => match e.downcast::<Cancelled>() {
197+
Ok(cancelled) => Err(cancelled),
198+
Err(e) => Ok(lsp_server::Response::new_err(
144199
id,
145200
lsp_server::ErrorCode::InternalError as i32,
146-
err.to_string(),
147-
)
148-
}
201+
e.to_string(),
202+
)),
203+
},
149204
}
150205
}
151206

207+
// lsp-server has req.extract(R::METHOD), but it doesn't work for us due to
208+
// ownership so we use this instead.
209+
pub fn from_json<T: DeserializeOwned>(
210+
what: &'static str,
211+
json: &serde_json::Value,
212+
) -> anyhow::Result<T> {
213+
serde_json::from_value(json.clone())
214+
.map_err(|e| anyhow::format_err!("Failed to deserialize {what}: {e}; {json}"))
215+
}
216+
152217
pub(crate) struct NotificationDispatcher<'a> {
153218
notif: Option<lsp_server::Notification>,
154219
state: &'a mut GlobalState,

0 commit comments

Comments
 (0)