From 4372ed82d877492667932dec824bf62171414b2a Mon Sep 17 00:00:00 2001 From: Jason Fulghum Date: Mon, 15 Jun 2026 13:55:37 -0700 Subject: [PATCH] Override coalesce with a Doltgres-specific implementation that works with DoltgresTypes --- server/analyzer/type_sanitizer.go | 26 +++++ server/expression/coalesce.go | 158 ++++++++++++++++++++++++++++++ testing/go/expressions_test.go | 63 ++++++++++++ 3 files changed, 247 insertions(+) create mode 100644 server/expression/coalesce.go diff --git a/server/analyzer/type_sanitizer.go b/server/analyzer/type_sanitizer.go index 245aeed2c5..c9dd97369e 100644 --- a/server/analyzer/type_sanitizer.go +++ b/server/analyzer/type_sanitizer.go @@ -78,6 +78,32 @@ func TypeSanitizer(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, scope // Some aggregation functions cannot be wrapped due to expectations in the analyzer, so we exclude them here. switch expr.FunctionName() { case "Count", "CountDistinct", "group_concat", "JSONObjectAgg", "Sum": + case "coalesce": + // Replace GMS Coalesce with a Doltgres-native implementation that uses + // Postgres type-resolution rules (FindCommonType) to infer the result type. + // GMS's Coalesce.Type() falls back to LongText when its arguments are + // DoltgresTypes because they don't satisfy GMS's IsNumber/IsText checks. + if _, isPgCoalesce := expr.(*pgexprs.PgCoalesce); !isPgCoalesce { + children := expr.Children() + allDoltgresTypes := true + for _, child := range children { + if _, ok := child.Type(ctx).(*pgtypes.DoltgresType); !ok { + allDoltgresTypes = false + break + } + } + if allDoltgresTypes { + pgCoalesce, err := pgexprs.NewPgCoalesce(ctx, children...) + if err != nil { + return nil, transform.NewTree, err + } + return pgCoalesce, transform.NewTree, nil + } + } + // Fall through to GMSCast if children aren't DoltgresTypes yet. + if _, ok := expr.Type(ctx).(*pgtypes.DoltgresType); !ok { + return pgexprs.NewGMSCast(expr), transform.NewTree, nil + } default: // Some GMS functions wrap Doltgres parameters, so we'll only handle those that return GMS types if _, ok := expr.Type(ctx).(*pgtypes.DoltgresType); !ok { diff --git a/server/expression/coalesce.go b/server/expression/coalesce.go new file mode 100644 index 0000000000..48a835abeb --- /dev/null +++ b/server/expression/coalesce.go @@ -0,0 +1,158 @@ +// Copyright 2026 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + "fmt" + "strings" + + "github.com/dolthub/go-mysql-server/sql" + + "github.com/dolthub/doltgresql/server/functions/framework" + pgtypes "github.com/dolthub/doltgresql/server/types" +) + +// PgCoalesce is a Doltgres-native COALESCE implementation. It uses Postgres type-resolution rules +// (FindCommonType) to compute the correct result type. +type PgCoalesce struct { + args []sql.Expression + typ *pgtypes.DoltgresType +} + +var _ sql.Expression = (*PgCoalesce)(nil) +var _ sql.FunctionExpression = (*PgCoalesce)(nil) +var _ sql.CollationCoercible = (*PgCoalesce)(nil) + +// NewPgCoalesce creates a new PgCoalesce expression. +func NewPgCoalesce(ctx *sql.Context, args ...sql.Expression) (*PgCoalesce, error) { + if len(args) == 0 { + return nil, sql.ErrInvalidArgumentNumber.New("COALESCE", "1 or more", 0) + } + expr, err := (&PgCoalesce{typ: pgtypes.Unknown}).WithChildren(ctx, args...) + if err != nil { + return nil, err + } + return expr.(*PgCoalesce), nil +} + +// FunctionName implements sql.FunctionExpression. +func (c *PgCoalesce) FunctionName() string { return "coalesce" } + +// Description implements sql.FunctionExpression. +func (c *PgCoalesce) Description() string { return "returns the first non-null value in a list." } + +// Type implements sql.Expression. +func (c *PgCoalesce) Type(_ *sql.Context) sql.Type { + return c.typ +} + +// CollationCoercibility implements sql.CollationCoercible. +func (c *PgCoalesce) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { + if cc, ok := c.Type(ctx).(sql.CollationCoercible); ok { + return cc.CollationCoercibility(ctx) + } + return sql.Collation_binary, 6 +} + +// IsNullable implements sql.Expression. +func (c *PgCoalesce) IsNullable(_ *sql.Context) bool { + return true +} + +// Resolved implements sql.Expression. +func (c *PgCoalesce) Resolved() bool { + for _, arg := range c.args { + if arg == nil || !arg.Resolved() { + return false + } + } + return true +} + +// Children implements sql.Expression. +func (c *PgCoalesce) Children() []sql.Expression { return c.args } + +// WithChildren implements sql.Expression. +func (c *PgCoalesce) WithChildren(ctx *sql.Context, children ...sql.Expression) (sql.Expression, error) { + if len(children) == 0 { + return nil, sql.ErrInvalidArgumentNumber.New("COALESCE", "1 or more", 0) + } + newC := &PgCoalesce{args: children, typ: pgtypes.Unknown} + childTypes := make([]*pgtypes.DoltgresType, 0, len(children)) + for _, child := range children { + dt, ok := child.Type(ctx).(*pgtypes.DoltgresType) + if !ok { + return newC, nil + } + childTypes = append(childTypes, dt) + } + commonType, _, err := framework.FindCommonType(ctx, childTypes) + if err != nil { + return nil, err + } + if commonType != nil { + newC.typ = commonType + } + return newC, nil +} + +// Eval implements sql.Expression. Returns the first non-null argument value, cast to the common type. +func (c *PgCoalesce) Eval(ctx *sql.Context, row sql.Row) (any, error) { + commonType := c.typ + for _, arg := range c.args { + if arg == nil { + continue + } + val, err := arg.Eval(ctx, row) + if err != nil { + return nil, err + } + if val == nil { + continue + } + if commonType == pgtypes.Unknown { + return val, nil + } + argType, ok := arg.Type(ctx).(*pgtypes.DoltgresType) + if ok && argType.Equals(commonType) { + return val, nil + } + // Cast the value to the common type (handles mixed-type args, e.g. int2 and int4). + converted, _, err := commonType.Convert(ctx, val) + if err != nil { + return nil, err + } + return converted, nil + } + return nil, nil +} + +// String implements sql.Expression. +func (c *PgCoalesce) String() string { + args := make([]string, len(c.args)) + for i, arg := range c.args { + args[i] = arg.String() + } + return fmt.Sprintf("coalesce(%s)", strings.Join(args, ",")) +} + +// DebugString implements the sql.Debuggable interface. +func (c *PgCoalesce) DebugString(ctx *sql.Context) string { + args := make([]string, len(c.args)) + for i, arg := range c.args { + args[i] = sql.DebugString(ctx, arg) + } + return fmt.Sprintf("coalesce(%s)", strings.Join(args, ",")) +} diff --git a/testing/go/expressions_test.go b/testing/go/expressions_test.go index c031789c3c..5c2b9c04df 100644 --- a/testing/go/expressions_test.go +++ b/testing/go/expressions_test.go @@ -487,3 +487,66 @@ func TestSubscript(t *testing.T) { }, }) } + +func TestCoalesce(t *testing.T) { + RunScripts(t, []ScriptTest{ + { + // https://github.com/dolthub/doltgresql/issues/2332 + Name: "COALESCE(NULL, col) in UPDATE", + SetUpScript: []string{ + `CREATE TABLE t (id UUID PRIMARY KEY, val INTEGER NOT NULL DEFAULT 0, d DATE)`, + `INSERT INTO t VALUES ('00000000-0000-0000-0000-000000000001', 42, '2026-01-01')`, + }, + Assertions: []ScriptTestAssertion{ + { + // Should be a no-op; val stays 42. + Query: `UPDATE t SET val = COALESCE(NULL, val) WHERE id = '00000000-0000-0000-0000-000000000001'`, + SkipResultsCheck: true, + }, + { + Query: `SELECT val FROM t WHERE id = '00000000-0000-0000-0000-000000000001'`, + Expected: []sql.Row{{int32(42)}}, + }, + { + // Should be a no-op; d stays '2026-01-01'. + Query: `UPDATE t SET d = COALESCE(NULL, d) WHERE id = '00000000-0000-0000-0000-000000000001'`, + SkipResultsCheck: true, + }, + { + Query: `SELECT d FROM t WHERE id = '00000000-0000-0000-0000-000000000001'`, + Expected: []sql.Row{{"2026-01-01"}}, + }, + }, + }, + { + Name: "COALESCE type resolution in SELECT", + Assertions: []ScriptTestAssertion{ + { + Query: `SELECT COALESCE(NULL, 42)`, + Expected: []sql.Row{{int32(42)}}, + }, + { + Query: `SELECT COALESCE(NULL, NULL)`, + Expected: []sql.Row{{nil}}, + }, + { + Query: `SELECT COALESCE(NULL, NULL, 'hello')`, + Expected: []sql.Row{{"hello"}}, + }, + { + Query: `SELECT COALESCE(1, 2, 3)`, + Expected: []sql.Row{{int32(1)}}, + }, + { + Query: `SELECT COALESCE(NULL, 2, 3)`, + Expected: []sql.Row{{int32(2)}}, + }, + { + // Explicit cast workaround still works. + Query: `SELECT COALESCE(NULL::integer, 42)`, + Expected: []sql.Row{{int32(42)}}, + }, + }, + }, + }) +}