Skip to content

Commit cec4ef4

Browse files
authored
remove sensitive headers when redirecting to foreign uri (#1208)
* remove sensitive headers when redirecting to foreign uri * clippy fix
1 parent 99592b9 commit cec4ef4

1 file changed

Lines changed: 65 additions & 13 deletions

File tree

client/src/middleware/redirect.rs

Lines changed: 65 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,46 @@ use crate::{
22
body::BoxBody,
33
error::{Error, InvalidUri},
44
http::{
5-
header::{CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE, LOCATION, TRANSFER_ENCODING},
5+
header::{
6+
AUTHORIZATION, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE, COOKIE, LOCATION, PROXY_AUTHORIZATION,
7+
TRANSFER_ENCODING,
8+
},
69
Method, StatusCode, Uri,
710
},
811
response::Response,
912
service::{Service, ServiceRequest},
1013
};
1114

1215
/// middleware for following redirect response.
13-
pub struct FollowRedirect<S> {
16+
pub struct FollowRedirect<S, const MAX_COUNT: usize = 10> {
1417
service: S,
1518
}
1619

1720
impl<S> FollowRedirect<S> {
18-
pub fn new(service: S) -> Self {
21+
/// construct redirect following middleware for client.
22+
///
23+
/// # Examples:
24+
/// ```rust
25+
/// # use xitca_client::{ClientBuilder, middleware::FollowRedirect};
26+
/// let builder = ClientBuilder::new()
27+
/// .middleware(FollowRedirect::new);
28+
/// ```
29+
pub const fn new(service: S) -> Self {
1930
Self { service }
2031
}
2132
}
2233

23-
impl<'r, 'c, S> Service<ServiceRequest<'r, 'c>> for FollowRedirect<S>
34+
impl<S, const MAX: usize> FollowRedirect<S, MAX> {
35+
/// set max depth of redirect following for request. when max value is reached the redirect following
36+
/// would stop and the most recent response will be returned as output.
37+
///
38+
/// Default to 10 times.
39+
pub fn max<const MAX2: usize>(self) -> FollowRedirect<S, MAX2> {
40+
FollowRedirect { service: self.service }
41+
}
42+
}
43+
44+
impl<'r, 'c, S, const MAX: usize> Service<ServiceRequest<'r, 'c>> for FollowRedirect<S, MAX>
2445
where
2546
S: for<'r2, 'c2> Service<ServiceRequest<'r2, 'c2>, Response = Response, Error = Error> + Send + Sync,
2647
{
@@ -33,8 +54,15 @@ where
3354
let mut method = req.method().clone();
3455
let mut uri = req.uri().clone();
3556
let ext = req.extensions().clone();
57+
let mut count = 0;
58+
3659
loop {
3760
let mut res = self.service.call(ServiceRequest { req, client, timeout }).await?;
61+
62+
if count == MAX {
63+
return Ok(res);
64+
}
65+
3866
match res.status() {
3967
StatusCode::MOVED_PERMANENTLY | StatusCode::FOUND | StatusCode::SEE_OTHER => {
4068
if method != Method::HEAD {
@@ -63,6 +91,13 @@ where
6391
.parse::<Uri>()?
6492
.into_parts();
6593

94+
// remove authenticated headers when redirected to different scheme/authority
95+
if parts_location.scheme != parts.scheme || parts_location.authority != parts.authority {
96+
headers.remove(AUTHORIZATION);
97+
headers.remove(PROXY_AUTHORIZATION);
98+
headers.remove(COOKIE);
99+
}
100+
66101
let mut uri_builder = Uri::builder();
67102

68103
if let Some(a) = parts_location.authority.or(parts.authority) {
@@ -80,6 +115,8 @@ where
80115
*req.method_mut() = method.clone();
81116
*req.headers_mut() = headers.clone();
82117
*req.extensions_mut() = ext.clone();
118+
119+
count += 1;
83120
}
84121
}
85122
}
@@ -98,14 +135,9 @@ mod test {
98135
async fn redirect() {
99136
let (handle, service) = mock_service();
100137

101-
let redirect = FollowRedirect::new(service);
102-
103-
let mut req = http::Request::builder()
104-
.uri("http://foo.bar/foo")
105-
.body(Default::default())
106-
.unwrap();
138+
let redirect = FollowRedirect::new(service).max::<1>();
107139

108-
let req = handle.mock(&mut req, |req| match req.uri().path() {
140+
let handler = |req: http::Request<BoxBody>| match req.uri().path() {
109141
"/foo" => Ok(http::Response::builder()
110142
.status(StatusCode::SEE_OTHER)
111143
.header("location", "/bar")
@@ -115,11 +147,31 @@ mod test {
115147
.status(StatusCode::IM_A_TEAPOT)
116148
.body(ResponseBody::Eof)
117149
.unwrap()),
150+
"/fur" => Ok(http::Response::builder()
151+
.status(StatusCode::SEE_OTHER)
152+
.header("location", "/foo")
153+
.body(ResponseBody::Eof)
154+
.unwrap()),
118155
p => panic!("unexpected uri path: {p}"),
119-
});
156+
};
120157

121-
let res = redirect.call(req).await.unwrap();
158+
let mut req = http::Request::builder()
159+
.uri("http://foo.bar/foo")
160+
.body(Default::default())
161+
.unwrap();
122162

163+
let req = handle.mock(&mut req, handler);
164+
let res = redirect.call(req).await.unwrap();
123165
assert_eq!(res.status(), StatusCode::IM_A_TEAPOT);
166+
167+
let mut req = http::Request::builder()
168+
.uri("http://foo.bar/fur")
169+
.body(Default::default())
170+
.unwrap();
171+
172+
let req = handle.mock(&mut req, handler);
173+
let res = redirect.call(req).await.unwrap();
174+
assert_eq!(res.status(), StatusCode::SEE_OTHER);
175+
assert_eq!(res.headers().get(LOCATION).unwrap().to_str().unwrap(), "/bar");
124176
}
125177
}

0 commit comments

Comments
 (0)