Skip to content

Commit b5f9b56

Browse files
snowflake rest retries support (#24)
* replace oneshot by JoinHandle * switch to different execution model with results waiting until task joined * fix abort * simplify executor interface
1 parent 600c44d commit b5f9b56

14 files changed

Lines changed: 324 additions & 372 deletions

File tree

crates/api-snowflake-rest/src/server/error.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use datafusion::arrow::error::ArrowError;
66
use error_stack::ErrorChainExt;
77
use error_stack::ErrorExt;
88
use error_stack_trace;
9-
use executor::QueryRecordId;
9+
use executor::QueryId;
1010
use executor::error::OperationOn;
1111
use executor::error_code::ErrorCode;
1212
use executor::snowflake_error::Entity;
@@ -132,11 +132,11 @@ impl Error {
132132
InvalidAuthDataSnafu.build()
133133
}
134134

135-
pub fn query_id(&self) -> QueryRecordId {
135+
pub fn query_id(&self) -> QueryId {
136136
if let Self::Execution { source, .. } = self {
137137
source.query_id()
138138
} else {
139-
QueryRecordId::default()
139+
QueryId::default()
140140
}
141141
}
142142

@@ -236,8 +236,7 @@ impl Error {
236236
tracing::Span::current()
237237
.record("error_code", error_code.to_string())
238238
.record("sql_state", sql_state.to_string())
239-
.record("query_id", self.query_id().as_i64())
240-
.record("query_uuid", self.query_id().as_uuid().to_string())
239+
.record("query_id", self.query_id().to_string())
241240
.record("display_error", &display_error)
242241
.record("debug_error", self.debug_error_message())
243242
.record("error_stack_trace", self.output_msg())
@@ -260,7 +259,7 @@ impl Error {
260259
returned: None,
261260
query_result_format: None,
262261
// Query uuid is returned to the user
263-
query_id: Some(self.query_id().as_uuid().to_string()),
262+
query_id: Some(self.query_id().to_string()),
264263
}),
265264
code: Some(error_code.to_string()),
266265
});

crates/api-snowflake-rest/src/server/handlers.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ pub async fn login(
3434
name = "api_snowflake_rest::query",
3535
level = "debug",
3636
skip(state),
37-
fields(query_id, query_uuid),
37+
fields(query_id),
3838
err,
3939
ret(level = tracing::Level::TRACE),
4040
)]
@@ -65,9 +65,10 @@ pub async fn abort(
6565
request_id,
6666
}): Json<AbortRequestBody>,
6767
) -> Result<Json<serde_json::value::Value>> {
68-
state
68+
let query_id = state
6969
.execution_svc
70-
.abort_query(RunningQueryId::ByRequestId(request_id, sql_text))?;
70+
.locate_query_id(RunningQueryId::ByRequestId(request_id, sql_text))?;
71+
state.execution_svc.abort(query_id)?;
7172
Ok(Json(serde_json::value::Value::Null))
7273
}
7374

crates/api-snowflake-rest/src/server/helpers.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ fn records_to_json_string(recs: &[RecordBatch]) -> std::result::Result<String, E
7474
)]
7575
pub fn handle_query_ok_result(
7676
sql_text: &str,
77-
query_uuid: Uuid,
77+
query_id: Uuid,
7878
query_result: QueryResult,
7979
ser_fmt: DataSerializationFormat,
8080
) -> Result<JsonResponse> {
@@ -115,7 +115,7 @@ pub fn handle_query_ok_result(
115115
row_set_base_64,
116116
total: Some(total_rows),
117117
returned: Some(returned_rows),
118-
query_id: Some(query_uuid.to_string()),
118+
query_id: Some(query_id.to_string()),
119119
error_code: None,
120120
sql_state: Some(SqlState::Success.to_string()),
121121
}),

crates/api-snowflake-rest/src/server/logic.rs

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use crate::server::error::{
88
};
99
use crate::server::helpers::handle_query_ok_result;
1010
use api_snowflake_rest_sessions::helpers::{create_jwt, ensure_jwt_secret_is_valid, jwt_claims};
11+
use executor::RunningQueryId;
1112
use executor::models::QueryContext;
1213
use snafu::{OptionExt, ResultExt};
1314
use time::Duration;
@@ -87,11 +88,28 @@ pub async fn handle_query_request(
8788
return api_snowflake_rest_error::NotImplementedSnafu.fail();
8889
}
8990

90-
let query_uuid = query_context.query_id.as_uuid();
91-
let result = state
92-
.execution_svc
93-
.query(session_id, &sql_text, query_context)
94-
.await?;
91+
// find running query by request_id
92+
let session = state.execution_svc.get_session(session_id).await?;
93+
let query_id_res = session
94+
.running_queries
95+
.locate_query_id(RunningQueryId::ByRequestId(
96+
query.request_id,
97+
sql_text.clone(),
98+
));
9599

96-
handle_query_ok_result(&sql_text, query_uuid, result, serialization_format)
100+
let (result, query_id) = if query.retry_count.unwrap_or_default() > 0
101+
&& let Ok(query_id) = query_id_res
102+
{
103+
let result = state.execution_svc.wait(query_id).await?;
104+
(result, query_id)
105+
} else {
106+
let query_id = query_context.query_id;
107+
let result = state
108+
.execution_svc
109+
.query(session_id, &sql_text, query_context)
110+
.await?;
111+
(result, query_id)
112+
};
113+
114+
handle_query_ok_result(&sql_text, query_id, result, serialization_format)
97115
}

crates/executor/src/datafusion/rewriters/session_context.rs

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::models::QueryContext;
2-
use crate::query_types::QueryRecordId;
3-
use crate::running_queries::{RunningQueries, RunningQueryId};
2+
use crate::query_types::QueryId;
3+
use crate::running_queries::RunningQueries;
44
use datafusion::arrow::array::{ListArray, ListBuilder, StringBuilder};
55
use datafusion::logical_expr::{Expr, LogicalPlan};
66
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
@@ -19,7 +19,7 @@ pub struct SessionContextExprRewriter {
1919
pub session_id: String,
2020
pub version: String,
2121
pub query_context: QueryContext,
22-
pub recent_queries: Arc<RwLock<VecDeque<QueryRecordId>>>,
22+
pub recent_queries: Arc<RwLock<VecDeque<QueryId>>>,
2323
pub running_queries: Arc<dyn RunningQueries>,
2424
}
2525

@@ -64,24 +64,17 @@ impl SessionContextExprRewriter {
6464

6565
#[allow(clippy::needless_pass_by_value)]
6666
pub fn cancel_query(&self, query_id: String) -> Result<ScalarValue> {
67-
let Ok(query_uuid) = Uuid::from_str(&query_id) else {
67+
let Ok(query_id) = Uuid::from_str(&query_id) else {
6868
return Ok(utf8_val("Invalid UUID."));
6969
};
7070

71-
let query_id = QueryRecordId::from(query_uuid);
72-
if self
73-
.running_queries
74-
.abort(RunningQueryId::ByQueryId(query_id))
75-
.is_err()
76-
{
71+
let query_id = QueryId::from(query_id);
72+
if self.running_queries.abort(query_id).is_err() {
7773
Ok(utf8_val(
7874
"Identified SQL statement is not currently executing.",
7975
))
8076
} else {
81-
Ok(utf8_val(format!(
82-
"query [{}] terminated.",
83-
query_id.as_uuid()
84-
)))
77+
Ok(utf8_val(format!("query [{query_id}] terminated.")))
8578
}
8679
}
8780
}

crates/executor/src/error.rs

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use super::snowflake_error::SnowflakeError;
2-
use crate::query_types::{QueryRecordId, QueryStatus};
2+
use crate::query_types::{QueryId, QueryStatus};
33
use catalog::error::Error as CatalogError;
44
use datafusion_common::DataFusionError;
55
use error_stack_trace;
@@ -544,18 +544,18 @@ pub enum Error {
544544
location: Location,
545545
},
546546

547-
#[snafu(display("{}: Query execution error: {source}", query_id.as_uuid()))]
547+
#[snafu(display("{query_id}: Query execution error: {source}"))]
548548
QueryExecution {
549-
query_id: QueryRecordId,
549+
query_id: QueryId,
550550
#[snafu(source(from(Error, Box::new)))]
551551
source: Box<Error>,
552552
#[snafu(implicit)]
553553
location: Location,
554554
},
555555

556-
#[snafu(display("Query {} isn't running", query_id.as_uuid()))]
556+
#[snafu(display("Query {query_id} isn't running"))]
557557
QueryIsntRunning {
558-
query_id: QueryRecordId,
558+
query_id: QueryId,
559559
#[snafu(implicit)]
560560
location: Location,
561561
},
@@ -568,48 +568,32 @@ pub enum Error {
568568
},
569569

570570
// When user tried to get result before query finished
571-
#[snafu(display("Query {} is running", query_id.as_uuid()))]
571+
#[snafu(display("Query {query_id} is running"))]
572572
QueryIsRunning {
573-
query_id: QueryRecordId,
573+
query_id: QueryId,
574574
#[snafu(implicit)]
575575
location: Location,
576576
},
577577

578-
#[snafu(display("Query {} cancelled", query_id.as_uuid()))]
578+
#[snafu(display("Query {query_id} cancelled"))]
579579
QueryCancelled {
580-
query_id: QueryRecordId,
580+
query_id: QueryId,
581581
#[snafu(implicit)]
582582
location: Location,
583583
},
584584

585-
#[snafu(display("Query [{}] result sending error", query_id.as_uuid()))]
586-
QueryResultSend {
587-
query_id: QueryRecordId,
588-
#[snafu(implicit)]
589-
location: Location,
590-
},
591-
592-
#[snafu(display("Query [{}] result recv error: {error}", query_id.as_uuid()))]
593-
QueryResultRecv {
594-
query_id: QueryRecordId,
595-
#[snafu(source)]
596-
error: tokio::sync::oneshot::error::RecvError,
597-
#[snafu(implicit)]
598-
location: Location,
599-
},
600-
601-
#[snafu(display("Query [{}] result notify error: {error}", query_id.as_uuid()))]
585+
#[snafu(display("Query {query_id} result notify error: {error}"))]
602586
QueryStatusRecv {
603-
query_id: QueryRecordId,
587+
query_id: QueryId,
604588
#[snafu(source)]
605589
error: tokio::sync::watch::error::RecvError,
606590
#[snafu(implicit)]
607591
location: Location,
608592
},
609593

610-
#[snafu(display("Query [{}] status notify error: {error}", query_id.as_uuid()))]
594+
#[snafu(display("Query {query_id} status notify error: {error}"))]
611595
NotifyQueryStatus {
612-
query_id: QueryRecordId,
596+
query_id: QueryId,
613597
#[snafu(source)]
614598
error: tokio::sync::watch::error::SendError<QueryStatus>,
615599
#[snafu(implicit)]
@@ -624,14 +608,38 @@ pub enum Error {
624608
#[snafu(implicit)]
625609
location: Location,
626610
},
611+
612+
#[snafu(display("Failed to get async result for query [{query_id}]: {error}"))]
613+
AsyncResultTaskJoin {
614+
#[snafu(source)]
615+
error: tokio::task::JoinError,
616+
query_id: QueryId,
617+
#[snafu(implicit)]
618+
location: Location,
619+
},
620+
621+
#[snafu(display("Failed to join query subtask: {error}"))]
622+
QuerySubtaskJoin {
623+
#[snafu(source)]
624+
error: tokio::task::JoinError,
625+
#[snafu(implicit)]
626+
location: Location,
627+
},
628+
629+
#[snafu(display("Missing results handle in running query [{query_id:?}]"))]
630+
NoJoinHandle {
631+
query_id: QueryId,
632+
#[snafu(implicit)]
633+
location: Location,
634+
},
627635
}
628636

629637
impl Error {
630-
pub fn query_id(&self) -> QueryRecordId {
638+
pub fn query_id(&self) -> QueryId {
631639
if let Self::QueryExecution { query_id, .. } = self {
632640
*query_id
633641
} else {
634-
QueryRecordId::default()
642+
QueryId::default()
635643
}
636644
}
637645
#[must_use]

crates/executor/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ pub mod utils;
1717
pub mod tests;
1818

1919
pub use error::{Error, Result};
20-
pub use query_types::{QueryRecordId, QueryStatus};
20+
pub use query_types::{QueryId, QueryStatus};
2121
pub use running_queries::RunningQueryId;
2222
pub use snowflake_error::SnowflakeError;
2323

crates/executor/src/models.rs

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,19 @@
1-
use crate::query_types::{QueryRecordId, QueryStatus};
1+
use crate::query_types::QueryId;
22
use datafusion::arrow::array::RecordBatch;
33
use datafusion::arrow::datatypes::{DataType, Field, Schema as ArrowSchema, TimeUnit};
44
use datafusion_common::arrow::datatypes::Schema;
55
use functions::to_snowflake_datatype;
66
use serde::{Deserialize, Serialize};
77
use std::collections::HashMap;
88
use std::sync::Arc;
9-
use tokio::sync::oneshot;
109
use uuid::Uuid;
1110

12-
pub struct AsyncQueryHandle {
13-
pub query_id: QueryRecordId,
14-
pub rx: oneshot::Receiver<QueryResultStatus>,
15-
}
16-
1711
#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
1812
pub struct QueryContext {
1913
pub database: Option<String>,
2014
pub schema: Option<String>,
2115
pub worksheet_id: Option<i64>,
22-
pub query_id: QueryRecordId,
16+
pub query_id: QueryId,
2317
pub request_id: Option<Uuid>,
2418
pub ip_address: Option<String>,
2519
}
@@ -35,14 +29,14 @@ impl QueryContext {
3529
database,
3630
schema,
3731
worksheet_id,
38-
query_id: QueryRecordId::default(),
32+
query_id: QueryId::default(),
3933
request_id: None,
4034
ip_address: None,
4135
}
4236
}
4337

4438
#[must_use]
45-
pub const fn with_query_id(mut self, new_id: QueryRecordId) -> Self {
39+
pub const fn with_query_id(mut self, new_id: QueryId) -> Self {
4640
self.query_id = new_id;
4741
self
4842
}
@@ -80,12 +74,6 @@ impl QueryResult {
8074
}
8175
}
8276

83-
#[derive(Debug)]
84-
pub struct QueryResultStatus {
85-
pub query_result: std::result::Result<QueryResult, crate::Error>,
86-
pub status: QueryStatus,
87-
}
88-
8977
// TODO: We should not have serde dependency here
9078
// Instead it should be in api-snowflake-rest
9179
#[derive(Debug, Serialize, Deserialize, Clone)]

crates/executor/src/query.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,7 @@ impl UserQuery {
262262
skip(self),
263263
fields(
264264
statement,
265-
query_id = self.query_context.query_id.as_i64(),
266-
query_uuid = self.query_context.query_id.as_uuid().to_string(),
265+
query_id = self.query_context.query_id.to_string(),
267266
),
268267
err
269268
)]

0 commit comments

Comments
 (0)