diff --git a/docs/commands/data-diff.md b/docs/commands/data-diff.md index ad66501ad..b1053ffde 100644 --- a/docs/commands/data-diff.md +++ b/docs/commands/data-diff.md @@ -152,6 +152,7 @@ The `data-diff` command includes specialized type mapping support for the follow - **BigQuery** - Native support for BigQuery types including `INT64`, `FLOAT64`, `BIGNUMERIC`, and BigQuery-specific formatting - **PostgreSQL & AWS Redshift** - Complete support for PostgreSQL types including `SERIAL` types, `MONEY`, network types (`CIDR`, `INET`), and `JSONB` - **Snowflake** - Full support for Snowflake types including `NUMBER`, `VARIANT`, and timezone-aware timestamp types +- **Athena** - Supports AWS Athena (Trino/Presto) type system including `DECIMAL`, `DOUBLE`, `BOOLEAN`, `JSON`, and timestamp types for data-diff comparisons ### Type Mapping Features @@ -230,4 +231,4 @@ Either use the format `connection:table` or add the `--connection` flag. ### Error: "failed to get connection" The specified connection name doesn't exist in your configuration. -Check your `.bruin.yml` or other [secrets backend](../secrets/overview.md) file for available connections. \ No newline at end of file +Check your `.bruin.yml` or other [secrets backend](../secrets/overview.md) file for available connections. diff --git a/pkg/athena/db.go b/pkg/athena/db.go index 07279fa27..d50adc131 100644 --- a/pkg/athena/db.go +++ b/pkg/athena/db.go @@ -4,10 +4,13 @@ import ( "context" "fmt" "sort" + "strconv" "strings" "sync" + "time" "github.com/bruin-data/bruin/pkg/ansisql" + "github.com/bruin-data/bruin/pkg/diff" "github.com/bruin-data/bruin/pkg/query" "github.com/jmoiron/sqlx" "github.com/pkg/errors" @@ -18,12 +21,15 @@ type DB struct { conn *sqlx.DB config *Config mutex sync.Mutex + // typeMapper normalizes Athena column types for data-diff support. + typeMapper *diff.DatabaseTypeMapper } func NewDB(c *Config) *DB { return &DB{ - config: c, - mutex: sync.Mutex{}, + config: c, + mutex: sync.Mutex{}, + typeMapper: diff.NewAthenaTypeMapper(), } } @@ -309,6 +315,623 @@ ORDER BY ordinal_position; return columns, nil } +func (db *DB) GetTableSummary(ctx context.Context, tableName string, schemaOnly bool) (*diff.TableSummaryResult, error) { + if tableName == "" { + return nil, errors.New("table name cannot be empty") + } + + if err := db.initializeDB(); err != nil { + return nil, err + } + + if db.config == nil { + return nil, errors.New("athena config is not initialized") + } + + schemaName, tableNameOnly, err := db.parseTableName(tableName) + if err != nil { + return nil, err + } + + fullTableIdentifier := buildFullyQualifiedTableName(schemaName, tableNameOnly) + + var rowCount int64 + if !schemaOnly { + rowCount, err = db.fetchRowCount(ctx, fullTableIdentifier) + if err != nil { + return nil, err + } + } + + columns, err := db.fetchColumns(ctx, schemaName, tableNameOnly, fullTableIdentifier, schemaOnly) + if err != nil { + return nil, err + } + + table := &diff.Table{ + Name: tableName, + Columns: columns, + } + + return &diff.TableSummaryResult{ + RowCount: rowCount, + Table: table, + }, nil +} + +func (db *DB) fetchRowCount(ctx context.Context, fullTableIdentifier string) (int64, error) { + queryString := fmt.Sprintf(`SELECT COUNT(*) AS row_count FROM %s`, fullTableIdentifier) + rows, err := db.conn.QueryContext(ctx, queryString) + if err != nil { + return 0, fmt.Errorf("failed to execute row count query for '%s': %w", fullTableIdentifier, err) + } + defer rows.Close() + + if !rows.Next() { + return 0, errors.New("row count query returned no rows") + } + + var countValue interface{} + if err := rows.Scan(&countValue); err != nil { + return 0, fmt.Errorf("failed to scan row count for '%s': %w", fullTableIdentifier, err) + } + + count, err := asInt64(countValue) + if err != nil { + return 0, fmt.Errorf("failed to parse row count for '%s': %w", fullTableIdentifier, err) + } + + return count, rows.Err() +} + +func (db *DB) fetchColumns(ctx context.Context, schemaName, tableName, fullTableIdentifier string, schemaOnly bool) ([]*diff.Column, error) { + schemaQuery := strings.TrimSpace(fmt.Sprintf(` +SELECT + column_name, + data_type, + is_nullable +FROM information_schema.columns +WHERE table_schema = '%s' AND table_name = '%s' +ORDER BY ordinal_position; +`, schemaName, tableName)) + + result, err := db.Select(ctx, &query.Query{Query: schemaQuery}) + if err != nil { + return nil, fmt.Errorf("failed to fetch column metadata for '%s.%s': %w", schemaName, tableName, err) + } + + columns := make([]*diff.Column, 0, len(result)) + for _, row := range result { + if len(row) < 3 { + continue + } + + columnName := toString(row[0]) + if columnName == "" { + continue + } + + dataType := toString(row[1]) + if dataType == "" { + continue + } + + isNullable := strings.EqualFold(toString(row[2]), "YES") + + normalizedType := db.getTypeMapper().MapType(dataType) + + var stats diff.ColumnStatistics + if !schemaOnly { + stats, err = db.fetchColumnStatistics(ctx, normalizedType, fullTableIdentifier, columnName) + if err != nil { + return nil, fmt.Errorf("failed to fetch statistics for column '%s': %w", columnName, err) + } + } + + columns = append(columns, &diff.Column{ + Name: columnName, + Type: dataType, + NormalizedType: normalizedType, + Nullable: isNullable, + PrimaryKey: false, + Unique: false, + Stats: stats, + }) + } + + return columns, nil +} + +func (db *DB) fetchColumnStatistics(ctx context.Context, normalizedType diff.CommonDataType, fullTableIdentifier, columnName string) (diff.ColumnStatistics, error) { + switch normalizedType { + case diff.CommonTypeNumeric: + return db.fetchNumericalStats(ctx, fullTableIdentifier, columnName) + case diff.CommonTypeString: + return db.fetchStringStats(ctx, fullTableIdentifier, columnName) + case diff.CommonTypeBoolean: + return db.fetchBooleanStats(ctx, fullTableIdentifier, columnName) + case diff.CommonTypeDateTime: + return db.fetchDateTimeStats(ctx, fullTableIdentifier, columnName) + case diff.CommonTypeJSON: + return db.fetchJSONStats(ctx, fullTableIdentifier, columnName) + default: + return &diff.UnknownStatistics{}, nil + } +} + +func (db *DB) fetchNumericalStats(ctx context.Context, fullTableIdentifier, columnName string) (*diff.NumericalStatistics, error) { + columnExpr := quoteIdentifier(columnName) + queryString := strings.TrimSpace(fmt.Sprintf(` +SELECT + COUNT(*) AS count, + COUNT(*) - COUNT(%[1]s) AS null_count, + MIN(%[1]s) AS min_val, + MAX(%[1]s) AS max_val, + AVG(CAST(%[1]s AS DOUBLE)) AS avg_val, + SUM(CAST(%[1]s AS DOUBLE)) AS sum_val, + STDDEV_POP(CAST(%[1]s AS DOUBLE)) AS stddev_val +FROM %[2]s +`, columnExpr, fullTableIdentifier)) + + row := db.conn.QueryRowContext(ctx, queryString) + + var ( + countVal, nullCountVal interface{} + minVal, maxVal, avgVal, sumVal, stddevVal interface{} + ) + + if err := row.Scan(&countVal, &nullCountVal, &minVal, &maxVal, &avgVal, &sumVal, &stddevVal); err != nil { + return nil, fmt.Errorf("failed to scan numerical stats for column '%s': %w", columnName, err) + } + + count, err := asInt64(countVal) + if err != nil { + return nil, fmt.Errorf("failed to parse count value for column '%s': %w", columnName, err) + } + + nullCount, err := asInt64(nullCountVal) + if err != nil { + return nil, fmt.Errorf("failed to parse null count for column '%s': %w", columnName, err) + } + + return &diff.NumericalStatistics{ + Count: count, + NullCount: nullCount, + Min: asFloatPointer(minVal), + Max: asFloatPointer(maxVal), + Avg: asFloatPointer(avgVal), + Sum: asFloatPointer(sumVal), + StdDev: asFloatPointer(stddevVal), + }, nil +} + +func (db *DB) fetchStringStats(ctx context.Context, fullTableIdentifier, columnName string) (*diff.StringStatistics, error) { + columnExpr := quoteIdentifier(columnName) + queryString := strings.TrimSpace(fmt.Sprintf(` +SELECT + COUNT(*) AS count, + COUNT(*) - COUNT(%[1]s) AS null_count, + APPROX_DISTINCT(%[1]s) AS distinct_count, + COUNT_IF(%[1]s = '') AS empty_count, + MIN(LENGTH(%[1]s)) AS min_length, + MAX(LENGTH(%[1]s)) AS max_length, + AVG(CAST(LENGTH(%[1]s) AS DOUBLE)) AS avg_length +FROM %[2]s +`, columnExpr, fullTableIdentifier)) + + row := db.conn.QueryRowContext(ctx, queryString) + + var ( + countVal, nullCountVal, distinctCountVal, emptyCountVal interface{} + minLengthVal, maxLengthVal interface{} + avgLengthVal interface{} + ) + + if err := row.Scan(&countVal, &nullCountVal, &distinctCountVal, &emptyCountVal, &minLengthVal, &maxLengthVal, &avgLengthVal); err != nil { + return nil, fmt.Errorf("failed to scan string stats for column '%s': %w", columnName, err) + } + + count, err := asInt64(countVal) + if err != nil { + return nil, fmt.Errorf("failed to parse count value for column '%s': %w", columnName, err) + } + + nullCount, err := asInt64(nullCountVal) + if err != nil { + return nil, fmt.Errorf("failed to parse null count for column '%s': %w", columnName, err) + } + + distinctCount, err := asInt64(distinctCountVal) + if err != nil { + return nil, fmt.Errorf("failed to parse distinct count for column '%s': %w", columnName, err) + } + + emptyCount, err := asInt64(emptyCountVal) + if err != nil { + return nil, fmt.Errorf("failed to parse empty count for column '%s': %w", columnName, err) + } + + minLength, err := asInt(minLengthVal) + if err != nil { + return nil, fmt.Errorf("failed to parse min length for column '%s': %w", columnName, err) + } + + maxLength, err := asInt(maxLengthVal) + if err != nil { + return nil, fmt.Errorf("failed to parse max length for column '%s': %w", columnName, err) + } + + avgLength := asFloatPointer(avgLengthVal) + avgLengthValue := 0.0 + if avgLength != nil { + avgLengthValue = *avgLength + } + + return &diff.StringStatistics{ + Count: count, + NullCount: nullCount, + DistinctCount: distinctCount, + EmptyCount: emptyCount, + MinLength: minLength, + MaxLength: maxLength, + AvgLength: avgLengthValue, + }, nil +} + +func (db *DB) fetchBooleanStats(ctx context.Context, fullTableIdentifier, columnName string) (*diff.BooleanStatistics, error) { + columnExpr := quoteIdentifier(columnName) + queryString := strings.TrimSpace(fmt.Sprintf(` +SELECT + COUNT(*) AS count, + COUNT(*) - COUNT(%[1]s) AS null_count, + COUNT_IF(%[1]s = TRUE) AS true_count, + COUNT_IF(%[1]s = FALSE) AS false_count +FROM %[2]s +`, columnExpr, fullTableIdentifier)) + + row := db.conn.QueryRowContext(ctx, queryString) + + var ( + countVal, nullCountVal, trueCountVal, falseCountVal interface{} + ) + + if err := row.Scan(&countVal, &nullCountVal, &trueCountVal, &falseCountVal); err != nil { + return nil, fmt.Errorf("failed to scan boolean stats for column '%s': %w", columnName, err) + } + + count, err := asInt64(countVal) + if err != nil { + return nil, fmt.Errorf("failed to parse count value for column '%s': %w", columnName, err) + } + + nullCount, err := asInt64(nullCountVal) + if err != nil { + return nil, fmt.Errorf("failed to parse null count for column '%s': %w", columnName, err) + } + + trueCount, err := asInt64(trueCountVal) + if err != nil { + return nil, fmt.Errorf("failed to parse true count for column '%s': %w", columnName, err) + } + + falseCount, err := asInt64(falseCountVal) + if err != nil { + return nil, fmt.Errorf("failed to parse false count for column '%s': %w", columnName, err) + } + + return &diff.BooleanStatistics{ + Count: count, + NullCount: nullCount, + TrueCount: trueCount, + FalseCount: falseCount, + }, nil +} + +func (db *DB) fetchDateTimeStats(ctx context.Context, fullTableIdentifier, columnName string) (*diff.DateTimeStatistics, error) { + columnExpr := quoteIdentifier(columnName) + queryString := strings.TrimSpace(fmt.Sprintf(` +SELECT + COUNT(*) AS count, + COUNT(*) - COUNT(%[1]s) AS null_count, + APPROX_DISTINCT(%[1]s) AS unique_count, + MIN(%[1]s) AS earliest_date, + MAX(%[1]s) AS latest_date +FROM %[2]s +`, columnExpr, fullTableIdentifier)) + + row := db.conn.QueryRowContext(ctx, queryString) + + var ( + countVal, nullCountVal, uniqueCountVal interface{} + earliestDateVal, latestDateVal interface{} + ) + + if err := row.Scan(&countVal, &nullCountVal, &uniqueCountVal, &earliestDateVal, &latestDateVal); err != nil { + return nil, fmt.Errorf("failed to scan datetime stats for column '%s': %w", columnName, err) + } + + count, err := asInt64(countVal) + if err != nil { + return nil, fmt.Errorf("failed to parse count value for column '%s': %w", columnName, err) + } + + nullCount, err := asInt64(nullCountVal) + if err != nil { + return nil, fmt.Errorf("failed to parse null count for column '%s': %w", columnName, err) + } + + uniqueCount, err := asInt64(uniqueCountVal) + if err != nil { + return nil, fmt.Errorf("failed to parse unique count for column '%s': %w", columnName, err) + } + + earliest, err := asTimePointer(earliestDateVal) + if err != nil { + return nil, fmt.Errorf("failed to parse earliest date for column '%s': %w", columnName, err) + } + + latest, err := asTimePointer(latestDateVal) + if err != nil { + return nil, fmt.Errorf("failed to parse latest date for column '%s': %w", columnName, err) + } + + return &diff.DateTimeStatistics{ + Count: count, + NullCount: nullCount, + UniqueCount: uniqueCount, + EarliestDate: earliest, + LatestDate: latest, + }, nil +} + +func (db *DB) fetchJSONStats(ctx context.Context, fullTableIdentifier, columnName string) (*diff.JSONStatistics, error) { + columnExpr := quoteIdentifier(columnName) + queryString := strings.TrimSpace(fmt.Sprintf(` +SELECT + COUNT(*) AS count, + COUNT(*) - COUNT(%[1]s) AS null_count +FROM %[2]s +`, columnExpr, fullTableIdentifier)) + + row := db.conn.QueryRowContext(ctx, queryString) + + var ( + countVal, nullCountVal interface{} + ) + + if err := row.Scan(&countVal, &nullCountVal); err != nil { + return nil, fmt.Errorf("failed to scan JSON stats for column '%s': %w", columnName, err) + } + + count, err := asInt64(countVal) + if err != nil { + return nil, fmt.Errorf("failed to parse count value for column '%s': %w", columnName, err) + } + + nullCount, err := asInt64(nullCountVal) + if err != nil { + return nil, fmt.Errorf("failed to parse null count for column '%s': %w", columnName, err) + } + + return &diff.JSONStatistics{ + Count: count, + NullCount: nullCount, + }, nil +} + +func (db *DB) parseTableName(tableName string) (string, string, error) { + parts := strings.Split(tableName, ".") + switch len(parts) { + case 1: + if db.config == nil || db.config.Database == "" { + return "", "", errors.New("athena database (schema) must be specified") + } + return normalizeIdentifier(db.config.Database), normalizeIdentifier(strings.TrimSpace(parts[0])), nil + case 2: + return normalizeIdentifier(strings.TrimSpace(parts[0])), normalizeIdentifier(strings.TrimSpace(parts[1])), nil + default: + return "", "", fmt.Errorf("invalid table name format '%s', expected schema.table or table", tableName) + } +} + +func (db *DB) getTypeMapper() *diff.DatabaseTypeMapper { + if db.typeMapper == nil { + db.typeMapper = diff.NewAthenaTypeMapper() + } + return db.typeMapper +} + +func buildFullyQualifiedTableName(schema, table string) string { + if schema == "" { + return quoteIdentifier(table) + } + return fmt.Sprintf("%s.%s", quoteIdentifier(schema), quoteIdentifier(table)) +} + +func quoteIdentifier(identifier string) string { + escaped := strings.ReplaceAll(identifier, `"`, `""`) + return fmt.Sprintf(`"%s"`, escaped) +} + +func normalizeIdentifier(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return value + } + value = strings.TrimPrefix(value, `"`) + value = strings.TrimSuffix(value, `"`) + return value +} + +func toString(value interface{}) string { + switch v := value.(type) { + case nil: + return "" + case string: + return v + case []byte: + return string(v) + default: + return fmt.Sprint(v) + } +} + +func asInt64(value interface{}) (int64, error) { + switch v := value.(type) { + case nil: + return 0, nil + case int64: + return v, nil + case int32: + return int64(v), nil + case int: + return int64(v), nil + case uint64: + return int64(v), nil + case uint32: + return int64(v), nil + case float64: + return int64(v), nil + case float32: + return int64(v), nil + case string: + if v == "" { + return 0, nil + } + parsed, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return 0, err + } + return parsed, nil + case []byte: + if len(v) == 0 { + return 0, nil + } + parsed, err := strconv.ParseInt(string(v), 10, 64) + if err != nil { + return 0, err + } + return parsed, nil + default: + return 0, fmt.Errorf("unsupported integer type %T", value) + } +} + +func asInt(value interface{}) (int, error) { + switch v := value.(type) { + case nil: + return 0, nil + case int: + return v, nil + case int64: + return int(v), nil + case int32: + return int(v), nil + case uint64: + return int(v), nil + case uint32: + return int(v), nil + case float64: + return int(v), nil + case float32: + return int(v), nil + case string: + if v == "" { + return 0, nil + } + parsed, err := strconv.Atoi(v) + if err != nil { + return 0, err + } + return parsed, nil + case []byte: + if len(v) == 0 { + return 0, nil + } + parsed, err := strconv.Atoi(string(v)) + if err != nil { + return 0, err + } + return parsed, nil + default: + return 0, fmt.Errorf("unsupported integer type %T", value) + } +} + +func asFloatPointer(value interface{}) *float64 { + switch v := value.(type) { + case nil: + return nil + case float64: + return &v + case float32: + f := float64(v) + return &f + case int64: + f := float64(v) + return &f + case int32: + f := float64(v) + return &f + case int: + f := float64(v) + return &f + case uint64: + f := float64(v) + return &f + case uint32: + f := float64(v) + return &f + case string: + if v == "" { + return nil + } + parsed, err := strconv.ParseFloat(v, 64) + if err != nil { + return nil + } + return &parsed + case []byte: + if len(v) == 0 { + return nil + } + parsed, err := strconv.ParseFloat(string(v), 64) + if err != nil { + return nil + } + return &parsed + default: + parsed, err := strconv.ParseFloat(fmt.Sprint(v), 64) + if err != nil { + return nil + } + return &parsed + } +} + +func asTimePointer(value interface{}) (*time.Time, error) { + switch v := value.(type) { + case nil: + return nil, nil + case time.Time: + return &v, nil + case *time.Time: + return v, nil + case string: + if v == "" { + return nil, nil + } + return diff.ParseDateTime(v) + case []byte: + if len(v) == 0 { + return nil, nil + } + return diff.ParseDateTime(string(v)) + default: + return diff.ParseDateTime(fmt.Sprint(v)) + } +} + func (db *DB) GetDatabaseSummary(ctx context.Context) (*ansisql.DBDatabase, error) { // Athena uses AWS Glue Data Catalog // We'll query INFORMATION_SCHEMA to get all schemas and tables diff --git a/pkg/athena/db_test.go b/pkg/athena/db_test.go index 8edc85df6..653647dfe 100644 --- a/pkg/athena/db_test.go +++ b/pkg/athena/db_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/DATA-DOG/go-sqlmock" + "github.com/bruin-data/bruin/pkg/diff" "github.com/bruin-data/bruin/pkg/query" "github.com/jmoiron/sqlx" "github.com/stretchr/testify/assert" @@ -264,6 +265,137 @@ func TestDB_SelectWithSchema(t *testing.T) { } } +func TestDB_GetTableSummary_WithStatistics(t *testing.T) { + t.Parallel() + + mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) + require.NoError(t, err) + defer mockDB.Close() + + sqlxDB := sqlx.NewDb(mockDB, "sqlmock") + + rowCountQuery := `SELECT COUNT(*) AS row_count FROM "default"."orders"` + mock.ExpectQuery(rowCountQuery). + WillReturnRows(sqlmock.NewRows([]string{"row_count"}).AddRow(int64(42))) + + schemaQuery := `SELECT + column_name, + data_type, + is_nullable +FROM information_schema.columns +WHERE table_schema = 'default' AND table_name = 'orders' +ORDER BY ordinal_position;` + + mock.ExpectQuery(schemaQuery).WillReturnRows( + sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable"}). + AddRow("id", "integer", "YES"), + ) + + numericStatsQuery := `SELECT + COUNT(*) AS count, + COUNT(*) - COUNT("id") AS null_count, + MIN("id") AS min_val, + MAX("id") AS max_val, + AVG(CAST("id" AS DOUBLE)) AS avg_val, + SUM(CAST("id" AS DOUBLE)) AS sum_val, + STDDEV_POP(CAST("id" AS DOUBLE)) AS stddev_val +FROM "default"."orders"` + + mock.ExpectQuery(numericStatsQuery).WillReturnRows( + sqlmock.NewRows([]string{ + "count", + "null_count", + "min_val", + "max_val", + "avg_val", + "sum_val", + "stddev_val", + }).AddRow( + int64(100), + int64(0), + int64(1), + int64(100), + float64(50.5), + float64(5050.0), + float64(10.0), + ), + ) + + db := NewDB(&Config{Database: "default"}) + db.conn = sqlxDB + + summary, err := db.GetTableSummary(context.Background(), "orders", false) + require.NoError(t, err) + + require.Equal(t, int64(42), summary.RowCount) + require.NotNil(t, summary.Table) + require.Equal(t, "orders", summary.Table.Name) + require.Len(t, summary.Table.Columns, 1) + + column := summary.Table.Columns[0] + require.Equal(t, "id", column.Name) + require.Equal(t, diff.CommonTypeNumeric, column.NormalizedType) + require.True(t, column.Nullable) + require.NotNil(t, column.Stats) + + numericStats, ok := column.Stats.(*diff.NumericalStatistics) + require.True(t, ok) + require.Equal(t, int64(100), numericStats.Count) + require.Equal(t, int64(0), numericStats.NullCount) + require.NotNil(t, numericStats.Min) + require.Equal(t, 1.0, *numericStats.Min) + require.NotNil(t, numericStats.Max) + require.Equal(t, 100.0, *numericStats.Max) + require.NotNil(t, numericStats.Avg) + require.Equal(t, 50.5, *numericStats.Avg) + require.NotNil(t, numericStats.Sum) + require.Equal(t, 5050.0, *numericStats.Sum) + require.NotNil(t, numericStats.StdDev) + require.Equal(t, 10.0, *numericStats.StdDev) + + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestDB_GetTableSummary_SchemaOnly(t *testing.T) { + t.Parallel() + + mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) + require.NoError(t, err) + defer mockDB.Close() + + sqlxDB := sqlx.NewDb(mockDB, "sqlmock") + + schemaQuery := `SELECT + column_name, + data_type, + is_nullable +FROM information_schema.columns +WHERE table_schema = 'default' AND table_name = 'orders' +ORDER BY ordinal_position;` + + mock.ExpectQuery(schemaQuery).WillReturnRows( + sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable"}). + AddRow("customer_name", "varchar", "NO"), + ) + + db := NewDB(&Config{Database: "default"}) + db.conn = sqlxDB + + summary, err := db.GetTableSummary(context.Background(), "orders", true) + require.NoError(t, err) + + require.Equal(t, int64(0), summary.RowCount) + require.Len(t, summary.Table.Columns, 1) + + column := summary.Table.Columns[0] + require.Equal(t, "customer_name", column.Name) + require.Equal(t, diff.CommonTypeString, column.NormalizedType) + require.False(t, column.Nullable) + require.Nil(t, column.Stats) + + require.NoError(t, mock.ExpectationsWereMet()) +} + func TestDB_BuildTableExistsQuery(t *testing.T) { t.Parallel() tests := []struct { diff --git a/pkg/diff/types.go b/pkg/diff/types.go index 5bbbfbc5e..163aec84e 100644 --- a/pkg/diff/types.go +++ b/pkg/diff/types.go @@ -352,6 +352,41 @@ func NewSnowflakeTypeMapper() *DatabaseTypeMapper { return mapper } +// NewAthenaTypeMapper provides Athena-specific type mapping. +func NewAthenaTypeMapper() *DatabaseTypeMapper { + mapper := NewDatabaseTypeMapper() + + // Numeric types supported by Athena (Presto/Trino compatible) + mapper.AddNumericTypes( + "tinyint", "smallint", "integer", "int", "bigint", + "decimal", "numeric", + "double", "double precision", + "float", "real", + ) + + // String and binary-like types + mapper.AddStringTypes( + "varchar", "char", "character", "string", + ) + + // Boolean support + mapper.AddBooleanTypes( + "boolean", "bool", + ) + + // Temporal types + mapper.AddDateTimeTypes( + "date", "time", "time with time zone", + "timestamp", "timestamp with time zone", + ) + + // Binary/JSON types + mapper.AddBinaryTypes("varbinary") + mapper.AddJSONTypes("json") + + return mapper +} + type Table struct { Name string Columns []*Column