diff --git a/Cargo.lock b/Cargo.lock index a311b0653a0..26fa64e8b13 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -296,6 +296,7 @@ checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" dependencies = [ "axum-core", "axum-macros", + "base64", "bytes", "form_urlencoded", "futures-util", @@ -314,8 +315,10 @@ dependencies = [ "serde_json", "serde_path_to_error", "serde_urlencoded", + "sha1 0.10.6", "sync_wrapper", "tokio", + "tokio-tungstenite", "tower", "tower-layer", "tower-service", @@ -456,6 +459,15 @@ dependencies = [ "cpufeatures 0.3.0", ] +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "block-buffer" version = "0.12.0" @@ -957,7 +969,7 @@ dependencies = [ "cfg-if", "cpufeatures 0.2.17", "curve25519-dalek-derive", - "digest", + "digest 0.11.2", "fiat-crypto", "rand_core 0.10.0", "rustc_version", @@ -1186,13 +1198,23 @@ version = "0.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8" +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer 0.10.4", + "crypto-common 0.1.7", +] + [[package]] name = "digest" version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4850db49bf08e663084f7fb5c87d202ef91a3907271aff24a94eb97ff039153c" dependencies = [ - "block-buffer", + "block-buffer 0.12.0", "const-oid", "crypto-common 0.2.1", ] @@ -2374,7 +2396,7 @@ dependencies = [ "data-encoding", "data-encoding-macro", "derive_more", - "digest", + "digest 0.11.2", "ed25519-dalek", "getrandom 0.4.2", "n0-error", @@ -2385,7 +2407,7 @@ dependencies = [ "serde", "serde_json", "serde_test", - "sha1", + "sha1 0.11.0", "sha2", "url", "zeroize", @@ -2530,6 +2552,7 @@ name = "iroh-relay" version = "0.98.0" dependencies = [ "ahash", + "axum", "blake3", "bytes", "cfg_aliases", @@ -2569,7 +2592,7 @@ dependencies = [ "serde", "serde_bytes", "serde_json", - "sha1", + "sha1 0.11.0", "simdutf8", "strum 0.28.0", "time", @@ -2797,9 +2820,9 @@ dependencies = [ [[package]] name = "lru" -version = "0.16.4" +version = "0.16.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f66e8d5d03f609abc3a39e6f08e4164ebf1447a732906d39eb9b99b7919ef39" +checksum = "a1dc47f592c06f33f8e3aea9591776ec7c9f9e4124778ff8a3c3b87159f7e593" [[package]] name = "lru" @@ -2835,7 +2858,7 @@ dependencies = [ "flume", "futures-lite", "getrandom 0.3.4", - "lru 0.16.4", + "lru 0.16.3", "serde", "serde_bencode", "serde_bytes", @@ -4621,6 +4644,17 @@ dependencies = [ "syn", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures 0.2.17", + "digest 0.10.7", +] + [[package]] name = "sha1" version = "0.11.0" @@ -4629,7 +4663,7 @@ checksum = "aacc4cc499359472b4abe1bf11d0b12e688af9a805fa5e3016f9a386dc2d0214" dependencies = [ "cfg-if", "cpufeatures 0.3.0", - "digest", + "digest 0.11.2", ] [[package]] @@ -4646,7 +4680,7 @@ checksum = "446ba717509524cb3f22f17ecc096f10f4822d76ab5c0b9822c5f9c284e825f4" dependencies = [ "cfg-if", "cpufeatures 0.3.0", - "digest", + "digest 0.11.2", ] [[package]] @@ -5173,6 +5207,18 @@ dependencies = [ "tokio-util", ] +[[package]] +name = "tokio-tungstenite" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d25a406cddcc431a75d3d9afc6a7c0f7428d4891dd973e4d54c56b46127bf857" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.18" @@ -5403,6 +5449,23 @@ dependencies = [ "linked-hash-map", ] +[[package]] +name = "tungstenite" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8628dcc84e5a09eb3d8423d6cb682965dea9133204e8fb3efee74c2a0c259442" +dependencies = [ + "bytes", + "data-encoding", + "http", + "httparse", + "log", + "rand 0.9.4", + "sha1 0.10.6", + "thiserror 2.0.18", + "utf-8", +] + [[package]] name = "typenum" version = "1.19.0" @@ -5474,6 +5537,12 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8_iter" version = "1.0.4" diff --git a/iroh-relay/Cargo.toml b/iroh-relay/Cargo.toml index 9963c1f0985..98fe8a02c3d 100644 --- a/iroh-relay/Cargo.toml +++ b/iroh-relay/Cargo.toml @@ -121,6 +121,8 @@ tokio = { version = "1", features = [ tracing-subscriber = { version = "0.3", features = ["env-filter"] } serde_json = "1" n0-tracing-test = "0.3" +# Used by the embedding integration tests. +axum = { version = "0.8", features = ["ws"] } [build-dependencies] cfg_aliases = "0.2.1" diff --git a/iroh-relay/src/server/http_server.rs b/iroh-relay/src/server/http_server.rs index 2c43ab3f289..ad7003f8b05 100644 --- a/iroh-relay/src/server/http_server.rs +++ b/iroh-relay/src/server/http_server.rs @@ -54,12 +54,14 @@ use crate::{ }, }; -// type BytesBody = http_body_util::Full; -pub(super) type BytesBody = Box< +/// Boxed HTTP response body produced by [`RelayServiceWithNotify`]. +pub type BytesBody = Box< dyn 'static + Send + Unpin + hyper::body::Body, >; -pub(super) type HyperError = Box; -pub(super) type HyperResult = std::result::Result; +/// Boxed error type returned from [`RelayServiceWithNotify`]'s [`hyper::service::Service`] impl. +pub type HyperError = Box; +/// Result alias for HTTP responses produced by [`RelayServiceWithNotify`]. +pub type HyperResult = std::result::Result; pub(super) type HyperHandler = Box< dyn Fn(Request, ResponseBuilder) -> HyperResult> + Send diff --git a/iroh-relay/tests/relay_axum.rs b/iroh-relay/tests/relay_axum.rs new file mode 100644 index 00000000000..9344da8776c --- /dev/null +++ b/iroh-relay/tests/relay_axum.rs @@ -0,0 +1,231 @@ +//! Embeds iroh-relay into an [`axum::Router`]. +//! +//! The relay handler upgrades the WebSocket via axum's [`WebSocketUpgrade`] +//! extractor, wraps the resulting [`WebSocket`] into a [`BytesStreamSink`] +//! as expected by iroh-relay's protocol handler, and then runs the +//! handshake plus client registration directly. +//! +//! [`BytesStreamSink`]: iroh_relay::protos::streams::BytesStreamSink + +#![cfg(feature = "server")] + +use std::{ + net::SocketAddr, + pin::Pin, + sync::Arc, + task::{Context, Poll}, + time::Duration, +}; + +use axum::{ + Router, + extract::{ + FromRequestParts, State, + ws::{Message, WebSocket, WebSocketUpgrade}, + }, + http::{StatusCode, header::ACCESS_CONTROL_ALLOW_ORIGIN}, + response::Response, + routing::get, +}; +use bytes::Bytes; +use iroh_base::{RelayUrl, SecretKey}; +use iroh_dns::dns::DnsResolver; +use iroh_relay::{ + ExportKeyingMaterial, KeyCache, + client::ClientBuilder, + http::{CLIENT_AUTH_HEADER, ProtocolVersion, RELAY_PATH, RELAY_PROBE_PATH}, + protos::{handshake, streams::StreamError}, + server::{ + AccessConfig, ClientRequest, Metrics, client::Config, clients::Clients, + streams::RelayedStream, + }, + tls::{CaRootsConfig, default_provider}, +}; +use n0_error::{AnyError, Result, StdResultExt}; +use n0_future::{Sink, Stream, task::AbortOnDropHandle}; +use n0_tracing_test::traced_test; +use rand::{RngExt, SeedableRng}; +use tokio::net::TcpListener; +use tracing::{trace, warn}; + +#[derive(Clone, Debug)] +struct RelayState { + key_cache: KeyCache, + access: Arc, + metrics: Arc, + clients: Clients, +} + +impl RelayState { + fn new() -> Self { + Self { + key_cache: KeyCache::new(1024), + access: Arc::new(AccessConfig::Everyone), + metrics: Arc::new(Metrics::default()), + clients: Clients::default(), + } + } +} + +async fn serve_axum() -> Result<(SocketAddr, AbortOnDropHandle<()>)> { + let state = RelayState::new(); + let router = Router::new() + .route(RELAY_PATH, get(relay_handler)) + .route(RELAY_PROBE_PATH, get(ping_handler)) + .with_state(state); + + let listener = TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + let task = tokio::spawn(async move { + let _ = axum::serve(listener, router.into_make_service()).await; + }); + Ok((addr, AbortOnDropHandle::new(task))) +} + +async fn relay_handler( + State(state): State, + request: axum::extract::Request, +) -> Result { + let (mut parts, _body) = request.into_parts(); + let ws = WebSocketUpgrade::from_request_parts(&mut parts, &state) + .await + .map_err(|_| StatusCode::BAD_REQUEST)?; + let client_auth_header = parts.headers.get(CLIENT_AUTH_HEADER).cloned(); + let ws = ws.protocols([ProtocolVersion::V2.to_str()]); + Ok(ws.on_upgrade(move |socket| async move { + if let Err(error) = handle_relay_websocket(socket, state, parts, client_auth_header).await { + warn!("relay websocket error: {error:#}"); + } + })) +} + +async fn ping_handler() -> impl axum::response::IntoResponse { + (StatusCode::OK, [(ACCESS_CONTROL_ALLOW_ORIGIN, "*")]) +} + +/// Bridges axum's [`WebSocket`] to iroh-relay's [`BytesStreamSink`]. +struct AxumWebSocketAdapter { + inner: Pin>, +} + +impl AxumWebSocketAdapter { + fn new(socket: WebSocket) -> Self { + Self { + inner: Box::pin(socket), + } + } +} + +impl Stream for AxumWebSocketAdapter { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + return match self.inner.as_mut().poll_next(cx) { + Poll::Ready(Some(Ok(Message::Binary(data)))) => Poll::Ready(Some(Ok(data))), + Poll::Ready(Some(Ok(Message::Close(_)))) => Poll::Ready(None), + Poll::Ready(Some(Ok(_))) => continue, + Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(AnyError::from_std(error)))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + }; + } + } +} + +impl Sink for AxumWebSocketAdapter { + type Error = StreamError; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner + .as_mut() + .poll_ready(cx) + .map_err(AnyError::from_std) + } + + fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { + self.inner + .as_mut() + .start_send(Message::Binary(item)) + .map_err(AnyError::from_std) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner + .as_mut() + .poll_flush(cx) + .map_err(AnyError::from_std) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner + .as_mut() + .poll_close(cx) + .map_err(AnyError::from_std) + } +} + +// Axum's WebSocket has no access to TLS keying material. +impl ExportKeyingMaterial for AxumWebSocketAdapter { + fn export_keying_material>( + &self, + _output: T, + _label: &[u8], + _context: Option<&[u8]>, + ) -> Option { + None + } +} + +async fn handle_relay_websocket( + socket: WebSocket, + state: RelayState, + request_parts: http::request::Parts, + client_auth_header: Option, +) -> Result<(), Box> { + let mut adapter = AxumWebSocketAdapter::new(socket); + let authentication = handshake::serverside(&mut adapter, client_auth_header).await?; + trace!(?authentication.mechanism, "verified authentication"); + + let request = ClientRequest::new(authentication.client_key, request_parts); + let is_authorized = state.access.is_allowed(&request).await; + let client_key = authentication + .authorize_if(is_authorized, &mut adapter) + .await?; + trace!("verified authorization"); + + let stream = RelayedStream::new(adapter, state.key_cache.clone()); + let config = Config::new(client_key, stream, ProtocolVersion::V2); + state.clients.register(config, state.metrics.clone()); + Ok(()) +} + +#[tokio::test] +#[traced_test] +async fn relay_embed_axum() -> Result<()> { + let _ = rustls::crypto::ring::default_provider().install_default(); + let (addr, _guard) = serve_axum().await?; + + let resp = reqwest::get(format!("http://{addr}/ping")) + .await + .std_context("ping request")?; + assert_eq!(resp.status(), 200); + + // Connect a relay client to `/relay` on the same axum-fronted port. + let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64); + let relay_url: RelayUrl = format!("http://{addr}").parse()?; + let tls_config = CaRootsConfig::default().client_config(default_provider())?; + tokio::time::timeout( + Duration::from_secs(5), + ClientBuilder::new( + relay_url, + SecretKey::from_bytes(&rng.random()), + DnsResolver::new(), + ) + .tls_client_config(tls_config) + .connect(), + ) + .await + .std_context("timeout")??; + Ok(()) +} diff --git a/iroh-relay/tests/relay_hyper.rs b/iroh-relay/tests/relay_hyper.rs new file mode 100644 index 00000000000..a5e9c54c59f --- /dev/null +++ b/iroh-relay/tests/relay_hyper.rs @@ -0,0 +1,123 @@ +//! Embeds the relay inside a plain hyper HTTP server. Routes `/relay` +//! through [`RelayServiceWithNotify`] and serves a `/ping` probe matching +//! the iroh-relay server's built-in probe. + +#![cfg(feature = "server")] + +use std::{convert::Infallible, net::SocketAddr, sync::Arc, time::Duration}; + +use bytes::Bytes; +use http::{HeaderMap, Method, Response, StatusCode, header::ACCESS_CONTROL_ALLOW_ORIGIN}; +use http_body_util::Full; +use hyper::{ + Request, + body::Incoming, + server::conn::http1, + service::{Service as _, service_fn}, +}; +use hyper_util::rt::TokioIo; +use iroh_base::{RelayUrl, SecretKey}; +use iroh_dns::dns::DnsResolver; +use iroh_relay::{ + KeyCache, + client::ClientBuilder, + http::{RELAY_PATH, RELAY_PROBE_PATH}, + server::{ + AccessConfig, Metrics, + http_server::{BytesBody, Handlers, RelayService, RelayServiceWithNotify}, + streams::MaybeTlsStream, + }, + tls::{CaRootsConfig, default_provider}, +}; +use n0_error::{Result, StdResultExt}; +use n0_future::task::AbortOnDropHandle; +use n0_tracing_test::traced_test; +use rand::{RngExt, SeedableRng}; +use tokio::{net::TcpListener, sync::Notify}; + +async fn dispatch(req: Request, service: RelayService) -> Response { + match (req.method(), req.uri().path()) { + (&Method::GET, RELAY_PROBE_PATH) => Response::builder() + .status(StatusCode::OK) + .header(ACCESS_CONTROL_ALLOW_ORIGIN, "*") + .body(Box::new(Full::new(Bytes::new())) as BytesBody) + .expect("valid response"), + (&Method::GET, RELAY_PATH) => RelayServiceWithNotify::new(service, Arc::new(Notify::new())) + .call(req) + .await + .expect("RelayServiceWithNotify::call returns Ok"), + _ => Response::builder() + .status(StatusCode::NOT_FOUND) + .body(Box::new(Full::new("not found".into())) as BytesBody) + .expect("valid response"), + } +} + +async fn serve_hyper() -> Result<(SocketAddr, AbortOnDropHandle<()>)> { + let service = RelayService::new( + Handlers::default(), + HeaderMap::new(), + None, + KeyCache::new(1024), + AccessConfig::Everyone, + Arc::new(Metrics::default()), + ); + + let listener = TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + let task = tokio::spawn(async move { + loop { + let Ok((stream, _)) = listener.accept().await else { + return; + }; + // The relay handler downcasts the hyper `Upgraded` back to + // `TokioIo`, so wrap the stream as such. + let stream = MaybeTlsStream::Plain(stream); + let service = service.clone(); + tokio::spawn(async move { + http1::Builder::new() + .serve_connection( + TokioIo::new(stream), + service_fn(move |req: Request| { + let service = service.clone(); + async move { Ok::<_, Infallible>(dispatch(req, service).await) } + }), + ) + .with_upgrades() + .await + .expect("serve_connection failed"); + }); + } + }); + Ok((addr, AbortOnDropHandle::new(task))) +} + +#[tokio::test] +#[traced_test] +async fn relay_embed_hyper() -> Result<()> { + let _ = rustls::crypto::ring::default_provider().install_default(); + let (addr, _guard) = serve_hyper().await?; + + let res = reqwest::get(format!("http://{addr}/ping")) + .await + .std_context("ping request")?; + assert_eq!(res.status(), 200); + + // Connect a relay client to `/relay`. + let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64); + let relay_url: RelayUrl = format!("http://{addr}").parse()?; + let tls_config = CaRootsConfig::default().client_config(default_provider())?; + tokio::time::timeout( + Duration::from_secs(5), + ClientBuilder::new( + relay_url, + SecretKey::from_bytes(&rng.random()), + DnsResolver::new(), + ) + .tls_client_config(tls_config) + .connect(), + ) + .await + .std_context("timeout")??; + Ok(()) +}