Skip to content

Commit abe4691

Browse files
authored
Core certificate authority, part 1: Proxy (#223)
* certificate handshake * cleanup * versions * prevent concurrent setups, allow multiple connections * inform core about errors * cleanup * bump qs * simplify setup * suggestions * update protos, defguard certs * whitelist defguard certs
1 parent 0a2970b commit abe4691

12 files changed

Lines changed: 1774 additions & 146 deletions

File tree

Cargo.lock

Lines changed: 1314 additions & 72 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ repository = "https://github.com/DefGuard/proxy"
88

99
[dependencies]
1010
defguard_version = { git = "https://github.com/DefGuard/defguard.git", rev = "8649a9ba225d7bd2066a09c9e1347705c34bd158" }
11+
defguard_certs = { git = "https://github.com/DefGuard/defguard.git", rev = "3304a76f1262eb381a44a0d6906595215cb740b8" }
1112
# base `axum` deps
1213
axum = { version = "0.8", features = ["ws"] }
1314
axum-client-ip = "0.7"

deny.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ exceptions = [
111111
"AGPL-3.0-only",
112112
"AGPL-3.0-or-later",
113113
], crate = "defguard_version" },
114+
{ allow = [
115+
"AGPL-3.0-only",
116+
"AGPL-3.0-or-later",
117+
], crate = "defguard_certs" },
114118
]
115119

116120
# Some crates don't have (easily) machine readable licensing information,

proto

src/config.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ pub struct Config {
5959

6060
#[arg(long, env = "DEFGUARD_GRPC_BIND_ADDRESS")]
6161
pub grpc_bind_address: Option<IpAddr>,
62+
63+
// TODO: On different platforms this may be different
64+
#[arg(
65+
long,
66+
env = "DEFGUARD_PROXY_CERT_DIR",
67+
default_value = "/etc/defguard/certs"
68+
)]
69+
pub cert_dir: PathBuf,
6270
}
6371

6472
#[derive(thiserror::Error, Debug)]

src/grpc.rs

Lines changed: 133 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,44 @@ use std::{
88
},
99
};
1010

11-
use defguard_version::{get_tracing_variables, ComponentInfo, DefguardComponent, Version};
11+
use defguard_version::{
12+
get_tracing_variables,
13+
server::{grpc::DefguardVersionInterceptor, DefguardVersionLayer},
14+
ComponentInfo, DefguardComponent, Version,
15+
};
1216
use tokio::sync::{mpsc, oneshot};
1317
use tokio_stream::wrappers::UnboundedReceiverStream;
14-
use tonic::{Request, Response, Status, Streaming};
18+
use tonic::{
19+
transport::{Identity, Server, ServerTlsConfig},
20+
Request, Response, Status, Streaming,
21+
};
22+
use tower::ServiceBuilder;
1523
use tracing::Instrument;
1624

1725
use crate::{
1826
error::ApiError,
27+
http::GRPC_SERVER_RESTART_CHANNEL,
1928
proto::{core_request, core_response, proxy_server, CoreRequest, CoreResponse, DeviceInfo},
29+
MIN_CORE_VERSION, VERSION,
2030
};
2131

2232
// connected clients
2333
type ClientMap = HashMap<SocketAddr, mpsc::UnboundedSender<Result<CoreRequest, Status>>>;
2434

35+
#[derive(Debug, Clone, Default)]
36+
pub(crate) struct Configuration {
37+
pub(crate) grpc_key_pem: String,
38+
pub(crate) grpc_cert_pem: String,
39+
}
40+
2541
pub(crate) struct ProxyServer {
2642
current_id: Arc<AtomicU64>,
2743
clients: Arc<Mutex<ClientMap>>,
2844
results: Arc<Mutex<HashMap<u64, oneshot::Sender<core_response::Payload>>>>,
2945
pub(crate) connected: Arc<AtomicBool>,
3046
pub(crate) core_version: Arc<Mutex<Option<Version>>>,
47+
config: Arc<Mutex<Option<Configuration>>>,
48+
setup_in_progress: Arc<AtomicBool>,
3149
}
3250

3351
impl ProxyServer {
@@ -40,18 +58,94 @@ impl ProxyServer {
4058
results: Arc::new(Mutex::new(HashMap::new())),
4159
connected: Arc::new(AtomicBool::new(false)),
4260
core_version: Arc::new(Mutex::new(None)),
61+
config: Arc::new(Mutex::new(None)),
62+
setup_in_progress: Arc::new(AtomicBool::new(false)),
4363
}
4464
}
4565

46-
/// Sends message to the other side of RPC, with given `payload` and optional `device_info`.
47-
/// Returns `tokio::sync::oneshot::Reveicer` to let the caller await reply.
48-
#[instrument(name = "send_grpc_message", level = "debug", skip(self, payload))]
66+
pub(crate) fn set_tls_config(&self, cert_pem: String, key_pem: String) -> Result<(), ApiError> {
67+
let mut lock = self
68+
.config
69+
.lock()
70+
.expect("Failed to acquire lock on config mutex when updating TLS configuration");
71+
let config = lock.get_or_insert_with(Configuration::default);
72+
config.grpc_cert_pem = cert_pem;
73+
config.grpc_key_pem = key_pem;
74+
Ok(())
75+
}
76+
77+
pub(crate) fn configure(&self, config: Configuration) {
78+
let mut lock = self
79+
.config
80+
.lock()
81+
.expect("Failed to acquire lock on config mutex when applying proxy configuration");
82+
*lock = Some(config);
83+
}
84+
85+
pub(crate) fn get_configuration(&self) -> Option<Configuration> {
86+
let lock = self
87+
.config
88+
.lock()
89+
.expect("Failed to acquire lock on config mutex when retrieving proxy configuration");
90+
lock.clone()
91+
}
92+
93+
pub(crate) async fn run(self, addr: SocketAddr) -> Result<(), anyhow::Error> {
94+
info!("Starting gRPC server on {addr}");
95+
let config = self.get_configuration();
96+
let (grpc_cert, grpc_key) = if let Some(cfg) = config {
97+
(cfg.grpc_cert_pem, cfg.grpc_key_pem)
98+
} else {
99+
return Err(anyhow::anyhow!("gRPC server configuration is missing"));
100+
};
101+
102+
let identity = Identity::from_pem(grpc_cert, grpc_key);
103+
let mut builder =
104+
Server::builder().tls_config(ServerTlsConfig::new().identity(identity))?;
105+
106+
let own_version = Version::parse(VERSION)?;
107+
let versioned_service = ServiceBuilder::new()
108+
.layer(tonic::service::InterceptorLayer::new(
109+
DefguardVersionInterceptor::new(
110+
own_version.clone(),
111+
DefguardComponent::Core,
112+
MIN_CORE_VERSION,
113+
false,
114+
),
115+
))
116+
.layer(DefguardVersionLayer::new(own_version))
117+
.service(proxy_server::ProxyServer::new(self.clone()));
118+
119+
builder
120+
.add_service(versioned_service)
121+
.serve_with_shutdown(addr, async move {
122+
let mut rx_lock = GRPC_SERVER_RESTART_CHANNEL.1.lock().await;
123+
rx_lock.recv().await;
124+
info!("Shutting down gRPC server for restart...");
125+
})
126+
.await
127+
.map_err(|err| {
128+
error!("gRPC server error: {err}");
129+
err
130+
})?;
131+
132+
Ok(())
133+
}
134+
135+
/// Sends message to the other side of RPC, with given `payload` and `device_info`.
136+
#[instrument(level = "debug", skip(self, payload))]
49137
pub(crate) fn send(
50138
&self,
51139
payload: core_request::Payload,
52140
device_info: DeviceInfo,
53141
) -> Result<oneshot::Receiver<core_response::Payload>, ApiError> {
54-
if let Some(client_tx) = self.clients.lock().unwrap().values().next() {
142+
if let Some(client_tx) = self
143+
.clients
144+
.lock()
145+
.expect("Failed to acquire lock on clients hashmap when sending message to core")
146+
.values()
147+
.next()
148+
{
55149
let id = self.current_id.fetch_add(1, Ordering::Relaxed);
56150
let res = CoreRequest {
57151
id,
@@ -63,8 +157,10 @@ impl ProxyServer {
63157
return Err(ApiError::Unexpected("Failed to send CoreRequest".into()));
64158
}
65159
let (tx, rx) = oneshot::channel();
66-
let mut results = self.results.lock().unwrap();
67-
results.insert(id, tx);
160+
self.results
161+
.lock()
162+
.expect("Failed to acquire lock on results hashmap when sending CoreRequest")
163+
.insert(id, tx);
68164
self.connected.store(true, Ordering::Relaxed);
69165
Ok(rx)
70166
} else {
@@ -75,6 +171,14 @@ impl ProxyServer {
75171
))
76172
}
77173
}
174+
175+
pub(crate) fn setup_completed(&self) -> bool {
176+
let lock = self
177+
.config
178+
.lock()
179+
.expect("Failed to acquire lock on config mutex when checking setup status");
180+
lock.is_some()
181+
}
78182
}
79183

80184
impl Clone for ProxyServer {
@@ -85,6 +189,8 @@ impl Clone for ProxyServer {
85189
results: Arc::clone(&self.results),
86190
connected: Arc::clone(&self.connected),
87191
core_version: Arc::clone(&self.core_version),
192+
config: Arc::clone(&self.config),
193+
setup_in_progress: Arc::clone(&self.setup_in_progress),
88194
}
89195
}
90196
}
@@ -99,14 +205,22 @@ impl proxy_server::Proxy for ProxyServer {
99205
&self,
100206
request: Request<Streaming<CoreResponse>>,
101207
) -> Result<Response<Self::BidiStream>, Status> {
208+
if !self.setup_completed() {
209+
error!("Received bidi connection before setup completion");
210+
return Err(Status::failed_precondition(
211+
"Setup must be completed before establishing bidi connection",
212+
));
213+
}
214+
102215
let Some(address) = request.remote_addr() else {
103216
error!("Failed to determine client address for request: {request:?}");
104217
return Err(Status::internal("Failed to determine client address"));
105218
};
106219
let maybe_info = ComponentInfo::from_metadata(request.metadata());
107220
let (version, info) = get_tracing_variables(&maybe_info);
108-
let mut core_version = self.core_version.lock().unwrap();
109-
*core_version = Some(version.clone());
221+
*self.core_version.lock().expect(
222+
"Failed to acquire lock on core_version mutex when storing version information",
223+
) = Some(version.clone());
110224

111225
let span = tracing::info_span!("core_bidi_stream", component = %DefguardComponent::Core,
112226
version = version.to_string(), info);
@@ -115,7 +229,12 @@ impl proxy_server::Proxy for ProxyServer {
115229
info!("Defguard Core gRPC client connected from: {address}");
116230

117231
let (tx, rx) = mpsc::unbounded_channel();
118-
self.clients.lock().unwrap().insert(address, tx);
232+
self.clients
233+
.lock()
234+
.expect(
235+
"Failed to acquire lock on clients hashmap when registering new core connection",
236+
)
237+
.insert(address, tx);
119238
self.connected.store(true, Ordering::Relaxed);
120239

121240
let clients = Arc::clone(&self.clients);
@@ -129,9 +248,9 @@ impl proxy_server::Proxy for ProxyServer {
129248
Ok(Some(response)) => {
130249
debug!("Received message from Defguard Core ID={}", response.id);
131250
connected.store(true, Ordering::Relaxed);
132-
// Discard empty payloads.
133251
if let Some(payload) = response.payload {
134-
if let Some(rx) = results.lock().unwrap().remove(&response.id) {
252+
let maybe_rx = results.lock().expect("Failed to acquire lock on results hashmap when processing response").remove(&response.id);
253+
if let Some(rx) = maybe_rx {
135254
if let Err(err) = rx.send(payload) {
136255
error!("Failed to send message to rx {:?}", err.type_id());
137256
}
@@ -152,7 +271,7 @@ impl proxy_server::Proxy for ProxyServer {
152271
}
153272
info!("Defguard core client disconnected: {address}");
154273
connected.store(false, Ordering::Relaxed);
155-
clients.lock().unwrap().remove(&address);
274+
clients.lock().expect("Failed to acquire lock on clients hashmap when removing disconnected client").remove(&address);
156275
}
157276
.instrument(tracing::Span::current()),
158277
);

src/handlers/desktop_client_mfa.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::collections::hash_map::Entry;
2+
13
use axum::{
24
extract::{
35
ws::{Message, WebSocket},
@@ -10,7 +12,6 @@ use axum::{
1012
use futures_util::{sink::SinkExt, stream::StreamExt};
1113
use serde::Deserialize;
1214
use serde_json::json;
13-
use std::collections::hash_map::Entry;
1415
use tokio::{sync::oneshot, task::JoinSet};
1516

1617
use crate::{

src/handlers/enrollment.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ use axum_extra::extract::{cookie::Cookie, PrivateCookieJar};
33
use time::OffsetDateTime;
44

55
use super::register_mfa::router as register_mfa_router;
6-
76
use crate::{
87
error::ApiError,
98
handlers::{get_core_response, mobile_client::register_mobile_auth},

src/handlers/register_mfa.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
use serde::Deserialize;
2-
31
use axum::{extract::State, response::IntoResponse, routing::post, Json, Router};
42
use axum_extra::extract::PrivateCookieJar;
3+
use serde::Deserialize;
54

65
use crate::{
76
error::ApiError,

0 commit comments

Comments
 (0)