11use super :: execution_unit:: QueryHash ;
22use super :: module_subscription_manager:: { Plan , SubscriptionGaugeStats , SubscriptionManager } ;
3- use super :: query:: compile_read_only_query ;
3+ use super :: query:: compile_query_with_hashes ;
44use super :: tx:: DeltaTx ;
55use super :: { collect_table_update, record_exec_metrics, TableUpdateType } ;
66use crate :: client:: messages:: {
@@ -16,8 +16,8 @@ use crate::estimation::estimate_rows_scanned;
1616use crate :: execution_context:: { Workload , WorkloadType } ;
1717use crate :: host:: module_host:: { DatabaseUpdate , EventStatus , ModuleEvent } ;
1818use crate :: messages:: websocket:: Subscribe ;
19- use crate :: sql:: ast:: SchemaViewer ;
2019use crate :: subscription:: execute_plans;
20+ use crate :: subscription:: query:: is_subscribe_to_all_tables;
2121use crate :: vm:: check_row_limit;
2222use crate :: worker_metrics:: WORKER_METRICS ;
2323use parking_lot:: RwLock ;
@@ -27,8 +27,6 @@ use spacetimedb_client_api_messages::websocket::{
2727 UnsubscribeMulti ,
2828} ;
2929use spacetimedb_execution:: pipelined:: PipelinedProject ;
30- use spacetimedb_expr:: check:: parse_and_type_sub;
31- use spacetimedb_expr:: errors:: TypingError ;
3230use spacetimedb_lib:: identity:: AuthCtx ;
3331use spacetimedb_lib:: metrics:: ExecutionMetrics ;
3432use spacetimedb_lib:: Identity ;
@@ -105,6 +103,20 @@ type FullSubscriptionUpdate = FormatSwitch<ws::DatabaseUpdate<BsatnFormat>, ws::
105103
106104/// A utility for sending an error message to a client and returning early
107105macro_rules! return_on_err {
106+ ( $expr: expr, $handler: expr) => {
107+ match $expr {
108+ Ok ( val) => val,
109+ Err ( e) => {
110+ // TODO: Handle errors sending messages.
111+ let _ = $handler( e. to_string( ) . into( ) ) ;
112+ return Ok ( ( ) ) ;
113+ }
114+ }
115+ } ;
116+ }
117+
118+ /// A utility for sending an error message to a client and returning early
119+ macro_rules! return_on_err_with_sql {
108120 ( $expr: expr, $sql: expr, $handler: expr) => {
109121 match $expr. map_err( |err| DBError :: WithSql {
110122 sql: $sql. into( ) ,
@@ -120,12 +132,6 @@ macro_rules! return_on_err {
120132 } ;
121133}
122134
123- /// Hash a sql query, using the caller's identity if necessary
124- fn hash_query ( sql : & str , tx : & TxId , auth : & AuthCtx ) -> Result < QueryHash , TypingError > {
125- parse_and_type_sub ( sql, & SchemaViewer :: new ( tx, auth) , auth)
126- . map ( |( _, has_param) | QueryHash :: from_string ( sql, auth. caller , has_param) )
127- }
128-
129135impl ModuleSubscriptions {
130136 pub fn new ( relational_db : Arc < RelationalDB > , subscriptions : Subscriptions , owner_identity : Identity ) -> Self {
131137 let stats = Box :: new ( SubscriptionGauges :: new ( & relational_db. database_identity ( ) ) ) ;
@@ -248,29 +254,34 @@ impl ModuleSubscriptions {
248254 } )
249255 } ;
250256
257+ let sql = request. query ;
258+ let auth = AuthCtx :: new ( self . owner_identity , sender. id . identity ) ;
259+ let hash = QueryHash :: from_string ( & sql, auth. caller , false ) ;
260+ let hash_with_param = QueryHash :: from_string ( & sql, auth. caller , true ) ;
261+
251262 let tx = scopeguard:: guard ( self . relational_db . begin_tx ( Workload :: Subscribe ) , |tx| {
252263 self . relational_db . release_tx ( tx) ;
253264 } ) ;
254- let auth = AuthCtx :: new ( self . owner_identity , sender. id . identity ) ;
255- let query = super :: query:: WHITESPACE . replace_all ( & request. query , " " ) ;
256- let sql = query. trim ( ) ;
257-
258- let hash = return_on_err ! ( hash_query( sql, & tx, & auth) , sql, send_err_msg) ;
259265
260266 let existing_query = {
261267 let guard = self . subscriptions . read ( ) ;
262268 guard. query ( & hash)
263269 } ;
264270
265- let query = return_on_err ! (
266- existing_query
267- . map( Ok )
268- . unwrap_or_else( || compile_read_only_query( & auth, & tx, sql) . map( Arc :: new) ) ,
271+ let query = return_on_err_with_sql ! (
272+ existing_query. map( Ok ) . unwrap_or_else( || compile_query_with_hashes(
273+ & auth,
274+ & tx,
275+ & sql,
276+ hash,
277+ hash_with_param
278+ )
279+ . map( Arc :: new) ) ,
269280 sql,
270281 send_err_msg
271282 ) ;
272283
273- let ( table_rows, metrics) = return_on_err ! (
284+ let ( table_rows, metrics) = return_on_err_with_sql ! (
274285 self . evaluate_initial_subscription( sender. clone( ) , query. clone( ) , & tx, & auth, TableUpdateType :: Subscribe ) ,
275286 query. sql( ) ,
276287 send_err_msg
@@ -356,7 +367,7 @@ impl ModuleSubscriptions {
356367 self . relational_db . release_tx ( tx) ;
357368 } ) ;
358369 let auth = AuthCtx :: new ( self . owner_identity , sender. id . identity ) ;
359- let ( table_rows, metrics) = return_on_err ! (
370+ let ( table_rows, metrics) = return_on_err_with_sql ! (
360371 self . evaluate_initial_subscription( sender. clone( ) , query. clone( ) , & tx, & auth, TableUpdateType :: Unsubscribe ) ,
361372 query. sql( ) ,
362373 send_err_msg
@@ -452,6 +463,74 @@ impl ModuleSubscriptions {
452463 Ok ( ( ) )
453464 }
454465
466+ /// Compiles the queries in a [Subscribe] or [SubscribeMulti] message.
467+ ///
468+ /// Note, we hash queries to avoid recompilation,
469+ /// but we need to know if a query is parameterized in order to hash it correctly.
470+ /// This requires that we type check which in turn requires that we start a tx.
471+ ///
472+ /// Unfortunately parsing with sqlparser is quite expensive,
473+ /// so we'd like to avoid that cost while holding the tx lock,
474+ /// especially since all we're trying to do is generate a hash.
475+ ///
476+ /// Instead we generate two hashes and outside of the tx lock.
477+ /// If either one is currently tracked, we can avoid recompilation.
478+ fn compile_queries (
479+ & self ,
480+ sender : Identity ,
481+ queries : impl IntoIterator < Item = Box < str > > ,
482+ num_queries : usize ,
483+ ) -> Result < ( Vec < Arc < Plan > > , AuthCtx , TxId ) , DBError > {
484+ let mut subscribe_to_all_tables = false ;
485+ let mut plans = Vec :: with_capacity ( num_queries) ;
486+ let mut query_hashes = Vec :: with_capacity ( num_queries) ;
487+
488+ for sql in queries {
489+ if is_subscribe_to_all_tables ( & sql) {
490+ subscribe_to_all_tables = true ;
491+ continue ;
492+ }
493+ let hash = QueryHash :: from_string ( & sql, sender, false ) ;
494+ let hash_with_param = QueryHash :: from_string ( & sql, sender, true ) ;
495+ query_hashes. push ( ( sql, hash, hash_with_param) ) ;
496+ }
497+
498+ let auth = AuthCtx :: new ( self . owner_identity , sender) ;
499+
500+ // We always get the db lock before the subscription lock to avoid deadlocks.
501+ let tx = scopeguard:: guard ( self . relational_db . begin_tx ( Workload :: Subscribe ) , |tx| {
502+ self . relational_db . release_tx ( tx) ;
503+ } ) ;
504+ let guard = self . subscriptions . read ( ) ;
505+
506+ if subscribe_to_all_tables {
507+ plans. extend (
508+ super :: subscription:: get_all ( & self . relational_db , & tx, & auth) ?
509+ . into_iter ( )
510+ . map ( Arc :: new) ,
511+ ) ;
512+ }
513+
514+ for ( sql, hash, hash_with_param) in query_hashes {
515+ if let Some ( unit) = guard. query ( & hash) {
516+ plans. push ( unit) ;
517+ } else if let Some ( unit) = guard. query ( & hash_with_param) {
518+ plans. push ( unit) ;
519+ } else {
520+ plans. push ( Arc :: new (
521+ compile_query_with_hashes ( & auth, & tx, & sql, hash, hash_with_param) . map_err ( |err| {
522+ DBError :: WithSql {
523+ error : Box :: new ( DBError :: Other ( err. into ( ) ) ) ,
524+ sql,
525+ }
526+ } ) ?,
527+ ) ) ;
528+ }
529+ }
530+
531+ Ok ( ( plans, auth, scopeguard:: ScopeGuard :: into_inner ( tx) ) )
532+ }
533+
455534 #[ tracing:: instrument( level = "trace" , skip_all) ]
456535 pub fn add_multi_subscription (
457536 & self ,
@@ -473,39 +552,14 @@ impl ModuleSubscriptions {
473552 } ) ;
474553 } ;
475554
476- // We always get the db lock before the subscription lock to avoid deadlocks.
477- let tx = scopeguard:: guard ( self . relational_db . begin_tx ( Workload :: Subscribe ) , |tx| {
555+ let num_queries = request. query_strings . len ( ) ;
556+ let ( queries, auth, tx) = return_on_err ! (
557+ self . compile_queries( sender. id. identity, request. query_strings, num_queries) ,
558+ send_err_msg
559+ ) ;
560+ let tx = scopeguard:: guard ( tx, |tx| {
478561 self . relational_db . release_tx ( tx) ;
479562 } ) ;
480- let auth = AuthCtx :: new ( self . owner_identity , sender. id . identity ) ;
481- let mut queries = vec ! [ ] ;
482- let guard = self . subscriptions . read ( ) ;
483- for sql in request
484- . query_strings
485- . iter ( )
486- . map ( |sql| super :: query:: WHITESPACE . replace_all ( sql, " " ) )
487- {
488- let sql = sql. trim ( ) ;
489- if sql == super :: query:: SUBSCRIBE_TO_ALL_QUERY {
490- queries. extend (
491- super :: subscription:: get_all ( & self . relational_db , & tx, & auth) ?
492- . into_iter ( )
493- . map ( Arc :: new) ,
494- ) ;
495- continue ;
496- }
497-
498- let hash = return_on_err ! ( hash_query( sql, & tx, & auth) , sql, send_err_msg) ;
499-
500- if let Some ( unit) = guard. query ( & hash) {
501- queries. push ( unit) ;
502- } else {
503- let compiled = return_on_err ! ( compile_read_only_query( & auth, & tx, sql) , sql, send_err_msg) ;
504- queries. push ( Arc :: new ( compiled) ) ;
505- }
506- }
507-
508- drop ( guard) ;
509563
510564 // We minimize locking so that other clients can add subscriptions concurrently.
511565 // We are protected from race conditions with broadcasts, because we have the db lock,
@@ -561,40 +615,11 @@ impl ModuleSubscriptions {
561615 timer : Instant ,
562616 _assert : Option < AssertTxFn > ,
563617 ) -> Result < ( ) , DBError > {
564- let tx = scopeguard:: guard ( self . relational_db . begin_tx ( Workload :: Subscribe ) , |tx| {
618+ let num_queries = subscription. query_strings . len ( ) ;
619+ let ( queries, auth, tx) = self . compile_queries ( sender. id . identity , subscription. query_strings , num_queries) ?;
620+ let tx = scopeguard:: guard ( tx, |tx| {
565621 self . relational_db . release_tx ( tx) ;
566622 } ) ;
567- let request_id = subscription. request_id ;
568- let auth = AuthCtx :: new ( self . owner_identity , sender. id . identity ) ;
569- let mut queries = vec ! [ ] ;
570-
571- let guard = self . subscriptions . read ( ) ;
572-
573- for sql in subscription
574- . query_strings
575- . iter ( )
576- . map ( |sql| super :: query:: WHITESPACE . replace_all ( sql, " " ) )
577- {
578- let sql = sql. trim ( ) ;
579- if sql == super :: query:: SUBSCRIBE_TO_ALL_QUERY {
580- queries. extend (
581- super :: subscription:: get_all ( & self . relational_db , & tx, & auth) ?
582- . into_iter ( )
583- . map ( Arc :: new) ,
584- ) ;
585- continue ;
586- }
587-
588- let hash = hash_query ( sql, & tx, & auth) ?;
589- if let Some ( unit) = guard. query ( & hash) {
590- queries. push ( unit) ;
591- } else {
592- let compiled = compile_read_only_query ( & auth, & tx, sql) ?;
593- queries. push ( Arc :: new ( compiled) ) ;
594- }
595- }
596-
597- drop ( guard) ;
598623
599624 check_row_limit (
600625 & queries,
@@ -639,7 +664,7 @@ impl ModuleSubscriptions {
639664 // on the wire
640665 let _ = sender. send_message ( SubscriptionUpdateMessage {
641666 database_update,
642- request_id : Some ( request_id) ,
667+ request_id : Some ( subscription . request_id ) ,
643668 timer : Some ( timer) ,
644669 } ) ;
645670 Ok ( ( ) )
0 commit comments