@@ -46,8 +46,15 @@ export function resolveTableSources(
4646 }
4747
4848 // At least one is a query — wrap both in CTEs
49- const srcExpr = source_is_query ? source : `SELECT * FROM ${ source } `
50- const tgtExpr = target_is_query ? target : `SELECT * FROM ${ target } `
49+ // Quote identifier parts so table names with special chars don't inject SQL.
50+ // Use double-quote escaping (ANSI SQL standard, works in Postgres/Snowflake/DuckDB/etc.)
51+ const quoteIdent = ( name : string ) =>
52+ name
53+ . split ( "." )
54+ . map ( ( p ) => `"${ p . replace ( / " / g, '""' ) } "` )
55+ . join ( "." )
56+ const srcExpr = source_is_query ? source : `SELECT * FROM ${ quoteIdent ( source ) } `
57+ const tgtExpr = target_is_query ? target : `SELECT * FROM ${ quoteIdent ( target ) } `
5158
5259 const ctePrefix = `WITH __diff_source AS (\n${ srcExpr } \n), __diff_target AS (\n${ tgtExpr } \n)`
5360 return {
@@ -247,16 +254,16 @@ function buildColumnDiscoverySQL(tableName: string, dialect: string): string {
247254 }
248255 case "oracle" : {
249256 // Oracle uses ALL_TAB_COLUMNS (no information_schema)
250- const oracleTable = parts [ parts . length - 1 ]
257+ const oracleTable = esc ( parts [ parts . length - 1 ] )
251258 const conditions = [ `TABLE_NAME = '${ oracleTable . toUpperCase ( ) } '` ]
252259 if ( parts . length >= 2 ) {
253- conditions . push ( `OWNER = '${ parts [ parts . length - 2 ] . toUpperCase ( ) } '` )
260+ conditions . push ( `OWNER = '${ esc ( parts [ parts . length - 2 ] ) . toUpperCase ( ) } '` )
254261 }
255262 return `SELECT COLUMN_NAME, DATA_DEFAULT FROM ALL_TAB_COLUMNS WHERE ${ conditions . join ( " AND " ) } ORDER BY COLUMN_ID`
256263 }
257264 case "sqlite" : {
258265 // PRAGMA table_info returns: cid, name, type, notnull, dflt_value, pk
259- const table = parts [ parts . length - 1 ]
266+ const table = esc ( parts [ parts . length - 1 ] )
260267 return `PRAGMA table_info('${ table } ')`
261268 }
262269 default : {
@@ -393,9 +400,19 @@ function dateTruncExpr(granularity: string, column: string, dialect: string): st
393400 const fmt = { day : "%Y-%m-%d" , week : "%Y-%u" , month : "%Y-%m-01" , year : "%Y-01-01" } [ g ] ?? "%Y-%m-01"
394401 return `DATE_FORMAT(${ column } , '${ fmt } ')`
395402 }
396- case "oracle" :
397- // Oracle uses TRUNC(), not DATE_TRUNC()
398- return `TRUNC(${ column } , '${ g . toUpperCase ( ) } ')`
403+ case "oracle" : {
404+ // Oracle uses TRUNC() with format models — 'WEEK' is invalid, use 'IW' for ISO week
405+ const oracleFmt : Record < string , string > = {
406+ day : "DDD" ,
407+ week : "IW" ,
408+ month : "MM" ,
409+ year : "YYYY" ,
410+ quarter : "Q" ,
411+ hour : "HH" ,
412+ minute : "MI" ,
413+ }
414+ return `TRUNC(${ column } , '${ oracleFmt [ g ] ?? g . toUpperCase ( ) } ')`
415+ }
399416 default :
400417 // Postgres, Snowflake, Redshift, DuckDB, etc.
401418 return `DATE_TRUNC('${ g } ', ${ column } )`
@@ -455,21 +472,23 @@ function buildPartitionWhereClause(
455472 dialect : string ,
456473) : string {
457474 const mode = partitionMode ( granularity , bucketSize )
475+ // Quote the column identifier to handle special characters and reserved words
476+ const quotedCol = `"${ partitionColumn . replace ( / " / g, '""' ) } "`
458477
459478 if ( mode === "numeric" ) {
460479 const lo = Number ( partitionValue )
461480 const hi = lo + bucketSize !
462- return `${ partitionColumn } >= ${ lo } AND ${ partitionColumn } < ${ hi } `
481+ return `${ quotedCol } >= ${ lo } AND ${ quotedCol } < ${ hi } `
463482 }
464483
465484 if ( mode === "categorical" ) {
466485 // Quote the value — works for strings, enums, booleans
467486 const escaped = partitionValue . replace ( / ' / g, "''" )
468- return `${ partitionColumn } = '${ escaped } '`
487+ return `${ quotedCol } = '${ escaped } '`
469488 }
470489
471490 // date mode
472- const expr = dateTruncExpr ( granularity ! , partitionColumn , dialect )
491+ const expr = dateTruncExpr ( granularity ! , quotedCol , dialect )
473492
474493 // Cast the literal appropriately per dialect
475494 switch ( dialect ) {
@@ -779,21 +798,32 @@ export async function runDataDiff(params: DataDiffParams): Promise<DataDiffResul
779798
780799 // Execute all SQL tasks in parallel
781800 const tasks = action . tasks ?? [ ]
782- const responses = await Promise . all (
801+ const taskResults = await Promise . all (
783802 tasks . map ( async ( task ) => {
784803 const warehouse = warehouseFor ( task . table_side )
785804 // Inject CTE definitions if we're in query-comparison mode
786805 const sql = ctePrefix ? injectCte ( task . sql , ctePrefix ) : task . sql
787806 try {
788807 const rows = await executeQuery ( sql , warehouse )
789- return { id : task . id , rows }
808+ return { id : task . id , rows, error : null }
790809 } catch ( e ) {
791- // Return error shape — engine will produce an Error action on next step
792- return { id : task . id , rows : [ ] , error : String ( e ) }
810+ return { id : task . id , rows : [ ] as ( string | null ) [ ] [ ] , error : String ( e ) }
793811 }
794812 } ) ,
795813 )
796814
815+ // Surface any SQL execution errors before feeding to the engine
816+ const sqlError = taskResults . find ( ( r ) => r . error !== null )
817+ if ( sqlError ) {
818+ return {
819+ success : false ,
820+ error : `SQL execution failed for task ${ sqlError . id } : ${ sqlError . error } ` ,
821+ steps : stepCount ,
822+ }
823+ }
824+
825+ const responses = taskResults . map ( ( { id, rows } ) => ( { id, rows } ) )
826+
797827 actionJson = session . step ( JSON . stringify ( responses ) )
798828 }
799829
0 commit comments