Skip to content

Commit ee56b20

Browse files
authored
feat: proxy and admin cors (#94)
1 parent bb8c0c9 commit ee56b20

7 files changed

Lines changed: 154 additions & 25 deletions

File tree

Cargo.lock

Lines changed: 4 additions & 13 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ axum-server = { version = "0.8.0", default-features = false, features = [
137137
backon = { version = "1.6.0", default-features = false, features = [
138138
"tokio-sleep",
139139
] }
140+
tower-http = { version = "0.6.10", features = ["cors"] }
140141

141142
[build-dependencies]
142143
vergen-git2 = { version = "9.1.0" }

config.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,16 @@ server:
2020
enabled: false
2121
cert_file: cert.pem
2222
key_file: key.pem
23+
cors:
24+
enabled: false
25+
allowed_origins: [ "*" ]
26+
allowed_methods: [ "GET", "POST" ]
27+
allowed_headers: [ "Authorization", "Content-Type" ]
28+
exposed_headers: []
29+
allow_credentials: false
2330
admin:
2431
listen: 127.0.0.1:3001
32+
tls:
33+
enabled: false
34+
cors:
35+
enabled: false

src/admin/mod.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ mod types;
66

77
use std::sync::Arc;
88

9+
use anyhow::Result;
910
use axum::{
1011
Router,
1112
extract::{Request, State},
@@ -107,8 +108,8 @@ impl AppState {
107108
}
108109
}
109110

110-
pub fn create_router(state: AppState) -> Router {
111-
Router::new()
111+
pub fn create_router(state: AppState) -> Result<Router> {
112+
let mut router = Router::new()
112113
.nest(
113114
PATH_PREFIX,
114115
Router::new()
@@ -148,8 +149,14 @@ pub fn create_router(state: AppState) -> Router {
148149
.route("/ui", get(|| async { Redirect::to("/ui/") }))
149150
.route("/ui/", get(aisix_admin_ui::handler))
150151
.route("/ui/{*path}", get(aisix_admin_ui::handler))
151-
.merge(Scalar::with_url("/openapi", ApiDoc::openapi()))
152-
.with_state(state)
152+
.merge(Scalar::with_url("/openapi", ApiDoc::openapi()));
153+
154+
let cors = &state.config.server.admin.cors;
155+
if cors.enabled {
156+
router = router.layer(cors.to_cors_layer()?)
157+
};
158+
159+
Ok(router.with_state(state))
153160
}
154161

155162
async fn auth(

src/config/types.rs

Lines changed: 113 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
use std::net::SocketAddr;
1+
use std::{net::SocketAddr, str::FromStr};
22

3-
use anyhow::Result;
3+
use anyhow::{Context, Result};
44
use async_trait::async_trait;
5+
use http::{HeaderValue, Method, header::HeaderName};
56
use serde::Deserialize;
67
use tokio::sync::mpsc;
8+
use tower_http::cors::{AllowOrigin, CorsLayer};
79

810
use crate::config::etcd;
911

@@ -46,19 +48,88 @@ pub struct ServerCommonTls {
4648
pub key_file: Option<String>,
4749
}
4850

51+
#[derive(Clone, Debug, Default, Deserialize)]
52+
pub struct ServerCommonCors {
53+
#[serde(default)]
54+
pub enabled: bool,
55+
pub allowed_origins: Option<Vec<String>>,
56+
pub allowed_methods: Option<Vec<String>>,
57+
pub allowed_headers: Option<Vec<String>>,
58+
pub exposed_headers: Option<Vec<String>>,
59+
pub allow_credentials: Option<bool>,
60+
}
61+
62+
impl ServerCommonCors {
63+
pub fn to_cors_layer(&self) -> Result<CorsLayer> {
64+
let mut cors = CorsLayer::new().allow_credentials(self.allow_credentials.unwrap_or(false));
65+
66+
if let Some(origins) = self.allowed_origins.as_deref() {
67+
cors = cors.allow_origin(if origins.iter().any(|o| o == "*") {
68+
AllowOrigin::any()
69+
} else {
70+
AllowOrigin::list(Self::parse_cors_values(
71+
"allowed_origin",
72+
origins,
73+
HeaderValue::from_str,
74+
)?)
75+
});
76+
}
77+
78+
if let Some(methods) = self.allowed_methods.as_deref() {
79+
cors = cors.allow_methods(Self::parse_cors_values(
80+
"allowed_method",
81+
methods,
82+
Method::from_str,
83+
)?);
84+
}
85+
86+
if let Some(headers) = self.allowed_headers.as_deref() {
87+
cors = cors.allow_headers(Self::parse_cors_values(
88+
"allowed_header",
89+
headers,
90+
HeaderName::from_str,
91+
)?);
92+
}
93+
94+
if let Some(headers) = self.exposed_headers.as_deref() {
95+
cors = cors.expose_headers(Self::parse_cors_values(
96+
"exposed_header",
97+
headers,
98+
HeaderName::from_str,
99+
)?);
100+
}
101+
102+
Ok(cors)
103+
}
104+
105+
fn parse_cors_values<T, E, F>(field: &str, values: &[String], mut parse: F) -> Result<Vec<T>>
106+
where
107+
F: FnMut(&str) -> std::result::Result<T, E>,
108+
E: std::error::Error + Send + Sync + 'static,
109+
{
110+
values
111+
.iter()
112+
.map(|value| parse(value).with_context(|| format!("Invalid CORS {}: {}", field, value)))
113+
.collect()
114+
}
115+
}
116+
49117
#[derive(Clone, Debug, Deserialize)]
50118
pub struct ServerProxy {
51119
#[serde(default = "defaults::listen")]
52120
pub listen: SocketAddr,
53121
#[serde(default)]
54122
pub tls: ServerCommonTls,
123+
#[serde(default)]
124+
pub cors: ServerCommonCors,
55125
}
56126

57127
impl Default for ServerProxy {
58128
fn default() -> Self {
59129
Self {
60130
listen: defaults::listen(),
61131
tls: ServerCommonTls::default(),
132+
cors: ServerCommonCors::default(),
62133
}
63134
}
64135
}
@@ -69,13 +140,16 @@ pub struct ServerAdmin {
69140
pub listen: SocketAddr,
70141
#[serde(default)]
71142
pub tls: ServerCommonTls,
143+
#[serde(default)]
144+
pub cors: ServerCommonCors,
72145
}
73146

74147
impl Default for ServerAdmin {
75148
fn default() -> Self {
76149
Self {
77150
listen: defaults::admin_listen(),
78151
tls: ServerCommonTls::default(),
152+
cors: ServerCommonCors::default(),
79153
}
80154
}
81155
}
@@ -212,3 +286,40 @@ impl dyn ConfigProvider {
212286
}
213287
}
214288
}
289+
290+
#[cfg(test)]
291+
mod tests {
292+
use super::ServerCommonCors;
293+
294+
#[test]
295+
fn to_cors_layer_accepts_valid_config() {
296+
let cors = ServerCommonCors {
297+
enabled: true,
298+
allowed_origins: Some(vec!["https://example.com".into()]),
299+
allowed_methods: Some(vec!["GET".into(), "POST".into()]),
300+
allowed_headers: Some(vec!["content-type".into()]),
301+
exposed_headers: Some(vec!["x-request-id".into()]),
302+
allow_credentials: Some(true),
303+
};
304+
305+
assert!(cors.to_cors_layer().is_ok());
306+
}
307+
308+
#[test]
309+
fn to_cors_layer_rejects_invalid_config() {
310+
let cors = ServerCommonCors {
311+
allowed_methods: Some(vec!["NOT A METHOD".into()]),
312+
..Default::default()
313+
};
314+
315+
let result = cors.to_cors_layer();
316+
317+
assert!(result.is_err());
318+
assert!(
319+
result
320+
.err()
321+
.map(|err| err.to_string().contains("Invalid CORS allowed_method"))
322+
.unwrap_or(false)
323+
);
324+
}
325+
}

src/lib.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ pub async fn run_with_provider(
9999
resources.clone(),
100100
gateway,
101101
message_history_storage,
102-
));
102+
))
103+
.context("failed to create proxy router")?;
103104

104105
let res = select! {
105106
res = tokio::signal::ctrl_c() =>
@@ -143,7 +144,7 @@ async fn serve_admin(config: Arc<config::Config>, state: admin::AppState) -> Res
143144
"Admin",
144145
config.server.admin.listen,
145146
&config.server.admin.tls,
146-
admin::create_router(state),
147+
admin::create_router(state).context("failed to create admin router")?,
147148
)
148149
.await
149150
}

src/proxy/mod.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ mod utils;
77

88
use std::sync::Arc;
99

10+
use anyhow::Result;
1011
use axum::{
1112
Router,
1213
extract::DefaultBodyLimit,
@@ -57,8 +58,8 @@ impl AppState {
5758
}
5859
}
5960

60-
pub fn create_router(state: AppState) -> Router {
61-
Router::new()
61+
pub fn create_router(state: AppState) -> Result<Router> {
62+
let mut router = Router::new()
6263
.merge(Router::new().route("/v1/models", get(handlers::models::list_models)))
6364
.route(
6465
"/v1/chat/completions",
@@ -80,6 +81,12 @@ pub fn create_router(state: AppState) -> Router {
8081
.route("/v1/embeddings", post(handlers::embeddings::embeddings))
8182
.layer(DefaultBodyLimit::max(10 * 1024 * 1024))
8283
.layer(from_fn_with_state(state.clone(), middlewares::auth))
83-
.layer(from_fn(middlewares::trace))
84-
.with_state(state)
84+
.layer(from_fn(middlewares::trace));
85+
86+
let cors = &state.config.server.proxy.cors;
87+
if cors.enabled {
88+
router = router.layer(cors.to_cors_layer()?)
89+
};
90+
91+
Ok(router.with_state(state))
8592
}

0 commit comments

Comments
 (0)