diff --git a/noq-proto/src/connection/mod.rs b/noq-proto/src/connection/mod.rs index 208d152c3..b604d5838 100644 --- a/noq-proto/src/connection/mod.rs +++ b/noq-proto/src/connection/mod.rs @@ -2094,15 +2094,14 @@ impl Connection { builder.finish(self, now); // Mark as sent after packet build succeeds. - self.n0_nat_traversal - .mark_probe_sent((remote.ip(), remote.port()), token); + self.n0_nat_traversal.mark_probe_sent(remote, token); let size = buf.len(); self.path_stats.for_path(path_id).udp_tx.on_sent(1, size); trace!(dst = ?remote, len = buf.len(), "sending off-path NAT probe"); Some(Transmit { - destination: remote, + destination: remote.into(), size, ecn: None, segment_size: None, diff --git a/noq-proto/src/n0_nat_traversal.rs b/noq-proto/src/n0_nat_traversal.rs index d238f8ea6..bc55cc6eb 100644 --- a/noq-proto/src/n0_nat_traversal.rs +++ b/noq-proto/src/n0_nat_traversal.rs @@ -2,6 +2,7 @@ use std::{ collections::hash_map::Entry, + fmt::Display, net::{IpAddr, SocketAddr}, }; @@ -14,8 +15,84 @@ use crate::{ frame::{AddAddress, ReachOut, RemoveAddress}, }; +/// An IP & port. +/// +/// Invariant: This value should always be in the ip family that the local +/// socket operates in. +/// E.g. if the local socket is ipv4, then all `IpPort`s should only have +/// IPv4 addresses, and if the socket supports ipv6, then all `IpPort`s +/// should be IPv6 addresses or IPv6-mapped IPv4 addresses. +/// +/// See also [`map_to_local_socket_family`], which powers this conversion. type IpPort = (IpAddr, u16); +/// An IP & port in canonical form. +/// +/// Avoids using ipv6-mapped ipv4 addresses. +/// This is the primary type used to send ip addresses around remotely +/// and the primary type used to canonicalize received addresses. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub(crate) struct CanonicalIpPort { + canonical_ip: IpAddr, + port: u16, +} + +impl CanonicalIpPort { + pub(crate) fn ip(&self) -> IpAddr { + self.canonical_ip + } + + pub(crate) fn port(&self) -> u16 { + self.port + } + + /// Converts this into a local-socket-family-mapped IP & port. + /// + /// Instead of using ipv4 and ipv6 addresses, this tries to match `ipv6`, which + /// should indicate whether the local socket supports ipv6 or not. + /// + /// If ipv6 is supported, all ipv4 addresses are mapped using ipv6-mapped ipv4 + /// addresses. + /// If ipv6 is not supported, then this returns `None` for ipv6 addresses. + /// + /// See also [`map_to_local_socket_family`]. + pub(crate) fn as_local_socket_family(&self, ipv6: bool) -> Option { + Some(( + map_to_local_socket_family(self.canonical_ip, ipv6)?, + self.port, + )) + } + + /// Returns this address as-is with the canonical IP used in a `SocketAddr`. + pub(crate) fn as_canonical_addr(&self) -> SocketAddr { + SocketAddr::new(self.canonical_ip, self.port) + } +} + +impl Display for CanonicalIpPort { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.as_canonical_addr().fmt(f) + } +} + +impl From for CanonicalIpPort { + fn from(addr: SocketAddr) -> Self { + Self { + canonical_ip: addr.ip().to_canonical(), + port: addr.port(), + } + } +} + +impl From for CanonicalIpPort { + fn from((ip, port): IpPort) -> Self { + Self { + canonical_ip: ip.to_canonical(), + port, + } + } +} + /// Errors that the nat traversal state might encounter. #[derive(Debug, thiserror::Error)] pub enum Error { @@ -150,14 +227,12 @@ impl State { Self::ClientSide(client_state) => Ok(client_state .local_addresses .iter() - .copied() - .map(Into::into) + .map(CanonicalIpPort::as_canonical_addr) .collect()), Self::ServerSide(server_state) => Ok(server_state .local_addresses .keys() - .copied() - .map(Into::into) + .map(CanonicalIpPort::as_canonical_addr) .collect()), } } @@ -165,7 +240,7 @@ impl State { /// Returns the next ready probe's address. /// /// If this is actually sent you must call [`Self::mark_probe_sent`]. - pub(crate) fn next_probe_addr(&self) -> Option { + pub(crate) fn next_probe_addr(&self) -> Option { match self { Self::NotNegotiated => None, Self::ClientSide(state) => state.next_probe_addr(), @@ -226,7 +301,7 @@ pub(crate) struct ClientState { /// They are indexed by their ADD_ADDRESS sequence id and stored in **canonical /// form**. Not in the socket-native form as usual. This because we need to store them /// so we have the correct sequence IDs. - remote_addresses: FxHashMap, + remote_addresses: FxHashMap, /// Candidate addresses for the local endpoint. /// /// These are addresses on which we are potentially reachable, to use for NAT traversal @@ -234,7 +309,7 @@ pub(crate) struct ClientState { /// /// They are stored in **canonical form**, not in socket-native form as usual. We may /// nave a reflexive address that is IPv6 even if our local socket can only handle IPv4. - local_addresses: FxHashSet, + local_addresses: FxHashSet, /// Current nat traversal round. round: VarInt, /// The data of PATH_CHALLENGE frames sent in probes. @@ -275,7 +350,7 @@ impl ClientState { } fn add_local_address(&mut self, address: SocketAddr) -> Result<(), Error> { - let address = (address.ip().to_canonical(), address.port()); + let address = CanonicalIpPort::from(address); if self.local_addresses.len() < self.max_local_addresses { self.local_addresses.insert(address); Ok(()) @@ -289,7 +364,7 @@ impl ClientState { } fn remove_local_address(&mut self, address: &IpPort) { - let address = (address.0.to_canonical(), address.1); + let address = CanonicalIpPort::from(*address); self.local_addresses.remove(&address); } @@ -327,12 +402,12 @@ impl ClientState { // Enqueue the NAT probes to known remote addresses. self.remote_addresses .values_mut() - .for_each(|((ip, port), state)| { - if let Some(ip) = map_to_local_socket_family(*ip, ipv6) { - self.pending_probes.insert((ip, *port)); + .for_each(|(ip_port, state)| { + if let Some(ip_port) = ip_port.as_local_socket_family(ipv6) { + self.pending_probes.insert(ip_port); *state = ProbeState::Active(MAX_NAT_PROBE_ATTEMPTS - 1); } else { - trace!(?ip, "not using IPv6 NAT candidate for IPv4 socket"); + trace!(%ip_port, "not using IPv6 NAT candidate for IPv4 socket"); *state = ProbeState::Active(0); } }); @@ -347,10 +422,10 @@ impl ClientState { let reach_out_frames: PendingReachOutFrames = self .local_addresses .iter() - .map(|&(ip, port)| ReachOut { + .map(|ip_port| ReachOut { round: self.round, - ip, - port, + ip: ip_port.ip(), + port: ip_port.port(), }) .collect(); @@ -372,13 +447,13 @@ impl ClientState { pub(crate) fn queue_retries(&mut self, ipv6: bool) -> bool { self.remote_addresses .values_mut() - .for_each(|(addr, state)| match state { + .for_each(|(ip_port, state)| match state { ProbeState::Active(remaining) if *remaining > 0 => { *remaining -= 1; - if let Some(ip) = map_to_local_socket_family(addr.0, ipv6) { - self.pending_probes.insert((ip, addr.1)); + if let Some(ip_port) = ip_port.as_local_socket_family(ipv6) { + self.pending_probes.insert(ip_port); } else { - trace!(?addr, "skipping IPv6 NAT candidate for IPv4 socket"); + trace!(%ip_port, "skipping IPv6 NAT candidate for IPv4 socket"); *remaining = 0; } } @@ -390,8 +465,8 @@ impl ClientState { /// Returns the next ready probe's address. /// /// If this is actually sent you must call [`Self::mark_probe_sent`]. - fn next_probe_addr(&self) -> Option { - self.pending_probes.iter().next().map(|addr| (*addr).into()) + fn next_probe_addr(&self) -> Option { + self.pending_probes.iter().next().copied() } /// Marks a probe as sent to the address with the challenge. @@ -400,7 +475,7 @@ impl ClientState { self.sent_challenges.insert(challenge, remote); } - /// Adds an address to the remote set + /// Adds an address to the remote set. /// /// On success returns the address if it was new to the set. It will error when the set /// has no capacity for the address. @@ -416,7 +491,7 @@ impl ClientState { add_addr: AddAddress, ) -> Result, Error> { let AddAddress { seq_no, ip, port } = add_addr; - let address = (ip.to_canonical(), port); + let address = CanonicalIpPort::from((ip, port)); let allow_new = self.remote_addresses.len() < self.max_remote_addresses; match self.remote_addresses.entry(seq_no) { Entry::Occupied(mut occupied_entry) => { @@ -426,11 +501,11 @@ impl ClientState { } // The value might be different. This should not happen, but we assume that the new // address is more recent than the previous, and thus worth updating - Ok(is_update.then_some(address.into())) + Ok(is_update.then_some(address.as_canonical_addr())) } Entry::Vacant(vacant_entry) if allow_new => { vacant_entry.insert((address, ProbeState::Active(MAX_NAT_PROBE_ATTEMPTS))); - Ok(Some(address.into())) + Ok(Some(address.as_canonical_addr())) } _ => Err(Error::TooManyAddresses), } @@ -445,7 +520,7 @@ impl ClientState { ) -> Option { self.remote_addresses .remove(&remove_addr.seq_no) - .map(|(address, _)| address.into()) + .map(|(address, _)| address.as_canonical_addr()) } /// Checks that a received remote address is valid. @@ -454,14 +529,14 @@ impl ClientState { pub(crate) fn check_remote_address(&self, add_addr: &AddAddress) -> bool { match self.remote_addresses.get(&add_addr.seq_no) { None => true, - Some((existing, _)) => existing == &add_addr.ip_port(), + Some((existing, _)) => *existing == CanonicalIpPort::from(add_addr.ip_port()), } } pub(crate) fn get_remote_nat_traversal_addresses(&self) -> Vec { self.remote_addresses .values() - .map(|(address, _)| (*address).into()) + .map(|(address, _)| (*address).as_canonical_addr()) .collect() } @@ -476,7 +551,7 @@ impl ClientState { entry.remove(); // self.remote_addresses is stored in canonical form. - let remote = (remote.0.to_canonical(), remote.1); + let remote = CanonicalIpPort::from(remote); // TODO: linear search is sad. if let Some(seq) = self .remote_addresses @@ -505,7 +580,7 @@ impl ClientState { ?network_path.remote, expected_remote = ?entry.get(), challenge = %display(format_args!("0x{challenge:x}")), - "PATH_RESPONSE matched a NAT traversal probe but mismatching addr XXXX", + "PATH_RESPONSE matched a NAT traversal probe but mismatching addr", ) } } @@ -570,7 +645,7 @@ pub(crate) struct ServerState { /// /// They are stored in **canonical form**, not in socket-native form as usual. We may /// nave a reflexive address that is IPv6 even if our local socket can only handle IPv4. - local_addresses: FxHashMap, + local_addresses: FxHashMap, /// The next id to use for local addresses sent to the client. next_local_addr_id: VarInt, /// Current nat traversal round @@ -613,7 +688,7 @@ impl ServerState { } fn add_local_address(&mut self, address: SocketAddr) -> Result, Error> { - let address = (address.ip().to_canonical(), address.port()); + let address = CanonicalIpPort::from(address); let allow_new = self.local_addresses.len() < self.max_local_addresses; match self.local_addresses.entry(address) { Entry::Occupied(_) => Ok(None), @@ -621,14 +696,17 @@ impl ServerState { let id = self.next_local_addr_id; self.next_local_addr_id = self.next_local_addr_id.saturating_add(1u8); vacant_entry.insert(id); - Ok(Some(AddAddress::new(address, id))) + Ok(Some(AddAddress::new((address.ip(), address.port()), id))) } _ => Err(Error::TooManyAddresses), } } fn remove_local_address(&mut self, address: &IpPort) -> Option { - self.local_addresses.remove(address).map(RemoveAddress::new) + let address = CanonicalIpPort::from(*address); + self.local_addresses + .remove(&address) + .map(RemoveAddress::new) } /// Returns the current NAT traversal round number. @@ -696,8 +774,8 @@ impl ServerState { /// Returns the next ready probe's address. /// /// If this is actually sent you must call [`Self::mark_probe_sent`]. - fn next_probe_addr(&self) -> Option { - self.pending_probes.iter().next().map(|addr| (*addr).into()) + fn next_probe_addr(&self) -> Option { + self.pending_probes.iter().next().cloned() } /// Marks a probe as sent to the address with the challenge. @@ -782,7 +860,7 @@ mod tests { let mut send_probe = |state: &mut ServerState| { let remote = state.next_probe_addr().unwrap(); challenge += 1; - state.mark_probe_sent((remote.ip(), remote.port()), challenge); + state.mark_probe_sent(remote, challenge); }; send_probe(&mut state);