Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions crates/core/src/subscription/module_subscription_actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1534,6 +1534,83 @@ mod tests {
Ok(())
}

/// Test that a client and the database owner can subscribe to the same query
#[tokio::test]
async fn test_rls_for_owner() -> anyhow::Result<()> {
// Establish a connection for owner and client
let (tx_for_a, mut rx_for_a) = client_connection(client_id_from_u8(0));
let (tx_for_b, mut rx_for_b) = client_connection(client_id_from_u8(1));

let db = relational_db()?;
let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone());

// Create table `t`
let table_id = db.create_table_for_test("t", &[("id", AlgebraicType::identity())], &[0.into()])?;

// Restrict access to `t`
insert_rls_rules(&db, [table_id], ["select * from t where id = :sender"])?;

let mut query_ids = 0;

// Have owner and client subscribe to `t`
subscribe_multi(&subs, &["select * from t"], tx_for_a, &mut query_ids)?;
subscribe_multi(&subs, &["select * from t"], tx_for_b, &mut query_ids)?;

// Wait for both subscriptions
assert_matches!(
rx_for_a.recv().await,
Some(SerializableMessage::Subscription(SubscriptionMessage {
result: SubscriptionResult::SubscribeMulti(_),
..
}))
);
assert_matches!(
rx_for_b.recv().await,
Some(SerializableMessage::Subscription(SubscriptionMessage {
result: SubscriptionResult::SubscribeMulti(_),
..
}))
);

let schema = ProductType::from([AlgebraicType::identity()]);

let id_for_b = identity_from_u8(1);
let id_for_c = identity_from_u8(2);

commit_tx(
&db,
&subs,
[],
[
// Insert an identity for client `b` plus a random identity
(table_id, product![id_for_b]),
(table_id, product![id_for_c]),
],
)?;

assert_tx_update_for_table(
&mut rx_for_a,
table_id,
&schema,
// The owner should receive both identities
[product![id_for_b], product![id_for_c]],
[],
)
.await;

assert_tx_update_for_table(
&mut rx_for_b,
table_id,
&schema,
// Client `b` should only receive its identity
[product![id_for_b]],
[],
)
.await;

Ok(())
}

/// Test that we do not send empty updates to clients
#[tokio::test]
async fn test_no_empty_updates() -> anyhow::Result<()> {
Expand Down
6 changes: 5 additions & 1 deletion crates/core/src/subscription/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,11 @@ pub fn compile_query_with_hashes(
let tx = SchemaViewer::new(tx, auth);
let (plans, has_param) = SubscriptionPlan::compile(input, &tx, auth)?;

if has_param {
if auth.is_owner() || has_param {
// Note that when generating hashes for queries from owners,
// we always treat them as if they were parameterized by :sender.
// This is because RLS is not applicable to owners.
// Hence owner hashes must never overlap with client hashes.
return Ok(Plan::new(plans, hash_with_param, input.to_owned()));
}
Ok(Plan::new(plans, hash, input.to_owned()))
Expand Down
26 changes: 18 additions & 8 deletions crates/core/src/subscription/subscription.rs
Original file line number Diff line number Diff line change
Expand Up @@ -614,13 +614,25 @@ pub(crate) fn get_all(relational_db: &RelationalDB, tx: &Tx, auth: &AuthCtx) ->
.get_all_tables(tx)?
.iter()
.map(Deref::deref)
.filter(|t| {
t.table_type == StTableType::User && (auth.owner == auth.caller || t.table_access == StAccess::Public)
})
.filter(|t| t.table_type == StTableType::User && (auth.is_owner() || t.table_access == StAccess::Public))
.map(|schema| {
let sql = format!("SELECT * FROM {}", schema.table_name);
SubscriptionPlan::compile(&sql, &SchemaViewer::new(tx, auth), auth)
.map(|(plans, has_param)| Plan::new(plans, QueryHash::from_string(&sql, auth.caller, has_param), sql))
let tx = SchemaViewer::new(tx, auth);
SubscriptionPlan::compile(&sql, &tx, auth).map(|(plans, has_param)| {
Plan::new(
plans,
QueryHash::from_string(
&sql,
auth.caller,
// Note that when generating hashes for queries from owners,
// we always treat them as if they were parameterized by :sender.
// This is because RLS is not applicable to owners.
// Hence owner hashes must never overlap with client hashes.
auth.is_owner() || has_param,
),
sql,
)
})
})
.collect::<Result<_, _>>()?)
}
Expand All @@ -638,9 +650,7 @@ pub(crate) fn legacy_get_all(
.get_all_tables(tx)?
.iter()
.map(Deref::deref)
.filter(|t| {
t.table_type == StTableType::User && (auth.owner == auth.caller || t.table_access == StAccess::Public)
})
.filter(|t| t.table_type == StTableType::User && (auth.is_owner() || t.table_access == StAccess::Public))
.map(|src| SupportedQuery {
kind: query::Supported::Select,
expr: QueryExpr::new(src),
Expand Down
4 changes: 2 additions & 2 deletions crates/expr/src/rls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub fn resolve_views_for_sub(
has_param: &mut bool,
) -> anyhow::Result<Vec<ProjectName>> {
// RLS does not apply to the database owner
if auth.caller == auth.owner {
if auth.is_owner() {
return Ok(vec![expr]);
}

Expand Down Expand Up @@ -56,7 +56,7 @@ pub fn resolve_views_for_sub(
/// Mainly a wrapper around [resolve_views_for_expr].
pub fn resolve_views_for_sql(tx: &impl SchemaView, expr: ProjectList, auth: &AuthCtx) -> anyhow::Result<ProjectList> {
// RLS does not apply to the database owner
if auth.caller == auth.owner {
if auth.is_owner() {
return Ok(expr);
}
// The subscription language is a subset of the sql language.
Expand Down
4 changes: 4 additions & 0 deletions crates/lib/src/identity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ impl AuthCtx {
pub fn for_current(owner: Identity) -> Self {
Self { owner, caller: owner }
}
/// Does `owner == caller`
pub fn is_owner(&self) -> bool {
self.owner == self.caller
}
/// WARNING: Use this only for simple test were the `auth` don't matter
pub fn for_testing() -> Self {
AuthCtx {
Expand Down
Loading