Skip to content

Commit 174bf4d

Browse files
committed
Fix UNNEST bulk insert column mapping
1 parent 69d8a95 commit 174bf4d

3 files changed

Lines changed: 84 additions & 1 deletion

File tree

server/ast/select_clause.go

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
package ast
1616

1717
import (
18-
"github.com/dolthub/go-mysql-server/sql/expression"
18+
"strings"
1919

20+
"github.com/dolthub/go-mysql-server/sql/expression"
2021
vitess "github.com/dolthub/vitess/go/vt/sqlparser"
2122

2223
"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
@@ -158,6 +159,36 @@ PostJoinRewrite:
158159
}
159160
}
160161
}
162+
// Handle multi-argument UNNEST: UNNEST(arr1, arr2, ...) produces a table with one column per array,
163+
// where corresponding elements are "zipped" together. PostgreSQL pads shorter arrays with NULLs.
164+
// We transform: SELECT * FROM UNNEST(arr1, arr2)
165+
// Into: SELECT * FROM (SELECT unnest(arr1), unnest(arr2)) AS unnest
166+
// GMS's ProjectRowWithNestedIters handles multiple SRFs by zipping them together correctly.
167+
if tableFuncExpr, ok := from[i].(*vitess.TableFuncExpr); ok {
168+
if strings.EqualFold(tableFuncExpr.Name, "unnest") && len(tableFuncExpr.Exprs) > 1 {
169+
selectExprs := make(vitess.SelectExprs, 0, len(tableFuncExpr.Exprs))
170+
for _, argExpr := range tableFuncExpr.Exprs {
171+
selectExprs = append(selectExprs, &vitess.AliasedExpr{
172+
Expr: &vitess.FuncExpr{
173+
Name: vitess.NewColIdent("unnest"),
174+
Exprs: vitess.SelectExprs{argExpr},
175+
},
176+
})
177+
}
178+
alias := tableFuncExpr.Alias
179+
if alias.IsEmpty() {
180+
alias = vitess.NewTableIdent("unnest")
181+
}
182+
from[i] = &vitess.AliasedTableExpr{
183+
Expr: &vitess.Subquery{
184+
Select: &vitess.Select{
185+
SelectExprs: selectExprs,
186+
},
187+
},
188+
As: alias,
189+
}
190+
}
191+
}
161192
}
162193
distinct := node.Distinct
163194
var distinctOn vitess.Exprs

testing/go/functions_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1552,6 +1552,36 @@ func TestArrayFunctions(t *testing.T) {
15521552
},
15531553
},
15541554
},
1555+
{
1556+
Name: "multi-argument unnest",
1557+
Assertions: []ScriptTestAssertion{
1558+
{
1559+
// Basic multi-argument UNNEST with equal-length arrays
1560+
Query: `SELECT * FROM UNNEST(ARRAY['a','b','c'], ARRAY[1,2,3])`,
1561+
Expected: []sql.Row{{"a", int64(1)}, {"b", int64(2)}, {"c", int64(3)}},
1562+
},
1563+
{
1564+
// Multi-argument UNNEST with unequal-length arrays (shorter padded with NULL)
1565+
Query: `SELECT * FROM UNNEST(ARRAY['a','b'], ARRAY[1,2,3])`,
1566+
Expected: []sql.Row{{"a", int64(1)}, {"b", int64(2)}, {nil, int64(3)}},
1567+
},
1568+
{
1569+
// Multi-argument UNNEST with empty array
1570+
Query: `SELECT * FROM UNNEST(ARRAY['a','b'], ARRAY[]::int[])`,
1571+
Expected: []sql.Row{{"a", nil}, {"b", nil}},
1572+
},
1573+
{
1574+
// Multi-argument UNNEST with three arrays (booleans come as "t"/"f" strings from PostgreSQL wire protocol)
1575+
Query: `SELECT * FROM UNNEST(ARRAY[1,2], ARRAY['x','y'], ARRAY[true,false])`,
1576+
Expected: []sql.Row{{int64(1), "x", "t"}, {int64(2), "y", "f"}},
1577+
},
1578+
{
1579+
// Multi-argument UNNEST with alias
1580+
Query: `SELECT u.* FROM UNNEST(ARRAY['a','b'], ARRAY[1,2]) AS u`,
1581+
Expected: []sql.Row{{"a", int64(1)}, {"b", int64(2)}},
1582+
},
1583+
},
1584+
},
15551585
{
15561586
Name: "array_to_string",
15571587
SetUpScript: []string{},

testing/go/insert_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,5 +293,27 @@ ON CONFLICT (id) do update set c1 = $4`,
293293
},
294294
},
295295
},
296+
{
297+
Name: "insert from unnest",
298+
SetUpScript: []string{
299+
`CREATE TABLE "django_content_type" (id serial primary key, app_label varchar, model varchar)`,
300+
},
301+
Assertions: []ScriptTestAssertion{
302+
{
303+
Query: `INSERT INTO "django_content_type" ("app_label", "model")
304+
SELECT * FROM UNNEST(('{debug_app,debug_app}')::varchar[],
305+
('{debugmodel1,debugmodel2}')::varchar[])
306+
RETURNING "django_content_type"."id"`,
307+
Expected: []sql.Row{{1}, {2}},
308+
},
309+
{
310+
Query: `SELECT "app_label", "model" FROM "django_content_type" ORDER BY "id"`,
311+
Expected: []sql.Row{
312+
{"debug_app", "debugmodel1"},
313+
{"debug_app", "debugmodel2"},
314+
},
315+
},
316+
},
317+
},
296318
})
297319
}

0 commit comments

Comments
 (0)