diff --git a/codex-rs/core/src/exec.rs b/codex-rs/core/src/exec.rs index fd5cd7bcdc6c..2737768cf39e 100644 --- a/codex-rs/core/src/exec.rs +++ b/codex-rs/core/src/exec.rs @@ -222,6 +222,17 @@ impl ExecExpiration { } } + #[cfg_attr(not(target_os = "windows"), allow(dead_code))] + pub(crate) fn cancellation_token(&self) -> Option { + match self { + ExecExpiration::Timeout(_) | ExecExpiration::DefaultTimeout => None, + ExecExpiration::Cancellation(cancellation) + | ExecExpiration::TimeoutOrCancellation { cancellation, .. } => { + Some(cancellation.clone()) + } + } + } + pub(crate) fn with_cancellation(self, cancellation: CancellationToken) -> Self { match self { ExecExpiration::Timeout(timeout) => ExecExpiration::TimeoutOrCancellation { @@ -581,13 +592,21 @@ async fn exec_windows_sandbox( network.apply_to_env(&mut env); } - // TODO(iceweasel-oai): run_windows_sandbox_capture should support all - // variants of ExecExpiration, not just timeout. + // Windows sandbox capture still receives timeout and cancellation separately. let timeout_ms = if capture_policy.uses_expiration() { expiration.timeout_ms() } else { None }; + let cancellation = if capture_policy.uses_expiration() { + expiration.cancellation_token().map(|token| { + codex_windows_sandbox::WindowsSandboxCancellationToken::new(move || { + token.is_cancelled() + }) + }) + } else { + None + }; let policy_str = serde_json::to_string(sandbox_policy).map_err(|err| { CodexErr::Io(io::Error::other(format!( @@ -639,6 +658,7 @@ async fn exec_windows_sandbox( cwd: &cwd, env_map: env, timeout_ms, + cancellation, use_private_desktop: windows_sandbox_private_desktop, proxy_enforced, read_roots_override: elevated_read_roots_override.as_deref(), @@ -657,6 +677,7 @@ async fn exec_windows_sandbox( &cwd, env, timeout_ms, + cancellation, &additional_deny_write_paths, windows_sandbox_private_desktop, ) diff --git a/codex-rs/windows-sandbox-rs/src/elevated_impl.rs b/codex-rs/windows-sandbox-rs/src/elevated_impl.rs index 2c1c4f79cec6..3d7bdbfb1596 100644 --- a/codex-rs/windows-sandbox-rs/src/elevated_impl.rs +++ b/codex-rs/windows-sandbox-rs/src/elevated_impl.rs @@ -10,6 +10,7 @@ pub struct ElevatedSandboxCaptureRequest<'a> { pub cwd: &'a Path, pub env_map: HashMap, pub timeout_ms: Option, + pub cancellation: Option, pub use_private_desktop: bool, pub proxy_enforced: bool, pub read_roots_override: Option<&'a [PathBuf]>, @@ -26,11 +27,14 @@ mod windows_impl { use crate::env::inherit_path_env; use crate::env::normalize_null_device_env; use crate::identity::require_logon_sandbox_creds; + use crate::ipc_framed::EmptyPayload; + use crate::ipc_framed::FramedMessage; use crate::ipc_framed::Message; use crate::ipc_framed::OutputStream; use crate::ipc_framed::SpawnRequest; use crate::ipc_framed::decode_bytes; use crate::ipc_framed::read_frame; + use crate::ipc_framed::write_frame; use crate::logging::log_failure; use crate::logging::log_start; use crate::logging::log_success; @@ -40,8 +44,13 @@ mod windows_impl { use crate::token::convert_string_sid_to_sid; use anyhow::Result; use std::collections::HashMap; + use std::fs::File; use std::path::Path; use std::path::PathBuf; + use std::sync::Arc; + use std::sync::atomic::AtomicBool; + use std::sync::atomic::Ordering; + use std::time::Duration; /// Ensures the parent directory of a path exists before writing to it. /// Walks upward from `start` to locate the git worktree root, following gitfile redirects. @@ -106,6 +115,36 @@ mod windows_impl { pub use crate::windows_impl::CaptureResult; + fn spawn_cancel_writer( + pipe_write: &File, + cancellation: Option, + ) -> Result, Arc)>> { + let Some(cancellation) = cancellation else { + return Ok(None); + }; + let mut pipe_write = pipe_write.try_clone()?; + let done = Arc::new(AtomicBool::new(false)); + let done_for_thread = Arc::clone(&done); + let handle = std::thread::spawn(move || { + while !done_for_thread.load(Ordering::SeqCst) { + if cancellation.is_cancelled() { + let _ = write_frame( + &mut pipe_write, + &FramedMessage { + version: 1, + message: Message::Terminate { + payload: EmptyPayload::default(), + }, + }, + ); + break; + } + std::thread::park_timeout(Duration::from_millis(50)); + } + }); + Ok(Some((handle, done))) + } + /// Launches the command runner under the sandbox user and captures its output. #[allow(clippy::too_many_arguments)] pub fn run_windows_sandbox_capture( @@ -119,6 +158,7 @@ mod windows_impl { cwd, mut env_map, timeout_ms, + cancellation, use_private_desktop, proxy_enforced, read_roots_override, @@ -206,33 +246,45 @@ mod windows_impl { spawn_request, )?; let (pipe_write, mut pipe_read) = transport.into_files(); - drop(pipe_write); + let cancel_writer = spawn_cancel_writer(&pipe_write, cancellation)?; let mut stdout = Vec::new(); let mut stderr = Vec::new(); - let (exit_code, timed_out) = loop { - let msg = read_frame(&mut pipe_read)? - .ok_or_else(|| anyhow::anyhow!("runner pipe closed before exit"))?; + let result = loop { + let msg = match read_frame(&mut pipe_read) { + Ok(Some(msg)) => msg, + Ok(None) => break Err(anyhow::anyhow!("runner pipe closed before exit")), + Err(err) => break Err(err), + }; match msg.message { Message::SpawnReady { .. } => {} - Message::Output { payload } => { - let bytes = decode_bytes(&payload.data_b64)?; - match payload.stream { + Message::Output { payload } => match decode_bytes(&payload.data_b64) { + Ok(bytes) => match payload.stream { OutputStream::Stdout => stdout.extend_from_slice(&bytes), OutputStream::Stderr => stderr.extend_from_slice(&bytes), + }, + Err(err) => { + break Err(err); } - } - Message::Exit { payload } => break (payload.exit_code, payload.timed_out), + }, + Message::Exit { payload } => break Ok((payload.exit_code, payload.timed_out)), Message::Error { payload } => { - return Err(anyhow::anyhow!("runner error: {}", payload.message)); + break Err(anyhow::anyhow!("runner error: {}", payload.message)); } other => { - return Err(anyhow::anyhow!( + break Err(anyhow::anyhow!( "unexpected runner message during capture: {other:?}" )); } } }; + if let Some((cancel_handle, done)) = cancel_writer { + done.store(true, Ordering::SeqCst); + cancel_handle.thread().unpark(); + let _ = cancel_handle.join(); + } + drop(pipe_write); + let (exit_code, timed_out) = result?; if exit_code == 0 { log_success(&command, logs_base_dir); diff --git a/codex-rs/windows-sandbox-rs/src/lib.rs b/codex-rs/windows-sandbox-rs/src/lib.rs index fc68f1ab4498..d607e199bff6 100644 --- a/codex-rs/windows-sandbox-rs/src/lib.rs +++ b/codex-rs/windows-sandbox-rs/src/lib.rs @@ -5,6 +5,36 @@ #[cfg(any(target_os = "windows", test))] mod ssh_config_dependencies; +use std::fmt; +use std::sync::Arc; + +/// Cancellation hook used by Windows sandbox capture backends. +#[derive(Clone)] +pub struct WindowsSandboxCancellationToken { + is_cancelled: Arc bool + Send + Sync>, +} + +impl WindowsSandboxCancellationToken { + /// Creates a token backed by a cancellation predicate. + pub fn new(is_cancelled: impl Fn() -> bool + Send + Sync + 'static) -> Self { + Self { + is_cancelled: Arc::new(is_cancelled), + } + } + + /// Returns whether the caller has requested cancellation. + pub fn is_cancelled(&self) -> bool { + (self.is_cancelled)() + } +} + +impl fmt::Debug for WindowsSandboxCancellationToken { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("WindowsSandboxCancellationToken") + .finish_non_exhaustive() + } +} + macro_rules! windows_modules { ($($name:ident),+ $(,)?) => { $(#[cfg(target_os = "windows")] mod $name;)+ @@ -241,6 +271,7 @@ pub use stub::run_windows_sandbox_legacy_preflight; #[cfg(target_os = "windows")] mod windows_impl { + use super::WindowsSandboxCancellationToken; use super::acl::add_allow_ace; use super::acl::add_deny_write_ace; use super::acl::allow_null_device; @@ -266,6 +297,8 @@ mod windows_impl { use std::path::Path; use std::path::PathBuf; use std::ptr; + use std::time::Duration; + use std::time::Instant; use windows_sys::Win32::Foundation::CloseHandle; use windows_sys::Win32::Foundation::GetLastError; use windows_sys::Win32::Foundation::HANDLE; @@ -278,6 +311,50 @@ mod windows_impl { type PipeHandles = ((HANDLE, HANDLE), (HANDLE, HANDLE), (HANDLE, HANDLE)); + enum WaitOutcome { + Exited, + TimedOut, + Cancelled, + } + + fn wait_for_process( + process: HANDLE, + timeout_ms: Option, + cancellation: Option<&WindowsSandboxCancellationToken>, + ) -> WaitOutcome { + let Some(cancellation) = cancellation else { + let timeout = timeout_ms.map(|ms| ms as u32).unwrap_or(INFINITE); + let res = unsafe { WaitForSingleObject(process, timeout) }; + return if res == 0x0000_0102 { + WaitOutcome::TimedOut + } else { + WaitOutcome::Exited + }; + }; + + let deadline = timeout_ms.map(|ms| Instant::now() + Duration::from_millis(ms)); + loop { + if cancellation.is_cancelled() { + return WaitOutcome::Cancelled; + } + let wait_ms = match deadline { + Some(deadline) => { + let remaining = deadline.saturating_duration_since(Instant::now()); + if remaining.is_zero() { + return WaitOutcome::TimedOut; + } + remaining.min(Duration::from_millis(50)).as_millis() as u32 + } + None => 50, + }; + let res = unsafe { WaitForSingleObject(process, wait_ms) }; + if res == 0x0000_0102 { + continue; + } + return WaitOutcome::Exited; + } + } + unsafe fn setup_stdio_pipes() -> io::Result { let mut in_r: HANDLE = 0; let mut in_w: HANDLE = 0; @@ -322,6 +399,7 @@ mod windows_impl { cwd: &Path, env_map: HashMap, timeout_ms: Option, + cancellation: Option, use_private_desktop: bool, ) -> Result { run_windows_sandbox_capture_with_extra_deny_write_paths( @@ -332,6 +410,7 @@ mod windows_impl { cwd, env_map, timeout_ms, + cancellation, &[], use_private_desktop, ) @@ -346,6 +425,7 @@ mod windows_impl { cwd: &Path, mut env_map: HashMap, timeout_ms: Option, + cancellation: Option, additional_deny_write_paths: &[PathBuf], use_private_desktop: bool, ) -> Result { @@ -539,11 +619,11 @@ mod windows_impl { let _ = tx_err.send(buf); }); - let timeout = timeout_ms.map(|ms| ms as u32).unwrap_or(INFINITE); - let res = unsafe { WaitForSingleObject(pi.hProcess, timeout) }; - let timed_out = res == 0x0000_0102; + let wait_outcome = wait_for_process(pi.hProcess, timeout_ms, cancellation.as_ref()); + let timed_out = matches!(wait_outcome, WaitOutcome::TimedOut); + let cancelled = matches!(wait_outcome, WaitOutcome::Cancelled); let mut exit_code_u32: u32 = 1; - if !timed_out { + if !timed_out && !cancelled { unsafe { GetExitCodeProcess(pi.hProcess, &mut exit_code_u32); } @@ -676,6 +756,7 @@ mod windows_impl { #[cfg(not(target_os = "windows"))] mod stub { + use super::WindowsSandboxCancellationToken; use anyhow::Result; use anyhow::bail; use codex_protocol::protocol::SandboxPolicy; @@ -699,6 +780,7 @@ mod stub { _cwd: &Path, _env_map: HashMap, _timeout_ms: Option, + _cancellation: Option, _use_private_desktop: bool, ) -> Result { bail!("Windows sandbox is only available on Windows") diff --git a/codex-rs/windows-sandbox-rs/src/unified_exec/tests.rs b/codex-rs/windows-sandbox-rs/src/unified_exec/tests.rs index b0530a4fb465..d3e777e51dd2 100644 --- a/codex-rs/windows-sandbox-rs/src/unified_exec/tests.rs +++ b/codex-rs/windows-sandbox-rs/src/unified_exec/tests.rs @@ -1,6 +1,7 @@ #![cfg(target_os = "windows")] use super::spawn_windows_sandbox_session_legacy; +use crate::WindowsSandboxCancellationToken; use crate::ipc_framed::Message; use crate::ipc_framed::decode_bytes; use crate::ipc_framed::read_frame; @@ -14,8 +15,10 @@ use std::io::Seek; use std::io::SeekFrom; use std::path::Path; use std::path::PathBuf; +use std::sync::Arc; use std::sync::Mutex; use std::sync::MutexGuard; +use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicU64; use std::sync::atomic::Ordering; use std::time::Duration; @@ -379,6 +382,7 @@ fn legacy_capture_powershell_emits_output() { cwd.as_path(), HashMap::new(), Some(10_000), + /*cancellation*/ None, /*use_private_desktop*/ true, ) .expect("run legacy capture powershell"); @@ -394,6 +398,54 @@ fn legacy_capture_powershell_emits_output() { ); } +#[test] +fn legacy_capture_cancellation_is_not_reported_as_timeout() { + let Some(pwsh) = pwsh_path() else { + return; + }; + let cwd = sandbox_cwd(); + let codex_home = sandbox_home("legacy-capture-cancel"); + let cancelled = Arc::new(AtomicBool::new(false)); + let cancelled_for_token = Arc::clone(&cancelled); + let cancellation = + WindowsSandboxCancellationToken::new(move || cancelled_for_token.load(Ordering::SeqCst)); + let cancelled_for_thread = Arc::clone(&cancelled); + let cancel_thread = std::thread::spawn(move || { + std::thread::sleep(Duration::from_millis(200)); + cancelled_for_thread.store(true, Ordering::SeqCst); + }); + + let started_at = Instant::now(); + let result = run_windows_sandbox_capture( + "workspace-write", + cwd.as_path(), + codex_home.path(), + vec![ + pwsh.display().to_string(), + "-NoProfile".to_string(), + "-Command".to_string(), + "Start-Sleep -Seconds 30".to_string(), + ], + cwd.as_path(), + HashMap::new(), + Some(30_000), + Some(cancellation), + /*use_private_desktop*/ true, + ) + .expect("run legacy capture powershell with cancellation"); + cancel_thread.join().expect("cancel thread should finish"); + + assert!( + started_at.elapsed() < Duration::from_secs(10), + "cancellation should end capture before the timeout" + ); + assert!( + !result.timed_out, + "cancellation should not be reported as a timeout" + ); + assert_ne!(result.exit_code, 0); +} + #[test] fn legacy_tty_powershell_emits_output_and_accepts_input() { let Some(pwsh) = pwsh_path() else {