Skip to content

Commit 9e4d145

Browse files
committed
feat: TransformationRule dry-run support
The Proxy must be able to reliably detect whether a statement requires transformation because if a statement does not require transformation then the potentially expensive AST rebuilding step can be skipped. The result of type-checking is insufficient in general to tell whether a statement requires transformation unless the `TransformationRule` logic is duplicated in the Proxy - which we don't want of course. This commit extends the `TransformationRule` trait with a `would_edit` method which answers the question "would this rule change the AST if it was applied?". Additionally, a new `TranformationRule` impl `DryRunnable` wraps another rule in such a way that it can "pretend" to be performing a `Transform` (as far as `sqltk` is concerned) when really its doing a dry-run after which it will tell us if it *would* change the AST. `TypedCheckedStatement::requires_transform` is a new method that wraps up the dry-run logic and tells the called whether a statement must be transformed
1 parent 4aada7e commit 9e4d145

16 files changed

Lines changed: 567 additions & 247 deletions

mise.local.example.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,6 @@ CS_LOG__AUTHENTICATION_LEVEL = "info"
4242
CS_LOG__CONTEXT_LEVEL = "info"
4343
CS_LOG__KEYSET_LEVEL = "info"
4444
CS_LOG__PROTOCOL_LEVEL = "info"
45-
CS_LOG__MAPPER_LEVEL = "info"
45+
CS_LOG__MAPPER_LEVEL = "debug"
4646
CS_LOG__SCHEMA_LEVEL = "info"
4747
CS_LOG__CONFIG_LEVEL = "info"

packages/cipherstash-proxy/src/error.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use crate::{postgresql::Column, Identifier};
22
use bytes::BytesMut;
33
use cipherstash_client::encryption;
4+
use eql_mapper::EqlMapperError;
45
use metrics_exporter_prometheus::BuildError;
56
use std::{io, time::Duration};
67
use thiserror::Error;
@@ -92,6 +93,9 @@ pub enum MappingError {
9293

9394
#[error("Statement encountered an internal error. This may be a bug in the statement mapping module of CipherStash Proxy. Please visit {}#mapping-internal-error for more information.", ERROR_DOC_BASE_URL)]
9495
Internal(String),
96+
97+
#[error(transparent)]
98+
EqlMapper(#[from] EqlMapperError),
9599
}
96100

97101
#[derive(Error, Debug)]

packages/cipherstash-proxy/src/postgresql/frontend.rs

Lines changed: 68 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ use crate::prometheus::{
2626
use crate::Encrypted;
2727
use bytes::BytesMut;
2828
use cipherstash_client::encryption::Plaintext;
29-
use eql_mapper::{self, EqlMapperError, EqlValue, TableColumn, TypedStatement};
29+
use eql_mapper::{self, EqlMapperError, EqlValue, TableColumn, TypeCheckedStatement};
3030
use metrics::{counter, histogram};
3131
use pg_escape::quote_literal;
3232
use serde::Serialize;
33-
use sqltk::AsNodeKey;
33+
use sqltk::NodeKey;
3434
use sqltk_parser::ast::{self, Value};
3535
use sqltk_parser::dialect::PostgreSqlDialect;
3636
use sqltk_parser::parser::Parser;
@@ -286,22 +286,23 @@ where
286286

287287
match self.to_encryptable_statement(&typed_statement, vec![])? {
288288
Some(statement) => {
289-
let encrypted_literals = self
290-
.encrypt_literals(&typed_statement, &statement.literal_columns)
291-
.await?;
292-
293-
if let Some(transformed_statement) = self
294-
.transform_statement(&typed_statement, &encrypted_literals)
295-
.await?
296-
{
297-
debug!(target: MAPPER,
298-
client_id = self.context.client_id,
299-
transformed_statement = ?transformed_statement,
300-
transformed_statement_text = %transformed_statement,
301-
);
289+
if typed_statement.requires_transform() {
290+
let encrypted_literals = self
291+
.encrypt_literals(&typed_statement, &statement.literal_columns)
292+
.await?;
293+
294+
if let Some(transformed_statement) = self
295+
.transform_statement(&typed_statement, &encrypted_literals)
296+
.await?
297+
{
298+
debug!(target: MAPPER,
299+
client_id = self.context.client_id,
300+
transformed_statement = ?transformed_statement,
301+
);
302302

303-
transformed_statements.push(transformed_statement);
304-
encrypted = true;
303+
transformed_statements.push(transformed_statement);
304+
encrypted = true;
305+
}
305306
}
306307
debug!(target: MAPPER,
307308
client_id = self.context.client_id,
@@ -356,7 +357,7 @@ where
356357
///
357358
async fn encrypt_literals(
358359
&mut self,
359-
typed_statement: &TypedStatement<'_>,
360+
typed_statement: &TypeCheckedStatement<'_>,
360361
literal_columns: &Vec<Option<Column>>,
361362
) -> Result<Vec<Option<Encrypted>>, Error> {
362363
let literal_values = typed_statement.literal_values();
@@ -402,14 +403,14 @@ where
402403
///
403404
async fn transform_statement(
404405
&mut self,
405-
typed_statement: &TypedStatement<'_>,
406+
typed_statement: &TypeCheckedStatement<'_>,
406407
encrypted_literals: &Vec<Option<Encrypted>>,
407408
) -> Result<Option<ast::Statement>, Error> {
408409
// Convert literals to ast Expr
409410
let mut encrypted_expressions = vec![];
410411
for encrypted in encrypted_literals {
411412
let e = match encrypted {
412-
Some(en) => Some(to_json_literal_expr(&en)?),
413+
Some(en) => Some(to_json_literal_value(&en)?),
413414
None => None,
414415
};
415416
encrypted_expressions.push(e);
@@ -421,14 +422,18 @@ where
421422
.literals
422423
.iter()
423424
.zip(encrypted_expressions.into_iter())
424-
.filter_map(|((_, original_node), en)| en.map(|en| (original_node.as_node_key(), en)))
425+
.filter_map(|((_, original_node), en)| en.map(|en| (NodeKey::new(*original_node), en)))
425426
.collect::<HashMap<_, _>>();
426427

427428
debug!(target: MAPPER,
428429
client_id = self.context.client_id,
429430
literals = encrypted_nodes.len(),
430431
);
431432

433+
if !typed_statement.requires_transform() {
434+
return Ok(None);
435+
}
436+
432437
let transformed_statement = typed_statement
433438
.transform(encrypted_nodes)
434439
.map_err(|e| MappingError::StatementCouldNotBeTransformed(e.to_string()))?;
@@ -495,21 +500,22 @@ where
495500

496501
match self.to_encryptable_statement(&typed_statement, param_types)? {
497502
Some(statement) => {
498-
let encrypted_literals = self
499-
.encrypt_literals(&typed_statement, &statement.literal_columns)
500-
.await?;
501-
502-
if let Some(transformed_statement) = self
503-
.transform_statement(&typed_statement, &encrypted_literals)
504-
.await?
505-
{
506-
debug!(target: MAPPER,
507-
client_id = self.context.client_id,
508-
transformed_statement = ?transformed_statement,
509-
transformed_statement_text = %transformed_statement,
510-
);
503+
if typed_statement.requires_transform() {
504+
let encrypted_literals = self
505+
.encrypt_literals(&typed_statement, &statement.literal_columns)
506+
.await?;
507+
508+
if let Some(transformed_statement) = self
509+
.transform_statement(&typed_statement, &encrypted_literals)
510+
.await?
511+
{
512+
debug!(target: MAPPER,
513+
client_id = self.context.client_id,
514+
transformed_statement = ?transformed_statement,
515+
);
511516

512-
message.rewrite_statement(transformed_statement.to_string());
517+
message.rewrite_statement(transformed_statement.to_string());
518+
}
513519
}
514520

515521
counter!(STATEMENTS_ENCRYPTED_TOTAL).increment(1);
@@ -526,14 +532,19 @@ where
526532
counter!(STATEMENTS_PASSTHROUGH_TOTAL).increment(1);
527533
}
528534
}
529-
let bytes = BytesMut::try_from(message)?;
530535

531-
debug!(target: MAPPER,
536+
if message.requires_rewrite() {
537+
let bytes = BytesMut::try_from(message)?;
538+
539+
debug!(target: MAPPER,
532540
client_id = self.context.client_id,
533541
msg = "Rewrite Parse",
534542
bytes = ?bytes);
535543

536-
Ok(Some(bytes))
544+
Ok(Some(bytes))
545+
} else {
546+
Ok(None)
547+
}
537548
}
538549

539550
///
@@ -596,13 +607,23 @@ where
596607
///
597608
fn to_encryptable_statement(
598609
&self,
599-
typed_statement: &TypedStatement<'_>,
610+
typed_statement: &TypeCheckedStatement<'_>,
600611
param_types: Vec<i32>,
601612
) -> Result<Option<Statement>, Error> {
602613
let param_columns = self.get_param_columns(typed_statement)?;
603614
let projection_columns = self.get_projection_columns(typed_statement)?;
604615
let literal_columns = self.get_literal_columns(typed_statement)?;
605616

617+
let no_encrypted_param_columns = param_columns.iter().all(|c| c.is_none());
618+
let no_encrypted_projection_columns = projection_columns.iter().all(|c| c.is_none());
619+
620+
if (param_columns.is_empty() || no_encrypted_param_columns)
621+
&& (projection_columns.is_empty() || no_encrypted_projection_columns)
622+
&& !typed_statement.requires_transform()
623+
{
624+
return Ok(None);
625+
}
626+
606627
debug!(target: MAPPER,
607628
client_id = self.context.client_id,
608629
msg = "Encryptable Statement",
@@ -713,7 +734,10 @@ where
713734
Ok(encrypted)
714735
}
715736

716-
fn type_check<'a>(&self, statement: &'a ast::Statement) -> Result<TypedStatement<'a>, Error> {
737+
fn type_check<'a>(
738+
&self,
739+
statement: &'a ast::Statement,
740+
) -> Result<TypeCheckedStatement<'a>, Error> {
717741
match eql_mapper::type_check(self.context.get_table_resolver(), statement) {
718742
Ok(typed_statement) => {
719743
debug!(target: MAPPER,
@@ -756,7 +780,7 @@ where
756780
///
757781
fn get_projection_columns(
758782
&self,
759-
typed_statement: &eql_mapper::TypedStatement<'_>,
783+
typed_statement: &eql_mapper::TypeCheckedStatement<'_>,
760784
) -> Result<Vec<Option<Column>>, Error> {
761785
let mut projection_columns = vec![];
762786
if let eql_mapper::Projection::WithColumns(columns) = &typed_statement.projection {
@@ -791,7 +815,7 @@ where
791815
///
792816
fn get_param_columns(
793817
&self,
794-
typed_statement: &eql_mapper::TypedStatement<'_>,
818+
typed_statement: &eql_mapper::TypeCheckedStatement<'_>,
795819
) -> Result<Vec<Option<Column>>, Error> {
796820
let mut param_columns = vec![];
797821

@@ -819,7 +843,7 @@ where
819843

820844
fn get_literal_columns(
821845
&self,
822-
typed_statement: &eql_mapper::TypedStatement<'_>,
846+
typed_statement: &eql_mapper::TypeCheckedStatement<'_>,
823847
) -> Result<Vec<Option<Column>>, Error> {
824848
let mut literal_columns = vec![];
825849

@@ -945,7 +969,7 @@ fn literals_to_plaintext(
945969
Ok(plaintexts)
946970
}
947971

948-
fn to_json_literal_expr<T>(literal: &T) -> Result<Value, Error>
972+
fn to_json_literal_value<T>(literal: &T) -> Result<Value, Error>
949973
where
950974
T: ?Sized + Serialize,
951975
{

packages/eql-mapper/src/encrypted_statement.rs

Lines changed: 0 additions & 61 deletions
This file was deleted.

packages/eql-mapper/src/eql_mapper.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ use super::importer::{ImportError, Importer};
22
use crate::{
33
inference::{TypeError, TypeInferencer},
44
unifier::{EqlValue, Unifier},
5-
DepMut, Fmt, Param, ParamError, ScopeError, ScopeTracker, TableResolver, Type, TypeRegistry,
6-
TypedStatement, Value,
5+
DepMut, Fmt, Param, ParamError, ScopeError, ScopeTracker, TableResolver, Type,
6+
TypeCheckedStatement, TypeRegistry, Value,
77
};
88
use sqltk::{Break, NodeKey, Visitable, Visitor};
99
use sqltk_parser::ast::{self as ast, Statement};
@@ -22,19 +22,19 @@ use tracing::{event, span, Level};
2222
/// - all operators and functions used with literals destined to be transformed to EQL types are semantically valid for
2323
/// that EQL type
2424
///
25-
/// A successful type check will return a [`TypedStatement`] which can be interrogated to discover the required params
25+
/// A successful type check will return a [`TypeCheckedStatement`] which can be interrogated to discover the required params
2626
/// and their types, the types and plaintext values of all literals, and an optional projection type (the optionality
2727
/// depending on the specific statement).
2828
///
29-
/// Invoking [`TypedStatement::transform`] will return an updated [`Statement`] where all plaintext literals have been
29+
/// Invoking [`TypeCheckedStatement::transform`] will return an updated [`Statement`] where all plaintext literals have been
3030
/// replaced with their EQL (encrypted) equivalent and specific SQL operators and functions will have been rewritten to
3131
/// invoke the EQL equivalents.
3232
///
3333
/// An [`EqlMapperError`] is returned if type checking fails.
3434
pub fn type_check<'ast>(
3535
resolver: Arc<TableResolver>,
3636
statement: &'ast Statement,
37-
) -> Result<TypedStatement<'ast>, EqlMapperError> {
37+
) -> Result<TypeCheckedStatement<'ast>, EqlMapperError> {
3838
let mut mapper = EqlMapper::<'ast>::new_with_resolver(resolver);
3939
match statement.accept(&mut mapper) {
4040
ControlFlow::Continue(()) => mapper.resolve(statement),
@@ -137,7 +137,7 @@ impl<'ast> EqlMapper<'ast> {
137137
pub fn resolve(
138138
self,
139139
statement: &'ast Statement,
140-
) -> Result<TypedStatement<'ast>, EqlMapperError> {
140+
) -> Result<TypeCheckedStatement<'ast>, EqlMapperError> {
141141
let span_begin = span!(
142142
target: "eqlmapper::spans",
143143
Level::TRACE,
@@ -169,7 +169,7 @@ impl<'ast> EqlMapper<'ast> {
169169
node_types = %Fmt(&node_types)
170170
);
171171

172-
Ok(TypedStatement {
172+
Ok(TypeCheckedStatement {
173173
statement,
174174
projection,
175175
params,

0 commit comments

Comments
 (0)