Skip to content
Merged
51 changes: 42 additions & 9 deletions netwatch/src/interfaces.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,16 +417,10 @@ fn prefixes_major_equal(a: impl Iterator<Item = IpNet>, b: impl Iterator<Item =
true
}

let a = a.filter(is_interesting);
let b = b.filter(is_interesting);
let a: Vec<_> = a.filter(is_interesting).collect();
let b: Vec<_> = b.filter(is_interesting).collect();

for (a, b) in a.zip(b) {
if a != b {
return false;
}
}

true
a == b
Comment thread
matheus23 marked this conversation as resolved.
Outdated
Comment thread
matheus23 marked this conversation as resolved.
Outdated
}

#[cfg(test)]
Expand All @@ -449,6 +443,45 @@ mod tests {
println!("home router: {home_router:#?}");
}

#[test]
fn test_prefixes_major_equal() {
use std::net::Ipv4Addr;

let a1 = IpNet::V4(Ipv4Net::new(Ipv4Addr::new(192, 168, 0, 1), 24).unwrap());
let a2 = IpNet::V4(Ipv4Net::new(Ipv4Addr::new(10, 0, 0, 1), 8).unwrap());
let a3 = IpNet::V4(Ipv4Net::new(Ipv4Addr::new(172, 16, 0, 1), 16).unwrap());

// equal lists
assert!(prefixes_major_equal(
vec![a1.clone(), a2.clone()].into_iter(),
vec![a1.clone(), a2.clone()].into_iter(),
));

// both empty
assert!(prefixes_major_equal(
std::iter::empty(),
std::iter::empty(),
));

// different prefixes
assert!(!prefixes_major_equal(
vec![a1.clone()].into_iter(),
vec![a2.clone()].into_iter(),
));

// a has extra prefix
assert!(!prefixes_major_equal(
vec![a1.clone(), a2.clone(), a3.clone()].into_iter(),
vec![a1.clone(), a2.clone()].into_iter(),
));

// b has extra prefix
assert!(!prefixes_major_equal(
vec![a1.clone(), a2.clone()].into_iter(),
vec![a1.clone(), a2.clone(), a3.clone()].into_iter(),
));
}

#[test]
fn test_is_usable_v6() {
let loopback = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0x1);
Expand Down
23 changes: 22 additions & 1 deletion netwatch/src/interfaces/bsd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,8 @@ pub fn parse_rib(typ: RIBType, data: &[u8]) -> Result<Vec<WireMessage>, RouteErr
ensure!(l != 0, RouteError::InvalidMessage);
ensure!(b.len() >= l as usize, RouteError::MessageTooShort);
if b[2] as i32 != ROUTING_STACK.rtm_version {
// b = b[l:];
b = &b[l as usize..];
nskips += 1;
continue;
}
match ROUTING_STACK.wire_formats.get(&(b[3] as i32)) {
Expand Down Expand Up @@ -1015,6 +1016,26 @@ fn parse_default_addr(b: &[u8]) -> Result<Addr, RouteError> {
mod tests {
use super::*;

#[test]
fn test_parse_rib_skips_version_mismatch() {
let wrong_version = (ROUTING_STACK.rtm_version as u8).wrapping_add(1);
let msg_len: u16 = 8;
let mut buf = vec![0u8; msg_len as usize];
buf[..2].copy_from_slice(&msg_len.to_ne_bytes());
buf[2] = wrong_version;
buf[3] = 0; // arbitrary type

#[cfg(any(target_os = "macos", target_os = "ios"))]
let rib_type = libc::NET_RT_IFLIST2;
#[cfg(any(target_os = "freebsd", target_os = "netbsd"))]
let rib_type = libc::NET_RT_IFLIST;
#[cfg(target_os = "openbsd")]
let rib_type = libc::NET_RT_IFLIST;

let msgs = parse_rib(rib_type, &buf).unwrap();
assert!(msgs.is_empty(), "version-mismatched message should be skipped");
}

#[test]
fn test_fetch_parse_routing_table() {
let rib_raw = fetch_routing_table().unwrap();
Expand Down