diff --git a/Cargo.lock b/Cargo.lock index eb28c635c0c2..5c97db4ee8ce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -130,6 +130,7 @@ dependencies = [ "authentik-common", "axum", "axum-server", + "client-ip", "durstr", "eyre", "futures", @@ -517,6 +518,16 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9" +[[package]] +name = "client-ip" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39d2056bf065c8b4bce5a8898d40e175211ff4410add2a84d695845d3937c729" +dependencies = [ + "forwarded-header-value", + "http", +] + [[package]] name = "cmake" version = "0.1.57" @@ -857,6 +868,16 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "forwarded-header-value" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8835f84f38484cc86f110a805655697908257fb9a7af005234060891557198e9" +dependencies = [ + "nonempty", + "thiserror 1.0.69", +] + [[package]] name = "fs-err" version = "3.3.0" @@ -1684,6 +1705,12 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nonempty" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9e591e719385e6ebaeb5ce5d3887f7d5676fceca6411d1925ccc95745f3d6f7" + [[package]] name = "notify" version = "8.2.0" diff --git a/Cargo.toml b/Cargo.toml index bac8aadf0077..4a752f49ce5c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ axum-server = { version = "= 0.8.0", features = ["tls-rustls-no-provider"] } aws-lc-rs = { version = "= 1.16.2", features = ["fips"] } axum = { version = "= 0.8.8", features = ["http2", "macros", "ws"] } clap = { version = "= 4.6.0", features = ["derive", "env"] } +client-ip = { version = "0.2.1", features = ["forwarded-header"] } colored = "= 3.1.1" config-rs = { package = "config", version = "= 0.15.22", default-features = false, features = [ "json", diff --git a/packages/ak-axum/Cargo.toml b/packages/ak-axum/Cargo.toml index ce4872eb4473..3eee20cc9704 100644 --- a/packages/ak-axum/Cargo.toml +++ b/packages/ak-axum/Cargo.toml @@ -13,6 +13,7 @@ publish.workspace = true ak-common.workspace = true axum-server.workspace = true axum.workspace = true +client-ip.workspace = true durstr.workspace = true eyre.workspace = true futures.workspace = true diff --git a/packages/ak-axum/src/extract/client_ip.rs b/packages/ak-axum/src/extract/client_ip.rs new file mode 100644 index 000000000000..698f512205b2 --- /dev/null +++ b/packages/ak-axum/src/extract/client_ip.rs @@ -0,0 +1,237 @@ +//! axum extractor and middleware to retrieve the client IP. + +use std::net::{IpAddr, Ipv6Addr, SocketAddr}; + +use axum::{ + Extension, RequestPartsExt as _, + extract::{ConnectInfo, FromRequestParts, Request}, + http::request::Parts, + middleware::Next, + response::Response, +}; +use tracing::{Span, instrument}; + +use crate::{accept::proxy_protocol::ProxyProtocolState, extract::trusted_proxy::TrustedProxy}; + +/// Client IP. +/// +/// The [`client_ip_middleware`] must be added to the router before using this extractor, +/// otherwise this will result in requests erroring. +#[derive(Clone, Copy, Debug)] +pub struct ClientIp(pub IpAddr); + +impl FromRequestParts for ClientIp +where + S: Send + Sync, +{ + type Rejection = as FromRequestParts>::Rejection; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + Extension::::from_request_parts(parts, state) + .await + .map(|Extension(client_ip)| client_ip) + } +} + +/// Get the client IP from the request. +#[instrument(skip_all)] +async fn extract_client_ip(parts: &mut Parts) -> IpAddr { + let is_trusted = parts + .extract::() + .await + .unwrap_or(TrustedProxy(false)) + .0; + + if is_trusted { + if let Ok(ip) = client_ip::rightmost_x_forwarded_for(&parts.headers) { + return ip; + } + + if let Ok(ip) = client_ip::x_real_ip(&parts.headers) { + return ip; + } + + if let Ok(ip) = client_ip::rightmost_forwarded(&parts.headers) { + return ip; + } + + if let Ok(Extension(proxy_protocol_state)) = + parts.extract::>().await + && let Some(header) = &proxy_protocol_state.header + && let Some(addr) = header.proxied_address() + { + return addr.source.ip(); + } + } + + if let Ok(ConnectInfo(addr)) = parts.extract::>().await { + addr.ip() + } else { + // No connect info means we received a request via a Unix socket, hence localhost + // as default. + Ipv6Addr::LOCALHOST.into() + } +} + +/// Middleware required by the [`ClientIp`] extractor. +/// +/// Use with [`axum::middleware::from_fn`]. +pub async fn client_ip_middleware(request: Request, next: Next) -> Response { + let (mut parts, body) = request.into_parts(); + + let client_ip = extract_client_ip(&mut parts).await; + Span::current().record("remote", client_ip.to_string()); + parts.extensions.insert::(ClientIp(client_ip)); + + let request = Request::from_parts(parts, body); + + next.run(request).await +} + +#[cfg(test)] +mod tests { + use std::net::Ipv4Addr; + + use axum::{body::Body, http::Request}; + + use super::*; + + #[tokio::test] + async fn x_forwarded_for_trusted() { + let (mut parts, _) = Request::builder() + .uri("http://example.com/path") + .header("x-forwarded-for", "192.0.2.51, 192.0.2.42") + .extension(TrustedProxy(true)) + .body(Body::empty()) + .expect("Failed to create request") + .into_parts(); + + let client_ip = extract_client_ip(&mut parts).await; + + assert_eq!(client_ip, Ipv4Addr::new(192, 0, 2, 42),); + } + + #[tokio::test] + async fn x_real_ip_trusted() { + let (mut parts, _) = Request::builder() + .uri("http://example.com/path") + .header("x-real-ip", "192.0.2.42") + .extension(TrustedProxy(true)) + .body(Body::empty()) + .expect("Failed to create request") + .into_parts(); + + let client_ip = extract_client_ip(&mut parts).await; + + assert_eq!(client_ip, Ipv4Addr::new(192, 0, 2, 42),); + } + + #[tokio::test] + async fn forwarded_header_trusted() { + let (mut parts, _) = Request::builder() + .uri("http://example.com/path") + .header("forwarded", "for=192.0.2.42") + .extension(TrustedProxy(true)) + .body(Body::empty()) + .expect("Failed to create request") + .into_parts(); + + let client_ip = extract_client_ip(&mut parts).await; + + assert_eq!(client_ip, Ipv4Addr::new(192, 0, 2, 42),); + } + + #[tokio::test] + async fn from_connect_info() { + let connect_addr: SocketAddr = "192.0.2.42:34932" + .parse() + .expect("Failed to parse socket address"); + let (mut parts, _) = Request::builder() + .uri("http://example.com/path") + .extension(ConnectInfo(connect_addr)) + .extension(TrustedProxy(false)) + .body(Body::empty()) + .expect("Failed to create request") + .into_parts(); + + let client_ip = extract_client_ip(&mut parts).await; + + assert_eq!(client_ip, Ipv4Addr::new(192, 0, 2, 42),); + } + + #[tokio::test] + async fn headers_untrusted() { + let (mut parts, _) = Request::builder() + .uri("http://example.com/path") + .header("x-forwarded-for", "192.0.2.42") + .extension(TrustedProxy(false)) + .body(Body::empty()) + .expect("Failed to create request") + .into_parts(); + + let client_ip = extract_client_ip(&mut parts).await; + + assert_eq!(client_ip, Ipv6Addr::LOCALHOST); + } + + #[tokio::test] + async fn priority_order() { + // Test that X-Forwarded-For takes priority over other headers when trusted + let (mut parts, _) = Request::builder() + .uri("http://example.com/path") + .header("x-forwarded-for", "192.0.2.1") + .header("x-real-ip", "192.0.2.2") + .header("forwarded", "for=192.0.2.3") + .extension(TrustedProxy(true)) + .body(Body::empty()) + .expect("Failed to create request") + .into_parts(); + + let client_ip = extract_client_ip(&mut parts).await; + + assert_eq!(client_ip, Ipv4Addr::new(192, 0, 2, 1),); + } + + #[tokio::test] + async fn no_ip_found() { + let (mut parts, _) = Request::builder() + .uri("http://example.com/path") + .body(Body::empty()) + .expect("Failed to create request") + .into_parts(); + + let client_ip = extract_client_ip(&mut parts).await; + + assert_eq!(client_ip, Ipv6Addr::LOCALHOST); + } + + #[tokio::test] + async fn ipv6() { + let (mut parts, _) = Request::builder() + .uri("http://example.com/path") + .header("x-forwarded-for", "2001:db8::42") + .extension(TrustedProxy(true)) + .body(Body::empty()) + .expect("Failed to create request") + .into_parts(); + + let client_ip = extract_client_ip(&mut parts).await; + + assert_eq!(client_ip, Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0x42),); + } + + #[tokio::test] + async fn multiple_x_forwarded_for() { + let (mut parts, _) = Request::builder() + .uri("http://example.com/path") + .header("x-forwarded-for", "192.0.2.1, 192.0.2.2, 192.0.2.3") + .extension(TrustedProxy(true)) + .body(Body::empty()) + .expect("Failed to create request") + .into_parts(); + + let client_ip = extract_client_ip(&mut parts).await; + + assert_eq!(client_ip, Ipv4Addr::new(192, 0, 2, 3),); + } +} diff --git a/packages/ak-axum/src/extract/mod.rs b/packages/ak-axum/src/extract/mod.rs index 2888d8d37235..84430ddd6ac1 100644 --- a/packages/ak-axum/src/extract/mod.rs +++ b/packages/ak-axum/src/extract/mod.rs @@ -1,3 +1,4 @@ //! axum extractors to get information about a request. +pub mod client_ip; pub mod trusted_proxy; diff --git a/packages/ak-axum/src/router.rs b/packages/ak-axum/src/router.rs index 6f678a85b31b..a86cb70c7237 100644 --- a/packages/ak-axum/src/router.rs +++ b/packages/ak-axum/src/router.rs @@ -6,7 +6,7 @@ use tower::ServiceBuilder; use tower_http::timeout::TimeoutLayer; use crate::{ - extract::trusted_proxy::trusted_proxy_middleware, + extract::{client_ip::client_ip_middleware, trusted_proxy::trusted_proxy_middleware}, tracing::{span_middleware, tracing_middleware}, }; @@ -27,7 +27,8 @@ pub fn wrap_router(router: Router, with_tracing: bool) -> Router { timeout, )) .layer(from_fn(span_middleware)) - .layer(from_fn(trusted_proxy_middleware)); + .layer(from_fn(trusted_proxy_middleware)) + .layer(from_fn(client_ip_middleware)); if with_tracing { router.layer(service_builder.layer(from_fn(tracing_middleware))) } else { diff --git a/packages/ak-axum/src/tracing.rs b/packages/ak-axum/src/tracing.rs index 3fd8a989e9bc..5f2bf196fee5 100644 --- a/packages/ak-axum/src/tracing.rs +++ b/packages/ak-axum/src/tracing.rs @@ -5,7 +5,7 @@ use std::collections::HashMap; use ak_common::config; use axum::{extract::Request, middleware::Next, response::Response}; use tokio::time::Instant; -use tracing::{Instrument as _, info, info_span, trace}; +use tracing::{Instrument as _, field, info, info_span, trace}; /// Create a [`tracing::Span`] for requests. pub(crate) async fn span_middleware(request: Request, next: Next) -> Response { @@ -27,6 +27,7 @@ pub(crate) async fn span_middleware(request: Request, next: Next) -> Response { "request", path = %request.uri(), method = %request.method(), + remote = field::Empty, http_headers = ?http_headers, ); next.run(request).instrument(span).await