Skip to content

Commit a8426c9

Browse files
Fix query overwrites in the subscription manager
1 parent c3afc17 commit a8426c9

2 files changed

Lines changed: 120 additions & 10 deletions

File tree

crates/core/src/subscription/module_subscription_actor.rs

Lines changed: 103 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2331,48 +2331,147 @@ mod tests {
23312331
Ok(())
23322332
}
23332333

2334-
/// Test that one client unsubscribing does not affect another
2334+
/// Test that one client subscribing does not affect another
23352335
#[tokio::test]
2336-
async fn test_unsubscribe() -> anyhow::Result<()> {
2336+
async fn test_subscribe_distinct_queries_same_plan() -> anyhow::Result<()> {
23372337
// Establish a connection for each client
23382338
let (tx_for_a, mut rx_for_a) = client_connection(client_id_from_u8(1));
23392339
let (tx_for_b, mut rx_for_b) = client_connection(client_id_from_u8(2));
23402340

23412341
let db = relational_db()?;
23422342
let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone());
23432343

2344-
let u_id = db.create_table_for_test(
2344+
let u_id = db.create_table_for_test_with_the_works(
23452345
"u",
23462346
&[
23472347
("i", AlgebraicType::U64),
23482348
("a", AlgebraicType::U64),
23492349
("b", AlgebraicType::U64),
23502350
],
23512351
&[0.into()],
2352+
// The join column for this table does not have to be unique,
2353+
// because pruning only requires us to probe the join index on `v`.
2354+
&[],
2355+
StAccess::Public,
23522356
)?;
2353-
let v_id = db.create_table_for_test(
2357+
let v_id = db.create_table_for_test_with_the_works(
23542358
"v",
23552359
&[
23562360
("i", AlgebraicType::U64),
23572361
("x", AlgebraicType::U64),
23582362
("y", AlgebraicType::U64),
23592363
],
23602364
&[0.into(), 1.into()],
2365+
&[0.into()],
2366+
StAccess::Public,
23612367
)?;
23622368

23632369
commit_tx(&db, &subs, [], [(v_id, product![1u64, 1u64, 1u64])])?;
23642370

23652371
let mut query_ids = 0;
23662372

2373+
// Both clients subscribe to the same query modulo whitespace
23672374
subscribe_multi(
23682375
&subs,
23692376
&["select u.* from u join v on u.i = v.i where v.x = 1"],
23702377
tx_for_a,
23712378
&mut query_ids,
23722379
)?;
2380+
subscribe_multi(
2381+
&subs,
2382+
&["select u.* from u join v on u.i = v.i where v.x = 1"],
2383+
tx_for_b.clone(),
2384+
&mut query_ids,
2385+
)?;
2386+
2387+
// Wait for both subscriptions
2388+
assert_matches!(
2389+
rx_for_a.recv().await,
2390+
Some(SerializableMessage::Subscription(SubscriptionMessage {
2391+
result: SubscriptionResult::SubscribeMulti(_),
2392+
..
2393+
}))
2394+
);
2395+
assert_matches!(
2396+
rx_for_b.recv().await,
2397+
Some(SerializableMessage::Subscription(SubscriptionMessage {
2398+
result: SubscriptionResult::SubscribeMulti(_),
2399+
..
2400+
}))
2401+
);
2402+
2403+
// Insert a new row into `u`
2404+
commit_tx(&db, &subs, [], [(u_id, product![1u64, 0u64, 0u64])])?;
2405+
2406+
assert_tx_update_for_table(
2407+
&mut rx_for_a,
2408+
u_id,
2409+
&ProductType::from([AlgebraicType::U64, AlgebraicType::U64, AlgebraicType::U64]),
2410+
[product![1u64, 0u64, 0u64]],
2411+
[],
2412+
)
2413+
.await;
2414+
2415+
assert_tx_update_for_table(
2416+
&mut rx_for_b,
2417+
u_id,
2418+
&ProductType::from([AlgebraicType::U64, AlgebraicType::U64, AlgebraicType::U64]),
2419+
[product![1u64, 0u64, 0u64]],
2420+
[],
2421+
)
2422+
.await;
2423+
2424+
Ok(())
2425+
}
2426+
2427+
/// Test that one client unsubscribing does not affect another
2428+
#[tokio::test]
2429+
async fn test_unsubscribe_distinct_queries_same_plan() -> anyhow::Result<()> {
2430+
// Establish a connection for each client
2431+
let (tx_for_a, mut rx_for_a) = client_connection(client_id_from_u8(1));
2432+
let (tx_for_b, mut rx_for_b) = client_connection(client_id_from_u8(2));
2433+
2434+
let db = relational_db()?;
2435+
let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone());
2436+
2437+
let u_id = db.create_table_for_test_with_the_works(
2438+
"u",
2439+
&[
2440+
("i", AlgebraicType::U64),
2441+
("a", AlgebraicType::U64),
2442+
("b", AlgebraicType::U64),
2443+
],
2444+
&[0.into()],
2445+
// The join column for this table does not have to be unique,
2446+
// because pruning only requires us to probe the join index on `v`.
2447+
&[],
2448+
StAccess::Public,
2449+
)?;
2450+
let v_id = db.create_table_for_test_with_the_works(
2451+
"v",
2452+
&[
2453+
("i", AlgebraicType::U64),
2454+
("x", AlgebraicType::U64),
2455+
("y", AlgebraicType::U64),
2456+
],
2457+
&[0.into(), 1.into()],
2458+
&[0.into()],
2459+
StAccess::Public,
2460+
)?;
2461+
2462+
commit_tx(&db, &subs, [], [(v_id, product![1u64, 1u64, 1u64])])?;
2463+
2464+
let mut query_ids = 0;
2465+
23732466
subscribe_multi(
23742467
&subs,
23752468
&["select u.* from u join v on u.i = v.i where v.x = 1"],
2469+
tx_for_a,
2470+
&mut query_ids,
2471+
)?;
2472+
subscribe_multi(
2473+
&subs,
2474+
&["select u.* from u join v on u.i = v.i where v.x = 1"],
23762475
tx_for_b.clone(),
23772476
&mut query_ids,
23782477
)?;

crates/core/src/subscription/module_subscription_manager.rs

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ impl QueriedTableIndexIds {
399399
/// See [`JoinEdge`] for more details.
400400
#[derive(Debug, Default)]
401401
pub struct JoinEdges {
402-
edges: BTreeMap<JoinEdge, HashMap<AlgebraicValue, QueryHash>>,
402+
edges: BTreeMap<JoinEdge, HashMap<AlgebraicValue, HashSet<QueryHash>>>,
403403
}
404404

405405
impl JoinEdges {
@@ -408,18 +408,28 @@ impl JoinEdges {
408408
let mut inserted = false;
409409
for (edge, rhs_val) in qs.query.join_edges() {
410410
inserted = true;
411-
self.edges.entry(edge).or_default().insert(rhs_val, qs.query.hash);
411+
self.edges
412+
.entry(edge)
413+
.or_default()
414+
.entry(rhs_val)
415+
.or_default()
416+
.insert(qs.query.hash);
412417
}
413418
inserted
414419
}
415420

416421
/// If this query has any join edges, remove them from the map.
417422
fn remove_query(&mut self, query: &Query) {
418423
for (edge, rhs_val) in query.join_edges() {
419-
if let Some(hashes) = self.edges.get_mut(&edge) {
420-
hashes.remove(&rhs_val);
421-
if hashes.is_empty() {
422-
self.edges.remove(&edge);
424+
if let Some(values) = self.edges.get_mut(&edge) {
425+
if let Some(hashes) = values.get_mut(&rhs_val) {
426+
hashes.remove(&query.hash);
427+
if hashes.is_empty() {
428+
values.remove(&rhs_val);
429+
if values.is_empty() {
430+
self.edges.remove(&edge);
431+
}
432+
}
423433
}
424434
}
425435
}
@@ -436,6 +446,7 @@ impl JoinEdges {
436446
self.edges
437447
.range(JoinEdge::range_for_table(table_id))
438448
.filter_map(move |(edge, hashes)| find_rhs_val(edge, row).as_ref().and_then(|rhs_val| hashes.get(rhs_val)))
449+
.flatten()
439450
}
440451
}
441452

0 commit comments

Comments
 (0)