diff --git a/server/cast/int32.go b/server/cast/int32.go index 6fecd8198c..5e95521918 100644 --- a/server/cast/int32.go +++ b/server/cast/int32.go @@ -15,6 +15,9 @@ package cast import ( + "strconv" + "strings" + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" @@ -34,6 +37,24 @@ func initInt32(builtInCasts map[id.Cast]casts.Cast) { // int32Explicit registers all explicit casts. This comprises only the source types. func int32Explicit(builtInCasts map[id.Cast]casts.Cast) { + framework.MustAddExplicitTypeCast(builtInCasts, framework.TypeCast{ + FromType: pgtypes.Int32, + ToType: pgtypes.Bit, + Function: func(ctx *sql.Context, val any, _, targetType *pgtypes.DoltgresType) (any, error) { + width := 1 + if attTypMod := targetType.GetAttTypMod(); attTypMod != -1 { + width = int(attTypMod) + } + bitStr := strconv.FormatInt(int64(val.(int32)), 2) + if len(bitStr) > width { + return bitStr[len(bitStr)-width:], nil + } else if len(bitStr) < width { + return strings.Repeat("0", width-len(bitStr)) + bitStr, nil + } else { + return bitStr, nil + } + }, + }) framework.MustAddExplicitTypeCast(builtInCasts, framework.TypeCast{ FromType: pgtypes.Int32, ToType: pgtypes.Bool, diff --git a/testing/go/pgcatalog_test.go b/testing/go/pgcatalog_test.go index 1eecaf5fef..c24819736c 100644 --- a/testing/go/pgcatalog_test.go +++ b/testing/go/pgcatalog_test.go @@ -484,7 +484,7 @@ func TestPgCast(t *testing.T) { Assertions: []ScriptTestAssertion{ { Query: `SELECT COUNT(*) FROM "pg_catalog"."pg_cast";`, - Expected: []sql.Row{{117}}, + Expected: []sql.Row{{118}}, }, { // Different cases and quoted, so it fails Query: `SELECT * FROM "PG_catalog"."pg_cast";`, @@ -496,7 +496,7 @@ func TestPgCast(t *testing.T) { }, { // Different cases but non-quoted, so it works Query: "SELECT COUNT(*) FROM PG_catalog.pg_CAST ORDER BY oid;", - Expected: []sql.Row{{117}}, + Expected: []sql.Row{{118}}, }, }, }, @@ -5082,7 +5082,7 @@ ORDER BY 1;`, { // This is to make sure a full range scan works (we don't support a full range scan on the index yet) Query: `SELECT relname from pg_catalog.pg_class order by oid limit 1;`, - Expected: []sql.Row{sql.Row{"pg_publication_namespace"}}, + Expected: []sql.Row{{"pg_publication_namespace"}}, }, { Query: `EXPLAIN SELECT c.oid @@ -6045,9 +6045,9 @@ func TestSystemTablesInPgcatalog(t *testing.T) { {245736992, "dolt_conflicts", 2200, "r"}, {1932298159, "dolt_constraint_violations", 2200, "r"}, {2357712556, "dolt_diff", 2200, "r"}, - sql.Row{101228732, "dolt_diff_commit_hash_key", 2200, "i"}, + {101228732, "dolt_diff_commit_hash_key", 2200, "i"}, {3491847678, "dolt_log", 2200, "r"}, - sql.Row{2292720014, "dolt_log_commit_hash_key", 2200, "i"}, + {2292720014, "dolt_log_commit_hash_key", 2200, "i"}, {604995978, "dolt_merge_status", 2200, "r"}, {887648921, "dolt_remote_branches", 2200, "r"}, {1471391189, "dolt_remote_branches_dolt_branches_name_idx_key", 2200, "i"}, diff --git a/testing/go/types_test.go b/testing/go/types_test.go index 90598f8a7b..48742d7353 100644 --- a/testing/go/types_test.go +++ b/testing/go/types_test.go @@ -95,6 +95,27 @@ var typesTests = []ScriptTest{ {2, pgtype.Bits{Bytes: []uint8{0x2b}, Len: 8, Valid: true}, pgtype.Bits{Bytes: []uint8{0x0}, Len: 3, Valid: true}}, }, }, + { + Query: "SELECT 0::bit, 1::bit, 2::bit, 3::bit, 4::bit, 5::bit(2), 6::bit(2);", + Expected: []sql.Row{{ + pgtype.Bits{Bytes: []uint8{0x00}, Len: 1, Valid: true}, + pgtype.Bits{Bytes: []uint8{0x80}, Len: 1, Valid: true}, + pgtype.Bits{Bytes: []uint8{0x00}, Len: 1, Valid: true}, + pgtype.Bits{Bytes: []uint8{0x80}, Len: 1, Valid: true}, + pgtype.Bits{Bytes: []uint8{0x00}, Len: 1, Valid: true}, + pgtype.Bits{Bytes: []uint8{0x40}, Len: 2, Valid: true}, + pgtype.Bits{Bytes: []uint8{0x80}, Len: 2, Valid: true}, + }}, + }, + { + Query: "SELECT (-1)::bit, (-2)::bit, (-5)::bit(2), (-6::int4)::bit(2);", + Expected: []sql.Row{{ + pgtype.Bits{Bytes: []uint8{0x80}, Len: 1, Valid: true}, + pgtype.Bits{Bytes: []uint8{0x00}, Len: 1, Valid: true}, + pgtype.Bits{Bytes: []uint8{0x40}, Len: 2, Valid: true}, + pgtype.Bits{Bytes: []uint8{0x80}, Len: 2, Valid: true}, + }}, + }, { Query: "INSERT INTO t_bit VALUES (3, B'101', '111');", ExpectedErr: "bit string length 3 does not match type bit(8)",