From 28a7105ed3eb0770f806999a702cc2df8127f326 Mon Sep 17 00:00:00 2001 From: exveria1015 Date: Fri, 3 Apr 2026 09:13:08 +0900 Subject: [PATCH] fix(postgres): preserve NUL bytes in internal JSONB tracking state - add reversible zero-byte escaping for internal JSONB values and keys - apply the wrapper to tracking tables and setup metadata reads/writes - keep prefix-only literals unchanged for backward compatibility - add tests for JSON roundtrip and prefix-literal handling --- rust/cocoindex/src/execution/db_tracking.rs | 70 +++--- rust/cocoindex/src/execution/row_indexer.rs | 7 +- .../cocoindex/src/execution/source_indexer.rs | 2 +- rust/cocoindex/src/setup/db_metadata.rs | 78 +++++-- rust/utils/src/str_sanitize.rs | 215 +++++++++++++++++- 5 files changed, 319 insertions(+), 53 deletions(-) diff --git a/rust/cocoindex/src/execution/db_tracking.rs b/rust/cocoindex/src/execution/db_tracking.rs index ddb56cb11..dad51f8af 100644 --- a/rust/cocoindex/src/execution/db_tracking.rs +++ b/rust/cocoindex/src/execution/db_tracking.rs @@ -11,6 +11,12 @@ use sqlx::PgPool; use std::fmt; use utils::{db::WriteAction, fingerprint::Fingerprint}; +type EscapedJson = utils::str_sanitize::ZeroCodeEscapedJson; + +fn escaped_json(value: T) -> EscapedJson { + utils::str_sanitize::ZeroCodeEscapedJson(value) +} + //////////////////////////////////////////////////////////// // Access for the row tracking table //////////////////////////////////////////////////////////// @@ -86,7 +92,7 @@ pub type TrackedTargetKeyForSource = Vec<(i32, Vec)>; #[derive(sqlx::FromRow, Debug)] pub struct SourceTrackingInfoForProcessing { - pub memoization_info: Option>>, + pub memoization_info: Option>>, pub processed_source_ordinal: Option, pub processed_source_fp: Option>, @@ -113,7 +119,7 @@ pub async fn read_source_tracking_info_for_processing( ); let tracking_info = sqlx::query_as(&query_str) .bind(source_id) - .bind(source_key_json) + .bind(escaped_json(source_key_json)) .fetch_optional(pool) .await?; @@ -123,13 +129,13 @@ pub async fn read_source_tracking_info_for_processing( #[derive(sqlx::FromRow, Debug)] pub struct SourceTrackingInfoForPrecommit { pub max_process_ordinal: i64, - pub staging_target_keys: sqlx::types::Json, + pub staging_target_keys: EscapedJson, pub processed_source_ordinal: Option, pub processed_source_fp: Option>, pub process_logic_fingerprint: Option>, pub process_ordinal: Option, - pub target_keys: Option>, + pub target_keys: Option>, } pub async fn read_source_tracking_info_for_precommit( @@ -150,7 +156,7 @@ pub async fn read_source_tracking_info_for_precommit( ); let precommit_tracking_info = sqlx::query_as(&query_str) .bind(source_id) - .bind(source_key_json) + .bind(escaped_json(source_key_json)) .fetch_optional(db_executor) .await?; @@ -181,10 +187,10 @@ pub async fn precommit_source_tracking_info( }; sqlx::query(&query_str) .bind(source_id) // $1 - .bind(source_key_json) // $2 + .bind(escaped_json(source_key_json)) // $2 .bind(max_process_ordinal) // $3 - .bind(sqlx::types::Json(staging_target_keys)) // $4 - .bind(memoization_info.map(sqlx::types::Json)) // $5 + .bind(escaped_json(staging_target_keys)) // $4 + .bind(memoization_info.map(escaped_json)) // $5 .execute(db_executor) .await?; Ok(()) @@ -207,9 +213,9 @@ pub async fn touch_max_process_ordinal( ); sqlx::query(&query_str) .bind(source_id) - .bind(source_key_json) + .bind(escaped_json(source_key_json)) .bind(process_ordinal) - .bind(sqlx::types::Json(TrackedTargetKeyForSource::default())) + .bind(escaped_json(TrackedTargetKeyForSource::default())) .execute(db_executor) .await?; Ok(()) @@ -217,7 +223,7 @@ pub async fn touch_max_process_ordinal( #[derive(sqlx::FromRow, Debug)] pub struct SourceTrackingInfoForCommit { - pub staging_target_keys: sqlx::types::Json, + pub staging_target_keys: EscapedJson, pub process_ordinal: Option, } @@ -234,7 +240,7 @@ pub async fn read_source_tracking_info_for_commit( ); let commit_tracking_info = sqlx::query_as(&query_str) .bind(source_id) - .bind(source_key_json) + .bind(escaped_json(source_key_json)) .fetch_optional(db_executor) .await?; Ok(commit_tracking_info) @@ -287,13 +293,13 @@ pub async fn commit_source_tracking_info( }; let mut query = sqlx::query(&query_str) .bind(source_id) // $1 - .bind(source_key_json) // $2 - .bind(sqlx::types::Json(staging_target_keys)) // $3 + .bind(escaped_json(source_key_json)) // $2 + .bind(escaped_json(staging_target_keys)) // $3 .bind(processed_source_ordinal) // $4 .bind(logic_fingerprint) // $5 .bind(process_ordinal) // $6 .bind(process_time_micros) // $7 - .bind(sqlx::types::Json(target_keys)); // $8 + .bind(escaped_json(target_keys)); // $8 if db_setup.has_fast_fingerprint_column { query = query.bind(processed_source_fp); // $9 @@ -316,7 +322,7 @@ pub async fn delete_source_tracking_info( ); sqlx::query(&query_str) .bind(source_id) - .bind(source_key_json) + .bind(escaped_json(source_key_json)) .execute(db_executor) .await?; Ok(()) @@ -344,8 +350,8 @@ pub async fn read_tracking_entries_for_sources( let rows: Vec<( i32, - serde_json::Value, - Option>, + EscapedJson, + Option>, )> = sqlx::query_as(&query_str) .bind(source_ids) .fetch_all(pool) @@ -356,8 +362,8 @@ pub async fn read_tracking_entries_for_sources( .map( |(source_id, source_key, target_keys_json)| SourceTrackingEntryForCleanup { source_id, - source_key, - target_keys: target_keys_json.map(|j| j.0), + source_key: source_key.into_inner(), + target_keys: target_keys_json.map(EscapedJson::into_inner), }, ) .collect()) @@ -378,8 +384,8 @@ pub fn read_tracking_entries_for_sources_stream( let mut rows = sqlx::query_as::<_, ( i32, - serde_json::Value, - Option>, + EscapedJson, + Option>, )>(&query_str) .bind(&source_ids) .fetch(&pool); @@ -388,8 +394,8 @@ pub fn read_tracking_entries_for_sources_stream( let (source_id, source_key, target_keys_json) = row; yield SourceTrackingEntryForCleanup { source_id, - source_key, - target_keys: target_keys_json.map(|j| j.0), + source_key: source_key.into_inner(), + target_keys: target_keys_json.map(EscapedJson::into_inner), }; } } @@ -412,7 +418,7 @@ pub async fn delete_tracking_entries_for_sources( #[derive(sqlx::FromRow, Debug)] pub struct TrackedSourceKeyMetadata { - pub source_key: serde_json::Value, + pub source_key: EscapedJson, pub processed_source_ordinal: Option, pub processed_source_fp: Option>, pub process_logic_fingerprint: Option>, @@ -474,7 +480,7 @@ pub async fn read_source_last_processed_info( ); let last_processed_info = sqlx::query_as(&query_str) .bind(source_id) - .bind(source_key_json) + .bind(escaped_json(source_key_json)) .fetch_optional(pool) .await?; Ok(last_processed_info) @@ -494,7 +500,7 @@ pub async fn update_source_tracking_ordinal( ); sqlx::query(&query_str) .bind(source_id) // $1 - .bind(source_key_json) // $2 + .bind(escaped_json(source_key_json)) // $2 .bind(processed_source_ordinal) // $3 .execute(db_executor) .await?; @@ -521,12 +527,12 @@ pub async fn read_source_state( "SELECT value FROM {} WHERE source_id = $1 AND key = $2", table_name ); - let state: Option = sqlx::query_scalar(&query_str) + let state: Option> = sqlx::query_scalar(&query_str) .bind(source_id) - .bind(source_key_json) + .bind(escaped_json(source_key_json)) .fetch_optional(db_executor) .await?; - Ok(state) + Ok(state.map(EscapedJson::into_inner)) } #[allow(dead_code)] @@ -549,8 +555,8 @@ pub async fn upsert_source_state( ); sqlx::query(&query_str) .bind(source_id) - .bind(source_key_json) - .bind(sqlx::types::Json(state)) + .bind(escaped_json(source_key_json)) + .bind(escaped_json(state)) .execute(db_executor) .await?; Ok(()) diff --git a/rust/cocoindex/src/execution/row_indexer.rs b/rust/cocoindex/src/execution/row_indexer.rs index 032e182b2..720006a1e 100644 --- a/rust/cocoindex/src/execution/row_indexer.rs +++ b/rust/cocoindex/src/execution/row_indexer.rs @@ -589,7 +589,7 @@ impl<'a> RowIndexer<'a> { // Collect from existing tracking info. if let Some(info) = tracking_info { - let sqlx::types::Json(staging_target_keys) = info.staging_target_keys; + let staging_target_keys = info.staging_target_keys.into_inner(); for (target_id, keys_info) in staging_target_keys.into_iter() { let target_info = tracking_info_for_targets.entry(target_id).or_default(); for key_info in keys_info.into_iter() { @@ -604,7 +604,8 @@ impl<'a> RowIndexer<'a> { } } - if let Some(sqlx::types::Json(target_keys)) = info.target_keys { + if let Some(target_keys) = info.target_keys.map(|target_keys| target_keys.into_inner()) + { for (target_id, keys_info) in target_keys.into_iter() { let target_info = tracking_info_for_targets.entry(target_id).or_default(); for key_info in keys_info.into_iter() { @@ -802,7 +803,7 @@ impl<'a> RowIndexer<'a> { let cleaned_staging_target_keys = tracking_info .map(|info| { - let sqlx::types::Json(staging_target_keys) = info.staging_target_keys; + let staging_target_keys = info.staging_target_keys.into_inner(); staging_target_keys .into_iter() .filter_map(|(target_id, target_keys)| { diff --git a/rust/cocoindex/src/execution/source_indexer.rs b/rust/cocoindex/src/execution/source_indexer.rs index 680b68df5..c4eba87e0 100644 --- a/rust/cocoindex/src/execution/source_indexer.rs +++ b/rust/cocoindex/src/execution/source_indexer.rs @@ -293,7 +293,7 @@ impl SourceIndexingContext { while let Some(key_metadata) = key_metadata_stream.next().await { let key_metadata = key_metadata?; let source_pk = value::KeyValue::from_json( - key_metadata.source_key, + key_metadata.source_key.into_inner(), &import_op.primary_key_schema, )?; if let Some(rows_to_retry) = &mut rows_to_retry { diff --git a/rust/cocoindex/src/setup/db_metadata.rs b/rust/cocoindex/src/setup/db_metadata.rs index 3189f4bec..53ed7d562 100644 --- a/rust/cocoindex/src/setup/db_metadata.rs +++ b/rust/cocoindex/src/setup/db_metadata.rs @@ -9,6 +9,12 @@ use utils::db::WriteAction; const SETUP_METADATA_TABLE_NAME_UNQUALIFIED: &str = "cocoindex_setup_metadata"; pub const FLOW_VERSION_RESOURCE_TYPE: &str = "__FlowVersion"; +type EscapedJson = utils::str_sanitize::ZeroCodeEscapedJson; + +fn escaped_json(value: T) -> EscapedJson { + utils::str_sanitize::ZeroCodeEscapedJson(value) +} + #[derive(sqlx::FromRow, Debug)] pub struct SetupMetadataRecord { pub flow_name: String, @@ -44,9 +50,33 @@ pub async fn read_setup_metadata(pool: &PgPool) -> Result, + Option>, + EscapedJson>>, + ), + >(&query_str) + .fetch_all(&mut *db_conn) + .await; let result = match metadata { - Ok(metadata) => Some(metadata), + Ok(metadata) => Some( + metadata + .into_iter() + .map(|(flow_name, resource_type, key, state, staging_changes)| { + SetupMetadataRecord { + flow_name, + resource_type, + key: key.into_inner(), + state: state.map(EscapedJson::into_inner), + staging_changes: sqlx::types::Json(staging_changes.into_inner()), + } + }) + .collect(), + ), Err(err) => { let exists: Option = sqlx::query_scalar(&format!( "SELECT EXISTS (SELECT 1 FROM pg_tables WHERE tablename = '{table_name}')" @@ -88,12 +118,30 @@ async fn read_metadata_records_for_flow( let query_str = format!( "SELECT flow_name, resource_type, key, state, staging_changes FROM {table_name} WHERE flow_name = $1", ); - let metadata: Vec = sqlx::query_as(&query_str) - .bind(flow_name) - .fetch_all(db_executor) - .await?; + let metadata = sqlx::query_as::< + _, + ( + String, + String, + EscapedJson, + Option>, + EscapedJson>>, + ), + >(&query_str) + .bind(flow_name) + .fetch_all(db_executor) + .await?; let result = metadata .into_iter() + .map( + |(flow_name, resource_type, key, state, staging_changes)| SetupMetadataRecord { + flow_name, + resource_type, + key: key.into_inner(), + state: state.map(EscapedJson::into_inner), + staging_changes: sqlx::types::Json(staging_changes.into_inner()), + }, + ) .map(|m| { ( ResourceTypeKey { @@ -116,13 +164,13 @@ async fn read_state( let query_str = format!( "SELECT state FROM {table_name} WHERE flow_name = $1 AND resource_type = $2 AND key = $3", ); - let state: Option = sqlx::query_scalar(&query_str) + let state: Option> = sqlx::query_scalar(&query_str) .bind(flow_name) .bind(&type_id.resource_type) - .bind(&type_id.key) + .bind(escaped_json(&type_id.key)) .fetch_optional(db_executor) .await?; - Ok(state) + Ok(state.map(EscapedJson::into_inner)) } async fn upsert_staging_changes( @@ -144,8 +192,8 @@ async fn upsert_staging_changes( sqlx::query(&query_str) .bind(flow_name) .bind(&type_id.resource_type) - .bind(&type_id.key) - .bind(sqlx::types::Json(staging_changes)) + .bind(escaped_json(&type_id.key)) + .bind(escaped_json(staging_changes)) .execute(db_executor) .await?; Ok(()) @@ -170,9 +218,9 @@ async fn upsert_state( sqlx::query(&query_str) .bind(flow_name) .bind(&type_id.resource_type) - .bind(&type_id.key) - .bind(sqlx::types::Json(state)) - .bind(sqlx::types::Json(Vec::::new())) + .bind(escaped_json(&type_id.key)) + .bind(escaped_json(state)) + .bind(escaped_json(Vec::::new())) .execute(db_executor) .await?; Ok(()) @@ -190,7 +238,7 @@ async fn delete_state( sqlx::query(&query_str) .bind(flow_name) .bind(&type_id.resource_type) - .bind(&type_id.key) + .bind(escaped_json(&type_id.key)) .execute(db_executor) .await?; Ok(()) diff --git a/rust/utils/src/str_sanitize.rs b/rust/utils/src/str_sanitize.rs index 17b483e13..40c79f160 100644 --- a/rust/utils/src/str_sanitize.rs +++ b/rust/utils/src/str_sanitize.rs @@ -1,15 +1,19 @@ use std::borrow::Cow; use std::fmt::Display; +use std::ops::{Deref, DerefMut}; +use base64::Engine; use serde::Serialize; +use serde::de::DeserializeOwned; use serde::ser::{ SerializeMap, SerializeSeq, SerializeStruct, SerializeStructVariant, SerializeTuple, SerializeTupleStruct, SerializeTupleVariant, }; use sqlx::Type; +use sqlx::decode::Decode; use sqlx::encode::{Encode, IsNull}; use sqlx::error::BoxDynError; -use sqlx::postgres::{PgArgumentBuffer, Postgres}; +use sqlx::postgres::{PgArgumentBuffer, PgValueRef, Postgres}; pub fn strip_zero_code<'a>(s: Cow<'a, str>) -> Cow<'a, str> { if s.contains('\0') { @@ -25,6 +29,149 @@ pub fn strip_zero_code<'a>(s: Cow<'a, str>) -> Cow<'a, str> { } } +const ZERO_CODE_ESCAPE_PREFIX: &str = "__cocoindex_zero_code_b64_v1__:"; +const ZERO_CODE_ESCAPE_MAGIC: &[u8] = b"cocoindex-zero-code-v1\0"; + +fn encode_zero_code_text(s: &str) -> String { + let mut payload = Vec::with_capacity(ZERO_CODE_ESCAPE_MAGIC.len() + s.len()); + payload.extend_from_slice(ZERO_CODE_ESCAPE_MAGIC); + payload.extend_from_slice(s.as_bytes()); + format!( + "{ZERO_CODE_ESCAPE_PREFIX}{}", + base64::engine::general_purpose::STANDARD.encode(payload) + ) +} + +fn escape_zero_code_text(s: &str) -> Option { + s.contains('\0').then(|| encode_zero_code_text(s)) +} + +fn decode_zero_code_text(s: &str) -> Option { + let encoded = s.strip_prefix(ZERO_CODE_ESCAPE_PREFIX)?; + let decoded = base64::engine::general_purpose::STANDARD + .decode(encoded) + .ok()?; + let payload = decoded.strip_prefix(ZERO_CODE_ESCAPE_MAGIC)?; + String::from_utf8(payload.to_vec()).ok() +} + +pub fn escape_zero_codes_in_json(value: &mut serde_json::Value) { + match value { + serde_json::Value::String(s) => { + if let Some(escaped) = escape_zero_code_text(s) { + *s = escaped; + } + } + serde_json::Value::Array(values) => { + for value in values { + escape_zero_codes_in_json(value); + } + } + serde_json::Value::Object(values) => { + let mut escaped = serde_json::Map::new(); + for (key, mut value) in std::mem::take(values) { + escape_zero_codes_in_json(&mut value); + escaped.insert(escape_zero_code_text(&key).unwrap_or(key), value); + } + *values = escaped; + } + serde_json::Value::Null | serde_json::Value::Bool(_) | serde_json::Value::Number(_) => {} + } +} + +pub fn unescape_zero_codes_in_json(value: &mut serde_json::Value) { + match value { + serde_json::Value::String(s) => { + if let Some(decoded) = decode_zero_code_text(s) { + *s = decoded; + } + } + serde_json::Value::Array(values) => { + for value in values { + unescape_zero_codes_in_json(value); + } + } + serde_json::Value::Object(values) => { + let mut decoded = serde_json::Map::new(); + for (key, mut value) in std::mem::take(values) { + unescape_zero_codes_in_json(&mut value); + decoded.insert(decode_zero_code_text(&key).unwrap_or(key), value); + } + *values = decoded; + } + serde_json::Value::Null | serde_json::Value::Bool(_) | serde_json::Value::Number(_) => {} + } +} + +#[derive(Debug, Clone)] +pub struct ZeroCodeEscapedJson(pub T); + +impl ZeroCodeEscapedJson { + pub fn into_inner(self) -> T { + self.0 + } +} + +impl From for ZeroCodeEscapedJson { + fn from(value: T) -> Self { + Self(value) + } +} + +impl Deref for ZeroCodeEscapedJson { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for ZeroCodeEscapedJson { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl Type for ZeroCodeEscapedJson { + fn type_info() -> ::TypeInfo { + as Type>::type_info() + } + + fn compatible(ty: &::TypeInfo) -> bool { + as Type>::compatible(ty) + } +} + +impl<'q, T> Encode<'q, Postgres> for ZeroCodeEscapedJson +where + T: Serialize, +{ + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + let mut value = serde_json::to_value(&self.0)?; + escape_zero_codes_in_json(&mut value); + as Encode<'q, Postgres>>::encode_by_ref( + &sqlx::types::Json(value), + buf, + ) + } + + fn size_hint(&self) -> usize { + 0 + } +} + +impl<'r, T> Decode<'r, Postgres> for ZeroCodeEscapedJson +where + T: DeserializeOwned, +{ + fn decode(value: PgValueRef<'r>) -> Result { + let sqlx::types::Json(mut json_value) = + as Decode<'r, Postgres>>::decode(value)?; + unescape_zero_codes_in_json(&mut json_value); + Ok(Self(serde_json::from_value(json_value)?)) + } +} + /// A thin wrapper for sqlx parameter binding that strips NUL (\0) bytes /// from the wrapped string before encoding. /// @@ -500,7 +647,7 @@ where #[cfg(test)] mod tests { use super::*; - use serde::Serialize; + use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; use std::borrow::Cow; use std::collections::BTreeMap; @@ -594,4 +741,68 @@ mod tests { assert!(var.contains_key("ke\0y")); assert_eq!(var.get("ke\0y").unwrap(), &json!("bar")); } + + fn json_contains_nul(value: &Value) -> bool { + match value { + Value::String(s) => s.contains('\0'), + Value::Array(values) => values.iter().any(json_contains_nul), + Value::Object(values) => values + .iter() + .any(|(key, value)| key.contains('\0') || json_contains_nul(value)), + Value::Null | Value::Bool(_) | Value::Number(_) => false, + } + } + + #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] + struct EscapedRoundtripFixture { + text: String, + nested: BTreeMap, + } + + #[test] + fn escaped_json_roundtrip_preserves_nuls_and_prefix_literals() { + let original = EscapedRoundtripFixture { + text: format!("{ZERO_CODE_ESCAPE_PREFIX}literal\0value"), + nested: BTreeMap::from([ + ( + "k\0ey".to_string(), + json!(["a\0b", ZERO_CODE_ESCAPE_PREFIX, "plain"]), + ), + ( + ZERO_CODE_ESCAPE_PREFIX.to_string(), + json!({ + "inner\0key": format!("{ZERO_CODE_ESCAPE_PREFIX}already-prefixed"), + }), + ), + ]), + }; + + let mut escaped = serde_json::to_value(&original).unwrap(); + escape_zero_codes_in_json(&mut escaped); + + assert!(!json_contains_nul(&escaped)); + + unescape_zero_codes_in_json(&mut escaped); + let roundtrip: EscapedRoundtripFixture = serde_json::from_value(escaped).unwrap(); + assert_eq!(roundtrip, original); + } + + #[test] + fn escaping_leaves_prefix_only_literals_unchanged() { + let mut value = json!({ + "literal": format!("{ZERO_CODE_ESCAPE_PREFIX}plain-text"), + "nested": [ZERO_CODE_ESCAPE_PREFIX], + }); + + let original = value.clone(); + escape_zero_codes_in_json(&mut value); + + assert_eq!(value, original); + } + + #[test] + fn decode_zero_code_text_ignores_unrelated_prefixed_values() { + let literal = format!("{ZERO_CODE_ESCAPE_PREFIX}not-base64"); + assert_eq!(decode_zero_code_text(&literal), None); + } }