diff --git a/Cargo.lock b/Cargo.lock index c357841c0cb..d3b66a81de5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -658,15 +658,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "concurrent-queue" -version = "2.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" -dependencies = [ - "crossbeam-utils", -] - [[package]] name = "console" version = "0.15.11" @@ -2229,20 +2220,17 @@ version = "0.35.0" dependencies = [ "aead", "anyhow", - "atomic-waker", "axum", "backon", "bytes", "cfg_aliases", "clap", - "concurrent-queue", "crypto_box", "data-encoding", "der", "derive_more", "ed25519-dalek", "futures-buffered", - "futures-lite", "futures-util", "getrandom 0.3.2", "hickory-resolver", @@ -2258,9 +2246,10 @@ dependencies = [ "iroh-relay", "n0-future", "n0-snafu", + "n0-watcher", "nested_enum_utils", "netdev", - "netwatch", + "netwatch 0.6.0", "parse-size", "pin-project", "pkarr", @@ -2331,6 +2320,7 @@ dependencies = [ "iroh-quinn", "n0-future", "n0-snafu", + "n0-watcher", "rand 0.8.5", "rcgen", "rustls", @@ -2423,9 +2413,9 @@ dependencies = [ [[package]] name = "iroh-quinn" -version = "0.13.0" +version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76c6245c9ed906506ab9185e8d7f64857129aee4f935e899f398a3bd3b70338d" +checksum = "0cde160ebee7aabede6ae887460cd303c8b809054224815addf1469d54a6fcf7" dependencies = [ "bytes", "cfg_aliases", @@ -2825,9 +2815,9 @@ dependencies = [ [[package]] name = "n0-snafu" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6726cb23c307232453032bca1c4fba5a5c60dc36775211e134ab304ab510b548" +checksum = "c4fed465ff57041f29db78a9adc8864296ef93c6c16029f9e192dc303404ebd0" dependencies = [ "anyhow", "btparse", @@ -2836,6 +2826,17 @@ dependencies = [ "tracing-error", ] +[[package]] +name = "n0-watcher" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f216d4ebc5fcf9548244803cbb93f488a2ae160feba3706cd17040d69cf7a368" +dependencies = [ + "derive_more", + "n0-future", + "snafu", +] + [[package]] name = "nested_enum_utils" version = "0.2.2" @@ -2977,6 +2978,41 @@ dependencies = [ "wmi", ] +[[package]] +name = "netwatch" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a829a830199b14989f9bccce6136ab928ab48336ab1f8b9002495dbbbb2edbe" +dependencies = [ + "atomic-waker", + "bytes", + "cfg_aliases", + "derive_more", + "iroh-quinn-udp", + "js-sys", + "libc", + "n0-future", + "n0-watcher", + "nested_enum_utils", + "netdev", + "netlink-packet-core", + "netlink-packet-route 0.23.0", + "netlink-proto", + "netlink-sys", + "pin-project-lite", + "serde", + "snafu", + "socket2", + "time", + "tokio", + "tokio-util", + "tracing", + "web-sys", + "windows 0.59.0", + "windows-result 0.3.2", + "wmi", +] + [[package]] name = "no-std-compat" version = "0.4.1" @@ -3456,7 +3492,7 @@ dependencies = [ "iroh-metrics", "libc", "nested_enum_utils", - "netwatch", + "netwatch 0.5.0", "num_enum", "rand 0.8.5", "serde", @@ -4451,9 +4487,9 @@ checksum = "fad6c857cbab2627dcf01ec85a623ca4e7dcb5691cbaa3d7fb7653671f0d09c9" [[package]] name = "snafu" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "223891c85e2a29c3fe8fb900c1fae5e69c2e42415e3177752e8718475efa5019" +checksum = "320b01e011bf8d5d7a4a4a4be966d9160968935849c83b918827f6a435e7f627" dependencies = [ "backtrace", "snafu-derive", @@ -4461,9 +4497,9 @@ dependencies = [ [[package]] name = "snafu-derive" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03c3c6b7927ffe7ecaa769ee0e3994da3b8cafc8f444578982c83ecb161af917" +checksum = "1961e2ef424c1424204d3a5d6975f934f56b6d50ff5732382d84ebf460e147f7" dependencies = [ "heck", "proc-macro2", diff --git a/iroh-relay/Cargo.toml b/iroh-relay/Cargo.toml index a3681693607..4b66b094d5b 100644 --- a/iroh-relay/Cargo.toml +++ b/iroh-relay/Cargo.toml @@ -42,7 +42,7 @@ postcard = { version = "1", default-features = false, features = [ "use-std", "experimental-derive", ] } -quinn = { package = "iroh-quinn", version = "0.13.0", default-features = false, features = ["rustls-ring"] } +quinn = { package = "iroh-quinn", version = "0.14.0", default-features = false, features = ["rustls-ring"] } quinn-proto = { package = "iroh-quinn-proto", version = "0.13.0" } rand = "0.8" reqwest = { version = "0.12", default-features = false, features = [ diff --git a/iroh/Cargo.toml b/iroh/Cargo.toml index 858c601de31..650e6d40229 100644 --- a/iroh/Cargo.toml +++ b/iroh/Cargo.toml @@ -23,8 +23,6 @@ workspace = true [dependencies] anyhow = "1.0.98" aead = { version = "0.5.2", features = ["bytes", "std"] } -atomic-waker = "1.1.2" -concurrent-queue = "2.5" backon = { version = "1.4" } bytes = "1.7" crypto_box = { version = "0.9.1", features = ["serde", "chacha20"] } @@ -44,13 +42,14 @@ iroh-base = { version = "0.35.0", default-features = false, features = ["key", " iroh-relay = { version = "0.35", path = "../iroh-relay", default-features = false } n0-future = "0.1.2" n0-snafu = "0.2.0" +n0-watcher = "0.2" nested_enum_utils = "0.2.1" -netwatch = { version = "0.5" } +netwatch = { version = "0.6" } pin-project = "1" pkarr = { version = "3.7", default-features = false, features = [ "relays", ] } -quinn = { package = "iroh-quinn", version = "0.13.0", default-features = false, features = ["rustls-ring"] } +quinn = { package = "iroh-quinn", version = "0.14.0", default-features = false, features = ["rustls-ring"] } quinn-proto = { package = "iroh-quinn-proto", version = "0.13.0" } quinn-udp = { package = "iroh-quinn-udp", version = "0.5.7" } rand = "0.8" @@ -104,13 +103,14 @@ tracing-subscriber = { version = "0.3", features = [ indicatif = { version = "0.17", features = ["tokio"], optional = true } parse-size = { version = "=1.0.0", optional = true, features = ['std'] } # pinned version to avoid bumping msrv to 1.81 + # non-wasm-in-browser dependencies [target.'cfg(not(all(target_family = "wasm", target_os = "unknown")))'.dependencies] hickory-resolver = "0.25.1" igd-next = { version = "0.16", features = ["aio_tokio"] } netdev = { version = "0.31.0" } portmapper = { version = "0.5.0", default-features = false } -quinn = { package = "iroh-quinn", version = "0.13.0", default-features = false, features = ["runtime-tokio", "rustls-ring"] } +quinn = { package = "iroh-quinn", version = "0.14.0", default-features = false, features = ["runtime-tokio", "rustls-ring"] } tokio = { version = "1", features = [ "io-util", "macros", @@ -141,7 +141,6 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] } [target.'cfg(not(all(target_family = "wasm", target_os = "unknown")))'.dev-dependencies] axum = { version = "0.8" } clap = { version = "4", features = ["derive"] } -futures-lite = "2.6" pretty_assertions = "1.4" rand_chacha = "0.3.1" tokio = { version = "1", features = [ diff --git a/iroh/bench/Cargo.toml b/iroh/bench/Cargo.toml index 2b6a01d55e8..6455ed7f506 100644 --- a/iroh/bench/Cargo.toml +++ b/iroh/bench/Cargo.toml @@ -12,7 +12,8 @@ iroh = { path = ".." } iroh-metrics = "0.34" n0-future = "0.1.1" n0-snafu = "0.2.0" -quinn = { package = "iroh-quinn", version = "0.13" } +n0-watcher = "0.2" +quinn = { package = "iroh-quinn", version = "0.14" } rand = "0.8" rcgen = "0.13" rustls = { version = "0.23", default-features = false, features = ["ring"] } diff --git a/iroh/bench/src/iroh.rs b/iroh/bench/src/iroh.rs index 24de8e518cd..bd1233d7f10 100644 --- a/iroh/bench/src/iroh.rs +++ b/iroh/bench/src/iroh.rs @@ -6,10 +6,10 @@ use std::{ use bytes::Bytes; use iroh::{ endpoint::{Connection, ConnectionError, RecvStream, SendStream, TransportConfig}, - watcher::Watcher as _, Endpoint, NodeAddr, RelayMode, RelayUrl, }; use n0_snafu::{Result, ResultExt}; +use n0_watcher::Watcher as _; use tracing::{trace, warn}; use crate::{ @@ -54,7 +54,7 @@ pub fn server_endpoint( } let addr = ep.bound_sockets(); - let addr = SocketAddr::new("127.0.0.1".parse().unwrap(), addr.0.port()); + let addr = SocketAddr::new("127.0.0.1".parse().unwrap(), addr[0].port()); let mut addr = NodeAddr::new(ep.node_id()).with_direct_addresses([addr]); if let Some(relay_url) = relay_url { addr = addr.with_relay_url(relay_url.clone()); diff --git a/iroh/examples/0rtt.rs b/iroh/examples/0rtt.rs index 945abdaf1b1..73be802a62d 100644 --- a/iroh/examples/0rtt.rs +++ b/iroh/examples/0rtt.rs @@ -3,12 +3,12 @@ use std::{env, future::Future, str::FromStr, time::Instant}; use clap::Parser; use iroh::{ endpoint::{Connecting, Connection}, - watcher::Watcher, SecretKey, }; use iroh_base::ticket::NodeTicket; use n0_future::{future, StreamExt}; use n0_snafu::ResultExt; +use n0_watcher::Watcher; use rand::thread_rng; use tracing::{info, trace}; @@ -150,6 +150,7 @@ async fn accept(_args: Args) -> n0_snafu::Result<()> { println!("Listening on: {:?}", addr); println!("Node ID: {:?}", addr.node_id); println!("Ticket: {}", NodeTicket::from(addr)); + let accept = async move { while let Some(incoming) = endpoint.accept().await { tokio::spawn(async move { diff --git a/iroh/examples/connect-unreliable.rs b/iroh/examples/connect-unreliable.rs index bc40be070c6..4a8933a98b4 100644 --- a/iroh/examples/connect-unreliable.rs +++ b/iroh/examples/connect-unreliable.rs @@ -8,8 +8,9 @@ use std::net::SocketAddr; use clap::Parser; -use iroh::{watcher::Watcher as _, Endpoint, NodeAddr, RelayMode, RelayUrl, SecretKey}; +use iroh::{Endpoint, NodeAddr, RelayMode, RelayUrl, SecretKey}; use n0_snafu::ResultExt; +use n0_watcher::Watcher as _; use tracing::info; // An example ALPN that we are using to communicate over the `Endpoint` diff --git a/iroh/examples/connect.rs b/iroh/examples/connect.rs index faa676cd0a2..864b2ae9bbb 100644 --- a/iroh/examples/connect.rs +++ b/iroh/examples/connect.rs @@ -8,8 +8,9 @@ use std::net::SocketAddr; use clap::Parser; -use iroh::{watcher::Watcher as _, Endpoint, NodeAddr, RelayMode, RelayUrl, SecretKey}; +use iroh::{Endpoint, NodeAddr, RelayMode, RelayUrl, SecretKey}; use n0_snafu::{Result, ResultExt}; +use n0_watcher::Watcher as _; use tracing::info; // An example ALPN that we are using to communicate over the `Endpoint` @@ -67,6 +68,8 @@ async fn main() -> Result<()> { .home_relay() .get() .unwrap() + .first() + .cloned() .expect("should be connected to a relay server, try calling `endpoint.local_endpoints()` or `endpoint.connect()` first, to ensure the endpoint has actually attempted a connection before checking for the connected relay server"); println!("node relay server url: {relay_url}\n"); // Build a `NodeAddr` from the node_id, relay url, and UDP addresses. diff --git a/iroh/examples/echo-no-router.rs b/iroh/examples/echo-no-router.rs index 3b9dc50bc5d..4026e32a7c7 100644 --- a/iroh/examples/echo-no-router.rs +++ b/iroh/examples/echo-no-router.rs @@ -7,8 +7,9 @@ //! //! cargo run --example echo-no-router --features=examples -use iroh::{watcher::Watcher as _, Endpoint, NodeAddr}; +use iroh::{Endpoint, NodeAddr}; use n0_snafu::{Error, Result, ResultExt}; +use n0_watcher::Watcher as _; /// Each protocol is identified by its ALPN string. /// diff --git a/iroh/examples/echo.rs b/iroh/examples/echo.rs index 24684ee4f67..2a1cb7b6d23 100644 --- a/iroh/examples/echo.rs +++ b/iroh/examples/echo.rs @@ -9,10 +9,10 @@ use iroh::{ endpoint::Connection, protocol::{AcceptError, ProtocolHandler, Router}, - watcher::Watcher as _, Endpoint, NodeAddr, }; use n0_snafu::{Result, ResultExt}; +use n0_watcher::Watcher as _; /// Each protocol is identified by its ALPN string. /// diff --git a/iroh/examples/listen-unreliable.rs b/iroh/examples/listen-unreliable.rs index 586fdded8a2..8a4c7e3cc38 100644 --- a/iroh/examples/listen-unreliable.rs +++ b/iroh/examples/listen-unreliable.rs @@ -3,8 +3,9 @@ //! This example uses the default relay servers to attempt to holepunch, and will use that relay server to relay packets if the two devices cannot establish a direct UDP connection. //! run this example from the project root: //! $ cargo run --example listen-unreliable -use iroh::{watcher::Watcher as _, Endpoint, RelayMode, SecretKey}; +use iroh::{Endpoint, RelayMode, SecretKey}; use n0_snafu::{Error, Result, ResultExt}; +use n0_watcher::Watcher as _; use tracing::{info, warn}; // An example ALPN that we are using to communicate over the `Endpoint` diff --git a/iroh/examples/listen.rs b/iroh/examples/listen.rs index ac1690c4cf5..a78fb008403 100644 --- a/iroh/examples/listen.rs +++ b/iroh/examples/listen.rs @@ -5,8 +5,9 @@ //! $ cargo run --example listen use std::time::Duration; -use iroh::{endpoint::ConnectionError, watcher::Watcher as _, Endpoint, RelayMode, SecretKey}; +use iroh::{endpoint::ConnectionError, Endpoint, RelayMode, SecretKey}; use n0_snafu::ResultExt; +use n0_watcher::Watcher as _; use tracing::{debug, info, warn}; // An example ALPN that we are using to communicate over the `Endpoint` diff --git a/iroh/examples/transfer.rs b/iroh/examples/transfer.rs index 21ecad77ce0..371f717ce61 100644 --- a/iroh/examples/transfer.rs +++ b/iroh/examples/transfer.rs @@ -13,12 +13,12 @@ use iroh::{ }, dns::{DnsResolver, N0_DNS_NODE_ORIGIN_PROD, N0_DNS_NODE_ORIGIN_STAGING}, endpoint::ConnectionError, - watcher::Watcher as _, Endpoint, NodeAddr, NodeId, RelayMap, RelayMode, RelayUrl, SecretKey, }; use iroh_base::ticket::NodeTicket; use n0_future::task::AbortOnDropHandle; use n0_snafu::{Result, ResultExt}; +use n0_watcher::Watcher as _; use tokio_stream::StreamExt; use tracing::{info, warn}; use url::Url; @@ -253,6 +253,7 @@ impl EndpointArgs { let relay_url = endpoint .home_relay() .get()? + .pop() .context("Failed to resolve our home relay")?; println!("Our home relay server:\n\t{relay_url}"); } diff --git a/iroh/src/disco.rs b/iroh/src/disco.rs index f33bf2e2e29..fc47f175702 100644 --- a/iroh/src/disco.rs +++ b/iroh/src/disco.rs @@ -30,6 +30,8 @@ use serde::{Deserialize, Serialize}; use snafu::{ensure, Snafu}; use url::Url; +use crate::magicsock::transports; + // TODO: custom magicn /// The 6 byte header of all discovery messages. pub const MAGIC: &str = "TS💬"; // 6 bytes: 0x54 53 f0 9f 92 ac @@ -158,6 +160,15 @@ impl SendAddr { } } +impl From for SendAddr { + fn from(addr: transports::Addr) -> Self { + match addr { + transports::Addr::Ip(addr) => SendAddr::Udp(addr), + transports::Addr::Relay(url, _) => SendAddr::Relay(url), + } + } +} + impl From for SendAddr { fn from(source: SocketAddr) -> Self { SendAddr::Udp(source) diff --git a/iroh/src/discovery.rs b/iroh/src/discovery.rs index 2749d3b4954..62519955729 100644 --- a/iroh/src/discovery.rs +++ b/iroh/src/discovery.rs @@ -605,13 +605,14 @@ mod tests { use iroh_base::{NodeAddr, SecretKey}; use n0_snafu::{Error, Result, ResultExt}; + use n0_watcher::Watcher as _; use quinn::{IdleTimeout, TransportConfig}; use rand::Rng; use tokio_util::task::AbortOnDropHandle; use tracing_test::traced_test; use super::*; - use crate::{endpoint::ConnectOptions, watcher::Watcher as _, RelayMode}; + use crate::{endpoint::ConnectOptions, RelayMode}; type InfoStore = HashMap; @@ -1095,8 +1096,7 @@ mod test_dns_pkarr { .context("wait for on node update")?; // we connect only by node id! - let res = ep2.connect(ep1.node_id(), TEST_ALPN).await; - assert!(res.is_ok(), "connection established"); + let _conn = ep2.connect(ep1.node_id(), TEST_ALPN).await?; Ok(()) } diff --git a/iroh/src/discovery/mdns.rs b/iroh/src/discovery/mdns.rs index 7836f0dba14..da3c29396d0 100644 --- a/iroh/src/discovery/mdns.rs +++ b/iroh/src/discovery/mdns.rs @@ -42,6 +42,7 @@ use n0_future::{ task::{self, AbortOnDropHandle, JoinSet}, time::{self, Duration}, }; +use n0_watcher::{Watchable, Watcher as _}; use swarm_discovery::{Discoverer, DropGuard, IpClass, Peer}; use tokio::sync::mpsc::{self, error::TrySendError}; use tracing::{debug, error, info_span, trace, warn, Instrument}; @@ -49,7 +50,6 @@ use tracing::{debug, error, info_span, trace, warn, Instrument}; use super::DiscoveryError; use crate::{ discovery::{Discovery, DiscoveryItem, NodeData, NodeInfo}, - watcher::{Watchable, Watcher as _}, Endpoint, }; diff --git a/iroh/src/discovery/pkarr.rs b/iroh/src/discovery/pkarr.rs index 58af87afb3c..5e32d6ad10c 100644 --- a/iroh/src/discovery/pkarr.rs +++ b/iroh/src/discovery/pkarr.rs @@ -53,6 +53,7 @@ use n0_future::{ task::{self, AbortOnDropHandle}, time::{self, Duration, Instant}, }; +use n0_watcher::{Disconnected, Watchable, Watcher as _}; use pkarr::{ errors::{PublicKeyError, SignedPacketVerifyError}, SignedPacket, @@ -65,7 +66,6 @@ use super::DiscoveryError; use crate::{ discovery::{Discovery, DiscoveryItem, NodeData}, endpoint::force_staging_infra, - watcher::{self, Disconnected, Watchable, Watcher as _}, Endpoint, }; @@ -247,7 +247,7 @@ struct PublisherService { secret_key: SecretKey, #[debug("PkarrClient")] pkarr_client: PkarrRelayClient, - watcher: watcher::Direct>, + watcher: n0_watcher::Direct>, ttl: u32, republish_interval: Duration, } diff --git a/iroh/src/endpoint.rs b/iroh/src/endpoint.rs index 46b5215f0be..76354bcbc3c 100644 --- a/iroh/src/endpoint.rs +++ b/iroh/src/endpoint.rs @@ -25,6 +25,7 @@ use ed25519_dalek::{pkcs8::DecodePublicKey, VerifyingKey}; use iroh_base::{NodeAddr, NodeId, RelayUrl, SecretKey}; use iroh_relay::RelayMap; use n0_future::{time::Duration, Stream}; +use n0_watcher::Watcher; use nested_enum_utils::common_fields; use pin_project::pin_project; use snafu::{ensure, ResultExt, Snafu}; @@ -43,9 +44,7 @@ use crate::{ magicsock::{self, Handle, NodeIdMappedAddr, OwnAddressSnafu}, metrics::EndpointMetrics, net_report::Report, - tls, - watcher::{self, Watcher}, - RelayProtocol, + tls, RelayProtocol, }; mod rtt_actor; @@ -82,29 +81,6 @@ const DISCOVERY_WAIT_PERIOD: Duration = Duration::from_millis(500); type DiscoveryBuilder = Box Option> + Send + Sync>; -/// A type alias for the return value of [`Endpoint::node_addr`]. -/// -/// This type implements [`Watcher`] with `Value` being an optional [`NodeAddr`]. -/// -/// We return a named type instead of `impl Watcher`, as this allows -/// you to e.g. store the watcher in a struct. -#[cfg(not(wasm_browser))] -pub type NodeAddrWatcher = watcher::Map< - ( - watcher::Direct>>, - watcher::Direct>, - ), - Option, ->; -/// A type alias for the return value of [`Endpoint::node_addr`]. -/// -/// This type implements [`Watcher`] with `Value` being an optional [`NodeAddr`]. -/// -/// We return a named type instead of `impl Watcher`, as this allows -/// you to e.g. store the watcher in a struct. -#[cfg(wasm_browser)] -pub type NodeAddrWatcher = watcher::Map>, Option>; - /// Defines the mode of path selection for all traffic flowing through /// the endpoint. #[cfg(any(test, feature = "test-utils"))] @@ -689,9 +665,10 @@ impl Endpoint { trace!("created magicsock"); debug!(version = env!("CARGO_PKG_VERSION"), "iroh Endpoint created"); + let metrics = msock.metrics.magicsock.clone(); let ep = Self { - msock: msock.clone(), - rtt_actor: Arc::new(rtt_actor::RttHandle::new(msock.metrics.magicsock.clone())), + msock, + rtt_actor: Arc::new(rtt_actor::RttHandle::new(metrics)), static_config: Arc::new(static_config), }; Ok(ep) @@ -950,7 +927,8 @@ impl Endpoint { /// /// ```no_run /// # async fn wrapper() -> n0_snafu::Result { - /// use iroh::{watcher::Watcher, Endpoint}; + /// use iroh::Endpoint; + /// use n0_watcher::Watcher; /// /// let endpoint = Endpoint::builder() /// .alpns(vec![b"my-alpn".to_vec()]) @@ -962,25 +940,22 @@ impl Endpoint { /// # } /// ``` #[cfg(not(wasm_browser))] - pub fn node_addr(&self) -> NodeAddrWatcher { + pub fn node_addr(&self) -> impl n0_watcher::Watcher> { let watch_addrs = self.direct_addresses(); let watch_relay = self.home_relay(); let node_id = self.node_id(); watch_addrs .or(watch_relay) - .map(move |(addrs, relay)| match (addrs, relay) { - (Some(addrs), relay) => Some(NodeAddr::from_parts( + .map(move |(addrs, mut relays)| match addrs { + Some(addrs) => Some(NodeAddr::from_parts( node_id, - relay, + relays.pop(), addrs.into_iter().map(|x| x.addr), )), - (None, Some(relay)) => Some(NodeAddr::from_parts( - node_id, - Some(relay), - std::iter::empty(), - )), - (None, None) => None, + None => relays.pop().map(|relay_url| { + NodeAddr::from_parts(node_id, Some(relay_url), std::iter::empty()) + }), }) .expect("watchable is alive - cannot be disconnected yet") } @@ -991,15 +966,17 @@ impl Endpoint { /// with a [`NodeAddr`] that only contains a relay URL, but no direct addresses, /// as there are no APIs for directly using sockets in browsers. #[cfg(wasm_browser)] - pub fn node_addr(&self) -> NodeAddrWatcher { + pub fn node_addr(&self) -> impl n0_watcher::Watcher> + '_ { // In browsers, there will never be any direct addresses, so we wait // for the home relay instead. This makes the `NodeAddr` have *some* way // of connecting to us. let watch_relay = self.home_relay(); let node_id = self.node_id(); watch_relay - .map(move |relay| { - relay.map(|relay| NodeAddr::from_parts(node_id, Some(relay), std::iter::empty())) + .map(move |mut relays| { + relays.pop().map(|relay_url| { + NodeAddr::from_parts(node_id, Some(relay_url), std::iter::empty()) + }) }) .expect("watchable is alive - cannot be disconnected yet") } @@ -1020,8 +997,9 @@ impl Endpoint { /// /// To wait for a home relay connection to be established, use [`Watcher::initialized`]: /// ```no_run - /// use futures_lite::StreamExt; - /// use iroh::{watcher::Watcher, Endpoint}; + /// use iroh::Endpoint; + /// use n0_future::StreamExt; + /// use n0_watcher::Watcher as _; /// /// # let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap(); /// # rt.block_on(async move { @@ -1029,7 +1007,7 @@ impl Endpoint { /// let _relay_url = mep.home_relay().initialized().await.unwrap(); /// # }); /// ``` - pub fn home_relay(&self) -> watcher::Direct> { + pub fn home_relay(&self) -> impl n0_watcher::Watcher> { self.msock.home_relay() } @@ -1056,8 +1034,9 @@ impl Endpoint { /// /// To get the first set of direct addresses use [`Watcher::initialized`]: /// ```no_run - /// use futures_lite::StreamExt; - /// use iroh::{watcher::Watcher, Endpoint}; + /// use iroh::Endpoint; + /// use n0_future::StreamExt; + /// use n0_watcher::Watcher as _; /// /// # let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap(); /// # rt.block_on(async move { @@ -1067,7 +1046,7 @@ impl Endpoint { /// ``` /// /// [STUN]: https://en.wikipedia.org/wiki/STUN - pub fn direct_addresses(&self) -> watcher::Direct>> { + pub fn direct_addresses(&self) -> n0_watcher::Direct>> { self.msock.direct_addresses() } @@ -1091,8 +1070,9 @@ impl Endpoint { /// /// To get the first report use [`Watcher::initialized`]: /// ```no_run - /// use futures_lite::StreamExt; - /// use iroh::{watcher::Watcher, Endpoint}; + /// use iroh::Endpoint; + /// use n0_future::StreamExt; + /// use n0_watcher::Watcher as _; /// /// # let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap(); /// # rt.block_on(async move { @@ -1101,7 +1081,7 @@ impl Endpoint { /// # }); /// ``` #[doc(hidden)] - pub fn net_report(&self) -> watcher::Direct>> { + pub fn net_report(&self) -> n0_watcher::Direct>> { self.msock.net_report() } @@ -1109,9 +1089,12 @@ impl Endpoint { /// /// The [`Endpoint`] always binds on an IPv4 address and also tries to bind on an IPv6 /// address if available. - #[cfg(not(wasm_browser))] - pub fn bound_sockets(&self) -> (SocketAddr, Option) { - self.msock.local_addr() + pub fn bound_sockets(&self) -> Vec { + self.msock + .local_addr() + .into_iter() + .filter_map(|addr| addr.into_socket_addr()) + .collect() } // # Getter methods for information about other nodes. @@ -1199,10 +1182,8 @@ impl Endpoint { /// recently connected to this node id but previous methods of reaching the node have /// become inaccessible. /// - /// # Errors - /// /// Will return `None` if we do not have any address information for the given `node_id`. - pub fn conn_type(&self, node_id: NodeId) -> Option> { + pub fn conn_type(&self, node_id: NodeId) -> Option> { self.msock.conn_type(node_id) } @@ -2310,6 +2291,7 @@ mod tests { use iroh_relay::http::Protocol; use n0_future::{task::AbortOnDropHandle, StreamExt}; use n0_snafu::{Error, Result, ResultExt}; + use n0_watcher::Watcher; use quinn::ConnectionError; use rand::SeedableRng; use tracing::{error_span, info, info_span, Instrument}; @@ -2319,7 +2301,6 @@ mod tests { use crate::{ endpoint::{ConnectOptions, Connection, ConnectionType, RemoteInfo}, test_utils::{run_relay_server, run_relay_server_with}, - watcher::Watcher, RelayMode, }; @@ -2519,7 +2500,8 @@ mod tests { .bind() .await?; let eps = ep.bound_sockets(); - info!(me = %ep.node_id().fmt_short(), ipv4=%eps.0, ipv6=?eps.1, "server listening on"); + + info!(me = %ep.node_id().fmt_short(), eps = ?eps, "server listening on"); for i in 0..n_clients { let round_start = Instant::now(); info!("[server] round {i}"); @@ -2561,7 +2543,8 @@ mod tests { .bind() .await?; let eps = ep.bound_sockets(); - info!(me = %ep.node_id().fmt_short(), ipv4=%eps.0, ipv6=?eps.1, "client bound"); + + info!(me = %ep.node_id().fmt_short(), eps=?eps, "client bound"); let node_addr = NodeAddr::new(server_node_id).with_relay_url(relay_url); info!(to = ?node_addr, "client connecting"); let conn = ep.connect(node_addr, TEST_ALPN).await.e()?; diff --git a/iroh/src/endpoint/rtt_actor.rs b/iroh/src/endpoint/rtt_actor.rs index f86ebeb9bad..fa0b6847c25 100644 --- a/iroh/src/endpoint/rtt_actor.rs +++ b/iroh/src/endpoint/rtt_actor.rs @@ -10,7 +10,7 @@ use n0_future::{ use tokio::sync::mpsc; use tracing::{debug, info_span, Instrument}; -use crate::{magicsock::ConnectionType, metrics::MagicsockMetrics, watcher}; +use crate::{magicsock::ConnectionType, metrics::MagicsockMetrics}; #[derive(Debug)] pub(super) struct RttHandle { @@ -47,7 +47,7 @@ pub(super) enum RttMessage { /// The connection. connection: quinn::WeakConnectionHandle, /// Path changes for this connection from the magic socket. - conn_type_changes: watcher::Stream>, + conn_type_changes: n0_watcher::Stream>, /// For reporting-only, the Node ID of this connection. node_id: NodeId, }, @@ -67,7 +67,7 @@ struct RttActor { #[derive(Debug)] struct MappedStream { - stream: watcher::Stream>, + stream: n0_watcher::Stream>, node_id: NodeId, /// Reference to the connection. connection: quinn::WeakConnectionHandle, @@ -157,7 +157,7 @@ impl RttActor { fn handle_new_connection( &mut self, connection: quinn::WeakConnectionHandle, - conn_type_changes: watcher::Stream>, + conn_type_changes: n0_watcher::Stream>, node_id: NodeId, ) { self.connection_events.push(MappedStream { diff --git a/iroh/src/lib.rs b/iroh/src/lib.rs index 2f44c3420e6..982c412d5f6 100644 --- a/iroh/src/lib.rs +++ b/iroh/src/lib.rs @@ -195,8 +195,8 @@ //! Every [`Endpoint`] can also accept connections: //! //! ```no_run -//! use futures_lite::StreamExt; //! use iroh::{Endpoint, NodeAddr}; +//! use n0_future::StreamExt; //! use n0_snafu::{Result, ResultExt}; //! //! async fn accept() -> Result<()> { @@ -269,7 +269,6 @@ pub mod endpoint; pub mod metrics; pub mod net_report; pub mod protocol; -pub mod watcher; pub use endpoint::{Endpoint, RelayMode}; pub use iroh_base::{ diff --git a/iroh/src/magicsock.rs b/iroh/src/magicsock.rs index 6bb68259829..c73cd457bca 100644 --- a/iroh/src/magicsock.rs +++ b/iroh/src/magicsock.rs @@ -22,15 +22,13 @@ use std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, pin::Pin, sync::{ - atomic::{AtomicBool, AtomicU16, AtomicU64, AtomicUsize, Ordering}, + atomic::{AtomicBool, AtomicU64, Ordering}, Arc, RwLock, }, - task::{Context, Poll, Waker}, + task::{Context, Poll}, }; -use atomic_waker::AtomicWaker; use bytes::Bytes; -use concurrent_queue::ConcurrentQueue; use data_encoding::HEXLOWER; use iroh_base::{NodeAddr, NodeId, PublicKey, RelayUrl, SecretKey}; use iroh_relay::{protos::stun, RelayMap}; @@ -38,34 +36,31 @@ use n0_future::{ boxed::BoxStream, task::{self, JoinSet}, time::{self, Duration, Instant}, - FutureExt, StreamExt, + StreamExt, }; +use n0_watcher::{self, Watchable, Watcher}; use nested_enum_utils::common_fields; -use netwatch::{ - interfaces, - netmon::{self, CallbackToken}, -}; +use netwatch::netmon; #[cfg(not(wasm_browser))] use netwatch::{ip::LocalAddresses, UdpSocket}; use quinn::{AsyncUdpSocket, ServerConfig}; use rand::{seq::SliceRandom, Rng, SeedableRng}; -use relay_actor::RelaySendItem; -use smallvec::{smallvec, SmallVec}; +use smallvec::SmallVec; use snafu::{IntoError, ResultExt, Snafu}; use tokio::sync::{self, mpsc, Mutex}; -use tokio_util::sync::CancellationToken; use tracing::{ debug, error, error_span, event, info, info_span, instrument, trace, trace_span, warn, Instrument, Level, Span, }; +use transports::LocalAddrsWatch; use url::Url; #[cfg(not(wasm_browser))] -use self::udp_conn::UdpConn; +use self::transports::IpTransport; use self::{ metrics::Metrics as MagicsockMetrics, node_map::{NodeMap, PingAction, PingRole, SendPing}, - relay_actor::{RelayActor, RelayActorMessage, RelayRecvDatagram}, + transports::{RelayActorConfig, RelayTransport, Transports, UdpSender}, }; #[cfg(not(wasm_browser))] use crate::dns::DnsResolver; @@ -75,19 +70,17 @@ use crate::endpoint::PathSelection; use crate::net_report::{IpMappedAddr, QuicConfig}; use crate::{ defaults::timeouts::NET_REPORT_TIMEOUT, - disco::{self, CallMeMaybe, SendAddr}, + disco::{self, SendAddr}, discovery::{Discovery, DiscoveryItem, DiscoverySubscribers, NodeData, UserData}, key::{public_ed_box, secret_ed_box, DecryptionError, SharedSecret}, metrics::EndpointMetrics, net_report::{self, IpMappedAddresses, Report, ReportError}, - watcher::{self, Watchable}, }; mod metrics; mod node_map; -mod relay_actor; -#[cfg(not(wasm_browser))] -mod udp_conn; + +pub(crate) mod transports; pub use node_map::Source; @@ -190,18 +183,6 @@ pub(crate) struct MagicSock { actor_sender: mpsc::Sender, /// String representation of the node_id of this node. me: String, - /// Proxy - proxy_url: Option, - /// Queue to receive datagrams from relays for [`AsyncUdpSocket::poll_recv`]. - /// - /// Relay datagrams received by relays are put into this queue and consumed by - /// [`AsyncUdpSocket`]. This queue takes care of the wakers needed by - /// [`AsyncUdpSocket::poll_recv`]. - relay_datagram_recv_queue: Arc, - /// Channel on which to send datagrams via a relay server. - relay_datagram_send_channel: RelayDatagramSendChannelSender, - /// Counter for ordering of [`MagicSock::poll_recv`] polling order. - poll_recv_counter: AtomicUsize, /// The DNS resolver to be used in this magicsock. #[cfg(not(wasm_browser))] @@ -212,10 +193,6 @@ pub(crate) struct MagicSock { /// Encryption key for this node. secret_encryption_key: crypto_box::SecretKey, - /// Sockets & related state - #[cfg(not(wasm_browser))] - sockets: SocketState, - /// Close is in progress (or done) closing: AtomicBool, /// Close was called. @@ -223,10 +200,8 @@ pub(crate) struct MagicSock { /// If the last net_report report, reports IPv6 to be available. ipv6_reported: Arc, - /// None (or zero nodes) means relay is disabled. + /// Zero nodes means relay is disabled. relay_map: RelayMap, - /// Nearest relay node ID; 0 means none/unknown. - my_relay: Watchable>, /// Tracks the networkmap node entity for each node discovery key. node_map: NodeMap, /// Tracks the mapped IP addresses @@ -236,8 +211,8 @@ pub(crate) struct MagicSock { /// The state for an active DiscoKey. disco_secrets: DiscoSecrets, - /// UDP disco (ping) queue - udp_disco_sender: mpsc::Sender<(SocketAddr, PublicKey, disco::Message)>, + /// Disco (ping) queue + disco_sender: mpsc::Sender<(SendAddr, PublicKey, disco::Message)>, /// Optional discovery service discovery: Option>, @@ -258,30 +233,14 @@ pub(crate) struct MagicSock { /// Indicates the direct addr update state. direct_addr_update_state: DirectAddrUpdateState, - /// Skip verification of SSL certificates from relay servers - /// - /// May only be used in tests. - #[cfg(any(test, feature = "test-utils"))] - insecure_skip_relay_cert_verify: bool, - /// Broadcast channel for listening to discovery updates. discovery_subscribers: DiscoverySubscribers, pub(crate) metrics: EndpointMetrics, -} -/// Sockets and related state, grouped together so we can cfg them out for browsers. -#[cfg(not(wasm_browser))] -#[derive(Debug)] -pub(crate) struct SocketState { - /// Port configured for the ipv4 socket. Can be 0 - port: AtomicU16, - /// UDP IPv4 socket - v4: UdpConn, - /// UDP IPv6 socket - v6: Option, - /// Cached version of the Ipv4 and Ipv6 addrs of the current connection. - local_addrs: std::sync::RwLock<(SocketAddr, Option)>, + local_addrs_watch: LocalAddrsWatch, + #[cfg(not(wasm_browser))] + ip_bind_addrs: Vec, } #[allow(missing_docs)] @@ -312,19 +271,13 @@ impl MagicSock { /// /// If `None`, then we are not connected to any relay nodes. pub(crate) fn my_relay(&self) -> Option { - self.my_relay.get() - } - - /// Get the current proxy configuration. - pub(crate) fn proxy_url(&self) -> Option<&Url> { - self.proxy_url.as_ref() - } - - /// Sets the relay node with the best latency. - /// - /// If we are not connected to any relay nodes, set this to `None`. - fn set_my_relay(&self, my_relay: Option) -> Option { - self.my_relay.set(my_relay).unwrap_or_else(|e| e) + self.local_addr().into_iter().find_map(|a| { + if let transports::Addr::Relay(url, _) = a { + Some(url) + } else { + None + } + }) } fn is_closing(&self) -> bool { @@ -339,10 +292,20 @@ impl MagicSock { self.secret_key.public() } - /// Get the cached version of the Ipv4 and Ipv6 addrs of the current connection. + /// Get the cached version of addresses. + pub(crate) fn local_addr(&self) -> Vec { + self.local_addrs_watch.get().expect("disconnected") + } + #[cfg(not(wasm_browser))] - pub(crate) fn local_addr(&self) -> (SocketAddr, Option) { - *self.sockets.local_addrs.read().expect("not poisoned") + fn ip_bind_addrs(&self) -> &[SocketAddr] { + &self.ip_bind_addrs + } + + fn ip_local_addrs(&self) -> impl Iterator { + self.local_addr() + .into_iter() + .filter_map(|addr| addr.into_socket_addr()) } /// Returns `true` if we have at least one candidate address where we can send packets to. @@ -375,9 +338,9 @@ impl MagicSock { /// /// To get the current direct addresses, use [`Watcher::initialized`]. /// - /// [`Watcher`]: crate::watcher::Watcher - /// [`Watcher::initialized`]: crate::watcher::Watcher::initialized - pub(crate) fn direct_addresses(&self) -> watcher::Direct>> { + /// [`Watcher`]: n0_watcher::Watcher + /// [`Watcher::initialized`]: n0_watcher::Watcher::initialized + pub(crate) fn direct_addresses(&self) -> n0_watcher::Direct>> { self.direct_addrs.addrs.watch() } @@ -393,9 +356,9 @@ impl MagicSock { /// /// To get the current `net-report`, use [`Watcher::initialized`]. /// - /// [`Watcher`]: crate::watcher::Watcher - /// [`Watcher::initialized`]: crate::watcher::Watcher::initialized - pub(crate) fn net_report(&self) -> watcher::Direct>> { + /// [`Watcher`]: n0_watcher::Watcher + /// [`Watcher::initialized`]: n0_watcher::Watcher::initialized + pub(crate) fn net_report(&self) -> n0_watcher::Direct>> { self.net_report.watch() } @@ -403,24 +366,33 @@ impl MagicSock { /// /// Note that this can be used to wait for the initial home relay to be known using /// [`Watcher::initialized`]. - /// - /// [`Watcher`]: crate::watcher::Watcher - /// [`Watcher::initialized`]: crate::watcher::Watcher::initialized - pub(crate) fn home_relay(&self) -> watcher::Direct> { - self.my_relay.watch() + pub(crate) fn home_relay(&self) -> impl Watcher> { + let res = self.local_addrs_watch.clone().map(|addrs| { + addrs + .into_iter() + .filter_map(|addr| { + if let transports::Addr::Relay(url, _) = addr { + Some(url) + } else { + None + } + }) + .collect() + }); + res.expect("disconnected") } - /// Returns a [`watcher::Direct`] that reports the [`ConnectionType`] we have to the + /// Returns a [`n0_watcher::Direct`] that reports the [`ConnectionType`] we have to the /// given `node_id`. /// - /// This gets us a copy of the [`watcher::Direct`] for the [`Watchable`] with a [`ConnectionType`] + /// This gets us a copy of the [`n0_watcher::Direct`] for the [`Watchable`] with a [`ConnectionType`] /// that the `NodeMap` stores for each `node_id`'s endpoint. /// /// # Errors /// /// Will return `None` if there is no address information known about the /// given `node_id`. - pub(crate) fn conn_type(&self, node_id: NodeId) -> Option> { + pub(crate) fn conn_type(&self, node_id: NodeId) -> Option> { self.node_map.conn_type(node_id) } @@ -509,17 +481,34 @@ impl MagicSock { .ok(); } - #[cfg(not(wasm_browser))] #[cfg_attr(windows, allow(dead_code))] fn normalized_local_addr(&self) -> io::Result { - let (v4, v6) = self.local_addr(); - let addr = if let Some(v6) = v6 { v6 } else { v4 }; - Ok(addr) + let addrs = self.local_addrs_watch.get().expect("disconnected"); + + let mut ipv4_addr = None; + for addr in addrs { + let Some(addr) = addr.into_socket_addr() else { + continue; + }; + if addr.is_ipv6() { + return Ok(addr); + } + if addr.is_ipv4() && ipv4_addr.is_none() { + ipv4_addr.replace(addr); + } + } + match ipv4_addr { + Some(addr) => Ok(addr), + None => Err(io::Error::other("no valid socket available")), + } } - /// Implementation for AsyncUdpSocket::try_send + /// Searches the `node_map` to determine the current transports to be used. #[instrument(skip_all)] - fn try_send(&self, transmit: &quinn_udp::Transmit) -> io::Result<()> { + fn prepare_send( + &self, + transmit: &quinn_udp::Transmit, + ) -> io::Result> { self.metrics .magicsock .send_data @@ -535,16 +524,12 @@ impl MagicSock { "connection closed", )); } + + let mut active_paths = SmallVec::<[_; 3]>::new(); + match MappedAddr::from(transmit.destination) { MappedAddr::None(dest) => { - error!(%dest, "Cannot convert to a mapped address, voiding transmit."); - // Returning Ok here means we let QUIC timeout. - // Returning an error would immediately fail a connection. - // The philosophy of quinn-udp is that a UDP connection could - // come back at any time or missing should be transient so chooses to let - // these kind of errors time out. See test_try_send_no_send_addr to try - // this out. - Ok(()) + error!(%dest, "Cannot convert to a mapped address."); } MappedAddr::NodeId(dest) => { trace!( @@ -556,124 +541,26 @@ impl MagicSock { // Get the node's relay address and best direct address, as well // as any pings that need to be sent for hole-punching purposes. - let mut transmit = transmit.clone(); match self.node_map.get_send_addrs( dest, self.ipv6_reported.load(Ordering::Relaxed), &self.metrics.magicsock, ) { - Some((node_id, udp_addr, relay_url, msgs)) => { - let mut pings_sent = false; - // If we have pings to send, we *have* to send them out first. - if !msgs.is_empty() { - if let Err(err) = self.try_send_ping_actions(msgs) { - warn!( - node = %node_id.fmt_short(), - "failed to handle ping actions: {err:#}", - ); - } - pings_sent = true; + Some((node_id, udp_addr, relay_url, ping_actions)) => { + if !ping_actions.is_empty() { + self.actor_sender + .try_send(ActorMessage::PingActions(ping_actions)) + .ok(); } - - let mut udp_sent = false; - let mut udp_error: Option = None; - let mut relay_sent = false; - let mut relay_error = None; - - // send udp - #[cfg(not(wasm_browser))] if let Some(addr) = udp_addr { - // rewrite target address - transmit.destination = addr; - match self.try_send_udp(addr, &transmit) { - Ok(()) => { - trace!(node = %node_id.fmt_short(), dst = %addr, - "sent transmit over UDP"); - udp_sent = true; - } - Err(err) => { - // No need to print "WouldBlock" errors to the console - if err.kind() != io::ErrorKind::WouldBlock { - warn!( - node = %node_id.fmt_short(), - dst = %addr, - "failed to send udp: {err:#}" - ); - } - udp_error = Some(err); - } - } + active_paths.push(transports::Addr::from(addr)); } - - // send relay - if let Some(ref relay_url) = relay_url { - match self.try_send_relay(relay_url, node_id, split_packets(&transmit)) - { - Ok(()) => { - relay_sent = true; - } - Err(err) => { - relay_error = Some(err); - } - } - } - - #[cfg(not(wasm_browser))] - let udp_pending = udp_error - .as_ref() - .map(|err| err.kind() == io::ErrorKind::WouldBlock) - .unwrap_or_default(); - #[cfg(wasm_browser)] - let udp_pending = false; - let relay_pending = relay_error - .as_ref() - .map(|err| err.kind() == io::ErrorKind::WouldBlock) - .unwrap_or_default(); - if udp_pending && relay_pending { - // Handle backpressure. - return Err(io::Error::new(io::ErrorKind::WouldBlock, "pending")); - } else { - if relay_sent || udp_sent { - trace!( - node = %node_id.fmt_short(), - send_udp = ?udp_addr, - send_relay = ?relay_url, - "sent transmit", - ); - } else if !pings_sent { - // Returning Ok here means we let QUIC handle a timeout for a lost - // packet, same would happen if we returned any errors. The - // philosophy of quinn-udp is that a UDP connection could come back - // at any time so these errors should be treated as transient and - // are just timeouts. Hence we opt for returning Ok. See - // test_try_send_no_udp_addr_or_relay_url to explore this further. - debug!( - node = %node_id.fmt_short(), - "no UDP or relay paths available for node, voiding transmit", - ); - // We log this as debug instead of error, because this is a - // situation that comes up under normal operation. If this were an - // error log, it would unnecessarily pollute logs. - // This situation happens essentially when `pings_sent` is false, - // `relay_url` is `None`, so `relay_sent` is false, and the UDP - // path is blocking, so `udp_sent` is false and `udp_pending` is - // true. - // Alternatively returning a WouldBlock error here would - // potentially needlessly block sending on the relay path for the - // next datagram. - } - return Ok(()); + if let Some(url) = relay_url { + active_paths.push(transports::Addr::Relay(url, node_id)); } } None => { - error!(%dest, "no NodeState for mapped address, dropping transmit"); - // Returning Ok here means we let QUIC timeout. Returning WouldBlock - // triggers a hot loop. Returning an error would immediately fail a - // connection. The philosophy of quinn-udp is that a UDP connection could - // come back at any time or missing should be transient so chooses to let - // these kind of errors time out. See test_try_send_no_send_addr to try - // this out. - return Ok(()); + error!(%dest, "no NodeState for mapped address"); } } } @@ -687,205 +574,19 @@ impl MagicSock { ); // Check if this is a known IpMappedAddr, and if so, send over UDP - let mut transmit = transmit.clone(); - // Get the socket addr match self.ip_mapped_addrs.get_ip_addr(&dest) { Some(addr) => { - // rewrite target address - transmit.destination = addr; - // send udp - match self.try_send_udp(addr, &transmit) { - Ok(()) => { - trace!(dst = %addr, - "sent IpMapped transmit over UDP"); - } - Err(err) => { - // No need to print "WouldBlock" errors to the console - if err.kind() == io::ErrorKind::WouldBlock { - return Err(io::Error::new( - io::ErrorKind::WouldBlock, - "pending", - )); - } else { - warn!( - dst = %addr, - "failed to send IpMapped message over udp: {err:#}" - ); - } - } - } - return Ok(()); + active_paths.push(transports::Addr::from(addr)); } None => { - error!(%dest, "unknown mapped address, dropping transmit"); - // Returning Ok here means we let QUIC timeout. - // Returning an error would immediately fail a connection. - // The philosophy of quinn-udp is that a UDP connection could - // come back at any time or missing should be transient so chooses to let - // these kind of errors time out. See test_try_send_no_send_addr to try - // this out. - return Ok(()); - } - } - } - } - } - - fn try_send_relay( - &self, - url: &RelayUrl, - node: NodeId, - contents: RelayContents, - ) -> io::Result<()> { - trace!( - node = %node.fmt_short(), - relay_url = %url, - count = contents.len(), - len = contents.iter().map(|c| c.len()).sum::(), - "send relay", - ); - let msg = RelaySendItem { - remote_node: node, - url: url.clone(), - datagrams: contents, - }; - match self.relay_datagram_send_channel.try_send(msg) { - Ok(_) => { - trace!(node = %node.fmt_short(), relay_url = %url, - "send relay: message queued"); - Ok(()) - } - Err(mpsc::error::TrySendError::Closed(_)) => { - error!(node = %node.fmt_short(), relay_url = %url, - "send relay: message dropped, channel to actor is closed"); - Err(io::Error::new( - io::ErrorKind::ConnectionReset, - "channel to actor is closed", - )) - } - Err(mpsc::error::TrySendError::Full(_)) => { - warn!(node = %node.fmt_short(), relay_url = %url, - "send relay: message dropped, channel to actor is full"); - Err(io::Error::new( - io::ErrorKind::WouldBlock, - "channel to actor is full", - )) - } - } - } - - #[cfg(not(wasm_browser))] - fn try_send_udp(&self, addr: SocketAddr, transmit: &quinn_udp::Transmit) -> io::Result<()> { - let conn = self.conn_for_addr(addr)?; - conn.try_send(transmit)?; - let total_bytes: u64 = transmit.contents.len() as u64; - if addr.is_ipv6() { - self.metrics.magicsock.send_ipv6.inc_by(total_bytes); - } else { - self.metrics.magicsock.send_ipv4.inc_by(total_bytes); - } - Ok(()) - } - - #[cfg(not(wasm_browser))] - fn conn_for_addr(&self, addr: SocketAddr) -> io::Result<&UdpConn> { - let sock = match addr { - SocketAddr::V4(_) => &self.sockets.v4, - SocketAddr::V6(_) => self - .sockets - .v6 - .as_ref() - .ok_or(io::Error::other("no IPv6 connection"))?, - }; - Ok(sock) - } - - /// NOTE: Receiving on a [`Self::closed`] socket will return [`Poll::Pending`] indefinitely. - #[instrument(skip_all)] - #[cfg(not(wasm_browser))] - fn poll_recv( - &self, - cx: &mut Context, - bufs: &mut [io::IoSliceMut<'_>], - metas: &mut [quinn_udp::RecvMeta], - ) -> Poll> { - debug_assert_eq!(bufs.len(), metas.len(), "non matching bufs & metas"); - if self.is_closed() { - return Poll::Pending; - } - - // Three macros to help polling: they return if they get a result, execution - // continues if they were Pending and we need to poll others (or finally return - // Pending). - macro_rules! poll_ipv4 { - () => { - match self.sockets.v4.poll_recv(cx, bufs, metas)? { - Poll::Pending | Poll::Ready(0) => {} - Poll::Ready(n) => { - self.process_udp_datagrams(true, &mut bufs[..n], &mut metas[..n]); - return Poll::Ready(Ok(n)); - } - } - }; - } - macro_rules! poll_ipv6 { - () => { - if let Some(ref socket) = self.sockets.v6 { - match socket.poll_recv(cx, bufs, metas)? { - Poll::Pending | Poll::Ready(0) => {} - Poll::Ready(n) => { - self.process_udp_datagrams(false, &mut bufs[..n], &mut metas[..n]); - return Poll::Ready(Ok(n)); - } + error!(%dest, "unknown mapped address"); } } - }; - } - macro_rules! poll_relay { - () => { - match self.poll_recv_relay(cx, bufs, metas) { - Poll::Pending => {} - Poll::Ready(n) => return Poll::Ready(n), - } - }; - } - - let counter = self.poll_recv_counter.fetch_add(1, Ordering::Relaxed); - match counter % 3 { - 0 => { - // order of polling: UDPv4, UDPv6, relay - poll_ipv4!(); - poll_ipv6!(); - poll_relay!(); - Poll::Pending - } - 1 => { - // order of polling: UDPv6, relay, UDPv4 - poll_ipv6!(); - poll_relay!(); - poll_ipv4!(); - Poll::Pending - } - _ => { - // order of polling: relay, UDPv4, UDPv6 - poll_relay!(); - poll_ipv4!(); - poll_ipv6!(); - Poll::Pending } } - } - /// poll_recv in browsers is "just" polling the relay receive path - #[cfg(wasm_browser)] - fn poll_recv( - &self, - cx: &mut Context, - bufs: &mut [io::IoSliceMut<'_>], - metas: &mut [quinn_udp::RecvMeta], - ) -> Poll> { - self.poll_recv_relay(cx, bufs, metas) + Ok(active_paths) } /// Process datagrams received from UDP sockets. @@ -894,14 +595,18 @@ impl MagicSock { /// /// This fixes up the datagrams to use the correct [`NodeIdMappedAddr`] and extracts DISCO /// packets, processing them inside the magic socket. - #[cfg(not(wasm_browser))] - fn process_udp_datagrams( + fn process_datagrams( &self, - from_ipv4: bool, bufs: &mut [io::IoSliceMut<'_>], metas: &mut [quinn_udp::RecvMeta], + source_addrs: &[transports::Addr], ) { debug_assert_eq!(bufs.len(), metas.len(), "non matching bufs & metas"); + debug_assert_eq!( + bufs.len(), + source_addrs.len(), + "non matching bufs & source_addrs" + ); // Adding the IP address we received something on results in Quinn using this // address on the send path to send from. However we let Quinn use a @@ -913,7 +618,7 @@ impl MagicSock { // NodeState/PathState. Then on the send path it should be retrieved from the // NodeState/PathSate together with the send address and substituted at send time. // This is relevant for IPv6 link-local addresses where the OS otherwise does not - // know which intervace to send from. + // know which interface to send from. #[cfg(not(windows))] let dst_ip = self.normalized_local_addr().ok().map(|addr| addr.ip()); // Reasoning for this here: @@ -923,21 +628,25 @@ impl MagicSock { let mut quic_packets_total = 0; - for (meta, buf) in metas.iter_mut().zip(bufs.iter_mut()) { + for ((quinn_meta, buf), source_addr) in metas + .iter_mut() + .zip(bufs.iter_mut()) + .zip(source_addrs.iter()) + { let mut buf_contains_quic_datagrams = false; let mut quic_datagram_count = 0; - if meta.len > meta.stride { - trace!(%meta.len, %meta.stride, "GRO datagram received"); + if quinn_meta.len > quinn_meta.stride { + trace!(%quinn_meta.len, %quinn_meta.stride, "GRO datagram received"); self.metrics.magicsock.recv_gro_datagrams.inc(); } // Chunk through the datagrams in this GRO payload to find disco and stun // packets and forward them to the actor - for datagram in buf[..meta.len].chunks_mut(meta.stride) { - if datagram.len() < meta.stride { + for datagram in buf[..quinn_meta.len].chunks_mut(quinn_meta.stride) { + if datagram.len() < quinn_meta.stride { trace!( len = %datagram.len(), - %meta.stride, + %quinn_meta.stride, "Last GRO datagram smaller than stride", ); } @@ -946,84 +655,109 @@ impl MagicSock { // byte of those packets with zero to make Quinn ignore the packet. This // relies on quinn::EndpointConfig::grease_quic_bit being set to `false`, // which we do in Endpoint::bind. - if stun::is(datagram) { - trace!(src = %meta.addr, len = %meta.stride, "UDP recv: stun packet"); + if source_addr.is_ip() && stun::is(datagram) { + trace!(src = ?source_addr, len = %quinn_meta.stride, "UDP recv: stun packet"); let packet2 = Bytes::copy_from_slice(datagram); - self.net_reporter.receive_stun_packet(packet2, meta.addr); + self.net_reporter.receive_stun_packet( + packet2, + source_addr.clone().into_socket_addr().expect("checked"), + ); datagram[0] = 0u8; } else if let Some((sender, sealed_box)) = disco::source_and_box(datagram) { - trace!(src = %meta.addr, len = %meta.stride, "UDP recv: disco packet"); - self.handle_disco_message( - sender, - sealed_box, - DiscoMessageSource::Udp(meta.addr), - ); + trace!(src = ?source_addr, len = %quinn_meta.stride, "UDP recv: disco packet"); + self.handle_disco_message(sender, sealed_box, source_addr); datagram[0] = 0u8; } else { - trace!(src = %meta.addr, len = %meta.stride, "UDP recv: quic packet"); - if from_ipv4 { - self.metrics - .magicsock - .recv_data_ipv4 - .inc_by(datagram.len() as _); - } else { - self.metrics - .magicsock - .recv_data_ipv6 - .inc_by(datagram.len() as _); + trace!(src = ?source_addr, len = %quinn_meta.stride, "UDP recv: quic packet"); + match source_addr { + transports::Addr::Ip(SocketAddr::V4(..)) => { + self.metrics + .magicsock + .recv_data_ipv4 + .inc_by(datagram.len() as _); + } + transports::Addr::Ip(SocketAddr::V6(..)) => { + self.metrics + .magicsock + .recv_data_ipv6 + .inc_by(datagram.len() as _); + } + transports::Addr::Relay(..) => { + self.metrics + .magicsock + .recv_data_relay + .inc_by(datagram.len() as _); + } } + quic_datagram_count += 1; buf_contains_quic_datagrams = true; - }; + } } if buf_contains_quic_datagrams { - // Update the NodeMap and remap RecvMeta to the NodeIdMappedAddr. - match self.node_map.receive_udp(meta.addr) { - None => { - // Check if this address is mapped to an IpMappedAddr - if let Some(ip_mapped_addr) = - self.ip_mapped_addrs.get_mapped_addr(&meta.addr) - { - trace!( - src = ?meta.addr, - count = %quic_datagram_count, - len = meta.len, - "UDP recv QUIC address discovery packets", - ); - quic_packets_total += quic_datagram_count; - meta.addr = ip_mapped_addr.private_socket_addr(); - } else { - warn!( - src = ?meta.addr, - count = %quic_datagram_count, - len = meta.len, - "UDP recv quic packets: no node state found, skipping", - ); - // If we have no node state for the from addr, set len to 0 to make - // quinn skip the buf completely. - meta.len = 0; + match source_addr { + #[cfg(wasm_browser)] + transports::Addr::Ip(_addr) => { + panic!("cannot use IP based addressing in the browser"); + } + #[cfg(not(wasm_browser))] + transports::Addr::Ip(addr) => { + // UDP + + // Update the NodeMap and remap RecvMeta to the NodeIdMappedAddr. + match self.node_map.receive_udp(*addr) { + None => { + // Check if this address is mapped to an IpMappedAddr + if let Some(ip_mapped_addr) = + self.ip_mapped_addrs.get_mapped_addr(addr) + { + trace!( + src = %addr, + count = %quic_datagram_count, + len = quinn_meta.len, + "UDP recv QUIC address discovery packets", + ); + quic_packets_total += quic_datagram_count; + quinn_meta.addr = ip_mapped_addr.private_socket_addr(); + } else { + warn!( + src = %addr, + count = %quic_datagram_count, + len = quinn_meta.len, + "UDP recv quic packets: no node state found, skipping", + ); + // If we have no node state for the from addr, set len to 0 to make + // quinn skip the buf completely. + quinn_meta.len = 0; + } + } + Some((node_id, quic_mapped_addr)) => { + trace!( + src = %addr, + node = %node_id.fmt_short(), + count = %quic_datagram_count, + len = quinn_meta.len, + "UDP recv quic packets", + ); + quic_packets_total += quic_datagram_count; + quinn_meta.addr = quic_mapped_addr.private_socket_addr(); + } } } - Some((node_id, quic_mapped_addr)) => { - trace!( - src = ?meta.addr, - node = %node_id.fmt_short(), - count = %quic_datagram_count, - len = meta.len, - "UDP recv quic packets", - ); - quic_packets_total += quic_datagram_count; - meta.addr = quic_mapped_addr.private_socket_addr(); + transports::Addr::Relay(src_url, src_node) => { + // Relay + let quic_mapped_addr = self.node_map.receive_relay(src_url, *src_node); + quinn_meta.addr = quic_mapped_addr.private_socket_addr(); } } } else { // If all datagrams in this buf are DISCO or STUN, set len to zero to make // Quinn skip the buf completely. - meta.len = 0; + quinn_meta.len = 0; } // Normalize local_ip - meta.dst_ip = dst_ip; + quinn_meta.dst_ip = dst_ip; } if quic_packets_total > 0 { @@ -1035,146 +769,21 @@ impl MagicSock { } } - #[instrument(skip_all)] - fn poll_recv_relay( - &self, - cx: &mut Context, - bufs: &mut [io::IoSliceMut<'_>], - metas: &mut [quinn_udp::RecvMeta], - ) -> Poll> { - let mut num_msgs = 0; - 'outer: for (buf_out, meta_out) in bufs.iter_mut().zip(metas.iter_mut()) { - if self.is_closed() { - break; - } - - // For each output buffer keep polling the datagrams from the relay until one is - // a QUIC datagram to be placed into the output buffer. Or the channel is empty. - loop { - let recv = match self.relay_datagram_recv_queue.poll_recv(cx) { - Poll::Ready(Ok(recv)) => recv, - Poll::Ready(Err(err)) => { - error!("relay_recv_channel closed: {err:#}"); - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::NotConnected, - "connection closed", - ))); - } - Poll::Pending => { - break 'outer; - } - }; - match self.process_relay_read_result(recv) { - None => { - // Received a DISCO or STUN datagram that was handled internally. - continue; - } - Some((node_id, meta, buf)) => { - self.metrics - .magicsock - .recv_data_relay - .inc_by(buf.len() as _); - trace!( - src = %meta.addr, - node = %node_id.fmt_short(), - count = meta.len / meta.stride, - len = meta.len, - "recv quic packets from relay", - ); - buf_out[..buf.len()].copy_from_slice(&buf); - *meta_out = meta; - num_msgs += 1; - break; - } - } - } - } - - // If we have any msgs to report, they are in the first `num_msgs_total` slots - if num_msgs > 0 { - self.metrics.magicsock.recv_datagrams.inc_by(num_msgs as _); - Poll::Ready(Ok(num_msgs)) - } else { - Poll::Pending - } - } - - /// Process datagrams received from the relay server into incoming Quinn datagrams. - /// - /// This will transform datagrams received from the relay server into Quinn datagrams to - /// receive, adding the [`quinn_udp::RecvMeta`]. - /// - /// If the incoming datagram is a DISCO packet it will be handled internally and `None` - /// is returned. - fn process_relay_read_result( - &self, - dm: RelayRecvDatagram, - ) -> Option<(NodeId, quinn_udp::RecvMeta, Bytes)> { - trace!("process_relay_read {} bytes", dm.buf.len()); - if dm.buf.is_empty() { - warn!("received empty relay packet"); - return None; - } - - if self.handle_relay_disco_message(&dm.buf, &dm.url, dm.src) { - // DISCO messages are handled internally in the MagicSock, do not pass to Quinn. - return None; - } - - let quic_mapped_addr = self.node_map.receive_relay(&dm.url, dm.src); - - // Normalize local_ip - #[cfg(not(any(windows, wasm_browser)))] - let dst_ip = self.normalized_local_addr().ok().map(|addr| addr.ip()); - // Reasoning for this here: - // https://github.com/n0-computer/iroh/pull/2595#issuecomment-2290947319 - #[cfg(any(windows, wasm_browser))] - let dst_ip = None; - - let meta = quinn_udp::RecvMeta { - len: dm.buf.len(), - stride: dm.buf.len(), - addr: quic_mapped_addr.private_socket_addr(), - dst_ip, - ecn: None, - }; - Some((dm.src, meta, dm.buf)) - } - - fn handle_relay_disco_message( - &self, - msg: &[u8], - url: &RelayUrl, - relay_node_src: PublicKey, - ) -> bool { - match disco::source_and_box(msg) { - Some((source, sealed_box)) => { - if relay_node_src != source { - // TODO: return here? - warn!("Received relay disco message from connection for {}, but with message from {}", relay_node_src.fmt_short(), source.fmt_short()); - } - self.handle_disco_message( - source, - sealed_box, - DiscoMessageSource::Relay { - url: url.clone(), - key: relay_node_src, - }, - ); - true - } - None => false, - } - } - /// Handles a discovery message. - #[instrument("disco_in", skip_all, fields(node = %sender.fmt_short(), %src))] - fn handle_disco_message(&self, sender: PublicKey, sealed_box: &[u8], src: DiscoMessageSource) { + #[instrument("disco_in", skip_all, fields(node = %sender.fmt_short(), ?src))] + fn handle_disco_message(&self, sender: PublicKey, sealed_box: &[u8], src: &transports::Addr) { trace!("handle_disco_message start"); if self.is_closed() { return; } + if let transports::Addr::Relay(_, node_id) = src { + if node_id != &sender { + // TODO: return here? + warn!("Received relay disco message from connection for {:?}, but with message from {}", node_id.fmt_short(), sender.fmt_short()); + } + } + // We're now reasonably sure we're expecting communication from // this node, do the heavy crypto lifting to see what they want. let dm = match self.disco_secrets.unseal_and_decode( @@ -1217,12 +826,12 @@ impl MagicSock { } disco::Message::Pong(pong) => { self.metrics.magicsock.recv_disco_pong.inc(); - self.node_map.handle_pong(sender, &src, pong); + self.node_map.handle_pong(sender, src, pong); } disco::Message::CallMeMaybe(cm) => { self.metrics.magicsock.recv_disco_call_me_maybe.inc(); match src { - DiscoMessageSource::Relay { url, .. } => { + transports::Addr::Relay(url, _) => { event!( target: "iroh::_events::call-me-maybe::recv", Level::DEBUG, @@ -1255,22 +864,22 @@ impl MagicSock { } /// Handle a ping message. - fn handle_ping(&self, dm: disco::Ping, sender: NodeId, src: DiscoMessageSource) { + fn handle_ping(&self, dm: disco::Ping, sender: NodeId, src: &transports::Addr) { // Insert the ping into the node map, and return whether a ping with this tx_id was already // received. let addr: SendAddr = src.clone().into(); let handled = self.node_map.handle_ping(sender, addr.clone(), dm.tx_id); match handled.role { PingRole::Duplicate => { - debug!(%src, tx = %HEXLOWER.encode(&dm.tx_id), "received ping: path already confirmed, skip"); + debug!(?src, tx = %HEXLOWER.encode(&dm.tx_id), "received ping: path already confirmed, skip"); return; } PingRole::LikelyHeartbeat => {} PingRole::NewPath => { - debug!(%src, tx = %HEXLOWER.encode(&dm.tx_id), "received ping: new path"); + debug!(?src, tx = %HEXLOWER.encode(&dm.tx_id), "received ping: new path"); } PingRole::Activate => { - debug!(%src, tx = %HEXLOWER.encode(&dm.tx_id), "received ping: path active"); + debug!(?src, tx = %HEXLOWER.encode(&dm.tx_id), "received ping: path active"); } } @@ -1289,7 +898,11 @@ impl MagicSock { txn = ?dm.tx_id, ); - if !self.send_disco_message_queued(addr.clone(), sender, pong) { + if self + .disco_sender + .try_send((addr.clone(), sender, pong)) + .is_err() + { warn!(%addr, "failed to queue pong"); } @@ -1324,19 +937,11 @@ impl MagicSock { tx_id, node_key: self.public_key(), }); - let sent = match dst { - #[cfg(not(wasm_browser))] - SendAddr::Udp(addr) => self - .udp_disco_sender - .try_send((addr, dst_node, msg)) - .is_ok(), - #[cfg(wasm_browser)] - SendAddr::Udp(_) => { - // Ignoring sending pings over UDP. We don't have a UDP socket. - return; - } - SendAddr::Relay(ref url) => self.send_disco_message_relay(url, dst_node, msg), - }; + let sent = self + .disco_sender + .try_send((dst.clone(), dst_node, msg)) + .is_ok(); + if sent { let msg_sender = self.actor_sender.clone(); trace!(%dst, tx = %HEXLOWER.encode(&tx_id), ?purpose, "ping sent (queued)"); @@ -1348,11 +953,7 @@ impl MagicSock { } /// Tries to send the ping actions. - /// - /// Note that on failure the (remaining) ping actions are simply dropped. That's bad! - /// The Endpoint will think a full ping was done and not request a new full-ping for a - /// while. We should probably be buffering the pings. - fn try_send_ping_actions(&self, msgs: Vec) -> io::Result<()> { + async fn send_ping_actions(&self, sender: &UdpSender, msgs: Vec) -> io::Result<()> { for msg in msgs { // Abort sending as soon as we know we are shutting down. if self.is_closing() || self.is_closed() { @@ -1366,160 +967,57 @@ impl MagicSock { self.send_or_queue_call_me_maybe(relay_url, dst_node); } PingAction::SendPing(ping) => { - self.try_send_ping(ping)?; + self.send_ping(sender, ping).await?; } } } Ok(()) } - /// Send a disco message. UDP messages will be queued. - /// - /// If `dst` is [`SendAddr::Relay`], the message will be pushed into the relay client channel. - /// If `dst` is [`SendAddr::Udp`], the message will be pushed into the udp disco send channel. - /// - /// Returns true if the channel had capacity for the message, and false if the message was - /// dropped. - fn send_disco_message_queued( - &self, - dst: SendAddr, - dst_key: PublicKey, - msg: disco::Message, - ) -> bool { - match dst { - SendAddr::Udp(addr) => self.udp_disco_sender.try_send((addr, dst_key, msg)).is_ok(), - SendAddr::Relay(ref url) => self.send_disco_message_relay(url, dst_key, msg), - } - } - /// Send a disco message. UDP messages will be polled to send directly on the UDP socket. - fn try_send_disco_message( + async fn send_disco_message( &self, + sender: &UdpSender, dst: SendAddr, dst_key: PublicKey, msg: disco::Message, ) -> io::Result<()> { - match dst { - #[cfg(not(wasm_browser))] - SendAddr::Udp(addr) => { - self.try_send_disco_message_udp(addr, dst_key, &msg)?; - } - #[cfg(wasm_browser)] - SendAddr::Udp(addr) => { - warn!(?addr, "Asked to send on UDP in browser code"); - } - SendAddr::Relay(ref url) => { - if !self.send_disco_message_relay(url, dst_key, msg) { - return Err(io::Error::other("Relay channel full")); - } - } - } - Ok(()) - } - - fn send_disco_message_relay(&self, url: &RelayUrl, dst: NodeId, msg: disco::Message) -> bool { - debug!(node = %dst.fmt_short(), %url, %msg, "send disco message (relay)"); - let pkt = self.encode_disco_message(dst, &msg); - self.metrics.magicsock.send_disco_relay.inc(); - match self.try_send_relay(url, dst, smallvec![pkt]) { - Ok(()) => { - if let disco::Message::CallMeMaybe(CallMeMaybe { ref my_numbers }) = msg { - event!( - target: "iroh::_events::call-me-maybe::sent", - Level::DEBUG, - remote_node = %dst.fmt_short(), - via = ?url, - addrs = ?my_numbers, - ); - } - self.metrics.magicsock.sent_disco_relay.inc(); - disco_message_sent(&msg, &self.metrics.magicsock); - true - } - Err(_) => false, - } - } - - #[cfg(not(wasm_browser))] - async fn send_disco_message_udp( - &self, - dst: SocketAddr, - dst_node: NodeId, - msg: &disco::Message, - ) -> io::Result<()> { - n0_future::future::poll_fn(move |cx| { - loop { - match self.try_send_disco_message_udp(dst, dst_node, msg) { - Ok(()) => return Poll::Ready(Ok(())), - Err(err) if err.kind() == io::ErrorKind::WouldBlock => { - // This is the socket .try_send_disco_message_udp used. - let sock = self.conn_for_addr(dst)?; - match sock.as_socket_ref().poll_writable(cx) { - Poll::Ready(Ok(())) => continue, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), - Poll::Pending => return Poll::Pending, - } - } - Err(err) => return Poll::Ready(Err(err)), - } - } - }) - .await - } + let dst = match dst { + SendAddr::Udp(addr) => transports::Addr::Ip(addr), + SendAddr::Relay(url) => transports::Addr::Relay(url, dst_key), + }; - #[cfg(not(wasm_browser))] - fn try_send_disco_message_udp( - &self, - dst: SocketAddr, - dst_node: NodeId, - msg: &disco::Message, - ) -> std::io::Result<()> { - trace!(%dst, %msg, "send disco message (UDP)"); + trace!(?dst, %msg, "send disco message (UDP)"); if self.is_closed() { return Err(io::Error::new( io::ErrorKind::NotConnected, "connection closed", )); } - let pkt = self.encode_disco_message(dst_node, msg); - // TODO: These metrics will be wrong with the poll impl - // Also - do we need it? I'd say the `sent_disco_udp` below is enough. - self.metrics.magicsock.send_disco_udp.inc(); - let transmit = quinn_udp::Transmit { - destination: dst, + let pkt = self.encode_disco_message(dst_key, &msg); + + let transmit = transports::Transmit { contents: &pkt, ecn: None, segment_size: None, - src_ip: None, // TODO }; - let sent = self.try_send_udp(dst, &transmit); - match sent { + + let dst2 = dst.clone(); + match sender.send(&dst2, None, &transmit).await { Ok(()) => { - trace!(%dst, node = %dst_node.fmt_short(), %msg, "sent disco message"); + trace!(?dst, %msg, "sent disco message"); self.metrics.magicsock.sent_disco_udp.inc(); - disco_message_sent(msg, &self.metrics.magicsock); + disco_message_sent(&msg, &self.metrics.magicsock); Ok(()) } Err(err) => { - warn!(%dst, node = %dst_node.fmt_short(), ?msg, ?err, - "failed to send disco message"); + warn!(?dst, ?msg, ?err, "failed to send disco message"); Err(err) } } } - #[instrument(skip_all)] - async fn handle_ping_actions(&mut self, msgs: Vec) { - // TODO: This used to make sure that all ping actions are sent. Though on the - // poll_send/try_send path we also do fire-and-forget. try_send_ping_actions() - // really should store any unsent pings on the Inner and send them at the next - // possible time. - if let Err(err) = self.try_send_ping_actions(msgs) { - warn!("Not all ping actions were sent: {err:#}"); - } - } - - fn try_send_ping(&self, ping: SendPing) -> io::Result<()> { + async fn send_ping(&self, sender: &UdpSender, ping: SendPing) -> io::Result<()> { let SendPing { id, dst, @@ -1531,8 +1029,10 @@ impl MagicSock { tx_id, node_key: self.public_key(), }); - self.try_send_disco_message(dst.clone(), dst_node, msg)?; - debug!(%dst, tx = %HEXLOWER.encode(&tx_id), ?purpose, "ping sent (polled)"); + + self.send_disco_message(sender, dst.clone(), dst_node, msg) + .await?; + debug!(%dst, tx = %HEXLOWER.encode(&tx_id), ?purpose, "ping sent"); let msg_sender = self.actor_sender.clone(); self.node_map .notify_ping_sent(id, dst.clone(), tx_id, purpose, msg_sender); @@ -1548,7 +1048,11 @@ impl MagicSock { .expect("poisoned") .drain() { - if !self.send_disco_message_relay(&url, public_key, msg.clone()) { + if self + .disco_sender + .try_send((SendAddr::Relay(url), public_key, msg.clone())) + .is_err() + { warn!(node = %public_key.fmt_short(), "relay channel full, dropping call-me-maybe"); } } @@ -1564,7 +1068,11 @@ impl MagicSock { Ok(()) => { let msg = self.direct_addrs.to_call_me_maybe_message(); let msg = disco::Message::CallMeMaybe(msg); - if !self.send_disco_message_relay(url, dst_node, msg) { + if self + .disco_sender + .try_send((SendAddr::Relay(url.clone()), dst_node, msg.clone())) + .is_err() + { warn!(dstkey = %dst_node.fmt_short(), relayurl = %url, "relay channel full, dropping call-me-maybe"); } else { @@ -1600,11 +1108,17 @@ impl MagicSock { if let Some(ref discovery) = self.discovery { let relay_url = self.my_relay(); let direct_addrs = self.direct_addrs.sockaddrs(); + let user_data = self .discovery_user_data .read() .expect("lock poisened") .clone(); + if relay_url.is_none() && direct_addrs.is_empty() && user_data.is_none() { + // do not bother publishing if we don't have any information + return; + } + let data = NodeData::new(relay_url, direct_addrs).with_user_data(user_data); discovery.publish(&data); } @@ -1637,45 +1151,6 @@ impl From for MappedAddr { } } -#[derive(Clone, Debug)] -enum DiscoMessageSource { - Udp(SocketAddr), - Relay { url: RelayUrl, key: PublicKey }, -} - -impl Display for DiscoMessageSource { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - Self::Udp(addr) => write!(f, "Udp({addr})"), - Self::Relay { ref url, key } => write!(f, "Relay({url}, {})", key.fmt_short()), - } - } -} - -impl From for SendAddr { - fn from(value: DiscoMessageSource) -> Self { - match value { - DiscoMessageSource::Udp(addr) => SendAddr::Udp(addr), - DiscoMessageSource::Relay { url, .. } => SendAddr::Relay(url), - } - } -} - -impl From<&DiscoMessageSource> for SendAddr { - fn from(value: &DiscoMessageSource) -> Self { - match value { - DiscoMessageSource::Udp(addr) => SendAddr::Udp(*addr), - DiscoMessageSource::Relay { url, .. } => SendAddr::Relay(url.clone()), - } - } -} - -impl DiscoMessageSource { - fn is_relay(&self) -> bool { - matches!(self, DiscoMessageSource::Relay { .. }) - } -} - /// Manages currently running direct addr discovery, aka net_report runs. /// /// Invariants: @@ -1782,20 +1257,32 @@ impl Handle { metrics, } = opts; + let addr_v4 = addr_v4.unwrap_or_else(|| SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)); + #[cfg(not(wasm_browser))] - let actor_sockets = ActorSocketState::bind(addr_v4, addr_v6, metrics.portmapper.clone()) - .context(BindSocketsSnafu)?; + let (ip_transports, port_mapper) = + bind_ip(addr_v4, addr_v6, &metrics).context(BindSocketsSnafu)?; #[cfg(not(wasm_browser))] - let sockets = actor_sockets - .msock_socket_state() - .context(CreateSocketStateSnafu)?; + let v4_socket = ip_transports + .iter() + .find(|t| t.bind_addr().is_ipv4()) + .expect("must bind a ipv4 socket") + .socket(); + #[cfg(not(wasm_browser))] + let v6_socket = ip_transports.iter().find_map(|t| { + if t.bind_addr().is_ipv6() { + Some(t.socket()) + } else { + None + } + }); let ip_mapped_addrs = IpMappedAddresses::default(); let net_reporter = net_report::Client::new( #[cfg(not(wasm_browser))] - Some(actor_sockets.port_mapper.clone()), + Some(port_mapper.clone()), #[cfg(not(wasm_browser))] dns_resolver.clone(), #[cfg(not(wasm_browser))] @@ -1804,10 +1291,7 @@ impl Handle { ); let (actor_sender, actor_receiver) = mpsc::channel(256); - let (relay_actor_sender, relay_actor_receiver) = mpsc::channel(256); - let (relay_datagram_send_tx, relay_datagram_send_rx) = relay_datagram_send_channel(); - let relay_datagram_recv_queue = Arc::new(RelayDatagramRecvQueue::new()); - let (udp_disco_sender, mut udp_disco_receiver) = mpsc::channel(256); + let (disco_sender, mut disco_receiver) = mpsc::channel(256); // load the node data let node_map = node_map.unwrap_or_default(); @@ -1816,29 +1300,46 @@ impl Handle { #[cfg(not(any(test, feature = "test-utils")))] let node_map = NodeMap::load_from_vec(node_map, &metrics.magicsock); + let my_relay = Watchable::new(None); + let ipv6_reported = Arc::new(AtomicBool::new(false)); + + let relay_transport = RelayTransport::new(RelayActorConfig { + my_relay: my_relay.clone(), + secret_key: secret_key.clone(), + #[cfg(not(wasm_browser))] + dns_resolver: dns_resolver.clone(), + proxy_url: proxy_url.clone(), + ipv6_reported: ipv6_reported.clone(), + #[cfg(any(test, feature = "test-utils"))] + insecure_skip_relay_cert_verify, + metrics: metrics.magicsock.clone(), + protocol: relay_protocol, + }); + let relay_transports = vec![relay_transport]; + let secret_encryption_key = secret_ed_box(secret_key.secret()); + #[cfg(not(wasm_browser))] + let ipv6 = ip_transports.iter().any(|t| t.bind_addr().is_ipv6()); + + #[cfg(not(wasm_browser))] + let transports = Transports::new(ip_transports, relay_transports); + #[cfg(wasm_browser)] + let transports = Transports::new(relay_transports); let msock = Arc::new(MagicSock { me, secret_key, secret_encryption_key, - proxy_url, - #[cfg(not(wasm_browser))] - sockets, closing: AtomicBool::new(false), closed: AtomicBool::new(false), - relay_datagram_recv_queue: relay_datagram_recv_queue.clone(), - relay_datagram_send_channel: relay_datagram_send_tx, - poll_recv_counter: AtomicUsize::new(0), actor_sender: actor_sender.clone(), - ipv6_reported: Arc::new(AtomicBool::new(false)), + ipv6_reported, relay_map, - my_relay: Default::default(), net_reporter: net_reporter.addr(), disco_secrets: DiscoSecrets::default(), node_map, ip_mapped_addrs, - udp_disco_sender, + disco_sender, discovery, discovery_user_data: RwLock::new(discovery_user_data), direct_addrs: Default::default(), @@ -1847,10 +1348,11 @@ impl Handle { direct_addr_update_state: DirectAddrUpdateState::new(), #[cfg(not(wasm_browser))] dns_resolver, - #[cfg(any(test, feature = "test-utils"))] - insecure_skip_relay_cert_verify, discovery_subscribers: DiscoverySubscribers::new(), metrics, + local_addrs_watch: transports.local_addrs_watch(), + #[cfg(not(wasm_browser))] + ip_bind_addrs: transports.ip_bind_addrs(), }); let mut endpoint_config = quinn::EndpointConfig::default(); @@ -1861,10 +1363,18 @@ impl Handle { // the packet if grease_quic_bit is set to false. endpoint_config.grease_quic_bit(false); + let sender1 = transports.create_sender(msock.clone()); + let sender2 = transports.create_sender(msock.clone()); + let local_addrs_watch = transports.local_addrs_watch(); + let network_change_sender = transports.create_network_change_sender(); + let endpoint = quinn::Endpoint::new_with_abstract_socket( endpoint_config, Some(server_config), - msock.clone(), + Box::new(MagicUdpSocket { + socket: msock.clone(), + transports, + }), #[cfg(not(wasm_browser))] Arc::new(quinn::TokioRuntime), #[cfg(wasm_browser)] @@ -1874,23 +1384,12 @@ impl Handle { let mut actor_tasks = JoinSet::default(); - let relay_actor = RelayActor::new(msock.clone(), relay_datagram_recv_queue, relay_protocol); - let relay_actor_cancel_token = relay_actor.cancel_token(); - actor_tasks.spawn( - async move { - relay_actor - .run(relay_actor_receiver, relay_datagram_send_rx) - .await; - } - .instrument(info_span!("relay-actor")), - ); - #[cfg(not(wasm_browser))] let _ = actor_tasks.spawn({ let msock = msock.clone(); async move { - while let Some((dst, dst_key, msg)) = udp_disco_receiver.recv().await { - if let Err(err) = msock.send_disco_message_udp(dst, dst_key, &msg).await { + while let Some((dst, dst_key, msg)) = disco_receiver.recv().await { + if let Err(err) = msock.send_disco_message(&sender1, dst.clone(), dst_key, msg).await { warn!(%dst, node = %dst_key.fmt_short(), ?err, "failed to send disco message (UDP)"); } } @@ -1902,76 +1401,58 @@ impl Handle { .context(CreateNetmonMonitorSnafu)?; let qad_endpoint = endpoint.clone(); - // create a client config for the endpoint to use for QUIC address discovery - let root_store = - rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); - let client_config = rustls::client::ClientConfig::builder_with_provider(Arc::new( - rustls::crypto::ring::default_provider(), - )) - .with_safe_default_protocol_versions() - .expect("ring supports these") - .with_root_certificates(root_store) - .with_no_client_auth(); + #[cfg(any(test, feature = "test-utils"))] + let client_config = if insecure_skip_relay_cert_verify { + iroh_relay::client::make_dangerous_client_config() + } else { + default_quic_client_config() + }; + #[cfg(not(any(test, feature = "test-utils")))] + let client_config = default_quic_client_config(); - #[cfg(not(wasm_browser))] - let quic_config = Some(QuicConfig { - ep: qad_endpoint, - client_config, - ipv4: true, - ipv6: actor_sockets.v6.is_some(), - }); - #[cfg(not(wasm_browser))] - let net_report_config = net_report::Options::default() - .stun_v4(Some(actor_sockets.v4.clone())) - .stun_v6(actor_sockets.v6.clone()) - .quic_config(quic_config); - #[cfg(wasm_browser)] let net_report_config = net_report::Options::default(); + #[cfg(not(wasm_browser))] + let net_report_config = net_report_config + .stun_v4(Some(v4_socket)) + .stun_v6(v6_socket) + .quic_config(Some(QuicConfig { + ep: qad_endpoint, + client_config, + ipv4: true, + ipv6, + })); - // Setup network monitoring - let (link_change_s, link_change_r) = mpsc::channel(8); - let netmon_token = network_monitor - .subscribe(move |is_major| { - let link_change_s = link_change_s.clone(); - async move { - link_change_s.send(is_major).await.ok(); - } - .boxed() - }) - .await - .context(SubscribeNetmonMonitorSnafu)?; - - actor_tasks.spawn({ - let msock = msock.clone(); - async move { - let actor = Actor { - msg_receiver: actor_receiver, - msg_sender: actor_sender, - relay_actor_sender, - relay_actor_cancel_token, - msock, - periodic_re_stun_timer: new_re_stun_timer(false), - net_info_last: None, - #[cfg(not(wasm_browser))] - sockets: actor_sockets, - no_v4_send: false, - net_reporter, - network_monitor, - net_report_config, - }; + #[cfg(any(test, feature = "test-utils"))] + let net_report_config = + net_report_config.insecure_skip_relay_cert_verify(insecure_skip_relay_cert_verify); + + let actor = Actor { + msg_receiver: actor_receiver, + msg_sender: actor_sender, + msock: msock.clone(), + periodic_re_stun_timer: new_re_stun_timer(false), + net_info_last: None, + #[cfg(not(wasm_browser))] + port_mapper, + no_v4_send: false, + net_reporter, + network_monitor, + net_report_config, + network_change_sender, + }; + actor_tasks.spawn( + actor + .run(local_addrs_watch, sender2) + .instrument(info_span!("actor")), + ); - actor.run(link_change_r, netmon_token).await; - } - .instrument(info_span!("actor")) - }); + let actor_tasks = Arc::new(Mutex::new(actor_tasks)); - let c = Handle { + Ok(Handle { msock, - actor_tasks: Arc::new(Mutex::new(actor_tasks)), + actor_tasks, endpoint, - }; - - Ok(c) + }) } /// The underlying [`quinn::Endpoint`] @@ -2047,6 +1528,19 @@ impl Handle { } } +fn default_quic_client_config() -> rustls::ClientConfig { + // create a client config for the endpoint to use for QUIC address discovery + let root_store = + rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + rustls::client::ClientConfig::builder_with_provider(Arc::new( + rustls::crypto::ring::default_provider(), + )) + .with_safe_default_protocol_versions() + .expect("ring supports these") + .with_root_certificates(root_store) + .with_no_client_auth() +} + #[derive(Debug, Default)] struct DiscoSecrets(std::sync::Mutex>); @@ -2083,340 +1577,96 @@ impl DiscoSecrets { mut sealed_box: Vec, ) -> Result { self.get(secret, node_id, |secret| secret.open(&mut sealed_box)) - .context(OpenSnafu)?; - disco::Message::from_bytes(&sealed_box).context(ParseSnafu) - } -} - -#[allow(missing_docs)] -#[common_fields({ - backtrace: Option, - #[snafu(implicit)] - span_trace: n0_snafu::SpanTrace, -})] -#[derive(Debug, Snafu)] -#[non_exhaustive] -enum DiscoBoxError { - #[snafu(display("Failed to open crypto box"))] - Open { - #[snafu(source(from(DecryptionError, Box::new)))] - source: Box, - }, - #[snafu(display("Failed to parse disco message"))] - Parse { - #[snafu(source(from(disco::ParseError, Box::new)))] - source: Box, - }, -} - -/// Creates a sender and receiver pair for sending datagrams to the [`RelayActor`]. -/// -/// These includes the waker coordination required to support [`AsyncUdpSocket::try_send`] -/// and [`quinn::UdpPoller::poll_writable`]. -fn relay_datagram_send_channel() -> ( - RelayDatagramSendChannelSender, - RelayDatagramSendChannelReceiver, -) { - let (sender, receiver) = mpsc::channel(256); - let wakers = Arc::new(std::sync::Mutex::new(Vec::new())); - let tx = RelayDatagramSendChannelSender { - sender, - wakers: wakers.clone(), - }; - let rx = RelayDatagramSendChannelReceiver { receiver, wakers }; - (tx, rx) -} - -/// Sender to send datagrams to the [`RelayActor`]. -/// -/// This includes the waker coordination required to support [`AsyncUdpSocket::try_send`] -/// and [`quinn::UdpPoller::poll_writable`]. -#[derive(Debug, Clone)] -struct RelayDatagramSendChannelSender { - sender: mpsc::Sender, - wakers: Arc>>, -} - -impl RelayDatagramSendChannelSender { - fn try_send( - &self, - item: RelaySendItem, - ) -> Result<(), mpsc::error::TrySendError> { - self.sender.try_send(item) - } - - fn poll_writable(&self, cx: &mut Context) -> Poll> { - match self.sender.capacity() { - 0 => { - let mut wakers = self.wakers.lock().expect("poisoned"); - if !wakers.iter().any(|waker| waker.will_wake(cx.waker())) { - wakers.push(cx.waker().clone()); - } - drop(wakers); - if self.sender.capacity() != 0 { - // We "risk" a spurious wake-up in this case, but rather that - // than potentially skipping a receive. - Poll::Ready(Ok(())) - } else { - Poll::Pending - } - } - _ => Poll::Ready(Ok(())), - } - } -} - -/// Receiver to send datagrams to the [`RelayActor`]. -/// -/// This includes the waker coordination required to support [`AsyncUdpSocket::try_send`] -/// and [`quinn::UdpPoller::poll_writable`]. -#[derive(Debug)] -struct RelayDatagramSendChannelReceiver { - receiver: mpsc::Receiver, - wakers: Arc>>, -} - -impl RelayDatagramSendChannelReceiver { - async fn recv(&mut self) -> Option { - let item = self.receiver.recv().await; - let mut wakers = self.wakers.lock().expect("poisoned"); - wakers.drain(..).for_each(Waker::wake); - item + .context(OpenSnafu)?; + disco::Message::from_bytes(&sealed_box).context(ParseSnafu) } } -/// A queue holding [`RelayRecvDatagram`]s that can be polled in async -/// contexts, and wakes up tasks when something adds items using [`try_send`]. -/// -/// This is used to transfer relay datagrams between the [`RelayActor`] -/// and [`MagicSock`]. -/// -/// [`try_send`]: Self::try_send -/// [`RelayActor`]: crate::magicsock::RelayActor -/// [`MagicSock`]: crate::magicsock::MagicSock -#[derive(Debug)] -struct RelayDatagramRecvQueue { - queue: ConcurrentQueue, - waker: AtomicWaker, -} - +#[allow(missing_docs)] +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] #[derive(Debug, Snafu)] #[non_exhaustive] -enum RelayRecvDatagramError { - #[snafu(display("Queue is closed"))] - Closed, +enum DiscoBoxError { + #[snafu(display("Failed to open crypto box"))] + Open { + #[snafu(source(from(DecryptionError, Box::new)))] + source: Box, + }, + #[snafu(display("Failed to parse disco message"))] + Parse { + #[snafu(source(from(disco::ParseError, Box::new)))] + source: Box, + }, } -impl RelayDatagramRecvQueue { - /// Creates a new, empty queue with a fixed size bound of 512 items. - fn new() -> Self { - Self { - queue: ConcurrentQueue::bounded(512), - waker: AtomicWaker::new(), - } - } - - /// Sends an item into this queue and wakes a potential task - /// that's registered its waker with a [`poll_recv`] call. - /// - /// [`poll_recv`]: Self::poll_recv - fn try_send( - &self, - item: RelayRecvDatagram, - ) -> Result<(), concurrent_queue::PushError> { - self.queue.push(item).inspect(|_| { - self.waker.wake(); - }) - } - - /// Polls for new items in the queue. - /// - /// Although this method is available from `&self`, it must not be - /// polled concurrently between tasks. - /// - /// Calling this will replace the current waker used. So if another task - /// waits for this, that task's waker will be replaced and it won't be - /// woken up for new items. - /// - /// The reason this method is made available as `&self` is because - /// the interface for quinn's [`AsyncUdpSocket::poll_recv`] requires us - /// to be able to poll from `&self`. - fn poll_recv( - &self, - cx: &mut Context, - ) -> Poll> { - match self.queue.pop() { - Ok(value) => Poll::Ready(Ok(value)), - Err(concurrent_queue::PopError::Empty) => { - self.waker.register(cx.waker()); - - match self.queue.pop() { - Ok(value) => { - self.waker.take(); - Poll::Ready(Ok(value)) - } - Err(concurrent_queue::PopError::Empty) => Poll::Pending, - Err(concurrent_queue::PopError::Closed) => { - self.waker.take(); - Poll::Ready(Err(ClosedSnafu.build())) - } - } - } - Err(concurrent_queue::PopError::Closed) => Poll::Ready(Err(ClosedSnafu.build())), - } - } +#[derive(Debug)] +struct MagicUdpSocket { + socket: Arc, + transports: Transports, } -impl AsyncUdpSocket for MagicSock { - fn create_io_poller(self: Arc) -> Pin> { - // To do this properly the MagicSock would need a registry of pollers. For each - // node we would look up the poller or create one. Then on each try_send we can - // look up the correct poller and configure it to poll the paths it needs. - // - // Note however that the current quinn impl calls UdpPoller::poll_writable() - // **before** it calls try_send(), as opposed to how it is documented. That is a - // problem as we would not yet know the path that needs to be polled. To avoid such - // ambiguity the API could be changed to a .poll_send(&self, cx: &mut Context, - // io_poller: Pin<&mut dyn UdpPoller>, transmit: &Transmit) -> Poll> - // instead of the existing .try_send() because then we would have control over this. - // - // Right now however we have one single poller behaving the same for each - // connection. It checks all paths and returns Poll::Ready as soon as any path is - // ready. - #[cfg(not(wasm_browser))] - let ipv4_poller = self.sockets.v4.create_io_poller(); - #[cfg(not(wasm_browser))] - let ipv6_poller = self.sockets.v6.as_ref().map(|sock| sock.create_io_poller()); - let relay_sender = self.relay_datagram_send_channel.clone(); - Box::pin(IoPoller { - #[cfg(not(wasm_browser))] - ipv4_poller, - #[cfg(not(wasm_browser))] - ipv6_poller, - relay_sender, - }) - } - - fn try_send(&self, transmit: &quinn_udp::Transmit) -> io::Result<()> { - self.try_send(transmit) +impl AsyncUdpSocket for MagicUdpSocket { + fn create_sender(&self) -> Pin> { + Box::pin(self.transports.create_sender(self.socket.clone())) } /// NOTE: Receiving on a closed socket will return [`Poll::Pending`] indefinitely. fn poll_recv( - &self, + &mut self, cx: &mut Context, bufs: &mut [io::IoSliceMut<'_>], metas: &mut [quinn_udp::RecvMeta], ) -> Poll> { - self.poll_recv(cx, bufs, metas) + self.transports.poll_recv(cx, bufs, metas, &self.socket) } + #[cfg(not(wasm_browser))] fn local_addr(&self) -> io::Result { - #[cfg(not(wasm_browser))] - match &*self.sockets.local_addrs.read().expect("not poisoned") { - (ipv4, None) => { - // Pretend to be IPv6, because our `MappedAddr`s - // need to be IPv6. - let ip: IpAddr = match ipv4.ip() { - IpAddr::V4(ip) => ip.to_ipv6_mapped().into(), - IpAddr::V6(ip) => ip.into(), - }; - Ok(SocketAddr::new(ip, ipv4.port())) - } - (_, Some(ipv6)) => Ok(*ipv6), - } - // Again, we need to pretend we're IPv6, because of our `MappedAddr`s. - #[cfg(wasm_browser)] - return Ok(SocketAddr::new(std::net::Ipv6Addr::LOCALHOST.into(), 0)); - } + let addrs: Vec<_> = self + .transports + .local_addrs() + .into_iter() + .filter_map(|addr| { + let addr: SocketAddr = addr.into_socket_addr()?; + Some(addr) + }) + .collect(); - #[cfg(not(wasm_browser))] - fn max_transmit_segments(&self) -> usize { - if let Some(socket) = self.sockets.v6.as_ref() { - std::cmp::min( - socket.max_transmit_segments(), - self.sockets.v4.max_transmit_segments(), - ) - } else { - self.sockets.v4.max_transmit_segments() + if let Some(addr) = addrs.iter().find(|addr| addr.is_ipv6()) { + return Ok(*addr); } - } - - #[cfg(wasm_browser)] - fn max_transmit_segments(&self) -> usize { - 1 - } - - #[cfg(not(wasm_browser))] - fn max_receive_segments(&self) -> usize { - if let Some(socket) = self.sockets.v6.as_ref() { - // `max_receive_segments` controls the size of the `RecvMeta` buffer - // that quinn creates. Having buffers slightly bigger than necessary - // isn't terrible, and makes sure a single socket can read the maximum - // amount with a single poll. We considered adding these numbers instead, - // but we never get data from both sockets at the same time in `poll_recv` - // and it's impossible and unnecessary to be refactored that way. - std::cmp::max( - socket.max_receive_segments(), - self.sockets.v4.max_receive_segments(), - ) - } else { - self.sockets.v4.max_receive_segments() + if let Some(SocketAddr::V4(addr)) = addrs.first() { + // Pretend to be IPv6, because our `MappedAddr`s need to be IPv6. + let ip = addr.ip().to_ipv6_mapped().into(); + return Ok(SocketAddr::new(ip, addr.port())); } + + Err(io::Error::other("no valid address available")) } #[cfg(wasm_browser)] - fn max_receive_segments(&self) -> usize { - 1 + fn local_addr(&self) -> io::Result { + // Again, we need to pretend we're IPv6, because of our `MappedAddr`s. + Ok(SocketAddr::new(std::net::Ipv6Addr::LOCALHOST.into(), 0)) } - #[cfg(not(wasm_browser))] - fn may_fragment(&self) -> bool { - if let Some(socket) = self.sockets.v6.as_ref() { - socket.may_fragment() || self.sockets.v4.may_fragment() - } else { - self.sockets.v4.may_fragment() - } + fn max_receive_segments(&self) -> usize { + self.transports.max_receive_segments() } - #[cfg(wasm_browser)] fn may_fragment(&self) -> bool { - false - } -} - -#[derive(Debug)] -struct IoPoller { - #[cfg(not(wasm_browser))] - ipv4_poller: Pin>, - #[cfg(not(wasm_browser))] - ipv6_poller: Option>>, - relay_sender: RelayDatagramSendChannelSender, -} - -impl quinn::UdpPoller for IoPoller { - fn poll_writable(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - // This version returns Ready as soon as any of them are ready. - let this = &mut *self; - #[cfg(not(wasm_browser))] - match this.ipv4_poller.as_mut().poll_writable(cx) { - Poll::Ready(_) => return Poll::Ready(Ok(())), - Poll::Pending => (), - } - #[cfg(not(wasm_browser))] - if let Some(ref mut ipv6_poller) = this.ipv6_poller { - match ipv6_poller.as_mut().poll_writable(cx) { - Poll::Ready(_) => return Poll::Ready(Ok(())), - Poll::Pending => (), - } - } - this.relay_sender.poll_writable(cx) + self.transports.may_fragment() } } #[derive(Debug)] enum ActorMessage { Shutdown, + PingActions(Vec), EndpointPingExpired(usize, stun_rs::TransactionId), NetReport( Result>, NetReportError>, @@ -2431,16 +1681,13 @@ struct Actor { msock: Arc, msg_receiver: mpsc::Receiver, msg_sender: mpsc::Sender, - relay_actor_sender: mpsc::Sender, - relay_actor_cancel_token: CancellationToken, /// When set, is an AfterFunc timer that will call MagicSock::do_periodic_stun. periodic_re_stun_timer: time::Interval, /// The `NetInfo` provided in the last call to `net_info_func`. It's used to deduplicate calls to netInfoFunc. net_info_last: Option, - /// Socket state, grouped so we can cfg it out as one for browsers #[cfg(not(wasm_browser))] - sockets: ActorSocketState, + port_mapper: portmapper::Client, /// Configuration for net report net_report_config: net_report::Options, @@ -2454,90 +1701,57 @@ struct Actor { net_reporter: net_report::Client, network_monitor: netmon::Monitor, -} - -/// Actor state that relies on sockets being available. -/// -/// We group these together into their own struct to make it easier to cfg out at once. -#[cfg(not(wasm_browser))] -struct ActorSocketState { - /// The NAT-PMP/PCP/UPnP prober/client, for requesting port mappings from NAT devices. - port_mapper: portmapper::Client, - - // The underlying UDP sockets used to send/rcv packets. - v4: Arc, - v6: Option>, + network_change_sender: transports::NetworkChangeSender, } #[cfg(not(wasm_browser))] -impl ActorSocketState { - fn bind( - addr_v4: Option, - addr_v6: Option, - metrics: Arc, - ) -> io::Result { - let port_mapper = portmapper::Client::with_metrics(Default::default(), metrics); - let (v4, v6) = Self::bind_sockets(addr_v4, addr_v6)?; - - let this = Self { - port_mapper, - v4, - v6, - }; - - let port = this.port_v4(); - - // NOTE: we can end up with a zero port if `netwatch::UdpSocket::socket_addr` fails - match port.try_into() { - Ok(non_zero_port) => { - this.port_mapper.update_local_port(non_zero_port); - } - Err(_zero_port) => debug!("Skipping port mapping with zero local port"), +fn bind_ip( + addr_v4: SocketAddrV4, + addr_v6: Option, + metrics: &EndpointMetrics, +) -> io::Result<(Vec, portmapper::Client)> { + let port_mapper = + portmapper::Client::with_metrics(Default::default(), metrics.portmapper.clone()); + + let v4 = Arc::new(bind_with_fallback(SocketAddr::V4(addr_v4))?); + let ip4_port = v4.local_addr()?.port(); + let ip6_port = ip4_port.checked_add(1).unwrap_or(ip4_port - 1); + + let addr_v6 = + addr_v6.unwrap_or_else(|| SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, ip6_port, 0, 0)); + + let v6 = match bind_with_fallback(SocketAddr::V6(addr_v6)) { + Ok(sock) => Some(Arc::new(sock)), + Err(err) => { + info!("bind ignoring IPv6 bind failure: {:?}", err); + None } + }; - Ok(this) - } + let port = v4.local_addr().map_or(0, |p| p.port()); - /// Returns the ipv4 port, or 0 if `netwatch::UdpSocket::socket_addr` failed. - fn port_v4(&self) -> u16 { - self.v4.local_addr().map_or(0, |p| p.port()) + let mut ip = vec![IpTransport::new( + addr_v4.into(), + v4, + metrics.magicsock.clone(), + )]; + if let Some(v6) = v6 { + ip.push(IpTransport::new( + addr_v6.into(), + v6, + metrics.magicsock.clone(), + )) } - fn bind_sockets( - addr_v4: Option, - addr_v6: Option, - ) -> io::Result<(Arc, Option>)> { - let addr_v4 = addr_v4.unwrap_or_else(|| SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)); - let v4 = Arc::new(bind_with_fallback(SocketAddr::V4(addr_v4))?); - - let ip4_port = v4.local_addr()?.port(); - let ip6_port = ip4_port.checked_add(1).unwrap_or(ip4_port - 1); - let addr_v6 = - addr_v6.unwrap_or_else(|| SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, ip6_port, 0, 0)); - let v6 = match bind_with_fallback(SocketAddr::V6(addr_v6)) { - Ok(sock) => Some(Arc::new(sock)), - Err(err) => { - info!("bind ignoring IPv6 bind failure: {:?}", err); - None - } - }; - - Ok((v4, v6)) + // NOTE: we can end up with a zero port if `netwatch::UdpSocket::socket_addr` fails + match port.try_into() { + Ok(non_zero_port) => { + port_mapper.update_local_port(non_zero_port); + } + Err(_zero_port) => debug!("Skipping port mapping with zero local port"), } - fn msock_socket_state(&self) -> io::Result { - let ipv4_addr = self.v4.local_addr()?; - let ipv6_addr = self.v6.as_ref().and_then(|c| c.local_addr().ok()); - - let socket_state = SocketState { - port: AtomicU16::new(self.port_v4()), - local_addrs: std::sync::RwLock::new((ipv4_addr, ipv6_addr)), - v4: UdpConn::wrap(self.v4.clone()), - v6: self.v6.clone().map(UdpConn::wrap), - }; - - Ok(socket_state) - } + Ok((ip, port_mapper)) } #[derive(Debug, Snafu)] @@ -2552,17 +1766,21 @@ enum NetReportError { } impl Actor { - async fn run(mut self, mut link_change_r: mpsc::Receiver, _token: CallbackToken) { - // Let the the heartbeat only start a couple seconds later + async fn run( + mut self, + mut watcher: impl Watcher> + Send + Sync, + sender: UdpSender, + ) { + // Setup network monitoring + let mut netmon_watcher = self.network_monitor.interface_state(); + let mut current_netmon_state = netmon_watcher.get().expect("missing network state"); + #[cfg(not(wasm_browser))] - let mut direct_addr_heartbeat_timer = time::interval_at( - time::Instant::now() + HEARTBEAT_INTERVAL, - HEARTBEAT_INTERVAL, - ); + let mut direct_addr_heartbeat_timer = time::interval(HEARTBEAT_INTERVAL); let mut direct_addr_update_receiver = self.msock.direct_addr_update_state.running.subscribe(); #[cfg(not(wasm_browser))] - let mut portmap_watcher = self.sockets.port_mapper.watch_external_address(); + let mut portmap_watcher = self.port_mapper.watch_external_address(); let mut discovery_events: BoxStream = Box::pin(n0_future::stream::empty()); if let Some(d) = self.msock.discovery() { @@ -2574,7 +1792,7 @@ impl Actor { let mut receiver_closed = false; #[cfg_attr(wasm_browser, allow(unused_mut))] let mut portmap_watcher_closed = false; - let mut link_change_closed = false; + loop { self.msock.metrics.magicsock.actor_tick_main.inc(); #[cfg(not(wasm_browser))] @@ -2599,7 +1817,7 @@ impl Actor { trace!(?msg, "tick: msg"); self.msock.metrics.magicsock.actor_tick_msg.inc(); - if self.handle_actor_message(msg).await { + if self.handle_actor_message(msg, &sender).await { return; } } @@ -2608,6 +1826,19 @@ impl Actor { self.msock.metrics.magicsock.actor_tick_re_stun.inc(); self.msock.re_stun("periodic"); } + new_addr = watcher.updated() => { + match new_addr { + Ok(addrs) => { + if !addrs.is_empty() { + trace!(?addrs, "local addrs"); + self.msock.publish_my_addr(); + } + } + Err(_) => { + warn!("local addr watcher stopped"); + } + } + } change = portmap_watcher_changed, if !portmap_watcher_closed => { #[cfg(not(wasm_browser))] { @@ -2640,7 +1871,7 @@ impl Actor { self.msock.node_map.prune_inactive(); let msgs = self.msock.node_map.nodes_stayin_alive(); - self.handle_ping_actions(msgs).await; + self.handle_ping_actions(&sender, msgs).await; } } _ = direct_addr_update_receiver.changed() => { @@ -2651,18 +1882,17 @@ impl Actor { self.refresh_direct_addrs(reason).await; } } - is_major = link_change_r.recv(), if !link_change_closed => { - let Some(is_major) = is_major else { + state = netmon_watcher.updated() => { + let Ok(state) = state else { trace!("tick: link change receiver closed"); self.msock.metrics.magicsock.actor_tick_other.inc(); - - link_change_closed = true; continue; }; - + let is_major = state.is_major_change(¤t_netmon_state); + current_netmon_state = state; trace!("tick: link change {}", is_major); self.msock.metrics.magicsock.actor_link_change.inc(); - self.handle_network_change(is_major).await; + self.handle_network_change(is_major); } // Even if `discovery_events` yields `None`, it could begin to yield // `Some` again in the future, so we don't want to disable this branch @@ -2686,24 +1916,17 @@ impl Actor { } } - async fn handle_network_change(&mut self, is_major: bool) { + fn handle_network_change(&mut self, is_major: bool) { debug!("link change detected: major? {}", is_major); if is_major { - #[cfg(not(wasm_browser))] - { - if let Err(err) = self.sockets.v4.rebind() { - warn!("failed to rebind Udp IPv4 socket: {:?}", err); - }; - if let Some(ref socket) = self.sockets.v6 { - if let Err(err) = socket.rebind() { - warn!("failed to rebind Udp IPv6 socket: {:?}", err); - }; - } - self.msock.dns_resolver.clear_cache(); + if let Err(err) = self.network_change_sender.rebind() { + warn!("failed to rebind transports: {:?}", err); } + + #[cfg(not(wasm_browser))] + self.msock.dns_resolver.clear_cache(); self.msock.re_stun("link-change-major"); - self.close_stale_relay_connections().await; self.reset_endpoint_states(); } else { self.msock.re_stun("link-change-minor"); @@ -2711,28 +1934,23 @@ impl Actor { } #[instrument(skip_all)] - async fn handle_ping_actions(&mut self, msgs: Vec) { - // TODO: This used to make sure that all ping actions are sent. Though on the - // poll_send/try_send path we also do fire-and-forget. try_send_ping_actions() - // really should store any unsent pings on the Inner and send them at the next - // possible time. - if let Err(err) = self.msock.try_send_ping_actions(msgs) { - warn!("Not all ping actions were sent: {err:#}"); + async fn handle_ping_actions(&mut self, sender: &UdpSender, msgs: Vec) { + if let Err(err) = self.msock.send_ping_actions(sender, msgs).await { + warn!("Failed to send ping actions: {err:#}"); } } /// Processes an incoming actor message. /// /// Returns `true` if it was a shutdown. - async fn handle_actor_message(&mut self, msg: ActorMessage) -> bool { + async fn handle_actor_message(&mut self, msg: ActorMessage, sender: &UdpSender) -> bool { match msg { ActorMessage::Shutdown => { debug!("shutting down"); self.msock.node_map.notify_shutdown(); #[cfg(not(wasm_browser))] - self.sockets.port_mapper.deactivate(); - self.relay_actor_cancel_token.cancel(); + self.port_mapper.deactivate(); debug!("shutdown complete"); return true; @@ -2759,7 +1977,10 @@ impl Actor { } #[cfg(test)] ActorMessage::ForceNetworkChange(is_major) => { - self.handle_network_change(is_major).await; + self.handle_network_change(is_major); + } + ActorMessage::PingActions(ping_actions) => { + self.handle_ping_actions(sender, ping_actions).await; } } @@ -2779,7 +2000,7 @@ impl Actor { debug!("starting direct addr update ({})", why); #[cfg(not(wasm_browser))] - self.sockets.port_mapper.procure_mapping(); + self.port_mapper.procure_mapping(); self.update_net_info(why).await; } @@ -2793,7 +2014,7 @@ impl Actor { /// - The local interfaces IP addresses. #[cfg(not(wasm_browser))] fn update_direct_addresses(&mut self, net_report_report: Option>) { - let portmap_watcher = self.sockets.port_mapper.watch_external_address(); + let portmap_watcher = self.port_mapper.watch_external_address(); // We only want to have one DirectAddr for each SocketAddr we have. So we store // this as a map of SocketAddr -> DirectAddrType. At the end we will construct a @@ -2820,17 +2041,25 @@ impl Actor { // port locally, assume they might've added a static // port mapping on their router to the same explicit // port that we are running with. Worst case it's an invalid candidate mapping. - let port = self.msock.sockets.port.load(Ordering::Relaxed); - if net_report_report - .mapping_varies_by_dest_ip - .unwrap_or_default() - && port != 0 - { - let mut addr = global_v4; - addr.set_port(port); - addrs - .entry(addr.into()) - .or_insert(DirectAddrType::Stun4LocalPort); + let port = self.msock.ip_bind_addrs().iter().find_map(|addr| { + if addr.port() != 0 { + Some(addr.port()) + } else { + None + } + }); + + if let Some(port) = port { + if net_report_report + .mapping_varies_by_dest_ip + .unwrap_or_default() + { + let mut addr = global_v4; + addr.set_port(port); + addrs + .entry(addr.into()) + .or_insert(DirectAddrType::Stun4LocalPort); + } } } if let Some(global_v6) = net_report_report.global_v6 { @@ -2840,17 +2069,29 @@ impl Actor { } } - let local_addr_v4 = self.sockets.v4.local_addr().ok(); - let local_addr_v6 = self.sockets.v6.as_ref().and_then(|c| c.local_addr().ok()); - - let is_unspecified_v4 = local_addr_v4 - .map(|a| a.ip().is_unspecified()) - .unwrap_or(false); - let is_unspecified_v6 = local_addr_v6 - .map(|a| a.ip().is_unspecified()) - .unwrap_or(false); + let local_addrs: Vec<(SocketAddr, SocketAddr)> = self + .msock + .ip_bind_addrs() + .iter() + .copied() + .zip(self.msock.ip_local_addrs()) + .collect(); let msock = self.msock.clone(); + let has_ipv4_unspecified = local_addrs.iter().find_map(|(_, a)| { + if a.is_ipv4() && a.ip().is_unspecified() { + Some(a.port()) + } else { + None + } + }); + let has_ipv6_unspecified = local_addrs.iter().find_map(|(_, a)| { + if a.is_ipv6() && a.ip().is_unspecified() { + Some(a.port()) + } else { + None + } + }); // The following code can be slow, we do not want to block the caller since it would // block the actor loop. @@ -2858,7 +2099,10 @@ impl Actor { async move { // If a socket is bound to the unspecified address, create SocketAddrs for // each local IP address by pairing it with the port the socket is bound on. - if is_unspecified_v4 || is_unspecified_v6 { + if local_addrs + .iter() + .any(|(_, local)| local.ip().is_unspecified()) + { // Depending on the OS and network interfaces attached and their state // enumerating the local interfaces can take a long time. Especially // Windows is very slow. @@ -2873,15 +2117,11 @@ impl Actor { // or public addresses, this allows testing offline. ips = loopback; } + for ip in ips { let port_if_unspecified = match ip { - IpAddr::V4(_) if is_unspecified_v4 => { - local_addr_v4.map(|addr| addr.port()) - } - IpAddr::V6(_) if is_unspecified_v6 => { - local_addr_v6.map(|addr| addr.port()) - } - _ => None, + IpAddr::V4(_) => has_ipv4_unspecified, + IpAddr::V6(_) => has_ipv6_unspecified, }; if let Some(port) = port_if_unspecified { let addr = SocketAddr::new(ip, port); @@ -2891,14 +2131,9 @@ impl Actor { } // If a socket is bound to a specific address, add it. - if !is_unspecified_v4 { - if let Some(addr) = local_addr_v4 { - addrs.entry(addr).or_insert(DirectAddrType::Local); - } - } - if !is_unspecified_v6 { - if let Some(addr) = local_addr_v6 { - addrs.entry(addr).or_insert(DirectAddrType::Local); + for (bound, local) in local_addrs { + if !bound.ip().is_unspecified() { + addrs.entry(local).or_insert(DirectAddrType::Local); } } @@ -3029,12 +2264,7 @@ impl Actor { self.no_v4_send = !r.ipv4_can_send; #[cfg(not(wasm_browser))] - let have_port_map = self - .sockets - .port_mapper - .watch_external_address() - .borrow() - .is_some(); + let have_port_map = self.port_mapper.watch_external_address().borrow().is_some(); #[cfg(wasm_browser)] let have_port_map = false; @@ -3066,7 +2296,8 @@ impl Actor { ni.preferred_relay = self.pick_relay_fallback(); } - self.set_nearest_relay(ni.preferred_relay.clone()); + // Notify all transports + self.network_change_sender.on_network_change(&ni); // TODO: set link type self.call_net_info_callback(ni).await; @@ -3075,28 +2306,6 @@ impl Actor { self.update_direct_addresses(report); } - fn set_nearest_relay(&mut self, relay_url: Option) { - let my_relay = self.msock.my_relay(); - if relay_url == my_relay { - // No change. - return; - } - let old_relay = self.msock.set_my_relay(relay_url.clone()); - - if let Some(ref relay_url) = relay_url { - self.msock.metrics.magicsock.relay_home_change.inc(); - - // On change, notify all currently connected relay servers and - // start connecting to our home relay if we are not already. - info!("home is now relay {}, was {:?}", relay_url, old_relay); - self.msock.publish_my_addr(); - - self.send_relay_actor(RelayActorMessage::SetHome { - url: relay_url.clone(), - }); - } - } - /// Returns a deterministic relay node to connect to. This is only used if net_report /// couldn't find the nearest one, for instance, if UDP is blocked and thus STUN /// latency checks aren't working. @@ -3129,38 +2338,6 @@ impl Actor { fn reset_endpoint_states(&mut self) { self.msock.node_map.reset_node_states() } - - /// Tells the relay actor to close stale relay connections. - /// - /// The relay connections who's local endpoints no longer exist after a network change - /// will error out soon enough. Closing them eagerly speeds this up however and allows - /// re-establishing a relay connection faster. - async fn close_stale_relay_connections(&self) { - let ifs = interfaces::State::new().await; - #[cfg(not(wasm_browser))] - let local_ips = ifs - .interfaces - .values() - .flat_map(|netif| netif.addrs()) - .map(|ipnet| ipnet.addr()) - .collect(); - // In browsers, we don't have this information. This will do the right thing in the ActiveRelayActor, though. - #[cfg(wasm_browser)] - let local_ips = Vec::new(); - self.send_relay_actor(RelayActorMessage::MaybeCloseRelaysOnRebind(local_ips)); - } - - fn send_relay_actor(&self, msg: RelayActorMessage) { - match self.relay_actor_sender.try_send(msg) { - Ok(_) => {} - Err(mpsc::error::TrySendError::Closed(_)) => { - warn!("unable to send to relay actor, already closed"); - } - Err(mpsc::error::TrySendError::Full(_)) => { - warn!("dropping message for relay actor, channel is full"); - } - } - } } fn new_re_stun_timer(initial_delay: bool) -> time::Interval { @@ -3279,28 +2456,6 @@ impl DiscoveredDirectAddrs { } } -/// Split a transmit containing a GSO payload into individual packets. -/// -/// This allocates the data. -/// -/// If the transmit has a segment size it contains multiple GSO packets. It will be split -/// into multiple packets according to that segment size. If it does not have a segment -/// size, the contents will be sent as a single packet. -// TODO: If quinn stayed on bytes this would probably be much cheaper, probably. Need to -// figure out where they allocate the Vec. -fn split_packets(transmit: &quinn_udp::Transmit) -> RelayContents { - let mut res = SmallVec::with_capacity(1); - let contents = transmit.contents; - if let Some(segment_size) = transmit.segment_size { - for chunk in contents.chunks(segment_size) { - res.push(Bytes::from(chunk.to_vec())); - } - } else { - res.push(Bytes::from(contents.to_vec())); - } - res -} - /// The fake address used by the QUIC layer to address a node. /// /// You can consider this as nothing more than a lookup key for a node the [`MagicSock`] knows @@ -3454,7 +2609,7 @@ impl Display for DirectAddrType { /// Contains information about the host's network state. #[derive(Debug, Clone, PartialEq)] -struct NetInfo { +pub(crate) struct NetInfo { /// Says whether the host's NAT mappings vary based on the destination IP. mapping_varies_by_dest_ip: Option, @@ -3534,30 +2689,25 @@ impl NetInfo { mod tests { use std::{collections::BTreeSet, sync::Arc, time::Duration}; - use bytes::Bytes; use data_encoding::HEXLOWER; - use iroh_base::{NodeAddr, NodeId, PublicKey, RelayUrl, SecretKey}; + use iroh_base::{NodeAddr, NodeId, PublicKey, SecretKey}; use iroh_relay::RelayMap; use n0_future::{time, StreamExt}; use n0_snafu::{Result, ResultExt}; + use n0_watcher::Watcher; use quinn::ServerConfig; use rand::{Rng, RngCore}; use tokio::task::JoinSet; use tokio_util::task::AbortOnDropHandle; - use tracing::{debug, error, info, info_span, instrument, Instrument}; + use tracing::{error, info, info_span, instrument, Instrument}; use tracing_test::traced_test; - use super::{split_packets, NodeIdMappedAddr, Options}; + use super::{NodeIdMappedAddr, Options}; use crate::{ - defaults::staging::{self, EU_RELAY_HOSTNAME}, dns::DnsResolver, endpoint::{DirectAddr, PathSelection, Source}, - magicsock::{ - node_map, Handle, MagicSock, RelayContents, RelayDatagramRecvQueue, RelayRecvDatagram, - }, - tls, - watcher::Watcher as _, - Endpoint, RelayMode, + magicsock::{node_map, Handle, MagicSock}, + tls, Endpoint, RelayMode, }; const ALPN: &[u8] = b"n0/test/1"; @@ -4069,46 +3219,6 @@ mod tests { Ok(()) } - #[test] - fn test_split_packets() { - fn mk_transmit(contents: &[u8], segment_size: Option) -> quinn_udp::Transmit<'_> { - let destination = "127.0.0.1:0".parse().unwrap(); - quinn_udp::Transmit { - destination, - ecn: None, - contents, - segment_size, - src_ip: None, - } - } - fn mk_expected(parts: impl IntoIterator) -> RelayContents { - parts - .into_iter() - .map(|p| p.as_bytes().to_vec().into()) - .collect() - } - // no split - assert_eq!( - split_packets(&mk_transmit(b"hello", None)), - mk_expected(["hello"]) - ); - // split without rest - assert_eq!( - split_packets(&mk_transmit(b"helloworld", Some(5))), - mk_expected(["hello", "world"]) - ); - // split with rest and second transmit - assert_eq!( - split_packets(&mk_transmit(b"hello world", Some(5))), - mk_expected(["hello", " worl", "d"]) // spellchecker:disable-line - ); - // split that results in 1 packet - assert_eq!( - split_packets(&mk_transmit(b"hello world", Some(1000))), - mk_expected(["hello world"]) - ); - } - #[tokio::test] #[traced_test] async fn test_local_endpoints() { @@ -4125,36 +3235,6 @@ mod tests { assert_eq!(eps0, eps1); } - #[tokio::test] - async fn test_watch_home_relay() { - // use an empty relay map to get full control of the changes during the test - let ops = Options { - relay_map: RelayMap::empty(), - ..Default::default() - }; - let msock = MagicSock::spawn(ops).await.unwrap(); - let mut relay_stream = msock.home_relay().stream().filter_map(|r| r); - - // no relay, nothing to report - assert_eq!( - n0_future::future::poll_once(relay_stream.next()).await, - None - ); - - let url: RelayUrl = format!("https://{}", EU_RELAY_HOSTNAME).parse().unwrap(); - msock.set_my_relay(Some(url.clone())); - - assert_eq!(relay_stream.next().await, Some(url.clone())); - - // drop the stream and query it again, the result should be immediately available - - let mut relay_stream = msock.home_relay().stream().filter_map(|r| r); - assert_eq!( - n0_future::future::poll_once(relay_stream.next()).await, - Some(Some(url)) - ); - } - /// Creates a new [`quinn::Endpoint`] hooked up to a [`MagicSock`]. /// /// This is without involving [`crate::endpoint::Endpoint`]. The socket will accept @@ -4181,7 +3261,7 @@ mod tests { dns_resolver, proxy_url: None, server_config, - insecure_skip_relay_cert_verify: true, + insecure_skip_relay_cert_verify: false, path_selection: PathSelection::default(), metrics: Default::default(), }; @@ -4453,60 +3533,6 @@ mod tests { // But we don't have that much private access to the NodeMap. This will do for now. } - #[tokio::test(flavor = "multi_thread")] - async fn test_relay_datagram_queue() { - let queue = Arc::new(RelayDatagramRecvQueue::new()); - let url = staging::default_na_relay_node().url; - let capacity = queue.queue.capacity().unwrap(); - - let mut tasks = JoinSet::new(); - - tasks.spawn({ - let queue = queue.clone(); - async move { - let mut expected_msgs: BTreeSet = (0..capacity).collect(); - while !expected_msgs.is_empty() { - let datagram = n0_future::future::poll_fn(|cx| { - queue.poll_recv(cx).map(|result| result.unwrap()) - }) - .await; - - let msg_num = usize::from_le_bytes(datagram.buf.as_ref().try_into().unwrap()); - debug!("Received {msg_num}"); - - if !expected_msgs.remove(&msg_num) { - panic!("Received message number {msg_num} twice or more, but expected it only exactly once."); - } - } - } - }); - - for i in 0..capacity { - tasks.spawn({ - let queue = queue.clone(); - let url = url.clone(); - async move { - debug!("Sending {i}"); - queue - .try_send(RelayRecvDatagram { - url, - src: PublicKey::from_bytes(&[0u8; 32]).unwrap(), - buf: Bytes::copy_from_slice(&i.to_le_bytes()), - }) - .unwrap(); - } - }); - } - - // We expect all of this work to be done in 10 seconds max. - if tokio::time::timeout(Duration::from_secs(10), tasks.join_all()) - .await - .is_err() - { - panic!("Timeout - not all messages between 0 and {capacity} received."); - } - } - #[tokio::test] async fn test_add_node_addr() -> Result { let stack = MagicStack::new(RelayMode::Default).await; diff --git a/iroh/src/magicsock/node_map.rs b/iroh/src/magicsock/node_map.rs index 6392ebeb0c4..c15aa69c008 100644 --- a/iroh/src/magicsock/node_map.rs +++ b/iroh/src/magicsock/node_map.rs @@ -15,13 +15,10 @@ use self::{ best_addr::ClearReason, node_state::{NodeState, Options, PingHandled}, }; -use super::{metrics::Metrics, ActorMessage, DiscoMessageSource, NodeIdMappedAddr}; +use super::{metrics::Metrics, transports, ActorMessage, NodeIdMappedAddr}; +use crate::disco::{CallMeMaybe, Pong, SendAddr}; #[cfg(any(test, feature = "test-utils"))] use crate::endpoint::PathSelection; -use crate::{ - disco::{CallMeMaybe, Pong, SendAddr}, - watcher, -}; mod best_addr; mod node_state; @@ -228,7 +225,7 @@ impl NodeMap { .handle_ping(sender, src, tx_id) } - pub(super) fn handle_pong(&self, sender: PublicKey, src: &DiscoMessageSource, pong: Pong) { + pub(super) fn handle_pong(&self, sender: PublicKey, src: &transports::Addr, pong: Pong) { self.inner .lock() .expect("poisoned") @@ -264,8 +261,8 @@ impl NodeMap { let ep = inner.get_mut(NodeStateKey::NodeIdMappedAddr(addr))?; let public_key = *ep.public_key(); trace!(dest = %addr, node_id = %public_key.fmt_short(), "dst mapped to NodeId"); - let (udp_addr, relay_url, msgs) = ep.get_send_addrs(have_ipv6, metrics); - Some((public_key, udp_addr, relay_url, msgs)) + let (udp_addr, relay_url, ping_actions) = ep.get_send_addrs(have_ipv6, metrics); + Some((public_key, udp_addr, relay_url, ping_actions)) } pub(super) fn notify_shutdown(&self) { @@ -303,13 +300,13 @@ impl NodeMap { .collect() } - /// Returns a [`watcher::Direct`] for given node's [`ConnectionType`]. + /// Returns a [`n0_watcher::Direct`] for given node's [`ConnectionType`]. /// /// # Errors /// /// Will return `None` if there is not an entry in the [`NodeMap`] for /// the `node_id` - pub(super) fn conn_type(&self, node_id: NodeId) -> Option> { + pub(super) fn conn_type(&self, node_id: NodeId) -> Option> { self.inner.lock().expect("poisoned").conn_type(node_id) } @@ -505,14 +502,14 @@ impl NodeMapInner { /// /// Will return `None` if there is not an entry in the [`NodeMap`] for /// the `public_key` - fn conn_type(&self, node_id: NodeId) -> Option> { + fn conn_type(&self, node_id: NodeId) -> Option> { self.get(NodeStateKey::NodeId(node_id)) .map(|ep| ep.conn_type()) } - fn handle_pong(&mut self, sender: NodeId, src: &DiscoMessageSource, pong: Pong) { + fn handle_pong(&mut self, sender: NodeId, src: &transports::Addr, pong: Pong) { if let Some(ns) = self.get_mut(NodeStateKey::NodeId(sender)).as_mut() { - let insert = ns.handle_pong(&pong, src.into()); + let insert = ns.handle_pong(&pong, src.clone().into()); if let Some((src, key)) = insert { self.set_node_key_for_ip_port(src, &key); } diff --git a/iroh/src/magicsock/node_map/node_state.rs b/iroh/src/magicsock/node_map/node_state.rs index b9c80198cbf..f48e5488c3a 100644 --- a/iroh/src/magicsock/node_map/node_state.rs +++ b/iroh/src/magicsock/node_map/node_state.rs @@ -11,6 +11,7 @@ use n0_future::{ task::{self, AbortOnDropHandle}, time::{self, Duration, Instant}, }; +use n0_watcher::Watchable; use serde::{Deserialize, Serialize}; use tokio::sync::mpsc; use tracing::{debug, event, info, instrument, trace, warn, Level}; @@ -26,7 +27,6 @@ use crate::endpoint::PathSelection; use crate::{ disco::{self, SendAddr}, magicsock::{ActorMessage, MagicsockMetrics, NodeIdMappedAddr, HEARTBEAT_INTERVAL}, - watcher::{self, Watchable}, }; /// Number of addresses that are not active that we keep around per node. @@ -202,7 +202,7 @@ impl NodeState { self.id } - pub(super) fn conn_type(&self) -> watcher::Direct { + pub(super) fn conn_type(&self) -> n0_watcher::Direct { self.conn_type.watch() } @@ -586,7 +586,6 @@ impl NodeState { } } } - // We send pings regardless of whether we have a RelayUrl. If we were given any // direct address paths to contact but no RelayUrl, we still need to send a DISCO // ping to the direct address paths so that the other node will learn about us and @@ -1177,19 +1176,18 @@ impl NodeState { metrics.nodes_contacted.inc(); } let (udp_addr, relay_url) = self.addr_for_send(&now, have_ipv6, metrics); - let mut ping_msgs = Vec::new(); - - if self.want_call_me_maybe(&now) { - ping_msgs = self.send_call_me_maybe(now, SendCallMeMaybe::IfNoRecent); - } + let ping_msgs = if self.want_call_me_maybe(&now) { + self.send_call_me_maybe(now, SendCallMeMaybe::IfNoRecent) + } else { + Vec::new() + }; trace!( ?udp_addr, ?relay_url, pings = %ping_msgs.len(), "found send address", ); - (udp_addr, relay_url, ping_msgs) } diff --git a/iroh/src/magicsock/transports.rs b/iroh/src/magicsock/transports.rs new file mode 100644 index 00000000000..b37f9ee3b94 --- /dev/null +++ b/iroh/src/magicsock/transports.rs @@ -0,0 +1,597 @@ +use std::{ + io::{self, IoSliceMut}, + net::{IpAddr, Ipv6Addr, SocketAddr, SocketAddrV6}, + pin::Pin, + sync::{atomic::AtomicUsize, Arc}, + task::{Context, Poll}, +}; + +use iroh_base::{NodeId, RelayUrl}; +use n0_watcher::Watcher; +use relay::{RelayNetworkChangeSender, RelaySender}; +use smallvec::SmallVec; +use tracing::{error, trace, warn}; + +#[cfg(not(wasm_browser))] +mod ip; +mod relay; + +#[cfg(not(wasm_browser))] +pub(crate) use self::ip::IpTransport; +#[cfg(not(wasm_browser))] +use self::ip::{IpNetworkChangeSender, IpSender}; +pub(crate) use self::relay::{RelayActorConfig, RelayTransport}; +use super::{MagicSock, NetInfo}; + +/// Manages the different underlying data transports that the magicsock +/// can support. +#[derive(Debug)] +pub(crate) struct Transports { + #[cfg(not(wasm_browser))] + ip: Vec, + relay: Vec, + + poll_recv_counter: AtomicUsize, +} + +#[cfg(not(wasm_browser))] +pub(crate) type LocalAddrsWatch = n0_watcher::Map< + ( + n0_watcher::Join>, + n0_watcher::Join< + Option<(RelayUrl, NodeId)>, + n0_watcher::Map>, Option<(RelayUrl, NodeId)>>, + >, + ), + Vec, +>; + +#[cfg(wasm_browser)] +pub(crate) type LocalAddrsWatch = n0_watcher::Map< + n0_watcher::Join< + Option<(RelayUrl, NodeId)>, + n0_watcher::Map>, Option<(RelayUrl, NodeId)>>, + >, + Vec, +>; + +impl Transports { + /// Creates a new transports structure. + pub(crate) fn new( + #[cfg(not(wasm_browser))] ip: Vec, + relay: Vec, + ) -> Self { + Self { + #[cfg(not(wasm_browser))] + ip, + relay, + poll_recv_counter: Default::default(), + } + } + + pub(crate) fn poll_recv( + &mut self, + cx: &mut Context, + bufs: &mut [io::IoSliceMut<'_>], + metas: &mut [quinn_udp::RecvMeta], + msock: &MagicSock, + ) -> Poll> { + debug_assert_eq!(bufs.len(), metas.len(), "non matching bufs & metas"); + if msock.is_closed() { + return Poll::Pending; + } + + let mut source_addrs = vec![Addr::default(); metas.len()]; + match self.inner_poll_recv(cx, bufs, metas, &mut source_addrs)? { + Poll::Pending | Poll::Ready(0) => Poll::Pending, + Poll::Ready(n) => { + msock.process_datagrams(&mut bufs[..n], &mut metas[..n], &source_addrs[..n]); + Poll::Ready(Ok(n)) + } + } + } + + /// Tries to recv data, on all available transports. + fn inner_poll_recv( + &mut self, + cx: &mut Context, + bufs: &mut [IoSliceMut<'_>], + metas: &mut [quinn_udp::RecvMeta], + source_addrs: &mut [Addr], + ) -> Poll> { + debug_assert_eq!(bufs.len(), metas.len(), "non matching bufs & metas"); + + macro_rules! poll_transport { + ($socket:expr) => { + match $socket.poll_recv(cx, bufs, metas, source_addrs)? { + Poll::Pending | Poll::Ready(0) => {} + Poll::Ready(n) => { + return Poll::Ready(Ok(n)); + } + } + }; + } + + // To improve fairness, every other call reverses the ordering of polling. + + let counter = self + .poll_recv_counter + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + if counter % 2 == 0 { + #[cfg(not(wasm_browser))] + for transport in &mut self.ip { + poll_transport!(transport); + } + for transport in &mut self.relay { + poll_transport!(transport); + } + } else { + for transport in self.relay.iter_mut().rev() { + poll_transport!(transport); + } + #[cfg(not(wasm_browser))] + for transport in self.ip.iter_mut().rev() { + poll_transport!(transport); + } + } + + Poll::Pending + } + + /// Returns a list of all currently known local addresses. + /// + /// For IP based transports this is the [`SocketAddr`] of the socket, + /// for relay transports, this is the home relay. + pub(crate) fn local_addrs(&self) -> Vec { + self.local_addrs_watch().get().expect("not disconnected") + } + + /// Watch for all currently known local addresses. + #[cfg(not(wasm_browser))] + pub(crate) fn local_addrs_watch(&self) -> LocalAddrsWatch { + let ips = n0_watcher::Join::new(self.ip.iter().map(|t| t.local_addr_watch())); + let relays = n0_watcher::Join::new(self.relay.iter().map(|t| t.local_addr_watch())); + + (ips, relays) + .map(|(ips, relays)| { + ips.into_iter() + .map(Addr::from) + .chain( + relays + .into_iter() + .flatten() + .map(|(relay_url, node_id)| Addr::Relay(relay_url, node_id)), + ) + .collect() + }) + .expect("disconnected") + } + + #[cfg(wasm_browser)] + pub(crate) fn local_addrs_watch(&self) -> LocalAddrsWatch { + let relays = self.relay.iter().map(|t| t.local_addr_watch()); + n0_watcher::Join::new(relays) + .map(|relays| relays.into_iter().flatten().map(Addr::from).collect()) + .expect("disconnected") + } + + /// Returns the bound addresses for IP based transports + #[cfg(not(wasm_browser))] + pub(crate) fn ip_bind_addrs(&self) -> Vec { + self.ip.iter().map(|t| t.bind_addr()).collect() + } + + #[cfg(not(wasm_browser))] + pub(crate) fn max_transmit_segments(&self) -> usize { + let res = self.ip.iter().map(|t| t.max_transmit_segments()).min(); + res.unwrap_or(1) + } + + #[cfg(wasm_browser)] + pub(crate) fn max_transmit_segments(&self) -> usize { + 1 + } + + #[cfg(not(wasm_browser))] + pub(crate) fn max_receive_segments(&self) -> usize { + // `max_receive_segments` controls the size of the `RecvMeta` buffer + // that quinn creates. Having buffers slightly bigger than necessary + // isn't terrible, and makes sure a single socket can read the maximum + // amount with a single poll. We considered adding these numbers instead, + // but we never get data from both sockets at the same time in `poll_recv` + // and it's impossible and unnecessary to be refactored that way. + + let res = self.ip.iter().map(|t| t.max_receive_segments()).max(); + res.unwrap_or(1) + } + + #[cfg(wasm_browser)] + pub(crate) fn max_receive_segments(&self) -> usize { + 1 + } + + #[cfg(not(wasm_browser))] + pub(crate) fn may_fragment(&self) -> bool { + self.ip.iter().any(|t| t.may_fragment()) + } + + #[cfg(wasm_browser)] + pub(crate) fn may_fragment(&self) -> bool { + false + } + + pub(crate) fn create_sender(&self, msock: Arc) -> UdpSender { + #[cfg(not(wasm_browser))] + let ip = self.ip.iter().map(|t| t.create_sender()).collect(); + let relay = self.relay.iter().map(|t| t.create_sender()).collect(); + let max_transmit_segments = self.max_transmit_segments(); + + UdpSender { + #[cfg(not(wasm_browser))] + ip, + msock, + relay, + max_transmit_segments, + } + } + + /// Handles potential changes to the underlying network conditions. + pub(crate) fn create_network_change_sender(&self) -> NetworkChangeSender { + NetworkChangeSender { + #[cfg(not(wasm_browser))] + ip: self + .ip + .iter() + .map(|t| t.create_network_change_sender()) + .collect(), + relay: self + .relay + .iter() + .map(|t| t.create_network_change_sender()) + .collect(), + } + } +} + +#[derive(Debug)] +pub(crate) struct NetworkChangeSender { + #[cfg(not(wasm_browser))] + ip: Vec, + relay: Vec, +} + +impl NetworkChangeSender { + pub(crate) fn on_network_change(&self, info: &NetInfo) { + #[cfg(not(wasm_browser))] + for ip in &self.ip { + ip.on_network_change(info); + } + + for relay in &self.relay { + relay.on_network_change(info); + } + } + + /// Rebinds underlying connections, if necessary. + pub(crate) fn rebind(&self) -> std::io::Result<()> { + let mut res = Ok(()); + + #[cfg(not(wasm_browser))] + for transport in &self.ip { + if let Err(err) = transport.rebind() { + warn!("failed to rebind {:?}", err); + res = Err(err); + } + } + + for transport in &self.relay { + if let Err(err) = transport.rebind() { + warn!("failed to rebind {:?}", err); + res = Err(err); + } + } + res + } +} + +/// An outgoing packet +#[derive(Debug, Clone)] +pub(crate) struct Transmit<'a> { + pub(crate) ecn: Option, + pub(crate) contents: &'a [u8], + pub(crate) segment_size: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum Addr { + Ip(SocketAddr), + Relay(RelayUrl, NodeId), +} + +impl Default for Addr { + fn default() -> Self { + Self::Ip(SocketAddr::V6(SocketAddrV6::new( + Ipv6Addr::UNSPECIFIED, + 0, + 0, + 0, + ))) + } +} + +impl From for Addr { + fn from(value: SocketAddr) -> Self { + Self::Ip(value) + } +} + +impl From<(RelayUrl, NodeId)> for Addr { + fn from(value: (RelayUrl, NodeId)) -> Self { + Self::Relay(value.0, value.1) + } +} + +impl Addr { + pub(crate) fn is_relay(&self) -> bool { + matches!(self, Self::Relay(..)) + } + + pub(crate) fn is_ip(&self) -> bool { + matches!(self, Self::Ip(..)) + } + + /// Returns `None` if not an `Ip`. + pub(crate) fn into_socket_addr(self) -> Option { + match self { + Self::Ip(ip) => Some(ip), + Self::Relay(..) => None, + } + } +} + +#[derive(Debug)] +pub(crate) struct UdpSender { + msock: Arc, // :( + #[cfg(not(wasm_browser))] + ip: Vec, + relay: Vec, + max_transmit_segments: usize, +} + +impl UdpSender { + pub(crate) async fn send( + &self, + destination: &Addr, + src: Option, + transmit: &Transmit<'_>, + ) -> io::Result<()> { + trace!(?destination, "sending"); + + let mut any_match = false; + match destination { + #[cfg(wasm_browser)] + Addr::Ip(..) => return Err(io::Error::other("IP is unsupported in browser")), + #[cfg(not(wasm_browser))] + Addr::Ip(addr) => { + for sender in &self.ip { + if sender.is_valid_send_addr(addr) { + any_match = true; + match sender.send(*addr, src, transmit).await { + Ok(()) => { + return Ok(()); + } + Err(err) => { + warn!("ip failed to send: {:?}", err); + } + } + } + } + } + Addr::Relay(url, node_id) => { + for sender in &self.relay { + if sender.is_valid_send_addr(url, node_id) { + any_match = true; + match sender.send(url.clone(), *node_id, transmit).await { + Ok(()) => { + return Ok(()); + } + Err(err) => { + warn!("relay failed to send: {:?}", err); + } + } + } + } + } + } + if any_match { + Err(io::Error::other("all available transports failed")) + } else { + Err(io::Error::other("no transport available")) + } + } + + pub(crate) fn inner_poll_send( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context, + destination: &Addr, + src: Option, + transmit: &Transmit<'_>, + ) -> Poll> { + trace!(?destination, "sending"); + + match destination { + #[cfg(wasm_browser)] + Addr::Ip(..) => { + return Poll::Ready(Err(io::Error::other("IP is unsupported in browser"))) + } + #[cfg(not(wasm_browser))] + Addr::Ip(addr) => { + for sender in &mut self.ip { + if sender.is_valid_send_addr(addr) { + match Pin::new(sender).poll_send(cx, *addr, src, transmit) { + Poll::Pending => {} + Poll::Ready(res) => return Poll::Ready(res), + } + } + } + } + Addr::Relay(url, node_id) => { + for sender in &mut self.relay { + if sender.is_valid_send_addr(url, node_id) { + match sender.poll_send(cx, url.clone(), *node_id, transmit) { + Poll::Pending => {} + Poll::Ready(res) => return Poll::Ready(res), + } + } + } + } + } + Poll::Pending + } + + /// Best effort sending + fn inner_try_send( + &self, + destination: &Addr, + src: Option, + transmit: &Transmit<'_>, + ) -> io::Result<()> { + trace!(?destination, "sending, best effort"); + + match destination { + #[cfg(wasm_browser)] + Addr::Ip(..) => return Err(io::Error::other("IP is unsupported in browser")), + #[cfg(not(wasm_browser))] + Addr::Ip(addr) => { + for transport in &self.ip { + if transport.is_valid_send_addr(addr) { + match transport.try_send(*addr, src, transmit) { + Ok(()) => return Ok(()), + Err(_err) => { + continue; + } + } + } + } + } + Addr::Relay(url, node_id) => { + for transport in &self.relay { + if transport.is_valid_send_addr(url, node_id) { + match transport.try_send(url.clone(), *node_id, transmit) { + Ok(()) => return Ok(()), + Err(_err) => { + continue; + } + } + } + } + } + } + Err(io::Error::new( + io::ErrorKind::WouldBlock, + "no transport ready", + )) + } +} + +impl quinn::UdpSender for UdpSender { + fn poll_send( + mut self: Pin<&mut Self>, + transmit: &quinn_udp::Transmit, + cx: &mut Context, + ) -> Poll> { + let active_paths = self.msock.prepare_send(transmit)?; + + if active_paths.is_empty() { + // Returning Ok here means we let QUIC timeout. + // Returning an error would immediately fail a connection. + // The philosophy of quinn-udp is that a UDP connection could + // come back at any time or missing should be transient so chooses to let + // these kind of errors time out. See test_try_send_no_send_addr to try + // this out. + error!("no paths available for node, voiding transmit"); + return Poll::Ready(Ok(())); + } + + let mut results = SmallVec::<[_; 3]>::new(); + + trace!(?active_paths, "attempting to send"); + + for destination in active_paths { + let src = transmit.src_ip; + let transmit = Transmit { + ecn: transmit.ecn, + contents: transmit.contents, + segment_size: transmit.segment_size, + }; + + let res = self + .as_mut() + .inner_poll_send(cx, &destination, src, &transmit); + match res { + Poll::Ready(Ok(())) => { + trace!(dst = ?destination, "sent transmit"); + } + Poll::Ready(Err(ref err)) => { + warn!(dst = ?destination, "failed to send: {err:#}"); + } + Poll::Pending => {} + } + results.push(res); + } + + if results.iter().all(|p| matches!(p, Poll::Pending)) { + // Handle backpressure. + return Poll::Pending; + } + Poll::Ready(Ok(())) + } + + fn max_transmit_segments(&self) -> usize { + self.max_transmit_segments + } + + fn try_send(self: Pin<&mut Self>, transmit: &quinn_udp::Transmit) -> io::Result<()> { + let active_paths = self.msock.prepare_send(transmit)?; + if active_paths.is_empty() { + // Returning Ok here means we let QUIC timeout. + // Returning an error would immediately fail a connection. + // The philosophy of quinn-udp is that a UDP connection could + // come back at any time or missing should be transient so chooses to let + // these kind of errors time out. See test_try_send_no_send_addr to try + // this out. + error!("no paths available for node, voiding transmit"); + return Ok(()); + } + + let mut results = SmallVec::<[_; 3]>::new(); + + trace!(?active_paths, "attempting to send"); + + for destination in active_paths { + let src = transmit.src_ip; + let transmit = Transmit { + ecn: transmit.ecn, + contents: transmit.contents, + segment_size: transmit.segment_size, + }; + + let res = self.inner_try_send(&destination, src, &transmit); + match res { + Ok(()) => { + trace!(dst = ?destination, "sent transmit"); + } + Err(ref err) => { + warn!(dst = ?destination, "failed to send: {err:#}"); + } + } + results.push(res); + } + + if results.iter().all(|p| p.is_err()) { + return Err(io::Error::other("all failed")); + } + Ok(()) + } +} diff --git a/iroh/src/magicsock/transports/ip.rs b/iroh/src/magicsock/transports/ip.rs new file mode 100644 index 00000000000..cc4b945d541 --- /dev/null +++ b/iroh/src/magicsock/transports/ip.rs @@ -0,0 +1,249 @@ +use std::{ + io, + net::{IpAddr, SocketAddr}, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use n0_watcher::Watchable; +use netwatch::{UdpSender, UdpSocket}; +use pin_project::pin_project; +use tracing::trace; + +use super::{Addr, Transmit}; +use crate::metrics::MagicsockMetrics; + +#[derive(Debug)] +pub(crate) struct IpTransport { + bind_addr: SocketAddr, + socket: Arc, + local_addr: Watchable, + metrics: Arc, +} + +impl IpTransport { + pub(crate) fn new( + bind_addr: SocketAddr, + socket: Arc, + metrics: Arc, + ) -> Self { + // Currently gets updated on manual rebind + // TODO: update when UdpSocket under the hood rebinds automatically + let local_addr = Watchable::new(socket.local_addr().expect("invalid socket")); + + Self { + bind_addr, + socket, + local_addr, + metrics, + } + } + + /// NOTE: Receiving on a closed socket will return [`Poll::Pending`] indefinitely. + pub(super) fn poll_recv( + &mut self, + cx: &mut Context, + bufs: &mut [io::IoSliceMut<'_>], + metas: &mut [quinn_udp::RecvMeta], + source_addrs: &mut [Addr], + ) -> Poll> { + match self.socket.poll_recv_quinn(cx, bufs, metas) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(n)) => { + for (addr, el) in source_addrs.iter_mut().zip(metas.iter()).take(n) { + *addr = el.addr.into(); + } + Poll::Ready(Ok(n)) + } + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + } + } + + pub(super) fn local_addr_watch(&self) -> n0_watcher::Direct { + self.local_addr.watch() + } + + pub(super) fn max_transmit_segments(&self) -> usize { + self.socket.max_gso_segments() + } + + pub(super) fn max_receive_segments(&self) -> usize { + self.socket.gro_segments() + } + + pub(super) fn may_fragment(&self) -> bool { + self.socket.may_fragment() + } + + pub(crate) fn bind_addr(&self) -> SocketAddr { + self.bind_addr + } + + pub(super) fn create_network_change_sender(&self) -> IpNetworkChangeSender { + IpNetworkChangeSender { + socket: self.socket.clone(), + local_addr: self.local_addr.clone(), + } + } + + pub(crate) fn socket(&self) -> Arc { + self.socket.clone() + } + + pub(super) fn create_sender(&self) -> IpSender { + let sender = self.socket.clone().create_sender(); + IpSender { + bind_addr: self.bind_addr, + sender, + metrics: self.metrics.clone(), + } + } +} + +#[derive(Debug)] +pub(super) struct IpNetworkChangeSender { + socket: Arc, + local_addr: Watchable, +} + +impl IpNetworkChangeSender { + pub(super) fn rebind(&self) -> io::Result<()> { + self.socket.rebind()?; + let addr = self.socket.local_addr()?; + self.local_addr.set(addr).ok(); + + Ok(()) + } + + pub(super) fn on_network_change(&self, _info: &crate::magicsock::NetInfo) { + // Nothing to do for now + } +} + +#[derive(Debug)] +#[pin_project] +pub(super) struct IpSender { + bind_addr: SocketAddr, + #[pin] + sender: UdpSender, + metrics: Arc, +} + +impl IpSender { + pub(super) fn is_valid_send_addr(&self, addr: &SocketAddr) -> bool { + #[allow(clippy::match_like_matches_macro)] + match (self.bind_addr, addr) { + (SocketAddr::V4(_), SocketAddr::V4(..)) => true, + (SocketAddr::V6(_), SocketAddr::V6(..)) => true, + _ => false, + } + } + + pub(super) async fn send( + &self, + destination: SocketAddr, + src: Option, + transmit: &Transmit<'_>, + ) -> io::Result<()> { + trace!("sending to {}", destination); + let total_bytes = transmit.contents.len() as u64; + let res = self + .sender + .send(&quinn_udp::Transmit { + destination, + ecn: transmit.ecn, + contents: transmit.contents, + segment_size: transmit.segment_size, + src_ip: src, + }) + .await; + trace!("send res: {:?}", res); + + match res { + Ok(res) => { + match destination { + SocketAddr::V4(_) => { + self.metrics.send_ipv4.inc_by(total_bytes); + } + SocketAddr::V6(_) => { + self.metrics.send_ipv6.inc_by(total_bytes); + } + } + Ok(res) + } + Err(err) => Err(err), + } + } + + pub(super) fn poll_send( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context, + destination: SocketAddr, + src: Option, + transmit: &Transmit<'_>, + ) -> Poll> { + trace!("sending to {}", destination); + let total_bytes = transmit.contents.len() as u64; + let res = Pin::new(&mut self.sender).poll_send( + &quinn_udp::Transmit { + destination, + ecn: transmit.ecn, + contents: transmit.contents, + segment_size: transmit.segment_size, + src_ip: src, + }, + cx, + ); + trace!("send res: {:?}", res); + + match res { + Poll::Ready(Ok(res)) => { + match destination { + SocketAddr::V4(_) => { + self.metrics.send_ipv4.inc_by(total_bytes); + } + SocketAddr::V6(_) => { + self.metrics.send_ipv6.inc_by(total_bytes); + } + } + Poll::Ready(Ok(res)) + } + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + Poll::Pending => Poll::Pending, + } + } + + pub(super) fn try_send( + &self, + destination: SocketAddr, + src: Option, + transmit: &Transmit<'_>, + ) -> io::Result<()> { + trace!("sending to {}", destination); + let total_bytes = transmit.contents.len() as u64; + let res = self.sender.try_send(&quinn_udp::Transmit { + destination, + ecn: transmit.ecn, + contents: transmit.contents, + segment_size: transmit.segment_size, + src_ip: src, + }); + trace!("send res: {:?}", res); + + match res { + Ok(res) => { + match destination { + SocketAddr::V4(_) => { + self.metrics.send_ipv4.inc_by(total_bytes); + } + SocketAddr::V6(_) => { + self.metrics.send_ipv6.inc_by(total_bytes); + } + } + Ok(res) + } + Err(err) => Err(err), + } + } +} diff --git a/iroh/src/magicsock/transports/relay.rs b/iroh/src/magicsock/transports/relay.rs new file mode 100644 index 00000000000..0a9d9ef89c0 --- /dev/null +++ b/iroh/src/magicsock/transports/relay.rs @@ -0,0 +1,423 @@ +use std::{ + io, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use iroh_base::{NodeId, RelayUrl}; +use n0_future::{ + ready, + task::{self, AbortOnDropHandle}, +}; +use n0_watcher::{Watchable, Watcher as _}; +use smallvec::SmallVec; +use tokio::sync::mpsc; +use tokio_util::sync::PollSender; +use tracing::{error, info_span, trace, warn, Instrument}; + +use super::{Addr, Transmit}; +use crate::magicsock::RelayContents; + +mod actor; + +pub(crate) use self::actor::Config as RelayActorConfig; +use self::actor::{RelayActor, RelayActorMessage, RelayRecvDatagram, RelaySendItem}; + +#[derive(Debug)] +pub(crate) struct RelayTransport { + /// Queue to receive datagrams from relays for [`quinn::AsyncUdpSocket::poll_recv`]. + relay_datagram_recv_queue: mpsc::Receiver, + /// Channel on which to send datagrams via a relay server. + relay_datagram_send_channel: mpsc::Sender, + actor_sender: mpsc::Sender, + _actor_handle: AbortOnDropHandle<()>, + my_relay: Watchable>, + my_node_id: NodeId, +} + +impl RelayTransport { + pub(crate) fn new(config: RelayActorConfig) -> Self { + let (relay_datagram_send_tx, relay_datagram_send_rx) = mpsc::channel(256); + + let (relay_datagram_recv_tx, relay_datagram_recv_rx) = mpsc::channel(512); + + let (actor_sender, actor_receiver) = mpsc::channel(256); + + let my_node_id = config.secret_key.public(); + let my_relay = config.my_relay.clone(); + + let relay_actor = RelayActor::new(config, relay_datagram_recv_tx); + + let actor_handle = AbortOnDropHandle::new(task::spawn( + async move { + relay_actor + .run(actor_receiver, relay_datagram_send_rx) + .await; + } + .instrument(info_span!("relay-actor")), + )); + + Self { + relay_datagram_recv_queue: relay_datagram_recv_rx, + relay_datagram_send_channel: relay_datagram_send_tx, + actor_sender, + _actor_handle: actor_handle, + my_relay, + my_node_id, + } + } + + pub(crate) fn create_sender(&self) -> RelaySender { + RelaySender { + sender: PollSender::new(self.relay_datagram_send_channel.clone()), + } + } + + pub(super) fn poll_recv( + &mut self, + cx: &mut Context, + bufs: &mut [io::IoSliceMut<'_>], + metas: &mut [quinn_udp::RecvMeta], + source_addrs: &mut [Addr], + ) -> Poll> { + let mut num_msgs = 0; + for ((buf_out, meta_out), addr) in bufs + .iter_mut() + .zip(metas.iter_mut()) + .zip(source_addrs.iter_mut()) + { + let dm = match self.relay_datagram_recv_queue.poll_recv(cx) { + Poll::Ready(Some(recv)) => recv, + Poll::Ready(None) => { + error!("relay_recv_channel closed"); + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::NotConnected, + "connection closed", + ))); + } + Poll::Pending => { + break; + } + }; + + buf_out[..dm.buf.len()].copy_from_slice(&dm.buf); + meta_out.len = dm.buf.len(); + meta_out.stride = dm.buf.len(); + meta_out.ecn = None; + meta_out.dst_ip = None; // TODO: insert the relay url for this relay + + *addr = (dm.url, dm.src).into(); + num_msgs += 1; + } + + // If we have any msgs to report, they are in the first `num_msgs_total` slots + if num_msgs > 0 { + debug_assert!(num_msgs <= metas.len()); + Poll::Ready(Ok(num_msgs)) + } else { + Poll::Pending + } + } + + pub(super) fn local_addr_watch( + &self, + ) -> n0_watcher::Map>, Option<(RelayUrl, NodeId)>> { + let my_node_id = self.my_node_id; + self.my_relay + .watch() + .map(move |url| url.map(|url| (url, my_node_id))) + .expect("disconnected") + } + + pub(super) fn create_network_change_sender(&self) -> RelayNetworkChangeSender { + RelayNetworkChangeSender { + sender: self.actor_sender.clone(), + } + } +} + +#[derive(Debug)] +pub(super) struct RelayNetworkChangeSender { + sender: mpsc::Sender, +} + +impl RelayNetworkChangeSender { + pub(super) fn on_network_change(&self, info: &crate::magicsock::NetInfo) { + self.send_relay_actor(RelayActorMessage::NetworkChange { info: info.clone() }); + } + + pub(super) fn rebind(&self) -> io::Result<()> { + self.send_relay_actor(RelayActorMessage::MaybeCloseRelaysOnRebind); + + Ok(()) + } + + fn send_relay_actor(&self, msg: RelayActorMessage) { + match self.sender.try_send(msg) { + Ok(_) => {} + Err(mpsc::error::TrySendError::Closed(_)) => { + warn!("unable to send to relay actor, already closed"); + } + Err(mpsc::error::TrySendError::Full(_)) => { + warn!("dropping message for relay actor, channel is full"); + } + } + } +} + +/// Sender to send datagrams to the [`RelayActor`]. +/// +/// This includes the waker coordination required to support [`quinn::UdpSender::poll_send`]. +#[derive(Debug, Clone)] +pub(crate) struct RelaySender { + sender: PollSender, +} + +impl RelaySender { + pub(super) fn is_valid_send_addr(&self, _url: &RelayUrl, _node_id: &NodeId) -> bool { + true + } + + pub(super) async fn send( + &self, + dest_url: RelayUrl, + dest_node: NodeId, + transmit: &Transmit<'_>, + ) -> io::Result<()> { + let contents = split_packets(transmit); + + let item = RelaySendItem { + remote_node: dest_node, + url: dest_url.clone(), + datagrams: contents, + }; + + let dest_node = item.remote_node; + let dest_url = item.url.clone(); + let Some(sender) = self.sender.get_ref() else { + return Err(io::Error::other("channel closed")); + }; + match sender.send(item).await { + Ok(_) => { + trace!(node = %dest_node.fmt_short(), relay_url = %dest_url, + "send relay: message queued"); + Ok(()) + } + Err(mpsc::error::SendError(_)) => { + error!(node = %dest_node.fmt_short(), relay_url = %dest_url, + "send relay: message dropped, channel to actor is closed"); + Err(io::Error::new( + io::ErrorKind::ConnectionReset, + "channel to actor is closed", + )) + } + } + } + + pub(super) fn poll_send( + &mut self, + cx: &mut Context, + dest_url: RelayUrl, + dest_node: NodeId, + transmit: &Transmit<'_>, + ) -> Poll> { + match ready!(self.sender.poll_reserve(cx)) { + Ok(()) => { + trace!(node = %dest_node.fmt_short(), relay_url = %dest_url, + "send relay: message queued"); + + let contents = split_packets(transmit); + let item = RelaySendItem { + remote_node: dest_node, + url: dest_url.clone(), + datagrams: contents, + }; + let dest_node = item.remote_node; + let dest_url = item.url.clone(); + + match self.sender.send_item(item) { + Ok(()) => Poll::Ready(Ok(())), + Err(_err) => { + error!(node = %dest_node.fmt_short(), relay_url = %dest_url, + "send relay: message dropped, channel to actor is closed"); + Poll::Ready(Err(io::Error::new( + io::ErrorKind::ConnectionReset, + "channel to actor is closed", + ))) + } + } + } + Err(_err) => { + error!(node = %dest_node.fmt_short(), relay_url = %dest_url, + "send relay: message dropped, channel to actor is closed"); + Poll::Ready(Err(io::Error::new( + io::ErrorKind::ConnectionReset, + "channel to actor is closed", + ))) + } + } + } + + pub(super) fn try_send( + &self, + dest_url: RelayUrl, + dest_node: NodeId, + transmit: &Transmit<'_>, + ) -> io::Result<()> { + let contents = split_packets(transmit); + + let item = RelaySendItem { + remote_node: dest_node, + url: dest_url.clone(), + datagrams: contents, + }; + + let dest_node = item.remote_node; + let dest_url = item.url.clone(); + + let Some(sender) = self.sender.get_ref() else { + return Err(io::Error::other("channel closed")); + }; + + match sender.try_send(item) { + Ok(_) => { + trace!(node = %dest_node.fmt_short(), relay_url = %dest_url, + "send relay: message queued"); + Ok(()) + } + Err(mpsc::error::TrySendError::Closed(_)) => { + error!(node = %dest_node.fmt_short(), relay_url = %dest_url, + "send relay: message dropped, channel to actor is closed"); + Err(io::Error::new( + io::ErrorKind::ConnectionReset, + "channel to actor is closed", + )) + } + Err(mpsc::error::TrySendError::Full(_)) => { + warn!(node = %dest_node.fmt_short(), relay_url = %dest_url, + "send relay: message dropped, channel to actor is full"); + Err(io::Error::new(io::ErrorKind::WouldBlock, "channel full")) + } + } + } +} + +/// Split a transmit containing a GSO payload into individual packets. +/// +/// This allocates the data. +/// +/// If the transmit has a segment size it contains multiple GSO packets. It will be split +/// into multiple packets according to that segment size. If it does not have a segment +/// size, the contents will be sent as a single packet. +// TODO: If quinn stayed on bytes this would probably be much cheaper, probably. Need to +// figure out where they allocate the Vec. +fn split_packets(transmit: &Transmit<'_>) -> RelayContents { + let mut res = SmallVec::with_capacity(1); + let contents = transmit.contents; + if let Some(segment_size) = transmit.segment_size { + for chunk in contents.chunks(segment_size) { + res.push(Bytes::from(chunk.to_vec())); + } + } else { + res.push(Bytes::from(contents.to_vec())); + } + res +} + +#[cfg(test)] +mod tests { + use std::{collections::BTreeSet, time::Duration}; + + use iroh_base::NodeId; + use tokio::task::JoinSet; + use tracing::debug; + + use super::*; + use crate::defaults::staging; + + #[test] + fn test_split_packets() { + fn mk_transmit(contents: &[u8], segment_size: Option) -> Transmit<'_> { + Transmit { + ecn: None, + contents, + segment_size, + } + } + fn mk_expected(parts: impl IntoIterator) -> RelayContents { + parts + .into_iter() + .map(|p| p.as_bytes().to_vec().into()) + .collect() + } + // no split + assert_eq!( + split_packets(&mk_transmit(b"hello", None)), + mk_expected(["hello"]) + ); + // split without rest + assert_eq!( + split_packets(&mk_transmit(b"helloworld", Some(5))), + mk_expected(["hello", "world"]) + ); + // split with rest and second transmit + assert_eq!( + split_packets(&mk_transmit(b"hello world", Some(5))), + mk_expected(["hello", " worl", "d"]) // spellchecker:disable-line + ); + // split that results in 1 packet + assert_eq!( + split_packets(&mk_transmit(b"hello world", Some(1000))), + mk_expected(["hello world"]) + ); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_relay_datagram_queue() { + let capacity = 16; + let (sender, mut receiver) = mpsc::channel(capacity); + let url = staging::default_na_relay_node().url; + + let mut tasks = JoinSet::new(); + + tasks.spawn({ + async move { + let mut expected_msgs: BTreeSet = (0..capacity).collect(); + while !expected_msgs.is_empty() { + let datagram: RelayRecvDatagram = receiver.recv().await.unwrap(); + let msg_num = usize::from_le_bytes(datagram.buf.as_ref().try_into().unwrap()); + debug!("Received {msg_num}"); + + if !expected_msgs.remove(&msg_num) { + panic!("Received message number {msg_num} twice or more, but expected it only exactly once."); + } + } + } + }); + + for i in 0..capacity { + tasks.spawn({ + let sender = sender.clone(); + let url = url.clone(); + async move { + debug!("Sending {i}"); + sender + .try_send(RelayRecvDatagram { + url, + src: NodeId::from_bytes(&[0u8; 32]).unwrap(), + buf: Bytes::copy_from_slice(&i.to_le_bytes()), + }) + .unwrap(); + } + }); + } + + // We expect all of this work to be done in 10 seconds max. + if tokio::time::timeout(Duration::from_secs(10), tasks.join_all()) + .await + .is_err() + { + panic!("Timeout - not all messages between 0 and {capacity} received."); + } + } +} diff --git a/iroh/src/magicsock/relay_actor.rs b/iroh/src/magicsock/transports/relay/actor.rs similarity index 92% rename from iroh/src/magicsock/relay_actor.rs rename to iroh/src/magicsock/transports/relay/actor.rs index 56ee46b50db..37e84d00a99 100644 --- a/iroh/src/magicsock/relay_actor.rs +++ b/iroh/src/magicsock/transports/relay/actor.rs @@ -50,18 +50,19 @@ use n0_future::{ time::{self, Duration, Instant, MissedTickBehavior}, FuturesUnorderedBounded, SinkExt, StreamExt, }; +use n0_watcher::Watchable; use nested_enum_utils::common_fields; +use netwatch::interfaces; use snafu::{IntoError, ResultExt, Snafu}; use tokio::sync::{mpsc, oneshot}; use tokio_util::sync::CancellationToken; -use tracing::{debug, error, event, info_span, instrument, trace, warn, Instrument, Level}; +use tracing::{debug, error, event, info, info_span, instrument, trace, warn, Instrument, Level}; use url::Url; -use super::RelayDatagramSendChannelReceiver; #[cfg(not(wasm_browser))] use crate::dns::DnsResolver; use crate::{ - magicsock::{MagicSock, Metrics as MagicsockMetrics, RelayContents, RelayDatagramRecvQueue}, + magicsock::{Metrics as MagicsockMetrics, NetInfo, RelayContents}, util::MaybeFuture, }; @@ -135,7 +136,7 @@ struct ActiveRelayActor { /// Inbox for messages which involve sending to the relay server. inbox: mpsc::Receiver, /// Queue for received relay datagrams. - relay_datagrams_recv: Arc, + relay_datagrams_recv: mpsc::Sender, /// Channel on which we queue packets to send to the relay. relay_datagrams_send: mpsc::Receiver, @@ -197,7 +198,7 @@ struct ActiveRelayActorOptions { prio_inbox_: mpsc::Receiver, inbox: mpsc::Receiver, relay_datagrams_send: mpsc::Receiver, - relay_datagrams_recv: Arc, + relay_datagrams_recv: mpsc::Sender, connection_opts: RelayConnectionOptions, stop_token: CancellationToken, metrics: Arc, @@ -842,28 +843,24 @@ impl ConnectedRelayState { } pub(super) enum RelayActorMessage { - MaybeCloseRelaysOnRebind(Vec), - SetHome { url: RelayUrl }, + MaybeCloseRelaysOnRebind, + NetworkChange { info: NetInfo }, } #[derive(Debug, Clone)] -pub(super) struct RelaySendItem { +pub(crate) struct RelaySendItem { /// The destination for the datagrams. - pub(super) remote_node: NodeId, + pub(crate) remote_node: NodeId, /// The home relay of the remote node. - pub(super) url: RelayUrl, + pub(crate) url: RelayUrl, /// One or more datagrams to send. - pub(super) datagrams: RelayContents, + pub(crate) datagrams: RelayContents, } pub(super) struct RelayActor { - msock: Arc, + config: Config, /// Queue on which to put received datagrams. - /// - /// [`AsyncUdpSocket::poll_recv`] will read from this queue. - /// - /// [`AsyncUdpSocket::poll_recv`]: quinn::AsyncUdpSocket::poll_recv - relay_datagram_recv_queue: Arc, + relay_datagram_recv_queue: mpsc::Sender, /// The actors managing each currently used relay server. /// /// These actors will exit when they have any inactivity. Otherwise they will keep @@ -872,34 +869,43 @@ pub(super) struct RelayActor { /// The tasks for the [`ActiveRelayActor`]s in `active_relays` above. active_relay_tasks: JoinSet<()>, cancel_token: CancellationToken, - protocol: iroh_relay::http::Protocol, +} + +#[derive(Debug)] +pub struct Config { + pub my_relay: Watchable>, + pub secret_key: SecretKey, + #[cfg(not(wasm_browser))] + pub dns_resolver: DnsResolver, + /// Proxy + pub proxy_url: Option, + /// If the last net_report report, reports IPv6 to be available. + pub ipv6_reported: Arc, + #[cfg(any(test, feature = "test-utils"))] + pub insecure_skip_relay_cert_verify: bool, + pub metrics: Arc, + pub protocol: iroh_relay::http::Protocol, } impl RelayActor { pub(super) fn new( - msock: Arc, - relay_datagram_recv_queue: Arc, - protocol: iroh_relay::http::Protocol, + config: Config, + relay_datagram_recv_queue: mpsc::Sender, ) -> Self { let cancel_token = CancellationToken::new(); Self { - msock, + config, relay_datagram_recv_queue, active_relays: Default::default(), active_relay_tasks: JoinSet::new(), cancel_token, - protocol, } } - pub(super) fn cancel_token(&self) -> CancellationToken { - self.cancel_token.clone() - } - pub(super) async fn run( mut self, mut receiver: mpsc::Receiver, - mut datagram_send_channel: RelayDatagramSendChannelReceiver, + mut datagram_send_channel: mpsc::Receiver, ) { // When this future is present, it is sending pending datagrams to an // ActiveRelayActor. We can not process further datagrams during this time. @@ -964,11 +970,11 @@ impl RelayActor { async fn handle_msg(&mut self, msg: RelayActorMessage) { match msg { - RelayActorMessage::SetHome { url } => { - self.set_home_relay(url).await; + RelayActorMessage::NetworkChange { info } => { + self.on_network_change(info).await; } - RelayActorMessage::MaybeCloseRelaysOnRebind(ifs) => { - self.maybe_close_relays_on_rebind(&ifs).await; + RelayActorMessage::MaybeCloseRelaysOnRebind => { + self.maybe_close_relays_on_rebind().await; } } } @@ -1001,6 +1007,28 @@ impl RelayActor { } } + async fn on_network_change(&mut self, info: NetInfo) { + let my_relay = self.config.my_relay.get(); + if info.preferred_relay == my_relay { + // No change. + return; + } + let old_relay = self + .config + .my_relay + .set(info.preferred_relay.clone()) + .unwrap_or_else(|e| e); + + if let Some(relay_url) = info.preferred_relay { + self.config.metrics.relay_home_change.inc(); + + // On change, notify all currently connected relay servers and + // start connecting to our home relay if we are not already. + info!("home is now relay {}, was {:?}", relay_url, old_relay); + self.set_home_relay(relay_url).await; + } + } + async fn set_home_relay(&mut self, home_url: RelayUrl) { let home_url_ref = &home_url; n0_future::join_all(self.active_relays.iter().map(|(url, handle)| async move { @@ -1066,7 +1094,7 @@ impl RelayActor { Some(e) => e.clone(), None => { let handle = self.start_active_relay(url.clone()); - if Some(&url) == self.msock.my_relay().as_ref() { + if Some(&url) == self.config.my_relay.get().as_ref() { if let Err(err) = handle .inbox_addr .try_send(ActiveRelayMessage::SetHomeRelay(true)) @@ -1085,14 +1113,14 @@ impl RelayActor { debug!(?url, "Adding relay connection"); let connection_opts = RelayConnectionOptions { - secret_key: self.msock.secret_key.clone(), + secret_key: self.config.secret_key.clone(), #[cfg(not(wasm_browser))] - dns_resolver: self.msock.dns_resolver.clone(), - proxy_url: self.msock.proxy_url().cloned(), - prefer_ipv6: self.msock.ipv6_reported.clone(), + dns_resolver: self.config.dns_resolver.clone(), + proxy_url: self.config.proxy_url.clone(), + prefer_ipv6: self.config.ipv6_reported.clone(), #[cfg(any(test, feature = "test-utils"))] - insecure_skip_cert_verify: self.msock.insecure_skip_relay_cert_verify, - protocol: self.protocol, + insecure_skip_cert_verify: self.config.insecure_skip_relay_cert_verify, + protocol: self.config.protocol, }; // TODO: Replace 64 with PER_CLIENT_SEND_QUEUE_DEPTH once that's unused @@ -1108,7 +1136,7 @@ impl RelayActor { relay_datagrams_recv: self.relay_datagram_recv_queue.clone(), connection_opts, stop_token: self.cancel_token.child_token(), - metrics: self.msock.metrics.magicsock.clone(), + metrics: self.config.metrics.clone(), }; let actor = ActiveRelayActor::new(opts); self.active_relay_tasks.spawn( @@ -1131,13 +1159,28 @@ impl RelayActor { /// Called in response to a rebind, any relay connection originating from an address /// that's not known to be currently a local IP address should be closed. All the other /// relay connections are pinged. - async fn maybe_close_relays_on_rebind(&mut self, okay_local_ips: &[IpAddr]) { - let send_futs = self.active_relays.values().map(|handle| async move { - handle - .inbox_addr - .send(ActiveRelayMessage::CheckConnection(okay_local_ips.to_vec())) - .await - .ok(); + async fn maybe_close_relays_on_rebind(&mut self) { + #[cfg(not(wasm_browser))] + let ifs = interfaces::State::new().await; + #[cfg(not(wasm_browser))] + let local_ips: Vec<_> = ifs + .interfaces + .values() + .flat_map(|netif| netif.addrs()) + .map(|ipnet| ipnet.addr()) + .collect(); + // In browsers, we don't have this information. This will do the right thing in the ActiveRelayActor, though. + #[cfg(wasm_browser)] + let local_ips = Vec::new(); + let send_futs = self.active_relays.values().map(|handle| { + let local_ips = local_ips.clone(); + async move { + handle + .inbox_addr + .send(ActiveRelayMessage::CheckConnection(local_ips)) + .await + .ok(); + } }); n0_future::join_all(send_futs).await; self.log_active_relay(); @@ -1149,8 +1192,8 @@ impl RelayActor { .retain(|_url, handle| !handle.inbox_addr.is_closed()); // Make sure home relay exists - if let Some(ref url) = self.msock.my_relay() { - self.active_relay_handle(url.clone()); + if let Some(url) = self.config.my_relay.get() { + self.active_relay_handle(url); } self.log_active_relay(); } @@ -1209,10 +1252,10 @@ struct RelaySendPacket { /// /// This could be either a QUIC or DISCO packet. #[derive(Debug)] -pub(super) struct RelayRecvDatagram { - pub(super) url: RelayUrl, - pub(super) src: NodeId, - pub(super) buf: Bytes, +pub(crate) struct RelayRecvDatagram { + pub(crate) url: RelayUrl, + pub(crate) src: NodeId, + pub(crate) buf: Bytes, } /// Combines datagrams into a single DISCO frame of at most MAX_PACKET_SIZE. @@ -1220,7 +1263,7 @@ pub(super) struct RelayRecvDatagram { /// The disco `iroh_relay::protos::Frame::SendPacket` frame can contain more then a single /// datagram. Each datagram in this frame is prefixed with a little-endian 2-byte length /// prefix. This occurs when Quinn sends a GSO transmit containing more than one datagram, -/// which are split using [`crate::magicsock::split_packets`]. +/// which are split using `split_packets`. /// /// The [`PacketSplitIter`] does the inverse and splits such packets back into individual /// datagrams. @@ -1333,7 +1376,6 @@ mod tests { use bytes::Bytes; use iroh_base::{NodeId, RelayUrl, SecretKey}; use iroh_relay::PingTracker; - use n0_future::future; use n0_snafu::{Error, Result, ResultExt}; use smallvec::smallvec; use tokio::sync::{mpsc, oneshot}; @@ -1343,18 +1385,10 @@ mod tests { use super::{ ActiveRelayActor, ActiveRelayActorOptions, ActiveRelayMessage, ActiveRelayPrioMessage, - RelayConnectionOptions, RelaySendItem, MAX_PACKET_SIZE, - }; - use crate::{ - dns::DnsResolver, - magicsock::{ - relay_actor::{ - PacketizeIter, RELAY_INACTIVE_CLEANUP_TIME, UNDELIVERABLE_DATAGRAM_TIMEOUT, - }, - RelayDatagramRecvQueue, RelayRecvDatagram, - }, - test_utils, + PacketizeIter, RelayConnectionOptions, RelayRecvDatagram, RelaySendItem, MAX_PACKET_SIZE, + RELAY_INACTIVE_CLEANUP_TIME, UNDELIVERABLE_DATAGRAM_TIMEOUT, }; + use crate::{dns::DnsResolver, test_utils}; #[test] fn test_packetize_iter() { @@ -1396,7 +1430,7 @@ mod tests { prio_inbox_rx: mpsc::Receiver, inbox_rx: mpsc::Receiver, relay_datagrams_send: mpsc::Receiver, - relay_datagrams_recv: Arc, + relay_datagrams_recv: mpsc::Sender, span: tracing::Span, ) -> AbortOnDropHandle<()> { let opts = ActiveRelayActorOptions { @@ -1427,7 +1461,7 @@ mod tests { /// [`ActiveRelayNode`] under test to check connectivity works. fn start_echo_node(relay_url: RelayUrl) -> (NodeId, AbortOnDropHandle<()>) { let secret_key = SecretKey::from_bytes(&[8u8; 32]); - let recv_datagram_queue = Arc::new(RelayDatagramRecvQueue::new()); + let (recv_datagram_tx, mut recv_datagram_rx) = mpsc::channel(16); let (send_datagram_tx, send_datagram_rx) = mpsc::channel(16); let (prio_inbox_tx, prio_inbox_rx) = mpsc::channel(8); let (inbox_tx, inbox_rx) = mpsc::channel(16); @@ -1439,15 +1473,15 @@ mod tests { prio_inbox_rx, inbox_rx, send_datagram_rx, - recv_datagram_queue.clone(), + recv_datagram_tx, info_span!("echo-node"), ); let echo_task = tokio::spawn({ let relay_url = relay_url.clone(); async move { loop { - let datagram = future::poll_fn(|cx| recv_datagram_queue.poll_recv(cx)).await; - if let Ok(recv) = datagram { + let datagram = recv_datagram_rx.recv().await; + if let Some(recv) = datagram { let RelayRecvDatagram { url: _, src, buf } = recv; info!(from = src.fmt_short(), "Received datagram"); let send = RelaySendItem { @@ -1485,7 +1519,7 @@ mod tests { async fn send_recv_echo( item: RelaySendItem, tx: &mpsc::Sender, - rx: &Arc, + rx: &mut mpsc::Receiver, ) -> Result<()> { assert!(item.datagrams.len() == 1); tokio::time::timeout(Duration::from_secs(10), async move { @@ -1496,7 +1530,7 @@ mod tests { url: _, src: _, buf, - } = future::poll_fn(|cx| rx.poll_recv(cx)).await?; + } = rx.recv().await.unwrap(); assert_eq!(buf.as_ref(), item.datagrams[0]); @@ -1520,7 +1554,7 @@ mod tests { let (peer_node, _echo_node_task) = start_echo_node(relay_url.clone()); let secret_key = SecretKey::from_bytes(&[1u8; 32]); - let datagram_recv_queue = Arc::new(RelayDatagramRecvQueue::new()); + let (datagram_recv_tx, mut datagram_recv_rx) = mpsc::channel(16); let (send_datagram_tx, send_datagram_rx) = mpsc::channel(16); let (_prio_inbox_tx, prio_inbox_rx) = mpsc::channel(8); let (inbox_tx, inbox_rx) = mpsc::channel(16); @@ -1532,7 +1566,7 @@ mod tests { prio_inbox_rx, inbox_rx, send_datagram_rx, - datagram_recv_queue.clone(), + datagram_recv_tx.clone(), info_span!("actor-under-test"), ); @@ -1546,7 +1580,7 @@ mod tests { send_recv_echo( hello_send_item.clone(), &send_datagram_tx, - &datagram_recv_queue, + &mut datagram_recv_rx, ) .await?; @@ -1581,7 +1615,7 @@ mod tests { send_recv_echo( hello_send_item.clone(), &send_datagram_tx, - &datagram_recv_queue, + &mut datagram_recv_rx, ) .await?; @@ -1601,7 +1635,7 @@ mod tests { send_recv_echo( hello_send_item.clone(), &send_datagram_tx, - &datagram_recv_queue, + &mut datagram_recv_rx, ) .await?; @@ -1618,7 +1652,7 @@ mod tests { let (_relay_map, relay_url, _server) = test_utils::run_relay_server().await?; let secret_key = SecretKey::from_bytes(&[1u8; 32]); - let datagram_recv_queue = Arc::new(RelayDatagramRecvQueue::new()); + let (datagram_recv_tx, _datagram_recv_rx) = mpsc::channel(16); let (_send_datagram_tx, send_datagram_rx) = mpsc::channel(16); let (_prio_inbox_tx, prio_inbox_rx) = mpsc::channel(8); let (inbox_tx, inbox_rx) = mpsc::channel(16); @@ -1630,7 +1664,7 @@ mod tests { prio_inbox_rx, inbox_rx, send_datagram_rx, - datagram_recv_queue.clone(), + datagram_recv_tx, info_span!("actor-under-test"), ); diff --git a/iroh/src/magicsock/udp_conn.rs b/iroh/src/magicsock/udp_conn.rs deleted file mode 100644 index 5c618121508..00000000000 --- a/iroh/src/magicsock/udp_conn.rs +++ /dev/null @@ -1,81 +0,0 @@ -use std::{ - fmt::Debug, - io, - net::SocketAddr, - pin::Pin, - sync::Arc, - task::{Context, Poll}, -}; - -use netwatch::UdpSocket; -use quinn::AsyncUdpSocket; -use quinn_udp::Transmit; - -/// Wrapper struct to implement Quinn's [`AsyncUdpSocket`] for [`UdpSocket`]. -#[derive(Debug, Clone)] -pub(super) struct UdpConn { - inner: Arc, -} - -impl UdpConn { - pub(super) fn wrap(inner: Arc) -> Self { - Self { inner } - } - - pub(super) fn as_socket_ref(&self) -> &UdpSocket { - &self.inner - } - - pub(super) fn create_io_poller(&self) -> Pin> { - Box::pin(IoPoller { - io: self.inner.clone(), - }) - } -} - -impl AsyncUdpSocket for UdpConn { - fn create_io_poller(self: Arc) -> Pin> { - (*self).create_io_poller() - } - - fn try_send(&self, transmit: &Transmit<'_>) -> io::Result<()> { - self.inner.try_send_quinn(transmit) - } - - fn poll_recv( - &self, - cx: &mut Context, - bufs: &mut [io::IoSliceMut<'_>], - meta: &mut [quinn_udp::RecvMeta], - ) -> Poll> { - self.inner.poll_recv_quinn(cx, bufs, meta) - } - - fn local_addr(&self) -> io::Result { - self.inner.local_addr() - } - - fn may_fragment(&self) -> bool { - self.inner.may_fragment() - } - - fn max_transmit_segments(&self) -> usize { - self.inner.max_gso_segments() - } - - fn max_receive_segments(&self) -> usize { - self.inner.gro_segments() - } -} - -/// Poller for when the socket is writable. -#[derive(Debug)] -struct IoPoller { - io: Arc, -} - -impl quinn::UdpPoller for IoPoller { - fn poll_writable(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.io.poll_writable(cx) - } -} diff --git a/iroh/src/net_report.rs b/iroh/src/net_report.rs index 0919fc1ea23..4991db9b0ff 100644 --- a/iroh/src/net_report.rs +++ b/iroh/src/net_report.rs @@ -617,6 +617,8 @@ impl Actor { self.metrics.clone(), #[cfg(not(wasm_browser))] socket_state, + #[cfg(any(test, feature = "test-utils"))] + opts.insecure_skip_relay_cert_verify, ); self.current_report_run = Some(ReportRun { diff --git a/iroh/src/net_report/options.rs b/iroh/src/net_report/options.rs index ba55a614530..5540f56e976 100644 --- a/iroh/src/net_report/options.rs +++ b/iroh/src/net_report/options.rs @@ -50,6 +50,9 @@ mod imp { /// /// On by default pub(crate) https: bool, + + #[cfg(any(test, feature = "test-utils"))] + pub(crate) insecure_skip_relay_cert_verify: bool, } impl Default for Options { @@ -61,6 +64,8 @@ mod imp { icmp_v4: true, icmp_v6: true, https: true, + #[cfg(any(test, feature = "test-utils"))] + insecure_skip_relay_cert_verify: false, } } } @@ -75,6 +80,8 @@ mod imp { icmp_v4: false, icmp_v6: false, https: false, + #[cfg(any(test, feature = "test-utils"))] + insecure_skip_relay_cert_verify: false, } } @@ -114,6 +121,13 @@ mod imp { self } + /// Skip cert verification + #[cfg(any(test, feature = "test-utils"))] + pub fn insecure_skip_relay_cert_verify(mut self, skip: bool) -> Self { + self.insecure_skip_relay_cert_verify = skip; + self + } + /// Turn the options into set of valid protocols pub(crate) fn to_protocols(&self) -> BTreeSet { let mut protocols = BTreeSet::new(); diff --git a/iroh/src/net_report/reportgen.rs b/iroh/src/net_report/reportgen.rs index 28b8e5cbe26..4484f20047d 100644 --- a/iroh/src/net_report/reportgen.rs +++ b/iroh/src/net_report/reportgen.rs @@ -115,6 +115,7 @@ impl Client { protocols: BTreeSet, metrics: Arc, #[cfg(not(wasm_browser))] socket_state: SocketState, + #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: bool, ) -> Self { let (msg_tx, msg_rx) = mpsc::channel(32); let addr = Addr { @@ -134,6 +135,8 @@ impl Client { #[cfg(not(wasm_browser))] hairpin_actor: hairpin::Client::new(net_report, addr), metrics, + #[cfg(any(test, feature = "test-utils"))] + insecure_skip_relay_cert_verify, }; let task = task::spawn(async move { actor.run().await }.instrument(info_span!("reportgen.actor"))); @@ -215,6 +218,8 @@ struct Actor { #[cfg(not(wasm_browser))] hairpin_actor: hairpin::Client, metrics: Arc, + #[cfg(any(test, feature = "test-utils"))] + insecure_skip_relay_cert_verify: bool, } #[allow(missing_docs)] @@ -655,6 +660,8 @@ impl Actor { pinger, #[cfg(not(wasm_browser))] socket_state, + #[cfg(any(test, feature = "test-utils"))] + self.insecure_skip_relay_cert_verify, ) .instrument(debug_span!("run_probe", %probe)), ); @@ -815,9 +822,10 @@ pub enum QuicError { } /// Pieces needed to do QUIC address discovery. -#[derive(Debug, Clone)] +#[derive(derive_more::Debug, Clone)] pub struct QuicConfig { /// A QUIC Endpoint + #[debug("quinn::Endpoint")] pub ep: quinn::Endpoint, /// A client config. pub client_config: rustls::ClientConfig, @@ -830,6 +838,7 @@ pub struct QuicConfig { /// Executes a particular [`Probe`], including using a delayed start if needed. /// /// If *stun_sock4* and *stun_sock6* are `None` the STUN probes are disabled. +#[allow(clippy::too_many_arguments)] async fn run_probe( reportstate: Addr, relay_node: Arc, @@ -838,6 +847,7 @@ async fn run_probe( metrics: Arc, #[cfg(not(wasm_browser))] pinger: Pinger, #[cfg(not(wasm_browser))] socket_state: SocketState, + #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: bool, ) -> Result { if !probe.delay().is_zero() { trace!("delaying probe"); @@ -912,7 +922,8 @@ async fn run_probe( #[cfg(not(wasm_browser))] &socket_state.dns_resolver, node, - None, + #[cfg(any(test, feature = "test-utils"))] + insecure_skip_relay_cert_verify, ) .await { @@ -1424,9 +1435,6 @@ enum MeasureHttpsLatencyError { #[cfg(not(wasm_browser))] #[snafu(transparent)] DnsLookup { source: StaggeredError }, - #[cfg(not(wasm_browser))] - #[snafu(display("Invalid certificate"))] - InvalidCertificate { source: reqwest::Error }, #[snafu(display("Creating HTTP client failed"))] CreateReqwestClient { source: reqwest::Error }, #[snafu(display("HTTP request failed"))] @@ -1443,7 +1451,7 @@ enum MeasureHttpsLatencyError { async fn measure_https_latency( #[cfg(not(wasm_browser))] dns_resolver: &DnsResolver, node: &RelayNode, - certs: Option>>, + #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: bool, ) -> Result<(Duration, IpAddr), MeasureHttpsLatencyError> { let url = node.url.join(RELAY_PROBE_PATH)?; @@ -1474,14 +1482,9 @@ async fn measure_https_latency( builder = builder.resolve_to_addrs(domain, &addrs); } - #[cfg(not(wasm_browser))] - if let Some(certs) = certs { - for cert in certs { - let cert = reqwest::Certificate::from_der(&cert) - .context(measure_https_latency_error::InvalidCertificateSnafu)?; - builder = builder.add_root_certificate(cert); - } - } + #[cfg(all(not(wasm_browser), any(test, feature = "test-utils")))] + let builder = builder.danger_accept_invalid_certs(insecure_skip_relay_cert_verify); + let client = builder .build() .context(measure_https_latency_error::CreateReqwestClientSnafu)?; @@ -1877,11 +1880,10 @@ mod tests { #[tokio::test] async fn test_measure_https_latency() -> Result { - let (server, relay) = test_utils::relay().await; + let (_server, relay) = test_utils::relay().await; let dns_resolver = dns::tests::resolver(); tracing::info!(relay_url = ?relay.url , "RELAY_URL"); - let (latency, ip) = - measure_https_latency(&dns_resolver, &relay, server.certificates()).await?; + let (latency, ip) = measure_https_latency(&dns_resolver, &relay, true).await?; assert!(latency > Duration::ZERO); diff --git a/iroh/src/protocol.rs b/iroh/src/protocol.rs index 04ff2af58aa..529ad11870f 100644 --- a/iroh/src/protocol.rs +++ b/iroh/src/protocol.rs @@ -63,7 +63,6 @@ use crate::{ /// /// ```no_run /// # use std::sync::Arc; -/// # use futures_lite::future::Boxed as BoxedFuture; /// # use n0_snafu::ResultExt; /// # use iroh::{endpoint::Connecting, protocol::{ProtocolHandler, Router}, Endpoint, NodeAddr}; /// # @@ -544,10 +543,11 @@ mod tests { use std::{sync::Mutex, time::Duration}; use n0_snafu::{Result, ResultExt}; + use n0_watcher::Watcher; use quinn::ApplicationClose; use super::*; - use crate::{endpoint::ConnectionError, watcher::Watcher, RelayMode}; + use crate::{endpoint::ConnectionError, RelayMode}; #[tokio::test] async fn test_shutdown() -> Result { @@ -585,16 +585,24 @@ mod tests { Ok(()) } } + #[tokio::test] async fn test_limiter() -> Result { - let e1 = Endpoint::builder().bind().await?; + // tracing_subscriber::fmt::try_init().ok(); + let e1 = Endpoint::builder() + .relay_mode(RelayMode::Disabled) + .bind() + .await?; // deny all access let proto = AccessLimit::new(Echo, |_node_id| false); let r1 = Router::builder(e1.clone()).accept(ECHO_ALPN, proto).spawn(); let addr1 = r1.endpoint().node_addr().initialized().await?; - - let e2 = Endpoint::builder().bind().await?; + dbg!(&addr1); + let e2 = Endpoint::builder() + .relay_mode(RelayMode::Disabled) + .bind() + .await?; println!("connecting"); let conn = e2.connect(addr1, ECHO_ALPN).await?; @@ -633,6 +641,7 @@ mod tests { } } + eprintln!("creating ep1"); let endpoint = Endpoint::builder() .relay_mode(RelayMode::Disabled) .bind() @@ -640,16 +649,21 @@ mod tests { let router = Router::builder(endpoint) .accept(TEST_ALPN, TestProtocol::default()) .spawn(); + eprintln!("waiting for node addr"); let addr = router.endpoint().node_addr().initialized().await?; + eprintln!("creating ep2"); let endpoint2 = Endpoint::builder() .relay_mode(RelayMode::Disabled) .bind() .await?; + eprintln!("connecting to {:?}", addr); let conn = endpoint2.connect(addr, TEST_ALPN).await?; + eprintln!("starting shutdown"); router.shutdown().await.e()?; + eprintln!("waiting for closed conn"); let reason = conn.closed().await; assert_eq!( reason, diff --git a/iroh/src/test_utils.rs b/iroh/src/test_utils.rs index 368c66b1394..9e6c87def79 100644 --- a/iroh/src/test_utils.rs +++ b/iroh/src/test_utils.rs @@ -372,6 +372,7 @@ pub(crate) mod pkarr_dns_state { use iroh_base::NodeId; use iroh_relay::node_info::{NodeIdExt, NodeInfo, IROH_TXT_NAME}; use pkarr::SignedPacket; + use tracing::debug; use crate::test_utils::dns_server::QueryHandler; @@ -398,7 +399,13 @@ pub(crate) mod pkarr_dns_state { pub async fn on_node(&self, node: &NodeId, timeout: Duration) -> std::io::Result<()> { let timeout = tokio::time::sleep(timeout); tokio::pin!(timeout); - while self.get(node, |p| p.is_none()) { + while self.get(node, |p| { + let node_info = p + .as_ref() + .and_then(|p| NodeInfo::from_pkarr_signed_packet(p).ok()); + debug!("got info {:#?}", node_info); + p.is_none() + }) { tokio::select! { _ = &mut timeout => return Err(std::io::Error::other("timeout")), _ = self.on_update() => {} diff --git a/iroh/src/watcher.rs b/iroh/src/watcher.rs deleted file mode 100644 index a9b17726ec9..00000000000 --- a/iroh/src/watcher.rs +++ /dev/null @@ -1,699 +0,0 @@ -//! Watchable values. -//! -//! A [`Watchable`] exists to keep track of a value which may change over time. It allows -//! observers to be notified of changes to the value. The aim is to always be aware of the -//! **last** value, not to observe *every* value change. -//! -//! In that way, a [`Watchable`] is like a [`tokio::sync::broadcast::Sender`] (and a -//! [`Watcher`] is like a [`tokio::sync::broadcast::Receiver`]), except that there's no risk -//! of the channel filling up, but instead you might miss items. -//! -//! This module is meant to be imported like this (if you use all of these things): -//! ```ignore -//! use iroh::watcher::{self, Watchable, Watcher as _}; -//! ``` - -#[cfg(not(iroh_loom))] -use std::sync; -use std::{ - collections::VecDeque, - future::Future, - pin::Pin, - sync::{Arc, Weak}, - task::{self, Poll, Waker}, -}; - -#[cfg(iroh_loom)] -use loom::sync; -use snafu::Snafu; -use sync::{Mutex, RwLock}; - -/// A wrapper around a value that notifies [`Watcher`]s when the value is modified. -/// -/// Only the most recent value is available to any observer, but the observer is guaranteed -/// to be notified of the most recent value. -#[derive(Debug, Default)] -pub struct Watchable { - shared: Arc>, -} - -impl Clone for Watchable { - fn clone(&self) -> Self { - Self { - shared: self.shared.clone(), - } - } -} - -impl Watchable { - /// Creates a [`Watchable`] initialized to given value. - pub fn new(value: T) -> Self { - Self { - shared: Arc::new(Shared { - state: RwLock::new(State { - value, - epoch: INITIAL_EPOCH, - }), - watchers: Default::default(), - }), - } - } - - /// Sets a new value. - /// - /// Returns `Ok(previous_value)` if the value was different from the one set, or - /// returns the provided value back as `Err(value)` if the value didn't change. - /// - /// Watchers are only notified if the value changed. - pub fn set(&self, value: T) -> Result { - // We don't actually write when the value didn't change, but there's unfortunately - // no way to upgrade a read guard to a write guard, and locking as read first, then - // dropping and locking as write introduces a possible race condition. - let mut state = self.shared.state.write().expect("poisoned"); - - // Find out if the value changed - let changed = state.value != value; - - let ret = if changed { - let old = std::mem::replace(&mut state.value, value); - state.epoch += 1; - Ok(old) - } else { - Err(value) - }; - drop(state); // No need to write anymore - - // Notify watchers - if changed { - for watcher in self.shared.watchers.lock().expect("poisoned").drain(..) { - watcher.wake(); - } - } - ret - } - - /// Creates a [`Direct`] [`Watcher`], allowing the value to be observed, but not modified. - pub fn watch(&self) -> Direct { - Direct { - epoch: self.shared.state.read().expect("poisoned").epoch, - shared: Arc::downgrade(&self.shared), - } - } - - /// Returns the currently stored value. - pub fn get(&self) -> T { - self.shared.get() - } -} - -/// A handle to a value that's represented by one or more underlying [`Watchable`]s. -/// -/// A [`Watcher`] can get the current value, and will be notified when the value changes. -/// Only the most recent value is accessible, and if the threads with the underlying [`Watchable`]s -/// change the value faster than the threads with the [`Watcher`] can keep up with, then -/// it'll miss in-between values. -/// When the thread changing the [`Watchable`] pauses updating, the [`Watcher`] will always -/// end up reporting the most recent state eventually. -/// -/// Watchers can be modified via [`Watcher::map`] to observe a value derived from the original -/// value via a function. -/// -/// Watchers can be combined via [`Watcher::or`] to allow observing multiple values at once and -/// getting an update in case any of the values updates. -/// -/// One of the underlying [`Watchable`]s might already be dropped. In that case, -/// the watcher will be "disconnected" and return [`Err(Disconnected)`](Disconnected) -/// on some function calls or, when turned into a stream, that stream will end. -pub trait Watcher: Clone { - /// The type of value that can change. - /// - /// We require `Clone`, because we need to be able to make - /// the values have a lifetime that's detached from the original [`Watchable`]'s - /// lifetime. - /// - /// We require `Eq`, to be able to check whether the value actually changed or - /// not, so we can notify or not notify accordingly. - type Value: Clone + Eq; - - /// Returns the current state of the underlying value, or errors out with - /// [`Disconnected`], if one of the underlying [`Watchable`]s has been dropped. - fn get(&self) -> Result; - - /// Polls for the next value, or returns [`Disconnected`] if one of the underlying - /// [`Watchable`]s has been dropped. - fn poll_updated( - &mut self, - cx: &mut task::Context<'_>, - ) -> Poll>; - - /// Returns a future completing with `Ok(value)` once a new value is set, or with - /// [`Err(Disconnected)`](Disconnected) if the connected [`Watchable`] was dropped. - /// - /// # Cancel Safety - /// - /// The returned future is cancel-safe. - fn updated(&mut self) -> NextFut<'_, Self> { - NextFut { watcher: self } - } - - /// Returns a future completing once the value is set to [`Some`] value. - /// - /// If the current value is [`Some`] value, this future will resolve immediately. - /// - /// This is a utility for the common case of storing an [`Option`] inside a - /// [`Watchable`]. - /// - /// # Cancel Safety - /// - /// The returned future is cancel-safe. - fn initialized(&mut self) -> InitializedFut<'_, T, Self> - where - Self: Watcher>, - { - InitializedFut { - initial: match self.get() { - Ok(Some(value)) => Some(Ok(value)), - Ok(None) => None, - Err(val) => Some(Err(val)), - }, - watcher: self, - } - } - - /// Returns a stream which will yield the most recent values as items. - /// - /// The first item of the stream is the current value, so that this stream can be easily - /// used to operate on the most recent value. - /// - /// Note however, that only the last item is stored. If the stream is not polled when an - /// item is available it can be replaced with another item by the time it is polled. - /// - /// This stream ends once the original [`Watchable`] has been dropped. - /// - /// # Cancel Safety - /// - /// The returned stream is cancel-safe. - fn stream(self) -> Stream - where - Self: Unpin, - { - Stream { - initial: self.get().ok(), - watcher: self, - } - } - - /// Returns a stream which will yield the most recent values as items, starting from - /// the next unobserved future value. - /// - /// This means this stream will only yield values when the watched value changes, - /// the value stored at the time the stream is created is not yielded. - /// - /// Note however, that only the last item is stored. If the stream is not polled when an - /// item is available it can be replaced with another item by the time it is polled. - /// - /// This stream ends once the original [`Watchable`] has been dropped. - /// - /// # Cancel Safety - /// - /// The returned stream is cancel-safe. - fn stream_updates_only(self) -> Stream - where - Self: Unpin, - { - Stream { - initial: None, - watcher: self, - } - } - - /// Maps this watcher with a function that transforms the observed values. - /// - /// The returned watcher will only register updates, when the *mapped* value - /// observably changes. For this, it needs to store a clone of `T` in the watcher. - fn map( - self, - map: impl Fn(Self::Value) -> T + 'static, - ) -> Result, Disconnected> { - Ok(Map { - current: (map)(self.get()?), - map: Arc::new(map), - watcher: self, - }) - } - - /// Returns a watcher that updates every time this or the other watcher - /// updates, and yields both watcher's items together when that happens. - fn or(self, other: W) -> (Self, W) { - (self, other) - } -} - -/// The immediate, direct observer of a [`Watchable`] value. -/// -/// This type is mainly used via the [`Watcher`] interface. -#[derive(Debug, Clone)] -pub struct Direct { - epoch: u64, - shared: Weak>, -} - -impl Watcher for Direct { - type Value = T; - - fn get(&self) -> Result { - let shared = self.shared.upgrade().ok_or(DisconnectedSnafu.build())?; - Ok(shared.get()) - } - - fn poll_updated( - &mut self, - cx: &mut task::Context<'_>, - ) -> Poll> { - let Some(shared) = self.shared.upgrade() else { - return Poll::Ready(Err(DisconnectedSnafu.build())); - }; - match shared.poll_updated(cx, self.epoch) { - Poll::Pending => Poll::Pending, - Poll::Ready((current_epoch, value)) => { - self.epoch = current_epoch; - Poll::Ready(Ok(value)) - } - } - } -} - -impl Watcher for (S, T) { - type Value = (S::Value, T::Value); - - fn get(&self) -> Result { - Ok((self.0.get()?, self.1.get()?)) - } - - fn poll_updated( - &mut self, - cx: &mut task::Context<'_>, - ) -> Poll> { - let poll_0 = self.0.poll_updated(cx)?; - let poll_1 = self.1.poll_updated(cx)?; - match (poll_0, poll_1) { - (Poll::Ready(s), Poll::Ready(t)) => Poll::Ready(Ok((s, t))), - (Poll::Ready(s), Poll::Pending) => Poll::Ready(self.1.get().map(move |t| (s, t))), - (Poll::Pending, Poll::Ready(t)) => Poll::Ready(self.0.get().map(move |s| (s, t))), - (Poll::Pending, Poll::Pending) => Poll::Pending, - } - } -} - -/// Wraps a [`Watcher`] to allow observing a derived value. -/// -/// See [`Watcher::map`]. -#[derive(derive_more::Debug, Clone)] -pub struct Map { - #[debug("Arc T + 'static>")] - map: Arc T + 'static>, - watcher: W, - current: T, -} - -impl Watcher for Map { - type Value = T; - - fn get(&self) -> Result { - Ok((self.map)(self.watcher.get()?)) - } - - fn poll_updated( - &mut self, - cx: &mut task::Context<'_>, - ) -> Poll> { - loop { - let value = n0_future::ready!(self.watcher.poll_updated(cx)?); - let mapped = (self.map)(value); - if mapped != self.current { - self.current = mapped.clone(); - return Poll::Ready(Ok(mapped)); - } else { - self.current = mapped; - } - } - } -} - -/// Future returning the next item after the current one in a [`Watcher`]. -/// -/// See [`Watcher::updated`]. -/// -/// # Cancel Safety -/// -/// This future is cancel-safe. -#[derive(Debug)] -pub struct NextFut<'a, W: Watcher> { - watcher: &'a mut W, -} - -impl Future for NextFut<'_, W> { - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { - self.watcher.poll_updated(cx) - } -} - -/// Future returning the current or next value that's [`Some`] value. -/// in a [`Watcher`]. -/// -/// See [`Watcher::initialized`]. -/// -/// # Cancel Safety -/// -/// This Future is cancel-safe. -#[derive(Debug)] -pub struct InitializedFut<'a, T, W: Watcher>> { - initial: Option>, - watcher: &'a mut W, -} - -impl> + Unpin> Future - for InitializedFut<'_, T, W> -{ - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { - if let Some(value) = self.as_mut().initial.take() { - return Poll::Ready(value); - } - loop { - if let Some(value) = n0_future::ready!(self.as_mut().watcher.poll_updated(cx)?) { - return Poll::Ready(Ok(value)); - } - } - } -} - -/// A stream for a [`Watcher`]'s next values. -/// -/// See [`Watcher::stream`] and [`Watcher::stream_updates_only`]. -/// -/// # Cancel Safety -/// -/// This stream is cancel-safe. -#[derive(Debug, Clone)] -pub struct Stream { - initial: Option, - watcher: W, -} - -impl n0_future::stream::Stream for Stream -where - W::Value: Unpin, -{ - type Item = W::Value; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { - if let Some(value) = self.as_mut().initial.take() { - return Poll::Ready(Some(value)); - } - match self.as_mut().watcher.poll_updated(cx) { - Poll::Ready(Ok(value)) => Poll::Ready(Some(value)), - Poll::Ready(Err(_)) => Poll::Ready(None), - Poll::Pending => Poll::Pending, - } - } -} - -/// The error for when a [`Watcher`] is disconnected from its underlying -/// [`Watchable`] value, because of that watchable having been dropped. -#[derive(Debug, Snafu)] -#[snafu(display("Watch lost connection to underlying Watchable, it was dropped"))] -pub struct Disconnected { - backtrace: Option, - #[snafu(implicit)] - span_trace: n0_snafu::SpanTrace, -} - -// Private: - -const INITIAL_EPOCH: u64 = 1; - -/// The shared state for a [`Watchable`]. -#[derive(Debug, Default)] -struct Shared { - /// The value to be watched and its current epoch. - state: RwLock>, - watchers: Mutex>, -} - -#[derive(Debug)] -struct State { - value: T, - epoch: u64, -} - -impl Default for State { - fn default() -> Self { - Self { - value: Default::default(), - epoch: INITIAL_EPOCH, - } - } -} - -impl Shared { - /// Returns the value, initialized or not. - fn get(&self) -> T { - self.state.read().expect("poisoned").value.clone() - } - - fn poll_updated(&self, cx: &mut task::Context<'_>, last_epoch: u64) -> Poll<(u64, T)> { - { - let state = self.state.read().expect("poisoned"); - let epoch = state.epoch; - - if last_epoch < epoch { - // Once initialized, our Option is never set back to None, but nevertheless - // this code is safer without relying on that invariant. - return Poll::Ready((epoch, state.value.clone())); - } - } - - self.watchers - .lock() - .expect("poisoned") - .push_back(cx.waker().to_owned()); - - #[cfg(iroh_loom)] - loom::thread::yield_now(); - - { - let state = self.state.read().expect("poisoned"); - let epoch = state.epoch; - - if last_epoch < epoch { - // Once initialized our Option is never set back to None, but nevertheless - // this code is safer without relying on that invariant. - return Poll::Ready((epoch, state.value.clone())); - } - } - - Poll::Pending - } -} - -#[cfg(test)] -mod tests { - use std::time::{Duration, Instant}; - - use n0_future::StreamExt; - use rand::{thread_rng, Rng}; - use tokio::task::JoinSet; - use tokio_util::sync::CancellationToken; - - use super::*; - - #[tokio::test] - async fn test_watcher() { - let cancel = CancellationToken::new(); - let watchable = Watchable::new(17); - - assert_eq!(watchable.watch().stream().next().await.unwrap(), 17); - - let start = Instant::now(); - // spawn watchers - let mut tasks = JoinSet::new(); - for i in 0..3 { - let mut watch = watchable.watch().stream(); - let cancel = cancel.clone(); - tasks.spawn(async move { - println!("[{i}] spawn"); - let mut expected_value = 17; - loop { - tokio::select! { - biased; - Some(value) = &mut watch.next() => { - println!("{:?} [{i}] update: {value}", start.elapsed()); - assert_eq!(value, expected_value); - if expected_value == 17 { - expected_value = 0; - } else { - expected_value += 1; - } - }, - _ = cancel.cancelled() => { - println!("{:?} [{i}] cancel", start.elapsed()); - assert_eq!(expected_value, 10); - break; - } - } - } - }); - } - for i in 0..3 { - let mut watch = watchable.watch().stream_updates_only(); - let cancel = cancel.clone(); - tasks.spawn(async move { - println!("[{i}] spawn"); - let mut expected_value = 0; - loop { - tokio::select! { - biased; - Some(value) = watch.next() => { - println!("{:?} [{i}] stream update: {value}", start.elapsed()); - assert_eq!(value, expected_value); - expected_value += 1; - }, - _ = cancel.cancelled() => { - println!("{:?} [{i}] cancel", start.elapsed()); - assert_eq!(expected_value, 10); - break; - } - else => { - panic!("stream died"); - } - } - } - }); - } - - // set value - for next_value in 0..10 { - let sleep = Duration::from_nanos(thread_rng().gen_range(0..100_000_000)); - println!("{:?} sleep {sleep:?}", start.elapsed()); - tokio::time::sleep(sleep).await; - - let changed = watchable.set(next_value); - println!("{:?} set {next_value} changed={changed:?}", start.elapsed()); - } - - println!("cancel"); - cancel.cancel(); - while let Some(res) = tasks.join_next().await { - res.expect("task failed"); - } - } - - #[test] - fn test_get() { - let watchable = Watchable::new(None); - assert!(watchable.get().is_none()); - - watchable.set(Some(1u8)).ok(); - assert_eq!(watchable.get(), Some(1u8)); - } - - #[tokio::test] - async fn test_initialize() { - let watchable = Watchable::new(None); - - let mut watcher = watchable.watch(); - let mut initialized = watcher.initialized(); - - let poll = n0_future::future::poll_once(&mut initialized).await; - assert!(poll.is_none()); - - watchable.set(Some(1u8)).ok(); - - let poll = n0_future::future::poll_once(&mut initialized).await; - assert_eq!(poll.unwrap().unwrap(), 1u8); - } - - #[tokio::test] - async fn test_initialize_already_init() { - let watchable = Watchable::new(Some(1u8)); - - let mut watcher = watchable.watch(); - let mut initialized = watcher.initialized(); - - let poll = n0_future::future::poll_once(&mut initialized).await; - assert_eq!(poll.unwrap().unwrap(), 1u8); - } - - #[test] - fn test_initialized_always_resolves() { - #[cfg(not(iroh_loom))] - use std::thread; - - #[cfg(iroh_loom)] - use loom::thread; - - let test_case = || { - let watchable = Watchable::>::new(None); - - let mut watch = watchable.watch(); - let thread = thread::spawn(move || n0_future::future::block_on(watch.initialized())); - - watchable.set(Some(42)).ok(); - - thread::yield_now(); - - let value: u8 = thread.join().unwrap().unwrap(); - - assert_eq!(value, 42); - }; - - #[cfg(iroh_loom)] - loom::model(test_case); - #[cfg(not(iroh_loom))] - test_case(); - } - - #[tokio::test(flavor = "multi_thread")] - async fn test_update_cancel_safety() { - let watchable = Watchable::new(0); - let mut watch = watchable.watch(); - const MAX: usize = 100_000; - - let handle = tokio::spawn(async move { - let mut last_observed = 0; - - while last_observed != MAX { - tokio::select! { - val = watch.updated() => { - let Ok(val) = val else { - return; - }; - - assert_ne!(val, last_observed, "never observe the same value twice, even with cancellation"); - last_observed = val; - } - _ = tokio::time::sleep(Duration::from_micros(thread_rng().gen_range(0..10_000))) => { - // We cancel the other future and start over again - continue; - } - } - } - }); - - for i in 1..=MAX { - watchable.set(i).ok(); - if thread_rng().gen_bool(0.2) { - tokio::task::yield_now().await; - } - } - - tokio::time::timeout(Duration::from_secs(10), handle) - .await - .unwrap() - .unwrap() - } -}