diff --git a/codex-rs/utils/pty/src/pipe.rs b/codex-rs/utils/pty/src/pipe.rs index 3a9b62d9b7b0..caa12974f6d7 100644 --- a/codex-rs/utils/pty/src/pipe.rs +++ b/codex-rs/utils/pty/src/pipe.rs @@ -1,38 +1,56 @@ use std::collections::HashMap; +#[cfg(not(windows))] use std::io; +#[cfg(not(windows))] use std::io::ErrorKind; use std::path::Path; +#[cfg(not(windows))] use std::process::Stdio; +#[cfg(not(windows))] use std::sync::Arc; +#[cfg(not(windows))] use std::sync::Mutex as StdMutex; +#[cfg(not(windows))] use std::sync::atomic::AtomicBool; use anyhow::Result; +#[cfg(not(windows))] use tokio::io::AsyncRead; +#[cfg(not(windows))] use tokio::io::AsyncReadExt; +#[cfg(not(windows))] use tokio::io::AsyncWriteExt; +#[cfg(not(windows))] use tokio::io::BufReader; +#[cfg(not(windows))] use tokio::process::Command; +#[cfg(not(windows))] use tokio::sync::mpsc; +#[cfg(not(windows))] use tokio::sync::oneshot; +#[cfg(not(windows))] use tokio::task::JoinHandle; +#[cfg(not(windows))] use crate::process::ChildTerminator; +#[cfg(not(windows))] use crate::process::ProcessHandle; +#[cfg(not(windows))] use crate::process::ProcessSignal; use crate::process::SpawnedProcess; +#[cfg(not(windows))] use crate::process::exit_code_from_status; #[cfg(target_os = "linux")] use libc; +#[cfg(not(windows))] struct PipeChildTerminator { - #[cfg(windows)] - pid: u32, #[cfg(unix)] process_group_id: u32, } +#[cfg(not(windows))] impl ChildTerminator for PipeChildTerminator { fn signal(&mut self, signal: ProcessSignal) -> io::Result<()> { match signal { @@ -56,11 +74,6 @@ impl ChildTerminator for PipeChildTerminator { crate::process_group::kill_process_group(self.process_group_id) } - #[cfg(windows)] - { - kill_process(self.pid) - } - #[cfg(not(any(unix, windows)))] { Ok(()) @@ -68,24 +81,7 @@ impl ChildTerminator for PipeChildTerminator { } } -#[cfg(windows)] -fn kill_process(pid: u32) -> io::Result<()> { - unsafe { - let handle = winapi::um::processthreadsapi::OpenProcess( - winapi::um::winnt::PROCESS_TERMINATE, - 0, - pid, - ); - if handle.is_null() { - return Err(io::Error::last_os_error()); - } - let success = winapi::um::processthreadsapi::TerminateProcess(handle, 1); - let err = io::Error::last_os_error(); - winapi::um::handleapi::CloseHandle(handle); - if success == 0 { Err(err) } else { Ok(()) } - } -} - +#[cfg(not(windows))] async fn read_output_stream(mut reader: R, output_tx: mpsc::Sender>) where R: AsyncRead + Unpin, @@ -103,12 +99,14 @@ where } } +#[cfg(not(windows))] #[derive(Clone, Copy)] enum PipeStdinMode { Piped, Null, } +#[cfg(not(windows))] async fn spawn_process_with_stdin_mode( program: &str, args: &[String], @@ -240,8 +238,6 @@ async fn spawn_process_with_stdin_mode( let handle = ProcessHandle::new( writer_tx, Box::new(PipeChildTerminator { - #[cfg(windows)] - pid, #[cfg(unix)] process_group_id, }), @@ -271,7 +267,16 @@ pub async fn spawn_process( env: &HashMap, arg0: &Option, ) -> Result { - spawn_process_with_stdin_mode(program, args, cwd, env, arg0, PipeStdinMode::Piped, &[]).await + #[cfg(windows)] + { + let _ = arg0; + crate::win::pipe::spawn_process(program, args, cwd, env).await + } + #[cfg(not(windows))] + { + spawn_process_with_stdin_mode(program, args, cwd, env, arg0, PipeStdinMode::Piped, &[]) + .await + } } /// Spawn a process using regular pipes, but close stdin immediately. @@ -295,14 +300,22 @@ pub async fn spawn_process_no_stdin_with_inherited_fds( arg0: &Option, inherited_fds: &[i32], ) -> Result { - spawn_process_with_stdin_mode( - program, - args, - cwd, - env, - arg0, - PipeStdinMode::Null, - inherited_fds, - ) - .await + #[cfg(windows)] + { + let _ = (arg0, inherited_fds); + crate::win::pipe::spawn_process_no_stdin(program, args, cwd, env).await + } + #[cfg(not(windows))] + { + spawn_process_with_stdin_mode( + program, + args, + cwd, + env, + arg0, + PipeStdinMode::Null, + inherited_fds, + ) + .await + } } diff --git a/codex-rs/utils/pty/src/process.rs b/codex-rs/utils/pty/src/process.rs index fdd693f43d6c..307d9e2f0050 100644 --- a/codex-rs/utils/pty/src/process.rs +++ b/codex-rs/utils/pty/src/process.rs @@ -2,6 +2,7 @@ use core::fmt; use std::io; #[cfg(unix)] use std::os::fd::RawFd; +#[cfg(not(windows))] use std::process::ExitStatus; use std::sync::Arc; use std::sync::Mutex as StdMutex; @@ -32,6 +33,7 @@ pub(crate) fn unsupported_signal(signal: ProcessSignal) -> io::Error { } } +#[cfg(not(windows))] pub(crate) fn exit_code_from_status(status: ExitStatus) -> i32 { if let Some(code) = status.code() { return code; diff --git a/codex-rs/utils/pty/src/tests.rs b/codex-rs/utils/pty/src/tests.rs index 7da410e3e467..41241093a321 100644 --- a/codex-rs/utils/pty/src/tests.rs +++ b/codex-rs/utils/pty/src/tests.rs @@ -28,6 +28,10 @@ mod windows_job_test_support; #[path = "windows_conpty_job_tests.rs"] mod windows_conpty_job_tests; +#[cfg(windows)] +#[path = "windows_pipe_job_tests.rs"] +mod windows_pipe_job_tests; + fn find_python() -> Option { for candidate in ["python3", "python"] { if let Ok(output) = std::process::Command::new(candidate) diff --git a/codex-rs/utils/pty/src/win/mod.rs b/codex-rs/utils/pty/src/win/mod.rs index d91152e5db99..954ce06c631d 100644 --- a/codex-rs/utils/pty/src/win/mod.rs +++ b/codex-rs/utils/pty/src/win/mod.rs @@ -40,10 +40,10 @@ use winapi::um::synchapi::WaitForSingleObject; use winapi::um::winbase::INFINITE; use winapi::um::winbase::WAIT_OBJECT_0; -#[cfg(test)] mod command; pub(crate) mod conpty; mod job; +pub(crate) mod pipe; mod procthreadattr; mod psuedocon; diff --git a/codex-rs/utils/pty/src/win/pipe.rs b/codex-rs/utils/pty/src/win/pipe.rs new file mode 100644 index 000000000000..b991b8eafae4 --- /dev/null +++ b/codex-rs/utils/pty/src/win/pipe.rs @@ -0,0 +1,309 @@ +use std::collections::HashMap; +use std::fs::File; +use std::io; +use std::io::ErrorKind; +use std::io::Read; +use std::io::Write; +use std::mem; +use std::os::windows::io::AsRawHandle; +use std::os::windows::io::FromRawHandle; +use std::os::windows::io::OwnedHandle; +use std::path::Path; +use std::ptr; +use std::sync::Arc; +use std::sync::Mutex as StdMutex; +use std::sync::atomic::AtomicBool; + +use anyhow::Context as _; +use anyhow::Result; +use tokio::sync::mpsc; +use tokio::sync::oneshot; +use tokio::task::JoinHandle; +use winapi::shared::minwindef::FALSE; +use winapi::shared::minwindef::TRUE; +use winapi::um::handleapi::SetHandleInformation; +use winapi::um::minwinbase::SECURITY_ATTRIBUTES; +use winapi::um::namedpipeapi::CreatePipe; +use winapi::um::processthreadsapi::CreateProcessW; +use winapi::um::processthreadsapi::GetExitCodeProcess; +use winapi::um::processthreadsapi::PROCESS_INFORMATION; +use winapi::um::synchapi::WaitForSingleObject; +use winapi::um::winbase::CREATE_SUSPENDED; +use winapi::um::winbase::CREATE_UNICODE_ENVIRONMENT; +use winapi::um::winbase::EXTENDED_STARTUPINFO_PRESENT; +use winapi::um::winbase::HANDLE_FLAG_INHERIT; +use winapi::um::winbase::INFINITE; +use winapi::um::winbase::STARTF_USESTDHANDLES; +use winapi::um::winbase::STARTUPINFOEXW; +use winapi::um::winbase::WAIT_OBJECT_0; +use winapi::um::winnt::HANDLE; + +use super::KillOnCloseJob; +use super::SuspendedProcess; +use super::command::prepare_command; +use super::procthreadattr::ProcThreadAttributeList; +use crate::process::ChildTerminator; +use crate::process::ProcessHandle; +use crate::process::ProcessSignal; +use crate::process::SpawnedProcess; + +struct JobTerminator { + controller: KillOnCloseJob, +} + +impl ChildTerminator for JobTerminator { + fn signal(&mut self, signal: ProcessSignal) -> io::Result<()> { + Err(crate::process::unsupported_signal(signal)) + } + + fn kill(&mut self) -> io::Result<()> { + self.controller.terminate_and_close(/*exit_code*/ 1) + } +} + +#[derive(Clone, Copy)] +enum StdinMode { + Piped, + Null, +} + +enum ParentPipeEnd { + Reads, + Writes, +} + +struct PipeEnds { + parent: OwnedHandle, + child: OwnedHandle, +} + +pub(crate) async fn spawn_process( + program: &str, + args: &[String], + cwd: &Path, + env: &HashMap, +) -> Result { + spawn_process_with_stdin_mode(program, args, cwd, env, StdinMode::Piped).await +} + +pub(crate) async fn spawn_process_no_stdin( + program: &str, + args: &[String], + cwd: &Path, + env: &HashMap, +) -> Result { + spawn_process_with_stdin_mode(program, args, cwd, env, StdinMode::Null).await +} + +async fn spawn_process_with_stdin_mode( + program: &str, + args: &[String], + cwd: &Path, + env: &HashMap, + stdin_mode: StdinMode, +) -> Result { + let mut command = prepare_command(program, args, cwd, env)?; + let stdin = create_pipe(ParentPipeEnd::Writes).context("failed to create stdin pipe")?; + let stdout = create_pipe(ParentPipeEnd::Reads).context("failed to create stdout pipe")?; + let stderr = create_pipe(ParentPipeEnd::Reads).context("failed to create stderr pipe")?; + + let mut startup: STARTUPINFOEXW = unsafe { mem::zeroed() }; + startup.StartupInfo.cb = mem::size_of::() as u32; + startup.StartupInfo.dwFlags = STARTF_USESTDHANDLES; + startup.StartupInfo.hStdInput = stdin.child.as_raw_handle().cast(); + startup.StartupInfo.hStdOutput = stdout.child.as_raw_handle().cast(); + startup.StartupInfo.hStdError = stderr.child.as_raw_handle().cast(); + + let mut child_handles = [ + stdin.child.as_raw_handle().cast(), + stdout.child.as_raw_handle().cast(), + stderr.child.as_raw_handle().cast(), + ]; + let mut attributes = ProcThreadAttributeList::with_capacity(/*num_attributes*/ 1)?; + attributes.set_handle_list(&mut child_handles)?; + startup.lpAttributeList = attributes.as_mut_ptr(); + + let job = KillOnCloseJob::new().context("failed to create process job")?; + let mut process_information: PROCESS_INFORMATION = unsafe { mem::zeroed() }; + let created = unsafe { + CreateProcessW( + command.application.as_ptr(), + command.command_line.as_mut_ptr(), + ptr::null_mut(), + ptr::null_mut(), + TRUE, + CREATE_SUSPENDED | CREATE_UNICODE_ENVIRONMENT | EXTENDED_STARTUPINFO_PRESENT, + command.environment.as_mut_ptr().cast(), + command.current_directory.as_ptr(), + &mut startup.StartupInfo, + &mut process_information, + ) + }; + if created == FALSE { + return Err(io::Error::last_os_error()).context("failed to create pipe child process"); + } + + let suspended = unsafe { + SuspendedProcess::from_raw_handles( + process_information.hProcess.cast(), + process_information.hThread.cast(), + process_information.dwProcessId, + ) + }; + drop(stdin.child); + drop(stdout.child); + drop(stderr.child); + let stdin_parent = match stdin_mode { + StdinMode::Piped => Some(stdin.parent), + StdinMode::Null => { + drop(stdin.parent); + None + } + }; + let process = suspended + .assign_and_resume(job) + .context("failed to contain and resume pipe child process")?; + let controller = process.controller(); + + let (writer_tx, writer_rx) = mpsc::channel::>(128); + let (stdout_tx, stdout_rx) = mpsc::channel::>(128); + let (stderr_tx, stderr_rx) = mpsc::channel::>(128); + + let writer_handle = match stdin_parent { + Some(stdin_parent) => spawn_writer(File::from(stdin_parent), writer_rx), + None => { + drop(writer_rx); + tokio::spawn(async {}) + } + }; + let stdout_reader = spawn_reader(File::from(stdout.parent), stdout_tx); + let stderr_reader = spawn_reader(File::from(stderr.parent), stderr_tx); + let reader_abort_handles = vec![stdout_reader.abort_handle(), stderr_reader.abort_handle()]; + let reader_handle = tokio::spawn(async move { + let _ = stdout_reader.await; + let _ = stderr_reader.await; + }); + + let (exit_tx, exit_rx) = oneshot::channel::(); + let exit_status = Arc::new(AtomicBool::new(false)); + let wait_exit_status = Arc::clone(&exit_status); + let exit_code = Arc::new(StdMutex::new(None)); + let wait_exit_code = Arc::clone(&exit_code); + let wait_handle: JoinHandle<()> = tokio::task::spawn_blocking(move || { + let code = wait_for_process(&process).unwrap_or(-1); + let _ = process.controller().close(); + wait_exit_status.store(true, std::sync::atomic::Ordering::SeqCst); + if let Ok(mut guard) = wait_exit_code.lock() { + *guard = Some(code); + } + let _ = exit_tx.send(code); + }); + + let handle = ProcessHandle::new( + writer_tx, + Box::new(JobTerminator { controller }), + reader_handle, + reader_abort_handles, + writer_handle, + wait_handle, + exit_status, + exit_code, + /*pty_handles*/ None, + /*resizer*/ None, + ); + + Ok(SpawnedProcess { + session: handle, + stdout_rx, + stderr_rx, + exit_rx, + }) +} + +fn create_pipe(parent_end: ParentPipeEnd) -> io::Result { + let mut attributes: SECURITY_ATTRIBUTES = unsafe { mem::zeroed() }; + attributes.nLength = mem::size_of::() as u32; + attributes.bInheritHandle = TRUE; + let mut read_handle: HANDLE = ptr::null_mut(); + let mut write_handle: HANDLE = ptr::null_mut(); + let created = unsafe { + CreatePipe( + &mut read_handle, + &mut write_handle, + &mut attributes, + /*nSize*/ 0, + ) + }; + if created == FALSE { + return Err(io::Error::last_os_error()); + } + + let read_handle = unsafe { OwnedHandle::from_raw_handle(read_handle.cast()) }; + let write_handle = unsafe { OwnedHandle::from_raw_handle(write_handle.cast()) }; + let (parent, child) = match parent_end { + ParentPipeEnd::Reads => (read_handle, write_handle), + ParentPipeEnd::Writes => (write_handle, read_handle), + }; + let result = unsafe { + SetHandleInformation( + parent.as_raw_handle().cast(), + HANDLE_FLAG_INHERIT, + /*dwFlags*/ 0, + ) + }; + if result == FALSE { + return Err(io::Error::last_os_error()); + } + + Ok(PipeEnds { parent, child }) +} + +fn spawn_writer(mut writer: File, mut input_rx: mpsc::Receiver>) -> JoinHandle<()> { + tokio::task::spawn_blocking(move || { + while let Some(bytes) = input_rx.blocking_recv() { + if writer + .write_all(&bytes) + .and_then(|()| writer.flush()) + .is_err() + { + break; + } + } + }) +} + +fn spawn_reader(mut reader: File, output_tx: mpsc::Sender>) -> JoinHandle<()> { + tokio::task::spawn_blocking(move || { + let mut buffer = vec![0u8; 8_192]; + loop { + match reader.read(&mut buffer) { + Ok(0) => break, + Ok(bytes_read) => { + if output_tx + .blocking_send(buffer[..bytes_read].to_vec()) + .is_err() + { + break; + } + } + Err(error) if error.kind() == ErrorKind::Interrupted => continue, + Err(_) => break, + } + } + }) +} + +fn wait_for_process(process: &super::JobProcess) -> io::Result { + let wait_result = unsafe { WaitForSingleObject(process.as_raw_handle().cast(), INFINITE) }; + if wait_result != WAIT_OBJECT_0 { + return Err(io::Error::last_os_error()); + } + + let mut exit_code = 0; + let result = unsafe { GetExitCodeProcess(process.as_raw_handle().cast(), &mut exit_code) }; + if result == FALSE { + Err(io::Error::last_os_error()) + } else { + Ok(exit_code as i32) + } +} diff --git a/codex-rs/utils/pty/src/win/procthreadattr.rs b/codex-rs/utils/pty/src/win/procthreadattr.rs index c7cf68fed9aa..acca482c86b3 100644 --- a/codex-rs/utils/pty/src/win/procthreadattr.rs +++ b/codex-rs/utils/pty/src/win/procthreadattr.rs @@ -26,7 +26,9 @@ use std::mem; use std::ptr; use winapi::shared::minwindef::DWORD; use winapi::um::processthreadsapi::*; +use winapi::um::winnt::HANDLE; +const PROC_THREAD_ATTRIBUTE_HANDLE_LIST: usize = 0x00020002; const PROC_THREAD_ATTRIBUTE_PSEUDOCONSOLE: usize = 0x00020016; pub struct ProcThreadAttributeList { @@ -82,6 +84,26 @@ impl ProcThreadAttributeList { ); Ok(()) } + + pub fn set_handle_list(&mut self, handles: &mut [HANDLE]) -> Result<(), Error> { + let res = unsafe { + UpdateProcThreadAttribute( + self.as_mut_ptr(), + 0, + PROC_THREAD_ATTRIBUTE_HANDLE_LIST, + handles.as_mut_ptr().cast(), + mem::size_of_val(handles), + ptr::null_mut(), + ptr::null_mut(), + ) + }; + ensure!( + res != 0, + "UpdateProcThreadAttribute for handle list failed: {}", + IoError::last_os_error() + ); + Ok(()) + } } impl Drop for ProcThreadAttributeList { diff --git a/codex-rs/utils/pty/src/windows_pipe_job_tests.rs b/codex-rs/utils/pty/src/windows_pipe_job_tests.rs new file mode 100644 index 000000000000..010393622dcd --- /dev/null +++ b/codex-rs/utils/pty/src/windows_pipe_job_tests.rs @@ -0,0 +1,99 @@ +use super::collect_split_output; +use super::windows_job_test_support::TestDirectory; +use super::windows_job_test_support::wait_for_path; +use super::windows_job_test_support::write_descendant_scripts; +use crate::SpawnedProcess; +use crate::spawn_pipe_process_no_stdin; +use std::collections::HashMap; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn raw_pipe_root_exit_closes_job_and_inherited_output() -> anyhow::Result<()> { + let directory = TestDirectory::new("pipe-root-exit")?; + let (root, ready, escaped) = write_descendant_scripts(&directory, true)?; + let env: HashMap = std::env::vars().collect(); + let spawned = spawn_pipe_process_no_stdin( + root.to_string_lossy().as_ref(), + &[], + &directory.path, + &env, + &None, + ) + .await?; + let SpawnedProcess { + session: _session, + stdout_rx, + stderr_rx, + exit_rx, + } = spawned; + let stdout_task = tokio::spawn(collect_split_output(stdout_rx)); + let stderr_task = tokio::spawn(collect_split_output(stderr_rx)); + let timeout = tokio::time::Duration::from_secs(10); + let exit_code = tokio::time::timeout(timeout, exit_rx).await??; + let stdout = tokio::time::timeout(timeout, stdout_task).await??; + let _stderr = tokio::time::timeout(timeout, stderr_task).await??; + + assert_eq!(exit_code, 37); + assert!(ready.exists()); + assert!(!escaped.exists()); + assert!(String::from_utf8_lossy(&stdout).contains("inherited-grandchild-ready")); + assert!(!String::from_utf8_lossy(&stdout).contains("inherited-grandchild-escaped")); + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn raw_pipe_explicit_termination_kills_descendants() -> anyhow::Result<()> { + let directory = TestDirectory::new("pipe-terminate")?; + let (root, ready, escaped) = write_descendant_scripts(&directory, false)?; + let env: HashMap = std::env::vars().collect(); + let spawned = spawn_pipe_process_no_stdin( + root.to_string_lossy().as_ref(), + &[], + &directory.path, + &env, + &None, + ) + .await?; + let SpawnedProcess { + session, + stdout_rx, + stderr_rx, + exit_rx, + } = spawned; + let stdout_task = tokio::spawn(collect_split_output(stdout_rx)); + let stderr_task = tokio::spawn(collect_split_output(stderr_rx)); + wait_for_path(&ready).await?; + session.request_terminate(); + let timeout = tokio::time::Duration::from_secs(10); + let _exit_code = tokio::time::timeout(timeout, exit_rx).await??; + let stdout = tokio::time::timeout(timeout, stdout_task).await??; + let _stderr = tokio::time::timeout(timeout, stderr_task).await??; + + assert!(!escaped.exists()); + assert!(String::from_utf8_lossy(&stdout).contains("inherited-grandchild-ready")); + assert!(!String::from_utf8_lossy(&stdout).contains("inherited-grandchild-escaped")); + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn dropping_raw_pipe_session_kills_descendants() -> anyhow::Result<()> { + let directory = TestDirectory::new("pipe-drop")?; + let (root, ready, escaped) = write_descendant_scripts(&directory, false)?; + let env: HashMap = std::env::vars().collect(); + let spawned = spawn_pipe_process_no_stdin( + root.to_string_lossy().as_ref(), + &[], + &directory.path, + &env, + &None, + ) + .await?; + wait_for_path(&ready).await?; + drop(spawned.session); + drop(spawned.stdout_rx); + drop(spawned.stderr_rx); + drop(spawned.exit_rx); + tokio::time::sleep(tokio::time::Duration::from_secs(4)).await; + + assert!(!escaped.exists()); + Ok(()) +}