Skip to content

Commit 7f95c7a

Browse files
authored
Merge pull request #209 from cipherstash/feat/mapper/group-by
feat: SQL transformations related to `GROUP BY`
2 parents 71a6655 + 9e4d145 commit 7f95c7a

60 files changed

Lines changed: 2830 additions & 1838 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

Cargo.lock

Lines changed: 34 additions & 11 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ strip = "none"
3535
debug = true
3636

3737
[workspace.dependencies]
38-
sqlparser = { version = "^0.52", features = ["bigdecimal", "serde"] }
38+
sqltk = { version = "0.5.0" }
39+
sqltk-parser = { version = "0.52.0" }
3940
thiserror = "2.0.9"
4041
tokio = { version = "1.44", features = ["full"] }
4142
tracing = "0.1"

mise.local.example.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,6 @@ CS_LOG__AUTHENTICATION_LEVEL = "info"
4242
CS_LOG__CONTEXT_LEVEL = "info"
4343
CS_LOG__KEYSET_LEVEL = "info"
4444
CS_LOG__PROTOCOL_LEVEL = "info"
45-
CS_LOG__MAPPER_LEVEL = "info"
45+
CS_LOG__MAPPER_LEVEL = "debug"
4646
CS_LOG__SCHEMA_LEVEL = "info"
4747
CS_LOG__CONFIG_LEVEL = "info"

mise.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ run = "docker compose rm --stop --force proxy proxy-tls"
108108
alias = ['t', 'ci']
109109
description = "Runs all tests (hygiene, unit, integration)"
110110
run = """
111+
mise run rust:version
111112
mise run test:check
112113
mise run test:format
113114
mise run test:clippy
@@ -369,6 +370,15 @@ mise --env tls run proxy:down
369370
description = "Runs cargo nextest, skipping integration tests"
370371
run = 'cargo nextest run --no-fail-fast --nocapture -E "not package(cipherstash-proxy-integration)" {{arg(name="test",default="")}}'
371372

373+
[tasks."rust:version"]
374+
description = "Outputs rust toolchain version info"
375+
run = """
376+
echo "rustc --version = " $(rustc --version)
377+
echo "cargo --version = " $(cargo --version)
378+
echo "cargo fmt --version = " $(cargo fmt --version)
379+
echo "cargo clippy --version = " $(cargo clippy --version)
380+
"""
381+
372382
[tasks."test:format"]
373383
description = "Runs cargo fmt"
374384
run = 'cargo fmt --all -- --check'

packages/cipherstash-proxy/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ rustls-pki-types = "1.10.0"
4040
serde = "1.0"
4141
serde_json = "1.0"
4242
socket2 = "0.5.7"
43-
sqlparser = { workspace = true }
43+
sqltk = { workspace = true }
44+
sqltk-parser = { workspace = true }
4445
thiserror = { workspace = true }
4546
tokio = { workspace = true }
4647
tokio-postgres = { version = "0.7", features = [

packages/cipherstash-proxy/src/encrypt/schema/manager.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::error::Error;
44
use crate::{connect, log::SCHEMA};
55
use arc_swap::ArcSwap;
66
use eql_mapper::{Column, Schema, Table};
7-
use sqlparser::ast::Ident;
7+
use sqltk_parser::ast::Ident;
88
use std::sync::Arc;
99
use std::time::Duration;
1010
use tokio::{task::JoinHandle, time};

packages/cipherstash-proxy/src/eql/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use cipherstash_client::{
33
zerokms::{encrypted_record, EncryptedRecord},
44
};
55
use serde::{Deserialize, Serialize};
6-
use sqlparser::ast::Ident;
6+
use sqltk_parser::ast::Ident;
77

88
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
99
pub struct Plaintext {

packages/cipherstash-proxy/src/error.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use crate::{postgresql::Column, Identifier};
22
use bytes::BytesMut;
33
use cipherstash_client::encryption;
4+
use eql_mapper::EqlMapperError;
45
use metrics_exporter_prometheus::BuildError;
56
use std::{io, time::Duration};
67
use thiserror::Error;
@@ -92,6 +93,9 @@ pub enum MappingError {
9293

9394
#[error("Statement encountered an internal error. This may be a bug in the statement mapping module of CipherStash Proxy. Please visit {}#mapping-internal-error for more information.", ERROR_DOC_BASE_URL)]
9495
Internal(String),
96+
97+
#[error(transparent)]
98+
EqlMapper(#[from] EqlMapperError),
9599
}
96100

97101
#[derive(Error, Debug)]
@@ -320,8 +324,8 @@ impl From<serde_json::Error> for Error {
320324
}
321325
}
322326

323-
impl From<sqlparser::parser::ParserError> for Error {
324-
fn from(e: sqlparser::parser::ParserError) -> Self {
327+
impl From<sqltk_parser::parser::ParserError> for Error {
328+
fn from(e: sqltk_parser::parser::ParserError) -> Self {
325329
Error::Mapping(MappingError::InvalidSqlStatement(e.to_string()))
326330
}
327331
}

packages/cipherstash-proxy/src/postgresql/data/from_sql.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use cipherstash_config::ColumnType;
1111
use postgres_types::FromSql;
1212
use postgres_types::Type;
1313
use rust_decimal::Decimal;
14-
use sqlparser::ast::Value;
14+
use sqltk_parser::ast::Value;
1515
use std::str::FromStr;
1616
use tracing::debug;
1717

packages/cipherstash-proxy/src/postgresql/frontend.rs

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,14 @@ use crate::prometheus::{
2626
use crate::Encrypted;
2727
use bytes::BytesMut;
2828
use cipherstash_client::encryption::Plaintext;
29-
use eql_mapper::{self, EqlMapperError, EqlValue, NodeKey, TableColumn, TypedStatement};
29+
use eql_mapper::{self, EqlMapperError, EqlValue, TableColumn, TypeCheckedStatement};
3030
use metrics::{counter, histogram};
3131
use pg_escape::quote_literal;
3232
use serde::Serialize;
33-
use sqlparser::ast::{self, Expr, Value};
34-
use sqlparser::dialect::PostgreSqlDialect;
35-
use sqlparser::parser::Parser;
33+
use sqltk::NodeKey;
34+
use sqltk_parser::ast::{self, Value};
35+
use sqltk_parser::dialect::PostgreSqlDialect;
36+
use sqltk_parser::parser::Parser;
3637
use std::collections::HashMap;
3738
use std::time::Instant;
3839
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
@@ -285,7 +286,7 @@ where
285286

286287
match self.to_encryptable_statement(&typed_statement, vec![])? {
287288
Some(statement) => {
288-
if statement.has_literals() || typed_statement.has_nodes_to_wrap() {
289+
if typed_statement.requires_transform() {
289290
let encrypted_literals = self
290291
.encrypt_literals(&typed_statement, &statement.literal_columns)
291292
.await?;
@@ -356,7 +357,7 @@ where
356357
///
357358
async fn encrypt_literals(
358359
&mut self,
359-
typed_statement: &TypedStatement<'_>,
360+
typed_statement: &TypeCheckedStatement<'_>,
360361
literal_columns: &Vec<Option<Column>>,
361362
) -> Result<Vec<Option<Encrypted>>, Error> {
362363
let literal_values = typed_statement.literal_values();
@@ -402,14 +403,14 @@ where
402403
///
403404
async fn transform_statement(
404405
&mut self,
405-
typed_statement: &TypedStatement<'_>,
406+
typed_statement: &TypeCheckedStatement<'_>,
406407
encrypted_literals: &Vec<Option<Encrypted>>,
407408
) -> Result<Option<ast::Statement>, Error> {
408409
// Convert literals to ast Expr
409410
let mut encrypted_expressions = vec![];
410411
for encrypted in encrypted_literals {
411412
let e = match encrypted {
412-
Some(en) => Some(to_json_literal_expr(&en)?),
413+
Some(en) => Some(to_json_literal_value(&en)?),
413414
None => None,
414415
};
415416
encrypted_expressions.push(e);
@@ -426,11 +427,10 @@ where
426427

427428
debug!(target: MAPPER,
428429
client_id = self.context.client_id,
429-
nodes_to_wrap = typed_statement.nodes_to_wrap.len(),
430430
literals = encrypted_nodes.len(),
431431
);
432432

433-
if !typed_statement.has_nodes_to_wrap() && encrypted_nodes.is_empty() {
433+
if !typed_statement.requires_transform() {
434434
return Ok(None);
435435
}
436436

@@ -500,7 +500,7 @@ where
500500

501501
match self.to_encryptable_statement(&typed_statement, param_types)? {
502502
Some(statement) => {
503-
if statement.has_literals() || typed_statement.has_nodes_to_wrap() {
503+
if typed_statement.requires_transform() {
504504
let encrypted_literals = self
505505
.encrypt_literals(&typed_statement, &statement.literal_columns)
506506
.await?;
@@ -607,7 +607,7 @@ where
607607
///
608608
fn to_encryptable_statement(
609609
&self,
610-
typed_statement: &TypedStatement<'_>,
610+
typed_statement: &TypeCheckedStatement<'_>,
611611
param_types: Vec<i32>,
612612
) -> Result<Option<Statement>, Error> {
613613
let param_columns = self.get_param_columns(typed_statement)?;
@@ -619,8 +619,7 @@ where
619619

620620
if (param_columns.is_empty() || no_encrypted_param_columns)
621621
&& (projection_columns.is_empty() || no_encrypted_projection_columns)
622-
&& literal_columns.is_empty()
623-
&& !typed_statement.has_nodes_to_wrap()
622+
&& !typed_statement.requires_transform()
624623
{
625624
return Ok(None);
626625
}
@@ -735,7 +734,10 @@ where
735734
Ok(encrypted)
736735
}
737736

738-
fn type_check<'a>(&self, statement: &'a ast::Statement) -> Result<TypedStatement<'a>, Error> {
737+
fn type_check<'a>(
738+
&self,
739+
statement: &'a ast::Statement,
740+
) -> Result<TypeCheckedStatement<'a>, Error> {
739741
match eql_mapper::type_check(self.context.get_table_resolver(), statement) {
740742
Ok(typed_statement) => {
741743
debug!(target: MAPPER,
@@ -778,10 +780,10 @@ where
778780
///
779781
fn get_projection_columns(
780782
&self,
781-
typed_statement: &eql_mapper::TypedStatement<'_>,
783+
typed_statement: &eql_mapper::TypeCheckedStatement<'_>,
782784
) -> Result<Vec<Option<Column>>, Error> {
783785
let mut projection_columns = vec![];
784-
if let Some(eql_mapper::Projection::WithColumns(columns)) = &typed_statement.projection {
786+
if let eql_mapper::Projection::WithColumns(columns) = &typed_statement.projection {
785787
for col in columns {
786788
let eql_mapper::ProjectionColumn { ty, .. } = col;
787789
let configured_column = match ty {
@@ -813,13 +815,13 @@ where
813815
///
814816
fn get_param_columns(
815817
&self,
816-
typed_statement: &eql_mapper::TypedStatement<'_>,
818+
typed_statement: &eql_mapper::TypeCheckedStatement<'_>,
817819
) -> Result<Vec<Option<Column>>, Error> {
818820
let mut param_columns = vec![];
819821

820822
for param in typed_statement.params.iter() {
821823
let configured_column = match param {
822-
eql_mapper::Value::Eql(EqlValue(TableColumn { table, column })) => {
824+
(_, eql_mapper::Value::Eql(EqlValue(TableColumn { table, column }))) => {
823825
let identifier = Identifier::from((table, column));
824826

825827
debug!(
@@ -841,7 +843,7 @@ where
841843

842844
fn get_literal_columns(
843845
&self,
844-
typed_statement: &eql_mapper::TypedStatement<'_>,
846+
typed_statement: &eql_mapper::TypeCheckedStatement<'_>,
845847
) -> Result<Vec<Option<Column>>, Error> {
846848
let mut literal_columns = vec![];
847849

@@ -967,9 +969,9 @@ fn literals_to_plaintext(
967969
Ok(plaintexts)
968970
}
969971

970-
fn to_json_literal_expr<T>(literal: &T) -> Result<Expr, Error>
972+
fn to_json_literal_value<T>(literal: &T) -> Result<Value, Error>
971973
where
972974
T: ?Sized + Serialize,
973975
{
974-
Ok(serde_json::to_string(literal).map(|json| Expr::Value(Value::SingleQuotedString(json)))?)
976+
Ok(serde_json::to_string(literal).map(Value::SingleQuotedString)?)
975977
}

0 commit comments

Comments
 (0)