Skip to content

Commit 4d8c2cf

Browse files
feat: add API to allow an origin to be allowed by CORS
1 parent b34c109 commit 4d8c2cf

3 files changed

Lines changed: 108 additions & 18 deletions

File tree

crates/devtools-core/src/aggregator.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,6 @@ impl<T, const CAP: usize> EventBuf<T, CAP> {
321321
}
322322

323323
/// Push an event into the buffer, overwriting the oldest event if the buffer is full.
324-
// TODO does it really make sense to track the dropped events here?
325324
pub fn push_overwrite(&mut self, item: T) {
326325
if self.inner.push_overwrite(item).is_some() {
327326
self.sent = self.sent.saturating_sub(1);

crates/devtools-core/src/server.rs

Lines changed: 100 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,24 @@ use devtools_wire_format::sources::sources_server::SourcesServer;
99
use devtools_wire_format::tauri::tauri_server;
1010
use devtools_wire_format::tauri::tauri_server::TauriServer;
1111
use futures::{FutureExt, TryStreamExt};
12+
use http::HeaderValue;
13+
use hyper::Body;
1214
use std::net::SocketAddr;
15+
use std::pin::Pin;
16+
use std::sync::{Arc, Mutex};
17+
use std::task::{Context, Poll};
1318
use tokio::sync::mpsc;
19+
use tonic::body::BoxBody;
1420
use tonic::codegen::http::Method;
1521
use tonic::codegen::tokio_stream::wrappers::ReceiverStream;
1622
use tonic::codegen::BoxStream;
1723
use tonic::{Request, Response, Status};
1824
use tonic_health::pb::health_server::{Health, HealthServer};
1925
use tonic_health::server::HealthReporter;
2026
use tonic_health::ServingStatus;
21-
use tower_http::cors::{AllowHeaders, CorsLayer};
27+
use tower::Service;
28+
use tower_http::cors::{AllowHeaders, AllowOrigin, CorsLayer};
29+
use tower_layer::Layer;
2230

2331
/// Default maximum capacity for the channel of events sent from a
2432
/// [`Server`] to each subscribed client.
@@ -28,15 +36,81 @@ use tower_http::cors::{AllowHeaders, CorsLayer};
2836
const DEFAULT_CLIENT_BUFFER_CAPACITY: usize = 1024 * 4;
2937

3038
/// The `gRPC` server that exposes the instrumenting API
31-
pub struct Server(
32-
tonic::transport::server::Router<tower_layer::Stack<CorsLayer, tower_layer::Identity>>,
33-
);
39+
pub struct Server {
40+
router: tonic::transport::server::Router<
41+
tower_layer::Stack<DynamicCorsLayer, tower_layer::Identity>,
42+
>,
43+
handle: ServerHandle,
44+
}
45+
46+
/// A handle to a server that is allowed to modify its properties (such as CORS allowed origins)
47+
#[derive(Clone)]
48+
pub struct ServerHandle {
49+
allowed_origins: Arc<Mutex<Vec<AllowOrigin>>>,
50+
}
51+
52+
impl ServerHandle {
53+
pub fn allow_origin(&self, origin: impl Into<AllowOrigin>) {
54+
self.allowed_origins.lock().unwrap().push(origin.into());
55+
}
56+
}
3457

3558
struct InstrumentService {
3659
tx: mpsc::Sender<Command>,
3760
health_reporter: HealthReporter,
3861
}
3962

63+
#[derive(Clone)]
64+
struct DynamicCorsLayer {
65+
allowed_origins: Arc<Mutex<Vec<AllowOrigin>>>,
66+
}
67+
68+
impl<S> Layer<S> for DynamicCorsLayer {
69+
type Service = DynamicCors<S>;
70+
71+
fn layer(&self, service: S) -> Self::Service {
72+
DynamicCors {
73+
inner: service,
74+
allowed_origins: self.allowed_origins.clone(),
75+
}
76+
}
77+
}
78+
79+
#[derive(Debug, Clone)]
80+
struct DynamicCors<S> {
81+
inner: S,
82+
allowed_origins: Arc<Mutex<Vec<AllowOrigin>>>,
83+
}
84+
85+
type BoxFuture<'a, T> = Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
86+
87+
impl<S> Service<hyper::Request<Body>> for DynamicCors<S>
88+
where
89+
S: Service<hyper::Request<Body>, Response = hyper::Response<BoxBody>> + Clone + Send + 'static,
90+
S::Future: Send + 'static,
91+
{
92+
type Response = S::Response;
93+
type Error = S::Error;
94+
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
95+
96+
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
97+
self.inner.poll_ready(cx)
98+
}
99+
100+
fn call(&mut self, req: hyper::Request<Body>) -> Self::Future {
101+
let mut cors = CorsLayer::new()
102+
// allow `GET` and `POST` when accessing the resource
103+
.allow_methods([Method::GET, Method::POST])
104+
.allow_headers(AllowHeaders::any());
105+
106+
for origin in &*self.allowed_origins.lock().unwrap() {
107+
cors = cors.allow_origin(origin.clone());
108+
}
109+
110+
Box::pin(cors.layer(self.inner.clone()).call(req))
111+
}
112+
}
113+
40114
impl Server {
41115
#[allow(clippy::missing_panics_doc)]
42116
pub fn new(
@@ -51,15 +125,22 @@ impl Server {
51125
.set_serving::<InstrumentServer<InstrumentService>>()
52126
.now_or_never();
53127

54-
let cors = CorsLayer::new()
55-
// allow `GET` and `POST` when accessing the resource
56-
.allow_methods([Method::GET, Method::POST])
57-
.allow_headers(AllowHeaders::any())
58-
.allow_origin(tower_http::cors::Any);
128+
let allowed_origins =
129+
Arc::new(Mutex::new(vec![
130+
if option_env!("__DEVTOOLS_LOCAL_DEVELOPMENT").is_some() {
131+
AllowOrigin::from(tower_http::cors::Any)
132+
} else {
133+
HeaderValue::from_str("https://devtools.crabnebula.dev")
134+
.unwrap()
135+
.into()
136+
},
137+
]));
59138

60139
let router = tonic::transport::Server::builder()
61140
.accept_http1(true)
62-
.layer(cors)
141+
.layer(DynamicCorsLayer {
142+
allowed_origins: allowed_origins.clone(),
143+
})
63144
.add_service(tonic_web::enable(health_service))
64145
.add_service(tonic_web::enable(InstrumentServer::new(
65146
InstrumentService {
@@ -71,7 +152,14 @@ impl Server {
71152
.add_service(tonic_web::enable(MetadataServer::new(metadata_server)))
72153
.add_service(tonic_web::enable(SourcesServer::new(sources_server)));
73154

74-
Self(router)
155+
Self {
156+
router,
157+
handle: ServerHandle { allowed_origins },
158+
}
159+
}
160+
161+
pub fn handle(&self) -> ServerHandle {
162+
self.handle.clone()
75163
}
76164

77165
/// Consumes this [`Server`] and returns a future that will execute the server.
@@ -82,7 +170,7 @@ impl Server {
82170
pub async fn run(self, addr: SocketAddr) -> crate::Result<()> {
83171
tracing::info!("Listening on {}", addr);
84172

85-
self.0.serve(addr).await?;
173+
self.router.serve(addr).await?;
86174

87175
Ok(())
88176
}

crates/devtools/src/lib.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ mod server;
33
use devtools_core::aggregator::Aggregator;
44
use devtools_core::layer::Layer;
55
use devtools_core::server::wire::tauri::tauri_server::TauriServer;
6-
use devtools_core::server::Server;
6+
use devtools_core::server::{Server, ServerHandle};
77
use devtools_core::Command;
88
pub use devtools_core::Error;
99
use devtools_core::{Result, Shared};
@@ -52,6 +52,7 @@ mod ios {
5252

5353
pub struct Devtools {
5454
pub connection: ConnectionInfo,
55+
pub server_handle: ServerHandle,
5556
}
5657

5758
fn init_plugin<R: Runtime>(
@@ -64,10 +65,6 @@ fn init_plugin<R: Runtime>(
6465
.setup(move |app_handle, _api| {
6566
let (mut health_reporter, health_service) = tonic_health::server::health_reporter();
6667

67-
app_handle.manage(Devtools {
68-
connection: connection_info(&addr),
69-
});
70-
7168
health_reporter
7269
.set_serving::<TauriServer<server::TauriService<R>>>()
7370
.now_or_never()
@@ -87,6 +84,12 @@ fn init_plugin<R: Runtime>(
8784
app_handle: app_handle.clone(),
8885
},
8986
);
87+
let server_handle = server.handle();
88+
89+
app_handle.manage(Devtools {
90+
connection: connection_info(&addr),
91+
server_handle,
92+
});
9093

9194
#[cfg(not(target_os = "ios"))]
9295
print_link(&addr);

0 commit comments

Comments
 (0)