1- use std:: net:: SocketAddr ;
1+ use std:: { net:: SocketAddr , str :: FromStr } ;
22
3- use anyhow:: Result ;
3+ use anyhow:: { Context , Result } ;
44use async_trait:: async_trait;
5+ use http:: { HeaderValue , Method , header:: HeaderName } ;
56use serde:: Deserialize ;
67use tokio:: sync:: mpsc;
8+ use tower_http:: cors:: { AllowOrigin , CorsLayer } ;
79
810use crate :: config:: etcd;
911
@@ -46,19 +48,88 @@ pub struct ServerCommonTls {
4648 pub key_file : Option < String > ,
4749}
4850
51+ #[ derive( Clone , Debug , Default , Deserialize ) ]
52+ pub struct ServerCommonCors {
53+ #[ serde( default ) ]
54+ pub enabled : bool ,
55+ pub allowed_origins : Option < Vec < String > > ,
56+ pub allowed_methods : Option < Vec < String > > ,
57+ pub allowed_headers : Option < Vec < String > > ,
58+ pub exposed_headers : Option < Vec < String > > ,
59+ pub allow_credentials : Option < bool > ,
60+ }
61+
62+ impl ServerCommonCors {
63+ pub fn to_cors_layer ( & self ) -> Result < CorsLayer > {
64+ let mut cors = CorsLayer :: new ( ) . allow_credentials ( self . allow_credentials . unwrap_or ( false ) ) ;
65+
66+ if let Some ( origins) = self . allowed_origins . as_deref ( ) {
67+ cors = cors. allow_origin ( if origins. iter ( ) . any ( |o| o == "*" ) {
68+ AllowOrigin :: any ( )
69+ } else {
70+ AllowOrigin :: list ( Self :: parse_cors_values (
71+ "allowed_origin" ,
72+ origins,
73+ HeaderValue :: from_str,
74+ ) ?)
75+ } ) ;
76+ }
77+
78+ if let Some ( methods) = self . allowed_methods . as_deref ( ) {
79+ cors = cors. allow_methods ( Self :: parse_cors_values (
80+ "allowed_method" ,
81+ methods,
82+ Method :: from_str,
83+ ) ?) ;
84+ }
85+
86+ if let Some ( headers) = self . allowed_headers . as_deref ( ) {
87+ cors = cors. allow_headers ( Self :: parse_cors_values (
88+ "allowed_header" ,
89+ headers,
90+ HeaderName :: from_str,
91+ ) ?) ;
92+ }
93+
94+ if let Some ( headers) = self . exposed_headers . as_deref ( ) {
95+ cors = cors. expose_headers ( Self :: parse_cors_values (
96+ "exposed_header" ,
97+ headers,
98+ HeaderName :: from_str,
99+ ) ?) ;
100+ }
101+
102+ Ok ( cors)
103+ }
104+
105+ fn parse_cors_values < T , E , F > ( field : & str , values : & [ String ] , mut parse : F ) -> Result < Vec < T > >
106+ where
107+ F : FnMut ( & str ) -> std:: result:: Result < T , E > ,
108+ E : std:: error:: Error + Send + Sync + ' static ,
109+ {
110+ values
111+ . iter ( )
112+ . map ( |value| parse ( value) . with_context ( || format ! ( "Invalid CORS {}: {}" , field, value) ) )
113+ . collect ( )
114+ }
115+ }
116+
49117#[ derive( Clone , Debug , Deserialize ) ]
50118pub struct ServerProxy {
51119 #[ serde( default = "defaults::listen" ) ]
52120 pub listen : SocketAddr ,
53121 #[ serde( default ) ]
54122 pub tls : ServerCommonTls ,
123+ #[ serde( default ) ]
124+ pub cors : ServerCommonCors ,
55125}
56126
57127impl Default for ServerProxy {
58128 fn default ( ) -> Self {
59129 Self {
60130 listen : defaults:: listen ( ) ,
61131 tls : ServerCommonTls :: default ( ) ,
132+ cors : ServerCommonCors :: default ( ) ,
62133 }
63134 }
64135}
@@ -69,13 +140,16 @@ pub struct ServerAdmin {
69140 pub listen : SocketAddr ,
70141 #[ serde( default ) ]
71142 pub tls : ServerCommonTls ,
143+ #[ serde( default ) ]
144+ pub cors : ServerCommonCors ,
72145}
73146
74147impl Default for ServerAdmin {
75148 fn default ( ) -> Self {
76149 Self {
77150 listen : defaults:: admin_listen ( ) ,
78151 tls : ServerCommonTls :: default ( ) ,
152+ cors : ServerCommonCors :: default ( ) ,
79153 }
80154 }
81155}
@@ -212,3 +286,40 @@ impl dyn ConfigProvider {
212286 }
213287 }
214288}
289+
290+ #[ cfg( test) ]
291+ mod tests {
292+ use super :: ServerCommonCors ;
293+
294+ #[ test]
295+ fn to_cors_layer_accepts_valid_config ( ) {
296+ let cors = ServerCommonCors {
297+ enabled : true ,
298+ allowed_origins : Some ( vec ! [ "https://example.com" . into( ) ] ) ,
299+ allowed_methods : Some ( vec ! [ "GET" . into( ) , "POST" . into( ) ] ) ,
300+ allowed_headers : Some ( vec ! [ "content-type" . into( ) ] ) ,
301+ exposed_headers : Some ( vec ! [ "x-request-id" . into( ) ] ) ,
302+ allow_credentials : Some ( true ) ,
303+ } ;
304+
305+ assert ! ( cors. to_cors_layer( ) . is_ok( ) ) ;
306+ }
307+
308+ #[ test]
309+ fn to_cors_layer_rejects_invalid_config ( ) {
310+ let cors = ServerCommonCors {
311+ allowed_methods : Some ( vec ! [ "NOT A METHOD" . into( ) ] ) ,
312+ ..Default :: default ( )
313+ } ;
314+
315+ let result = cors. to_cors_layer ( ) ;
316+
317+ assert ! ( result. is_err( ) ) ;
318+ assert ! (
319+ result
320+ . err( )
321+ . map( |err| err. to_string( ) . contains( "Invalid CORS allowed_method" ) )
322+ . unwrap_or( false )
323+ ) ;
324+ }
325+ }
0 commit comments