@@ -31,15 +31,13 @@ use functions::session_params::{SessionParams, SessionProperty};
3131use functions:: table:: register_udtfs;
3232use snafu:: ResultExt ;
3333#[ cfg( feature = "state-store" ) ]
34- use state_store:: { StateStore , Variable } ;
34+ use state_store:: { SessionRecord , StateStore , Variable } ;
3535use std:: collections:: { HashMap , VecDeque } ;
3636use std:: num:: NonZero ;
3737use std:: sync:: atomic:: AtomicI64 ;
3838use std:: sync:: { Arc , RwLock } ;
3939use std:: thread:: available_parallelism;
4040use time:: { Duration , OffsetDateTime } ;
41- #[ cfg( feature = "state-store" ) ]
42- use tracing:: warn;
4341
4442pub const SESSION_INACTIVITY_EXPIRATION_SECONDS : i64 = 5 * 60 ;
4543static MINIMUM_PARALLEL_OUTPUT_FILES : usize = 1 ;
@@ -54,7 +52,7 @@ pub const fn to_unix(t: OffsetDateTime) -> i64 {
5452pub struct UserSession {
5553 pub metastore : Arc < dyn Metastore > ,
5654 #[ cfg( feature = "state-store" ) ]
57- state_store : Arc < StateStore > ,
55+ state_store : Arc < dyn StateStore > ,
5856 // running_queries contains all the queries running across sessions
5957 pub running_queries : Arc < dyn RunningQueries > ,
6058 pub ctx : SessionContext ,
@@ -65,6 +63,7 @@ pub struct UserSession {
6563 pub expiry : AtomicI64 ,
6664 pub session_params : Arc < SessionParams > ,
6765 pub recent_queries : Arc < RwLock < VecDeque < QueryId > > > ,
66+ pub session_id : String ,
6867}
6968
7069impl UserSession {
@@ -75,8 +74,8 @@ impl UserSession {
7574 config : Arc < Config > ,
7675 catalog_list : Arc < EmbucketCatalogList > ,
7776 runtime_env : Arc < RuntimeEnv > ,
78- # [ cfg ( feature = "state-store" ) ] session_id : & str ,
79- #[ cfg( feature = "state-store" ) ] state_store : Arc < StateStore > ,
77+ session_id : & str ,
78+ #[ cfg( feature = "state-store" ) ] state_store : Arc < dyn StateStore > ,
8079 ) -> Result < Self > {
8180 let sql_parser_dialect = config
8281 . sql_parser_dialect
@@ -91,7 +90,8 @@ impl UserSession {
9190
9291 let session_params = SessionParams :: default ( ) ;
9392 #[ cfg( feature = "state-store" ) ]
94- let session_params_arc = Self :: session_params ( session_id, state_store. clone ( ) ) . await ;
93+ let session_params_arc =
94+ Arc :: new ( Self :: session_params ( session_id, state_store. clone ( ) ) . await ) ;
9595 #[ cfg( not( feature = "state-store" ) ) ]
9696 let session_params_arc = Arc :: new ( session_params. clone ( ) ) ;
9797 let mut config_options = ConfigOptions :: from_env ( ) . context ( ex_error:: DataFusionSnafu ) ?;
@@ -159,52 +159,79 @@ impl UserSession {
159159 ) ) ,
160160 session_params : session_params_arc,
161161 recent_queries : Arc :: new ( RwLock :: new ( VecDeque :: new ( ) ) ) ,
162+ session_id : session_id. to_string ( ) ,
162163 } ;
163164 Ok ( session)
164165 }
165166
166167 #[ cfg( feature = "state-store" ) ]
167168 pub async fn session_params (
168169 session_id : & str ,
169- state_store : Arc < StateStore > ,
170- ) -> Arc < SessionParams > {
170+ state_store : Arc < dyn StateStore > ,
171+ ) -> SessionParams {
171172 let session_params = SessionParams :: default ( ) ;
172173 #[ cfg( feature = "state-store" ) ]
173- match state_store
174- . get_session ( session_id)
174+ if let Some ( params) = Self :: get_session_state_params ( session_id, state_store) . await {
175+ session_params. set_properties ( params) ;
176+ }
177+ session_params
178+ }
179+
180+ #[ cfg( feature = "state-store" ) ]
181+ pub async fn get_session_state_params (
182+ session_id : & str ,
183+ state_store : Arc < dyn StateStore > ,
184+ ) -> Option < HashMap < String , SessionProperty > > {
185+ state_store. get_session ( session_id) . await . ok ( ) . map ( |sr| {
186+ sr. variables
187+ . into_iter ( )
188+ . map ( |( n, v) | ( n, state_store_variable_to_property ( v, session_id) ) )
189+ . collect ( )
190+ } )
191+ }
192+
193+ #[ cfg( feature = "state-store" ) ]
194+ pub async fn set_session_state_params (
195+ & self ,
196+ set : bool ,
197+ params : HashMap < String , SessionProperty > ,
198+ ) -> Result < ( ) > {
199+ let mut session_record =
200+ if let Ok ( session) = self . state_store . get_session ( & self . session_id ) . await {
201+ session
202+ } else {
203+ SessionRecord :: new ( & self . session_id )
204+ } ;
205+ let current_params: HashMap < String , SessionProperty > = session_record
206+ . variables
207+ . into_iter ( )
208+ . map ( |( n, v) | ( n, state_store_variable_to_property ( v, & self . session_id ) ) )
209+ . collect ( ) ;
210+ let session_params = SessionParams :: default ( ) ;
211+ session_params. set_properties ( current_params) ;
212+
213+ if set {
214+ session_params. set_properties ( params) ;
215+ } else {
216+ session_params. remove_properties ( params) ;
217+ }
218+ session_record. variables = session_params_to_state_variables ( & session_params) ;
219+ self . state_store
220+ . put_session ( session_record)
175221 . await
176222 . context ( ex_error:: StateStoreSnafu )
177- {
178- Ok ( session) => {
179- let params = session
180- . variables
181- . into_iter ( )
182- . map ( |v| {
183- (
184- v. name . clone ( ) ,
185- state_store_variable_to_property ( v, session_id) ,
186- )
187- } )
188- . collect ( ) ;
189- session_params. set_properties ( params) ;
190- }
191- Err ( _) => {
192- warn ! ( "Failed to retrieve session from state store for {session_id}" ) ;
193- }
194- }
195- Arc :: new ( session_params. clone ( ) )
196223 }
197224
198- pub fn set_database ( & self , database : & str ) -> Result < ( ) > {
199- self . set_variable ( "database" , database)
225+ pub async fn set_database ( & self , database : & str ) -> Result < ( ) > {
226+ self . set_variable ( "database" , database) . await
200227 }
201228
202- pub fn set_schema ( & self , schema : & str ) -> Result < ( ) > {
203- self . set_variable ( "schema" , schema)
229+ pub async fn set_schema ( & self , schema : & str ) -> Result < ( ) > {
230+ self . set_variable ( "schema" , schema) . await
204231 }
205232
206- pub fn set_warehouse ( & self , warehouse : & str ) -> Result < ( ) > {
207- self . set_variable ( "warehouse" , warehouse)
233+ pub async fn set_warehouse ( & self , warehouse : & str ) -> Result < ( ) > {
234+ self . set_variable ( "warehouse" , warehouse) . await
208235 }
209236
210237 #[ tracing:: instrument(
@@ -213,17 +240,19 @@ impl UserSession {
213240 skip( self ) ,
214241 err
215242 ) ]
216- fn set_variable ( & self , key : & str , value : & str ) -> Result < ( ) > {
243+ async fn set_variable ( & self , key : & str , value : & str ) -> Result < ( ) > {
217244 if key. is_empty ( ) || value. is_empty ( ) {
218245 return ex_error:: OnyUseWithVariablesSnafu . fail ( ) ;
219246 }
220- let session_id = self . ctx . session_id ( ) ;
221247 let params = HashMap :: from ( [ (
222248 key. to_string ( ) ,
223- SessionProperty :: from_str_value ( key. to_string ( ) , value. to_string ( ) , Some ( session_id) ) ,
249+ SessionProperty :: from_str_value (
250+ key. to_string ( ) ,
251+ value. to_string ( ) ,
252+ Some ( self . session_id . clone ( ) ) ,
253+ ) ,
224254 ) ] ) ;
225- self . set_session_variable ( true , params) ?;
226- Ok ( ( ) )
255+ self . set_session_variable ( true , params) . await
227256 }
228257
229258 pub fn query < S > ( self : & Arc < Self > , query : S , query_context : QueryContext ) -> UserQuery
@@ -233,11 +262,15 @@ impl UserSession {
233262 UserQuery :: new ( self . clone ( ) , query. into ( ) , query_context)
234263 }
235264
236- pub fn set_session_variable (
265+ #[ allow( clippy:: unused_async) ]
266+ pub async fn set_session_variable (
237267 & self ,
238268 set : bool ,
239269 params : HashMap < String , SessionProperty > ,
240270 ) -> Result < ( ) > {
271+ #[ cfg( feature = "state-store" ) ]
272+ self . set_session_state_params ( set, params. clone ( ) ) . await ?;
273+
241274 let state = self . ctx . state_ref ( ) ;
242275 let mut write = state. write ( ) ;
243276
@@ -263,8 +296,7 @@ impl UserSession {
263296 if set {
264297 cfg. set_properties ( session_params) ;
265298 } else {
266- cfg. remove_properties ( session_params)
267- . context ( ex_error:: DataFusionSnafu ) ?;
299+ cfg. remove_properties ( session_params) ;
268300 }
269301 }
270302 Ok ( ( ) )
@@ -315,7 +347,7 @@ pub fn parse_bool(value: &str) -> Option<bool> {
315347}
316348
317349#[ cfg( feature = "state-store" ) ]
318- #[ allow( clippy:: as_conversions) ]
350+ #[ allow( clippy:: as_conversions, clippy :: cast_possible_wrap ) ]
319351pub fn state_store_variable_to_property ( var : Variable , session_id : & str ) -> SessionProperty {
320352 SessionProperty {
321353 session_id : Some ( session_id. to_string ( ) ) ,
@@ -334,3 +366,37 @@ pub fn state_store_variable_to_property(var: Variable, session_id: &str) -> Sess
334366 name : var. name ,
335367 }
336368}
369+
370+ #[ cfg( feature = "state-store" ) ]
371+ #[ allow(
372+ clippy:: as_conversions,
373+ clippy:: cast_possible_wrap,
374+ clippy:: cast_sign_loss
375+ ) ]
376+ #[ must_use]
377+ pub fn session_params_to_state_variables (
378+ session_params : & SessionParams ,
379+ ) -> HashMap < String , Variable > {
380+ session_params
381+ . properties
382+ . iter ( )
383+ . map ( |entry| {
384+ let key = entry. key ( ) . clone ( ) ;
385+ let prop = entry. value ( ) . clone ( ) ;
386+
387+ let created_at = prop. created_on . timestamp ( ) as u64 ;
388+ let updated_at = Some ( prop. updated_on . timestamp ( ) as u64 ) ;
389+
390+ let var = Variable {
391+ name : prop. name ,
392+ value : prop. value ,
393+ value_type : prop. property_type ,
394+ comment : prop. comment ,
395+ created_at,
396+ updated_at,
397+ } ;
398+
399+ ( key, var)
400+ } )
401+ . collect ( )
402+ }
0 commit comments