Skip to content

Commit 022ca68

Browse files
Fix query overwrites in the subscription manager (#2905)
1 parent 28186d8 commit 022ca68

2 files changed

Lines changed: 129 additions & 18 deletions

File tree

crates/core/src/subscription/module_subscription_actor.rs

Lines changed: 112 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ impl ModuleSubscriptions {
554554
fn compile_queries(
555555
&self,
556556
sender: Identity,
557-
queries: impl IntoIterator<Item = Box<str>>,
557+
queries: &[Box<str>],
558558
num_queries: usize,
559559
metrics: &SubscriptionMetrics,
560560
) -> Result<(Vec<Arc<Plan>>, AuthCtx, TxId, HistogramTimer), DBError> {
@@ -563,12 +563,13 @@ impl ModuleSubscriptions {
563563
let mut query_hashes = Vec::with_capacity(num_queries);
564564

565565
for sql in queries {
566-
if is_subscribe_to_all_tables(&sql) {
566+
let sql = sql.trim();
567+
if is_subscribe_to_all_tables(sql) {
567568
subscribe_to_all_tables = true;
568569
continue;
569570
}
570-
let hash = QueryHash::from_string(&sql, sender, false);
571-
let hash_with_param = QueryHash::from_string(&sql, sender, true);
571+
let hash = QueryHash::from_string(sql, sender, false);
572+
let hash_with_param = QueryHash::from_string(sql, sender, true);
572573
query_hashes.push((sql, hash, hash_with_param));
573574
}
574575

@@ -606,10 +607,10 @@ impl ModuleSubscriptions {
606607
plans.push(unit);
607608
} else {
608609
plans.push(Arc::new(
609-
compile_query_with_hashes(&auth, &tx, &sql, hash, hash_with_param).map_err(|err| {
610+
compile_query_with_hashes(&auth, &tx, sql, hash, hash_with_param).map_err(|err| {
610611
DBError::WithSql {
611612
error: Box::new(DBError::Other(err.into())),
612-
sql,
613+
sql: sql.into(),
613614
}
614615
})?,
615616
));
@@ -670,7 +671,7 @@ impl ModuleSubscriptions {
670671
let (queries, auth, tx, compile_timer) = return_on_err!(
671672
self.compile_queries(
672673
sender.id.identity,
673-
request.query_strings,
674+
&request.query_strings,
674675
num_queries,
675676
&subscription_metrics
676677
),
@@ -767,7 +768,7 @@ impl ModuleSubscriptions {
767768

768769
let (queries, auth, tx, compile_timer) = self.compile_queries(
769770
sender.id.identity,
770-
subscription.query_strings,
771+
&subscription.query_strings,
771772
num_queries,
772773
&subscription_metrics,
773774
)?;
@@ -2331,48 +2332,147 @@ mod tests {
23312332
Ok(())
23322333
}
23332334

2334-
/// Test that one client unsubscribing does not affect another
2335+
/// Test that one client subscribing does not affect another
23352336
#[tokio::test]
2336-
async fn test_unsubscribe() -> anyhow::Result<()> {
2337+
async fn test_subscribe_distinct_queries_same_plan() -> anyhow::Result<()> {
23372338
// Establish a connection for each client
23382339
let (tx_for_a, mut rx_for_a) = client_connection(client_id_from_u8(1));
23392340
let (tx_for_b, mut rx_for_b) = client_connection(client_id_from_u8(2));
23402341

23412342
let db = relational_db()?;
23422343
let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone());
23432344

2344-
let u_id = db.create_table_for_test(
2345+
let u_id = db.create_table_for_test_with_the_works(
23452346
"u",
23462347
&[
23472348
("i", AlgebraicType::U64),
23482349
("a", AlgebraicType::U64),
23492350
("b", AlgebraicType::U64),
23502351
],
23512352
&[0.into()],
2353+
// The join column for this table does not have to be unique,
2354+
// because pruning only requires us to probe the join index on `v`.
2355+
&[],
2356+
StAccess::Public,
23522357
)?;
2353-
let v_id = db.create_table_for_test(
2358+
let v_id = db.create_table_for_test_with_the_works(
23542359
"v",
23552360
&[
23562361
("i", AlgebraicType::U64),
23572362
("x", AlgebraicType::U64),
23582363
("y", AlgebraicType::U64),
23592364
],
23602365
&[0.into(), 1.into()],
2366+
&[0.into()],
2367+
StAccess::Public,
23612368
)?;
23622369

23632370
commit_tx(&db, &subs, [], [(v_id, product![1u64, 1u64, 1u64])])?;
23642371

23652372
let mut query_ids = 0;
23662373

2374+
// Both clients subscribe to the same query modulo whitespace
23672375
subscribe_multi(
23682376
&subs,
23692377
&["select u.* from u join v on u.i = v.i where v.x = 1"],
23702378
tx_for_a,
23712379
&mut query_ids,
23722380
)?;
2381+
subscribe_multi(
2382+
&subs,
2383+
&["select u.* from u join v on u.i = v.i where v.x = 1"],
2384+
tx_for_b.clone(),
2385+
&mut query_ids,
2386+
)?;
2387+
2388+
// Wait for both subscriptions
2389+
assert_matches!(
2390+
rx_for_a.recv().await,
2391+
Some(SerializableMessage::Subscription(SubscriptionMessage {
2392+
result: SubscriptionResult::SubscribeMulti(_),
2393+
..
2394+
}))
2395+
);
2396+
assert_matches!(
2397+
rx_for_b.recv().await,
2398+
Some(SerializableMessage::Subscription(SubscriptionMessage {
2399+
result: SubscriptionResult::SubscribeMulti(_),
2400+
..
2401+
}))
2402+
);
2403+
2404+
// Insert a new row into `u`
2405+
commit_tx(&db, &subs, [], [(u_id, product![1u64, 0u64, 0u64])])?;
2406+
2407+
assert_tx_update_for_table(
2408+
&mut rx_for_a,
2409+
u_id,
2410+
&ProductType::from([AlgebraicType::U64, AlgebraicType::U64, AlgebraicType::U64]),
2411+
[product![1u64, 0u64, 0u64]],
2412+
[],
2413+
)
2414+
.await;
2415+
2416+
assert_tx_update_for_table(
2417+
&mut rx_for_b,
2418+
u_id,
2419+
&ProductType::from([AlgebraicType::U64, AlgebraicType::U64, AlgebraicType::U64]),
2420+
[product![1u64, 0u64, 0u64]],
2421+
[],
2422+
)
2423+
.await;
2424+
2425+
Ok(())
2426+
}
2427+
2428+
/// Test that one client unsubscribing does not affect another
2429+
#[tokio::test]
2430+
async fn test_unsubscribe_distinct_queries_same_plan() -> anyhow::Result<()> {
2431+
// Establish a connection for each client
2432+
let (tx_for_a, mut rx_for_a) = client_connection(client_id_from_u8(1));
2433+
let (tx_for_b, mut rx_for_b) = client_connection(client_id_from_u8(2));
2434+
2435+
let db = relational_db()?;
2436+
let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone());
2437+
2438+
let u_id = db.create_table_for_test_with_the_works(
2439+
"u",
2440+
&[
2441+
("i", AlgebraicType::U64),
2442+
("a", AlgebraicType::U64),
2443+
("b", AlgebraicType::U64),
2444+
],
2445+
&[0.into()],
2446+
// The join column for this table does not have to be unique,
2447+
// because pruning only requires us to probe the join index on `v`.
2448+
&[],
2449+
StAccess::Public,
2450+
)?;
2451+
let v_id = db.create_table_for_test_with_the_works(
2452+
"v",
2453+
&[
2454+
("i", AlgebraicType::U64),
2455+
("x", AlgebraicType::U64),
2456+
("y", AlgebraicType::U64),
2457+
],
2458+
&[0.into(), 1.into()],
2459+
&[0.into()],
2460+
StAccess::Public,
2461+
)?;
2462+
2463+
commit_tx(&db, &subs, [], [(v_id, product![1u64, 1u64, 1u64])])?;
2464+
2465+
let mut query_ids = 0;
2466+
23732467
subscribe_multi(
23742468
&subs,
23752469
&["select u.* from u join v on u.i = v.i where v.x = 1"],
2470+
tx_for_a,
2471+
&mut query_ids,
2472+
)?;
2473+
subscribe_multi(
2474+
&subs,
2475+
&["select u.* from u join v on u.i = v.i where v.x = 1"],
23762476
tx_for_b.clone(),
23772477
&mut query_ids,
23782478
)?;

crates/core/src/subscription/module_subscription_manager.rs

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ impl QueriedTableIndexIds {
399399
/// See [`JoinEdge`] for more details.
400400
#[derive(Debug, Default)]
401401
pub struct JoinEdges {
402-
edges: BTreeMap<JoinEdge, HashMap<AlgebraicValue, QueryHash>>,
402+
edges: BTreeMap<JoinEdge, HashMap<AlgebraicValue, HashSet<QueryHash>>>,
403403
}
404404

405405
impl JoinEdges {
@@ -408,18 +408,28 @@ impl JoinEdges {
408408
let mut inserted = false;
409409
for (edge, rhs_val) in qs.query.join_edges() {
410410
inserted = true;
411-
self.edges.entry(edge).or_default().insert(rhs_val, qs.query.hash);
411+
self.edges
412+
.entry(edge)
413+
.or_default()
414+
.entry(rhs_val)
415+
.or_default()
416+
.insert(qs.query.hash);
412417
}
413418
inserted
414419
}
415420

416421
/// If this query has any join edges, remove them from the map.
417422
fn remove_query(&mut self, query: &Query) {
418423
for (edge, rhs_val) in query.join_edges() {
419-
if let Some(hashes) = self.edges.get_mut(&edge) {
420-
hashes.remove(&rhs_val);
421-
if hashes.is_empty() {
422-
self.edges.remove(&edge);
424+
if let Some(values) = self.edges.get_mut(&edge) {
425+
if let Some(hashes) = values.get_mut(&rhs_val) {
426+
hashes.remove(&query.hash);
427+
if hashes.is_empty() {
428+
values.remove(&rhs_val);
429+
if values.is_empty() {
430+
self.edges.remove(&edge);
431+
}
432+
}
423433
}
424434
}
425435
}
@@ -436,6 +446,7 @@ impl JoinEdges {
436446
self.edges
437447
.range(JoinEdge::range_for_table(table_id))
438448
.filter_map(move |(edge, hashes)| find_rhs_val(edge, row).as_ref().and_then(|rhs_val| hashes.get(rhs_val)))
449+
.flatten()
439450
}
440451
}
441452

0 commit comments

Comments
 (0)