diff --git a/drivers/postgres/internal/backfill.go b/drivers/postgres/internal/backfill.go index fbcc3c257..2f5e58214 100644 --- a/drivers/postgres/internal/backfill.go +++ b/drivers/postgres/internal/backfill.go @@ -52,7 +52,8 @@ func (p *Postgres) backfill(pool *protocol.WriterPool, stream protocol.Stream) e defer tx.Rollback() splitColumn := stream.Self().StreamMetadata.SplitColumn splitColumn = utils.Ternary(splitColumn == "", "ctid", splitColumn).(string) - stmt := jdbc.PostgresChunkScanQuery(stream, splitColumn, chunk) + splitColType, _ := stream.Schema().GetType(splitColumn) + stmt := jdbc.PostgresChunkScanQuery(stream, splitColumn, chunk, splitColType) setter := jdbc.NewReader(backfillCtx, stmt, p.config.BatchSize, func(ctx context.Context, query string, args ...any) (*sql.Rows, error) { return tx.Query(query, args...) @@ -163,8 +164,8 @@ func (p *Postgres) splitTableIntoChunks(stream protocol.Stream) ([]types.Chunk, splitColumn := stream.Self().StreamMetadata.SplitColumn if splitColumn != "" { var minValue, maxValue interface{} - minMaxRowCountQuery := jdbc.MinMaxQuery(stream, splitColumn) - // TODO: Fails on UUID type (Good First Issue) + splitColType, _ := stream.Schema().GetType(splitColumn) + minMaxRowCountQuery := jdbc.PostgresMinMaxQuery(stream, splitColumn, splitColType) err := p.client.QueryRow(minMaxRowCountQuery).Scan(&minValue, &maxValue) if err != nil { return nil, fmt.Errorf("failed to fetch table min max: %s", err) @@ -180,7 +181,6 @@ func (p *Postgres) splitTableIntoChunks(stream protocol.Stream) ([]types.Chunk, return nil, fmt.Errorf("provided split column is not a primary key") } - splitColType, _ := stream.Schema().GetType(splitColumn) // evenly distirbution only available for float and int types if splitColType == types.Int64 || splitColType == types.Float64 { return splitViaBatchSize(minValue, maxValue, p.config.BatchSize) @@ -193,7 +193,8 @@ func (p *Postgres) splitTableIntoChunks(stream protocol.Stream) ([]types.Chunk, func (p *Postgres) nextChunkEnd(stream protocol.Stream, previousChunkEnd interface{}, splitColumn string) (interface{}, error) { var chunkEnd interface{} - nextChunkEnd := jdbc.PostgresNextChunkEndQuery(stream, splitColumn, previousChunkEnd, p.config.BatchSize) + splitColType, _ := stream.Schema().GetType(splitColumn) + nextChunkEnd := jdbc.PostgresNextChunkEndQuery(stream, splitColumn, previousChunkEnd, p.config.BatchSize, splitColType) err := p.client.QueryRow(nextChunkEnd).Scan(&chunkEnd) if err != nil { return nil, fmt.Errorf("failed to query[%s] next chunk end: %s", nextChunkEnd, err) diff --git a/pkg/jdbc/jdbc.go b/pkg/jdbc/jdbc.go index 959d4460e..376d21d7a 100644 --- a/pkg/jdbc/jdbc.go +++ b/pkg/jdbc/jdbc.go @@ -19,19 +19,47 @@ func NextChunkEndQuery(stream protocol.Stream, column string, chunkSize int) str return fmt.Sprintf(`SELECT MAX(%[1]s) FROM (SELECT %[1]s FROM %[2]s.%[3]s WHERE %[1]s > ? ORDER BY %[1]s LIMIT %[4]d) AS subquery`, column, stream.Namespace(), stream.Name(), chunkSize) } -// buildChunkCondition builds the condition for a chunk -func buildChunkCondition(filterColumn string, chunk types.Chunk) string { - if chunk.Min != nil && chunk.Max != nil { - return fmt.Sprintf("%s >= %v AND %s <= %v", filterColumn, chunk.Min, filterColumn, chunk.Max) - } else if chunk.Min != nil { - return fmt.Sprintf("%s >= %v", filterColumn, chunk.Min) +// buildChunkCondition creates SQL conditions for filtering based on chunk boundaries +// with formatting determined by the provided formatter function +func buildChunkCondition( + filterColumn string, + chunk types.Chunk, + formatter func(column string, operator string, value interface{}) string, +) string { + // If formatter is nil, use default formatting + if formatter == nil { + formatter = func(column string, operator string, value interface{}) string { + return fmt.Sprintf("%s %s %v", column, operator, value) + } + } + + // Only Min condition + if chunk.Min != nil && chunk.Max == nil { + return formatter(filterColumn, ">=", chunk.Min) } - return fmt.Sprintf("%s <= %v", filterColumn, chunk.Max) + + // Only Max condition + if chunk.Min == nil && chunk.Max != nil { + return formatter(filterColumn, "<=", chunk.Max) + } + + // Both Min and Max conditions + return fmt.Sprintf("%s AND %s", + formatter(filterColumn, ">=", chunk.Min), + formatter(filterColumn, "<=", chunk.Max)) } // PostgreSQL-Specific Queries // TODO: Rewrite queries for taking vars as arguments while execution. +// PostgresMinMaxQuery returns the query to fetch MIN and MAX values of a column in a table +func PostgresMinMaxQuery(stream protocol.Stream, column string, columnType types.DataType) string { + if columnType == types.String { + return fmt.Sprintf(`SELECT MIN(%[1]s::text) AS min_value, MAX(%[1]s::text) AS max_value FROM %[2]s.%[3]s`, column, stream.Namespace(), stream.Name()) + } + return fmt.Sprintf(`SELECT MIN(%[1]s) AS min_value, MAX(%[1]s) AS max_value FROM %[2]s.%[3]s`, column, stream.Namespace(), stream.Name()) +} + // PostgresWithoutState returns the query for a simple SELECT without state func PostgresWithoutState(stream protocol.Stream) string { return fmt.Sprintf(`SELECT * FROM "%s"."%s" ORDER BY %s`, stream.Namespace(), stream.Name(), stream.Cursor()) @@ -58,7 +86,10 @@ func PostgresWalLSNQuery() string { } // PostgresNextChunkEndQuery generates a SQL query to fetch the maximum value of a specified column -func PostgresNextChunkEndQuery(stream protocol.Stream, filterColumn string, filterValue interface{}, batchSize int) string { +func PostgresNextChunkEndQuery(stream protocol.Stream, filterColumn string, filterValue interface{}, batchSize int, filterColumnType types.DataType) string { + if filterColumnType == types.String { + return fmt.Sprintf(`SELECT MAX(%s::text) FROM (SELECT %s FROM "%s"."%s" WHERE %s::text > $$%v$$ ORDER BY %s ASC LIMIT %d) AS T`, filterColumn, filterColumn, stream.Namespace(), stream.Name(), filterColumn, filterValue, filterColumn, batchSize) + } return fmt.Sprintf(`SELECT MAX(%s) FROM (SELECT %s FROM "%s"."%s" WHERE %s > %v ORDER BY %s ASC LIMIT %d) AS T`, filterColumn, filterColumn, stream.Namespace(), stream.Name(), filterColumn, filterValue, filterColumn, batchSize) } @@ -68,8 +99,14 @@ func PostgresMinQuery(stream protocol.Stream, filterColumn string, filterValue i } // PostgresBuildSplitScanQuery builds a chunk scan query for PostgreSQL -func PostgresChunkScanQuery(stream protocol.Stream, filterColumn string, chunk types.Chunk) string { - condition := buildChunkCondition(filterColumn, chunk) +func PostgresChunkScanQuery(stream protocol.Stream, filterColumn string, chunk types.Chunk, filterColumnType types.DataType) string { + postgresFormatter := func(column string, operator string, value interface{}) string { + if filterColumnType == types.String { + return fmt.Sprintf("%s::text %s $$%v$$", column, operator, value) + } + return fmt.Sprintf("%s %s %v", column, operator, value) + } + condition := buildChunkCondition(filterColumn, chunk, postgresFormatter) return fmt.Sprintf(`SELECT * FROM "%s"."%s" WHERE %s`, stream.Namespace(), stream.Name(), condition) } @@ -77,7 +114,7 @@ func PostgresChunkScanQuery(stream protocol.Stream, filterColumn string, chunk t // MySQLWithoutState builds a chunk scan query for MySql func MysqlChunkScanQuery(stream protocol.Stream, filterColumn string, chunk types.Chunk) string { - condition := buildChunkCondition(filterColumn, chunk) + condition := buildChunkCondition(filterColumn, chunk, nil) return fmt.Sprintf("SELECT * FROM `%s`.`%s` WHERE %s", stream.Namespace(), stream.Name(), condition) } @@ -149,6 +186,7 @@ func MySQLTableColumnsQuery() string { ORDER BY ORDINAL_POSITION ` } + func WithIsolation(ctx context.Context, client *sql.DB, fn func(tx *sql.Tx) error) error { tx, err := client.BeginTx(ctx, &sql.TxOptions{ Isolation: sql.LevelRepeatableRead,