diff --git a/Cargo.toml b/Cargo.toml index 9629bc2..1946a91 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ clap = { version = "4.5.40", features = ["cargo"] } futures-util = "0.3.31" indexmap = { version = "2.14.0", features = ["serde"] } log = { version = "0.4.29", features = ["std"] } +nix = { version = "0.31.3", features = ["feature", "sched"] } rtnetlink = { git = "https://github.com/rust-netlink/rtnetlink" } serde = { version = "1.0", default-features = false, features = ["derive"] } serde_json = "1.0.140" diff --git a/src/error.rs b/src/error.rs index b54b41a..a641a09 100644 --- a/src/error.rs +++ b/src/error.rs @@ -17,6 +17,15 @@ impl From<&str> for CliError { } } +impl From for CliError { + fn from(msg: String) -> Self { + Self { + code: DEFAULT_ERROR_CODE, + msg, + } + } +} + impl std::fmt::Display for CliError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "error {}: {}", self.code, self.msg) diff --git a/src/ip/main.rs b/src/ip/main.rs index cc8c729..8649672 100644 --- a/src/ip/main.rs +++ b/src/ip/main.rs @@ -2,6 +2,7 @@ mod address; mod link; +mod neighbour; #[cfg(test)] mod tests; @@ -10,6 +11,8 @@ use std::io::IsTerminal; use iproute_rs::{CliColor, CliError, OutputFormat, print_result_and_exit}; +use crate::neighbour::NeighbourCommand; + use self::{address::AddressCommand, link::LinkCommand}; #[tokio::main(flavor = "current_thread")] @@ -55,9 +58,17 @@ async fn main() -> Result<(), CliError> { .action(clap::ArgAction::SetTrue) .global(true), ) + .arg( + clap::Arg::new("STATISTICS") + .short('s') + .help("Show object statistics") + .action(clap::ArgAction::SetTrue) + .global(true), + ) .subcommand_required(true) .subcommand(LinkCommand::gen_command()) - .subcommand(AddressCommand::gen_command()); + .subcommand(AddressCommand::gen_command()) + .subcommand(NeighbourCommand::gen_command()); let matches = app.get_matches_mut(); @@ -84,6 +95,10 @@ async fn main() -> Result<(), CliError> { matches.subcommand_matches(AddressCommand::CMD) { print_result_and_exit(AddressCommand::handle(matches).await, fmt); + } else if let Some(matches) = + matches.subcommand_matches(NeighbourCommand::CMD) + { + print_result_and_exit(NeighbourCommand::handle(matches).await, fmt); } else { app.print_help()?; println!(); diff --git a/src/ip/neighbour/cli.rs b/src/ip/neighbour/cli.rs new file mode 100644 index 0000000..5074729 --- /dev/null +++ b/src/ip/neighbour/cli.rs @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT + +use iproute_rs::CliError; + +use super::show::{CliNeighbourInfo, handle_show}; + +pub(crate) struct NeighbourCommand; + +impl NeighbourCommand { + pub(crate) const CMD: &'static str = "neighbour"; + + pub(crate) fn gen_command() -> clap::Command { + clap::Command::new(Self::CMD) + .about("arp/ndp table management") + .alias("neigh") + .alias("neig") + .alias("nei") + .alias("ne") + .alias("n") + .subcommand_required(false) + .subcommand( + clap::Command::new("show") + .about("list neighbour entries") + .alias("list") + .alias("lst") + .alias("ls") + .alias("li") + .alias("l") + .arg( + clap::Arg::new("options") + .action(clap::ArgAction::Append) + .trailing_var_arg(true), + ), + ) + } + + pub(crate) async fn handle( + matches: &clap::ArgMatches, + ) -> Result, CliError> { + if let Some(matches) = matches.subcommand_matches("show") { + let opts = matches + .get_many::("options") + .unwrap_or_default() + .map(String::as_str); + handle_show(opts, matches.get_flag("STATISTICS")).await + } else { + handle_show([].into_iter(), matches.get_flag("STATISTICS")).await + } + } +} diff --git a/src/ip/neighbour/mod.rs b/src/ip/neighbour/mod.rs new file mode 100644 index 0000000..aca63d2 --- /dev/null +++ b/src/ip/neighbour/mod.rs @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: MIT + +mod cli; +mod show; + +#[cfg(test)] +mod tests; + +pub(crate) use self::cli::NeighbourCommand; diff --git a/src/ip/neighbour/show.rs b/src/ip/neighbour/show.rs new file mode 100644 index 0000000..ba0d927 --- /dev/null +++ b/src/ip/neighbour/show.rs @@ -0,0 +1,457 @@ +// SPDX-License-Identifier: MIT + +use std::{ + collections::{BTreeMap, HashMap}, + net::{IpAddr, Ipv4Addr}, + str::FromStr, +}; + +use futures_util::TryStreamExt; +use iproute_rs::{ + CanDisplay, CanOutput, CliColor, mac_to_string, write_with_color, +}; +use rtnetlink::{ + Handle, + packet_route::{ + link::LinkAttribute, + neighbour::{ + NeighbourAddress, NeighbourAttribute, NeighbourFlags, + NeighbourMessage, NeighbourState, + }, + }, +}; +use serde::Serialize; + +use crate::CliError; + +/// Bespoke struct to preserve odd `ip -json` behaviour, +/// where a router-address has a `"router": null`. +#[derive(Copy, Clone, Debug, Default)] +enum IsRouter { + #[default] + NotRouter, + Router, +} + +impl IsRouter { + pub(crate) fn is_a_router(&self) -> bool { + matches!(self, IsRouter::Router) + } + + pub(crate) fn is_not_a_router(&self) -> bool { + matches!(self, IsRouter::NotRouter) + } +} + +impl Serialize for IsRouter { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_none() + } +} + +#[derive(Serialize, Debug)] +pub(crate) struct CliNeighbourInfo { + #[serde(skip)] + family: String, + dst: IpAddr, + dev: String, + #[serde(skip_serializing_if = "Option::is_none")] + lladdr: Option, + + #[serde(skip)] + refcnt: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + used: Option, + #[serde(skip_serializing_if = "Option::is_none")] + confirmed: Option, + #[serde(skip_serializing_if = "Option::is_none")] + updated: Option, + #[serde(skip_serializing_if = "Option::is_none")] + probes: Option, + + #[serde(skip_serializing_if = "IsRouter::is_not_a_router")] + router: IsRouter, + /// TODO: iproute2 emits a JSON array for these; need to figure out in what situation we have more than 1. + state: Vec, +} + +impl std::fmt::Display for CliNeighbourInfo { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write_with_color!( + f, + CliColor::address_color(&self.family), + "{}", + self.dst + )?; + write!(f, " dev ")?; + write_with_color!(f, CliColor::IfaceName, "{}", self.dev)?; + if let Some(lladdr) = &self.lladdr { + write!(f, " lladdr ")?; + write_with_color!(f, CliColor::Mac, "{lladdr}")?; + } + + if self.router.is_a_router() { + write!(f, " router")?; + } + + if let Some(refcnt) = self.refcnt + && refcnt != 0 + { + write!(f, " ref {refcnt}")?; + }; + + if let Some(used) = self.used { + let confirmed = self.confirmed.unwrap_or(0); + let updated = self.updated.unwrap_or(0); + write!(f, " used {used}/{confirmed}/{updated}")?; + } + + if let Some(probes) = self.probes { + write!(f, " probes {probes}")?; + } + + for state in &self.state { + write!(f, " {state}")?; + } + + Ok(()) + } +} + +impl CanDisplay for CliNeighbourInfo { + fn gen_string(&self) -> String { + self.to_string() + } +} + +impl CanOutput for CliNeighbourInfo {} + +#[derive(Default, Debug)] +enum NudFilter { + #[default] + Default, + All, + Specified(NeighbourState), +} + +impl FromStr for NudFilter { + type Err = CliError; + + fn from_str(s: &str) -> Result { + if s == "all" { + return Ok(NudFilter::All); + } + let Ok(state) = s.parse::() else { + return Err("Invalid nud `{s}`".into()); + }; + + Ok(NudFilter::Specified(state)) + } +} + +#[derive(Default, Debug)] +enum ControllerFilter<'a> { + #[default] + Unfiltered, + DeviceName(&'a str), + NoController, +} + +#[derive(Default, Debug)] +struct ShowArguments<'a> { + nud_filter: NudFilter, + dev_filter: Option<&'a str>, + controller_filter: ControllerFilter<'a>, + proxy: bool, + unused: bool, + address_filter: Option, +} + +impl<'a> ShowArguments<'a> { + fn from_arguments( + mut arguments: impl Iterator, + ) -> Result { + let mut args = Self::default(); + while let Some(opt) = arguments.next() { + match opt { + "proxy" => { + args.proxy = true; + } + + "unused" => { + args.unused = true; + } + + "dev" => { + let Some(dev_name) = arguments.next() else { + return Err("Missing argument for `dev`".into()); + }; + args.dev_filter = Some(dev_name); + } + + "vrf" => { + let Some(vrf_name) = arguments.next() else { + return Err("Missing argument for `vrf`".into()); + }; + args.controller_filter = + ControllerFilter::DeviceName(vrf_name); + } + + "nomaster" => { + args.controller_filter = ControllerFilter::NoController; + } + + "nud" => { + let Some(nud) = arguments.next() else { + return Err("Missing argument for `nud`".into()); + }; + args.nud_filter = nud.parse()?; + } + + "to" => { + let Some(address) = arguments.next() else { + return Err("Missing argument for `to`".into()); + }; + args.address_filter = Some(parse_address(address)?); + } + + raw => { + args.address_filter = Some(parse_address(raw)?); + } + } + } + + Ok(args) + } +} + +/// Parse an address similarly to `ip neigh show `. +/// It accepts either an IP-address, or a numeric IPv4 (probably network-order). +/// +/// TODO: Check if `0XXXX` is represented as octal. +fn parse_address(address: &str) -> Result { + if let Ok(address) = address.parse() { + return Ok(address); + } + + // Try to parse as an integer and convert to ipv4 + let (address, radix) = if let Some(address) = address.strip_prefix("0x") { + (address, 16) + } else if let Some(address) = address.strip_prefix("0b") { + (address, 2) + } else { + (address, 10) + }; + let address_num = u32::from_str_radix(address, radix) + .map_err(|_| format!("Invalid address `{address}`"))?; + + Ok(Ipv4Addr::from_bits(address_num).into()) +} + +fn parse_nl_msg_to_neighbour( + nl_msg: NeighbourMessage, + interface_names: &BTreeMap, + clocks_per_second: u32, +) -> Result, CliError> { + let family = nl_msg.header.family.to_string(); + let flags = NeighbourFlags::from_bits_retain(nl_msg.header.flags.bits()); + let mut dst = None; + let mut lladdr = None; + let mut confirmed = None; + let mut used = None; + let mut updated = None; + let mut refcnt = None; + let mut probes = None; + + let state = if nl_msg.header.state == NeighbourState::None { + vec![] + } else { + vec![nl_msg.header.state.to_string().to_ascii_uppercase()] + }; + + let dev = interface_names + .get(&nl_msg.header.ifindex) + .cloned() + .unwrap_or_else(|| nl_msg.header.ifindex.to_string()); + + for nla in nl_msg.attributes { + match nla { + NeighbourAttribute::Destination(a) => { + dst = match a { + NeighbourAddress::Inet(addr) => Some(addr.into()), + NeighbourAddress::Inet6(addr) => Some(addr.into()), + _ => None, + }; + } + NeighbourAttribute::LinkLayerAddress(raw_lladdr) => { + lladdr = Some(mac_to_string(&raw_lladdr)); + } + NeighbourAttribute::CacheInfo(info) => { + confirmed = Some(info.confirmed / clocks_per_second); + used = Some(info.used / clocks_per_second); + updated = Some(info.updated / clocks_per_second); + refcnt = Some(info.refcnt); + } + NeighbourAttribute::Probes(probes_) => { + probes = Some(probes_); + } + _ => {} + } + } + + let router = if flags.contains(NeighbourFlags::Router) { + IsRouter::Router + } else { + IsRouter::NotRouter + }; + + let Some(dst) = dst else { + return Ok(None); + }; + + let cli_addr_info = CliNeighbourInfo { + family, + dst, + dev, + lladdr, + refcnt, + used, + confirmed, + updated, + probes, + router, + state, + }; + + Ok(Some(cli_addr_info)) +} + +/// Build a bidrectional-mapping between interface names and their indicies. +/// Optionally retrieves a single link if limited by the user. +async fn get_links( + handle: &Handle, + dev_filter: Option<&str>, +) -> Result<(BTreeMap, HashMap), CliError> { + let mut links_get_handler = handle.link().get(); + + if let Some(dev_filter) = dev_filter { + links_get_handler = links_get_handler.match_name(dev_filter.into()); + } + + let mut links = links_get_handler.execute(); + let mut link_names = BTreeMap::new(); + let mut link_indicies = HashMap::new(); + while let Some(nl_msg) = links.try_next().await? { + let index = nl_msg.header.index; + for attr in nl_msg.attributes { + if let LinkAttribute::IfName(name) = attr { + link_names.insert(index, name.clone()); + link_indicies.insert(name, index); + } + } + } + + Ok((link_names, link_indicies)) +} + +pub(crate) async fn handle_show( + opts: impl Iterator, + show_statistics: bool, +) -> Result, CliError> { + let (connection, handle, _) = rtnetlink::new_connection()?; + + tokio::spawn(connection); + + let args = ShowArguments::from_arguments(opts)?; + let (link_names, link_indicies) = + get_links(&handle, args.dev_filter).await?; + + let mut neighbours_get_handle = handle.neighbours().get(); + if args.proxy { + neighbours_get_handle = neighbours_get_handle.proxies(); + } + if let Some(dev_name) = args.dev_filter { + let dev_index = link_indicies + .get(dev_name) + .ok_or_else(|| format!("Cannot find device \"{dev_name}\""))?; + + neighbours_get_handle + .message_mut() + .attributes + .push(NeighbourAttribute::IfIndex(*dev_index)); + } + let controller_filter = match args.controller_filter { + ControllerFilter::DeviceName(vrf_name) => { + let index = link_indicies.get(vrf_name).ok_or_else(|| { + format!( + "argument \"{vrf_name}\" is wrong: Not a valid VRF name" + ) + })?; + Some(*index) + } + ControllerFilter::NoController => Some(u32::MAX), + _ => None, + }; + neighbours_get_handle + .message_mut() + .attributes + .extend(controller_filter.map(NeighbourAttribute::Controller)); + + let mut neighbours = neighbours_get_handle.execute(); + let mut neighbour_info: Vec = Vec::new(); + + // Retrieve clock resolution (USER_HZ in kernel) for time calculations. + // Typically it is set to a hundreth of a second, but this is overridable + // when compiling the kernel. + let clock = nix::unistd::sysconf(nix::unistd::SysconfVar::CLK_TCK) + .unwrap_or(None) + .unwrap_or(100) as u32; + + while let Some(nl_msg) = neighbours.try_next().await? { + match args.nud_filter { + NudFilter::Default => { + if nl_msg.header.state == NeighbourState::None + || nl_msg.header.state == NeighbourState::Noarp + { + continue; + } + } + NudFilter::Specified(neighbour_state) => { + if nl_msg.header.state != neighbour_state { + continue; + } + } + NudFilter::All => {} + } + + let Some(mut neigh) = + parse_nl_msg_to_neighbour(nl_msg, &link_names, clock)? + else { + continue; + }; + + if let Some(address_filter) = args.address_filter + && neigh.dst != address_filter + { + continue; + } + if args.unused && neigh.refcnt.unwrap_or(0) != 0 { + continue; + } + if !show_statistics { + neigh.refcnt = None; + neigh.used = None; + neigh.confirmed = None; + neigh.updated = None; + neigh.probes = None; + } + + neighbour_info.push(neigh); + } + + Ok(neighbour_info) +} diff --git a/src/ip/neighbour/tests/mod.rs b/src/ip/neighbour/tests/mod.rs new file mode 100644 index 0000000..7739469 --- /dev/null +++ b/src/ip/neighbour/tests/mod.rs @@ -0,0 +1,4 @@ +// SPDX-License-Identifier: MIT + +#[cfg(test)] +mod neighbour; diff --git a/src/ip/neighbour/tests/neighbour.rs b/src/ip/neighbour/tests/neighbour.rs new file mode 100644 index 0000000..842e1cb --- /dev/null +++ b/src/ip/neighbour/tests/neighbour.rs @@ -0,0 +1,180 @@ +// SPDX-License-Identifier: MIT + +use std::{ + net::{IpAddr, Ipv4Addr, Ipv6Addr}, + os::fd::AsFd, +}; + +use nix::sched::CloneFlags; +use rtnetlink::packet_route::neighbour::NeighbourState; + +use crate::tests::{exec_cmd, ip_rs_exec_cmd}; + +#[test] +fn test_neighbour_show() { + let neigh_address1 = Ipv4Addr::new(10, 0, 0, 1).into(); + let neigh_address2 = Ipv4Addr::new(10, 0, 0, 2).into(); + let neigh_address3 = Ipv6Addr::from_bits(0x3000u128).into(); + let lladdr = "AA:AA:AA:AA:AA:AA"; + + in_test_netns(|| { + add_neighbour(neigh_address1, NeighbourState::Reachable, Some(lladdr)); + add_neighbour(neigh_address2, NeighbourState::Reachable, Some(lladdr)); + add_neighbour(neigh_address3, NeighbourState::Reachable, Some(lladdr)); + let expected_output = exec_cmd(&["ip", "neigh", "show"]); + let our_output = ip_rs_exec_cmd(&["neigh", "show"]); + + trimmed_assert_eq(&expected_output, &our_output); + }); +} + +#[test] +fn test_neighbour_show_json() { + let neigh_address1 = Ipv4Addr::new(10, 0, 0, 1).into(); + let neigh_address2 = Ipv4Addr::new(10, 0, 0, 2).into(); + let neigh_address3 = Ipv6Addr::from_bits(0x3000u128).into(); + let lladdr = "AA:AA:AA:AA:AA:AA"; + + in_test_netns(|| { + add_neighbour(neigh_address1, NeighbourState::Reachable, Some(lladdr)); + add_neighbour(neigh_address2, NeighbourState::Reachable, Some(lladdr)); + add_neighbour(neigh_address3, NeighbourState::Reachable, Some(lladdr)); + let expected_output = exec_cmd(&["ip", "-j", "neigh", "show"]); + let our_output = ip_rs_exec_cmd(&["-j", "neigh", "show"]); + + trimmed_assert_eq(&expected_output, &our_output); + }); +} + +#[test] +fn test_neighbour_show_to() { + let neigh_address1 = Ipv4Addr::new(10, 0, 0, 1).into(); + let neigh_address2 = Ipv4Addr::new(10, 0, 0, 2).into(); + let lladdr = "AA:AA:AA:AA:AA:AA"; + + in_test_netns(|| { + add_neighbour(neigh_address1, NeighbourState::Reachable, Some(lladdr)); + add_neighbour(neigh_address2, NeighbourState::Reachable, Some(lladdr)); + // Implicit "to" parameter + let expected_output = + exec_cmd(&["ip", "neigh", "show", &neigh_address1.to_string()]); + let our_output = + ip_rs_exec_cmd(&["neigh", "show", &neigh_address1.to_string()]); + + trimmed_assert_eq(&expected_output, &our_output); + + // Explicit "to" parameter + let expected_output = exec_cmd(&[ + "ip", + "neigh", + "show", + "to", + &neigh_address1.to_string(), + ]); + let our_output = ip_rs_exec_cmd(&[ + "neigh", + "show", + "to", + &neigh_address1.to_string(), + ]); + + trimmed_assert_eq(&expected_output, &our_output); + }); +} + +#[test] +fn test_neighbour_show_nud() { + let neigh_address1 = Ipv4Addr::new(10, 0, 0, 1).into(); + let neigh_address2 = Ipv4Addr::new(10, 0, 0, 2).into(); + let neigh_address3 = Ipv4Addr::new(10, 0, 0, 3).into(); + let neigh_address4 = Ipv4Addr::new(10, 0, 0, 4).into(); + let lladdr = "AA:AA:AA:AA:AA:AA"; + + in_test_netns(|| { + add_neighbour(neigh_address1, NeighbourState::Reachable, Some(lladdr)); + add_neighbour(neigh_address2, NeighbourState::Stale, Some(lladdr)); + add_neighbour(neigh_address3, NeighbourState::Noarp, Some(lladdr)); + add_neighbour(neigh_address4, NeighbourState::None, Some(lladdr)); + + // First, make sure that by default we don't show none/noarp neighs + let expected_output = exec_cmd(&["ip", "neigh", "show"]); + let our_output = ip_rs_exec_cmd(&["neigh", "show"]); + + trimmed_assert_eq(&expected_output, &our_output); + + // Then, ask for them explictly + let expected_output = exec_cmd(&["ip", "neigh", "show", "nud", "none"]); + let our_output = ip_rs_exec_cmd(&["neigh", "show", "nud", "none"]); + trimmed_assert_eq(&expected_output, &our_output); + + let expected_output = + exec_cmd(&["ip", "neigh", "show", "nud", "noarp"]); + let our_output = ip_rs_exec_cmd(&["neigh", "show", "nud", "noarp"]); + trimmed_assert_eq(&expected_output, &our_output); + }); +} + +// TODO: Tests +// - Filter by dev, vrf, and nomaster +// - Filter by pneigh (proxy) +// - Show statistics (need to figure out how to handle time differences) + +fn add_neighbour( + neigh_address: IpAddr, + nud: NeighbourState, + lladdr: Option<&str>, +) { + let neigh_address = neigh_address.to_string(); + let nud = nud.to_string(); + let mut cmd = vec![ + "ip", + "neigh", + "add", + "dev", + "tap0", + &neigh_address, + "nud", + &nud, + ]; + if let Some(lladdr) = lladdr { + cmd.extend(["lladdr", lladdr]); + } + exec_cmd(&cmd); +} + +/// Asserts textual outputs of us and iproute2 are equal, +/// normalizing iproute2 output to remove trailing whitespace. +fn trimmed_assert_eq(expected: &str, actual: &str) { + let expected: Vec<_> = expected.lines().map(|l| l.trim_end()).collect(); + let mut expected = expected.join("\n"); + expected.push('\n'); + + pretty_assertions::assert_eq!(expected, actual); +} + +/// Runs the test body in a dedicated disposable network-namespace. +/// The namespace is created with a single tap-device `tap0`. +/// No need to cleanup anything inside the test; the namespace is deleted afterwards. +fn in_test_netns(test: T) +where + T: FnOnce() + std::panic::UnwindSafe, +{ + // Get reference to old netns, create and enter a new one. + let current_fd = std::fs::File::open("/proc/thread-self/ns/net").unwrap(); + nix::sched::unshare(CloneFlags::CLONE_NEWNET).unwrap(); + + // Create a tap device we can put our neighbours on. + exec_cmd(&["ip", "tuntap", "add", "mode", "tap", "name", "tap0"]); + + let result = std::panic::catch_unwind(|| { + test(); + }); + + // Switch back to old netns + nix::sched::setns(current_fd.as_fd(), CloneFlags::CLONE_NEWNET).unwrap(); + + assert!(result.is_ok()) + + // No need for explicit cleanup, we did not mount our netns anywhere in the filesystem, + // so it will die now that we exited it. +}