diff --git a/crates/client-api-messages/src/websocket.rs b/crates/client-api-messages/src/websocket.rs index 2292f4b12fe..b2a92e71bb0 100644 --- a/crates/client-api-messages/src/websocket.rs +++ b/crates/client-api-messages/src/websocket.rs @@ -859,25 +859,19 @@ pub fn decide_compression(len: usize, compression: Compression) -> Compression { } } -pub fn brotli_compress(bytes: &[u8], out: &mut Vec) { - let reader = &mut &bytes[..]; - - // The default Brotli buffer size. - const BUFFER_SIZE: usize = 4096; +pub fn brotli_compress(bytes: &[u8], out: &mut impl io::Write) { // We are optimizing for compression speed, // so we choose the lowest (fastest) level of compression. // Experiments on internal workloads have shown compression ratios between 7:1 and 10:1 // for large `SubscriptionUpdate` messages at this level. - const COMPRESSION_LEVEL: u32 = 1; - // The default value for an internal compression parameter. - // See `BrotliEncoderParams` for more details. - const LG_WIN: u32 = 22; + const COMPRESSION_LEVEL: i32 = 1; - let mut encoder = brotli::CompressorReader::new(reader, BUFFER_SIZE, COMPRESSION_LEVEL, LG_WIN); - - encoder - .read_to_end(out) - .expect("Failed to Brotli compress `SubscriptionUpdateMessage`"); + let params = brotli::enc::BrotliEncoderParams { + quality: COMPRESSION_LEVEL, + ..<_>::default() + }; + let reader = &mut &bytes[..]; + brotli::BrotliCompress(reader, out, ¶ms).expect("should be able to BrotliCompress"); } pub fn brotli_decompress(bytes: &[u8]) -> Result, io::Error> { @@ -886,10 +880,10 @@ pub fn brotli_decompress(bytes: &[u8]) -> Result, io::Error> { Ok(decompressed) } -pub fn gzip_compress(bytes: &[u8], out: &mut Vec) { +pub fn gzip_compress(bytes: &[u8], out: &mut impl io::Write) { let mut encoder = flate2::write::GzEncoder::new(out, flate2::Compression::fast()); encoder.write_all(bytes).unwrap(); - encoder.finish().expect("Failed to gzip compress `bytes`"); + encoder.finish().expect("should be able to gzip compress `bytes`"); } pub fn gzip_decompress(bytes: &[u8]) -> Result, io::Error> { diff --git a/crates/client-api/src/routes/subscribe.rs b/crates/client-api/src/routes/subscribe.rs index 8d09eee3c93..e7795c69046 100644 --- a/crates/client-api/src/routes/subscribe.rs +++ b/crates/client-api/src/routes/subscribe.rs @@ -14,12 +14,14 @@ use futures::{Future, FutureExt, SinkExt, StreamExt}; use http::{HeaderValue, StatusCode}; use scopeguard::ScopeGuard; use serde::Deserialize; -use spacetimedb::client::messages::{serialize, IdentityTokenMessage, SerializableMessage}; +use spacetimedb::client::messages::{serialize, IdentityTokenMessage, SerializableMessage, SerializeBuffer}; use spacetimedb::client::{ClientActorId, ClientConfig, ClientConnection, DataMessage, MessageHandleError, Protocol}; +use spacetimedb::execution_context::WorkloadType; use spacetimedb::host::module_host::ClientConnectedError; use spacetimedb::host::NoSuchModule; use spacetimedb::util::also_poll; use spacetimedb::worker_metrics::WORKER_METRICS; +use spacetimedb::Identity; use spacetimedb_client_api_messages::websocket::{self as ws_api, Compression}; use spacetimedb_lib::connection_id::{ConnectionId, ConnectionIdForUrl}; use std::time::Instant; @@ -246,6 +248,7 @@ async fn ws_client_actor_inner( outgoing_queue_length_metric.sub(sendrx.len() as _); }; + let mut msg_buffer = SerializeBuffer::new(client.config); loop { rx_buf.clear(); enum Item { @@ -299,36 +302,40 @@ async fn ws_client_actor_inner( log::info!("dropping {n} messages due to ws already being closed"); log::debug!("dropped messages: {:?}", &rx_buf[..n]); } else { - let send_all = async { - for msg in rx_buf.drain(..n) { - let workload = msg.workload(); - let num_rows = msg.num_rows(); - - let msg = datamsg_to_wsmsg(serialize(msg, client.config)); - - // These metrics should be updated together, - // or not at all. - if let (Some(workload), Some(num_rows)) = (workload, num_rows) { - WORKER_METRICS - .websocket_sent_num_rows - .with_label_values(&addr, &workload) - .observe(num_rows as f64); - WORKER_METRICS - .websocket_sent_msg_size - .with_label_values(&addr, &workload) - .observe(msg.len() as f64); + let send_all = async { + for msg in rx_buf.drain(..n) { + let workload = msg.workload(); + let num_rows = msg.num_rows(); + + // Serialize the message, report metrics, + // and keep a handle to the buffer. + let (msg_alloc, msg_data) = serialize(msg_buffer, msg, client.config); + report_ws_sent_metrics(&addr, workload, num_rows, &msg_data); + + // Buffer the message without necessarily sending it. + let res = ws.feed(datamsg_to_wsmsg(msg_data)).await; + + // At this point, + // the underlying allocation of `msg_data` should have a single referent + // and this should be `msg_alloc`. + // We can put this back into our pool. + msg_buffer = msg_alloc.try_reclaim() + .expect("should have a unique referent to `msg_alloc`"); + + if res.is_err() { + return (res, msg_buffer); + } } - // feed() buffers the message, but does not necessarily send it - ws.feed(msg).await?; - } - // now we flush all the messages to the socket - ws.flush().await - }; + // now we flush all the messages to the socket + (ws.flush().await, msg_buffer) + }; // Flush the websocket while continuing to poll the `handle_queue`, // to avoid deadlocks or delays due to enqueued futures holding resources. let send_all = also_poll(send_all, make_progress(&mut current_message)); let t1 = Instant::now(); - if let Err(error) = send_all.await { + let (send_all_result, buf) = send_all.await; + msg_buffer = buf; + if let Err(error) = send_all_result { log::warn!("Websocket send error: {error}") } let time = t1.elapsed(); @@ -394,10 +401,22 @@ async fn ws_client_actor_inner( if let Err(e) = res { if let MessageHandleError::Execution(err) = e { log::error!("{err:#}"); - let msg = serialize(err, client.config); - if let Err(error) = ws.send(datamsg_to_wsmsg(msg)).await { + // Serialize the message and keep a handle to the buffer. + let (msg_alloc, msg_data) = serialize(msg_buffer, err, client.config); + + // Buffer the message without necessarily sending it. + if let Err(error) = ws.send(datamsg_to_wsmsg(msg_data)).await { log::warn!("Websocket send error: {error}") } + + // At this point, + // the underlying allocation of `msg_data` should have a single referent + // and this should be `msg_alloc`. + // We can put this back into our pool. + msg_buffer = msg_alloc + .try_reclaim() + .expect("should have a unique referent to `msg_alloc`"); + continue; } log::debug!("Client caused error on text message: {}", e); @@ -461,6 +480,27 @@ impl ClientMessage { } } +/// Report metrics on sent rows and message sizes to a websocket client. +fn report_ws_sent_metrics( + addr: &Identity, + workload: Option, + num_rows: Option, + msg_ws: &DataMessage, +) { + // These metrics should be updated together, + // or not at all. + if let (Some(workload), Some(num_rows)) = (workload, num_rows) { + WORKER_METRICS + .websocket_sent_num_rows + .with_label_values(addr, &workload) + .observe(num_rows as f64); + WORKER_METRICS + .websocket_sent_msg_size + .with_label_values(addr, &workload) + .observe(msg_ws.len() as f64); + } +} + fn datamsg_to_wsmsg(msg: DataMessage) -> WsMessage { match msg { DataMessage::Text(text) => WsMessage::Text(bytestring_to_utf8bytes(text)), diff --git a/crates/core/src/client/client_connection.rs b/crates/core/src/client/client_connection.rs index ad107d0e0bc..941064413f5 100644 --- a/crates/core/src/client/client_connection.rs +++ b/crates/core/src/client/client_connection.rs @@ -234,17 +234,27 @@ impl From> for DataMessage { } impl DataMessage { + /// Returns the number of bytes this message consists of. pub fn len(&self) -> usize { match self { - DataMessage::Text(s) => s.len(), - DataMessage::Binary(b) => b.len(), + Self::Text(s) => s.len(), + Self::Binary(b) => b.len(), } } + /// Is the message empty? #[must_use] pub fn is_empty(&self) -> bool { self.len() == 0 } + + /// Returns a handle to the underlying allocation of the message without consuming it. + pub fn allocation(&self) -> Bytes { + match self { + DataMessage::Text(alloc) => alloc.as_bytes().clone(), + DataMessage::Binary(alloc) => alloc.clone(), + } + } } // if a client racks up this many messages in the queue without ACK'ing diff --git a/crates/core/src/client/messages.rs b/crates/core/src/client/messages.rs index cdd0e3c1c94..fa801a17bc3 100644 --- a/crates/core/src/client/messages.rs +++ b/crates/core/src/client/messages.rs @@ -3,6 +3,8 @@ use crate::execution_context::WorkloadType; use crate::host::module_host::{EventStatus, ModuleEvent}; use crate::host::ArgsTuple; use crate::messages::websocket as ws; +use bytes::{BufMut, Bytes, BytesMut}; +use bytestring::ByteString; use derive_more::From; use spacetimedb_client_api_messages::websocket::{ BsatnFormat, Compression, FormatSwitch, JsonFormat, OneOffTable, RowListLen, WebsocketFormat, @@ -27,36 +29,131 @@ pub trait ToProtocol { pub(super) type SwitchedServerMessage = FormatSwitch, ws::ServerMessage>; pub(super) type SwitchedDbUpdate = FormatSwitch, ws::DatabaseUpdate>; +/// The initial size of a `serialize` buffer. +/// Currently 4k to align with the linux page size +/// and this should be more than enough in the common case. +const SERIALIZE_BUFFER_INIT_CAP: usize = 4096; + +/// A buffer used by [`serialize`] +pub struct SerializeBuffer { + uncompressed: BytesMut, + compressed: BytesMut, +} + +impl SerializeBuffer { + pub fn new(config: ClientConfig) -> Self { + let uncompressed_capacity = SERIALIZE_BUFFER_INIT_CAP; + let compressed_capacity = if config.compression == Compression::None || config.protocol == Protocol::Text { + 0 + } else { + SERIALIZE_BUFFER_INIT_CAP + }; + Self { + uncompressed: BytesMut::with_capacity(uncompressed_capacity), + compressed: BytesMut::with_capacity(compressed_capacity), + } + } + + /// Take the uncompressed message as the one to use. + fn uncompressed(self) -> (InUseSerializeBuffer, Bytes) { + let uncompressed = self.uncompressed.freeze(); + let in_use = InUseSerializeBuffer::Uncompressed { + uncompressed: uncompressed.clone(), + compressed: self.compressed, + }; + (in_use, uncompressed) + } + + /// Write uncompressed data with a leading tag. + fn write_with_tag(&mut self, tag: u8, write: F) -> &[u8] + where + F: FnOnce(bytes::buf::Writer<&mut BytesMut>), + { + self.uncompressed.put_u8(tag); + write((&mut self.uncompressed).writer()); + &self.uncompressed[1..] + } + + /// Compress the data from a `write_with_tag` call, and change the tag. + fn compress_with_tag( + self, + tag: u8, + write: impl FnOnce(&[u8], &mut bytes::buf::Writer), + ) -> (InUseSerializeBuffer, Bytes) { + let mut writer = self.compressed.writer(); + writer.get_mut().put_u8(tag); + write(&self.uncompressed[1..], &mut writer); + let compressed = writer.into_inner().freeze(); + let in_use = InUseSerializeBuffer::Compressed { + uncompressed: self.uncompressed, + compressed: compressed.clone(), + }; + (in_use, compressed) + } +} + +type BytesMutWriter<'a> = bytes::buf::Writer<&'a mut BytesMut>; + +pub enum InUseSerializeBuffer { + Uncompressed { uncompressed: Bytes, compressed: BytesMut }, + Compressed { uncompressed: BytesMut, compressed: Bytes }, +} + +impl InUseSerializeBuffer { + pub fn try_reclaim(self) -> Option { + let (mut uncompressed, mut compressed) = match self { + Self::Uncompressed { + uncompressed, + compressed, + } => (uncompressed.try_into_mut().ok()?, compressed), + Self::Compressed { + uncompressed, + compressed, + } => (uncompressed, compressed.try_into_mut().ok()?), + }; + uncompressed.clear(); + compressed.clear(); + Some(SerializeBuffer { + uncompressed, + compressed, + }) + } +} + /// Serialize `msg` into a [`DataMessage`] containing a [`ws::ServerMessage`]. /// /// If `protocol` is [`Protocol::Binary`], /// the message will be conditionally compressed by this method according to `compression`. -pub fn serialize(msg: impl ToProtocol, config: ClientConfig) -> DataMessage { - // TODO(centril, perf): here we are allocating buffers only to throw them away eventually. - // Consider pooling these allocations so that we reuse them. +pub fn serialize( + mut buffer: SerializeBuffer, + msg: impl ToProtocol, + config: ClientConfig, +) -> (InUseSerializeBuffer, DataMessage) { match msg.to_protocol(config.protocol) { - FormatSwitch::Json(msg) => serde_json::to_string(&SerializeWrapper::new(msg)).unwrap().into(), + FormatSwitch::Json(msg) => { + let out: BytesMutWriter<'_> = (&mut buffer.uncompressed).writer(); + serde_json::to_writer(out, &SerializeWrapper::new(msg)) + .expect("should be able to json encode a `ServerMessage`"); + + let (in_use, out) = buffer.uncompressed(); + // SAFETY: `serde_json::to_writer` states that: + // > "Serialization guarantees it only feeds valid UTF-8 sequences to the writer." + let msg_json = unsafe { ByteString::from_bytes_unchecked(out) }; + (in_use, msg_json.into()) + } FormatSwitch::Bsatn(msg) => { // First write the tag so that we avoid shifting the entire message at the end. - let mut msg_bytes = vec![SERVER_MSG_COMPRESSION_TAG_NONE]; - bsatn::to_writer(&mut msg_bytes, &msg).unwrap(); + let srv_msg = buffer.write_with_tag(SERVER_MSG_COMPRESSION_TAG_NONE, |w| { + bsatn::to_writer(w.into_inner(), &msg).unwrap() + }); // Conditionally compress the message. - let srv_msg = &msg_bytes[1..]; - let msg_bytes = match ws::decide_compression(srv_msg.len(), config.compression) { - Compression::None => msg_bytes, - Compression::Brotli => { - let mut out = vec![SERVER_MSG_COMPRESSION_TAG_BROTLI]; - ws::brotli_compress(srv_msg, &mut out); - out - } - Compression::Gzip => { - let mut out = vec![SERVER_MSG_COMPRESSION_TAG_GZIP]; - ws::gzip_compress(srv_msg, &mut out); - out - } + let (in_use, msg_bytes) = match ws::decide_compression(srv_msg.len(), config.compression) { + Compression::None => buffer.uncompressed(), + Compression::Brotli => buffer.compress_with_tag(SERVER_MSG_COMPRESSION_TAG_BROTLI, ws::brotli_compress), + Compression::Gzip => buffer.compress_with_tag(SERVER_MSG_COMPRESSION_TAG_GZIP, ws::gzip_compress), }; - msg_bytes.into() + (in_use, msg_bytes.into()) } } } diff --git a/crates/sats/src/buffer.rs b/crates/sats/src/buffer.rs index c6a81ffed40..c3762282c7f 100644 --- a/crates/sats/src/buffer.rs +++ b/crates/sats/src/buffer.rs @@ -2,6 +2,8 @@ //! without relying on types in third party libraries like `bytes::Bytes`, etc. //! Meant to be kept slim and trim for use across both native and WASM. +use bytes::{BufMut, BytesMut}; + use crate::{i256, u256}; use core::cell::Cell; use core::fmt; @@ -309,6 +311,12 @@ impl BufWriter for &mut [u8] { } } +impl BufWriter for BytesMut { + fn put_slice(&mut self, slice: &[u8]) { + BufMut::put_slice(self, slice); + } +} + /// A [`BufWriter`] that only counts the bytes. #[derive(Default)] pub struct CountWriter {