Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 11 additions & 31 deletions crates/client-api/src/routes/subscribe.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::collections::VecDeque;
use std::mem;
use std::pin::{pin, Pin};
use std::time::Duration;
Expand All @@ -15,7 +14,10 @@ use http::{HeaderValue, StatusCode};
use scopeguard::ScopeGuard;
use serde::Deserialize;
use spacetimedb::client::messages::{serialize, IdentityTokenMessage, SerializableMessage, SerializeBuffer};
use spacetimedb::client::{ClientActorId, ClientConfig, ClientConnection, DataMessage, MessageHandleError, Protocol};
use spacetimedb::client::{
ClientActorId, ClientConfig, ClientConnection, DataMessage, MessageHandleError, MeteredDeque, MeteredReceiver,
Protocol,
};
use spacetimedb::execution_context::WorkloadType;
use spacetimedb::host::module_host::ClientConnectedError;
use spacetimedb::host::NoSuchModule;
Expand All @@ -25,7 +27,6 @@ use spacetimedb::Identity;
use spacetimedb_client_api_messages::websocket::{self as ws_api, Compression};
use spacetimedb_lib::connection_id::{ConnectionId, ConnectionIdForUrl};
use std::time::Instant;
use tokio::sync::mpsc;
use tokio_tungstenite::tungstenite::Utf8Bytes;

use crate::auth::SpacetimeAuth;
Expand Down Expand Up @@ -182,7 +183,7 @@ where

const LIVELINESS_TIMEOUT: Duration = Duration::from_secs(60);

async fn ws_client_actor(client: ClientConnection, ws: WebSocketStream, sendrx: mpsc::Receiver<SerializableMessage>) {
async fn ws_client_actor(client: ClientConnection, ws: WebSocketStream, sendrx: MeteredReceiver<SerializableMessage>) {
// ensure that even if this task gets cancelled, we always cleanup the connection
let mut client = scopeguard::guard(client, |client| {
tokio::spawn(client.disconnect());
Expand All @@ -204,11 +205,13 @@ async fn make_progress<Fut: Future>(fut: &mut Pin<&mut MaybeDone<Fut>>) {
async fn ws_client_actor_inner(
client: &mut ClientConnection,
mut ws: WebSocketStream,
mut sendrx: mpsc::Receiver<SerializableMessage>,
mut sendrx: MeteredReceiver<SerializableMessage>,
) {
let mut liveness_check_interval = tokio::time::interval(LIVELINESS_TIMEOUT);
let mut got_pong = true;

let addr = client.module.info().database_identity;

// Build a queue of incoming messages to handle, to be processed one at a time,
// in the order they're received.
//
Expand All @@ -222,32 +225,14 @@ async fn ws_client_actor_inner(
// `select!` for examples of how to do this.
//
// TODO: do we want this to have a fixed capacity? or should it be unbounded
let mut message_queue = VecDeque::<(DataMessage, Instant)>::new();
let mut message_queue = MeteredDeque::<(DataMessage, Instant)>::new(
WORKER_METRICS.total_incoming_queue_length.with_label_values(&addr),
);
let mut current_message = pin!(MaybeDone::Gone);

let mut closed = false;
let mut rx_buf = Vec::new();

let addr = client.module.info().database_identity;

// Grab handles on the total incoming and outgoing queue length metrics,
// which we'll increment and decrement as we push into and pull out of those queues.
// Note that `total_outgoing_queue_length` is incremented separately,
// by `ClientConnectionSender::send` in core/src/client/client_connection.rs;
// we're only responsible for decrementing that one.
// Also note that much care must be taken to clean up these metrics when the connection closes!
// Any path which exits this function must decrement each of these metrics
// by the number of messages still waiting in this client's queue,
// or else they will grow without bound as clients disconnect, and be useless.
let incoming_queue_length_metric = WORKER_METRICS.total_incoming_queue_length.with_label_values(&addr);
let outgoing_queue_length_metric = WORKER_METRICS.total_outgoing_queue_length.with_label_values(&addr);

let clean_up_metrics = |message_queue: &VecDeque<(DataMessage, Instant)>,
sendrx: &mpsc::Receiver<SerializableMessage>| {
incoming_queue_length_metric.sub(message_queue.len() as _);
outgoing_queue_length_metric.sub(sendrx.len() as _);
};

let mut msg_buffer = SerializeBuffer::new(client.config);
loop {
rx_buf.clear();
Expand All @@ -257,7 +242,6 @@ async fn ws_client_actor_inner(
}
if let MaybeDone::Gone = *current_message {
if let Some((message, timer)) = message_queue.pop_front() {
incoming_queue_length_metric.dec();
let client = client.clone();
let fut = async move { client.handle_message(message, timer).await };
current_message.set(MaybeDone::Future(fut));
Expand Down Expand Up @@ -286,15 +270,13 @@ async fn ws_client_actor_inner(
}
// the client sent us a close frame
None => {
clean_up_metrics(&message_queue, &sendrx);
break
},
},

// If we have an outgoing message to send, send it off.
// No incoming `message` to handle, so `continue`.
Some(n) = sendrx.recv_many(&mut rx_buf, 32).map(|n| (n != 0).then_some(n)) => {
outgoing_queue_length_metric.sub(n as _);
if closed {
// TODO: this isn't great. when we receive a close request from the peer,
// tungstenite doesn't let us send any new messages on the socket,
Expand Down Expand Up @@ -379,7 +361,6 @@ async fn ws_client_actor_inner(
} else {
// the client never responded to our ping; drop them without trying to send them a Close
log::warn!("client {} timed out", client.id);
clean_up_metrics(&message_queue, &sendrx);
break;
}
}
Expand All @@ -394,7 +375,6 @@ async fn ws_client_actor_inner(
match message {
Item::Message(ClientMessage::Message(message)) => {
let timer = Instant::now();
incoming_queue_length_metric.inc();
message_queue.push_back((message, timer))
}
Item::HandleResult(res) => {
Expand Down
3 changes: 2 additions & 1 deletion crates/core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ mod message_handlers;
pub mod messages;

pub use client_connection::{
ClientConfig, ClientConnection, ClientConnectionSender, ClientSendError, DataMessage, Protocol,
ClientConfig, ClientConnection, ClientConnectionSender, ClientSendError, DataMessage, MeteredDeque,
MeteredReceiver, Protocol,
};
pub use client_connection_index::ClientActorIndex;
pub use message_handlers::MessageHandleError;
Expand Down
117 changes: 115 additions & 2 deletions crates/core/src/client/client_connection.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::VecDeque;
use std::ops::Deref;
use std::sync::atomic::{AtomicBool, Ordering::Relaxed};
use std::sync::Arc;
Expand Down Expand Up @@ -132,14 +133,15 @@ pub enum ClientSendError {
}

impl ClientConnectionSender {
pub fn dummy_with_channel(id: ClientActorId, config: ClientConfig) -> (Self, mpsc::Receiver<SerializableMessage>) {
pub fn dummy_with_channel(id: ClientActorId, config: ClientConfig) -> (Self, MeteredReceiver<SerializableMessage>) {
let (sendtx, rx) = mpsc::channel(1);
// just make something up, it doesn't need to be attached to a real task
let abort_handle = match tokio::runtime::Handle::try_current() {
Ok(h) => h.spawn(async {}).abort_handle(),
Err(_) => tokio::runtime::Runtime::new().unwrap().spawn(async {}).abort_handle(),
};

let rx = MeteredReceiver::new(rx);
let cancelled = AtomicBool::new(false);
let sender = Self {
id,
Expand Down Expand Up @@ -257,6 +259,116 @@ impl DataMessage {
}
}

/// Wraps a [VecDeque] with a gauge for tracking its size.
/// We subtract its size from the gauge on drop to avoid leaking the metric.
pub struct MeteredDeque<T> {
inner: VecDeque<T>,
gauge: IntGauge,
}

impl<T> MeteredDeque<T> {
pub fn new(gauge: IntGauge) -> Self {
Self {
inner: VecDeque::new(),
gauge,
}
}

pub fn pop_front(&mut self) -> Option<T> {
self.inner.pop_front().inspect(|_| {
self.gauge.dec();
})
}

pub fn pop_back(&mut self) -> Option<T> {
self.inner.pop_back().inspect(|_| {
self.gauge.dec();
})
}

pub fn push_front(&mut self, value: T) {
self.gauge.inc();
self.inner.push_front(value);
}

pub fn push_back(&mut self, value: T) {
self.gauge.inc();
self.inner.push_back(value);
}

pub fn len(&self) -> usize {
self.inner.len()
}

pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
}

impl<T> Drop for MeteredDeque<T> {
fn drop(&mut self) {
// Record the number of elements still in the deque on drop
self.gauge.sub(self.inner.len() as _);
}
}

/// Wraps the receiving end of a channel with a gauge for tracking the size of the channel.
/// We subtract the size of the channel from the gauge on drop to avoid leaking the metric.
pub struct MeteredReceiver<T> {
inner: mpsc::Receiver<T>,
gauge: Option<IntGauge>,
}

impl<T> MeteredReceiver<T> {
pub fn new(inner: mpsc::Receiver<T>) -> Self {
Self { inner, gauge: None }
}

pub fn with_gauge(inner: mpsc::Receiver<T>, gauge: IntGauge) -> Self {
Self {
inner,
gauge: Some(gauge),
}
}

pub async fn recv(&mut self) -> Option<T> {
self.inner.recv().await.inspect(|_| {
if let Some(gauge) = &self.gauge {
gauge.dec();
}
})
}

pub async fn recv_many(&mut self, buf: &mut Vec<T>, max: usize) -> usize {
let n = self.inner.recv_many(buf, max).await;
if let Some(gauge) = &self.gauge {
gauge.sub(n as _);
}
n
}

pub fn len(&self) -> usize {
self.inner.len()
}

pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}

pub fn close(&mut self) {
self.inner.close();
}
}

impl<T> Drop for MeteredReceiver<T> {
fn drop(&mut self) {
// Record the number of elements still in the channel on drop
if let Some(gauge) = &self.gauge {
gauge.sub(self.inner.len() as _);
}
}
}
Comment thread
joshua-spacetime marked this conversation as resolved.

// if a client racks up this many messages in the queue without ACK'ing
// anything, we boot 'em.
const CLIENT_CHANNEL_CAPACITY: usize = 16 * KB;
Expand All @@ -269,7 +381,7 @@ impl ClientConnection {
config: ClientConfig,
replica_id: u64,
mut module_rx: watch::Receiver<ModuleHost>,
actor: impl FnOnce(ClientConnection, mpsc::Receiver<SerializableMessage>) -> Fut,
actor: impl FnOnce(ClientConnection, MeteredReceiver<SerializableMessage>) -> Fut,
) -> Result<ClientConnection, ClientConnectedError>
where
Fut: Future<Output = ()> + Send + 'static,
Expand Down Expand Up @@ -299,6 +411,7 @@ impl ClientConnection {
.abort_handle();

let metrics = ClientConnectionMetrics::new(database_identity, config.protocol);
let sendrx = MeteredReceiver::with_gauge(sendrx, metrics.sendtx_queue_size.clone());

let sender = Arc::new(ClientConnectionSender {
id,
Expand Down
12 changes: 7 additions & 5 deletions crates/core/src/subscription/module_subscription_actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,7 @@ mod tests {
SerializableMessage, SubscriptionData, SubscriptionError, SubscriptionMessage, SubscriptionResult,
SubscriptionUpdateMessage, TransactionUpdateMessage,
};
use crate::client::{ClientActorId, ClientConfig, ClientConnectionSender, ClientName, Protocol};
use crate::client::{ClientActorId, ClientConfig, ClientConnectionSender, ClientName, MeteredReceiver, Protocol};
use crate::db::datastore::system_tables::{StRowLevelSecurityRow, ST_ROW_LEVEL_SECURITY_ID};
use crate::db::relational_db::tests_utils::{
begin_mut_tx, begin_tx, insert, with_auto_commit, with_read_only, TestDB,
Expand Down Expand Up @@ -964,7 +964,7 @@ mod tests {
use spacetimedb_sats::product;
use std::time::Instant;
use std::{sync::Arc, time::Duration};
use tokio::sync::mpsc::{self, Receiver};
use tokio::sync::mpsc::{self};

fn add_subscriber(db: Arc<RelationalDB>, sql: &str, assert: Option<AssertTxFn>) -> Result<(), DBError> {
// Create and enter a Tokio runtime to run the `ModuleSubscriptions`' background workers in parallel.
Expand Down Expand Up @@ -1072,7 +1072,7 @@ mod tests {
fn client_connection_with_compression(
client_id: ClientActorId,
compression: Compression,
) -> (Arc<ClientConnectionSender>, Receiver<SerializableMessage>) {
) -> (Arc<ClientConnectionSender>, MeteredReceiver<SerializableMessage>) {
let (sender, rx) = ClientConnectionSender::dummy_with_channel(
client_id,
ClientConfig {
Expand All @@ -1085,7 +1085,9 @@ mod tests {
}

/// Instantiate a client connection
fn client_connection(client_id: ClientActorId) -> (Arc<ClientConnectionSender>, Receiver<SerializableMessage>) {
fn client_connection(
client_id: ClientActorId,
) -> (Arc<ClientConnectionSender>, MeteredReceiver<SerializableMessage>) {
client_connection_with_compression(client_id, Compression::None)
}

Expand Down Expand Up @@ -1159,7 +1161,7 @@ mod tests {

/// Pull a message from receiver and assert that it is a `TxUpdate` with the expected rows
async fn assert_tx_update_for_table(
rx: &mut Receiver<SerializableMessage>,
rx: &mut MeteredReceiver<SerializableMessage>,
table_id: TableId,
schema: &ProductType,
inserts: impl IntoIterator<Item = ProductValue>,
Expand Down
Loading