diff --git a/app/src/antivirus/windows.rs b/app/src/antivirus/windows.rs index 2b4d73951f..a7f864f5d1 100644 --- a/app/src/antivirus/windows.rs +++ b/app/src/antivirus/windows.rs @@ -1,3 +1,6 @@ +use std::marker::PhantomData; +use std::rc::Rc; + use warp_core::send_telemetry_from_ctx; use warpui::ModelContext; use windows::Win32::System::Com::{ @@ -11,9 +14,9 @@ use crate::antivirus::{AntivirusInfo, AntivirusInfoEvent}; impl AntivirusInfo { #[cfg(windows)] pub(super) async fn scan() -> anyhow::Result> { - unsafe { - com_initialized()?; + let _com = ComGuard::new()?; + unsafe { // Read out all of the registered antivirus products. let pl: IWSCProductList = CoCreateInstance(&WSCProductList, None, CLSCTX_ALL)?; pl.Initialize(WSC_SECURITY_PROVIDER_ANTIVIRUS)?; @@ -68,35 +71,35 @@ impl AntivirusInfo { } } -/// Helper struct to properly enforce reference counting of the Windows COM library. +/// RAII guard that initializes the Windows COM library for the current thread and uninitializes it +/// when dropped. /// /// Per the Windows docs (https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-coinitializeex) -/// each call to [`CoInitializeEx`] must be paired with a call to [`CoUninitialize`] in order for -/// the COM library to be gracefully uninitialized. +/// each successful call to [`CoInitializeEx`] must be paired with a call to [`CoUninitialize`] on +/// the same thread. // TODO(alokedesai): Move this to a shared place in `core` so we can use it in other places in the // app. -struct ComInitialized; - -impl Drop for ComInitialized { - fn drop(&mut self) { - unsafe { CoUninitialize() }; - } +struct ComGuard { + // Tie the guard to its creating thread so `CoUninitialize` only runs there. + not_send_or_sync: PhantomData>, } -thread_local! { - static COM_INITIALIZED: Result = { - unsafe { - CoInitializeEx(None, COINIT_APARTMENTTHREADED).ok()?; - Ok(ComInitialized) - } - }; +impl ComGuard { + /// Initializes COM as a single-threaded apartment for the current thread. + fn new() -> Result { + // SAFETY: balanced by `CoUninitialize` in `Drop`, on the same thread. + unsafe { CoInitializeEx(None, COINIT_APARTMENTTHREADED).ok()? }; + Ok(Self { + not_send_or_sync: PhantomData, + }) + } } -fn com_initialized() -> Result<(), windows::core::Error> { - COM_INITIALIZED.with(|initialized| { - initialized - .as_ref() - .map(|_| ()) - .map_err(|error| error.clone()) - }) +impl Drop for ComGuard { + fn drop(&mut self) { + // SAFETY: `new` succeeded on this thread (otherwise this guard would not exist), and this + // runs in the synchronous `scan` scope rather than a TLS destructor, so it is safe to + // balance the init here. + unsafe { CoUninitialize() }; + } }