Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

136 changes: 121 additions & 15 deletions crates/pgls_cli/src/service/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@ pub(crate) use self::unix::{ensure_daemon, open_socket, print_socket, run_daemon
/// [WorkspaceTransport] instance if the socket is currently active
pub fn open_transport(runtime: Runtime) -> io::Result<Option<impl WorkspaceTransport>> {
match runtime.block_on(open_socket()) {
Ok(Some((read, write))) => Ok(Some(SocketTransport::open(runtime, read, write))),
Ok(Some((read, write))) => Ok(Some(SocketTransport::open_with_timeout(
runtime,
read,
write,
DEFAULT_REQUEST_TIMEOUT,
))),
Ok(None) => Ok(None),
Err(err) => Err(err),
}
Expand Down Expand Up @@ -99,8 +104,11 @@ pub struct SocketTransport {
runtime: Runtime,
write_send: Sender<(Vec<u8>, bool)>,
pending_requests: PendingRequests,
request_timeout: Duration,
}

const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(15);

/// Stores a handle to the map of pending requests, and clears the map
/// automatically when the handle is dropped
#[derive(Clone, Default)]
Expand Down Expand Up @@ -131,7 +139,12 @@ impl Drop for PendingRequests {
}

impl SocketTransport {
pub fn open<R, W>(runtime: Runtime, socket_read: R, socket_write: W) -> Self
pub fn open_with_timeout<R, W>(
runtime: Runtime,
socket_read: R,
socket_write: W,
request_timeout: Duration,
) -> Self
where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
Expand Down Expand Up @@ -172,6 +185,7 @@ impl SocketTransport {
runtime,
write_send,
pending_requests: pending_requests_2,
request_timeout,
}
}
}
Expand All @@ -185,27 +199,30 @@ impl WorkspaceTransport for SocketTransport {
P: Serialize,
R: DeserializeOwned,
{
let request_id = request.id;
let (send, recv) = oneshot::channel();

self.pending_requests.insert(request.id, send);

let is_shutdown = request.method == "pgls/shutdown";

let request = JsonRpcRequest {
jsonrpc: Cow::Borrowed("2.0"),
id: request.id,
id: request_id,
method: Cow::Borrowed(request.method),
params: request.params,
};

let request = to_vec(&request).map_err(|err| {
TransportError::SerdeError(format!(
"failed to serialize {} into byte buffer: {err}",
type_name::<P>()
))
})?;
let request = match to_vec(&request) {
Ok(request) => request,
Err(err) => {
return Err(TransportError::SerdeError(format!(
"failed to serialize {} into byte buffer: {err}",
type_name::<P>()
)));
}
};

let response = self.runtime.block_on(async move {
self.pending_requests.insert(request_id, send);

let response = match self.runtime.block_on(async move {
self.write_send
.send((request, is_shutdown))
.await
Expand All @@ -219,11 +236,17 @@ impl WorkspaceTransport for SocketTransport {
Err(_) => Err(TransportError::ChannelClosed),
}
}
_ = sleep(Duration::from_secs(15)) => {
_ = sleep(self.request_timeout) => {
Err(TransportError::Timeout)
}
}
})?;
}) {
Ok(response) => response,
Err(err) => {
self.pending_requests.remove(&request_id);
return Err(err);
}
};

let response = response.get();
let result = from_str(response).map_err(|err| {
Expand Down Expand Up @@ -472,3 +495,86 @@ impl FromStr for TransportHeader {
}
}
}

#[cfg(test)]
mod tests {
use std::fmt;
use std::time::Duration;

use pgls_workspace::TransportError;
use pgls_workspace::workspace::{TransportRequest, WorkspaceTransport};
use serde::Serialize;
use serde::ser::{Error as SerError, Serializer};
use serde_json::Value;
use tokio::io::{duplex, split};
use tokio::runtime::Runtime;

use super::SocketTransport;

struct FailingParams;

impl Serialize for FailingParams {
fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
Err(S::Error::custom("expected serialization failure"))
}
}

impl fmt::Debug for FailingParams {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("FailingParams")
}
}

fn disconnected_transport() -> SocketTransport {
let runtime = Runtime::new().expect("failed to create tokio runtime");
let (stream, peer) = duplex(1024);
drop(peer);
let (read, write) = split(stream);
SocketTransport::open_with_timeout(runtime, read, write, Duration::from_millis(50))
}

#[test]
fn request_does_not_retain_pending_entries_when_serialization_fails() {
let transport = disconnected_transport();

let result: Result<Value, TransportError> = transport.request(TransportRequest {
id: 1,
method: "pgls/get_file_content",
params: FailingParams,
});

assert!(matches!(result, Err(TransportError::SerdeError(_))));
assert_eq!(
transport.pending_requests.len(),
0,
"pending request should be cleaned up on serialization failure"
);
}

#[test]
fn request_does_not_retain_pending_entries_on_timeout_or_channel_close() {
let transport = disconnected_transport();

let result: Result<Value, TransportError> = transport.request(TransportRequest {
id: 2,
method: "pgls/get_file_content",
params: (),
});

assert!(
matches!(
result,
Err(TransportError::Timeout | TransportError::ChannelClosed)
),
"expected timeout or channel-closed error, got {result:?}"
);
assert_eq!(
transport.pending_requests.len(),
0,
"pending request should be cleaned up on timeout/channel-close"
);
}
}
18 changes: 11 additions & 7 deletions xtask/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@ rust-version.workspace = true
version = "0.0.0"

[dependencies]
anyhow = "1.0.62"
flate2 = "1.0.24"
time = { version = "0.3", default-features = false }
write-json = "0.1.2"
xflags = "0.3.0"
xshell = "0.2.2"
zip = { version = "0.6", default-features = false, features = ["deflate", "time"] }
anyhow = "1.0.62"
flate2 = "1.0.24"
pgls_cli = { path = "../crates/pgls_cli" }
pgls_workspace = { path = "../crates/pgls_workspace" }
serde_json = "1.0.114"
time = { version = "0.3", default-features = false }
tokio = { version = "1.40.0", features = ["rt-multi-thread", "net"] }
write-json = "0.1.2"
xflags = "0.3.0"
xshell = "0.2.2"
zip = { version = "0.6", default-features = false, features = ["deflate", "time"] }
# Avoid adding more dependencies to this crate
18 changes: 18 additions & 0 deletions xtask/src/flags.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ xflags::xflags! {
/// Install only the language server.
optional --server
}

/// Run a simple macOS leak check against an isolated language server process.
cmd leak-check {
/// Number of open/change/close LSP cycles to run.
optional --iterations n: usize
/// Pause between cycles in milliseconds.
optional --pause-ms n: u64
/// Probe to run: lsp | cli-timeout | both
optional --probe name: String
}
}
}

Expand All @@ -32,6 +42,7 @@ pub struct Xtask {
#[derive(Debug)]
pub enum XtaskCmd {
Install(Install),
LeakCheck(LeakCheck),
}

#[derive(Debug)]
Expand All @@ -41,6 +52,13 @@ pub struct Install {
pub server: bool,
}

#[derive(Debug)]
pub struct LeakCheck {
pub iterations: Option<usize>,
pub pause_ms: Option<u64>,
pub probe: Option<String>,
}

impl Xtask {
#[allow(dead_code)]
pub fn from_env_or_exit() -> Self {
Expand Down
Loading
Loading