Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rivetkit-rust/packages/rivetkit-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ authors.workspace = true
license.workspace = true
edition.workspace = true
workspace = "../../../"
autotests = false

[features]
default = []
Expand Down
3 changes: 2 additions & 1 deletion rivetkit-rust/packages/rivetkit-core/src/actor/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ fn duration_ms(value: u32) -> Duration {
Duration::from_millis(u64::from(value))
}

// Test shim keeps moved tests in crate-root tests/ with private-module access.
#[cfg(test)]
#[path = "../../tests/modules/config.rs"]
#[path = "../../tests/config.rs"]
mod tests;
334 changes: 3 additions & 331 deletions rivetkit-rust/packages/rivetkit-core/src/actor/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -982,335 +982,7 @@ pub(crate) fn make_connection_key(conn_id: &str) -> Vec<u8> {
key
}

// Test shim keeps moved tests in crate-root tests/ with private-module access.
#[cfg(test)]
mod tests {
use std::collections::BTreeSet;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};

use parking_lot::Mutex;
use tokio::sync::{Barrier, mpsc};
use tokio::task::yield_now;

use super::{
HibernatableConnectionMetadata, PersistedConnection, decode_persisted_connection,
encode_persisted_connection, hibernatable_id_from_slice, make_connection_key,
};
use crate::actor::context::ActorContext;
use crate::actor::messages::ActorEvent;
use crate::actor::preload::PreloadedKv;
use crate::actor::task::LifecycleEvent;
use crate::kv::Kv;

fn next_non_activity_lifecycle_event(
rx: &mut mpsc::Receiver<LifecycleEvent>,
) -> Option<LifecycleEvent> {
rx.try_recv().ok()
}

#[tokio::test]
async fn restore_persisted_uses_preloaded_connection_prefix_when_present() {
let ctx = ActorContext::new_with_kv(
"actor-preload",
"actor",
Vec::new(),
"local",
Kv::new_in_memory(),
);
let persisted = PersistedConnection {
id: "conn-preloaded".to_owned(),
parameters: vec![1],
state: vec![2],
gateway_id: [1, 2, 3, 4],
request_id: [5, 6, 7, 8],
request_path: "/socket".to_owned(),
..PersistedConnection::default()
};
let preloaded = PreloadedKv::new_with_requested_get_keys(
[(
make_connection_key(&persisted.id),
encode_persisted_connection(&persisted)
.expect("persisted connection should encode"),
)],
Vec::new(),
vec![vec![2]],
);

let restored = ctx
.restore_persisted(Some(&preloaded))
.await
.expect("restore should use preloaded entries instead of unconfigured kv");

assert_eq!(restored.len(), 1);
assert_eq!(restored[0].id(), "conn-preloaded");
assert_eq!(restored[0].state(), vec![2]);
assert!(ctx.connection("conn-preloaded").is_some());
}

#[test]
fn persisted_connection_uses_ts_v4_fixed_id_wire_format() {
let persisted = PersistedConnection {
id: "c".to_owned(),
parameters: vec![1, 2],
state: vec![3],
gateway_id: [10, 11, 12, 13],
request_id: [20, 21, 22, 23],
server_message_index: 9,
client_message_index: 10,
request_path: "/".to_owned(),
..PersistedConnection::default()
};

let encoded =
encode_persisted_connection(&persisted).expect("persisted connection should encode");

assert_eq!(
encoded,
vec![
4, 0, // embedded version
1, b'c', // id
2, 1, 2, // parameters
1, 3, // state
0, // subscriptions
10, 11, 12, 13, // gatewayId fixed data[4]
20, 21, 22, 23, // requestId fixed data[4]
9, 0, // serverMessageIndex
10, 0, // clientMessageIndex
1, b'/', // requestPath
0, // requestHeaders
]
);

let decoded =
decode_persisted_connection(&encoded).expect("persisted connection should decode");
assert_eq!(decoded.gateway_id, [10, 11, 12, 13]);
assert_eq!(decoded.request_id, [20, 21, 22, 23]);
}

#[test]
fn hibernatable_id_validation_returns_rivet_error() {
let error = hibernatable_id_from_slice("gateway_id", &[1, 2, 3])
.expect_err("invalid id should fail");
let error = rivet_error::RivetError::extract(&error);

assert_eq!(error.group(), "actor");
assert_eq!(error.code(), "invalid_request");
}

#[tokio::test(start_paused = true)]
async fn concurrent_disconnects_only_emit_one_close_and_one_hibernation_removal() {
let ctx = ActorContext::new_with_kv(
"actor-race",
"actor",
Vec::new(),
"local",
Kv::new_in_memory(),
);
ctx.configure_connection_runtime(crate::actor::config::ActorConfig::default());
let (events_tx, mut events_rx) = mpsc::unbounded_channel();
ctx.configure_actor_events(Some(events_tx));
let closed = Arc::new(AtomicUsize::new(0));
let observed_conn_id = Arc::new(Mutex::new(None::<String>));

let recv = tokio::spawn({
let closed = closed.clone();
let observed_conn_id = observed_conn_id.clone();
async move {
while let Some(event) = events_rx.recv().await {
match event {
ActorEvent::ConnectionOpen { reply, .. } => reply.send(Ok(())),
ActorEvent::ConnectionClosed { conn } => {
*observed_conn_id.lock() = Some(conn.id().to_owned());
closed.fetch_add(1, Ordering::SeqCst);
break;
}
other => panic!("unexpected event: {other:?}"),
}
}
}
});

let conn = ctx
.connect_with_state(
vec![1],
true,
Some(HibernatableConnectionMetadata {
gateway_id: [1, 2, 3, 4],
request_id: [5, 6, 7, 8],
..HibernatableConnectionMetadata::default()
}),
None,
async { Ok(vec![9]) },
)
.await
.expect("connection should open");
let conn_id = conn.id().to_owned();
ctx.record_connections_updated();
ctx.reset_sleep_timer();

let barrier = Arc::new(Barrier::new(2));
conn.configure_transport_disconnect_handler(Some(Arc::new({
let barrier = barrier.clone();
move |_reason| {
let barrier = barrier.clone();
Box::pin(async move {
barrier.wait().await;
Ok(())
})
}
})));

let first = tokio::spawn({
let conn = conn.clone();
async move { conn.disconnect(Some("first")).await }
});
let second = tokio::spawn({
let conn = conn.clone();
async move { conn.disconnect(Some("second")).await }
});

yield_now().await;
first
.await
.expect("first disconnect task should join")
.expect("first disconnect should succeed");
second
.await
.expect("second disconnect task should join")
.expect("second disconnect should succeed");
recv.await.expect("event receiver should join");

assert_eq!(closed.load(Ordering::SeqCst), 1);
assert_eq!(observed_conn_id.lock().as_deref(), Some(conn_id.as_str()));
assert!(ctx.connection(&conn_id).is_none());

let pending = ctx.take_pending_hibernation_changes_inner();
assert!(pending.updated.is_empty());
assert_eq!(pending.removed, BTreeSet::from([conn_id]));
}

#[tokio::test]
async fn hibernatable_set_state_queues_save_and_non_hibernatable_stays_memory_only() {
let ctx = ActorContext::new_with_kv(
"actor-state-dirty",
"actor",
Vec::new(),
"local",
Kv::new_in_memory(),
);
let (actor_events_tx, mut actor_events_rx) = mpsc::unbounded_channel();
let (lifecycle_events_tx, mut lifecycle_events_rx) = mpsc::channel(4);
ctx.configure_actor_events(Some(actor_events_tx));
ctx.configure_lifecycle_events(Some(lifecycle_events_tx));

let open_replies = tokio::spawn(async move {
for _ in 0..2 {
match actor_events_rx
.recv()
.await
.expect("open event should arrive")
{
ActorEvent::ConnectionOpen { reply, .. } => reply.send(Ok(())),
other => panic!("unexpected actor event: {other:?}"),
}
}
});

let non_hibernatable = ctx
.connect_with_state(vec![1], false, None, None, async { Ok(vec![2]) })
.await
.expect("non-hibernatable connection should open");
non_hibernatable.set_state(vec![3]);
assert_eq!(non_hibernatable.state(), vec![3]);
assert!(
ctx.dirty_hibernatable_conns_inner().is_empty(),
"non-hibernatable state changes should not queue persistence"
);
assert!(
next_non_activity_lifecycle_event(&mut lifecycle_events_rx).is_none(),
"non-hibernatable state changes should not request actor save"
);

let hibernatable = ctx
.connect_with_state(
vec![4],
true,
Some(HibernatableConnectionMetadata {
gateway_id: [1, 2, 3, 4],
request_id: [5, 6, 7, 8],
..HibernatableConnectionMetadata::default()
}),
None,
async { Ok(vec![5]) },
)
.await
.expect("hibernatable connection should open");
hibernatable.set_state(vec![6]);

assert_eq!(
ctx.dirty_hibernatable_conns_inner()
.into_iter()
.map(|conn| conn.id().to_owned())
.collect::<Vec<_>>(),
vec![hibernatable.id().to_owned()]
);
assert_eq!(
next_non_activity_lifecycle_event(&mut lifecycle_events_rx)
.expect("hibernatable state change should request save"),
LifecycleEvent::SaveRequested { immediate: false }
);

open_replies
.await
.expect("open reply task should join cleanly");
}

#[tokio::test(start_paused = true)]
async fn remove_existing_for_disconnect_has_exactly_one_winner() {
let ctx = ActorContext::new_with_kv(
"actor-race",
"actor",
Vec::new(),
"local",
Kv::new_in_memory(),
);
let conn = super::ConnHandle::new("conn-race", vec![1], vec![2], true);
conn.configure_hibernation(Some(HibernatableConnectionMetadata {
gateway_id: [1, 2, 3, 4],
request_id: [5, 6, 7, 8],
..HibernatableConnectionMetadata::default()
}));
ctx.insert_existing(conn);

let barrier = Arc::new(Barrier::new(2));
let first = tokio::spawn({
let ctx = ctx.clone();
let barrier = barrier.clone();
async move {
barrier.wait().await;
ctx.remove_existing_for_disconnect("conn-race")
.map(|conn| conn.id().to_owned())
}
});
let second = tokio::spawn({
let ctx = ctx.clone();
let barrier = barrier.clone();
async move {
barrier.wait().await;
ctx.remove_existing_for_disconnect("conn-race")
.map(|conn| conn.id().to_owned())
}
});

let first = first.await.expect("first task should join");
let second = second.await.expect("second task should join");
let winners = [first, second].into_iter().flatten().collect::<Vec<_>>();

assert_eq!(winners, vec!["conn-race".to_owned()]);
assert!(ctx.connection("conn-race").is_none());

let pending = ctx.take_pending_hibernation_changes_inner();
assert!(pending.updated.is_empty());
assert_eq!(pending.removed, BTreeSet::from(["conn-race".to_owned()]));
}
}
#[path = "../../tests/connection.rs"]
mod tests;
3 changes: 2 additions & 1 deletion rivetkit-rust/packages/rivetkit-core/src/actor/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1547,6 +1547,7 @@ impl std::fmt::Debug for ActorContext {
}
}

// Test shim keeps moved tests in crate-root tests/ with private-module access.
#[cfg(test)]
#[path = "../../tests/modules/context.rs"]
#[path = "../../tests/context.rs"]
pub(crate) mod tests;
3 changes: 2 additions & 1 deletion rivetkit-rust/packages/rivetkit-core/src/actor/kv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ fn apply_list_opts(entries: &mut Vec<(Vec<u8>, Vec<u8>)>, opts: ListOpts) {
}
}

// Test shim keeps moved tests in crate-root tests/ with private-module access.
#[cfg(test)]
#[path = "../../tests/modules/kv.rs"]
#[path = "../../tests/kv.rs"]
pub(crate) mod tests;
3 changes: 2 additions & 1 deletion rivetkit-rust/packages/rivetkit-core/src/actor/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ impl ActorEvent {
}
}

// Test shim keeps moved tests in crate-root tests/ with private-module access.
#[cfg(test)]
#[path = "../../tests/modules/messages.rs"]
#[path = "../../tests/messages.rs"]
mod tests;
Loading
Loading