Skip to content

Commit 8133fbb

Browse files
risssondavepgreene
authored andcommitted
packages/ak-axum/accept/proxy_protocol: init (goauthentik#21319)
1 parent 910cf59 commit 8133fbb

3 files changed

Lines changed: 90 additions & 4 deletions

File tree

packages/ak-axum/src/accept/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
pub mod proxy_protocol;
12
pub mod tls;
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
use std::{io, time::Duration};
2+
3+
use ak_common::tokio::proxy_protocol::{ProxyProtocolStream, header::Header};
4+
use axum::{Extension, middleware::AddExtension};
5+
use axum_server::accept::{Accept, DefaultAcceptor};
6+
use futures::future::BoxFuture;
7+
use tokio::io::{AsyncRead, AsyncWrite};
8+
use tower::Layer as _;
9+
use tracing::instrument;
10+
11+
#[derive(Clone, Debug)]
12+
pub struct ProxyProtocolState {
13+
pub header: Option<Header<'static>>,
14+
}
15+
16+
#[derive(Clone)]
17+
pub(crate) struct ProxyProtocolAcceptor<A = DefaultAcceptor> {
18+
inner: A,
19+
parsing_timeout: Duration,
20+
}
21+
22+
impl ProxyProtocolAcceptor {
23+
pub(crate) fn new() -> Self {
24+
let inner = DefaultAcceptor::new();
25+
26+
#[cfg(not(test))]
27+
let parsing_timeout = Duration::from_secs(10);
28+
29+
// Don't force tests to wait too long
30+
#[cfg(test)]
31+
let parsing_timeout = Duration::from_secs(1);
32+
33+
Self {
34+
inner,
35+
parsing_timeout,
36+
}
37+
}
38+
}
39+
40+
impl Default for ProxyProtocolAcceptor {
41+
fn default() -> Self {
42+
Self::new()
43+
}
44+
}
45+
46+
impl<A> ProxyProtocolAcceptor<A> {
47+
pub(crate) fn acceptor<Acceptor>(self, acceptor: Acceptor) -> ProxyProtocolAcceptor<Acceptor> {
48+
ProxyProtocolAcceptor {
49+
inner: acceptor,
50+
parsing_timeout: self.parsing_timeout,
51+
}
52+
}
53+
}
54+
55+
impl<A, I, S> Accept<I, S> for ProxyProtocolAcceptor<A>
56+
where
57+
A: Accept<I, S> + Clone + Send + 'static,
58+
A::Stream: AsyncRead + AsyncWrite + Unpin + Send,
59+
A::Service: Send,
60+
A::Future: Send,
61+
I: AsyncRead + AsyncWrite + Unpin + Send + 'static,
62+
S: Send + 'static,
63+
{
64+
type Future = BoxFuture<'static, io::Result<(Self::Stream, Self::Service)>>;
65+
type Service = AddExtension<A::Service, ProxyProtocolState>;
66+
type Stream = ProxyProtocolStream<A::Stream>;
67+
68+
#[instrument(skip_all)]
69+
fn accept(&self, stream: I, service: S) -> Self::Future {
70+
let acceptor = self.inner.clone();
71+
72+
Box::pin(async move {
73+
let (stream, service) = acceptor.accept(stream, service).await?;
74+
let stream = ProxyProtocolStream::new(stream).await?;
75+
76+
let proxy_protocol_state = ProxyProtocolState {
77+
header: stream.header().cloned(),
78+
};
79+
80+
let service = Extension(proxy_protocol_state).layer(service);
81+
82+
Ok((stream, service))
83+
})
84+
}
85+
}

packages/ak-axum/src/server.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use axum_server::{
1212
use eyre::Result;
1313
use tracing::info;
1414

15-
use crate::accept::tls::TlsAcceptor;
15+
use crate::accept::{proxy_protocol::ProxyProtocolAcceptor, tls::TlsAcceptor};
1616

1717
async fn run_plain(
1818
arbiter: Arbiter,
@@ -27,7 +27,7 @@ async fn run_plain(
2727
arbiter.add_net_handle(handle.clone()).await;
2828

2929
let res = axum_server::Server::bind(addr)
30-
.acceptor(DefaultAcceptor::new())
30+
.acceptor(ProxyProtocolAcceptor::new().acceptor(DefaultAcceptor::new()))
3131
.handle(handle)
3232
.serve(router.into_make_service_with_connect_info::<net::SocketAddr>())
3333
.await;
@@ -121,9 +121,9 @@ async fn run_tls(
121121
arbiter.add_net_handle(handle.clone()).await;
122122

123123
axum_server::Server::bind(addr)
124-
.acceptor(TlsAcceptor::new(
124+
.acceptor(ProxyProtocolAcceptor::new().acceptor(TlsAcceptor::new(
125125
RustlsAcceptor::new(config).acceptor(DefaultAcceptor::new()),
126-
))
126+
)))
127127
.handle(handle)
128128
.serve(router.into_make_service_with_connect_info::<net::SocketAddr>())
129129
.await?;

0 commit comments

Comments
 (0)