Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 46 additions & 4 deletions resources/windows_firewall/src/firewall.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,36 @@ use windows::Win32::Foundation::{S_FALSE, VARIANT_BOOL};
use windows::Win32::NetworkManagement::WindowsFirewall::*;
use windows::Win32::System::Com::{CLSCTX_INPROC_SERVER, CoCreateInstance, CoInitializeEx, CoUninitialize, IDispatch, COINIT_APARTMENTTHREADED};
use windows::Win32::System::Ole::IEnumVARIANT;
use windows::Win32::System::Variant::VARIANT;
use windows::Win32::System::Variant::{VARIANT, VariantClear};

use crate::types::{FirewallError, FirewallRule, FirewallRuleList, RuleAction, RuleDirection};
use crate::util::matches_any_filter;

/// RAII wrapper for VARIANT that automatically calls VariantClear on drop
struct SafeVariant(VARIANT);

impl SafeVariant {
fn new() -> Self {
Self(VARIANT::default())
}

fn as_mut_ptr(&mut self) -> *mut VARIANT {
&mut self.0
}

fn as_ref(&self) -> &VARIANT {
&self.0
}
}

impl Drop for SafeVariant {
fn drop(&mut self) {
if let Err(e) = unsafe { VariantClear(&mut self.0) } {
crate::write_error(&format!("Warning: VariantClear failed with HRESULT: {:#010x}", e.code().0));
}
}
}

struct ComGuard;

impl ComGuard {
Expand Down Expand Up @@ -55,20 +80,23 @@ impl FirewallStore {
let mut results = Vec::new();
loop {
let mut fetched = 0u32;
let mut variant = [VARIANT::default()];
let hr = unsafe { enum_variant.Next(&mut variant, &mut fetched) };
let mut safe_variant = SafeVariant::new();
let variant_slice = unsafe { std::slice::from_raw_parts_mut(safe_variant.as_mut_ptr(), 1) };
let hr = unsafe { enum_variant.Next(variant_slice, &mut fetched) };
if hr == S_FALSE || fetched == 0 {
break;
}
hr.ok()
.map_err(|error| t!("firewall.ruleEnumerationFailed", error = error.to_string()).to_string())?;

let dispatch = IDispatch::try_from(&variant[0])
let dispatch = IDispatch::try_from(safe_variant.as_ref())
.map_err(|error: windows::core::Error| t!("firewall.ruleEnumerationFailed", error = error.to_string()).to_string())?;
let rule: INetFwRule = dispatch
.cast()
.map_err(|error| t!("firewall.ruleEnumerationFailed", error = error.to_string()).to_string())?;
results.push(rule);

// SafeVariant will automatically call VariantClear when it goes out of scope
}

Ok(results)
Expand Down Expand Up @@ -169,6 +197,10 @@ fn profiles_from_mask(mask: i32) -> Vec<String> {
}

fn profiles_to_mask(values: &[String]) -> Result<i32, FirewallError> {
if values.is_empty() {
return Ok(NET_FW_PROFILE2_ALL.0);
}

let mut mask = 0;
for value in values {
match value.to_ascii_lowercase().as_str() {
Expand Down Expand Up @@ -197,6 +229,10 @@ fn join_csv(value: &[String]) -> String {
}

fn interface_types_to_string(values: &[String]) -> Result<String, FirewallError> {
if values.is_empty() {
return Ok("All".to_string());
}

let mut normalized = Vec::new();
for value in values {
match value.to_ascii_lowercase().as_str() {
Expand Down Expand Up @@ -269,6 +305,12 @@ fn apply_rule_properties(rule: &INetFwRule, desired: &FirewallRule, existing_pro
// the existing rule's protocol (if updating an existing rule).
let effective_protocol = desired.protocol.or(existing_protocol);

// If effective_protocol is None, read the current protocol from the rule.
let effective_protocol = match effective_protocol {
Some(protocol) => Some(protocol),
None => Some(unsafe { rule.Protocol() }.map_err(&err)?),
};

// Reject port specifications for protocols that don't support them (e.g. ICMP).
// This must be checked regardless of whether the protocol itself was changed,
// because the caller may only be setting local_ports or remote_ports.
Expand Down
2 changes: 1 addition & 1 deletion resources/windows_firewall/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ const EXIT_INVALID_ARGS: i32 = 1;
const EXIT_INVALID_INPUT: i32 = 2;
const EXIT_FIREWALL_ERROR: i32 = 3;

fn write_error(message: &str) {
pub(crate) fn write_error(message: &str) {
eprintln!("{}", serde_json::json!({ "error": message }));
}

Expand Down
Loading