diff --git a/drivers/postgres/internal/cdc.go b/drivers/postgres/internal/cdc.go index 961adee73..bf35c1575 100644 --- a/drivers/postgres/internal/cdc.go +++ b/drivers/postgres/internal/cdc.go @@ -26,6 +26,7 @@ func (p *Postgres) prepareWALJSConfig(streams ...types.StreamInterface) (*waljs. InitialWaitTime: time.Duration(p.cdcConfig.InitialWaitTime) * time.Second, Tables: types.NewSet(streams...), Publication: p.cdcConfig.Publication, + PluginArgs: p.cdcConfig.PluginArgs, }, nil } diff --git a/drivers/postgres/internal/config.go b/drivers/postgres/internal/config.go index eabe40beb..5c47ab200 100644 --- a/drivers/postgres/internal/config.go +++ b/drivers/postgres/internal/config.go @@ -31,6 +31,9 @@ type CDC struct { InitialWaitTime int `json:"initial_wait_time"` // Publications used when OutputPlugin is pgoutput Publication string `json:"publication"` + // PluginArgs allows custom replication plugin arguments + // Format: key-value pairs e.g., {"include-unchanged-toast": "false", "format-version": "2"} + PluginArgs map[string]string `json:"plugin_args"` } func (c *Config) Validate() error { diff --git a/drivers/postgres/resources/spec.json b/drivers/postgres/resources/spec.json index e13f774e7..3d2da8233 100644 --- a/drivers/postgres/resources/spec.json +++ b/drivers/postgres/resources/spec.json @@ -105,6 +105,14 @@ "type": "string", "title": "Publication", "description": "Publication defines which tables need to be consumed" + }, + "plugin_args": { + "type": "object", + "title": "Plugin Arguments", + "description": "Custom replication plugin arguments as key-value pairs (optional). Example: {\"include-unchanged-toast\": \"false\", \"format-version\": \"2\"}", + "additionalProperties": { + "type": "string" + } } }, "required": [ diff --git a/pkg/waljs/decoder.go b/pkg/waljs/decoder.go new file mode 100644 index 000000000..10a28932c --- /dev/null +++ b/pkg/waljs/decoder.go @@ -0,0 +1,303 @@ +package waljs + +import ( + "encoding/json" + "fmt" + + "github.com/datazip-inc/olake/utils/typeutils" + "github.com/jackc/pgtype" +) + +// PgtypeDecoder uses pgtype for structured decoding of PostgreSQL binary data +type PgtypeDecoder struct { + connInfo *pgtype.ConnInfo +} + +// NewPgtypeDecoder creates a decoder with registered pgtype handlers +func NewPgtypeDecoder() *PgtypeDecoder { + connInfo := pgtype.NewConnInfo() + return &PgtypeDecoder{ + connInfo: connInfo, + } +} + +// DecodeBinary decodes PostgreSQL binary data based on OID into a Go value +// This eliminates the need to convert binary -> string -> type +func (d *PgtypeDecoder) DecodeBinary(data []byte, oid uint32) (interface{}, error) { + if data == nil { + return nil, typeutils.ErrNullValue + } + + // Handle common types with direct decoding + switch oid { + case pgtype.JSONOID: + return d.decodeJSON(data) + case pgtype.JSONBOID: + return d.decodeJSONB(data) + case pgtype.UUIDOID: + return d.decodeUUID(data) + case pgtype.Int8OID: + return d.decodeInt8(data) + case pgtype.Int4OID: + return d.decodeInt4(data) + case pgtype.Int2OID: + return d.decodeInt2(data) + case pgtype.Float8OID: + return d.decodeFloat8(data) + case pgtype.Float4OID: + return d.decodeFloat4(data) + case pgtype.BoolOID: + return d.decodeBool(data) + case pgtype.TimestampOID: + return d.decodeTimestamp(data) + case pgtype.TimestamptzOID: + return d.decodeTimestamptz(data) + case pgtype.DateOID: + return d.decodeDate(data) + case pgtype.ByteaOID: + return d.decodeBytea(data) + case pgtype.NumericOID: + return d.decodeNumeric(data) + case pgtype.TextOID, pgtype.VarcharOID, pgtype.BPCharOID: + return d.decodeText(data) + } + + // For unknown types or arrays, try generic decode or fall back to string + dt, ok := d.connInfo.DataTypeForOID(oid) + if !ok { + // Try as text for unknown types + return string(data), nil + } + + value := dt.Value + if decoder, ok := value.(pgtype.BinaryDecoder); ok { + if err := decoder.DecodeBinary(d.connInfo, data); err == nil { + return d.extractGoValue(value, oid) + } + } + + // Fallback to text decode + if decoder, ok := value.(pgtype.TextDecoder); ok { + if err := decoder.DecodeText(d.connInfo, data); err == nil { + return d.extractGoValue(value, oid) + } + } + + // Final fallback to string + return string(data), nil +} + +// Type-specific decoders +func (d *PgtypeDecoder) decodeJSON(data []byte) (interface{}, error) { + var v pgtype.JSON + if err := v.DecodeBinary(d.connInfo, data); err != nil { + // Try text format + if err := v.DecodeText(d.connInfo, data); err != nil { + return string(data), nil + } + } + if v.Status != pgtype.Present { + return nil, typeutils.ErrNullValue + } + return d.parseJSON(v.Bytes) +} + +func (d *PgtypeDecoder) decodeJSONB(data []byte) (interface{}, error) { + var v pgtype.JSONB + if err := v.DecodeBinary(d.connInfo, data); err != nil { + // Try text format + if err := v.DecodeText(d.connInfo, data); err != nil { + return string(data), nil + } + } + if v.Status != pgtype.Present { + return nil, typeutils.ErrNullValue + } + return d.parseJSON(v.Bytes) +} + +func (d *PgtypeDecoder) parseJSON(data []byte) (interface{}, error) { + // Try to parse as map first + var mapResult map[string]interface{} + if err := json.Unmarshal(data, &mapResult); err == nil { + return mapResult, nil + } + + // Try to parse as array + var arrayResult []interface{} + if err := json.Unmarshal(data, &arrayResult); err == nil { + return arrayResult, nil + } + + // Fallback to string + return string(data), nil +} + +func (d *PgtypeDecoder) decodeUUID(data []byte) (interface{}, error) { + var v pgtype.UUID + if err := v.DecodeBinary(d.connInfo, data); err != nil { + return string(data), nil + } + if v.Status != pgtype.Present { + return nil, typeutils.ErrNullValue + } + return fmt.Sprintf("%x-%x-%x-%x-%x", v.Bytes[0:4], v.Bytes[4:6], v.Bytes[6:8], v.Bytes[8:10], v.Bytes[10:16]), nil +} + +func (d *PgtypeDecoder) decodeInt8(data []byte) (interface{}, error) { + var v pgtype.Int8 + if err := v.DecodeBinary(d.connInfo, data); err != nil { + return nil, err + } + if v.Status != pgtype.Present { + return nil, typeutils.ErrNullValue + } + return v.Int, nil +} + +func (d *PgtypeDecoder) decodeInt4(data []byte) (interface{}, error) { + var v pgtype.Int4 + if err := v.DecodeBinary(d.connInfo, data); err != nil { + return nil, err + } + if v.Status != pgtype.Present { + return nil, typeutils.ErrNullValue + } + return v.Int, nil +} + +func (d *PgtypeDecoder) decodeInt2(data []byte) (interface{}, error) { + var v pgtype.Int2 + if err := v.DecodeBinary(d.connInfo, data); err != nil { + return nil, err + } + if v.Status != pgtype.Present { + return nil, typeutils.ErrNullValue + } + return v.Int, nil +} + +func (d *PgtypeDecoder) decodeFloat8(data []byte) (interface{}, error) { + var v pgtype.Float8 + if err := v.DecodeBinary(d.connInfo, data); err != nil { + return nil, err + } + if v.Status != pgtype.Present { + return nil, typeutils.ErrNullValue + } + return v.Float, nil +} + +func (d *PgtypeDecoder) decodeFloat4(data []byte) (interface{}, error) { + var v pgtype.Float4 + if err := v.DecodeBinary(d.connInfo, data); err != nil { + return nil, err + } + if v.Status != pgtype.Present { + return nil, typeutils.ErrNullValue + } + return v.Float, nil +} + +func (d *PgtypeDecoder) decodeBool(data []byte) (interface{}, error) { + var v pgtype.Bool + if err := v.DecodeBinary(d.connInfo, data); err != nil { + return nil, err + } + if v.Status != pgtype.Present { + return nil, typeutils.ErrNullValue + } + return v.Bool, nil +} + +func (d *PgtypeDecoder) decodeTimestamp(data []byte) (interface{}, error) { + var v pgtype.Timestamp + if err := v.DecodeBinary(d.connInfo, data); err != nil { + return nil, err + } + if v.Status != pgtype.Present { + return nil, typeutils.ErrNullValue + } + return v.Time, nil +} + +func (d *PgtypeDecoder) decodeTimestamptz(data []byte) (interface{}, error) { + var v pgtype.Timestamptz + if err := v.DecodeBinary(d.connInfo, data); err != nil { + return nil, err + } + if v.Status != pgtype.Present { + return nil, typeutils.ErrNullValue + } + return v.Time, nil +} + +func (d *PgtypeDecoder) decodeDate(data []byte) (interface{}, error) { + var v pgtype.Date + if err := v.DecodeBinary(d.connInfo, data); err != nil { + return nil, err + } + if v.Status != pgtype.Present { + return nil, typeutils.ErrNullValue + } + return v.Time, nil +} + +func (d *PgtypeDecoder) decodeBytea(data []byte) (interface{}, error) { + var v pgtype.Bytea + if err := v.DecodeBinary(d.connInfo, data); err != nil { + return nil, err + } + if v.Status != pgtype.Present { + return nil, typeutils.ErrNullValue + } + return v.Bytes, nil +} + +func (d *PgtypeDecoder) decodeNumeric(data []byte) (interface{}, error) { + var v pgtype.Numeric + if err := v.DecodeBinary(d.connInfo, data); err != nil { + return nil, err + } + if v.Status != pgtype.Present { + return nil, typeutils.ErrNullValue + } + // Convert numeric to float64 + var f float64 + if err := v.AssignTo(&f); err != nil { + return string(data), nil + } + return f, nil +} + +func (d *PgtypeDecoder) decodeText(data []byte) (interface{}, error) { + var v pgtype.Text + if err := v.DecodeBinary(d.connInfo, data); err != nil { + // Try as raw string + return string(data), nil + } + if v.Status != pgtype.Present { + return nil, typeutils.ErrNullValue + } + return v.String, nil +} + +// extractGoValue converts pgtype values to appropriate Go types (for generic cases) +func (d *PgtypeDecoder) extractGoValue(value pgtype.Value, oid uint32) (interface{}, error) { + // For array types and other complex types, use Get() method + if getter, ok := value.(interface{ Get() interface{} }); ok { + result := getter.Get() + if result == nil { + return nil, typeutils.ErrNullValue + } + return result, nil + } + + // Fallback: convert to AssignTo string + var s string + if err := value.AssignTo(&s); err == nil { + return s, nil + } + + return fmt.Sprintf("%v", value), nil +} diff --git a/pkg/waljs/pgoutput.go b/pkg/waljs/pgoutput.go index cbf7faf3b..46d5efbf6 100644 --- a/pkg/waljs/pgoutput.go +++ b/pkg/waljs/pgoutput.go @@ -37,8 +37,15 @@ func (p *pgoutputReplicator) StreamChanges(ctx context.Context, db *sqlx.DB, ins } p.socket.CurrentWalPosition = slot.CurrentLSN + // Build plugin arguments with defaults and custom args + pluginArgs := []string{"proto_version '1'", fmt.Sprintf("publication_names '%s'", p.publication)} + // Merge custom plugin arguments + for key, value := range p.socket.pluginArgs { + pluginArgs = append(pluginArgs, fmt.Sprintf("%s '%s'", key, value)) + } + err := pglogrepl.StartReplication(ctx, p.socket.pgConn, p.socket.ReplicationSlot, p.socket.ConfirmedFlushLSN, pglogrepl.StartReplicationOptions{ - PluginArgs: []string{"proto_version '1'", fmt.Sprintf("publication_names '%s'", p.publication)}}) + PluginArgs: pluginArgs}) if err != nil { return fmt.Errorf("failed to start replication: %v", err) } @@ -159,9 +166,9 @@ func (p *pgoutputReplicator) tupleValuesToMap(rel *pglogrepl.RelationMessage, tu continue } - // Convert according to OID to string - typeName := oidToString(colType) - val, err := p.socket.changeFilter.converter(string(col.Data), typeName) + // Use pgtype decoder for structured binary decoding + // This eliminates binary -> string -> type conversion + val, err := p.socket.decoder.DecodeBinary(col.Data, colType) if err != nil && err != typeutils.ErrNullValue { return nil, err } @@ -227,51 +234,70 @@ func (p *pgoutputReplicator) emitDelete(ctx context.Context, m *pglogrepl.Delete return insertFn(ctx, abstract.CDCChange{Stream: stream, Timestamp: p.txnCommitTime, Kind: "delete", Data: values}) } -// OIDToString converts a PostgreSQL OID to its string representation -func oidToString(oid uint32) string { - if typeName, ok := oidToTypeName[oid]; ok { - return typeName - } - logger.Warnf("unknown oid[%d] falling back to string", oid) - // default to json, which will be converted to string - return "json" -} - // OidToTypeName maps PostgreSQL OIDs to their corresponding type names +// Using pgtype constants for automatic OID resolution reduces manual maintenance var oidToTypeName = map[uint32]string{ - pgtype.BoolOID: "bool", - pgtype.ByteaOID: "bytea", - pgtype.Int8OID: "int8", - pgtype.Int2OID: "int2", - pgtype.Int4OID: "int4", - pgtype.TextOID: "text", - pgtype.UUIDOID: "uuid", - pgtype.JSONOID: "json", - pgtype.Float4OID: "float4", - pgtype.Float8OID: "float8", + // Basic types + pgtype.BoolOID: "bool", + pgtype.ByteaOID: "bytea", + pgtype.Int8OID: "int8", + pgtype.Int2OID: "int2", + pgtype.Int4OID: "int4", + pgtype.TextOID: "text", + pgtype.UUIDOID: "uuid", + pgtype.JSONOID: "json", + pgtype.JSONBOID: "jsonb", + pgtype.Float4OID: "float4", + pgtype.Float8OID: "float8", + pgtype.NameOID: "name", + + // Character types + pgtype.BPCharOID: "bpchar", + pgtype.VarcharOID: "varchar", + + // Date/Time types + pgtype.DateOID: "date", + pgtype.TimeOID: "time", + pgtype.TimestampOID: "timestamp", + pgtype.TimestamptzOID: "timestamptz", + pgtype.IntervalOID: "interval", + + // Network types + pgtype.InetOID: "inet", + pgtype.CIDROID: "cidr", + pgtype.MacaddrOID: "macaddr", + + // Geometric types + pgtype.PointOID: "point", + pgtype.LineOID: "line", + pgtype.LsegOID: "lseg", + pgtype.BoxOID: "box", + pgtype.PathOID: "path", + pgtype.PolygonOID: "polygon", + pgtype.CircleOID: "circle", + + // Other types + pgtype.BitOID: "bit", + pgtype.VarbitOID: "varbit", + pgtype.NumericOID: "numeric", + pgtype.OIDOID: "oid", + + // Array types pgtype.BoolArrayOID: "bool[]", pgtype.Int2ArrayOID: "int2[]", pgtype.Int4ArrayOID: "int4[]", + pgtype.Int8ArrayOID: "int8[]", pgtype.TextArrayOID: "text[]", pgtype.ByteaArrayOID: "bytea[]", - pgtype.Int8ArrayOID: "int8[]", pgtype.Float4ArrayOID: "float4[]", pgtype.Float8ArrayOID: "float8[]", - pgtype.BPCharOID: "bpchar", - pgtype.VarcharOID: "varchar", - pgtype.DateOID: "date", - pgtype.TimeOID: "time", - pgtype.TimestampOID: "timestamp", pgtype.TimestampArrayOID: "timestamp[]", - pgtype.DateArrayOID: "date[]", - pgtype.TimestamptzOID: "timestamptz", pgtype.TimestamptzArrayOID: "timestamptz[]", - pgtype.IntervalOID: "interval", + pgtype.DateArrayOID: "date[]", pgtype.NumericArrayOID: "numeric[]", - pgtype.BitOID: "bit", - pgtype.VarbitOID: "varbit", - pgtype.NumericOID: "numeric", pgtype.UUIDArrayOID: "uuid[]", - pgtype.JSONBOID: "jsonb", pgtype.JSONBArrayOID: "jsonb[]", + pgtype.VarcharArrayOID: "varchar[]", + pgtype.InetArrayOID: "inet[]", + pgtype.CIDRArrayOID: "cidr[]", } diff --git a/pkg/waljs/replicator.go b/pkg/waljs/replicator.go index f026cc622..7eb876596 100644 --- a/pkg/waljs/replicator.go +++ b/pkg/waljs/replicator.go @@ -37,6 +37,10 @@ type Socket struct { ReplicationSlot string // initialWaitTime is the duration to wait for first wal log catchup before timing out initialWaitTime time.Duration + // pluginArgs holds custom replication plugin arguments + pluginArgs map[string]string + // decoder handles structured decoding of PostgreSQL binary data using pgtype + decoder *PgtypeDecoder } // Replicator defines an abstraction over different logical decoding plugins. @@ -111,6 +115,8 @@ func NewReplicator(ctx context.Context, db *sqlx.DB, config *Config, typeConvert CurrentWalPosition: slot.CurrentLSN, ReplicationSlot: config.ReplicationSlotName, initialWaitTime: config.InitialWaitTime, + pluginArgs: config.PluginArgs, + decoder: NewPgtypeDecoder(), } plugin := strings.ToLower(strings.TrimSpace(slot.Plugin)) diff --git a/pkg/waljs/types.go b/pkg/waljs/types.go index 8dca4bd1d..03ae0cb33 100644 --- a/pkg/waljs/types.go +++ b/pkg/waljs/types.go @@ -21,6 +21,8 @@ type Config struct { BatchSize int // Publications is used with pgoutput Publication string + // PluginArgs allows passing custom replication plugin arguments + PluginArgs map[string]string } type WALState struct { diff --git a/pkg/waljs/waljs.go b/pkg/waljs/waljs.go index 9b65f33c8..c7be57dca 100644 --- a/pkg/waljs/waljs.go +++ b/pkg/waljs/waljs.go @@ -16,12 +16,6 @@ import ( const AdvanceLSNTemplate = "SELECT * FROM pg_replication_slot_advance('%s', '%s')" -var pluginArguments = []string{ - "\"include-lsn\" 'on'", - "\"pretty-print\" 'off'", - "\"include-timestamp\" 'on'", -} - // wal2jsonReplicator implements Replicator for wal2json plugin type wal2jsonReplicator struct { socket *Socket @@ -41,12 +35,25 @@ func (w *wal2jsonReplicator) StreamChanges(ctx context.Context, db *sqlx.DB, cal // update current wal lsn w.socket.CurrentWalPosition = slot.CurrentLSN - // Start logical replication with wal2json plugin arguments. + // Build wal2json plugin arguments var tables []string for key := range w.socket.changeFilter.tables { tables = append(tables, key) } - pluginArguments = append(pluginArguments, fmt.Sprintf("\"add-tables\" '%s'", strings.Join(tables, ","))) + + // Default wal2json arguments + pluginArguments := []string{ + "\"include-lsn\" 'on'", + "\"pretty-print\" 'off'", + "\"include-timestamp\" 'on'", + fmt.Sprintf("\"add-tables\" '%s'", strings.Join(tables, ",")), + } + + // Merge custom plugin arguments + for key, value := range w.socket.pluginArgs { + pluginArguments = append(pluginArguments, fmt.Sprintf("\"%s\" '%s'", key, value)) + } + if err := pglogrepl.StartReplication( ctx, w.socket.pgConn, diff --git a/types/data_types.go b/types/data_types.go index 70ef5046b..380034283 100644 --- a/types/data_types.go +++ b/types/data_types.go @@ -107,9 +107,12 @@ func (d DataType) ToNewParquet() parquet.Node { n = parquet.Leaf(parquet.BooleanType) case Timestamp, TimestampMilli, TimestampMicro, TimestampNano: n = parquet.Timestamp(parquet.Microsecond) - case Object, Array: - // Ensure proper handling of nested structures - n = parquet.String() + case Object: + // JSON/JSONB objects stored as Parquet JSON type (preserves structure) + n = parquet.JSON() + case Array: + // Arrays stored as Parquet JSON type (preserves structure) + n = parquet.JSON() default: n = parquet.Leaf(parquet.ByteArrayType) } diff --git a/utils/typeutils/flatten.go b/utils/typeutils/flatten.go index a76d38761..9f564691f 100644 --- a/utils/typeutils/flatten.go +++ b/utils/typeutils/flatten.go @@ -5,8 +5,6 @@ import ( "reflect" "time" - "github.com/goccy/go-json" - "github.com/datazip-inc/olake/types" "github.com/datazip-inc/olake/utils" ) @@ -43,18 +41,10 @@ func (f *FlattenerImpl) flatten(key string, value any, destination types.Record) key = utils.Reformat(key) t := reflect.ValueOf(value) switch t.Kind() { - case reflect.Slice: // Stringify arrays - b, err := json.Marshal(value) - if err != nil { - return fmt.Errorf("error marshaling array with key %s: %v", key, err) - } - destination[key] = string(b) - case reflect.Map: // Stringify nested maps - b, err := json.Marshal(value) - if err != nil { - return fmt.Errorf("error marshaling array with key[%s] and value %v: %v", key, value, err) - } - destination[key] = string(b) + case reflect.Slice: // Preserve arrays as structured types + destination[key] = value + case reflect.Map: // Preserve maps as structured types (JSON/JSONB) + destination[key] = value case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String: diff --git a/utils/typeutils/reformat.go b/utils/typeutils/reformat.go index 0bfc2da5c..e3cc00a5b 100644 --- a/utils/typeutils/reformat.go +++ b/utils/typeutils/reformat.go @@ -2,6 +2,7 @@ package typeutils import ( "database/sql" + "encoding/json" "fmt" "strconv" "strings" @@ -105,13 +106,47 @@ func ReformatValue(dataType types.DataType, v any) (any, error) { return ReformatFloat32(v) case types.Float64: return ReformatFloat64(v) + case types.Object: + // Marshal maps to JSON for Parquet JSON logical type + // Parquet JSON columns expect JSON strings, not Go maps + switch val := v.(type) { + case map[string]interface{}: + jsonBytes, err := json.Marshal(val) + if err != nil { + return fmt.Sprintf("%v", v), fmt.Errorf("failed to marshal object to JSON: %w", err) + } + return string(jsonBytes), nil + case string: + // Already a JSON string + return val, nil + default: + // Fallback: try to marshal whatever it is + jsonBytes, err := json.Marshal(v) + if err != nil { + return fmt.Sprintf("%v", v), nil + } + return string(jsonBytes), nil + } case types.Array: - if value, isArray := v.([]any); isArray { - return value, nil + // Marshal arrays to JSON for Parquet JSON logical type + switch val := v.(type) { + case []interface{}: + jsonBytes, err := json.Marshal(val) + if err != nil { + return fmt.Sprintf("%v", v), fmt.Errorf("failed to marshal array to JSON: %w", err) + } + return string(jsonBytes), nil + case string: + // Already a JSON string + return val, nil + default: + // Fallback: try to marshal as single-element array + jsonBytes, err := json.Marshal([]any{v}) + if err != nil { + return fmt.Sprintf("%v", v), nil + } + return string(jsonBytes), nil } - - // make it an array - return []any{v}, nil default: return v, nil }