diff --git a/netwatch/src/netmon/bsd.rs b/netwatch/src/netmon/bsd.rs index 0c6e5aaf..16146d1d 100644 --- a/netwatch/src/netmon/bsd.rs +++ b/netwatch/src/netmon/bsd.rs @@ -1,8 +1,11 @@ #[cfg(any(target_os = "macos", target_os = "ios"))] use libc::{RTAX_DST, RTAX_IFP}; use n0_error::stack_error; +use n0_future::{ + task::AbortOnDropHandle, + time::{self, Duration}, +}; use tokio::{io::AsyncReadExt, sync::mpsc}; -use tokio_util::task::AbortOnDropHandle; use tracing::{trace, warn}; use super::actor::NetworkMessage; @@ -39,19 +42,24 @@ impl RouteMonitor { let handle = tokio::task::spawn(async move { trace!("AF_ROUTE monitor started"); - // TODO: cleaner shutdown let mut buffer = vec![0u8; 2048]; + let mut backoff = Duration::from_secs(1); + const MAX_BACKOFF: Duration = Duration::from_secs(30); + loop { match socket.read(&mut buffer).await { Ok(read) => { + backoff = Duration::from_secs(1); trace!("AF_ROUTE: read {} bytes", read); match super::super::interfaces::bsd::parse_rib( libc::NET_RT_DUMP, &buffer[..read], ) { Ok(msgs) => { - if contains_interesting_message(&msgs) { - sender.send(NetworkMessage::Change).await.ok(); + if contains_interesting_message(&msgs) + && sender.send(NetworkMessage::Change).await.is_err() + { + break; } } Err(err) => { @@ -61,15 +69,15 @@ impl RouteMonitor { } Err(err) => { warn!("AF_ROUTE: error reading: {:?}", err); - // recreate socket, as it is likely in an invalid state - // TODO: distinguish between different errors? + time::sleep(backoff).await; match create_socket() { Ok(new_socket) => { socket = new_socket; + backoff = Duration::from_secs(1); } Err(err) => { - warn!("AF_ROUTE: unable to bind a new socket: {:?}", err); - // TODO: what to do here? + warn!("AF_ROUTE: unable to recreate socket: {:?}", err); + backoff = (backoff * 2).min(MAX_BACKOFF); } } } diff --git a/netwatch/src/netmon/linux.rs b/netwatch/src/netmon/linux.rs index 026ab879..1af1f009 100644 --- a/netwatch/src/netmon/linux.rs +++ b/netwatch/src/netmon/linux.rs @@ -8,11 +8,15 @@ use libc::{ RTNLGRP_IPV6_ROUTE, RTNLGRP_IPV6_RULE, }; use n0_error::stack_error; -use n0_future::StreamExt; -use netlink_packet_core::NetlinkPayload; +use n0_future::{ + Stream, StreamExt, + task::AbortOnDropHandle, + time::{self, Duration}, +}; +use netlink_packet_core::{NetlinkMessage, NetlinkPayload}; use netlink_packet_route::{RouteNetlinkMessage, address, route}; use netlink_sys::{AsyncSocket, SocketAddr}; -use tokio::{sync::mpsc, task::JoinHandle}; +use tokio::sync::mpsc; use tracing::{trace, warn}; use super::actor::NetworkMessage; @@ -20,15 +24,7 @@ use crate::ip::is_link_local; #[derive(Debug)] pub(super) struct RouteMonitor { - conn_handle: JoinHandle<()>, - handle: JoinHandle<()>, -} - -impl Drop for RouteMonitor { - fn drop(&mut self) { - self.handle.abort(); - self.conn_handle.abort(); - } + _handle: AbortOnDropHandle<()>, } #[stack_error(derive, add_meta, from_sources, std_sources)] @@ -53,122 +49,160 @@ macro_rules! get_nla { }; } -impl RouteMonitor { - pub(super) fn new(sender: mpsc::Sender) -> Result { - use netlink_sys::protocols::NETLINK_ROUTE; +#[allow(clippy::type_complexity)] +fn setup_netlink() -> std::io::Result<( + AbortOnDropHandle<()>, + impl Stream, SocketAddr)>, +)> { + use netlink_sys::protocols::NETLINK_ROUTE; - let (mut conn, _handle, mut messages) = netlink_proto::new_connection::< - netlink_packet_route::RouteNetlinkMessage, - >(NETLINK_ROUTE)?; + let (mut conn, _handle, messages) = + netlink_proto::new_connection::(NETLINK_ROUTE)?; - // Specify flags to listen on. - let groups = nl_mgrp(RTNLGRP_IPV4_IFADDR) - | nl_mgrp(RTNLGRP_IPV6_IFADDR) - | nl_mgrp(RTNLGRP_IPV4_ROUTE) - | nl_mgrp(RTNLGRP_IPV6_ROUTE) - | nl_mgrp(RTNLGRP_IPV4_RULE) - | nl_mgrp(RTNLGRP_IPV6_RULE); + let groups = nl_mgrp(RTNLGRP_IPV4_IFADDR) + | nl_mgrp(RTNLGRP_IPV6_IFADDR) + | nl_mgrp(RTNLGRP_IPV4_ROUTE) + | nl_mgrp(RTNLGRP_IPV6_ROUTE) + | nl_mgrp(RTNLGRP_IPV4_RULE) + | nl_mgrp(RTNLGRP_IPV6_RULE); - let addr = SocketAddr::new(0, groups); - conn.socket_mut().socket_mut().bind(&addr)?; + let addr = SocketAddr::new(0, groups); + conn.socket_mut().socket_mut().bind(&addr)?; - let conn_handle = tokio::task::spawn(conn); + let conn_handle = AbortOnDropHandle::new(tokio::task::spawn(conn)); - let handle = tokio::task::spawn(async move { - // let mut addr_cache: HashMap>> = HashMap::new(); - let mut addr_cache: HashMap> = HashMap::new(); + Ok((conn_handle, messages)) +} - while let Some((message, _)) = messages.next().await { - match message.payload { - NetlinkPayload::Error(err) => { - warn!("error reading netlink payload: {:?}", err); +/// Returns `true` if the connection was lost (should reconnect), +/// `false` if the sender is gone (should shut down). +async fn process_messages( + sender: &mpsc::Sender, + messages: &mut (impl Stream, SocketAddr)> + Unpin), +) -> bool { + let mut addr_cache: HashMap> = HashMap::new(); + + while let Some((message, _)) = messages.next().await { + match message.payload { + NetlinkPayload::Error(err) => { + warn!("error reading netlink payload: {:?}", err); + } + NetlinkPayload::Done(_) => { + trace!("done received, reconnecting"); + return true; + } + NetlinkPayload::InnerMessage(msg) => match msg { + RouteNetlinkMessage::NewAddress(msg) => { + trace!("NEWADDR: {:?}", msg); + let addrs = addr_cache.entry(msg.header.index).or_default(); + if let Some(addr) = get_nla!(msg, address::AddressAttribute::Address) { + if addrs.contains(addr) { + continue; + } else { + addrs.insert(*addr); + if sender.send(NetworkMessage::Change).await.is_err() { + return false; + } + } + } + } + RouteNetlinkMessage::DelAddress(msg) => { + trace!("DELADDR: {:?}", msg); + let addrs = addr_cache.entry(msg.header.index).or_default(); + if let Some(addr) = get_nla!(msg, address::AddressAttribute::Address) { + addrs.remove(addr); } - NetlinkPayload::Done(_) => { - trace!("done received, exiting"); - break; + if sender.send(NetworkMessage::Change).await.is_err() { + return false; } - NetlinkPayload::InnerMessage(msg) => match msg { - RouteNetlinkMessage::NewAddress(msg) => { - trace!("NEWADDR: {:?}", msg); - let addrs = addr_cache.entry(msg.header.index).or_default(); - if let Some(addr) = get_nla!(msg, address::AddressAttribute::Address) { - if addrs.contains(addr) { - // already cached + } + RouteNetlinkMessage::NewRoute(msg) | RouteNetlinkMessage::DelRoute(msg) => { + trace!("ROUTE:: {:?}", msg); + + let table = get_nla!(msg, route::RouteAttribute::Table) + .copied() + .unwrap_or_default(); + if let Some(dst) = get_nla!(msg, route::RouteAttribute::Destination) { + match dst { + route::RouteAddress::Inet(addr) => { + if (table == 255 || table == 254) + && (addr.is_multicast() || is_link_local(IpAddr::V4(*addr))) + { continue; - } else { - addrs.insert(*addr); - sender.send(NetworkMessage::Change).await.ok(); } } - } - RouteNetlinkMessage::DelAddress(msg) => { - trace!("DELADDR: {:?}", msg); - let addrs = addr_cache.entry(msg.header.index).or_default(); - if let Some(addr) = get_nla!(msg, address::AddressAttribute::Address) { - addrs.remove(addr); - } - sender.send(NetworkMessage::Change).await.ok(); - } - RouteNetlinkMessage::NewRoute(msg) | RouteNetlinkMessage::DelRoute(msg) => { - trace!("ROUTE:: {:?}", msg); - - // Ignore the following messages - let table = get_nla!(msg, route::RouteAttribute::Table) - .copied() - .unwrap_or_default(); - if let Some(dst) = get_nla!(msg, route::RouteAttribute::Destination) { - match dst { - route::RouteAddress::Inet(addr) => { - if (table == 255 || table == 254) - && (addr.is_multicast() - || is_link_local(IpAddr::V4(*addr))) - { - continue; - } - } - route::RouteAddress::Inet6(addr) => { - if (table == 255 || table == 254) - && (addr.is_multicast() - || is_link_local(IpAddr::V6(*addr))) - { - continue; - } - } - _ => {} + route::RouteAddress::Inet6(addr) => { + if (table == 255 || table == 254) + && (addr.is_multicast() || is_link_local(IpAddr::V6(*addr))) + { + continue; } } - sender.send(NetworkMessage::Change).await.ok(); - } - RouteNetlinkMessage::NewRule(msg) => { - trace!("NEWRULE: {:?}", msg); - sender.send(NetworkMessage::Change).await.ok(); - } - RouteNetlinkMessage::DelRule(msg) => { - trace!("DELRULE: {:?}", msg); - sender.send(NetworkMessage::Change).await.ok(); + _ => {} } - RouteNetlinkMessage::NewLink(msg) => { - trace!("NEWLINK: {:?}", msg); - // ignored atm - } - RouteNetlinkMessage::DelLink(msg) => { - trace!("DELLINK: {:?}", msg); - // ignored atm - } - msg => { - trace!("unhandled: {:?}", msg); + } + if sender.send(NetworkMessage::Change).await.is_err() { + return false; + } + } + RouteNetlinkMessage::NewRule(msg) => { + trace!("NEWRULE: {:?}", msg); + if sender.send(NetworkMessage::Change).await.is_err() { + return false; + } + } + RouteNetlinkMessage::DelRule(msg) => { + trace!("DELRULE: {:?}", msg); + if sender.send(NetworkMessage::Change).await.is_err() { + return false; + } + } + RouteNetlinkMessage::NewLink(msg) => { + trace!("NEWLINK: {:?}", msg); + } + RouteNetlinkMessage::DelLink(msg) => { + trace!("DELLINK: {:?}", msg); + } + msg => { + trace!("unhandled: {:?}", msg); + } + }, + _ => {} + } + } + + // Stream ended — connection lost + true +} + +impl RouteMonitor { + pub(super) fn new(sender: mpsc::Sender) -> Result { + let handle = tokio::task::spawn(async move { + let mut backoff = Duration::from_secs(1); + const MAX_BACKOFF: Duration = Duration::from_secs(30); + + loop { + match setup_netlink() { + Ok((_conn_handle, mut messages)) => { + backoff = Duration::from_secs(1); + let should_reconnect = process_messages(&sender, &mut messages).await; + // _conn_handle dropped here, aborting the connection task + if !should_reconnect { + break; } - }, - _ => { - // ignore other types + warn!("netlink connection lost, reconnecting"); + } + Err(err) => { + warn!("failed to setup netlink: {:?}", err); } } + time::sleep(backoff).await; + backoff = (backoff * 2).min(MAX_BACKOFF); } }); Ok(RouteMonitor { - handle, - conn_handle, + _handle: AbortOnDropHandle::new(handle), }) } }