From 4032e98d0aa9ced1e426954c1558abbc2a451bf6 Mon Sep 17 00:00:00 2001 From: Mazdak Farrokhzad Date: Tue, 3 Jun 2025 22:18:15 +0200 Subject: [PATCH 1/3] messages::serialize: take/put buffers from/into a SerializeBufferPool --- crates/client-api-messages/src/websocket.rs | 26 ++--- crates/client-api/src/lib.rs | 7 ++ crates/client-api/src/routes/subscribe.rs | 106 ++++++++++++++------ crates/core/src/client/client_connection.rs | 14 ++- crates/core/src/client/messages.rs | 86 +++++++++++++--- crates/core/src/host/host_controller.rs | 4 + crates/sats/src/buffer.rs | 8 ++ crates/standalone/src/lib.rs | 4 + 8 files changed, 191 insertions(+), 64 deletions(-) diff --git a/crates/client-api-messages/src/websocket.rs b/crates/client-api-messages/src/websocket.rs index 2292f4b12fe..b2a92e71bb0 100644 --- a/crates/client-api-messages/src/websocket.rs +++ b/crates/client-api-messages/src/websocket.rs @@ -859,25 +859,19 @@ pub fn decide_compression(len: usize, compression: Compression) -> Compression { } } -pub fn brotli_compress(bytes: &[u8], out: &mut Vec) { - let reader = &mut &bytes[..]; - - // The default Brotli buffer size. - const BUFFER_SIZE: usize = 4096; +pub fn brotli_compress(bytes: &[u8], out: &mut impl io::Write) { // We are optimizing for compression speed, // so we choose the lowest (fastest) level of compression. // Experiments on internal workloads have shown compression ratios between 7:1 and 10:1 // for large `SubscriptionUpdate` messages at this level. - const COMPRESSION_LEVEL: u32 = 1; - // The default value for an internal compression parameter. - // See `BrotliEncoderParams` for more details. - const LG_WIN: u32 = 22; + const COMPRESSION_LEVEL: i32 = 1; - let mut encoder = brotli::CompressorReader::new(reader, BUFFER_SIZE, COMPRESSION_LEVEL, LG_WIN); - - encoder - .read_to_end(out) - .expect("Failed to Brotli compress `SubscriptionUpdateMessage`"); + let params = brotli::enc::BrotliEncoderParams { + quality: COMPRESSION_LEVEL, + ..<_>::default() + }; + let reader = &mut &bytes[..]; + brotli::BrotliCompress(reader, out, ¶ms).expect("should be able to BrotliCompress"); } pub fn brotli_decompress(bytes: &[u8]) -> Result, io::Error> { @@ -886,10 +880,10 @@ pub fn brotli_decompress(bytes: &[u8]) -> Result, io::Error> { Ok(decompressed) } -pub fn gzip_compress(bytes: &[u8], out: &mut Vec) { +pub fn gzip_compress(bytes: &[u8], out: &mut impl io::Write) { let mut encoder = flate2::write::GzEncoder::new(out, flate2::Compression::fast()); encoder.write_all(bytes).unwrap(); - encoder.finish().expect("Failed to gzip compress `bytes`"); + encoder.finish().expect("should be able to gzip compress `bytes`"); } pub fn gzip_decompress(bytes: &[u8]) -> Result, io::Error> { diff --git a/crates/client-api/src/lib.rs b/crates/client-api/src/lib.rs index 37b120ed17c..375849bca3d 100644 --- a/crates/client-api/src/lib.rs +++ b/crates/client-api/src/lib.rs @@ -5,6 +5,7 @@ use async_trait::async_trait; use axum::response::ErrorResponse; use http::StatusCode; +use spacetimedb::client::messages::SerializeBufferPool; use spacetimedb::client::ClientActorIndex; use spacetimedb::energy::{EnergyBalance, EnergyQuanta}; use spacetimedb::host::{HostController, ModuleHost, NoSuchModule, UpdateDatabaseResult}; @@ -39,6 +40,8 @@ pub trait NodeDelegate: Send + Sync { /// The [`Host`] is spawned implicitly if not already running. async fn leader(&self, database_id: u64) -> anyhow::Result>; fn module_logs_dir(&self, replica_id: u64) -> ModuleLogsDir; + + fn websocket_send_serialize_buffer_pool(&self) -> &Arc; } /// Client view of a running module. @@ -371,6 +374,10 @@ impl NodeDelegate for Arc { fn module_logs_dir(&self, replica_id: u64) -> ModuleLogsDir { (**self).module_logs_dir(replica_id) } + + fn websocket_send_serialize_buffer_pool(&self) -> &Arc { + (**self).websocket_send_serialize_buffer_pool() + } } pub fn log_and_500(e: impl std::fmt::Display) -> ErrorResponse { diff --git a/crates/client-api/src/routes/subscribe.rs b/crates/client-api/src/routes/subscribe.rs index 8d09eee3c93..933ff6c282d 100644 --- a/crates/client-api/src/routes/subscribe.rs +++ b/crates/client-api/src/routes/subscribe.rs @@ -1,6 +1,7 @@ use std::collections::VecDeque; use std::mem; use std::pin::{pin, Pin}; +use std::sync::Arc; use std::time::Duration; use axum::extract::{Path, Query, State}; @@ -14,12 +15,14 @@ use futures::{Future, FutureExt, SinkExt, StreamExt}; use http::{HeaderValue, StatusCode}; use scopeguard::ScopeGuard; use serde::Deserialize; -use spacetimedb::client::messages::{serialize, IdentityTokenMessage, SerializableMessage}; +use spacetimedb::client::messages::{serialize, IdentityTokenMessage, SerializableMessage, SerializeBufferPool}; use spacetimedb::client::{ClientActorId, ClientConfig, ClientConnection, DataMessage, MessageHandleError, Protocol}; +use spacetimedb::execution_context::WorkloadType; use spacetimedb::host::module_host::ClientConnectedError; use spacetimedb::host::NoSuchModule; use spacetimedb::util::also_poll; use spacetimedb::worker_metrics::WORKER_METRICS; +use spacetimedb::Identity; use spacetimedb_client_api_messages::websocket::{self as ws_api, Compression}; use spacetimedb_lib::connection_id::{ConnectionId, ConnectionIdForUrl}; use std::time::Instant; @@ -125,6 +128,8 @@ where name: ctx.client_actor_index().next_client_name(), }; + let serialize_buffer_pool = ctx.websocket_send_serialize_buffer_pool().clone(); + let ws_config = WebSocketConfig::default() .max_message_size(Some(0x2000000)) .max_frame_size(None) @@ -146,7 +151,7 @@ where None => log::debug!("New client connected from unknown ip"), } - let actor = |client, sendrx| ws_client_actor(client, ws, sendrx); + let actor = |client, sendrx| ws_client_actor(client, ws, sendrx, serialize_buffer_pool); let client = match ClientConnection::spawn(client_id, client_config, leader.replica_id, module_rx, actor).await { Ok(s) => s, @@ -180,13 +185,18 @@ where const LIVELINESS_TIMEOUT: Duration = Duration::from_secs(60); -async fn ws_client_actor(client: ClientConnection, ws: WebSocketStream, sendrx: mpsc::Receiver) { +async fn ws_client_actor( + client: ClientConnection, + ws: WebSocketStream, + sendrx: mpsc::Receiver, + serialize_buffer_pool: Arc, +) { // ensure that even if this task gets cancelled, we always cleanup the connection let mut client = scopeguard::guard(client, |client| { tokio::spawn(client.disconnect()); }); - ws_client_actor_inner(&mut client, ws, sendrx).await; + ws_client_actor_inner(&mut client, ws, sendrx, &serialize_buffer_pool).await; ScopeGuard::into_inner(client).disconnect().await; } @@ -203,6 +213,7 @@ async fn ws_client_actor_inner( client: &mut ClientConnection, mut ws: WebSocketStream, mut sendrx: mpsc::Receiver, + serialize_buffer_pool: &SerializeBufferPool, ) { let mut liveness_check_interval = tokio::time::interval(LIVELINESS_TIMEOUT); let mut got_pong = true; @@ -299,31 +310,31 @@ async fn ws_client_actor_inner( log::info!("dropping {n} messages due to ws already being closed"); log::debug!("dropped messages: {:?}", &rx_buf[..n]); } else { - let send_all = async { - for msg in rx_buf.drain(..n) { - let workload = msg.workload(); - let num_rows = msg.num_rows(); - - let msg = datamsg_to_wsmsg(serialize(msg, client.config)); - - // These metrics should be updated together, - // or not at all. - if let (Some(workload), Some(num_rows)) = (workload, num_rows) { - WORKER_METRICS - .websocket_sent_num_rows - .with_label_values(&addr, &workload) - .observe(num_rows as f64); - WORKER_METRICS - .websocket_sent_msg_size - .with_label_values(&addr, &workload) - .observe(msg.len() as f64); + let send_all = async { + for msg in rx_buf.drain(..n) { + let workload = msg.workload(); + let num_rows = msg.num_rows(); + + // Serialize the message, report metrics, + // and keep a handle to the buffer. + let msg_data = serialize(serialize_buffer_pool, msg, client.config); + report_ws_sent_metrics(&addr, workload, num_rows, &msg_data); + let msg_alloc = msg_data.allocation(); + + // Buffer the message without necessarily sending it. + ws.feed(datamsg_to_wsmsg(msg_data)).await?; + + // At this point, + // the underlying allocation of `msg_data` should have a single referent + // and this should be `msg_alloc`. + // We can put this back into our pool. + let msg_alloc = msg_alloc.try_into_mut() + .expect("should have a unique referent to `msg_alloc`"); + serialize_buffer_pool.put(msg_alloc); } - // feed() buffers the message, but does not necessarily send it - ws.feed(msg).await?; - } - // now we flush all the messages to the socket - ws.flush().await - }; + // now we flush all the messages to the socket + ws.flush().await + }; // Flush the websocket while continuing to poll the `handle_queue`, // to avoid deadlocks or delays due to enqueued futures holding resources. let send_all = also_poll(send_all, make_progress(&mut current_message)); @@ -394,10 +405,24 @@ async fn ws_client_actor_inner( if let Err(e) = res { if let MessageHandleError::Execution(err) = e { log::error!("{err:#}"); - let msg = serialize(err, client.config); - if let Err(error) = ws.send(datamsg_to_wsmsg(msg)).await { + // Serialize the message and keep a handle to the buffer. + let msg_data = serialize(serialize_buffer_pool, err, client.config); + let msg_alloc = msg_data.allocation(); + + // Buffer the message without necessarily sending it. + if let Err(error) = ws.send(datamsg_to_wsmsg(msg_data)).await { log::warn!("Websocket send error: {error}") } + + // At this point, + // the underlying allocation of `msg_data` should have a single referent + // and this should be `msg_alloc`. + // We can put this back into our pool. + let msg_alloc = msg_alloc + .try_into_mut() + .expect("should have a unique referent to `msg_alloc`"); + serialize_buffer_pool.put(msg_alloc); + continue; } log::debug!("Client caused error on text message: {}", e); @@ -461,6 +486,27 @@ impl ClientMessage { } } +/// Report metrics on sent rows and message sizes to a websocket client. +fn report_ws_sent_metrics( + addr: &Identity, + workload: Option, + num_rows: Option, + msg_ws: &DataMessage, +) { + // These metrics should be updated together, + // or not at all. + if let (Some(workload), Some(num_rows)) = (workload, num_rows) { + WORKER_METRICS + .websocket_sent_num_rows + .with_label_values(addr, &workload) + .observe(num_rows as f64); + WORKER_METRICS + .websocket_sent_msg_size + .with_label_values(addr, &workload) + .observe(msg_ws.len() as f64); + } +} + fn datamsg_to_wsmsg(msg: DataMessage) -> WsMessage { match msg { DataMessage::Text(text) => WsMessage::Text(bytestring_to_utf8bytes(text)), diff --git a/crates/core/src/client/client_connection.rs b/crates/core/src/client/client_connection.rs index 894951c04a3..56bcb13f318 100644 --- a/crates/core/src/client/client_connection.rs +++ b/crates/core/src/client/client_connection.rs @@ -234,17 +234,27 @@ impl From> for DataMessage { } impl DataMessage { + /// Returns the number of bytes this message consists of. pub fn len(&self) -> usize { match self { - DataMessage::Text(s) => s.len(), - DataMessage::Binary(b) => b.len(), + Self::Text(s) => s.len(), + Self::Binary(b) => b.len(), } } + /// Is the message empty? #[must_use] pub fn is_empty(&self) -> bool { self.len() == 0 } + + /// Returns a handle to the underlying allocation of the message without consuming it. + pub fn allocation(&self) -> Bytes { + match self { + DataMessage::Text(alloc) => alloc.as_bytes().clone(), + DataMessage::Binary(alloc) => alloc.clone(), + } + } } // if a client racks up this many messages in the queue without ACK'ing diff --git a/crates/core/src/client/messages.rs b/crates/core/src/client/messages.rs index cdd0e3c1c94..4eb490c94ae 100644 --- a/crates/core/src/client/messages.rs +++ b/crates/core/src/client/messages.rs @@ -3,6 +3,9 @@ use crate::execution_context::WorkloadType; use crate::host::module_host::{EventStatus, ModuleEvent}; use crate::host::ArgsTuple; use crate::messages::websocket as ws; +use bytes::{BufMut, BytesMut}; +use bytestring::ByteString; +use crossbeam_queue::SegQueue; use derive_more::From; use spacetimedb_client_api_messages::websocket::{ BsatnFormat, Compression, FormatSwitch, JsonFormat, OneOffTable, RowListLen, WebsocketFormat, @@ -27,36 +30,87 @@ pub trait ToProtocol { pub(super) type SwitchedServerMessage = FormatSwitch, ws::ServerMessage>; pub(super) type SwitchedDbUpdate = FormatSwitch, ws::DatabaseUpdate>; +/// The initial size of a `serialize` buffer. +/// Currently 4k to align with the linux page size +/// and this should be more than enough in the common case. +const SERIALIZE_BUFFER_INIT_CAP: usize = 4096; + +// A pool of buffers used by [`serialize`]. +#[derive(Default)] +pub struct SerializeBufferPool { + pool: SegQueue, +} + +impl SerializeBufferPool { + /// Puts back a buffer into the pool. + pub fn put(&self, buf: BytesMut) { + self.pool.push(buf); + } + + /// Returns a buffer from the pool or creates a new one. + fn take(&self) -> BytesMut { + match self.pool.pop() { + Some(mut buf) => { + buf.clear(); + buf + } + None => BytesMut::with_capacity(SERIALIZE_BUFFER_INIT_CAP), + } + } + + /// Returns a buffer and inserts a leading `tag`. + fn take_with_tag(&self, tag: u8) -> BytesMut { + let mut buf = self.take(); + buf.put_u8(tag); + buf + } + + /// Returns a buffer, inserts a leading `tag`, and then runs `write`. + fn take_with_tag_and_writer(&self, tag: u8, write: impl FnOnce(&mut bytes::buf::Writer)) -> BytesMut { + let mut writer = self.take_with_tag(tag).writer(); + write(&mut writer); + writer.into_inner() + } +} + /// Serialize `msg` into a [`DataMessage`] containing a [`ws::ServerMessage`]. /// /// If `protocol` is [`Protocol::Binary`], /// the message will be conditionally compressed by this method according to `compression`. -pub fn serialize(msg: impl ToProtocol, config: ClientConfig) -> DataMessage { - // TODO(centril, perf): here we are allocating buffers only to throw them away eventually. - // Consider pooling these allocations so that we reuse them. +pub fn serialize( + pool: &SerializeBufferPool, + msg: impl ToProtocol, + config: ClientConfig, +) -> DataMessage { match msg.to_protocol(config.protocol) { - FormatSwitch::Json(msg) => serde_json::to_string(&SerializeWrapper::new(msg)).unwrap().into(), + FormatSwitch::Json(msg) => { + let mut out: bytes::buf::Writer = pool.take().writer(); + serde_json::to_writer(&mut out, &SerializeWrapper::new(msg)) + .expect("should be able to json encode a `ServerMessage`"); + let out = out.into_inner(); + + let out = out.freeze(); + // SAFETY: `serde_json::to_writer` states that: + // > "Serialization guarantees it only feeds valid UTF-8 sequences to the writer." + let msg_json = unsafe { ByteString::from_bytes_unchecked(out) }; + msg_json.into() + } FormatSwitch::Bsatn(msg) => { // First write the tag so that we avoid shifting the entire message at the end. - let mut msg_bytes = vec![SERVER_MSG_COMPRESSION_TAG_NONE]; + let mut msg_bytes = pool.take_with_tag(SERVER_MSG_COMPRESSION_TAG_NONE); bsatn::to_writer(&mut msg_bytes, &msg).unwrap(); // Conditionally compress the message. let srv_msg = &msg_bytes[1..]; let msg_bytes = match ws::decide_compression(srv_msg.len(), config.compression) { Compression::None => msg_bytes, - Compression::Brotli => { - let mut out = vec![SERVER_MSG_COMPRESSION_TAG_BROTLI]; - ws::brotli_compress(srv_msg, &mut out); - out - } - Compression::Gzip => { - let mut out = vec![SERVER_MSG_COMPRESSION_TAG_GZIP]; - ws::gzip_compress(srv_msg, &mut out); - out - } + Compression::Brotli => pool.take_with_tag_and_writer(SERVER_MSG_COMPRESSION_TAG_BROTLI, |out| { + ws::brotli_compress(srv_msg, out) + }), + Compression::Gzip => pool + .take_with_tag_and_writer(SERVER_MSG_COMPRESSION_TAG_GZIP, |out| ws::gzip_compress(srv_msg, out)), }; - msg_bytes.into() + msg_bytes.freeze().into() } } } diff --git a/crates/core/src/host/host_controller.rs b/crates/core/src/host/host_controller.rs index a3231df6dfd..26eb548188a 100644 --- a/crates/core/src/host/host_controller.rs +++ b/crates/core/src/host/host_controller.rs @@ -2,6 +2,7 @@ use super::module_host::{EventStatus, ModuleHost, ModuleInfo, NoSuchModule}; use super::scheduler::SchedulerStarter; use super::wasmtime::WasmtimeRuntime; use super::{Scheduler, UpdateDatabaseResult}; +use crate::client::messages::SerializeBufferPool; use crate::database_logger::DatabaseLogger; use crate::db::datastore::traits::Program; use crate::db::db_metrics::data_size::DATA_SIZE_METRICS; @@ -94,6 +95,8 @@ pub struct HostController { durability: Arc, /// The page pool all databases will use by cloning the ref counted pool. pub page_pool: PagePool, + // The buffer pool used for serializing websocket messages to send. + pub websocket_send_serialize_buffer_pool: Arc, /// The runtimes for running our modules. runtimes: Arc, /// The CPU cores that are reserved for ModuleHost operations to run on. @@ -184,6 +187,7 @@ impl HostController { data_dir, page_pool: PagePool::new(default_config.page_pool_max_size), db_cores, + websocket_send_serialize_buffer_pool: <_>::default(), } } diff --git a/crates/sats/src/buffer.rs b/crates/sats/src/buffer.rs index c6a81ffed40..c3762282c7f 100644 --- a/crates/sats/src/buffer.rs +++ b/crates/sats/src/buffer.rs @@ -2,6 +2,8 @@ //! without relying on types in third party libraries like `bytes::Bytes`, etc. //! Meant to be kept slim and trim for use across both native and WASM. +use bytes::{BufMut, BytesMut}; + use crate::{i256, u256}; use core::cell::Cell; use core::fmt; @@ -309,6 +311,12 @@ impl BufWriter for &mut [u8] { } } +impl BufWriter for BytesMut { + fn put_slice(&mut self, slice: &[u8]) { + BufMut::put_slice(self, slice); + } +} + /// A [`BufWriter`] that only counts the bytes. #[derive(Default)] pub struct CountWriter { diff --git a/crates/standalone/src/lib.rs b/crates/standalone/src/lib.rs index 6021694c9a7..bb4ae12eca2 100644 --- a/crates/standalone/src/lib.rs +++ b/crates/standalone/src/lib.rs @@ -8,6 +8,7 @@ use crate::subcommands::{extract_schema, start}; use anyhow::{ensure, Context, Ok}; use async_trait::async_trait; use clap::{ArgMatches, Command}; +use spacetimedb::client::messages::SerializeBufferPool; use spacetimedb::client::ClientActorIndex; use spacetimedb::config::{CertificateAuthority, MetadataFile}; use spacetimedb::db::datastore::traits::Program; @@ -161,6 +162,9 @@ impl NodeDelegate for StandaloneEnv { fn module_logs_dir(&self, replica_id: u64) -> ModuleLogsDir { self.data_dir().replica(replica_id).module_logs() } + fn websocket_send_serialize_buffer_pool(&self) -> &Arc { + &self.host_controller.websocket_send_serialize_buffer_pool + } } impl spacetimedb_client_api::ControlStateReadAccess for StandaloneEnv { From 42dc169d76457346ebf10f4f6b77521ba2301f5b Mon Sep 17 00:00:00 2001 From: Noa Date: Thu, 5 Jun 2025 14:17:11 -0500 Subject: [PATCH 2/3] Local buffers --- Cargo.lock | 1 - crates/client-api/src/lib.rs | 7 -- crates/client-api/src/routes/subscribe.rs | 44 +++---- crates/core/Cargo.toml | 1 - crates/core/src/client/messages.rs | 137 ++++++++++++++-------- crates/core/src/host/host_controller.rs | 4 - crates/standalone/src/lib.rs | 4 - 7 files changed, 109 insertions(+), 89 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c9bd4e05985..fae19ae9257 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5470,7 +5470,6 @@ dependencies = [ "core_affinity", "criterion", "crossbeam-channel", - "crossbeam-queue", "derive_more", "dirs", "enum-as-inner", diff --git a/crates/client-api/src/lib.rs b/crates/client-api/src/lib.rs index 375849bca3d..37b120ed17c 100644 --- a/crates/client-api/src/lib.rs +++ b/crates/client-api/src/lib.rs @@ -5,7 +5,6 @@ use async_trait::async_trait; use axum::response::ErrorResponse; use http::StatusCode; -use spacetimedb::client::messages::SerializeBufferPool; use spacetimedb::client::ClientActorIndex; use spacetimedb::energy::{EnergyBalance, EnergyQuanta}; use spacetimedb::host::{HostController, ModuleHost, NoSuchModule, UpdateDatabaseResult}; @@ -40,8 +39,6 @@ pub trait NodeDelegate: Send + Sync { /// The [`Host`] is spawned implicitly if not already running. async fn leader(&self, database_id: u64) -> anyhow::Result>; fn module_logs_dir(&self, replica_id: u64) -> ModuleLogsDir; - - fn websocket_send_serialize_buffer_pool(&self) -> &Arc; } /// Client view of a running module. @@ -374,10 +371,6 @@ impl NodeDelegate for Arc { fn module_logs_dir(&self, replica_id: u64) -> ModuleLogsDir { (**self).module_logs_dir(replica_id) } - - fn websocket_send_serialize_buffer_pool(&self) -> &Arc { - (**self).websocket_send_serialize_buffer_pool() - } } pub fn log_and_500(e: impl std::fmt::Display) -> ErrorResponse { diff --git a/crates/client-api/src/routes/subscribe.rs b/crates/client-api/src/routes/subscribe.rs index 933ff6c282d..e7795c69046 100644 --- a/crates/client-api/src/routes/subscribe.rs +++ b/crates/client-api/src/routes/subscribe.rs @@ -1,7 +1,6 @@ use std::collections::VecDeque; use std::mem; use std::pin::{pin, Pin}; -use std::sync::Arc; use std::time::Duration; use axum::extract::{Path, Query, State}; @@ -15,7 +14,7 @@ use futures::{Future, FutureExt, SinkExt, StreamExt}; use http::{HeaderValue, StatusCode}; use scopeguard::ScopeGuard; use serde::Deserialize; -use spacetimedb::client::messages::{serialize, IdentityTokenMessage, SerializableMessage, SerializeBufferPool}; +use spacetimedb::client::messages::{serialize, IdentityTokenMessage, SerializableMessage, SerializeBuffer}; use spacetimedb::client::{ClientActorId, ClientConfig, ClientConnection, DataMessage, MessageHandleError, Protocol}; use spacetimedb::execution_context::WorkloadType; use spacetimedb::host::module_host::ClientConnectedError; @@ -128,8 +127,6 @@ where name: ctx.client_actor_index().next_client_name(), }; - let serialize_buffer_pool = ctx.websocket_send_serialize_buffer_pool().clone(); - let ws_config = WebSocketConfig::default() .max_message_size(Some(0x2000000)) .max_frame_size(None) @@ -151,7 +148,7 @@ where None => log::debug!("New client connected from unknown ip"), } - let actor = |client, sendrx| ws_client_actor(client, ws, sendrx, serialize_buffer_pool); + let actor = |client, sendrx| ws_client_actor(client, ws, sendrx); let client = match ClientConnection::spawn(client_id, client_config, leader.replica_id, module_rx, actor).await { Ok(s) => s, @@ -185,18 +182,13 @@ where const LIVELINESS_TIMEOUT: Duration = Duration::from_secs(60); -async fn ws_client_actor( - client: ClientConnection, - ws: WebSocketStream, - sendrx: mpsc::Receiver, - serialize_buffer_pool: Arc, -) { +async fn ws_client_actor(client: ClientConnection, ws: WebSocketStream, sendrx: mpsc::Receiver) { // ensure that even if this task gets cancelled, we always cleanup the connection let mut client = scopeguard::guard(client, |client| { tokio::spawn(client.disconnect()); }); - ws_client_actor_inner(&mut client, ws, sendrx, &serialize_buffer_pool).await; + ws_client_actor_inner(&mut client, ws, sendrx).await; ScopeGuard::into_inner(client).disconnect().await; } @@ -213,7 +205,6 @@ async fn ws_client_actor_inner( client: &mut ClientConnection, mut ws: WebSocketStream, mut sendrx: mpsc::Receiver, - serialize_buffer_pool: &SerializeBufferPool, ) { let mut liveness_check_interval = tokio::time::interval(LIVELINESS_TIMEOUT); let mut got_pong = true; @@ -257,6 +248,7 @@ async fn ws_client_actor_inner( outgoing_queue_length_metric.sub(sendrx.len() as _); }; + let mut msg_buffer = SerializeBuffer::new(client.config); loop { rx_buf.clear(); enum Item { @@ -317,29 +309,33 @@ async fn ws_client_actor_inner( // Serialize the message, report metrics, // and keep a handle to the buffer. - let msg_data = serialize(serialize_buffer_pool, msg, client.config); + let (msg_alloc, msg_data) = serialize(msg_buffer, msg, client.config); report_ws_sent_metrics(&addr, workload, num_rows, &msg_data); - let msg_alloc = msg_data.allocation(); // Buffer the message without necessarily sending it. - ws.feed(datamsg_to_wsmsg(msg_data)).await?; + let res = ws.feed(datamsg_to_wsmsg(msg_data)).await; // At this point, // the underlying allocation of `msg_data` should have a single referent // and this should be `msg_alloc`. // We can put this back into our pool. - let msg_alloc = msg_alloc.try_into_mut() + msg_buffer = msg_alloc.try_reclaim() .expect("should have a unique referent to `msg_alloc`"); - serialize_buffer_pool.put(msg_alloc); + + if res.is_err() { + return (res, msg_buffer); + } } // now we flush all the messages to the socket - ws.flush().await + (ws.flush().await, msg_buffer) }; // Flush the websocket while continuing to poll the `handle_queue`, // to avoid deadlocks or delays due to enqueued futures holding resources. let send_all = also_poll(send_all, make_progress(&mut current_message)); let t1 = Instant::now(); - if let Err(error) = send_all.await { + let (send_all_result, buf) = send_all.await; + msg_buffer = buf; + if let Err(error) = send_all_result { log::warn!("Websocket send error: {error}") } let time = t1.elapsed(); @@ -406,8 +402,7 @@ async fn ws_client_actor_inner( if let MessageHandleError::Execution(err) = e { log::error!("{err:#}"); // Serialize the message and keep a handle to the buffer. - let msg_data = serialize(serialize_buffer_pool, err, client.config); - let msg_alloc = msg_data.allocation(); + let (msg_alloc, msg_data) = serialize(msg_buffer, err, client.config); // Buffer the message without necessarily sending it. if let Err(error) = ws.send(datamsg_to_wsmsg(msg_data)).await { @@ -418,10 +413,9 @@ async fn ws_client_actor_inner( // the underlying allocation of `msg_data` should have a single referent // and this should be `msg_alloc`. // We can put this back into our pool. - let msg_alloc = msg_alloc - .try_into_mut() + msg_buffer = msg_alloc + .try_reclaim() .expect("should have a unique referent to `msg_alloc`"); - serialize_buffer_pool.put(msg_alloc); continue; } diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 5599fb11615..d3c5ecfe293 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -48,7 +48,6 @@ bytes.workspace = true bytestring.workspace = true chrono.workspace = true crossbeam-channel.workspace = true -crossbeam-queue.workspace = true derive_more.workspace = true dirs.workspace = true enum-as-inner.workspace = true diff --git a/crates/core/src/client/messages.rs b/crates/core/src/client/messages.rs index 4eb490c94ae..fa801a17bc3 100644 --- a/crates/core/src/client/messages.rs +++ b/crates/core/src/client/messages.rs @@ -3,9 +3,8 @@ use crate::execution_context::WorkloadType; use crate::host::module_host::{EventStatus, ModuleEvent}; use crate::host::ArgsTuple; use crate::messages::websocket as ws; -use bytes::{BufMut, BytesMut}; +use bytes::{BufMut, Bytes, BytesMut}; use bytestring::ByteString; -use crossbeam_queue::SegQueue; use derive_more::From; use spacetimedb_client_api_messages::websocket::{ BsatnFormat, Compression, FormatSwitch, JsonFormat, OneOffTable, RowListLen, WebsocketFormat, @@ -35,41 +34,89 @@ pub(super) type SwitchedDbUpdate = FormatSwitch, /// and this should be more than enough in the common case. const SERIALIZE_BUFFER_INIT_CAP: usize = 4096; -// A pool of buffers used by [`serialize`]. -#[derive(Default)] -pub struct SerializeBufferPool { - pool: SegQueue, +/// A buffer used by [`serialize`] +pub struct SerializeBuffer { + uncompressed: BytesMut, + compressed: BytesMut, } -impl SerializeBufferPool { - /// Puts back a buffer into the pool. - pub fn put(&self, buf: BytesMut) { - self.pool.push(buf); +impl SerializeBuffer { + pub fn new(config: ClientConfig) -> Self { + let uncompressed_capacity = SERIALIZE_BUFFER_INIT_CAP; + let compressed_capacity = if config.compression == Compression::None || config.protocol == Protocol::Text { + 0 + } else { + SERIALIZE_BUFFER_INIT_CAP + }; + Self { + uncompressed: BytesMut::with_capacity(uncompressed_capacity), + compressed: BytesMut::with_capacity(compressed_capacity), + } } - /// Returns a buffer from the pool or creates a new one. - fn take(&self) -> BytesMut { - match self.pool.pop() { - Some(mut buf) => { - buf.clear(); - buf - } - None => BytesMut::with_capacity(SERIALIZE_BUFFER_INIT_CAP), - } + /// Take the uncompressed message as the one to use. + fn uncompressed(self) -> (InUseSerializeBuffer, Bytes) { + let uncompressed = self.uncompressed.freeze(); + let in_use = InUseSerializeBuffer::Uncompressed { + uncompressed: uncompressed.clone(), + compressed: self.compressed, + }; + (in_use, uncompressed) + } + + /// Write uncompressed data with a leading tag. + fn write_with_tag(&mut self, tag: u8, write: F) -> &[u8] + where + F: FnOnce(bytes::buf::Writer<&mut BytesMut>), + { + self.uncompressed.put_u8(tag); + write((&mut self.uncompressed).writer()); + &self.uncompressed[1..] } - /// Returns a buffer and inserts a leading `tag`. - fn take_with_tag(&self, tag: u8) -> BytesMut { - let mut buf = self.take(); - buf.put_u8(tag); - buf + /// Compress the data from a `write_with_tag` call, and change the tag. + fn compress_with_tag( + self, + tag: u8, + write: impl FnOnce(&[u8], &mut bytes::buf::Writer), + ) -> (InUseSerializeBuffer, Bytes) { + let mut writer = self.compressed.writer(); + writer.get_mut().put_u8(tag); + write(&self.uncompressed[1..], &mut writer); + let compressed = writer.into_inner().freeze(); + let in_use = InUseSerializeBuffer::Compressed { + uncompressed: self.uncompressed, + compressed: compressed.clone(), + }; + (in_use, compressed) } +} + +type BytesMutWriter<'a> = bytes::buf::Writer<&'a mut BytesMut>; + +pub enum InUseSerializeBuffer { + Uncompressed { uncompressed: Bytes, compressed: BytesMut }, + Compressed { uncompressed: BytesMut, compressed: Bytes }, +} - /// Returns a buffer, inserts a leading `tag`, and then runs `write`. - fn take_with_tag_and_writer(&self, tag: u8, write: impl FnOnce(&mut bytes::buf::Writer)) -> BytesMut { - let mut writer = self.take_with_tag(tag).writer(); - write(&mut writer); - writer.into_inner() +impl InUseSerializeBuffer { + pub fn try_reclaim(self) -> Option { + let (mut uncompressed, mut compressed) = match self { + Self::Uncompressed { + uncompressed, + compressed, + } => (uncompressed.try_into_mut().ok()?, compressed), + Self::Compressed { + uncompressed, + compressed, + } => (uncompressed, compressed.try_into_mut().ok()?), + }; + uncompressed.clear(); + compressed.clear(); + Some(SerializeBuffer { + uncompressed, + compressed, + }) } } @@ -78,39 +125,35 @@ impl SerializeBufferPool { /// If `protocol` is [`Protocol::Binary`], /// the message will be conditionally compressed by this method according to `compression`. pub fn serialize( - pool: &SerializeBufferPool, + mut buffer: SerializeBuffer, msg: impl ToProtocol, config: ClientConfig, -) -> DataMessage { +) -> (InUseSerializeBuffer, DataMessage) { match msg.to_protocol(config.protocol) { FormatSwitch::Json(msg) => { - let mut out: bytes::buf::Writer = pool.take().writer(); - serde_json::to_writer(&mut out, &SerializeWrapper::new(msg)) + let out: BytesMutWriter<'_> = (&mut buffer.uncompressed).writer(); + serde_json::to_writer(out, &SerializeWrapper::new(msg)) .expect("should be able to json encode a `ServerMessage`"); - let out = out.into_inner(); - let out = out.freeze(); + let (in_use, out) = buffer.uncompressed(); // SAFETY: `serde_json::to_writer` states that: // > "Serialization guarantees it only feeds valid UTF-8 sequences to the writer." let msg_json = unsafe { ByteString::from_bytes_unchecked(out) }; - msg_json.into() + (in_use, msg_json.into()) } FormatSwitch::Bsatn(msg) => { // First write the tag so that we avoid shifting the entire message at the end. - let mut msg_bytes = pool.take_with_tag(SERVER_MSG_COMPRESSION_TAG_NONE); - bsatn::to_writer(&mut msg_bytes, &msg).unwrap(); + let srv_msg = buffer.write_with_tag(SERVER_MSG_COMPRESSION_TAG_NONE, |w| { + bsatn::to_writer(w.into_inner(), &msg).unwrap() + }); // Conditionally compress the message. - let srv_msg = &msg_bytes[1..]; - let msg_bytes = match ws::decide_compression(srv_msg.len(), config.compression) { - Compression::None => msg_bytes, - Compression::Brotli => pool.take_with_tag_and_writer(SERVER_MSG_COMPRESSION_TAG_BROTLI, |out| { - ws::brotli_compress(srv_msg, out) - }), - Compression::Gzip => pool - .take_with_tag_and_writer(SERVER_MSG_COMPRESSION_TAG_GZIP, |out| ws::gzip_compress(srv_msg, out)), + let (in_use, msg_bytes) = match ws::decide_compression(srv_msg.len(), config.compression) { + Compression::None => buffer.uncompressed(), + Compression::Brotli => buffer.compress_with_tag(SERVER_MSG_COMPRESSION_TAG_BROTLI, ws::brotli_compress), + Compression::Gzip => buffer.compress_with_tag(SERVER_MSG_COMPRESSION_TAG_GZIP, ws::gzip_compress), }; - msg_bytes.freeze().into() + (in_use, msg_bytes.into()) } } } diff --git a/crates/core/src/host/host_controller.rs b/crates/core/src/host/host_controller.rs index 26eb548188a..a3231df6dfd 100644 --- a/crates/core/src/host/host_controller.rs +++ b/crates/core/src/host/host_controller.rs @@ -2,7 +2,6 @@ use super::module_host::{EventStatus, ModuleHost, ModuleInfo, NoSuchModule}; use super::scheduler::SchedulerStarter; use super::wasmtime::WasmtimeRuntime; use super::{Scheduler, UpdateDatabaseResult}; -use crate::client::messages::SerializeBufferPool; use crate::database_logger::DatabaseLogger; use crate::db::datastore::traits::Program; use crate::db::db_metrics::data_size::DATA_SIZE_METRICS; @@ -95,8 +94,6 @@ pub struct HostController { durability: Arc, /// The page pool all databases will use by cloning the ref counted pool. pub page_pool: PagePool, - // The buffer pool used for serializing websocket messages to send. - pub websocket_send_serialize_buffer_pool: Arc, /// The runtimes for running our modules. runtimes: Arc, /// The CPU cores that are reserved for ModuleHost operations to run on. @@ -187,7 +184,6 @@ impl HostController { data_dir, page_pool: PagePool::new(default_config.page_pool_max_size), db_cores, - websocket_send_serialize_buffer_pool: <_>::default(), } } diff --git a/crates/standalone/src/lib.rs b/crates/standalone/src/lib.rs index bb4ae12eca2..6021694c9a7 100644 --- a/crates/standalone/src/lib.rs +++ b/crates/standalone/src/lib.rs @@ -8,7 +8,6 @@ use crate::subcommands::{extract_schema, start}; use anyhow::{ensure, Context, Ok}; use async_trait::async_trait; use clap::{ArgMatches, Command}; -use spacetimedb::client::messages::SerializeBufferPool; use spacetimedb::client::ClientActorIndex; use spacetimedb::config::{CertificateAuthority, MetadataFile}; use spacetimedb::db::datastore::traits::Program; @@ -162,9 +161,6 @@ impl NodeDelegate for StandaloneEnv { fn module_logs_dir(&self, replica_id: u64) -> ModuleLogsDir { self.data_dir().replica(replica_id).module_logs() } - fn websocket_send_serialize_buffer_pool(&self) -> &Arc { - &self.host_controller.websocket_send_serialize_buffer_pool - } } impl spacetimedb_client_api::ControlStateReadAccess for StandaloneEnv { From 2131c22bcb86a24f6938febece1c834a3bd26791 Mon Sep 17 00:00:00 2001 From: Phoebe Goldman Date: Mon, 16 Jun 2025 11:46:39 -0400 Subject: [PATCH 3/3] Add `core` dep on `crossbeam-queue` Not sure how this got missed, as it's used in `core/src/startup.rs`. Perhaps a bad merge/rebase? --- Cargo.lock | 1 + crates/core/Cargo.toml | 1 + 2 files changed, 2 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index fae19ae9257..c9bd4e05985 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5470,6 +5470,7 @@ dependencies = [ "core_affinity", "criterion", "crossbeam-channel", + "crossbeam-queue", "derive_more", "dirs", "enum-as-inner", diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index d3c5ecfe293..5599fb11615 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -48,6 +48,7 @@ bytes.workspace = true bytestring.workspace = true chrono.workspace = true crossbeam-channel.workspace = true +crossbeam-queue.workspace = true derive_more.workspace = true dirs.workspace = true enum-as-inner.workspace = true