@@ -2,25 +2,46 @@ use crate::{
22 body:: BoxBody ,
33 error:: { Error , InvalidUri } ,
44 http:: {
5- header:: { CONTENT_ENCODING , CONTENT_LENGTH , CONTENT_TYPE , LOCATION , TRANSFER_ENCODING } ,
5+ header:: {
6+ AUTHORIZATION , CONTENT_ENCODING , CONTENT_LENGTH , CONTENT_TYPE , COOKIE , LOCATION , PROXY_AUTHORIZATION ,
7+ TRANSFER_ENCODING ,
8+ } ,
69 Method , StatusCode , Uri ,
710 } ,
811 response:: Response ,
912 service:: { Service , ServiceRequest } ,
1013} ;
1114
1215/// middleware for following redirect response.
13- pub struct FollowRedirect < S > {
16+ pub struct FollowRedirect < S , const MAX_COUNT : usize = 10 > {
1417 service : S ,
1518}
1619
1720impl < S > FollowRedirect < S > {
18- pub fn new ( service : S ) -> Self {
21+ /// construct redirect following middleware for client.
22+ ///
23+ /// # Examples:
24+ /// ```rust
25+ /// # use xitca_client::{ClientBuilder, middleware::FollowRedirect};
26+ /// let builder = ClientBuilder::new()
27+ /// .middleware(FollowRedirect::new);
28+ /// ```
29+ pub const fn new ( service : S ) -> Self {
1930 Self { service }
2031 }
2132}
2233
23- impl < ' r , ' c , S > Service < ServiceRequest < ' r , ' c > > for FollowRedirect < S >
34+ impl < S , const MAX : usize > FollowRedirect < S , MAX > {
35+ /// set max depth of redirect following for request. when max value is reached the redirect following
36+ /// would stop and the most recent response will be returned as output.
37+ ///
38+ /// Default to 10 times.
39+ pub fn max < const MAX2 : usize > ( self ) -> FollowRedirect < S , MAX2 > {
40+ FollowRedirect { service : self . service }
41+ }
42+ }
43+
44+ impl < ' r , ' c , S , const MAX : usize > Service < ServiceRequest < ' r , ' c > > for FollowRedirect < S , MAX >
2445where
2546 S : for < ' r2 , ' c2 > Service < ServiceRequest < ' r2 , ' c2 > , Response = Response , Error = Error > + Send + Sync ,
2647{
3354 let mut method = req. method ( ) . clone ( ) ;
3455 let mut uri = req. uri ( ) . clone ( ) ;
3556 let ext = req. extensions ( ) . clone ( ) ;
57+ let mut count = 0 ;
58+
3659 loop {
3760 let mut res = self . service . call ( ServiceRequest { req, client, timeout } ) . await ?;
61+
62+ if count == MAX {
63+ return Ok ( res) ;
64+ }
65+
3866 match res. status ( ) {
3967 StatusCode :: MOVED_PERMANENTLY | StatusCode :: FOUND | StatusCode :: SEE_OTHER => {
4068 if method != Method :: HEAD {
6391 . parse :: < Uri > ( ) ?
6492 . into_parts ( ) ;
6593
94+ // remove authenticated headers when redirected to different scheme/authority
95+ if parts_location. scheme != parts. scheme || parts_location. authority != parts. authority {
96+ headers. remove ( AUTHORIZATION ) ;
97+ headers. remove ( PROXY_AUTHORIZATION ) ;
98+ headers. remove ( COOKIE ) ;
99+ }
100+
66101 let mut uri_builder = Uri :: builder ( ) ;
67102
68103 if let Some ( a) = parts_location. authority . or ( parts. authority ) {
80115 * req. method_mut ( ) = method. clone ( ) ;
81116 * req. headers_mut ( ) = headers. clone ( ) ;
82117 * req. extensions_mut ( ) = ext. clone ( ) ;
118+
119+ count += 1 ;
83120 }
84121 }
85122}
@@ -98,14 +135,9 @@ mod test {
98135 async fn redirect ( ) {
99136 let ( handle, service) = mock_service ( ) ;
100137
101- let redirect = FollowRedirect :: new ( service) ;
102-
103- let mut req = http:: Request :: builder ( )
104- . uri ( "http://foo.bar/foo" )
105- . body ( Default :: default ( ) )
106- . unwrap ( ) ;
138+ let redirect = FollowRedirect :: new ( service) . max :: < 1 > ( ) ;
107139
108- let req = handle . mock ( & mut req , |req| match req. uri ( ) . path ( ) {
140+ let handler = |req : http :: Request < BoxBody > | match req. uri ( ) . path ( ) {
109141 "/foo" => Ok ( http:: Response :: builder ( )
110142 . status ( StatusCode :: SEE_OTHER )
111143 . header ( "location" , "/bar" )
@@ -115,11 +147,31 @@ mod test {
115147 . status ( StatusCode :: IM_A_TEAPOT )
116148 . body ( ResponseBody :: Eof )
117149 . unwrap ( ) ) ,
150+ "/fur" => Ok ( http:: Response :: builder ( )
151+ . status ( StatusCode :: SEE_OTHER )
152+ . header ( "location" , "/foo" )
153+ . body ( ResponseBody :: Eof )
154+ . unwrap ( ) ) ,
118155 p => panic ! ( "unexpected uri path: {p}" ) ,
119- } ) ;
156+ } ;
120157
121- let res = redirect. call ( req) . await . unwrap ( ) ;
158+ let mut req = http:: Request :: builder ( )
159+ . uri ( "http://foo.bar/foo" )
160+ . body ( Default :: default ( ) )
161+ . unwrap ( ) ;
122162
163+ let req = handle. mock ( & mut req, handler) ;
164+ let res = redirect. call ( req) . await . unwrap ( ) ;
123165 assert_eq ! ( res. status( ) , StatusCode :: IM_A_TEAPOT ) ;
166+
167+ let mut req = http:: Request :: builder ( )
168+ . uri ( "http://foo.bar/fur" )
169+ . body ( Default :: default ( ) )
170+ . unwrap ( ) ;
171+
172+ let req = handle. mock ( & mut req, handler) ;
173+ let res = redirect. call ( req) . await . unwrap ( ) ;
174+ assert_eq ! ( res. status( ) , StatusCode :: SEE_OTHER ) ;
175+ assert_eq ! ( res. headers( ) . get( LOCATION ) . unwrap( ) . to_str( ) . unwrap( ) , "/bar" ) ;
124176 }
125177}
0 commit comments