diff --git a/resources/windows_firewall/src/firewall.rs b/resources/windows_firewall/src/firewall.rs index bf75cd2aa..5c9556178 100644 --- a/resources/windows_firewall/src/firewall.rs +++ b/resources/windows_firewall/src/firewall.rs @@ -8,11 +8,36 @@ use windows::Win32::Foundation::{S_FALSE, VARIANT_BOOL}; use windows::Win32::NetworkManagement::WindowsFirewall::*; use windows::Win32::System::Com::{CLSCTX_INPROC_SERVER, CoCreateInstance, CoInitializeEx, CoUninitialize, IDispatch, COINIT_APARTMENTTHREADED}; use windows::Win32::System::Ole::IEnumVARIANT; -use windows::Win32::System::Variant::VARIANT; +use windows::Win32::System::Variant::{VARIANT, VariantClear}; use crate::types::{FirewallError, FirewallRule, FirewallRuleList, RuleAction, RuleDirection}; use crate::util::matches_any_filter; +/// RAII wrapper for VARIANT that automatically calls VariantClear on drop +struct SafeVariant(VARIANT); + +impl SafeVariant { + fn new() -> Self { + Self(VARIANT::default()) + } + + fn as_mut_ptr(&mut self) -> *mut VARIANT { + &mut self.0 + } + + fn as_ref(&self) -> &VARIANT { + &self.0 + } +} + +impl Drop for SafeVariant { + fn drop(&mut self) { + if let Err(e) = unsafe { VariantClear(&mut self.0) } { + crate::write_error(&format!("Warning: VariantClear failed with HRESULT: {:#010x}", e.code().0)); + } + } +} + struct ComGuard; impl ComGuard { @@ -55,20 +80,23 @@ impl FirewallStore { let mut results = Vec::new(); loop { let mut fetched = 0u32; - let mut variant = [VARIANT::default()]; - let hr = unsafe { enum_variant.Next(&mut variant, &mut fetched) }; + let mut safe_variant = SafeVariant::new(); + let variant_slice = unsafe { std::slice::from_raw_parts_mut(safe_variant.as_mut_ptr(), 1) }; + let hr = unsafe { enum_variant.Next(variant_slice, &mut fetched) }; if hr == S_FALSE || fetched == 0 { break; } hr.ok() .map_err(|error| t!("firewall.ruleEnumerationFailed", error = error.to_string()).to_string())?; - let dispatch = IDispatch::try_from(&variant[0]) + let dispatch = IDispatch::try_from(safe_variant.as_ref()) .map_err(|error: windows::core::Error| t!("firewall.ruleEnumerationFailed", error = error.to_string()).to_string())?; let rule: INetFwRule = dispatch .cast() .map_err(|error| t!("firewall.ruleEnumerationFailed", error = error.to_string()).to_string())?; results.push(rule); + + // SafeVariant will automatically call VariantClear when it goes out of scope } Ok(results) @@ -169,6 +197,10 @@ fn profiles_from_mask(mask: i32) -> Vec { } fn profiles_to_mask(values: &[String]) -> Result { + if values.is_empty() { + return Ok(NET_FW_PROFILE2_ALL.0); + } + let mut mask = 0; for value in values { match value.to_ascii_lowercase().as_str() { @@ -197,6 +229,10 @@ fn join_csv(value: &[String]) -> String { } fn interface_types_to_string(values: &[String]) -> Result { + if values.is_empty() { + return Ok("All".to_string()); + } + let mut normalized = Vec::new(); for value in values { match value.to_ascii_lowercase().as_str() { @@ -269,6 +305,12 @@ fn apply_rule_properties(rule: &INetFwRule, desired: &FirewallRule, existing_pro // the existing rule's protocol (if updating an existing rule). let effective_protocol = desired.protocol.or(existing_protocol); + // If effective_protocol is None, read the current protocol from the rule. + let effective_protocol = match effective_protocol { + Some(protocol) => Some(protocol), + None => Some(unsafe { rule.Protocol() }.map_err(&err)?), + }; + // Reject port specifications for protocols that don't support them (e.g. ICMP). // This must be checked regardless of whether the protocol itself was changed, // because the caller may only be setting local_ports or remote_ports. diff --git a/resources/windows_firewall/src/main.rs b/resources/windows_firewall/src/main.rs index c18093a95..3036ed2f1 100644 --- a/resources/windows_firewall/src/main.rs +++ b/resources/windows_firewall/src/main.rs @@ -19,7 +19,7 @@ const EXIT_INVALID_ARGS: i32 = 1; const EXIT_INVALID_INPUT: i32 = 2; const EXIT_FIREWALL_ERROR: i32 = 3; -fn write_error(message: &str) { +pub(crate) fn write_error(message: &str) { eprintln!("{}", serde_json::json!({ "error": message })); }