@@ -26,13 +26,14 @@ use crate::prometheus::{
2626use crate :: Encrypted ;
2727use bytes:: BytesMut ;
2828use cipherstash_client:: encryption:: Plaintext ;
29- use eql_mapper:: { self , EqlMapperError , EqlValue , NodeKey , TableColumn , TypedStatement } ;
29+ use eql_mapper:: { self , EqlMapperError , EqlValue , TableColumn , TypeCheckedStatement } ;
3030use metrics:: { counter, histogram} ;
3131use pg_escape:: quote_literal;
3232use serde:: Serialize ;
33- use sqlparser:: ast:: { self , Expr , Value } ;
34- use sqlparser:: dialect:: PostgreSqlDialect ;
35- use sqlparser:: parser:: Parser ;
33+ use sqltk:: NodeKey ;
34+ use sqltk_parser:: ast:: { self , Value } ;
35+ use sqltk_parser:: dialect:: PostgreSqlDialect ;
36+ use sqltk_parser:: parser:: Parser ;
3637use std:: collections:: HashMap ;
3738use std:: time:: Instant ;
3839use tokio:: io:: { AsyncRead , AsyncWrite , AsyncWriteExt } ;
@@ -285,7 +286,7 @@ where
285286
286287 match self . to_encryptable_statement ( & typed_statement, vec ! [ ] ) ? {
287288 Some ( statement) => {
288- if statement . has_literals ( ) || typed_statement. has_nodes_to_wrap ( ) {
289+ if typed_statement. requires_transform ( ) {
289290 let encrypted_literals = self
290291 . encrypt_literals ( & typed_statement, & statement. literal_columns )
291292 . await ?;
@@ -356,7 +357,7 @@ where
356357 ///
357358 async fn encrypt_literals (
358359 & mut self ,
359- typed_statement : & TypedStatement < ' _ > ,
360+ typed_statement : & TypeCheckedStatement < ' _ > ,
360361 literal_columns : & Vec < Option < Column > > ,
361362 ) -> Result < Vec < Option < Encrypted > > , Error > {
362363 let literal_values = typed_statement. literal_values ( ) ;
@@ -402,14 +403,14 @@ where
402403 ///
403404 async fn transform_statement (
404405 & mut self ,
405- typed_statement : & TypedStatement < ' _ > ,
406+ typed_statement : & TypeCheckedStatement < ' _ > ,
406407 encrypted_literals : & Vec < Option < Encrypted > > ,
407408 ) -> Result < Option < ast:: Statement > , Error > {
408409 // Convert literals to ast Expr
409410 let mut encrypted_expressions = vec ! [ ] ;
410411 for encrypted in encrypted_literals {
411412 let e = match encrypted {
412- Some ( en) => Some ( to_json_literal_expr ( & en) ?) ,
413+ Some ( en) => Some ( to_json_literal_value ( & en) ?) ,
413414 None => None ,
414415 } ;
415416 encrypted_expressions. push ( e) ;
@@ -426,11 +427,10 @@ where
426427
427428 debug ! ( target: MAPPER ,
428429 client_id = self . context. client_id,
429- nodes_to_wrap = typed_statement. nodes_to_wrap. len( ) ,
430430 literals = encrypted_nodes. len( ) ,
431431 ) ;
432432
433- if !typed_statement. has_nodes_to_wrap ( ) && encrypted_nodes . is_empty ( ) {
433+ if !typed_statement. requires_transform ( ) {
434434 return Ok ( None ) ;
435435 }
436436
@@ -500,7 +500,7 @@ where
500500
501501 match self . to_encryptable_statement ( & typed_statement, param_types) ? {
502502 Some ( statement) => {
503- if statement . has_literals ( ) || typed_statement. has_nodes_to_wrap ( ) {
503+ if typed_statement. requires_transform ( ) {
504504 let encrypted_literals = self
505505 . encrypt_literals ( & typed_statement, & statement. literal_columns )
506506 . await ?;
@@ -607,7 +607,7 @@ where
607607 ///
608608 fn to_encryptable_statement (
609609 & self ,
610- typed_statement : & TypedStatement < ' _ > ,
610+ typed_statement : & TypeCheckedStatement < ' _ > ,
611611 param_types : Vec < i32 > ,
612612 ) -> Result < Option < Statement > , Error > {
613613 let param_columns = self . get_param_columns ( typed_statement) ?;
@@ -619,8 +619,7 @@ where
619619
620620 if ( param_columns. is_empty ( ) || no_encrypted_param_columns)
621621 && ( projection_columns. is_empty ( ) || no_encrypted_projection_columns)
622- && literal_columns. is_empty ( )
623- && !typed_statement. has_nodes_to_wrap ( )
622+ && !typed_statement. requires_transform ( )
624623 {
625624 return Ok ( None ) ;
626625 }
@@ -735,7 +734,10 @@ where
735734 Ok ( encrypted)
736735 }
737736
738- fn type_check < ' a > ( & self , statement : & ' a ast:: Statement ) -> Result < TypedStatement < ' a > , Error > {
737+ fn type_check < ' a > (
738+ & self ,
739+ statement : & ' a ast:: Statement ,
740+ ) -> Result < TypeCheckedStatement < ' a > , Error > {
739741 match eql_mapper:: type_check ( self . context . get_table_resolver ( ) , statement) {
740742 Ok ( typed_statement) => {
741743 debug ! ( target: MAPPER ,
@@ -778,10 +780,10 @@ where
778780 ///
779781 fn get_projection_columns (
780782 & self ,
781- typed_statement : & eql_mapper:: TypedStatement < ' _ > ,
783+ typed_statement : & eql_mapper:: TypeCheckedStatement < ' _ > ,
782784 ) -> Result < Vec < Option < Column > > , Error > {
783785 let mut projection_columns = vec ! [ ] ;
784- if let Some ( eql_mapper:: Projection :: WithColumns ( columns) ) = & typed_statement. projection {
786+ if let eql_mapper:: Projection :: WithColumns ( columns) = & typed_statement. projection {
785787 for col in columns {
786788 let eql_mapper:: ProjectionColumn { ty, .. } = col;
787789 let configured_column = match ty {
@@ -813,13 +815,13 @@ where
813815 ///
814816 fn get_param_columns (
815817 & self ,
816- typed_statement : & eql_mapper:: TypedStatement < ' _ > ,
818+ typed_statement : & eql_mapper:: TypeCheckedStatement < ' _ > ,
817819 ) -> Result < Vec < Option < Column > > , Error > {
818820 let mut param_columns = vec ! [ ] ;
819821
820822 for param in typed_statement. params . iter ( ) {
821823 let configured_column = match param {
822- eql_mapper:: Value :: Eql ( EqlValue ( TableColumn { table, column } ) ) => {
824+ ( _ , eql_mapper:: Value :: Eql ( EqlValue ( TableColumn { table, column } ) ) ) => {
823825 let identifier = Identifier :: from ( ( table, column) ) ;
824826
825827 debug ! (
@@ -841,7 +843,7 @@ where
841843
842844 fn get_literal_columns (
843845 & self ,
844- typed_statement : & eql_mapper:: TypedStatement < ' _ > ,
846+ typed_statement : & eql_mapper:: TypeCheckedStatement < ' _ > ,
845847 ) -> Result < Vec < Option < Column > > , Error > {
846848 let mut literal_columns = vec ! [ ] ;
847849
@@ -967,9 +969,9 @@ fn literals_to_plaintext(
967969 Ok ( plaintexts)
968970}
969971
970- fn to_json_literal_expr < T > ( literal : & T ) -> Result < Expr , Error >
972+ fn to_json_literal_value < T > ( literal : & T ) -> Result < Value , Error >
971973where
972974 T : ?Sized + Serialize ,
973975{
974- Ok ( serde_json:: to_string ( literal) . map ( |json| Expr :: Value ( Value :: SingleQuotedString ( json ) ) ) ?)
976+ Ok ( serde_json:: to_string ( literal) . map ( Value :: SingleQuotedString ) ?)
975977}
0 commit comments