@@ -8,26 +8,44 @@ use std::{
88 } ,
99} ;
1010
11- use defguard_version:: { get_tracing_variables, ComponentInfo , DefguardComponent , Version } ;
11+ use defguard_version:: {
12+ get_tracing_variables,
13+ server:: { grpc:: DefguardVersionInterceptor , DefguardVersionLayer } ,
14+ ComponentInfo , DefguardComponent , Version ,
15+ } ;
1216use tokio:: sync:: { mpsc, oneshot} ;
1317use tokio_stream:: wrappers:: UnboundedReceiverStream ;
14- use tonic:: { Request , Response , Status , Streaming } ;
18+ use tonic:: {
19+ transport:: { Identity , Server , ServerTlsConfig } ,
20+ Request , Response , Status , Streaming ,
21+ } ;
22+ use tower:: ServiceBuilder ;
1523use tracing:: Instrument ;
1624
1725use crate :: {
1826 error:: ApiError ,
27+ http:: GRPC_SERVER_RESTART_CHANNEL ,
1928 proto:: { core_request, core_response, proxy_server, CoreRequest , CoreResponse , DeviceInfo } ,
29+ MIN_CORE_VERSION , VERSION ,
2030} ;
2131
2232// connected clients
2333type ClientMap = HashMap < SocketAddr , mpsc:: UnboundedSender < Result < CoreRequest , Status > > > ;
2434
35+ #[ derive( Debug , Clone , Default ) ]
36+ pub ( crate ) struct Configuration {
37+ pub ( crate ) grpc_key_pem : String ,
38+ pub ( crate ) grpc_cert_pem : String ,
39+ }
40+
2541pub ( crate ) struct ProxyServer {
2642 current_id : Arc < AtomicU64 > ,
2743 clients : Arc < Mutex < ClientMap > > ,
2844 results : Arc < Mutex < HashMap < u64 , oneshot:: Sender < core_response:: Payload > > > > ,
2945 pub ( crate ) connected : Arc < AtomicBool > ,
3046 pub ( crate ) core_version : Arc < Mutex < Option < Version > > > ,
47+ config : Arc < Mutex < Option < Configuration > > > ,
48+ setup_in_progress : Arc < AtomicBool > ,
3149}
3250
3351impl ProxyServer {
@@ -40,18 +58,94 @@ impl ProxyServer {
4058 results : Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ,
4159 connected : Arc :: new ( AtomicBool :: new ( false ) ) ,
4260 core_version : Arc :: new ( Mutex :: new ( None ) ) ,
61+ config : Arc :: new ( Mutex :: new ( None ) ) ,
62+ setup_in_progress : Arc :: new ( AtomicBool :: new ( false ) ) ,
4363 }
4464 }
4565
46- /// Sends message to the other side of RPC, with given `payload` and optional `device_info`.
47- /// Returns `tokio::sync::oneshot::Reveicer` to let the caller await reply.
48- #[ instrument( name = "send_grpc_message" , level = "debug" , skip( self , payload) ) ]
66+ pub ( crate ) fn set_tls_config ( & self , cert_pem : String , key_pem : String ) -> Result < ( ) , ApiError > {
67+ let mut lock = self
68+ . config
69+ . lock ( )
70+ . expect ( "Failed to acquire lock on config mutex when updating TLS configuration" ) ;
71+ let config = lock. get_or_insert_with ( Configuration :: default) ;
72+ config. grpc_cert_pem = cert_pem;
73+ config. grpc_key_pem = key_pem;
74+ Ok ( ( ) )
75+ }
76+
77+ pub ( crate ) fn configure ( & self , config : Configuration ) {
78+ let mut lock = self
79+ . config
80+ . lock ( )
81+ . expect ( "Failed to acquire lock on config mutex when applying proxy configuration" ) ;
82+ * lock = Some ( config) ;
83+ }
84+
85+ pub ( crate ) fn get_configuration ( & self ) -> Option < Configuration > {
86+ let lock = self
87+ . config
88+ . lock ( )
89+ . expect ( "Failed to acquire lock on config mutex when retrieving proxy configuration" ) ;
90+ lock. clone ( )
91+ }
92+
93+ pub ( crate ) async fn run ( self , addr : SocketAddr ) -> Result < ( ) , anyhow:: Error > {
94+ info ! ( "Starting gRPC server on {addr}" ) ;
95+ let config = self . get_configuration ( ) ;
96+ let ( grpc_cert, grpc_key) = if let Some ( cfg) = config {
97+ ( cfg. grpc_cert_pem , cfg. grpc_key_pem )
98+ } else {
99+ return Err ( anyhow:: anyhow!( "gRPC server configuration is missing" ) ) ;
100+ } ;
101+
102+ let identity = Identity :: from_pem ( grpc_cert, grpc_key) ;
103+ let mut builder =
104+ Server :: builder ( ) . tls_config ( ServerTlsConfig :: new ( ) . identity ( identity) ) ?;
105+
106+ let own_version = Version :: parse ( VERSION ) ?;
107+ let versioned_service = ServiceBuilder :: new ( )
108+ . layer ( tonic:: service:: InterceptorLayer :: new (
109+ DefguardVersionInterceptor :: new (
110+ own_version. clone ( ) ,
111+ DefguardComponent :: Core ,
112+ MIN_CORE_VERSION ,
113+ false ,
114+ ) ,
115+ ) )
116+ . layer ( DefguardVersionLayer :: new ( own_version) )
117+ . service ( proxy_server:: ProxyServer :: new ( self . clone ( ) ) ) ;
118+
119+ builder
120+ . add_service ( versioned_service)
121+ . serve_with_shutdown ( addr, async move {
122+ let mut rx_lock = GRPC_SERVER_RESTART_CHANNEL . 1 . lock ( ) . await ;
123+ rx_lock. recv ( ) . await ;
124+ info ! ( "Shutting down gRPC server for restart..." ) ;
125+ } )
126+ . await
127+ . map_err ( |err| {
128+ error ! ( "gRPC server error: {err}" ) ;
129+ err
130+ } ) ?;
131+
132+ Ok ( ( ) )
133+ }
134+
135+ /// Sends message to the other side of RPC, with given `payload` and `device_info`.
136+ #[ instrument( level = "debug" , skip( self , payload) ) ]
49137 pub ( crate ) fn send (
50138 & self ,
51139 payload : core_request:: Payload ,
52140 device_info : DeviceInfo ,
53141 ) -> Result < oneshot:: Receiver < core_response:: Payload > , ApiError > {
54- if let Some ( client_tx) = self . clients . lock ( ) . unwrap ( ) . values ( ) . next ( ) {
142+ if let Some ( client_tx) = self
143+ . clients
144+ . lock ( )
145+ . expect ( "Failed to acquire lock on clients hashmap when sending message to core" )
146+ . values ( )
147+ . next ( )
148+ {
55149 let id = self . current_id . fetch_add ( 1 , Ordering :: Relaxed ) ;
56150 let res = CoreRequest {
57151 id,
@@ -63,8 +157,10 @@ impl ProxyServer {
63157 return Err ( ApiError :: Unexpected ( "Failed to send CoreRequest" . into ( ) ) ) ;
64158 }
65159 let ( tx, rx) = oneshot:: channel ( ) ;
66- let mut results = self . results . lock ( ) . unwrap ( ) ;
67- results. insert ( id, tx) ;
160+ self . results
161+ . lock ( )
162+ . expect ( "Failed to acquire lock on results hashmap when sending CoreRequest" )
163+ . insert ( id, tx) ;
68164 self . connected . store ( true , Ordering :: Relaxed ) ;
69165 Ok ( rx)
70166 } else {
@@ -75,6 +171,14 @@ impl ProxyServer {
75171 ) )
76172 }
77173 }
174+
175+ pub ( crate ) fn setup_completed ( & self ) -> bool {
176+ let lock = self
177+ . config
178+ . lock ( )
179+ . expect ( "Failed to acquire lock on config mutex when checking setup status" ) ;
180+ lock. is_some ( )
181+ }
78182}
79183
80184impl Clone for ProxyServer {
@@ -85,6 +189,8 @@ impl Clone for ProxyServer {
85189 results : Arc :: clone ( & self . results ) ,
86190 connected : Arc :: clone ( & self . connected ) ,
87191 core_version : Arc :: clone ( & self . core_version ) ,
192+ config : Arc :: clone ( & self . config ) ,
193+ setup_in_progress : Arc :: clone ( & self . setup_in_progress ) ,
88194 }
89195 }
90196}
@@ -99,14 +205,22 @@ impl proxy_server::Proxy for ProxyServer {
99205 & self ,
100206 request : Request < Streaming < CoreResponse > > ,
101207 ) -> Result < Response < Self :: BidiStream > , Status > {
208+ if !self . setup_completed ( ) {
209+ error ! ( "Received bidi connection before setup completion" ) ;
210+ return Err ( Status :: failed_precondition (
211+ "Setup must be completed before establishing bidi connection" ,
212+ ) ) ;
213+ }
214+
102215 let Some ( address) = request. remote_addr ( ) else {
103216 error ! ( "Failed to determine client address for request: {request:?}" ) ;
104217 return Err ( Status :: internal ( "Failed to determine client address" ) ) ;
105218 } ;
106219 let maybe_info = ComponentInfo :: from_metadata ( request. metadata ( ) ) ;
107220 let ( version, info) = get_tracing_variables ( & maybe_info) ;
108- let mut core_version = self . core_version . lock ( ) . unwrap ( ) ;
109- * core_version = Some ( version. clone ( ) ) ;
221+ * self . core_version . lock ( ) . expect (
222+ "Failed to acquire lock on core_version mutex when storing version information" ,
223+ ) = Some ( version. clone ( ) ) ;
110224
111225 let span = tracing:: info_span!( "core_bidi_stream" , component = %DefguardComponent :: Core ,
112226 version = version. to_string( ) , info) ;
@@ -115,7 +229,12 @@ impl proxy_server::Proxy for ProxyServer {
115229 info ! ( "Defguard Core gRPC client connected from: {address}" ) ;
116230
117231 let ( tx, rx) = mpsc:: unbounded_channel ( ) ;
118- self . clients . lock ( ) . unwrap ( ) . insert ( address, tx) ;
232+ self . clients
233+ . lock ( )
234+ . expect (
235+ "Failed to acquire lock on clients hashmap when registering new core connection" ,
236+ )
237+ . insert ( address, tx) ;
119238 self . connected . store ( true , Ordering :: Relaxed ) ;
120239
121240 let clients = Arc :: clone ( & self . clients ) ;
@@ -129,9 +248,9 @@ impl proxy_server::Proxy for ProxyServer {
129248 Ok ( Some ( response) ) => {
130249 debug ! ( "Received message from Defguard Core ID={}" , response. id) ;
131250 connected. store ( true , Ordering :: Relaxed ) ;
132- // Discard empty payloads.
133251 if let Some ( payload) = response. payload {
134- if let Some ( rx) = results. lock ( ) . unwrap ( ) . remove ( & response. id ) {
252+ let maybe_rx = results. lock ( ) . expect ( "Failed to acquire lock on results hashmap when processing response" ) . remove ( & response. id ) ;
253+ if let Some ( rx) = maybe_rx {
135254 if let Err ( err) = rx. send ( payload) {
136255 error ! ( "Failed to send message to rx {:?}" , err. type_id( ) ) ;
137256 }
@@ -152,7 +271,7 @@ impl proxy_server::Proxy for ProxyServer {
152271 }
153272 info ! ( "Defguard core client disconnected: {address}" ) ;
154273 connected. store ( false , Ordering :: Relaxed ) ;
155- clients. lock ( ) . unwrap ( ) . remove ( & address) ;
274+ clients. lock ( ) . expect ( "Failed to acquire lock on clients hashmap when removing disconnected client" ) . remove ( & address) ;
156275 }
157276 . instrument ( tracing:: Span :: current ( ) ) ,
158277 ) ;
0 commit comments