Skip to content

Commit 4bf2a7c

Browse files
committed
refactor(context): push frontend coordination logic into Context
Add helper methods to Context that encapsulate multi-step operations previously done inline in frontend.rs: - with_session: update session metrics if session ID present, no-op otherwise - set_execute_for_portal: look up session ID and set execute state - close_statement_and_portal: close both statement and portal together - get_statement_session_or_latest: session lookup with fallback logic - reload_schema_if_changed: conditional schema reload Simplifies frontend.rs by centralizing related operations in Context.
1 parent e6bf1de commit 4bf2a7c

2 files changed

Lines changed: 63 additions & 41 deletions

File tree

packages/cipherstash-proxy/src/postgresql/context/mod.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,12 @@ where
302302
let _ = self.execute.write().map(|mut queue| queue.add(ctx));
303303
}
304304

305+
/// Set execute state for portal, looking up session ID internally.
306+
pub fn set_execute_for_portal(&mut self, name: Name) {
307+
let session_id = self.get_portal_session_id(&name);
308+
self.set_execute(name, session_id);
309+
}
310+
305311
/// Marks the current Execution as Complete.
306312
///
307313
/// Transfers accumulated timing data from ExecuteContext to SessionMetricsContext.phase_timing:
@@ -393,6 +399,12 @@ where
393399
.map(|mut guarded| guarded.remove(name));
394400
}
395401

402+
/// Close both statement and its associated portal.
403+
pub fn close_statement_and_portal(&mut self, name: &Name) {
404+
self.close_portal(name);
405+
self.close_statement(name);
406+
}
407+
396408
pub fn add_portal(&mut self, name: Name, portal: Portal) {
397409
debug!(target: CONTEXT, client_id = self.client_id, name = ?name, portal = ?portal);
398410
let _ = self.portals.write().map(|mut portals| {
@@ -421,6 +433,24 @@ where
421433
sessions.get(name).copied()
422434
}
423435

436+
/// Get session for statement, falling back to latest session with warning log.
437+
pub fn get_statement_session_or_latest(&self, name: &Name) -> Option<SessionId> {
438+
if let Some(id) = self.get_statement_session(name) {
439+
return Some(id);
440+
}
441+
442+
let fallback = self.latest_session_id();
443+
if fallback.is_some() {
444+
warn!(
445+
target: CONTEXT,
446+
client_id = self.client_id,
447+
prepared_statement = ?name,
448+
msg = "Session lookup failed for prepared statement, using latest session"
449+
);
450+
}
451+
fallback
452+
}
453+
424454
///
425455
/// Close the portal identified by `name`
426456
/// Portal is removed from queue
@@ -740,6 +770,13 @@ where
740770
debug!(target: CONTEXT, msg = "Database schema reloaded", ?response);
741771
}
742772

773+
/// Reload schema if it has changed since last check.
774+
pub async fn reload_schema_if_changed(&self) {
775+
if self.schema_changed() {
776+
self.reload_schema().await;
777+
}
778+
}
779+
743780
pub fn is_passthrough(&self) -> bool {
744781
self.encrypt_config.is_empty() || self.config.mapping_disabled()
745782
}
@@ -927,6 +964,16 @@ where
927964
});
928965
}
929966

967+
/// Update statement metadata if session ID is present, no-op otherwise.
968+
pub fn with_session<F>(&mut self, session_id: Option<SessionId>, f: F)
969+
where
970+
F: FnOnce(&mut SessionMetricsContext),
971+
{
972+
if let Some(sid) = session_id {
973+
self.with_session_metrics_mut(sid, f);
974+
}
975+
}
976+
930977
/// Record server wait for first response; otherwise accumulate response time for the current execute
931978
pub fn record_execute_server_timing(&mut self, duration: Duration) {
932979
if let Ok(mut queue) = self.execute.write() {

packages/cipherstash-proxy/src/postgresql/frontend.rs

Lines changed: 16 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use super::parser::SqlParser;
1111
use super::protocol::{self};
1212
use crate::connect::Sender;
1313
use crate::error::{EncryptError, Error, MappingError};
14-
use crate::log::{CONTEXT, MAPPER, PROTOCOL};
14+
use crate::log::{MAPPER, PROTOCOL};
1515
use crate::postgresql::context::column::Column;
1616
use crate::postgresql::context::statement_metadata::{ProtocolType, StatementType};
1717
use crate::postgresql::context::Portal;
@@ -268,9 +268,7 @@ where
268268
?code,
269269
);
270270

271-
if self.context.schema_changed() {
272-
self.context.reload_schema().await;
273-
}
271+
self.context.reload_schema_if_changed().await;
274272

275273
if self.error_state.is_some() {
276274
debug!(target: PROTOCOL,
@@ -331,20 +329,16 @@ where
331329
debug!(target: PROTOCOL, client_id = self.context.client_id, ?close);
332330
match close.target {
333331
Target::Portal => self.context.close_portal(&close.name),
334-
Target::Statement => {
335-
self.context.close_portal(&close.name);
336-
self.context.close_statement(&close.name);
337-
}
332+
Target::Statement => self.context.close_statement_and_portal(&close.name),
338333
}
339334
Ok(())
340335
}
341336

342337
async fn execute_handler(&mut self, bytes: &BytesMut) -> Result<(), Error> {
343338
let execute = Execute::try_from(bytes)?;
344339
debug!(target: PROTOCOL, client_id = self.context.client_id, ?execute);
345-
let session_id = self.context.get_portal_session_id(&execute.portal);
346340
self.context
347-
.set_execute(execute.portal.to_owned(), session_id);
341+
.set_execute_for_portal(execute.portal.to_owned());
348342
Ok(())
349343
}
350344

@@ -968,28 +962,14 @@ where
968962

969963
let mut bind = Bind::try_from(bytes)?;
970964

971-
let session_id = match self.context.get_statement_session(&bind.prepared_statement) {
972-
Some(id) => Some(id),
973-
None => {
974-
let fallback = self.context.latest_session_id();
975-
if fallback.is_some() {
976-
warn!(
977-
target: CONTEXT,
978-
client_id = self.context.client_id,
979-
prepared_statement = ?bind.prepared_statement,
980-
msg = "Session lookup failed for prepared statement, using latest session"
981-
);
982-
}
983-
fallback
984-
}
985-
};
965+
let session_id = self
966+
.context
967+
.get_statement_session_or_latest(&bind.prepared_statement);
986968

987969
// Track param bytes for diagnostics
988970
let param_bytes: usize = bind.param_values.iter().map(|p| p.bytes.len()).sum();
989-
if let Some(session_id) = session_id {
990-
self.context
991-
.update_statement_metadata(session_id, |m| m.set_param_bytes(param_bytes));
992-
}
971+
self.context
972+
.with_session(session_id, |m| m.metadata.set_param_bytes(param_bytes));
993973

994974
debug!(target: PROTOCOL, client_id = self.context.client_id, bind = ?bind);
995975

@@ -1008,10 +988,8 @@ where
1008988
bind.result_columns_format_codes.to_owned(),
1009989
session_id,
1010990
);
1011-
if let Some(session_id) = session_id {
1012-
self.context
1013-
.update_statement_metadata(session_id, |m| m.encrypted = true);
1014-
}
991+
self.context
992+
.with_session(session_id, |m| m.metadata.encrypted = true);
1015993
}
1016994
};
1017995

@@ -1065,16 +1043,13 @@ where
10651043

10661044
// Record timing and metadata for this encryption operation
10671045
let encrypted_count = encrypted.iter().filter(|e| e.is_some()).count();
1068-
if let Some(sid) = session_id {
1046+
self.context.with_session(session_id, |m| {
10691047
// Add to phase timing diagnostics (accumulate)
1070-
self.context.add_encrypt_duration(sid, duration);
1071-
1048+
m.phase_timing.add_encrypt(duration);
10721049
// Always update metadata for slow-statement logging
1073-
self.context.update_statement_metadata(sid, |m| {
1074-
m.encrypted = true;
1075-
m.set_encrypted_values_count(encrypted_count);
1076-
});
1077-
}
1050+
m.metadata.encrypted = true;
1051+
m.metadata.set_encrypted_values_count(encrypted_count);
1052+
});
10781053

10791054
// Prometheus metrics remain gated
10801055
if self.context.prometheus_enabled() {

0 commit comments

Comments
 (0)