Skip to content

Commit 8996d1e

Browse files
Fix v2 sender-view updates after disconnect
1 parent 61be6e6 commit 8996d1e

4 files changed

Lines changed: 335 additions & 84 deletions

File tree

crates/core/src/subscription/module_subscription_actor.rs

Lines changed: 213 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1880,7 +1880,8 @@ mod tests {
18801880
Protocol, WsVersion,
18811881
};
18821882
use crate::db::relational_db::tests_utils::{
1883-
begin_mut_tx, begin_tx, insert, with_auto_commit, with_read_only, TestDB,
1883+
begin_mut_tx, begin_tx, create_view_for_test, insert, insert_into_view, with_auto_commit, with_read_only,
1884+
TestDB,
18841885
};
18851886
use crate::db::relational_db::{Persistence, RelationalDB, Txdata};
18861887
use crate::error::DBError;
@@ -1919,6 +1920,8 @@ mod tests {
19191920
use tokio::sync::mpsc::{self};
19201921
use tokio::sync::watch;
19211922

1923+
const TEST_MESSAGE_TIMEOUT: Duration = Duration::from_millis(20);
1924+
19221925
fn add_subscriber(db: Arc<RelationalDB>, sql: &str, assert: Option<AssertTxFn>) -> Result<(), DBError> {
19231926
// Create and enter a Tokio runtime to run the `ModuleSubscriptions`' background workers in parallel.
19241927
let runtime = tokio::runtime::Runtime::new().unwrap();
@@ -2180,6 +2183,24 @@ mod tests {
21802183
)
21812184
}
21822185

2186+
/// Instantiate a v2 client connection with the default test settings.
2187+
fn v2_client_connection(
2188+
client_id: ClientActorId,
2189+
db: &Arc<RelationalDB>,
2190+
) -> (Arc<ClientConnectionSender>, ClientConnectionReceiver) {
2191+
client_connection_with_config(
2192+
client_id,
2193+
db,
2194+
ClientConfig {
2195+
protocol: Protocol::Binary,
2196+
version: WsVersion::V2,
2197+
compression: ws_v1::Compression::None,
2198+
tx_update_full: true,
2199+
confirmed_reads: false,
2200+
},
2201+
)
2202+
}
2203+
21832204
/// Insert rules into the RLS system table
21842205
fn insert_rls_rules(
21852206
db: &RelationalDB,
@@ -2263,17 +2284,7 @@ mod tests {
22632284
let db = relational_db()?;
22642285

22652286
let client_id = client_id_from_u8(1);
2266-
let (sender, mut rx) = client_connection_with_config(
2267-
client_id,
2268-
&db,
2269-
ClientConfig {
2270-
protocol: Protocol::Binary,
2271-
version: WsVersion::V2,
2272-
compression: ws_v1::Compression::None,
2273-
tx_update_full: true,
2274-
confirmed_reads: false,
2275-
},
2276-
);
2287+
let (sender, mut rx) = v2_client_connection(client_id, &db);
22772288

22782289
let auth = AuthCtx::new(db.owner_identity(), client_id.identity);
22792290
let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone());
@@ -2333,17 +2344,7 @@ mod tests {
23332344
let db = relational_db()?;
23342345

23352346
let client_id = client_id_from_u8(1);
2336-
let (sender, mut rx) = client_connection_with_config(
2337-
client_id,
2338-
&db,
2339-
ClientConfig {
2340-
protocol: Protocol::Binary,
2341-
version: WsVersion::V2,
2342-
compression: ws_v1::Compression::None,
2343-
tx_update_full: true,
2344-
confirmed_reads: false,
2345-
},
2346-
);
2347+
let (sender, mut rx) = v2_client_connection(client_id, &db);
23472348

23482349
let auth = AuthCtx::new(db.owner_identity(), client_id.identity);
23492350
let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone());
@@ -2405,17 +2406,7 @@ mod tests {
24052406
let db = relational_db()?;
24062407

24072408
let client_id = client_id_from_u8(1);
2408-
let (sender, mut rx) = client_connection_with_config(
2409-
client_id,
2410-
&db,
2411-
ClientConfig {
2412-
protocol: Protocol::Binary,
2413-
version: WsVersion::V2,
2414-
compression: ws_v1::Compression::None,
2415-
tx_update_full: true,
2416-
confirmed_reads: false,
2417-
},
2418-
);
2409+
let (sender, mut rx) = v2_client_connection(client_id, &db);
24192410

24202411
let auth = AuthCtx::new(db.owner_identity(), client_id.identity);
24212412
let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone());
@@ -2460,8 +2451,84 @@ mod tests {
24602451

24612452
let _ = commit_tx(&db, &subs, [], [(table_id, product![2_u8])])?;
24622453

2463-
let recv = tokio::time::timeout(Duration::from_millis(20), rx.recv()).await;
2464-
assert!(recv.is_err(), "expected no updates after unsubscribe");
2454+
assert_no_outbound_message(rx.recv()).await;
2455+
2456+
Ok(())
2457+
}
2458+
2459+
#[tokio::test]
2460+
async fn unsubscribe_v2_other_clients_receive_sender_view_updates() -> anyhow::Result<()> {
2461+
let db = relational_db()?;
2462+
2463+
let id_for_a = identity_from_u8(1);
2464+
let client_id_for_a = client_id_from_u8(1);
2465+
let client_id_for_b = client_id_from_u8(2);
2466+
2467+
let (tx_for_a, mut rx_for_a) = v2_client_connection(client_id_for_a, &db);
2468+
let (tx_for_b, mut rx_for_b) = v2_client_connection(client_id_for_b, &db);
2469+
2470+
let auth_for_a = AuthCtx::new(db.owner_identity(), client_id_for_a.identity);
2471+
let auth_for_b = AuthCtx::new(db.owner_identity(), client_id_for_b.identity);
2472+
let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone());
2473+
2474+
let (_, view_table_id) = create_view_for_test(&db, "my_view", &[("counter", AlgebraicType::U8)], false)?;
2475+
2476+
// Seed a sender-scoped row that only client A should observe through the view.
2477+
with_auto_commit(&db, |tx| -> anyhow::Result<_> {
2478+
insert_into_view(&db, tx, view_table_id, Some(id_for_a), product![7_u8])?;
2479+
Ok(())
2480+
})?;
2481+
2482+
subs.add_v2_subscription_inner::<crate::host::wasmtime::WasmtimeInstance>(
2483+
None,
2484+
tx_for_a.clone(),
2485+
auth_for_a,
2486+
ws_v2::Subscribe {
2487+
request_id: 1,
2488+
query_set_id: ws_v2::QuerySetId::new(1),
2489+
query_strings: ["select * from my_view".into()].into(),
2490+
},
2491+
Instant::now(),
2492+
None,
2493+
)?;
2494+
subs.add_v2_subscription_inner::<crate::host::wasmtime::WasmtimeInstance>(
2495+
None,
2496+
tx_for_b.clone(),
2497+
auth_for_b,
2498+
ws_v2::Subscribe {
2499+
request_id: 2,
2500+
query_set_id: ws_v2::QuerySetId::new(2),
2501+
query_strings: ["select * from my_view".into()].into(),
2502+
},
2503+
Instant::now(),
2504+
None,
2505+
)?;
2506+
2507+
assert!(matches!(
2508+
rx_for_a.recv().await,
2509+
Some(OutboundMessage::V2(ws_v2::ServerMessage::SubscribeApplied(_)))
2510+
));
2511+
assert!(matches!(
2512+
rx_for_b.recv().await,
2513+
Some(OutboundMessage::V2(ws_v2::ServerMessage::SubscribeApplied(_)))
2514+
));
2515+
2516+
// Dropping client B must not break client A's sender-view bookkeeping.
2517+
subs.remove_subscriber(client_id_for_b);
2518+
2519+
// Delete the backing row and verify the surviving subscriber still receives the view delta.
2520+
let _ = commit_tx(&db, &subs, [(view_table_id, product![id_for_a, 7_u8])], [])?;
2521+
2522+
let schema = ProductType::from([AlgebraicType::U8]);
2523+
assert_v2_tx_update_for_table(
2524+
rx_for_a.recv(),
2525+
ws_v2::QuerySetId::new(1),
2526+
"my_view",
2527+
&schema,
2528+
[],
2529+
[product![7_u8]],
2530+
)
2531+
.await;
24652532

24662533
Ok(())
24672534
}
@@ -2477,6 +2544,59 @@ mod tests {
24772544
Ok(())
24782545
}
24792546

2547+
fn update_row_counts<I, D, BI, BD>(
2548+
rows_received: &mut HashMap<ProductValue, i32>,
2549+
schema: &ProductType,
2550+
inserts: I,
2551+
deletes: D,
2552+
) where
2553+
I: IntoIterator<Item = BI>,
2554+
D: IntoIterator<Item = BD>,
2555+
BI: AsRef<[u8]>,
2556+
BD: AsRef<[u8]>,
2557+
{
2558+
for row in inserts.into_iter().map(|bytes| {
2559+
let mut bytes = bytes.as_ref();
2560+
ProductValue::decode(schema, &mut bytes).unwrap()
2561+
}) {
2562+
*rows_received.entry(row).or_insert(0) += 1;
2563+
}
2564+
2565+
for row in deletes.into_iter().map(|bytes| {
2566+
let mut bytes = bytes.as_ref();
2567+
ProductValue::decode(schema, &mut bytes).unwrap()
2568+
}) {
2569+
*rows_received.entry(row).or_insert(0) -= 1;
2570+
}
2571+
}
2572+
2573+
fn assert_received_rows(
2574+
rows_received: HashMap<ProductValue, i32>,
2575+
inserts: impl IntoIterator<Item = ProductValue>,
2576+
deletes: impl IntoIterator<Item = ProductValue>,
2577+
) {
2578+
assert_eq!(
2579+
rows_received
2580+
.iter()
2581+
.filter(|(_, n)| n > &&0)
2582+
.map(|(row, _)| row)
2583+
.cloned()
2584+
.sorted()
2585+
.collect::<Vec<_>>(),
2586+
inserts.into_iter().sorted().collect::<Vec<_>>()
2587+
);
2588+
assert_eq!(
2589+
rows_received
2590+
.iter()
2591+
.filter(|(_, n)| n < &&0)
2592+
.map(|(row, _)| row)
2593+
.cloned()
2594+
.sorted()
2595+
.collect::<Vec<_>>(),
2596+
deletes.into_iter().sorted().collect::<Vec<_>>()
2597+
);
2598+
}
2599+
24802600
/// Pull a message from receiver and assert that it is a `TxUpdate` with the expected rows
24812601
async fn assert_tx_update_for_table(
24822602
rx: impl Future<Output = Option<OutboundMessage>>,
@@ -2485,7 +2605,7 @@ mod tests {
24852605
inserts: impl IntoIterator<Item = ProductValue>,
24862606
deletes: impl IntoIterator<Item = ProductValue>,
24872607
) {
2488-
match rx.await {
2608+
match recv_outbound_message(rx, "TxUpdate").await {
24892609
Some(OutboundMessage::V1(SerializableMessage::TxUpdate(TransactionUpdateMessage {
24902610
database_update:
24912611
SubscriptionUpdateMessage {
@@ -2512,49 +2632,70 @@ mod tests {
25122632
panic!("expected an uncompressed table update")
25132633
};
25142634

2515-
for row in table_update
2516-
.inserts
2517-
.into_iter()
2518-
.map(|bytes| ProductValue::decode(schema, &mut &*bytes).unwrap())
2519-
{
2520-
*rows_received.entry(row).or_insert(0) += 1;
2521-
}
2522-
2523-
for row in table_update
2524-
.deletes
2525-
.into_iter()
2526-
.map(|bytes| ProductValue::decode(schema, &mut &*bytes).unwrap())
2527-
{
2528-
*rows_received.entry(row).or_insert(0) -= 1;
2529-
}
2635+
update_row_counts(&mut rows_received, schema, &table_update.inserts, &table_update.deletes);
25302636
}
25312637

2532-
assert_eq!(
2533-
rows_received
2534-
.iter()
2535-
.filter(|(_, n)| n > &&0)
2536-
.map(|(row, _)| row)
2537-
.cloned()
2538-
.sorted()
2539-
.collect::<Vec<_>>(),
2540-
inserts.into_iter().sorted().collect::<Vec<_>>()
2541-
);
2542-
assert_eq!(
2543-
rows_received
2544-
.iter()
2545-
.filter(|(_, n)| n < &&0)
2546-
.map(|(row, _)| row)
2547-
.cloned()
2548-
.sorted()
2549-
.collect::<Vec<_>>(),
2550-
deletes.into_iter().sorted().collect::<Vec<_>>()
2551-
);
2638+
assert_received_rows(rows_received, inserts, deletes);
25522639
}
25532640
Some(msg) => panic!("expected a TxUpdate, but got {msg:#?}"),
25542641
None => panic!("The receiver closed due to an error"),
25552642
}
25562643
}
25572644

2645+
/// Pull a message from receiver and assert that it is a v2 `TransactionUpdate`
2646+
/// with the expected rows for a single table in a single query set.
2647+
async fn assert_v2_tx_update_for_table(
2648+
rx: impl Future<Output = Option<OutboundMessage>>,
2649+
query_set_id: ws_v2::QuerySetId,
2650+
table_name: &str,
2651+
schema: &ProductType,
2652+
inserts: impl IntoIterator<Item = ProductValue>,
2653+
deletes: impl IntoIterator<Item = ProductValue>,
2654+
) {
2655+
match recv_outbound_message(rx, "v2 TransactionUpdate").await {
2656+
Some(OutboundMessage::V2(ws_v2::ServerMessage::TransactionUpdate(update))) => {
2657+
assert_eq!(update.query_sets.len(), 1);
2658+
let query_set = &update.query_sets[0];
2659+
assert_eq!(query_set.query_set_id, query_set_id);
2660+
assert_eq!(query_set.tables.len(), 1);
2661+
2662+
let table_update = &query_set.tables[0];
2663+
assert_eq!(table_update.table_name.as_ref(), table_name);
2664+
2665+
let mut rows_received: HashMap<ProductValue, i32> = HashMap::new();
2666+
2667+
for rows in table_update.rows.iter() {
2668+
let ws_v2::TableUpdateRows::PersistentTable(rows) = rows else {
2669+
panic!("expected a persistent-table update")
2670+
};
2671+
2672+
update_row_counts(&mut rows_received, schema, &rows.inserts, &rows.deletes);
2673+
}
2674+
2675+
assert_received_rows(rows_received, inserts, deletes);
2676+
}
2677+
Some(msg) => panic!("expected a v2 TransactionUpdate, but got {msg:#?}"),
2678+
None => panic!("The receiver closed due to an error"),
2679+
}
2680+
}
2681+
2682+
async fn recv_outbound_message(
2683+
rx: impl Future<Output = Option<OutboundMessage>>,
2684+
expected: &str,
2685+
) -> Option<OutboundMessage> {
2686+
tokio::time::timeout(TEST_MESSAGE_TIMEOUT, rx)
2687+
.await
2688+
.unwrap_or_else(|_| panic!("timed out waiting for {expected}"))
2689+
}
2690+
2691+
async fn assert_no_outbound_message(rx: impl Future<Output = Option<OutboundMessage>>) {
2692+
match tokio::time::timeout(TEST_MESSAGE_TIMEOUT, rx).await {
2693+
Err(_) => {}
2694+
Ok(Some(msg)) => panic!("expected no message, got {msg:#?}"),
2695+
Ok(None) => panic!("the receiver closed due to an error"),
2696+
}
2697+
}
2698+
25582699
/// Assert that the future `f` completes only after `durability` is marked
25592700
/// durable.
25602701
///

0 commit comments

Comments
 (0)