@@ -13,7 +13,7 @@ use tokio::sync::Mutex;
1313use bytes:: Bytes ;
1414use futures_util:: { SinkExt , StreamExt } ;
1515use global_error:: * ;
16- use http_body_util:: Full ;
16+ use http_body_util:: { BodyExt , Full } ;
1717use hyper:: body:: Incoming as BodyIncoming ;
1818use hyper:: header:: HeaderName ;
1919use hyper:: { Request , Response , StatusCode } ;
@@ -34,6 +34,68 @@ const X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for");
3434const ROUTE_CACHE_TTL : Duration = Duration :: from_secs ( 60 * 10 ) ; // 10 minutes
3535const PROXY_STATE_CACHE_TTL : Duration = Duration :: from_secs ( 60 * 60 ) ; // 1 hour
3636
37+ /// Response body type that can handle both streaming and buffered responses
38+ #[ derive( Debug ) ]
39+ pub enum ResponseBody {
40+ /// Buffered response body
41+ Full ( Full < Bytes > ) ,
42+ /// Streaming response body
43+ Incoming ( BodyIncoming ) ,
44+ }
45+
46+ impl http_body:: Body for ResponseBody {
47+ type Data = Bytes ;
48+ type Error = Box < dyn std:: error:: Error + Send + Sync > ;
49+
50+ fn poll_frame (
51+ self : std:: pin:: Pin < & mut Self > ,
52+ cx : & mut std:: task:: Context < ' _ > ,
53+ ) -> std:: task:: Poll < Option < Result < http_body:: Frame < Self :: Data > , Self :: Error > > > {
54+ match self . get_mut ( ) {
55+ ResponseBody :: Full ( body) => {
56+ let pin = std:: pin:: Pin :: new ( body) ;
57+ match pin. poll_frame ( cx) {
58+ std:: task:: Poll :: Ready ( Some ( Ok ( frame) ) ) => {
59+ std:: task:: Poll :: Ready ( Some ( Ok ( frame) ) )
60+ }
61+ std:: task:: Poll :: Ready ( Some ( Err ( e) ) ) => {
62+ std:: task:: Poll :: Ready ( Some ( Err ( Box :: new ( e) ) ) )
63+ }
64+ std:: task:: Poll :: Ready ( None ) => std:: task:: Poll :: Ready ( None ) ,
65+ std:: task:: Poll :: Pending => std:: task:: Poll :: Pending ,
66+ }
67+ }
68+ ResponseBody :: Incoming ( body) => {
69+ let pin = std:: pin:: Pin :: new ( body) ;
70+ match pin. poll_frame ( cx) {
71+ std:: task:: Poll :: Ready ( Some ( Ok ( frame) ) ) => {
72+ std:: task:: Poll :: Ready ( Some ( Ok ( frame) ) )
73+ }
74+ std:: task:: Poll :: Ready ( Some ( Err ( e) ) ) => {
75+ std:: task:: Poll :: Ready ( Some ( Err ( Box :: new ( e) ) ) )
76+ }
77+ std:: task:: Poll :: Ready ( None ) => std:: task:: Poll :: Ready ( None ) ,
78+ std:: task:: Poll :: Pending => std:: task:: Poll :: Pending ,
79+ }
80+ }
81+ }
82+ }
83+
84+ fn is_end_stream ( & self ) -> bool {
85+ match self {
86+ ResponseBody :: Full ( body) => body. is_end_stream ( ) ,
87+ ResponseBody :: Incoming ( body) => body. is_end_stream ( ) ,
88+ }
89+ }
90+
91+ fn size_hint ( & self ) -> http_body:: SizeHint {
92+ match self {
93+ ResponseBody :: Full ( body) => body. size_hint ( ) ,
94+ ResponseBody :: Incoming ( body) => body. size_hint ( ) ,
95+ }
96+ }
97+ }
98+
3799// Routing types
38100#[ derive( Clone , Debug ) ]
39101pub struct RouteTarget {
@@ -71,7 +133,7 @@ pub struct StructuredResponse {
71133}
72134
73135impl StructuredResponse {
74- pub fn build_response ( & self ) -> GlobalResult < Response < Full < Bytes > > > {
136+ pub fn build_response ( & self ) -> GlobalResult < Response < ResponseBody > > {
75137 let mut body = StdHashMap :: new ( ) ;
76138 body. insert ( "message" , self . message . clone ( ) . into_owned ( ) ) ;
77139
@@ -85,7 +147,7 @@ impl StructuredResponse {
85147 let response = Response :: builder ( )
86148 . status ( self . status )
87149 . header ( hyper:: header:: CONTENT_TYPE , "application/json" )
88- . body ( Full :: new ( bytes) ) ?;
150+ . body ( ResponseBody :: Full ( Full :: new ( bytes) ) ) ?;
89151
90152 Ok ( response)
91153 }
@@ -605,7 +667,7 @@ impl ProxyService {
605667 & self ,
606668 req : Request < BodyIncoming > ,
607669 request_context : & mut RequestContext ,
608- ) -> GlobalResult < Response < Full < Bytes > > > {
670+ ) -> GlobalResult < Response < ResponseBody > > {
609671 let host = req
610672 . headers ( )
611673 . get ( hyper:: header:: HOST )
@@ -641,7 +703,7 @@ impl ProxyService {
641703 tracing:: error!( ?err, "Routing error" ) ;
642704 return Ok ( Response :: builder ( )
643705 . status ( StatusCode :: BAD_GATEWAY )
644- . body ( Full :: < Bytes > :: new ( Bytes :: new ( ) ) ) ?) ;
706+ . body ( ResponseBody :: Full ( Full :: < Bytes > :: new ( Bytes :: new ( ) ) ) ) ?) ;
645707 }
646708 } ;
647709
@@ -669,14 +731,14 @@ impl ProxyService {
669731 let res = if !self . state . check_rate_limit ( client_ip, & actor_id) . await ? {
670732 Response :: builder ( )
671733 . status ( StatusCode :: TOO_MANY_REQUESTS )
672- . body ( Full :: < Bytes > :: new ( Bytes :: new ( ) ) )
734+ . body ( ResponseBody :: Full ( Full :: < Bytes > :: new ( Bytes :: new ( ) ) ) )
673735 . map_err ( Into :: into)
674736 }
675737 // Check in-flight limit
676738 else if !self . state . acquire_in_flight ( client_ip, & actor_id) . await ? {
677739 Response :: builder ( )
678740 . status ( StatusCode :: TOO_MANY_REQUESTS )
679- . body ( Full :: < Bytes > :: new ( Bytes :: new ( ) ) )
741+ . body ( ResponseBody :: Full ( Full :: < Bytes > :: new ( Bytes :: new ( ) ) ) )
680742 . map_err ( Into :: into)
681743 } else {
682744 // Increment metrics
@@ -782,7 +844,7 @@ impl ProxyService {
782844 req : Request < BodyIncoming > ,
783845 mut target : RouteTarget ,
784846 request_context : & mut RequestContext ,
785- ) -> GlobalResult < Response < Full < Bytes > > > {
847+ ) -> GlobalResult < Response < ResponseBody > > {
786848 // Get middleware config for this actor if it exists
787849 let middleware_config = match & target. actor_id {
788850 Some ( actor_id) => self . state . get_middleware_config ( actor_id) . await ?,
@@ -894,20 +956,38 @@ impl ProxyService {
894956 Ok ( Ok ( resp) ) => {
895957 let response_receive_time = request_send_start. elapsed ( ) ;
896958
897- // Convert the hyper::body::Incoming to http_body_util::Full<Bytes>
898959 let ( parts, body) = resp. into_parts ( ) ;
899960
900- // Read the response body
901- let body_bytes = match http_body_util:: BodyExt :: collect ( body) . await {
902- Ok ( collected) => collected. to_bytes ( ) ,
903- Err ( _) => Bytes :: new ( ) ,
904- } ;
961+ // Check if this is a streaming response by examining headers
962+ // let is_streaming = parts.headers.get("content-type")
963+ // .and_then(|ct| ct.to_str().ok())
964+ // .map(|ct| ct.contains("text/event-stream") || ct.contains("application/stream"))
965+ // .unwrap_or(false);
966+ let is_streaming = true ;
967+
968+ if is_streaming {
969+ // For streaming responses, pass through the body without buffering
970+ tracing:: debug!( "Detected streaming response, preserving stream" ) ;
905971
906- // Set actual response body size in analytics
907- request_context. guard_response_body_bytes = Some ( body_bytes . len ( ) as u64 ) ;
972+ // We can't easily calculate response size for streaming, so set it to None
973+ request_context. guard_response_body_bytes = None ;
908974
909- let full_body = Full :: new ( body_bytes) ;
910- return Ok ( Response :: from_parts ( parts, full_body) ) ;
975+ let streaming_body = ResponseBody :: Incoming ( body) ;
976+ return Ok ( Response :: from_parts ( parts, streaming_body) ) ;
977+ } else {
978+ // For non-streaming responses, buffer as before
979+ let body_bytes = match BodyExt :: collect ( body) . await {
980+ Ok ( collected) => collected. to_bytes ( ) ,
981+ Err ( _) => Bytes :: new ( ) ,
982+ } ;
983+
984+ // Set actual response body size in analytics
985+ request_context. guard_response_body_bytes =
986+ Some ( body_bytes. len ( ) as u64 ) ;
987+
988+ let full_body = ResponseBody :: Full ( Full :: new ( body_bytes) ) ;
989+ return Ok ( Response :: from_parts ( parts, full_body) ) ;
990+ }
911991 }
912992 Ok ( Err ( err) ) => {
913993 if !err. is_connect ( ) || attempts >= max_attempts {
@@ -944,7 +1024,9 @@ impl ProxyService {
9441024 tracing:: error!( ?err, "Routing error" ) ;
9451025 return Ok ( Response :: builder ( )
9461026 . status ( StatusCode :: BAD_GATEWAY )
947- . body ( Full :: < Bytes > :: new ( Bytes :: new ( ) ) ) ?) ;
1027+ . body ( ResponseBody :: Full ( Full :: < Bytes > :: new (
1028+ Bytes :: new ( ) ,
1029+ ) ) ) ?) ;
9481030 }
9491031 } ;
9501032
@@ -980,7 +1062,7 @@ impl ProxyService {
9801062
9811063 Ok ( Response :: builder ( )
9821064 . status ( status_code)
983- . body ( Full :: < Bytes > :: new ( Bytes :: new ( ) ) ) ?)
1065+ . body ( ResponseBody :: Full ( Full :: < Bytes > :: new ( Bytes :: new ( ) ) ) ) ?)
9841066 }
9851067
9861068 // Common function to build a request URI and headers
@@ -1033,7 +1115,7 @@ impl ProxyService {
10331115 req : Request < BodyIncoming > ,
10341116 mut target : RouteTarget ,
10351117 _request_context : & mut RequestContext ,
1036- ) -> GlobalResult < Response < Full < Bytes > > > {
1118+ ) -> GlobalResult < Response < ResponseBody > > {
10371119 // Get actor and server IDs for metrics and middleware
10381120 let actor_id = target. actor_id ;
10391121 let server_id = target. server_id ;
@@ -1606,15 +1688,18 @@ impl ProxyService {
16061688 // Create a new response with an empty body - WebSocket upgrades don't need a body
16071689 Ok ( Response :: from_parts (
16081690 parts,
1609- Full :: < Bytes > :: new ( Bytes :: new ( ) ) ,
1691+ ResponseBody :: Full ( Full :: < Bytes > :: new ( Bytes :: new ( ) ) ) ,
16101692 ) )
16111693 }
16121694}
16131695
16141696impl ProxyService {
16151697 // Process an individual request
16161698 #[ tracing:: instrument( skip_all) ]
1617- pub async fn process ( & self , req : Request < BodyIncoming > ) -> GlobalResult < Response < Full < Bytes > > > {
1699+ pub async fn process (
1700+ & self ,
1701+ req : Request < BodyIncoming > ,
1702+ ) -> GlobalResult < Response < ResponseBody > > {
16181703 // Create request context for analytics tracking
16191704 let mut request_context = RequestContext :: new ( self . state . clickhouse_inserter . clone ( ) ) ;
16201705
0 commit comments