Skip to content

Commit 0fa14c5

Browse files
risssondavepgreene
authored andcommitted
packages/ak-axum/extract/client_ip: init (goauthentik#21321)
1 parent 06f80e5 commit 0fa14c5

7 files changed

Lines changed: 272 additions & 3 deletions

File tree

Cargo.lock

Lines changed: 27 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ axum-server = { version = "= 0.8.0", features = ["tls-rustls-no-provider"] }
2424
aws-lc-rs = { version = "= 1.16.2", features = ["fips"] }
2525
axum = { version = "= 0.8.8", features = ["http2", "macros", "ws"] }
2626
clap = { version = "= 4.6.0", features = ["derive", "env"] }
27+
client-ip = { version = "0.2.1", features = ["forwarded-header"] }
2728
colored = "= 3.1.1"
2829
config-rs = { package = "config", version = "= 0.15.22", default-features = false, features = [
2930
"json",

packages/ak-axum/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ publish.workspace = true
1313
ak-common.workspace = true
1414
axum-server.workspace = true
1515
axum.workspace = true
16+
client-ip.workspace = true
1617
durstr.workspace = true
1718
eyre.workspace = true
1819
futures.workspace = true
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
//! axum extractor and middleware to retrieve the client IP.
2+
3+
use std::net::{IpAddr, Ipv6Addr, SocketAddr};
4+
5+
use axum::{
6+
Extension, RequestPartsExt as _,
7+
extract::{ConnectInfo, FromRequestParts, Request},
8+
http::request::Parts,
9+
middleware::Next,
10+
response::Response,
11+
};
12+
use tracing::{Span, instrument};
13+
14+
use crate::{accept::proxy_protocol::ProxyProtocolState, extract::trusted_proxy::TrustedProxy};
15+
16+
/// Client IP.
17+
///
18+
/// The [`client_ip_middleware`] must be added to the router before using this extractor,
19+
/// otherwise this will result in requests erroring.
20+
#[derive(Clone, Copy, Debug)]
21+
pub struct ClientIp(pub IpAddr);
22+
23+
impl<S> FromRequestParts<S> for ClientIp
24+
where
25+
S: Send + Sync,
26+
{
27+
type Rejection = <Extension<Self> as FromRequestParts<S>>::Rejection;
28+
29+
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
30+
Extension::<Self>::from_request_parts(parts, state)
31+
.await
32+
.map(|Extension(client_ip)| client_ip)
33+
}
34+
}
35+
36+
/// Get the client IP from the request.
37+
#[instrument(skip_all)]
38+
async fn extract_client_ip(parts: &mut Parts) -> IpAddr {
39+
let is_trusted = parts
40+
.extract::<TrustedProxy>()
41+
.await
42+
.unwrap_or(TrustedProxy(false))
43+
.0;
44+
45+
if is_trusted {
46+
if let Ok(ip) = client_ip::rightmost_x_forwarded_for(&parts.headers) {
47+
return ip;
48+
}
49+
50+
if let Ok(ip) = client_ip::x_real_ip(&parts.headers) {
51+
return ip;
52+
}
53+
54+
if let Ok(ip) = client_ip::rightmost_forwarded(&parts.headers) {
55+
return ip;
56+
}
57+
58+
if let Ok(Extension(proxy_protocol_state)) =
59+
parts.extract::<Extension<ProxyProtocolState>>().await
60+
&& let Some(header) = &proxy_protocol_state.header
61+
&& let Some(addr) = header.proxied_address()
62+
{
63+
return addr.source.ip();
64+
}
65+
}
66+
67+
if let Ok(ConnectInfo(addr)) = parts.extract::<ConnectInfo<SocketAddr>>().await {
68+
addr.ip()
69+
} else {
70+
// No connect info means we received a request via a Unix socket, hence localhost
71+
// as default.
72+
Ipv6Addr::LOCALHOST.into()
73+
}
74+
}
75+
76+
/// Middleware required by the [`ClientIp`] extractor.
77+
///
78+
/// Use with [`axum::middleware::from_fn`].
79+
pub async fn client_ip_middleware(request: Request, next: Next) -> Response {
80+
let (mut parts, body) = request.into_parts();
81+
82+
let client_ip = extract_client_ip(&mut parts).await;
83+
Span::current().record("remote", client_ip.to_string());
84+
parts.extensions.insert::<ClientIp>(ClientIp(client_ip));
85+
86+
let request = Request::from_parts(parts, body);
87+
88+
next.run(request).await
89+
}
90+
91+
#[cfg(test)]
92+
mod tests {
93+
use std::net::Ipv4Addr;
94+
95+
use axum::{body::Body, http::Request};
96+
97+
use super::*;
98+
99+
#[tokio::test]
100+
async fn x_forwarded_for_trusted() {
101+
let (mut parts, _) = Request::builder()
102+
.uri("http://example.com/path")
103+
.header("x-forwarded-for", "192.0.2.51, 192.0.2.42")
104+
.extension(TrustedProxy(true))
105+
.body(Body::empty())
106+
.expect("Failed to create request")
107+
.into_parts();
108+
109+
let client_ip = extract_client_ip(&mut parts).await;
110+
111+
assert_eq!(client_ip, Ipv4Addr::new(192, 0, 2, 42),);
112+
}
113+
114+
#[tokio::test]
115+
async fn x_real_ip_trusted() {
116+
let (mut parts, _) = Request::builder()
117+
.uri("http://example.com/path")
118+
.header("x-real-ip", "192.0.2.42")
119+
.extension(TrustedProxy(true))
120+
.body(Body::empty())
121+
.expect("Failed to create request")
122+
.into_parts();
123+
124+
let client_ip = extract_client_ip(&mut parts).await;
125+
126+
assert_eq!(client_ip, Ipv4Addr::new(192, 0, 2, 42),);
127+
}
128+
129+
#[tokio::test]
130+
async fn forwarded_header_trusted() {
131+
let (mut parts, _) = Request::builder()
132+
.uri("http://example.com/path")
133+
.header("forwarded", "for=192.0.2.42")
134+
.extension(TrustedProxy(true))
135+
.body(Body::empty())
136+
.expect("Failed to create request")
137+
.into_parts();
138+
139+
let client_ip = extract_client_ip(&mut parts).await;
140+
141+
assert_eq!(client_ip, Ipv4Addr::new(192, 0, 2, 42),);
142+
}
143+
144+
#[tokio::test]
145+
async fn from_connect_info() {
146+
let connect_addr: SocketAddr = "192.0.2.42:34932"
147+
.parse()
148+
.expect("Failed to parse socket address");
149+
let (mut parts, _) = Request::builder()
150+
.uri("http://example.com/path")
151+
.extension(ConnectInfo(connect_addr))
152+
.extension(TrustedProxy(false))
153+
.body(Body::empty())
154+
.expect("Failed to create request")
155+
.into_parts();
156+
157+
let client_ip = extract_client_ip(&mut parts).await;
158+
159+
assert_eq!(client_ip, Ipv4Addr::new(192, 0, 2, 42),);
160+
}
161+
162+
#[tokio::test]
163+
async fn headers_untrusted() {
164+
let (mut parts, _) = Request::builder()
165+
.uri("http://example.com/path")
166+
.header("x-forwarded-for", "192.0.2.42")
167+
.extension(TrustedProxy(false))
168+
.body(Body::empty())
169+
.expect("Failed to create request")
170+
.into_parts();
171+
172+
let client_ip = extract_client_ip(&mut parts).await;
173+
174+
assert_eq!(client_ip, Ipv6Addr::LOCALHOST);
175+
}
176+
177+
#[tokio::test]
178+
async fn priority_order() {
179+
// Test that X-Forwarded-For takes priority over other headers when trusted
180+
let (mut parts, _) = Request::builder()
181+
.uri("http://example.com/path")
182+
.header("x-forwarded-for", "192.0.2.1")
183+
.header("x-real-ip", "192.0.2.2")
184+
.header("forwarded", "for=192.0.2.3")
185+
.extension(TrustedProxy(true))
186+
.body(Body::empty())
187+
.expect("Failed to create request")
188+
.into_parts();
189+
190+
let client_ip = extract_client_ip(&mut parts).await;
191+
192+
assert_eq!(client_ip, Ipv4Addr::new(192, 0, 2, 1),);
193+
}
194+
195+
#[tokio::test]
196+
async fn no_ip_found() {
197+
let (mut parts, _) = Request::builder()
198+
.uri("http://example.com/path")
199+
.body(Body::empty())
200+
.expect("Failed to create request")
201+
.into_parts();
202+
203+
let client_ip = extract_client_ip(&mut parts).await;
204+
205+
assert_eq!(client_ip, Ipv6Addr::LOCALHOST);
206+
}
207+
208+
#[tokio::test]
209+
async fn ipv6() {
210+
let (mut parts, _) = Request::builder()
211+
.uri("http://example.com/path")
212+
.header("x-forwarded-for", "2001:db8::42")
213+
.extension(TrustedProxy(true))
214+
.body(Body::empty())
215+
.expect("Failed to create request")
216+
.into_parts();
217+
218+
let client_ip = extract_client_ip(&mut parts).await;
219+
220+
assert_eq!(client_ip, Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0x42),);
221+
}
222+
223+
#[tokio::test]
224+
async fn multiple_x_forwarded_for() {
225+
let (mut parts, _) = Request::builder()
226+
.uri("http://example.com/path")
227+
.header("x-forwarded-for", "192.0.2.1, 192.0.2.2, 192.0.2.3")
228+
.extension(TrustedProxy(true))
229+
.body(Body::empty())
230+
.expect("Failed to create request")
231+
.into_parts();
232+
233+
let client_ip = extract_client_ip(&mut parts).await;
234+
235+
assert_eq!(client_ip, Ipv4Addr::new(192, 0, 2, 3),);
236+
}
237+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
//! axum extractors to get information about a request.
22
3+
pub mod client_ip;
34
pub mod trusted_proxy;

packages/ak-axum/src/router.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use tower::ServiceBuilder;
66
use tower_http::timeout::TimeoutLayer;
77

88
use crate::{
9-
extract::trusted_proxy::trusted_proxy_middleware,
9+
extract::{client_ip::client_ip_middleware, trusted_proxy::trusted_proxy_middleware},
1010
tracing::{span_middleware, tracing_middleware},
1111
};
1212

@@ -27,7 +27,8 @@ pub fn wrap_router(router: Router, with_tracing: bool) -> Router {
2727
timeout,
2828
))
2929
.layer(from_fn(span_middleware))
30-
.layer(from_fn(trusted_proxy_middleware));
30+
.layer(from_fn(trusted_proxy_middleware))
31+
.layer(from_fn(client_ip_middleware));
3132
if with_tracing {
3233
router.layer(service_builder.layer(from_fn(tracing_middleware)))
3334
} else {

packages/ak-axum/src/tracing.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use std::collections::HashMap;
55
use ak_common::config;
66
use axum::{extract::Request, middleware::Next, response::Response};
77
use tokio::time::Instant;
8-
use tracing::{Instrument as _, info, info_span, trace};
8+
use tracing::{Instrument as _, field, info, info_span, trace};
99

1010
/// Create a [`tracing::Span`] for requests.
1111
pub(crate) async fn span_middleware(request: Request, next: Next) -> Response {
@@ -27,6 +27,7 @@ pub(crate) async fn span_middleware(request: Request, next: Next) -> Response {
2727
"request",
2828
path = %request.uri(),
2929
method = %request.method(),
30+
remote = field::Empty,
3031
http_headers = ?http_headers,
3132
);
3233
next.run(request).instrument(span).await

0 commit comments

Comments
 (0)