Skip to content

Commit abb4fb4

Browse files
committed
Add trait
1 parent 3bf1bda commit abb4fb4

13 files changed

Lines changed: 164 additions & 80 deletions

File tree

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,4 @@ jobs:
5858
- name: cargo test
5959
env:
6060
COLUMNS: "80"
61-
run: RUSTFLAGS="-C linker=clang -C link-arg=-fuse-ld=lld" cargo test --profile=ci --workspace --all-features --all-targets
61+
run: RUSTFLAGS="-C linker=clang -C link-arg=-fuse-ld=lld" cargo test --profile=ci --workspace --default-features --all-targets

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/api-snowflake-rest/src/server/logic.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,18 +58,20 @@ pub async fn handle_login_request(
5858

5959
// set database, schema when provided
6060
if let Some(db) = params.database_name {
61-
session.set_database(&db).context(SetVariableSnafu {
61+
session.set_database(&db).await.context(SetVariableSnafu {
6262
variable: "database",
6363
})?;
6464
}
6565
if let Some(schema) = params.schema_name {
6666
session
6767
.set_schema(&schema)
68+
.await
6869
.context(SetVariableSnafu { variable: "schema" })?;
6970
}
7071
if let Some(warehouse) = params.warehouse {
7172
session
7273
.set_warehouse(&warehouse)
74+
.await
7375
.context(SetVariableSnafu {
7476
variable: "warehouse",
7577
})?;

crates/executor/src/query.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ impl UserQuery {
283283
})
284284
.collect();
285285

286-
self.session.set_session_variable(set, params)?;
286+
self.session.set_session_variable(set, params).await?;
287287
return self.status_response();
288288
}
289289
Statement::Use(entity) => {
@@ -432,7 +432,7 @@ impl UserQuery {
432432
Some(self.session.ctx.session_id()),
433433
),
434434
)]);
435-
self.session.set_session_variable(true, params)?;
435+
self.session.set_session_variable(true, params).await?;
436436
self.status_response()
437437
}
438438

@@ -489,7 +489,9 @@ impl UserQuery {
489489
}?;
490490
session_params.insert(key, session_value);
491491
}
492-
self.session.set_session_variable(true, session_params)?;
492+
self.session
493+
.set_session_variable(true, session_params)
494+
.await?;
493495
self.status_response()
494496
}
495497

@@ -538,7 +540,9 @@ impl UserQuery {
538540
session_params.insert(name, session_value);
539541
}
540542

541-
self.session.set_session_variable(true, session_params)?;
543+
self.session
544+
.set_session_variable(true, session_params)
545+
.await?;
542546
}
543547

544548
self.status_response()

crates/executor/src/service.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ use crate::utils::{Config, MemPoolType};
3434
use catalog::catalog_list::EmbucketCatalogList;
3535
use catalog_metastore::{InMemoryMetastore, Metastore, TableIdent as MetastoreTableIdent};
3636
#[cfg(feature = "state-store")]
37-
use state_store::StateStore;
37+
use state_store::{DynamoDbStateStore, StateStore};
3838
use tokio::sync::RwLock;
3939
use tokio::time::Duration;
4040
use tracing::Instrument;
@@ -149,7 +149,7 @@ pub struct CoreExecutionService {
149149
runtime_env: Arc<RuntimeEnv>,
150150
queries: Arc<RunningQueriesRegistry>,
151151
#[cfg(feature = "state-store")]
152-
state_store: Arc<StateStore>,
152+
state_store: Arc<dyn StateStore>,
153153
}
154154

155155
impl CoreExecutionService {
@@ -165,7 +165,7 @@ impl CoreExecutionService {
165165
let catalog_list = Self::catalog_list(metastore.clone(), &config).await?;
166166
let runtime_env = Self::runtime_env(&config, catalog_list.clone())?;
167167
#[cfg(feature = "state-store")]
168-
let state_store = StateStore::new_from_env()
168+
let state_store = DynamoDbStateStore::new_from_env()
169169
.await
170170
.context(ex_error::StateStoreSnafu)?;
171171
Ok(Self {
@@ -283,7 +283,6 @@ impl ExecutionService for CoreExecutionService {
283283
self.config.clone(),
284284
self.catalog_list.clone(),
285285
self.runtime_env.clone(),
286-
#[cfg(feature = "state-store")]
287286
session_id,
288287
#[cfg(feature = "state-store")]
289288
self.state_store.clone(),
@@ -298,7 +297,7 @@ impl ExecutionService for CoreExecutionService {
298297

299298
#[cfg(feature = "state-store")]
300299
self.state_store
301-
.put_new_session(session_id.to_string())
300+
.put_new_session(session_id)
302301
.await
303302
.context(ex_error::StateStoreSnafu)?;
304303

crates/executor/src/session.rs

Lines changed: 111 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,13 @@ use functions::session_params::{SessionParams, SessionProperty};
3131
use functions::table::register_udtfs;
3232
use snafu::ResultExt;
3333
#[cfg(feature = "state-store")]
34-
use state_store::{StateStore, Variable};
34+
use state_store::{SessionRecord, StateStore, Variable};
3535
use std::collections::{HashMap, VecDeque};
3636
use std::num::NonZero;
3737
use std::sync::atomic::AtomicI64;
3838
use std::sync::{Arc, RwLock};
3939
use std::thread::available_parallelism;
4040
use time::{Duration, OffsetDateTime};
41-
#[cfg(feature = "state-store")]
42-
use tracing::warn;
4341

4442
pub const SESSION_INACTIVITY_EXPIRATION_SECONDS: i64 = 5 * 60;
4543
static MINIMUM_PARALLEL_OUTPUT_FILES: usize = 1;
@@ -54,7 +52,7 @@ pub const fn to_unix(t: OffsetDateTime) -> i64 {
5452
pub struct UserSession {
5553
pub metastore: Arc<dyn Metastore>,
5654
#[cfg(feature = "state-store")]
57-
state_store: Arc<StateStore>,
55+
state_store: Arc<dyn StateStore>,
5856
// running_queries contains all the queries running across sessions
5957
pub running_queries: Arc<dyn RunningQueries>,
6058
pub ctx: SessionContext,
@@ -65,6 +63,7 @@ pub struct UserSession {
6563
pub expiry: AtomicI64,
6664
pub session_params: Arc<SessionParams>,
6765
pub recent_queries: Arc<RwLock<VecDeque<QueryId>>>,
66+
pub session_id: String,
6867
}
6968

7069
impl UserSession {
@@ -75,8 +74,8 @@ impl UserSession {
7574
config: Arc<Config>,
7675
catalog_list: Arc<EmbucketCatalogList>,
7776
runtime_env: Arc<RuntimeEnv>,
78-
#[cfg(feature = "state-store")] session_id: &str,
79-
#[cfg(feature = "state-store")] state_store: Arc<StateStore>,
77+
session_id: &str,
78+
#[cfg(feature = "state-store")] state_store: Arc<dyn StateStore>,
8079
) -> Result<Self> {
8180
let sql_parser_dialect = config
8281
.sql_parser_dialect
@@ -91,7 +90,8 @@ impl UserSession {
9190

9291
let session_params = SessionParams::default();
9392
#[cfg(feature = "state-store")]
94-
let session_params_arc = Self::session_params(session_id, state_store.clone()).await;
93+
let session_params_arc =
94+
Arc::new(Self::session_params(session_id, state_store.clone()).await);
9595
#[cfg(not(feature = "state-store"))]
9696
let session_params_arc = Arc::new(session_params.clone());
9797
let mut config_options = ConfigOptions::from_env().context(ex_error::DataFusionSnafu)?;
@@ -159,52 +159,79 @@ impl UserSession {
159159
)),
160160
session_params: session_params_arc,
161161
recent_queries: Arc::new(RwLock::new(VecDeque::new())),
162+
session_id: session_id.to_string(),
162163
};
163164
Ok(session)
164165
}
165166

166167
#[cfg(feature = "state-store")]
167168
pub async fn session_params(
168169
session_id: &str,
169-
state_store: Arc<StateStore>,
170-
) -> Arc<SessionParams> {
170+
state_store: Arc<dyn StateStore>,
171+
) -> SessionParams {
171172
let session_params = SessionParams::default();
172173
#[cfg(feature = "state-store")]
173-
match state_store
174-
.get_session(session_id)
174+
if let Some(params) = Self::get_session_state_params(session_id, state_store).await {
175+
session_params.set_properties(params);
176+
}
177+
session_params
178+
}
179+
180+
#[cfg(feature = "state-store")]
181+
pub async fn get_session_state_params(
182+
session_id: &str,
183+
state_store: Arc<dyn StateStore>,
184+
) -> Option<HashMap<String, SessionProperty>> {
185+
state_store.get_session(session_id).await.ok().map(|sr| {
186+
sr.variables
187+
.into_iter()
188+
.map(|(n, v)| (n, state_store_variable_to_property(v, session_id)))
189+
.collect()
190+
})
191+
}
192+
193+
#[cfg(feature = "state-store")]
194+
pub async fn set_session_state_params(
195+
&self,
196+
set: bool,
197+
params: HashMap<String, SessionProperty>,
198+
) -> Result<()> {
199+
let mut session_record =
200+
if let Ok(session) = self.state_store.get_session(&self.session_id).await {
201+
session
202+
} else {
203+
SessionRecord::new(&self.session_id)
204+
};
205+
let current_params: HashMap<String, SessionProperty> = session_record
206+
.variables
207+
.into_iter()
208+
.map(|(n, v)| (n, state_store_variable_to_property(v, &self.session_id)))
209+
.collect();
210+
let session_params = SessionParams::default();
211+
session_params.set_properties(current_params);
212+
213+
if set {
214+
session_params.set_properties(params);
215+
} else {
216+
session_params.remove_properties(params);
217+
}
218+
session_record.variables = session_params_to_state_variables(&session_params);
219+
self.state_store
220+
.put_session(session_record)
175221
.await
176222
.context(ex_error::StateStoreSnafu)
177-
{
178-
Ok(session) => {
179-
let params = session
180-
.variables
181-
.into_iter()
182-
.map(|v| {
183-
(
184-
v.name.clone(),
185-
state_store_variable_to_property(v, session_id),
186-
)
187-
})
188-
.collect();
189-
session_params.set_properties(params);
190-
}
191-
Err(_) => {
192-
warn!("Failed to retrieve session from state store for {session_id}");
193-
}
194-
}
195-
Arc::new(session_params.clone())
196223
}
197224

198-
pub fn set_database(&self, database: &str) -> Result<()> {
199-
self.set_variable("database", database)
225+
pub async fn set_database(&self, database: &str) -> Result<()> {
226+
self.set_variable("database", database).await
200227
}
201228

202-
pub fn set_schema(&self, schema: &str) -> Result<()> {
203-
self.set_variable("schema", schema)
229+
pub async fn set_schema(&self, schema: &str) -> Result<()> {
230+
self.set_variable("schema", schema).await
204231
}
205232

206-
pub fn set_warehouse(&self, warehouse: &str) -> Result<()> {
207-
self.set_variable("warehouse", warehouse)
233+
pub async fn set_warehouse(&self, warehouse: &str) -> Result<()> {
234+
self.set_variable("warehouse", warehouse).await
208235
}
209236

210237
#[tracing::instrument(
@@ -213,17 +240,19 @@ impl UserSession {
213240
skip(self),
214241
err
215242
)]
216-
fn set_variable(&self, key: &str, value: &str) -> Result<()> {
243+
async fn set_variable(&self, key: &str, value: &str) -> Result<()> {
217244
if key.is_empty() || value.is_empty() {
218245
return ex_error::OnyUseWithVariablesSnafu.fail();
219246
}
220-
let session_id = self.ctx.session_id();
221247
let params = HashMap::from([(
222248
key.to_string(),
223-
SessionProperty::from_str_value(key.to_string(), value.to_string(), Some(session_id)),
249+
SessionProperty::from_str_value(
250+
key.to_string(),
251+
value.to_string(),
252+
Some(self.session_id.clone()),
253+
),
224254
)]);
225-
self.set_session_variable(true, params)?;
226-
Ok(())
255+
self.set_session_variable(true, params).await
227256
}
228257

229258
pub fn query<S>(self: &Arc<Self>, query: S, query_context: QueryContext) -> UserQuery
@@ -233,11 +262,15 @@ impl UserSession {
233262
UserQuery::new(self.clone(), query.into(), query_context)
234263
}
235264

236-
pub fn set_session_variable(
265+
#[allow(clippy::unused_async)]
266+
pub async fn set_session_variable(
237267
&self,
238268
set: bool,
239269
params: HashMap<String, SessionProperty>,
240270
) -> Result<()> {
271+
#[cfg(feature = "state-store")]
272+
self.set_session_state_params(set, params.clone()).await?;
273+
241274
let state = self.ctx.state_ref();
242275
let mut write = state.write();
243276

@@ -263,8 +296,7 @@ impl UserSession {
263296
if set {
264297
cfg.set_properties(session_params);
265298
} else {
266-
cfg.remove_properties(session_params)
267-
.context(ex_error::DataFusionSnafu)?;
299+
cfg.remove_properties(session_params);
268300
}
269301
}
270302
Ok(())
@@ -315,7 +347,7 @@ pub fn parse_bool(value: &str) -> Option<bool> {
315347
}
316348

317349
#[cfg(feature = "state-store")]
318-
#[allow(clippy::as_conversions)]
350+
#[allow(clippy::as_conversions, clippy::cast_possible_wrap)]
319351
pub fn state_store_variable_to_property(var: Variable, session_id: &str) -> SessionProperty {
320352
SessionProperty {
321353
session_id: Some(session_id.to_string()),
@@ -334,3 +366,37 @@ pub fn state_store_variable_to_property(var: Variable, session_id: &str) -> Sess
334366
name: var.name,
335367
}
336368
}
369+
370+
#[cfg(feature = "state-store")]
371+
#[allow(
372+
clippy::as_conversions,
373+
clippy::cast_possible_wrap,
374+
clippy::cast_sign_loss
375+
)]
376+
#[must_use]
377+
pub fn session_params_to_state_variables(
378+
session_params: &SessionParams,
379+
) -> HashMap<String, Variable> {
380+
session_params
381+
.properties
382+
.iter()
383+
.map(|entry| {
384+
let key = entry.key().clone();
385+
let prop = entry.value().clone();
386+
387+
let created_at = prop.created_on.timestamp() as u64;
388+
let updated_at = Some(prop.updated_on.timestamp() as u64);
389+
390+
let var = Variable {
391+
name: prop.name,
392+
value: prop.value,
393+
value_type: prop.property_type,
394+
comment: prop.comment,
395+
created_at,
396+
updated_at,
397+
};
398+
399+
(key, var)
400+
})
401+
.collect()
402+
}

0 commit comments

Comments
 (0)