@@ -64,6 +64,15 @@ pub struct StreamableHttpServerConfig {
6464 /// or with ports:
6565 /// allowed_hosts = ["example.com", "example.com:8080"]
6666 pub allowed_hosts : Vec < String > ,
67+ /// Allowed browser origins for inbound `Origin` validation.
68+ ///
69+ /// Defaults to an empty list, which disables Origin validation. When
70+ /// non-empty, requests carrying an `Origin` header must match per RFC 6454
71+ /// `(scheme, host, port)`; missing-`Origin` requests still pass. Entries
72+ /// must include a scheme; `"null"` matches the browser's `Origin: null`.
73+ /// examples:
74+ /// allowed_origins = ["https://app.example.com", "http://localhost:8080"]
75+ pub allowed_origins : Vec < String > ,
6776 /// Optional external session store for cross-instance recovery.
6877 ///
6978 /// When set, [`SessionState`] (the client's `initialize` parameters) is
@@ -103,6 +112,7 @@ impl Default for StreamableHttpServerConfig {
103112 json_response : false ,
104113 cancellation_token : CancellationToken :: new ( ) ,
105114 allowed_hosts : vec ! [ "localhost" . into( ) , "127.0.0.1" . into( ) , "::1" . into( ) ] ,
115+ allowed_origins : vec ! [ ] ,
106116 session_store : None ,
107117 }
108118 }
@@ -121,6 +131,18 @@ impl StreamableHttpServerConfig {
121131 self . allowed_hosts . clear ( ) ;
122132 self
123133 }
134+ pub fn with_allowed_origins (
135+ mut self ,
136+ allowed_origins : impl IntoIterator < Item = impl Into < String > > ,
137+ ) -> Self {
138+ self . allowed_origins = allowed_origins. into_iter ( ) . map ( Into :: into) . collect ( ) ;
139+ self
140+ }
141+ /// Disable Origin validation, reverting to the default ignore-Origin behavior.
142+ pub fn disable_allowed_origins ( mut self ) -> Self {
143+ self . allowed_origins . clear ( ) ;
144+ self
145+ }
124146 pub fn with_sse_keep_alive ( mut self , duration : Option < Duration > ) -> Self {
125147 self . sse_keep_alive = duration;
126148 self
@@ -243,6 +265,59 @@ fn host_is_allowed(host: &NormalizedAuthority, allowed_hosts: &[String]) -> bool
243265 } )
244266}
245267
268+ #[ derive( Debug , Clone , PartialEq , Eq ) ]
269+ enum NormalizedOrigin {
270+ Null ,
271+ Tuple {
272+ scheme : String ,
273+ host : String ,
274+ port : Option < u16 > ,
275+ } ,
276+ }
277+
278+ fn parse_origin_value ( value : & str ) -> Option < NormalizedOrigin > {
279+ let value = value. trim ( ) ;
280+ if value. is_empty ( ) {
281+ return None ;
282+ }
283+ if value. eq_ignore_ascii_case ( "null" ) {
284+ return Some ( NormalizedOrigin :: Null ) ;
285+ }
286+ let uri = http:: Uri :: try_from ( value) . ok ( ) ?;
287+ let scheme = uri. scheme_str ( ) ?. to_ascii_lowercase ( ) ;
288+ let authority = uri. authority ( ) ?;
289+ Some ( NormalizedOrigin :: Tuple {
290+ scheme,
291+ host : normalize_host ( authority. host ( ) ) ,
292+ port : authority. port_u16 ( ) ,
293+ } )
294+ }
295+
296+ fn origin_is_allowed ( origin : & NormalizedOrigin , allowed_origins : & [ String ] ) -> bool {
297+ if allowed_origins. is_empty ( ) {
298+ return true ;
299+ }
300+ allowed_origins
301+ . iter ( )
302+ . filter_map ( |raw| parse_origin_value ( raw) )
303+ . any ( |allowed| match ( & allowed, origin) {
304+ ( NormalizedOrigin :: Null , NormalizedOrigin :: Null ) => true ,
305+ (
306+ NormalizedOrigin :: Tuple {
307+ scheme : a_scheme,
308+ host : a_host,
309+ port : a_port,
310+ } ,
311+ NormalizedOrigin :: Tuple {
312+ scheme : o_scheme,
313+ host : o_host,
314+ port : o_port,
315+ } ,
316+ ) => a_scheme == o_scheme && a_host == o_host && ( a_port. is_none ( ) || a_port == o_port) ,
317+ _ => false ,
318+ } )
319+ }
320+
246321fn bad_request_response ( message : & str ) -> BoxResponse {
247322 let body = Full :: from ( message. to_string ( ) ) . boxed ( ) ;
248323
@@ -274,7 +349,30 @@ fn validate_dns_rebinding_headers(
274349 if !host_is_allowed ( & host, & config. allowed_hosts ) {
275350 return Err ( forbidden_response ( "Forbidden: Host header is not allowed" ) ) ;
276351 }
352+ validate_origin_header ( headers, & config. allowed_origins ) ?;
353+ Ok ( ( ) )
354+ }
277355
356+ fn validate_origin_header (
357+ headers : & HeaderMap ,
358+ allowed_origins : & [ String ] ,
359+ ) -> Result < ( ) , BoxResponse > {
360+ if allowed_origins. is_empty ( ) {
361+ return Ok ( ( ) ) ;
362+ }
363+ let Some ( origin_header) = headers. get ( http:: header:: ORIGIN ) else {
364+ return Ok ( ( ) ) ;
365+ } ;
366+ let origin_str = origin_header
367+ . to_str ( )
368+ . map_err ( |_| bad_request_response ( "Bad Request: Invalid Origin header encoding" ) ) ?;
369+ let origin = parse_origin_value ( origin_str)
370+ . ok_or_else ( || bad_request_response ( "Bad Request: Invalid Origin header" ) ) ?;
371+ if !origin_is_allowed ( & origin, allowed_origins) {
372+ return Err ( forbidden_response (
373+ "Forbidden: Origin header is not allowed" ,
374+ ) ) ;
375+ }
278376 Ok ( ( ) )
279377}
280378
0 commit comments