Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ databend-driver = { workspace = true, features = ["rustls", "flight-sql"] }
databend-driver-core = { workspace = true }
tokio-stream = { workspace = true }

serde_json = "1.0"
ctor = "0.2"
env_logger = "0.11.8"
http = "1.0"
Expand Down
55 changes: 23 additions & 32 deletions bindings/python/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
use std::collections::BTreeMap;
use std::collections::HashMap;

use databend_driver::Param;
use databend_driver::Params;
use pyo3::exceptions::PyAttributeError;
use pyo3::types::PyTuple;
Expand Down Expand Up @@ -46,54 +45,46 @@ pub(crate) fn to_sql_params(v: Option<Bound<PyAny>>) -> Params {
let mut params = HashMap::new();
for (k, v) in v.iter() {
let k = k.extract::<String>().unwrap();
let v = to_sql_string(v).unwrap();
let v = py_to_json(v).unwrap();
params.insert(k, v);
}
Params::NamedParams(params)
} else if let Ok(v) = v.downcast::<PyList>() {
let mut params = vec![];
for v in v.iter() {
let v = to_sql_string(v).unwrap();
params.push(v);
}
let params: Vec<serde_json::Value> =
v.iter().map(|v| py_to_json(v).unwrap()).collect();
Params::QuestionParams(params)
} else if let Ok(v) = v.downcast::<PyTuple>() {
let mut params = vec![];
for v in v.iter() {
let v = to_sql_string(v).unwrap();
params.push(v);
}
let params: Vec<serde_json::Value> =
v.iter().map(|v| py_to_json(v).unwrap()).collect();
Params::QuestionParams(params)
} else {
Params::QuestionParams(vec![to_sql_string(v).unwrap()])
Params::QuestionParams(vec![py_to_json(v).unwrap()])
}
}
None => Params::default(),
}
}

fn to_sql_string(v: Bound<PyAny>) -> PyResult<String> {
fn py_to_json(v: Bound<PyAny>) -> PyResult<serde_json::Value> {
if v.is_none() {
return Ok("NULL".to_string());
return Ok(serde_json::Value::Null);
}
match v.downcast::<PyAny>() {
Ok(v) => {
if let Ok(v) = v.extract::<String>() {
Ok(v.as_sql_string())
} else if let Ok(v) = v.extract::<bool>() {
Ok(v.as_sql_string())
} else if let Ok(v) = v.extract::<i64>() {
Ok(v.as_sql_string())
} else if let Ok(v) = v.extract::<f64>() {
Ok(v.as_sql_string())
} else {
Err(PyAttributeError::new_err(format!(
"Invalid parameter type for: {v:?}, expected str, bool, int or float"
)))
}
}
Err(e) => Err(e.into()),
// Check bool before int (bool is a subclass of int in Python)
if let Ok(v) = v.extract::<bool>() {
return Ok(serde_json::Value::Bool(v));
}
if let Ok(v) = v.extract::<i64>() {
return Ok(serde_json::json!(v));
}
if let Ok(v) = v.extract::<f64>() {
return Ok(serde_json::json!(v));
}
if let Ok(v) = v.extract::<String>() {
return Ok(serde_json::Value::String(v));
}
Err(PyAttributeError::new_err(format!(
"Invalid parameter type for: {v:?}, expected str, bool, int or float"
)))
}

pub(super) fn options_as_ref(
Expand Down
3 changes: 3 additions & 0 deletions core/src/capability.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@ use semver::Version;
pub struct Capability {
pub streaming_load: bool,
pub arrow_data: bool,
pub server_side_params: bool,
}

impl Capability {
pub fn from_server_version(ver: &Version) -> Capability {
Capability {
streaming_load: ver > &Version::new(1, 2, 781),
arrow_data: ver > &Version::new(1, 2, 835),
// TODO: update version threshold when server PR gets a release version
server_side_params: ver > &Version::new(1, 2, 900),
}
}
}
32 changes: 24 additions & 8 deletions core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -553,9 +553,14 @@ impl APIClient {
}
}

pub async fn start_query(self: &Arc<Self>, sql: &str, need_progress: bool) -> Result<Pages> {
pub async fn start_query(
self: &Arc<Self>,
sql: &str,
need_progress: bool,
params: Option<serde_json::Value>,
) -> Result<Pages> {
info!("start query: {sql}");
let (resp, batches) = self.start_query_inner(sql, None, false).await?;
let (resp, batches) = self.start_query_inner(sql, None, false, params).await?;
Pages::new(self.clone(), resp, batches, need_progress)
}

Expand Down Expand Up @@ -589,6 +594,7 @@ impl APIClient {
sql: &str,
stage_attachment_config: Option<StageAttachmentConfig<'_>>,
force_json_body: bool,
params: Option<serde_json::Value>,
) -> Result<(QueryResponse, Vec<RecordBatch>)> {
if !self.in_active_transaction() {
self.route_hint.next();
Expand All @@ -601,7 +607,8 @@ impl APIClient {
let req = QueryRequest::new(sql)
.with_pagination(self.make_pagination())
.with_session(Some(session_state))
.with_stage_attachment(stage_attachment_config);
.with_stage_attachment(stage_attachment_config)
.with_params(params);

// headers
let query_id = self.gen_query_id();
Expand Down Expand Up @@ -766,16 +773,23 @@ impl APIClient {
Ok(())
}

pub async fn query_all(self: &Arc<Self>, sql: &str) -> Result<Page> {
self.query_all_inner(sql, false).await
pub async fn query_all(
self: &Arc<Self>,
sql: &str,
params: Option<serde_json::Value>,
) -> Result<Page> {
self.query_all_inner(sql, false, params).await
}

pub async fn query_all_inner(
self: &Arc<Self>,
sql: &str,
force_json_body: bool,
params: Option<serde_json::Value>,
) -> Result<Page> {
let (resp, batches) = self.start_query_inner(sql, None, force_json_body).await?;
let (resp, batches) = self
.start_query_inner(sql, None, force_json_body, params)
.await?;
let mut pages = Pages::new(self.clone(), resp, batches, false)?;
let mut all = Page::default();
while let Some(page) = pages.next().await {
Expand Down Expand Up @@ -842,7 +856,9 @@ impl APIClient {
file_format_options: Some(file_format_options),
copy_options: Some(copy_options),
});
let (resp, batches) = self.start_query_inner(sql, stage_attachment, true).await?;
let (resp, batches) = self
.start_query_inner(sql, stage_attachment, true, None)
.await?;
let mut pages = Pages::new(self.clone(), resp, batches, false)?;
let mut all = Page::default();
while let Some(page) = pages.next().await {
Expand All @@ -854,7 +870,7 @@ impl APIClient {
async fn get_presigned_upload_url(self: &Arc<Self>, stage: &str) -> Result<PresignedResponse> {
info!("get presigned upload url: {stage}");
let sql = format!("PRESIGN UPLOAD {stage}");
let resp = self.query_all_inner(&sql, true).await?;
let resp = self.query_all_inner(&sql, true, None).await?;
if resp.data.len() != 1 {
return Err(Error::Decode(
"Empty response from server for presigned request".to_string(),
Expand Down
22 changes: 22 additions & 0 deletions core/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ pub struct QueryRequest<'a> {
pagination: Option<PaginationConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
stage_attachment: Option<StageAttachmentConfig<'a>>,
#[serde(skip_serializing_if = "Option::is_none")]
params: Option<serde_json::Value>,
}

#[derive(Serialize, Debug)]
Expand Down Expand Up @@ -54,6 +56,7 @@ impl<'r, 't: 'r> QueryRequest<'r> {
sql,
pagination: None,
stage_attachment: None,
params: None,
}
}

Expand All @@ -74,6 +77,11 @@ impl<'r, 't: 'r> QueryRequest<'r> {
self.stage_attachment = stage_attachment;
self
}

pub fn with_params(mut self, params: Option<serde_json::Value>) -> Self {
self.params = params;
self
}
}

#[cfg(test)]
Expand Down Expand Up @@ -103,4 +111,18 @@ mod test {
);
Ok(())
}

#[test]
fn build_request_with_params() -> Result<()> {
let req = QueryRequest::new("SELECT ? + ?").with_params(Some(serde_json::json!([1, 2])));
assert_eq!(
serde_json::to_string(&req)?,
r#"{"sql":"SELECT ? + ?","params":[1,2]}"#
);

// params=None should not appear in serialized output
let req = QueryRequest::new("SELECT 1").with_params(None);
assert_eq!(serde_json::to_string(&req)?, r#"{"sql":"SELECT 1"}"#);
Ok(())
}
}
10 changes: 5 additions & 5 deletions core/tests/core/retry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ async fn retry_503_then_success() {

let dsn = build_dsn(port, 2, "");
let client = APIClient::new(&dsn, None).await.unwrap();
let result = client.query_all("select 42").await.unwrap();
let result = client.query_all("select 42", None).await.unwrap();

assert_eq!(result.data, vec![vec![Some("42".to_string())]]);
assert_eq!(requests.lock().unwrap().len(), 2);
Expand Down Expand Up @@ -303,7 +303,7 @@ async fn retry_401_with_access_token_file_reload_then_success() {
&format!("access_token_file={}", token_file.to_string_lossy()),
);
let client = APIClient::new(&dsn, None).await.unwrap();
let result = client.query_all("select 'reloaded'").await.unwrap();
let result = client.query_all("select 'reloaded'", None).await.unwrap();

assert_eq!(result.data, vec![vec![Some("reloaded".to_string())]]);
assert_eq!(requests.lock().unwrap().len(), 2);
Expand All @@ -329,7 +329,7 @@ async fn retry_401_auth_reload_stops_at_max_retries() {
&format!("access_token_file={}", token_file.to_string_lossy()),
);
let client = APIClient::new(&dsn, None).await.unwrap();
let err = match client.query_all("select 1").await {
let err = match client.query_all("select 1", None).await {
Ok(_) => panic!("expected unauthorized error"),
Err(err) => err,
};
Expand Down Expand Up @@ -360,7 +360,7 @@ async fn start_query_404_keeps_logic_error() {

let dsn = build_dsn(port, 2, "");
let client = APIClient::new(&dsn, None).await.unwrap();
let err = match client.query_all("select 1").await {
let err = match client.query_all("select 1", None).await {
Ok(_) => panic!("expected logic error"),
Err(err) => err,
};
Expand Down Expand Up @@ -403,7 +403,7 @@ async fn query_page_404_maps_to_query_not_found() {

let dsn = build_dsn(port, 2, "");
let client = APIClient::new(&dsn, None).await.unwrap();
let err = match client.query_all("select 1").await {
let err = match client.query_all("select 1", None).await {
Ok(_) => panic!("expected QueryNotFound"),
Err(err) => err,
};
Expand Down
5 changes: 4 additions & 1 deletion core/tests/core/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ use crate::common::DEFAULT_DSN;
async fn select_simple() {
let dsn = option_env!("TEST_DATABEND_DSN").unwrap_or(DEFAULT_DSN);
let client = APIClient::new(dsn, None).await.unwrap();
let mut pages = client.start_query("select 15532", true).await.unwrap();
let mut pages = client
.start_query("select 15532", true, None)
.await
.unwrap();
let page = pages.next().await.unwrap().unwrap();
assert_eq!(page.data, [[Some("15532".to_string())]]);
}
6 changes: 3 additions & 3 deletions core/tests/core/stage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async fn insert_with_stage(presign: bool) {
.await
.unwrap();
let sql = format!("CREATE TABLE `{table}` (id UInt64, city String, number UInt64)");
client.query_all(&sql).await.unwrap();
client.query_all(&sql, None).await.unwrap();

let sql = format!("INSERT INTO `{table}` VALUES");
let file_format_options = vec![
Expand All @@ -68,7 +68,7 @@ async fn insert_with_stage(presign: bool) {
.unwrap();

let sql = format!("SELECT * FROM `{table}`");
let resp = client.query_all(&sql).await.unwrap();
let resp = client.query_all(&sql, None).await.unwrap();
assert_eq!(resp.data.len(), 6);
let expect = [
["1", "Beijing", "100"],
Expand All @@ -90,7 +90,7 @@ async fn insert_with_stage(presign: bool) {
assert_eq!(result, expect);

let sql = format!("DROP TABLE `{table}`;");
client.query_all(&sql).await.unwrap();
client.query_all(&sql, None).await.unwrap();
}

#[tokio::test]
Expand Down
Loading
Loading