Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions netwatch/src/netmon/bsd.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#[cfg(any(target_os = "macos", target_os = "ios"))]
use libc::{RTAX_DST, RTAX_IFP};
use n0_error::stack_error;
use n0_future::{
task::AbortOnDropHandle,
time::{self, Duration},
};
use tokio::{io::AsyncReadExt, sync::mpsc};
use tokio_util::task::AbortOnDropHandle;
use tracing::{trace, warn};

use super::actor::NetworkMessage;
Expand Down Expand Up @@ -39,19 +42,24 @@ impl RouteMonitor {
let handle = tokio::task::spawn(async move {
trace!("AF_ROUTE monitor started");

// TODO: cleaner shutdown
let mut buffer = vec![0u8; 2048];
let mut backoff = Duration::from_secs(1);
const MAX_BACKOFF: Duration = Duration::from_secs(30);

loop {
match socket.read(&mut buffer).await {
Ok(read) => {
backoff = Duration::from_secs(1);
trace!("AF_ROUTE: read {} bytes", read);
match super::super::interfaces::bsd::parse_rib(
libc::NET_RT_DUMP,
&buffer[..read],
) {
Ok(msgs) => {
if contains_interesting_message(&msgs) {
sender.send(NetworkMessage::Change).await.ok();
if contains_interesting_message(&msgs)
&& sender.send(NetworkMessage::Change).await.is_err()
{
break;
}
}
Err(err) => {
Expand All @@ -61,15 +69,15 @@ impl RouteMonitor {
}
Err(err) => {
warn!("AF_ROUTE: error reading: {:?}", err);
// recreate socket, as it is likely in an invalid state
// TODO: distinguish between different errors?
time::sleep(backoff).await;
match create_socket() {
Ok(new_socket) => {
socket = new_socket;
backoff = Duration::from_secs(1);
}
Err(err) => {
warn!("AF_ROUTE: unable to bind a new socket: {:?}", err);
// TODO: what to do here?
warn!("AF_ROUTE: unable to recreate socket: {:?}", err);
backoff = (backoff * 2).min(MAX_BACKOFF);
}
}
}
Expand Down
248 changes: 141 additions & 107 deletions netwatch/src/netmon/linux.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,23 @@ use libc::{
RTNLGRP_IPV6_ROUTE, RTNLGRP_IPV6_RULE,
};
use n0_error::stack_error;
use n0_future::StreamExt;
use netlink_packet_core::NetlinkPayload;
use n0_future::{
Stream, StreamExt,
task::AbortOnDropHandle,
time::{self, Duration},
};
use netlink_packet_core::{NetlinkMessage, NetlinkPayload};
use netlink_packet_route::{RouteNetlinkMessage, address, route};
use netlink_sys::{AsyncSocket, SocketAddr};
use tokio::{sync::mpsc, task::JoinHandle};
use tokio::sync::mpsc;
use tracing::{trace, warn};

use super::actor::NetworkMessage;
use crate::ip::is_link_local;

#[derive(Debug)]
pub(super) struct RouteMonitor {
conn_handle: JoinHandle<()>,
handle: JoinHandle<()>,
}

impl Drop for RouteMonitor {
fn drop(&mut self) {
self.handle.abort();
self.conn_handle.abort();
}
_handle: AbortOnDropHandle<()>,
}

#[stack_error(derive, add_meta, from_sources, std_sources)]
Expand All @@ -53,122 +49,160 @@ macro_rules! get_nla {
};
}

impl RouteMonitor {
pub(super) fn new(sender: mpsc::Sender<NetworkMessage>) -> Result<Self, Error> {
use netlink_sys::protocols::NETLINK_ROUTE;
#[allow(clippy::type_complexity)]
fn setup_netlink() -> std::io::Result<(
AbortOnDropHandle<()>,
impl Stream<Item = (NetlinkMessage<RouteNetlinkMessage>, SocketAddr)>,
)> {
use netlink_sys::protocols::NETLINK_ROUTE;

let (mut conn, _handle, mut messages) = netlink_proto::new_connection::<
netlink_packet_route::RouteNetlinkMessage,
>(NETLINK_ROUTE)?;
let (mut conn, _handle, messages) =
netlink_proto::new_connection::<RouteNetlinkMessage>(NETLINK_ROUTE)?;

// Specify flags to listen on.
let groups = nl_mgrp(RTNLGRP_IPV4_IFADDR)
| nl_mgrp(RTNLGRP_IPV6_IFADDR)
| nl_mgrp(RTNLGRP_IPV4_ROUTE)
| nl_mgrp(RTNLGRP_IPV6_ROUTE)
| nl_mgrp(RTNLGRP_IPV4_RULE)
| nl_mgrp(RTNLGRP_IPV6_RULE);
let groups = nl_mgrp(RTNLGRP_IPV4_IFADDR)
| nl_mgrp(RTNLGRP_IPV6_IFADDR)
| nl_mgrp(RTNLGRP_IPV4_ROUTE)
| nl_mgrp(RTNLGRP_IPV6_ROUTE)
| nl_mgrp(RTNLGRP_IPV4_RULE)
| nl_mgrp(RTNLGRP_IPV6_RULE);

let addr = SocketAddr::new(0, groups);
conn.socket_mut().socket_mut().bind(&addr)?;
let addr = SocketAddr::new(0, groups);
conn.socket_mut().socket_mut().bind(&addr)?;

let conn_handle = tokio::task::spawn(conn);
let conn_handle = AbortOnDropHandle::new(tokio::task::spawn(conn));

let handle = tokio::task::spawn(async move {
// let mut addr_cache: HashMap<u32, HashSet<Vec<u8>>> = HashMap::new();
let mut addr_cache: HashMap<u32, HashSet<IpAddr>> = HashMap::new();
Ok((conn_handle, messages))
}

while let Some((message, _)) = messages.next().await {
match message.payload {
NetlinkPayload::Error(err) => {
warn!("error reading netlink payload: {:?}", err);
/// Returns `true` if the connection was lost (should reconnect),
/// `false` if the sender is gone (should shut down).
async fn process_messages(
sender: &mpsc::Sender<NetworkMessage>,
messages: &mut (impl Stream<Item = (NetlinkMessage<RouteNetlinkMessage>, SocketAddr)> + Unpin),
) -> bool {
let mut addr_cache: HashMap<u32, HashSet<IpAddr>> = HashMap::new();

while let Some((message, _)) = messages.next().await {
match message.payload {
NetlinkPayload::Error(err) => {
warn!("error reading netlink payload: {:?}", err);
}
NetlinkPayload::Done(_) => {
trace!("done received, reconnecting");
return true;
}
NetlinkPayload::InnerMessage(msg) => match msg {
RouteNetlinkMessage::NewAddress(msg) => {
trace!("NEWADDR: {:?}", msg);
let addrs = addr_cache.entry(msg.header.index).or_default();
if let Some(addr) = get_nla!(msg, address::AddressAttribute::Address) {
if addrs.contains(addr) {
continue;
} else {
addrs.insert(*addr);
if sender.send(NetworkMessage::Change).await.is_err() {
return false;
}
}
}
}
RouteNetlinkMessage::DelAddress(msg) => {
trace!("DELADDR: {:?}", msg);
let addrs = addr_cache.entry(msg.header.index).or_default();
if let Some(addr) = get_nla!(msg, address::AddressAttribute::Address) {
addrs.remove(addr);
}
NetlinkPayload::Done(_) => {
trace!("done received, exiting");
break;
if sender.send(NetworkMessage::Change).await.is_err() {
return false;
}
NetlinkPayload::InnerMessage(msg) => match msg {
RouteNetlinkMessage::NewAddress(msg) => {
trace!("NEWADDR: {:?}", msg);
let addrs = addr_cache.entry(msg.header.index).or_default();
if let Some(addr) = get_nla!(msg, address::AddressAttribute::Address) {
if addrs.contains(addr) {
// already cached
}
RouteNetlinkMessage::NewRoute(msg) | RouteNetlinkMessage::DelRoute(msg) => {
trace!("ROUTE:: {:?}", msg);

let table = get_nla!(msg, route::RouteAttribute::Table)
.copied()
.unwrap_or_default();
if let Some(dst) = get_nla!(msg, route::RouteAttribute::Destination) {
match dst {
route::RouteAddress::Inet(addr) => {
if (table == 255 || table == 254)
&& (addr.is_multicast() || is_link_local(IpAddr::V4(*addr)))
{
continue;
} else {
addrs.insert(*addr);
sender.send(NetworkMessage::Change).await.ok();
}
}
}
RouteNetlinkMessage::DelAddress(msg) => {
trace!("DELADDR: {:?}", msg);
let addrs = addr_cache.entry(msg.header.index).or_default();
if let Some(addr) = get_nla!(msg, address::AddressAttribute::Address) {
addrs.remove(addr);
}
sender.send(NetworkMessage::Change).await.ok();
}
RouteNetlinkMessage::NewRoute(msg) | RouteNetlinkMessage::DelRoute(msg) => {
trace!("ROUTE:: {:?}", msg);

// Ignore the following messages
let table = get_nla!(msg, route::RouteAttribute::Table)
.copied()
.unwrap_or_default();
if let Some(dst) = get_nla!(msg, route::RouteAttribute::Destination) {
match dst {
route::RouteAddress::Inet(addr) => {
if (table == 255 || table == 254)
&& (addr.is_multicast()
|| is_link_local(IpAddr::V4(*addr)))
{
continue;
}
}
route::RouteAddress::Inet6(addr) => {
if (table == 255 || table == 254)
&& (addr.is_multicast()
|| is_link_local(IpAddr::V6(*addr)))
{
continue;
}
}
_ => {}
route::RouteAddress::Inet6(addr) => {
if (table == 255 || table == 254)
&& (addr.is_multicast() || is_link_local(IpAddr::V6(*addr)))
{
continue;
}
}
sender.send(NetworkMessage::Change).await.ok();
}
RouteNetlinkMessage::NewRule(msg) => {
trace!("NEWRULE: {:?}", msg);
sender.send(NetworkMessage::Change).await.ok();
}
RouteNetlinkMessage::DelRule(msg) => {
trace!("DELRULE: {:?}", msg);
sender.send(NetworkMessage::Change).await.ok();
_ => {}
}
RouteNetlinkMessage::NewLink(msg) => {
trace!("NEWLINK: {:?}", msg);
// ignored atm
}
RouteNetlinkMessage::DelLink(msg) => {
trace!("DELLINK: {:?}", msg);
// ignored atm
}
msg => {
trace!("unhandled: {:?}", msg);
}
if sender.send(NetworkMessage::Change).await.is_err() {
return false;
}
}
RouteNetlinkMessage::NewRule(msg) => {
trace!("NEWRULE: {:?}", msg);
if sender.send(NetworkMessage::Change).await.is_err() {
return false;
}
}
RouteNetlinkMessage::DelRule(msg) => {
trace!("DELRULE: {:?}", msg);
if sender.send(NetworkMessage::Change).await.is_err() {
return false;
}
}
RouteNetlinkMessage::NewLink(msg) => {
trace!("NEWLINK: {:?}", msg);
}
RouteNetlinkMessage::DelLink(msg) => {
trace!("DELLINK: {:?}", msg);
}
msg => {
trace!("unhandled: {:?}", msg);
}
},
_ => {}
}
}

// Stream ended — connection lost
true
}

impl RouteMonitor {
pub(super) fn new(sender: mpsc::Sender<NetworkMessage>) -> Result<Self, Error> {
let handle = tokio::task::spawn(async move {
let mut backoff = Duration::from_secs(1);
const MAX_BACKOFF: Duration = Duration::from_secs(30);

loop {
match setup_netlink() {
Ok((_conn_handle, mut messages)) => {
backoff = Duration::from_secs(1);
let should_reconnect = process_messages(&sender, &mut messages).await;
// _conn_handle dropped here, aborting the connection task
if !should_reconnect {
break;
}
},
_ => {
// ignore other types
warn!("netlink connection lost, reconnecting");
}
Err(err) => {
warn!("failed to setup netlink: {:?}", err);
}
}
time::sleep(backoff).await;
backoff = (backoff * 2).min(MAX_BACKOFF);
}
});

Ok(RouteMonitor {
handle,
conn_handle,
_handle: AbortOnDropHandle::new(handle),
})
}
}
Expand Down
Loading