Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 28 additions & 25 deletions app/src/antivirus/windows.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use std::marker::PhantomData;
use std::rc::Rc;

use warp_core::send_telemetry_from_ctx;
use warpui::ModelContext;
use windows::Win32::System::Com::{
Expand All @@ -11,9 +14,9 @@ use crate::antivirus::{AntivirusInfo, AntivirusInfoEvent};
impl AntivirusInfo {
#[cfg(windows)]
pub(super) async fn scan() -> anyhow::Result<Option<String>> {
unsafe {
com_initialized()?;
let _com = ComGuard::new()?;

unsafe {
// Read out all of the registered antivirus products.
let pl: IWSCProductList = CoCreateInstance(&WSCProductList, None, CLSCTX_ALL)?;
pl.Initialize(WSC_SECURITY_PROVIDER_ANTIVIRUS)?;
Expand Down Expand Up @@ -68,35 +71,35 @@ impl AntivirusInfo {
}
}

/// Helper struct to properly enforce reference counting of the Windows COM library.
/// RAII guard that initializes the Windows COM library for the current thread and uninitializes it
/// when dropped.
///
/// Per the Windows docs (https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-coinitializeex)
/// each call to [`CoInitializeEx`] must be paired with a call to [`CoUninitialize`] in order for
/// the COM library to be gracefully uninitialized.
/// each successful call to [`CoInitializeEx`] must be paired with a call to [`CoUninitialize`] on
/// the same thread.
// TODO(alokedesai): Move this to a shared place in `core` so we can use it in other places in the
// app.
struct ComInitialized;

impl Drop for ComInitialized {
fn drop(&mut self) {
unsafe { CoUninitialize() };
}
struct ComGuard {
// Tie the guard to its creating thread so `CoUninitialize` only runs there.
not_send_or_sync: PhantomData<Rc<()>>,
}

thread_local! {
static COM_INITIALIZED: Result<ComInitialized, windows::core::Error> = {
unsafe {
CoInitializeEx(None, COINIT_APARTMENTTHREADED).ok()?;
Ok(ComInitialized)
}
};
impl ComGuard {
/// Initializes COM as a single-threaded apartment for the current thread.
fn new() -> Result<Self, windows::core::Error> {
// SAFETY: balanced by `CoUninitialize` in `Drop`, on the same thread.
unsafe { CoInitializeEx(None, COINIT_APARTMENTTHREADED).ok()? };
Ok(Self {
not_send_or_sync: PhantomData,
})
}
}

fn com_initialized() -> Result<(), windows::core::Error> {
COM_INITIALIZED.with(|initialized| {
initialized
.as_ref()
.map(|_| ())
.map_err(|error| error.clone())
})
impl Drop for ComGuard {
fn drop(&mut self) {
// SAFETY: `new` succeeded on this thread (otherwise this guard would not exist), and this
// runs in the synchronous `scan` scope rather than a TLS destructor, so it is safe to
// balance the init here.
unsafe { CoUninitialize() };
}
}
Loading