Skip to content

Commit eeb1d59

Browse files
committed
feat: support server-side parameter binding via /v1/query params field
ref #759 When the server supports it (version > 1.2.900), send raw SQL with a JSON `params` field instead of interpolating parameters client-side. Falls back to client-side interpolation for older servers or when SQL contains `$N` column position placeholders (which the server uses for stage column refs). Changes: - Add `params` field to `QueryRequest` (core) - Add `server_side_params` capability flag with version threshold - Thread params through `start_query` / `query_all` in core client - Add `Params::to_json_value()` with `sql_string_to_json()` reverse parser - Add `PlaceholderVisitor::has_dollar_positions()` for `$N` detection - Add `*_with_params()` methods to `IConnection` trait with defaults - Override in `RestAPIConnection` to pass params to server - Route in `QueryBuilder`/`ExecBuilder`: server-side when supported and no `$N`, client-side otherwise - Add `to_json_params()` helper in Python bindings for future use
1 parent 45a7eef commit eeb1d59

13 files changed

Lines changed: 357 additions & 21 deletions

File tree

bindings/python/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ databend-driver = { workspace = true, features = ["rustls", "flight-sql"] }
1919
databend-driver-core = { workspace = true }
2020
tokio-stream = { workspace = true }
2121

22+
serde_json = "1.0"
2223
ctor = "0.2"
2324
env_logger = "0.11.8"
2425
http = "1.0"

bindings/python/src/utils.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,56 @@ fn to_sql_string(v: Bound<PyAny>) -> PyResult<String> {
9696
}
9797
}
9898

99+
/// Convert Python params directly to JSON values, preserving native types.
100+
pub(crate) fn to_json_params(v: Option<Bound<PyAny>>) -> Option<serde_json::Value> {
101+
match v {
102+
Some(v) => {
103+
if let Ok(v) = v.downcast::<PyDict>() {
104+
let mut map = serde_json::Map::new();
105+
for (k, v) in v.iter() {
106+
let k = k.extract::<String>().unwrap();
107+
let v = py_to_json(v).unwrap();
108+
map.insert(k, v);
109+
}
110+
Some(serde_json::Value::Object(map))
111+
} else if let Ok(v) = v.downcast::<PyList>() {
112+
let arr: Vec<serde_json::Value> =
113+
v.iter().map(|v| py_to_json(v).unwrap()).collect();
114+
Some(serde_json::Value::Array(arr))
115+
} else if let Ok(v) = v.downcast::<PyTuple>() {
116+
let arr: Vec<serde_json::Value> =
117+
v.iter().map(|v| py_to_json(v).unwrap()).collect();
118+
Some(serde_json::Value::Array(arr))
119+
} else {
120+
Some(serde_json::Value::Array(vec![py_to_json(v).unwrap()]))
121+
}
122+
}
123+
None => None,
124+
}
125+
}
126+
127+
fn py_to_json(v: Bound<PyAny>) -> PyResult<serde_json::Value> {
128+
if v.is_none() {
129+
return Ok(serde_json::Value::Null);
130+
}
131+
// Check bool before int (bool is a subclass of int in Python)
132+
if let Ok(v) = v.extract::<bool>() {
133+
return Ok(serde_json::Value::Bool(v));
134+
}
135+
if let Ok(v) = v.extract::<i64>() {
136+
return Ok(serde_json::json!(v));
137+
}
138+
if let Ok(v) = v.extract::<f64>() {
139+
return Ok(serde_json::json!(v));
140+
}
141+
if let Ok(v) = v.extract::<String>() {
142+
return Ok(serde_json::Value::String(v));
143+
}
144+
Err(PyAttributeError::new_err(format!(
145+
"Invalid parameter type for: {v:?}, expected str, bool, int or float"
146+
)))
147+
}
148+
99149
pub(super) fn options_as_ref(
100150
format_options: &Option<BTreeMap<String, String>>,
101151
) -> Option<BTreeMap<&str, &str>> {

core/src/capability.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,16 @@ use semver::Version;
1818
pub struct Capability {
1919
pub streaming_load: bool,
2020
pub arrow_data: bool,
21+
pub server_side_params: bool,
2122
}
2223

2324
impl Capability {
2425
pub fn from_server_version(ver: &Version) -> Capability {
2526
Capability {
2627
streaming_load: ver > &Version::new(1, 2, 781),
2728
arrow_data: ver > &Version::new(1, 2, 835),
29+
// TODO: update version threshold when server PR gets a release version
30+
server_side_params: ver > &Version::new(1, 2, 900),
2831
}
2932
}
3033
}

core/src/client.rs

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -553,9 +553,14 @@ impl APIClient {
553553
}
554554
}
555555

556-
pub async fn start_query(self: &Arc<Self>, sql: &str, need_progress: bool) -> Result<Pages> {
556+
pub async fn start_query(
557+
self: &Arc<Self>,
558+
sql: &str,
559+
need_progress: bool,
560+
params: Option<serde_json::Value>,
561+
) -> Result<Pages> {
557562
info!("start query: {sql}");
558-
let (resp, batches) = self.start_query_inner(sql, None, false).await?;
563+
let (resp, batches) = self.start_query_inner(sql, None, false, params).await?;
559564
Pages::new(self.clone(), resp, batches, need_progress)
560565
}
561566

@@ -589,6 +594,7 @@ impl APIClient {
589594
sql: &str,
590595
stage_attachment_config: Option<StageAttachmentConfig<'_>>,
591596
force_json_body: bool,
597+
params: Option<serde_json::Value>,
592598
) -> Result<(QueryResponse, Vec<RecordBatch>)> {
593599
if !self.in_active_transaction() {
594600
self.route_hint.next();
@@ -601,7 +607,8 @@ impl APIClient {
601607
let req = QueryRequest::new(sql)
602608
.with_pagination(self.make_pagination())
603609
.with_session(Some(session_state))
604-
.with_stage_attachment(stage_attachment_config);
610+
.with_stage_attachment(stage_attachment_config)
611+
.with_params(params);
605612

606613
// headers
607614
let query_id = self.gen_query_id();
@@ -766,16 +773,23 @@ impl APIClient {
766773
Ok(())
767774
}
768775

769-
pub async fn query_all(self: &Arc<Self>, sql: &str) -> Result<Page> {
770-
self.query_all_inner(sql, false).await
776+
pub async fn query_all(
777+
self: &Arc<Self>,
778+
sql: &str,
779+
params: Option<serde_json::Value>,
780+
) -> Result<Page> {
781+
self.query_all_inner(sql, false, params).await
771782
}
772783

773784
pub async fn query_all_inner(
774785
self: &Arc<Self>,
775786
sql: &str,
776787
force_json_body: bool,
788+
params: Option<serde_json::Value>,
777789
) -> Result<Page> {
778-
let (resp, batches) = self.start_query_inner(sql, None, force_json_body).await?;
790+
let (resp, batches) = self
791+
.start_query_inner(sql, None, force_json_body, params)
792+
.await?;
779793
let mut pages = Pages::new(self.clone(), resp, batches, false)?;
780794
let mut all = Page::default();
781795
while let Some(page) = pages.next().await {
@@ -842,7 +856,9 @@ impl APIClient {
842856
file_format_options: Some(file_format_options),
843857
copy_options: Some(copy_options),
844858
});
845-
let (resp, batches) = self.start_query_inner(sql, stage_attachment, true).await?;
859+
let (resp, batches) = self
860+
.start_query_inner(sql, stage_attachment, true, None)
861+
.await?;
846862
let mut pages = Pages::new(self.clone(), resp, batches, false)?;
847863
let mut all = Page::default();
848864
while let Some(page) = pages.next().await {
@@ -854,7 +870,7 @@ impl APIClient {
854870
async fn get_presigned_upload_url(self: &Arc<Self>, stage: &str) -> Result<PresignedResponse> {
855871
info!("get presigned upload url: {stage}");
856872
let sql = format!("PRESIGN UPLOAD {stage}");
857-
let resp = self.query_all_inner(&sql, true).await?;
873+
let resp = self.query_all_inner(&sql, true, None).await?;
858874
if resp.data.len() != 1 {
859875
return Err(Error::Decode(
860876
"Empty response from server for presigned request".to_string(),

core/src/request.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ pub struct QueryRequest<'a> {
2626
pagination: Option<PaginationConfig>,
2727
#[serde(skip_serializing_if = "Option::is_none")]
2828
stage_attachment: Option<StageAttachmentConfig<'a>>,
29+
#[serde(skip_serializing_if = "Option::is_none")]
30+
params: Option<serde_json::Value>,
2931
}
3032

3133
#[derive(Serialize, Debug)]
@@ -54,6 +56,7 @@ impl<'r, 't: 'r> QueryRequest<'r> {
5456
sql,
5557
pagination: None,
5658
stage_attachment: None,
59+
params: None,
5760
}
5861
}
5962

@@ -74,6 +77,11 @@ impl<'r, 't: 'r> QueryRequest<'r> {
7477
self.stage_attachment = stage_attachment;
7578
self
7679
}
80+
81+
pub fn with_params(mut self, params: Option<serde_json::Value>) -> Self {
82+
self.params = params;
83+
self
84+
}
7785
}
7886

7987
#[cfg(test)]
@@ -103,4 +111,18 @@ mod test {
103111
);
104112
Ok(())
105113
}
114+
115+
#[test]
116+
fn build_request_with_params() -> Result<()> {
117+
let req = QueryRequest::new("SELECT ? + ?").with_params(Some(serde_json::json!([1, 2])));
118+
assert_eq!(
119+
serde_json::to_string(&req)?,
120+
r#"{"sql":"SELECT ? + ?","params":[1,2]}"#
121+
);
122+
123+
// params=None should not appear in serialized output
124+
let req = QueryRequest::new("SELECT 1").with_params(None);
125+
assert_eq!(serde_json::to_string(&req)?, r#"{"sql":"SELECT 1"}"#);
126+
Ok(())
127+
}
106128
}

core/tests/core/retry.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ async fn retry_503_then_success() {
262262

263263
let dsn = build_dsn(port, 2, "");
264264
let client = APIClient::new(&dsn, None).await.unwrap();
265-
let result = client.query_all("select 42").await.unwrap();
265+
let result = client.query_all("select 42", None).await.unwrap();
266266

267267
assert_eq!(result.data, vec![vec![Some("42".to_string())]]);
268268
assert_eq!(requests.lock().unwrap().len(), 2);
@@ -303,7 +303,7 @@ async fn retry_401_with_access_token_file_reload_then_success() {
303303
&format!("access_token_file={}", token_file.to_string_lossy()),
304304
);
305305
let client = APIClient::new(&dsn, None).await.unwrap();
306-
let result = client.query_all("select 'reloaded'").await.unwrap();
306+
let result = client.query_all("select 'reloaded'", None).await.unwrap();
307307

308308
assert_eq!(result.data, vec![vec![Some("reloaded".to_string())]]);
309309
assert_eq!(requests.lock().unwrap().len(), 2);
@@ -329,7 +329,7 @@ async fn retry_401_auth_reload_stops_at_max_retries() {
329329
&format!("access_token_file={}", token_file.to_string_lossy()),
330330
);
331331
let client = APIClient::new(&dsn, None).await.unwrap();
332-
let err = match client.query_all("select 1").await {
332+
let err = match client.query_all("select 1", None).await {
333333
Ok(_) => panic!("expected unauthorized error"),
334334
Err(err) => err,
335335
};
@@ -360,7 +360,7 @@ async fn start_query_404_keeps_logic_error() {
360360

361361
let dsn = build_dsn(port, 2, "");
362362
let client = APIClient::new(&dsn, None).await.unwrap();
363-
let err = match client.query_all("select 1").await {
363+
let err = match client.query_all("select 1", None).await {
364364
Ok(_) => panic!("expected logic error"),
365365
Err(err) => err,
366366
};
@@ -403,7 +403,7 @@ async fn query_page_404_maps_to_query_not_found() {
403403

404404
let dsn = build_dsn(port, 2, "");
405405
let client = APIClient::new(&dsn, None).await.unwrap();
406-
let err = match client.query_all("select 1").await {
406+
let err = match client.query_all("select 1", None).await {
407407
Ok(_) => panic!("expected QueryNotFound"),
408408
Err(err) => err,
409409
};

core/tests/core/simple.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@ use crate::common::DEFAULT_DSN;
2121
async fn select_simple() {
2222
let dsn = option_env!("TEST_DATABEND_DSN").unwrap_or(DEFAULT_DSN);
2323
let client = APIClient::new(dsn, None).await.unwrap();
24-
let mut pages = client.start_query("select 15532", true).await.unwrap();
24+
let mut pages = client
25+
.start_query("select 15532", true, None)
26+
.await
27+
.unwrap();
2528
let page = pages.next().await.unwrap().unwrap();
2629
assert_eq!(page.data, [[Some("15532".to_string())]]);
2730
}

core/tests/core/stage.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ async fn insert_with_stage(presign: bool) {
4848
.await
4949
.unwrap();
5050
let sql = format!("CREATE TABLE `{table}` (id UInt64, city String, number UInt64)");
51-
client.query_all(&sql).await.unwrap();
51+
client.query_all(&sql, None).await.unwrap();
5252

5353
let sql = format!("INSERT INTO `{table}` VALUES");
5454
let file_format_options = vec![
@@ -68,7 +68,7 @@ async fn insert_with_stage(presign: bool) {
6868
.unwrap();
6969

7070
let sql = format!("SELECT * FROM `{table}`");
71-
let resp = client.query_all(&sql).await.unwrap();
71+
let resp = client.query_all(&sql, None).await.unwrap();
7272
assert_eq!(resp.data.len(), 6);
7373
let expect = [
7474
["1", "Beijing", "100"],
@@ -90,7 +90,7 @@ async fn insert_with_stage(presign: bool) {
9090
assert_eq!(result, expect);
9191

9292
let sql = format!("DROP TABLE `{table}`;");
93-
client.query_all(&sql).await.unwrap();
93+
client.query_all(&sql, None).await.unwrap();
9494
}
9595

9696
#[tokio::test]

0 commit comments

Comments
 (0)