Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions crates/client-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,15 @@ jsonwebtoken.workspace = true
scopeguard.workspace = true
serde_with.workspace = true
async-stream.workspace = true
humantime.workspace = true

[target.'cfg(not(target_env = "msvc"))'.dependencies]
jemalloc_pprof.workspace = true

[dev-dependencies]
jsonwebtoken.workspace = true
pretty_assertions = { workspace = true, features = ["unstable"] }
toml.workspace = true

[lints]
workspace = true
4 changes: 2 additions & 2 deletions crates/client-api/src/routes/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use spacetimedb_lib::db::raw_def::v9::RawModuleDefV9;
use spacetimedb_lib::identity::AuthCtx;
use spacetimedb_lib::{sats, Timestamp};

use super::subscribe::handle_websocket;
use super::subscribe::{handle_websocket, HasWebSocketOptions};

#[derive(Deserialize)]
pub struct CallParams {
Expand Down Expand Up @@ -790,7 +790,7 @@ pub struct DatabaseRoutes<S> {

impl<S> Default for DatabaseRoutes<S>
where
S: NodeDelegate + ControlStateDelegate + Clone + 'static,
S: NodeDelegate + ControlStateDelegate + HasWebSocketOptions + Clone + 'static,
{
fn default() -> Self {
use axum::routing::{delete, get, post, put};
Expand Down
125 changes: 99 additions & 26 deletions crates/client-api/src/routes/subscribe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ use tokio::time::{sleep_until, timeout};
use tokio_tungstenite::tungstenite::Utf8Bytes;

use crate::auth::SpacetimeAuth;
use crate::util::serde::humantime_duration;
use crate::util::websocket::{
CloseCode, CloseFrame, Message as WsMessage, WebSocketConfig, WebSocketStream, WebSocketUpgrade, WsError,
};
Expand All @@ -55,6 +56,16 @@ pub const TEXT_PROTOCOL: HeaderValue = HeaderValue::from_static(ws_api::TEXT_PRO
#[allow(clippy::declare_interior_mutable_const)]
pub const BIN_PROTOCOL: HeaderValue = HeaderValue::from_static(ws_api::BIN_PROTOCOL);

pub trait HasWebSocketOptions {
fn websocket_options(&self) -> WebSocketOptions;
}

impl<T: HasWebSocketOptions> HasWebSocketOptions for Arc<T> {
fn websocket_options(&self) -> WebSocketOptions {
(**self).websocket_options()
}
}

#[derive(Deserialize)]
pub struct SubscribeParams {
pub name_or_identity: NameOrIdentity,
Expand Down Expand Up @@ -88,7 +99,7 @@ pub async fn handle_websocket<S>(
ws: WebSocketUpgrade,
) -> axum::response::Result<impl IntoResponse>
where
S: NodeDelegate + ControlStateDelegate,
S: NodeDelegate + ControlStateDelegate + HasWebSocketOptions,
{
if connection_id.is_some() {
// TODO: Bump this up to `log::warn!` after removing the client SDKs' uses of that parameter.
Expand Down Expand Up @@ -146,6 +157,7 @@ where
.max_message_size(Some(0x2000000))
.max_frame_size(None)
.accept_unmasked_frames(false);
let ws_opts = ctx.websocket_options();

tokio::spawn(async move {
let ws = match ws_upgrade.upgrade(ws_config).await {
Expand All @@ -163,7 +175,7 @@ where
None => log::debug!("New client connected from unknown ip"),
}

let actor = |client, sendrx| ws_client_actor(client, ws, sendrx);
let actor = |client, sendrx| ws_client_actor(ws_opts, client, ws, sendrx);
let client = match ClientConnection::spawn(client_id, client_config, leader.replica_id, module_rx, actor).await
{
Ok(s) => s,
Expand Down Expand Up @@ -198,13 +210,13 @@ where
struct ActorState {
pub client_id: ClientActorId,
pub database: Identity,
config: ActorConfig,
config: WebSocketOptions,
closed: AtomicBool,
got_pong: AtomicBool,
}

impl ActorState {
pub fn new(database: Identity, client_id: ClientActorId, config: ActorConfig) -> Self {
pub fn new(database: Identity, client_id: ClientActorId, config: WebSocketOptions) -> Self {
Self {
database,
client_id,
Expand Down Expand Up @@ -235,14 +247,19 @@ impl ActorState {
}
}

struct ActorConfig {
/// Configuration for WebSocket connections.
#[derive(Clone, Copy, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct WebSocketOptions {
/// Interval at which to send `Ping` frames.
///
/// We use pings for connection keep-alive.
/// Value must be smaller than `idle_timeout`.
///
/// Default: 15s
ping_interval: Duration,
#[serde(with = "humantime_duration")]
#[serde(default = "WebSocketOptions::default_ping_interval")]
pub ping_interval: Duration,
/// Amount of time after which an idle connection is closed.
///
/// A connection is considered idle if no data is received nor sent.
Expand All @@ -251,47 +268,80 @@ struct ActorConfig {
/// Value must be greater than `ping_interval`.
///
/// Default: 30s
idle_timeout: Duration,
#[serde(with = "humantime_duration")]
#[serde(default = "WebSocketOptions::default_idle_timeout")]
pub idle_timeout: Duration,
/// For how long to keep draining the incoming messages until a client close
/// is received.
///
/// Default: 250ms
close_handshake_timeout: Duration,
#[serde(with = "humantime_duration")]
#[serde(default = "WebSocketOptions::default_close_handshake_timeout")]
pub close_handshake_timeout: Duration,
/// Maximum number of messages to queue for processing.
///
/// If this number is exceeded, the client is disconnected.
///
/// Default: 2048
incoming_queue_length: NonZeroUsize,
#[serde(default = "WebSocketOptions::default_incoming_queue_length")]
pub incoming_queue_length: NonZeroUsize,
}

impl Default for ActorConfig {
impl Default for WebSocketOptions {
fn default() -> Self {
Self {
ping_interval: Duration::from_secs(15),
idle_timeout: Duration::from_secs(30),
close_handshake_timeout: Duration::from_millis(250),
incoming_queue_length:
// SAFETY: 2048 > 0, qed
unsafe { NonZeroUsize::new_unchecked(2048) }
}
Self::DEFAULT
}
}

async fn ws_client_actor(client: ClientConnection, ws: WebSocketStream, sendrx: MeteredReceiver<SerializableMessage>) {
impl WebSocketOptions {
const DEFAULT_PING_INTERVAL: Duration = Duration::from_secs(15);
const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(30);
const DEFAULT_CLOSE_HANDSHAKE_TIMEOUT: Duration = Duration::from_millis(250);
const DEFAULT_INCOMING_QUEUE_LENGTH: NonZeroUsize = NonZeroUsize::new(2048).expect("2048 > 0, qed");

const DEFAULT: Self = Self {
ping_interval: Self::DEFAULT_PING_INTERVAL,
idle_timeout: Self::DEFAULT_IDLE_TIMEOUT,
close_handshake_timeout: Self::DEFAULT_CLOSE_HANDSHAKE_TIMEOUT,
incoming_queue_length: Self::DEFAULT_INCOMING_QUEUE_LENGTH,
};

const fn default_ping_interval() -> Duration {
Self::DEFAULT_PING_INTERVAL
}

const fn default_idle_timeout() -> Duration {
Self::DEFAULT_IDLE_TIMEOUT
}

const fn default_close_handshake_timeout() -> Duration {
Self::DEFAULT_CLOSE_HANDSHAKE_TIMEOUT
}

const fn default_incoming_queue_length() -> NonZeroUsize {
Self::DEFAULT_INCOMING_QUEUE_LENGTH
}
}

async fn ws_client_actor(
options: WebSocketOptions,
client: ClientConnection,
ws: WebSocketStream,
sendrx: MeteredReceiver<SerializableMessage>,
) {
// ensure that even if this task gets cancelled, we always cleanup the connection
let mut client = scopeguard::guard(client, |client| {
tokio::spawn(client.disconnect());
});

ws_client_actor_inner(&mut client, <_>::default(), ws, sendrx).await;
ws_client_actor_inner(&mut client, options, ws, sendrx).await;

ScopeGuard::into_inner(client).disconnect().await;
}

async fn ws_client_actor_inner(
client: &mut ClientConnection,
config: ActorConfig,
config: WebSocketOptions,
ws: WebSocketStream,
sendrx: MeteredReceiver<SerializableMessage>,
) {
Expand Down Expand Up @@ -1160,7 +1210,7 @@ mod tests {
dummy_actor_state_with_config(<_>::default())
}

fn dummy_actor_state_with_config(config: ActorConfig) -> ActorState {
fn dummy_actor_state_with_config(config: WebSocketOptions) -> ActorState {
ActorState::new(Identity::ZERO, dummy_client_id(), config)
}

Expand Down Expand Up @@ -1482,7 +1532,7 @@ mod tests {

#[tokio::test]
async fn main_loop_terminates_on_idle_timeout() {
let state = Arc::new(dummy_actor_state_with_config(ActorConfig {
let state = Arc::new(dummy_actor_state_with_config(WebSocketOptions {
idle_timeout: Duration::from_millis(10),
..<_>::default()
}));
Expand Down Expand Up @@ -1520,7 +1570,7 @@ mod tests {

#[tokio::test]
async fn main_loop_keepalive_keeps_alive() {
let state = Arc::new(dummy_actor_state_with_config(ActorConfig {
let state = Arc::new(dummy_actor_state_with_config(WebSocketOptions {
ping_interval: Duration::from_millis(5),
idle_timeout: Duration::from_millis(10),
..<_>::default()
Expand Down Expand Up @@ -1616,7 +1666,7 @@ mod tests {

#[tokio::test]
async fn recv_queue_sends_close_when_at_capacity() {
let state = Arc::new(dummy_actor_state_with_config(ActorConfig {
let state = Arc::new(dummy_actor_state_with_config(WebSocketOptions {
incoming_queue_length: 10.try_into().unwrap(),
..<_>::default()
}));
Expand All @@ -1632,7 +1682,7 @@ mod tests {

#[tokio::test]
async fn recv_queue_closes_state_if_sender_gone() {
let state = Arc::new(dummy_actor_state_with_config(ActorConfig {
let state = Arc::new(dummy_actor_state_with_config(WebSocketOptions {
incoming_queue_length: 10.try_into().unwrap(),
..<_>::default()
}));
Expand Down Expand Up @@ -1695,4 +1745,27 @@ mod tests {
Poll::Ready(Ok(()))
}
}

#[test]
fn options_toml_roundtrip() {
let options = WebSocketOptions::default();
let toml = toml::to_string(&options).unwrap();
assert_eq!(options, toml::from_str::<WebSocketOptions>(&toml).unwrap());
}

#[test]
fn options_from_partial_toml() {
let toml = r#"
ping-interval = "53s"
idle-timeout = "1m 3s"
"#;

let expected = WebSocketOptions {
ping_interval: Duration::from_secs(53),
idle_timeout: Duration::from_secs(63),
..<_>::default()
};

assert_eq!(expected, toml::from_str(toml).unwrap());
}
}
7 changes: 4 additions & 3 deletions crates/client-api/src/util.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod flat_csv;
pub(crate) mod serde;
pub mod websocket;

use core::fmt;
Expand Down Expand Up @@ -111,16 +112,16 @@ impl NameOrIdentity {
}
}

impl<'de> serde::Deserialize<'de> for NameOrIdentity {
impl<'de> ::serde::Deserialize<'de> for NameOrIdentity {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
D: ::serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
if let Ok(addr) = Identity::from_hex(&s) {
Ok(NameOrIdentity::Identity(IdentityForUrl::from(addr)))
} else {
let name: DatabaseName = s.try_into().map_err(serde::de::Error::custom)?;
let name: DatabaseName = s.try_into().map_err(::serde::de::Error::custom)?;
Ok(NameOrIdentity::Name(name))
}
}
Expand Down
18 changes: 18 additions & 0 deletions crates/client-api/src/util/serde.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/// Ser/De of [`std::time::Duration`] via the `humantime` crate.
///
/// Suitable for use with the `#[serde(with)]` annotation.
pub(crate) mod humantime_duration {
use std::time::Duration;

use ::serde::{Deserialize as _, Deserializer, Serialize as _, Serializer};

pub fn serialize<S: Serializer>(duration: &Duration, ser: S) -> Result<S::Ok, S::Error> {
humantime::format_duration(*duration).to_string().serialize(ser)
}

pub fn deserialize<'de, D: Deserializer<'de>>(de: D) -> Result<Duration, D::Error> {
// TODO: `toml` chokes if we try to derserialize to `&str` here.
let s = String::deserialize(de)?;
humantime::parse_duration(&s).map_err(serde::de::Error::custom)
}
}
2 changes: 1 addition & 1 deletion crates/core/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ impl CertificateAuthority {
}

#[serde_with::serde_as]
#[derive(serde::Deserialize, Default)]
#[derive(Clone, serde::Deserialize, Default)]
#[serde(rename_all = "kebab-case")]
pub struct LogConfig {
#[serde_as(as = "Option<serde_with::DisplayFromStr>")]
Expand Down
1 change: 1 addition & 0 deletions crates/standalone/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ openssl.workspace = true
parse-size.workspace = true
prometheus.workspace = true
scopeguard.workspace = true
serde.workspace = true
serde_json.workspace = true
sled.workspace = true
socket2.workspace = true
Expand Down
Loading
Loading