Skip to content

Commit c692881

Browse files
committed
Protect r_task() from panics and detect cancelled handlers
1 parent b06b7da commit c692881

2 files changed

Lines changed: 64 additions & 10 deletions

File tree

crates/ark/src/lsp/main_loop.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ use stdext::result::ResultExt;
2929
use tokio::sync::mpsc;
3030
use tokio::sync::mpsc::unbounded_channel as tokio_unbounded_channel;
3131
use tokio::task::JoinHandle;
32+
use tower_lsp::jsonrpc;
3233
use tower_lsp::lsp_types;
3334
use tower_lsp::lsp_types::Diagnostic;
3435
use tower_lsp::lsp_types::MessageType;
@@ -39,6 +40,7 @@ use super::backend::RequestResponse;
3940
use crate::console::ConsoleNotification;
4041
use crate::lsp;
4142
use crate::lsp::ark_file::ArkFile;
43+
use crate::lsp::backend::LspError;
4244
use crate::lsp::backend::LspMessage;
4345
use crate::lsp::backend::LspNotification;
4446
use crate::lsp::backend::LspRequest;
@@ -602,6 +604,12 @@ fn respond<T>(
602604
let response = match std::panic::catch_unwind(std::panic::AssertUnwindSafe(response)) {
603605
Ok(Ok(t)) => RequestResponse::Result(Ok(into_lsp_response(t))),
604606
Ok(Err(e)) => RequestResponse::Result(Err(e)),
607+
Err(err) if err.downcast_ref::<salsa::Cancelled>().is_some() => {
608+
// A salsa write cancelled an oak query while the handler ran.
609+
// Report `ContentModified` so the client knows the content moved
610+
// under us and re-requests.
611+
RequestResponse::Result(Err(LspError::JsonRpc(jsonrpc::Error::content_modified())))
612+
},
605613
Err(err) => {
606614
// Set global crash flag to disable the LSP
607615
LSP_HAS_CRASHED.store(true, Ordering::Release);
@@ -1048,11 +1056,17 @@ mod tests {
10481056
use aether_path::FilePath;
10491057
use oak_scan::DbScan;
10501058
use salsa::Database;
1059+
use tower_lsp::jsonrpc;
10511060
use url::Url;
10521061

10531062
use super::catch_cancellation;
10541063
use super::refresh_diagnostics;
1064+
use super::respond;
1065+
use super::tokio_unbounded_channel;
10551066
use super::RefreshDiagnosticsTask;
1067+
use crate::lsp::backend::LspError;
1068+
use crate::lsp::backend::LspResponse;
1069+
use crate::lsp::backend::RequestResponse;
10561070
use crate::lsp::state::WorldState;
10571071

10581072
/// A salsa cancellation during the pass is swallowed into `None` by
@@ -1084,6 +1098,41 @@ mod tests {
10841098
assert!(catch_cancellation(|| refresh_diagnostics(task)).is_none());
10851099
}
10861100

1101+
/// A `salsa::Cancelled` re-raised out of a request handler (by `r_task`,
1102+
/// after catching it on the R thread) must not crash the LSP. `respond`
1103+
/// recognises the payload and answers `ContentModified` so the client
1104+
/// re-requests, rather than taking the panic-is-a-crash path.
1105+
#[test]
1106+
fn test_cancelled_request_reports_content_modified() {
1107+
let mut state = WorldState::default();
1108+
let uri = Url::parse("file:///test.R").unwrap();
1109+
let file = state
1110+
.db
1111+
.upsert_editor(FilePath::from_url(&uri), "foo".to_string());
1112+
state.insert_ark_file(uri.clone(), file, None);
1113+
1114+
let file = state.ark_file(&uri).unwrap();
1115+
let snapshot = state.diagnostics_snapshot();
1116+
snapshot.db.cancellation_token().cancel();
1117+
1118+
let (response_tx, mut response_rx) = tokio_unbounded_channel::<RequestResponse>();
1119+
respond(
1120+
response_tx,
1121+
|| {
1122+
let _ = file.tree_sitter(&snapshot.db);
1123+
Ok(LspResponse::Hover(None))
1124+
},
1125+
|response| response,
1126+
)
1127+
.unwrap();
1128+
1129+
let response = response_rx.try_recv().unwrap();
1130+
let RequestResponse::Result(Err(LspError::JsonRpc(error))) = response else {
1131+
panic!("Expected a jsonrpc error response");
1132+
};
1133+
assert_eq!(error.code, jsonrpc::ErrorCode::ContentModified);
1134+
}
1135+
10871136
/// The central diagnostics refresh keys off the oak revision advancing
10881137
/// across a loop tick, so an oak write must bump the revision. This pins
10891138
/// that assumption: if a salsa upgrade changed it, the refresh would

crates/ark/src/r_task.rs

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,6 @@ impl RTaskStartInfo {
191191
// running, so borrowing is allowed even though we send it to another
192192
// thread. See also `Crossbeam::thread::ScopedThreadBuilder` (from which
193193
// `r_task()` is adapted) for a similar approach.
194-
//
195-
// Don't run oak (salsa) queries inside `f`. `f` executes on the R thread, and a
196-
// `salsa::Cancelled` unwind there would cross R's C frames, which is UB. Pull
197-
// whatever you need out of oak before the `r_task()`, on the calling thread.
198194

199195
pub fn r_task<'env, F, T>(f: F) -> T
200196
where
@@ -222,13 +218,17 @@ where
222218
// Instead of scoping the task with a thread join, we send it on the R
223219
// thread and block the thread until a completion channel wakes us up.
224220

225-
// The result of `f` will be stored here.
226-
let result = SharedOption::default();
221+
// Stores the outcome of `f`. We catch any unwind on the R thread instead of
222+
// letting it escape the closure: the closure runs inside `r_sandbox`'s
223+
// `try_catch`, and a Rust unwind crossing those C frames is UB. The payload
224+
// is ferried back and re-raised below, on the calling thread.
225+
let result: SharedOption<std::thread::Result<T>> = SharedOption::default();
227226

228227
{
229228
let result = Arc::clone(&result);
230229
let closure = move || {
231-
*result.lock().unwrap() = Some(f());
230+
let caught = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f));
231+
*result.lock().unwrap() = Some(caught);
232232
};
233233

234234
// Move `f` to heap and erase its lifetime so we can send it to
@@ -286,9 +286,14 @@ where
286286
}
287287
}
288288

289-
// Retrieve closure result from the synchronized shared option.
290-
// If we get here without panicking we know the result was assigned.
291-
return result.lock().unwrap().take().unwrap();
289+
// The closure ran to completion: it caught its own unwind, and an R-level
290+
// error would have panicked above. Re-raise on this thread any panic the
291+
// closure caught on the R thread.
292+
let caught = result.lock().unwrap().take().unwrap();
293+
match caught {
294+
Ok(value) => value,
295+
Err(payload) => std::panic::resume_unwind(payload),
296+
}
292297
}
293298

294299
/// An async task to be run on the R thread.

0 commit comments

Comments
 (0)