diff --git a/Cargo.lock b/Cargo.lock index 4077c0853..2a9a24e30 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5585,7 +5585,11 @@ version = "0.0.0" dependencies = [ "anyhow", "flate2", + "pgls_cli", + "pgls_workspace", + "serde_json", "time", + "tokio", "write-json", "xflags", "xshell", diff --git a/crates/pgls_cli/src/service/mod.rs b/crates/pgls_cli/src/service/mod.rs index c676f9d3e..917d9352e 100644 --- a/crates/pgls_cli/src/service/mod.rs +++ b/crates/pgls_cli/src/service/mod.rs @@ -54,7 +54,12 @@ pub(crate) use self::unix::{ensure_daemon, open_socket, print_socket, run_daemon /// [WorkspaceTransport] instance if the socket is currently active pub fn open_transport(runtime: Runtime) -> io::Result> { match runtime.block_on(open_socket()) { - Ok(Some((read, write))) => Ok(Some(SocketTransport::open(runtime, read, write))), + Ok(Some((read, write))) => Ok(Some(SocketTransport::open_with_timeout( + runtime, + read, + write, + DEFAULT_REQUEST_TIMEOUT, + ))), Ok(None) => Ok(None), Err(err) => Err(err), } @@ -99,8 +104,11 @@ pub struct SocketTransport { runtime: Runtime, write_send: Sender<(Vec, bool)>, pending_requests: PendingRequests, + request_timeout: Duration, } +const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(15); + /// Stores a handle to the map of pending requests, and clears the map /// automatically when the handle is dropped #[derive(Clone, Default)] @@ -131,7 +139,12 @@ impl Drop for PendingRequests { } impl SocketTransport { - pub fn open(runtime: Runtime, socket_read: R, socket_write: W) -> Self + pub fn open_with_timeout( + runtime: Runtime, + socket_read: R, + socket_write: W, + request_timeout: Duration, + ) -> Self where R: AsyncRead + Unpin + Send + 'static, W: AsyncWrite + Unpin + Send + 'static, @@ -172,6 +185,7 @@ impl SocketTransport { runtime, write_send, pending_requests: pending_requests_2, + request_timeout, } } } @@ -185,27 +199,30 @@ impl WorkspaceTransport for SocketTransport { P: Serialize, R: DeserializeOwned, { + let request_id = request.id; let (send, recv) = oneshot::channel(); - - self.pending_requests.insert(request.id, send); - let is_shutdown = request.method == "pgls/shutdown"; let request = JsonRpcRequest { jsonrpc: Cow::Borrowed("2.0"), - id: request.id, + id: request_id, method: Cow::Borrowed(request.method), params: request.params, }; - let request = to_vec(&request).map_err(|err| { - TransportError::SerdeError(format!( - "failed to serialize {} into byte buffer: {err}", - type_name::

() - )) - })?; + let request = match to_vec(&request) { + Ok(request) => request, + Err(err) => { + return Err(TransportError::SerdeError(format!( + "failed to serialize {} into byte buffer: {err}", + type_name::

() + ))); + } + }; - let response = self.runtime.block_on(async move { + self.pending_requests.insert(request_id, send); + + let response = match self.runtime.block_on(async move { self.write_send .send((request, is_shutdown)) .await @@ -219,11 +236,17 @@ impl WorkspaceTransport for SocketTransport { Err(_) => Err(TransportError::ChannelClosed), } } - _ = sleep(Duration::from_secs(15)) => { + _ = sleep(self.request_timeout) => { Err(TransportError::Timeout) } } - })?; + }) { + Ok(response) => response, + Err(err) => { + self.pending_requests.remove(&request_id); + return Err(err); + } + }; let response = response.get(); let result = from_str(response).map_err(|err| { @@ -472,3 +495,86 @@ impl FromStr for TransportHeader { } } } + +#[cfg(test)] +mod tests { + use std::fmt; + use std::time::Duration; + + use pgls_workspace::TransportError; + use pgls_workspace::workspace::{TransportRequest, WorkspaceTransport}; + use serde::Serialize; + use serde::ser::{Error as SerError, Serializer}; + use serde_json::Value; + use tokio::io::{duplex, split}; + use tokio::runtime::Runtime; + + use super::SocketTransport; + + struct FailingParams; + + impl Serialize for FailingParams { + fn serialize(&self, _serializer: S) -> Result + where + S: Serializer, + { + Err(S::Error::custom("expected serialization failure")) + } + } + + impl fmt::Debug for FailingParams { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("FailingParams") + } + } + + fn disconnected_transport() -> SocketTransport { + let runtime = Runtime::new().expect("failed to create tokio runtime"); + let (stream, peer) = duplex(1024); + drop(peer); + let (read, write) = split(stream); + SocketTransport::open_with_timeout(runtime, read, write, Duration::from_millis(50)) + } + + #[test] + fn request_does_not_retain_pending_entries_when_serialization_fails() { + let transport = disconnected_transport(); + + let result: Result = transport.request(TransportRequest { + id: 1, + method: "pgls/get_file_content", + params: FailingParams, + }); + + assert!(matches!(result, Err(TransportError::SerdeError(_)))); + assert_eq!( + transport.pending_requests.len(), + 0, + "pending request should be cleaned up on serialization failure" + ); + } + + #[test] + fn request_does_not_retain_pending_entries_on_timeout_or_channel_close() { + let transport = disconnected_transport(); + + let result: Result = transport.request(TransportRequest { + id: 2, + method: "pgls/get_file_content", + params: (), + }); + + assert!( + matches!( + result, + Err(TransportError::Timeout | TransportError::ChannelClosed) + ), + "expected timeout or channel-closed error, got {result:?}" + ); + assert_eq!( + transport.pending_requests.len(), + 0, + "pending request should be cleaned up on timeout/channel-close" + ); + } +} diff --git a/xtask/Cargo.toml b/xtask/Cargo.toml index fe2e0ec63..c61cdaabc 100644 --- a/xtask/Cargo.toml +++ b/xtask/Cargo.toml @@ -7,11 +7,15 @@ rust-version.workspace = true version = "0.0.0" [dependencies] -anyhow = "1.0.62" -flate2 = "1.0.24" -time = { version = "0.3", default-features = false } -write-json = "0.1.2" -xflags = "0.3.0" -xshell = "0.2.2" -zip = { version = "0.6", default-features = false, features = ["deflate", "time"] } +anyhow = "1.0.62" +flate2 = "1.0.24" +pgls_cli = { path = "../crates/pgls_cli" } +pgls_workspace = { path = "../crates/pgls_workspace" } +serde_json = "1.0.114" +time = { version = "0.3", default-features = false } +tokio = { version = "1.40.0", features = ["rt-multi-thread", "net"] } +write-json = "0.1.2" +xflags = "0.3.0" +xshell = "0.2.2" +zip = { version = "0.6", default-features = false, features = ["deflate", "time"] } # Avoid adding more dependencies to this crate diff --git a/xtask/src/flags.rs b/xtask/src/flags.rs index 4ed3da40b..6b11f06ca 100644 --- a/xtask/src/flags.rs +++ b/xtask/src/flags.rs @@ -18,6 +18,16 @@ xflags::xflags! { /// Install only the language server. optional --server } + + /// Run a simple macOS leak check against an isolated language server process. + cmd leak-check { + /// Number of open/change/close LSP cycles to run. + optional --iterations n: usize + /// Pause between cycles in milliseconds. + optional --pause-ms n: u64 + /// Probe to run: lsp | cli-timeout | both + optional --probe name: String + } } } @@ -32,6 +42,7 @@ pub struct Xtask { #[derive(Debug)] pub enum XtaskCmd { Install(Install), + LeakCheck(LeakCheck), } #[derive(Debug)] @@ -41,6 +52,13 @@ pub struct Install { pub server: bool, } +#[derive(Debug)] +pub struct LeakCheck { + pub iterations: Option, + pub pause_ms: Option, + pub probe: Option, +} + impl Xtask { #[allow(dead_code)] pub fn from_env_or_exit() -> Self { diff --git a/xtask/src/leak_check.rs b/xtask/src/leak_check.rs new file mode 100644 index 000000000..89cf878b9 --- /dev/null +++ b/xtask/src/leak_check.rs @@ -0,0 +1,353 @@ +use std::io::Write; +use std::process::{Child, ChildStdin, Command, Stdio}; +use std::thread; +use std::time::{Duration, Instant}; + +use anyhow::{bail, Context}; +#[cfg(unix)] +use pgls_cli::SocketTransport; +#[cfg(unix)] +use pgls_workspace::workspace::{TransportRequest, WorkspaceTransport}; +use serde_json::Value; +#[cfg(unix)] +use tokio::net::UnixStream; +use xshell::Shell; + +use crate::flags; + +const DEFAULT_ITERATIONS: usize = 200; +const DEFAULT_PAUSE_MS: u64 = 20; + +impl flags::LeakCheck { + pub(crate) fn run(self, _sh: &Shell) -> anyhow::Result<()> { + if !cfg!(target_os = "macos") { + bail!("`xtask leak-check` is currently implemented only for macOS (`leaks` tool)."); + } + + if Command::new("leaks").arg("--help").output().is_err() { + bail!( + "`leaks` not found — install Xcode Command Line Tools (`xcode-select --install`)" + ); + } + + let iterations = self.iterations.unwrap_or(DEFAULT_ITERATIONS); + let pause = Duration::from_millis(self.pause_ms.unwrap_or(DEFAULT_PAUSE_MS)); + let probe = self.probe.unwrap_or_else(|| "lsp".to_string()); + + match probe.as_str() { + "lsp" => run_lsp_probe(iterations, pause), + "cli-timeout" => run_cli_timeout_probe(iterations), + "both" => { + run_lsp_probe(iterations, pause)?; + run_cli_timeout_probe(iterations) + } + other => bail!("invalid --probe value `{other}` (expected: lsp | cli-timeout | both)"), + } + } +} + +fn run_lsp_probe(iterations: usize, pause: Duration) -> anyhow::Result<()> { + let status = Command::new("cargo") + .arg("build") + .arg("-p") + .arg("pgls_cli") + .status() + .context("failed to execute cargo build for pgls_cli")?; + if !status.success() { + bail!("failed to build pgls_cli binary"); + } + + let root = std::env::current_dir().context("failed to get current directory")?; + let binary = root.join("target/debug/postgres-language-server"); + + if !binary.exists() { + bail!("binary not found at {}", binary.display()); + } + + let server = ProcessGuard::spawn( + Command::new(&binary) + .arg("__run_server") + .arg("--stop-on-disconnect") + .arg("--log-level=error") + .arg("--log-kind=hierarchical") + .stdin(Stdio::null()) + .stdout(Stdio::null()) + .stderr(Stdio::null()), + "server", + )?; + + wait_for_socket(server.pid())?; + + let mut proxy = ProcessGuard::spawn( + Command::new(&binary) + .arg("lsp-proxy") + .arg("--stdio") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::null()), + "lsp-proxy", + )?; + + let mut proxy_stdin = proxy + .stdin() + .context("failed to capture lsp-proxy stdin pipe")?; + proxy.start_stdout_drain()?; + + run_lsp_churn(&mut proxy_stdin, iterations, pause)?; + + // Give background diagnostics/tasks a chance to settle. + thread::sleep(Duration::from_millis(300)); + + let leaks_output = run_leaks(server.pid())?; + let pass = leaks_output + .to_lowercase() + .contains("0 leaks for 0 total leaked bytes"); + + // Ask LSP to shutdown before cleanup. Ignore failures, cleanup guard is authoritative. + let _ = send_shutdown_and_exit(&mut proxy_stdin); + drop(proxy_stdin); + + if pass { + println!( + "LEAK_CHECK[LSP]: PASS (pid={}, iterations={iterations})", + server.pid() + ); + return Ok(()); + } + + println!("LEAK_CHECK[LSP]: FAIL (pid={})", server.pid()); + println!("{leaks_output}"); + bail!("leaks reported potential leaked allocations in lsp probe"); +} + +#[cfg(unix)] +fn run_cli_timeout_probe(iterations: usize) -> anyhow::Result<()> { + let rss_start = current_rss_kb(std::process::id())?; + + let runtime = tokio::runtime::Runtime::new().context("failed to create tokio runtime")?; + let _enter = runtime.enter(); + let (stream_a, stream_b) = + UnixStream::pair().context("failed to create unix socket pair for timeout probe")?; + drop(stream_b); + + let (read, write) = stream_a.into_split(); + let transport = + SocketTransport::open_with_timeout(runtime, read, write, Duration::from_millis(2)); + + let mut channel_closed = 0usize; + let mut timed_out = 0usize; + let mut other_errors = 0usize; + + for i in 0..iterations { + let request = TransportRequest { + id: i as u64, + method: "pgls/get_file_content", + params: (), + }; + + let result: Result = transport.request(request); + match result { + Err(pgls_workspace::TransportError::ChannelClosed) => channel_closed += 1, + Err(pgls_workspace::TransportError::Timeout) => timed_out += 1, + Err(_) => other_errors += 1, + Ok(_) => {} + } + } + + // Keep transport alive until after RSS sampling so retained map growth is visible. + let rss_end = current_rss_kb(std::process::id())?; + let rss_delta = rss_end.saturating_sub(rss_start); + + println!( + "LEAK_CHECK[CLI_TIMEOUT]: requests={iterations} channel_closed={channel_closed} timed_out={timed_out} other_errors={other_errors} rss_start_kb={rss_start} rss_end_kb={rss_end} rss_delta_kb={rss_delta}" + ); + + // Heuristic threshold for a strong "likely leak/retention" signal. + let fail_threshold_kb: u64 = 20_000; + if rss_delta >= fail_threshold_kb { + bail!( + "CLI timeout probe shows strong retained-memory growth (delta={rss_delta} KB >= {fail_threshold_kb} KB)" + ); + } + + Ok(()) +} + +#[cfg(not(unix))] +fn run_cli_timeout_probe(_iterations: usize) -> anyhow::Result<()> { + bail!("cli-timeout probe requires unix (UnixStream)"); +} + +struct ProcessGuard { + child: Child, + name: &'static str, + stdout_drain: Option>, +} + +impl ProcessGuard { + fn spawn(command: &mut Command, name: &'static str) -> anyhow::Result { + let child = command + .spawn() + .with_context(|| format!("failed to spawn {name}"))?; + Ok(Self { + child, + name, + stdout_drain: None, + }) + } + + fn pid(&self) -> u32 { + self.child.id() + } + + fn stdin(&mut self) -> Option { + self.child.stdin.take() + } + + fn start_stdout_drain(&mut self) -> anyhow::Result<()> { + let stdout = self + .child + .stdout + .take() + .context("failed to capture child stdout")?; + + let handle = thread::spawn(move || { + let mut reader = std::io::BufReader::new(stdout); + let mut sink = std::io::sink(); + let _ = std::io::copy(&mut reader, &mut sink); + }); + + self.stdout_drain = Some(handle); + Ok(()) + } +} + +impl Drop for ProcessGuard { + fn drop(&mut self) { + if let Ok(None) = self.child.try_wait() { + let _ = self.child.kill(); + } + let _ = self.child.wait(); + if let Some(handle) = self.stdout_drain.take() { + let _ = handle.join(); + } + let _ = self.name; + } +} + +fn wait_for_socket(pid: u32) -> anyhow::Result<()> { + let deadline = Instant::now() + Duration::from_secs(10); + while Instant::now() < deadline { + let output = Command::new("lsof") + .arg("-p") + .arg(pid.to_string()) + .output() + .context("failed to run `lsof` while waiting for socket")?; + + let combined = format!( + "{}\n{}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + + if combined.contains("pgls-socket-") { + return Ok(()); + } + thread::sleep(Duration::from_millis(50)); + } + + bail!("timed out waiting for server socket to become ready"); +} + +fn run_lsp_churn(stdin: &mut ChildStdin, iterations: usize, pause: Duration) -> anyhow::Result<()> { + let uri = "file:///tmp/pgls-leak-check.sql"; + + send_lsp_json( + stdin, + serde_json::json!({"jsonrpc":"2.0","id":1,"method":"initialize","params":{"capabilities":{},"rootUri":null}}), + )?; + send_lsp_json( + stdin, + serde_json::json!({"jsonrpc":"2.0","method":"initialized","params":{}}), + )?; + + for i in 0..iterations { + let open_text = format!("select {i} as value;"); + let changed_text = format!("select {i} as value, {} as extra;", i + 1); + + send_lsp_json( + stdin, + serde_json::json!({"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{"uri":uri,"languageId":"sql","version":1,"text":open_text}}}), + )?; + + send_lsp_json( + stdin, + serde_json::json!({"jsonrpc":"2.0","method":"textDocument/didChange","params":{"textDocument":{"uri":uri,"version":2},"contentChanges":[{"text":changed_text}]}}), + )?; + + send_lsp_json( + stdin, + serde_json::json!({"jsonrpc":"2.0","method":"textDocument/didClose","params":{"textDocument":{"uri":uri}}}), + )?; + + thread::sleep(pause); + } + + Ok(()) +} + +fn send_shutdown_and_exit(stdin: &mut ChildStdin) -> anyhow::Result<()> { + send_lsp_json( + stdin, + serde_json::json!({"jsonrpc":"2.0","id":2,"method":"shutdown","params":null}), + )?; + send_lsp_json( + stdin, + serde_json::json!({"jsonrpc":"2.0","method":"exit","params":null}), + )?; + Ok(()) +} + +fn send_lsp_json(stdin: &mut ChildStdin, value: Value) -> anyhow::Result<()> { + let payload = serde_json::to_string(&value).context("failed to serialize LSP message")?; + let header = format!("Content-Length: {}\r\n\r\n", payload.len()); + stdin + .write_all(header.as_bytes()) + .context("failed to write LSP header")?; + stdin + .write_all(payload.as_bytes()) + .context("failed to write LSP payload")?; + stdin.flush().context("failed to flush LSP payload") +} + +fn run_leaks(pid: u32) -> anyhow::Result { + let output = Command::new("leaks") + .arg(pid.to_string()) + .output() + .context("failed to run `leaks`")?; + + let combined = format!( + "{}\n{}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + + Ok(combined) +} + +fn current_rss_kb(pid: u32) -> anyhow::Result { + let output = Command::new("ps") + .arg("-o") + .arg("rss=") + .arg("-p") + .arg(pid.to_string()) + .output() + .context("failed to run `ps` for rss sampling")?; + + let value = String::from_utf8_lossy(&output.stdout); + let trimmed = value.trim(); + let rss = trimmed + .parse::() + .with_context(|| format!("failed to parse rss value from `{trimmed}`"))?; + Ok(rss) +} diff --git a/xtask/src/main.rs b/xtask/src/main.rs index 282470849..8e5fbf77a 100644 --- a/xtask/src/main.rs +++ b/xtask/src/main.rs @@ -17,6 +17,7 @@ mod flags; mod install; +mod leak_check; use std::{ env, @@ -32,6 +33,7 @@ fn main() -> anyhow::Result<()> { match flags.subcommand { flags::XtaskCmd::Install(cmd) => cmd.run(sh), + flags::XtaskCmd::LeakCheck(cmd) => cmd.run(sh), } }