@@ -44,6 +44,7 @@ use tokio::time::{sleep_until, timeout};
4444use tokio_tungstenite:: tungstenite:: Utf8Bytes ;
4545
4646use crate :: auth:: SpacetimeAuth ;
47+ use crate :: util:: serde:: humantime_duration;
4748use crate :: util:: websocket:: {
4849 CloseCode , CloseFrame , Message as WsMessage , WebSocketConfig , WebSocketStream , WebSocketUpgrade , WsError ,
4950} ;
@@ -55,6 +56,16 @@ pub const TEXT_PROTOCOL: HeaderValue = HeaderValue::from_static(ws_api::TEXT_PRO
5556#[ allow( clippy:: declare_interior_mutable_const) ]
5657pub const BIN_PROTOCOL : HeaderValue = HeaderValue :: from_static ( ws_api:: BIN_PROTOCOL ) ;
5758
59+ pub trait HasWebSocketOptions {
60+ fn websocket_options ( & self ) -> WebSocketOptions ;
61+ }
62+
63+ impl < T : HasWebSocketOptions > HasWebSocketOptions for Arc < T > {
64+ fn websocket_options ( & self ) -> WebSocketOptions {
65+ ( * * self ) . websocket_options ( )
66+ }
67+ }
68+
5869#[ derive( Deserialize ) ]
5970pub struct SubscribeParams {
6071 pub name_or_identity : NameOrIdentity ,
@@ -88,7 +99,7 @@ pub async fn handle_websocket<S>(
8899 ws : WebSocketUpgrade ,
89100) -> axum:: response:: Result < impl IntoResponse >
90101where
91- S : NodeDelegate + ControlStateDelegate ,
102+ S : NodeDelegate + ControlStateDelegate + HasWebSocketOptions ,
92103{
93104 if connection_id. is_some ( ) {
94105 // TODO: Bump this up to `log::warn!` after removing the client SDKs' uses of that parameter.
@@ -146,6 +157,7 @@ where
146157 . max_message_size ( Some ( 0x2000000 ) )
147158 . max_frame_size ( None )
148159 . accept_unmasked_frames ( false ) ;
160+ let ws_opts = ctx. websocket_options ( ) ;
149161
150162 tokio:: spawn ( async move {
151163 let ws = match ws_upgrade. upgrade ( ws_config) . await {
@@ -163,7 +175,7 @@ where
163175 None => log:: debug!( "New client connected from unknown ip" ) ,
164176 }
165177
166- let actor = |client, sendrx| ws_client_actor ( client, ws, sendrx) ;
178+ let actor = |client, sendrx| ws_client_actor ( ws_opts , client, ws, sendrx) ;
167179 let client = match ClientConnection :: spawn ( client_id, client_config, leader. replica_id , module_rx, actor) . await
168180 {
169181 Ok ( s) => s,
@@ -198,13 +210,13 @@ where
198210struct ActorState {
199211 pub client_id : ClientActorId ,
200212 pub database : Identity ,
201- config : ActorConfig ,
213+ config : WebSocketOptions ,
202214 closed : AtomicBool ,
203215 got_pong : AtomicBool ,
204216}
205217
206218impl ActorState {
207- pub fn new ( database : Identity , client_id : ClientActorId , config : ActorConfig ) -> Self {
219+ pub fn new ( database : Identity , client_id : ClientActorId , config : WebSocketOptions ) -> Self {
208220 Self {
209221 database,
210222 client_id,
@@ -235,14 +247,19 @@ impl ActorState {
235247 }
236248}
237249
238- struct ActorConfig {
250+ /// Configuration for WebSocket connections.
251+ #[ derive( Clone , Copy , Debug , PartialEq , serde:: Serialize , serde:: Deserialize ) ]
252+ #[ serde( rename_all = "kebab-case" ) ]
253+ pub struct WebSocketOptions {
239254 /// Interval at which to send `Ping` frames.
240255 ///
241256 /// We use pings for connection keep-alive.
242257 /// Value must be smaller than `idle_timeout`.
243258 ///
244259 /// Default: 15s
245- ping_interval : Duration ,
260+ #[ serde( with = "humantime_duration" ) ]
261+ #[ serde( default = "WebSocketOptions::default_ping_interval" ) ]
262+ pub ping_interval : Duration ,
246263 /// Amount of time after which an idle connection is closed.
247264 ///
248265 /// A connection is considered idle if no data is received nor sent.
@@ -251,47 +268,80 @@ struct ActorConfig {
251268 /// Value must be greater than `ping_interval`.
252269 ///
253270 /// Default: 30s
254- idle_timeout : Duration ,
271+ #[ serde( with = "humantime_duration" ) ]
272+ #[ serde( default = "WebSocketOptions::default_idle_timeout" ) ]
273+ pub idle_timeout : Duration ,
255274 /// For how long to keep draining the incoming messages until a client close
256275 /// is received.
257276 ///
258277 /// Default: 250ms
259- close_handshake_timeout : Duration ,
278+ #[ serde( with = "humantime_duration" ) ]
279+ #[ serde( default = "WebSocketOptions::default_close_handshake_timeout" ) ]
280+ pub close_handshake_timeout : Duration ,
260281 /// Maximum number of messages to queue for processing.
261282 ///
262283 /// If this number is exceeded, the client is disconnected.
263284 ///
264285 /// Default: 2048
265- incoming_queue_length : NonZeroUsize ,
286+ #[ serde( default = "WebSocketOptions::default_incoming_queue_length" ) ]
287+ pub incoming_queue_length : NonZeroUsize ,
266288}
267289
268- impl Default for ActorConfig {
290+ impl Default for WebSocketOptions {
269291 fn default ( ) -> Self {
270- Self {
271- ping_interval : Duration :: from_secs ( 15 ) ,
272- idle_timeout : Duration :: from_secs ( 30 ) ,
273- close_handshake_timeout : Duration :: from_millis ( 250 ) ,
274- incoming_queue_length :
275- // SAFETY: 2048 > 0, qed
276- unsafe { NonZeroUsize :: new_unchecked ( 2048 ) }
277- }
292+ Self :: DEFAULT
278293 }
279294}
280295
281- async fn ws_client_actor ( client : ClientConnection , ws : WebSocketStream , sendrx : MeteredReceiver < SerializableMessage > ) {
296+ impl WebSocketOptions {
297+ const DEFAULT_PING_INTERVAL : Duration = Duration :: from_secs ( 15 ) ;
298+ const DEFAULT_IDLE_TIMEOUT : Duration = Duration :: from_secs ( 30 ) ;
299+ const DEFAULT_CLOSE_HANDSHAKE_TIMEOUT : Duration = Duration :: from_millis ( 250 ) ;
300+ const DEFAULT_INCOMING_QUEUE_LENGTH : NonZeroUsize = NonZeroUsize :: new ( 2048 ) . expect ( "2048 > 0, qed" ) ;
301+
302+ const DEFAULT : Self = Self {
303+ ping_interval : Self :: DEFAULT_PING_INTERVAL ,
304+ idle_timeout : Self :: DEFAULT_IDLE_TIMEOUT ,
305+ close_handshake_timeout : Self :: DEFAULT_CLOSE_HANDSHAKE_TIMEOUT ,
306+ incoming_queue_length : Self :: DEFAULT_INCOMING_QUEUE_LENGTH ,
307+ } ;
308+
309+ const fn default_ping_interval ( ) -> Duration {
310+ Self :: DEFAULT_PING_INTERVAL
311+ }
312+
313+ const fn default_idle_timeout ( ) -> Duration {
314+ Self :: DEFAULT_IDLE_TIMEOUT
315+ }
316+
317+ const fn default_close_handshake_timeout ( ) -> Duration {
318+ Self :: DEFAULT_CLOSE_HANDSHAKE_TIMEOUT
319+ }
320+
321+ const fn default_incoming_queue_length ( ) -> NonZeroUsize {
322+ Self :: DEFAULT_INCOMING_QUEUE_LENGTH
323+ }
324+ }
325+
326+ async fn ws_client_actor (
327+ options : WebSocketOptions ,
328+ client : ClientConnection ,
329+ ws : WebSocketStream ,
330+ sendrx : MeteredReceiver < SerializableMessage > ,
331+ ) {
282332 // ensure that even if this task gets cancelled, we always cleanup the connection
283333 let mut client = scopeguard:: guard ( client, |client| {
284334 tokio:: spawn ( client. disconnect ( ) ) ;
285335 } ) ;
286336
287- ws_client_actor_inner ( & mut client, < _ > :: default ( ) , ws, sendrx) . await ;
337+ ws_client_actor_inner ( & mut client, options , ws, sendrx) . await ;
288338
289339 ScopeGuard :: into_inner ( client) . disconnect ( ) . await ;
290340}
291341
292342async fn ws_client_actor_inner (
293343 client : & mut ClientConnection ,
294- config : ActorConfig ,
344+ config : WebSocketOptions ,
295345 ws : WebSocketStream ,
296346 sendrx : MeteredReceiver < SerializableMessage > ,
297347) {
@@ -1160,7 +1210,7 @@ mod tests {
11601210 dummy_actor_state_with_config ( <_ >:: default ( ) )
11611211 }
11621212
1163- fn dummy_actor_state_with_config ( config : ActorConfig ) -> ActorState {
1213+ fn dummy_actor_state_with_config ( config : WebSocketOptions ) -> ActorState {
11641214 ActorState :: new ( Identity :: ZERO , dummy_client_id ( ) , config)
11651215 }
11661216
@@ -1482,7 +1532,7 @@ mod tests {
14821532
14831533 #[ tokio:: test]
14841534 async fn main_loop_terminates_on_idle_timeout ( ) {
1485- let state = Arc :: new ( dummy_actor_state_with_config ( ActorConfig {
1535+ let state = Arc :: new ( dummy_actor_state_with_config ( WebSocketOptions {
14861536 idle_timeout : Duration :: from_millis ( 10 ) ,
14871537 ..<_ >:: default ( )
14881538 } ) ) ;
@@ -1520,7 +1570,7 @@ mod tests {
15201570
15211571 #[ tokio:: test]
15221572 async fn main_loop_keepalive_keeps_alive ( ) {
1523- let state = Arc :: new ( dummy_actor_state_with_config ( ActorConfig {
1573+ let state = Arc :: new ( dummy_actor_state_with_config ( WebSocketOptions {
15241574 ping_interval : Duration :: from_millis ( 5 ) ,
15251575 idle_timeout : Duration :: from_millis ( 10 ) ,
15261576 ..<_ >:: default ( )
@@ -1616,7 +1666,7 @@ mod tests {
16161666
16171667 #[ tokio:: test]
16181668 async fn recv_queue_sends_close_when_at_capacity ( ) {
1619- let state = Arc :: new ( dummy_actor_state_with_config ( ActorConfig {
1669+ let state = Arc :: new ( dummy_actor_state_with_config ( WebSocketOptions {
16201670 incoming_queue_length : 10 . try_into ( ) . unwrap ( ) ,
16211671 ..<_ >:: default ( )
16221672 } ) ) ;
@@ -1632,7 +1682,7 @@ mod tests {
16321682
16331683 #[ tokio:: test]
16341684 async fn recv_queue_closes_state_if_sender_gone ( ) {
1635- let state = Arc :: new ( dummy_actor_state_with_config ( ActorConfig {
1685+ let state = Arc :: new ( dummy_actor_state_with_config ( WebSocketOptions {
16361686 incoming_queue_length : 10 . try_into ( ) . unwrap ( ) ,
16371687 ..<_ >:: default ( )
16381688 } ) ) ;
@@ -1695,4 +1745,27 @@ mod tests {
16951745 Poll :: Ready ( Ok ( ( ) ) )
16961746 }
16971747 }
1748+
1749+ #[ test]
1750+ fn options_toml_roundtrip ( ) {
1751+ let options = WebSocketOptions :: default ( ) ;
1752+ let toml = toml:: to_string ( & options) . unwrap ( ) ;
1753+ assert_eq ! ( options, toml:: from_str:: <WebSocketOptions >( & toml) . unwrap( ) ) ;
1754+ }
1755+
1756+ #[ test]
1757+ fn options_from_partial_toml ( ) {
1758+ let toml = r#"
1759+ ping-interval = "53s"
1760+ idle-timeout = "1m 3s"
1761+ "# ;
1762+
1763+ let expected = WebSocketOptions {
1764+ ping_interval : Duration :: from_secs ( 53 ) ,
1765+ idle_timeout : Duration :: from_secs ( 63 ) ,
1766+ ..<_ >:: default ( )
1767+ } ;
1768+
1769+ assert_eq ! ( expected, toml:: from_str( toml) . unwrap( ) ) ;
1770+ }
16981771}
0 commit comments