Skip to content

Commit 4d554c3

Browse files
committed
feat(client): add some general HTTP/1 client middleware
1 parent 4595a08 commit 4d554c3

File tree

2 files changed

+156
-0
lines changed

2 files changed

+156
-0
lines changed

src/client/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,7 @@ pub mod legacy;
77
#[cfg(feature = "client-pool")]
88
pub mod pool;
99

10+
pub mod service;
11+
1012
#[cfg(feature = "client-proxy")]
1113
pub mod proxy;

src/client/service.rs

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
//! Middleware services for normalizing request URIs and Host headers.
2+
3+
use std::task::{Context, Poll};
4+
5+
use http::header::{HeaderValue, HOST};
6+
use http::{Method, Request, Uri};
7+
use tower_service::Service;
8+
9+
/// A middleware that ensures the Host header matches the URI's authority.
10+
///
11+
/// Particularly useful for HTTP/1 clients and proxies, where the Host
12+
/// header is mandatory and should be derived from the request URI.
13+
#[derive(Clone, Debug)]
14+
pub struct SetHost<S> {
15+
inner: S,
16+
}
17+
18+
/// A middleware that modifies the request target for HTTP/1 semantics.
19+
///
20+
/// Ensures CONNECT uses authority-form, and all other methods use origin-form.
21+
#[derive(Clone, Debug)]
22+
pub struct Http1RequestTarget<S> {
23+
inner: S,
24+
}
25+
26+
// ===== impl SetHost =====
27+
28+
impl<S> SetHost<S> {
29+
/// Create a new `SetHost` middleware wrapping the given service.
30+
pub fn new(inner: S) -> Self {
31+
SetHost { inner }
32+
}
33+
34+
/// Access the inner service.
35+
pub fn inner(&self) -> &S {
36+
&self.inner
37+
}
38+
}
39+
40+
impl<S, ReqBody> Service<Request<ReqBody>> for SetHost<S>
41+
where
42+
S: Service<Request<ReqBody>>,
43+
{
44+
type Response = S::Response;
45+
type Error = S::Error;
46+
type Future = S::Future;
47+
48+
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
49+
self.inner.poll_ready(cx)
50+
}
51+
52+
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
53+
if req.uri().authority().is_some() {
54+
let uri = req.uri().clone();
55+
req.headers_mut().entry(HOST).or_insert_with(|| {
56+
let hostname = uri.host().expect("authority implies host");
57+
if let Some(port) = get_non_default_port(&uri) {
58+
let s = format!("{hostname}:{port}");
59+
HeaderValue::from_str(&s)
60+
} else {
61+
HeaderValue::from_str(hostname)
62+
}
63+
.expect("uri host is valid header value")
64+
});
65+
}
66+
self.inner.call(req)
67+
}
68+
}
69+
70+
fn get_non_default_port(uri: &Uri) -> Option<http::uri::Port<&str>> {
71+
match (uri.port().map(|p| p.as_u16()), is_schema_secure(uri)) {
72+
(Some(443), true) => None,
73+
(Some(80), false) => None,
74+
_ => uri.port(),
75+
}
76+
}
77+
78+
fn is_schema_secure(uri: &Uri) -> bool {
79+
uri.scheme_str()
80+
.map(|scheme_str| matches!(scheme_str, "wss" | "https"))
81+
.unwrap_or_default()
82+
}
83+
84+
// ===== impl Http1RequestTarget =====
85+
86+
impl<S> Http1RequestTarget<S> {
87+
/// Create a new `Http1RequestTarget` middleware wrapping the given service.
88+
pub fn new(inner: S) -> Self {
89+
Http1RequestTarget { inner }
90+
}
91+
92+
/// Access the inner service.
93+
pub fn inner(&self) -> &S {
94+
&self.inner
95+
}
96+
}
97+
98+
impl<S, ReqBody> Service<Request<ReqBody>> for Http1RequestTarget<S>
99+
where
100+
S: Service<Request<ReqBody>>,
101+
{
102+
type Response = S::Response;
103+
type Error = S::Error;
104+
type Future = S::Future;
105+
106+
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
107+
self.inner.poll_ready(cx)
108+
}
109+
110+
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
111+
// CONNECT always sends authority-form, so check it first...
112+
if req.method() == Method::CONNECT {
113+
authority_form(req.uri_mut());
114+
} else {
115+
origin_form(req.uri_mut());
116+
}
117+
self.inner.call(req)
118+
}
119+
}
120+
121+
fn origin_form(uri: &mut Uri) {
122+
let path = match uri.path_and_query() {
123+
Some(path) if path.as_str() != "/" => {
124+
let mut parts = ::http::uri::Parts::default();
125+
parts.path_and_query = Some(path.clone());
126+
Uri::from_parts(parts).expect("path is valid uri")
127+
}
128+
_none_or_just_slash => {
129+
debug_assert!(Uri::default() == "/");
130+
Uri::default()
131+
}
132+
};
133+
*uri = path
134+
}
135+
136+
fn authority_form(uri: &mut Uri) {
137+
if let Some(path) = uri.path_and_query() {
138+
// `https://hyper.rs` would parse with `/` path, don't
139+
// annoy people about that...
140+
if path != "/" {
141+
tracing::debug!("HTTP/1.1 CONNECT request stripping path: {:?}", path);
142+
}
143+
}
144+
*uri = match uri.authority() {
145+
Some(auth) => {
146+
let mut parts = ::http::uri::Parts::default();
147+
parts.authority = Some(auth.clone());
148+
Uri::from_parts(parts).expect("authority is valid")
149+
}
150+
None => {
151+
unreachable!("authority_form with relative uri");
152+
}
153+
};
154+
}

0 commit comments

Comments
 (0)