Skip to content

Commit 1ed27ef

Browse files
committed
move to validate
1 parent 95e0552 commit 1ed27ef

File tree

2 files changed

+86
-93
lines changed

2 files changed

+86
-93
lines changed

internal/compiler/analyze.go

Lines changed: 2 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99
"github.com/sqlc-dev/sqlc/internal/sql/ast"
1010
"github.com/sqlc-dev/sqlc/internal/sql/named"
1111
"github.com/sqlc-dev/sqlc/internal/sql/rewrite"
12-
"github.com/sqlc-dev/sqlc/internal/sql/sqlerr"
1312
"github.com/sqlc-dev/sqlc/internal/sql/validate"
1413
)
1514

@@ -143,11 +142,7 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
143142
raw, namedParams, edits := rewrite.NamedParameters(c.conf.Engine, raw, numbers, dollar)
144143

145144
var table *ast.TableName
146-
switch n := raw.Stmt.(type) {
147-
case *ast.InsertStmt:
148-
if err := check(validate.InsertStmt(n)); err != nil {
149-
return nil, err
150-
}
145+
if n, ok := raw.Stmt.(*ast.InsertStmt); ok {
151146
var err error
152147
table, err = ParseTableName(n.Relation)
153148
if err := check(err); err != nil {
@@ -187,7 +182,7 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
187182
return nil, err
188183
}
189184
if n, ok := raw.Stmt.(*ast.InsertStmt); ok {
190-
if err := check(c.validateOnConflictClause(n)); err != nil {
185+
if err := check(validate.InsertStmt(n, table, c.catalog)); err != nil {
191186
return nil, err
192187
}
193188
}
@@ -219,88 +214,3 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
219214
Named: namedParams,
220215
}, rerr
221216
}
222-
223-
// validateOnConflictClause validates an ON CONFLICT DO UPDATE clause against
224-
// the target table. It checks:
225-
// - ON CONFLICT (col, ...) conflict target columns exist
226-
// - DO UPDATE SET col = ... assignment target columns exist
227-
// - EXCLUDED.col references exist
228-
func (c *Compiler) validateOnConflictClause(n *ast.InsertStmt) error {
229-
if n.OnConflictClause == nil || n.OnConflictClause.Action != ast.OnConflictActionUpdate {
230-
return nil
231-
}
232-
233-
fqn, err := ParseTableName(n.Relation)
234-
if err != nil {
235-
return err
236-
}
237-
238-
table, err := c.catalog.GetTable(fqn)
239-
if err != nil {
240-
return err
241-
}
242-
243-
// Build set of column names for existence checks.
244-
colNames := make(map[string]struct{}, len(table.Columns))
245-
for _, col := range table.Columns {
246-
colNames[col.Name] = struct{}{}
247-
}
248-
249-
// Validate ON CONFLICT (col, ...) conflict target columns.
250-
if n.OnConflictClause.Infer != nil && n.OnConflictClause.Infer.IndexElems != nil {
251-
for _, item := range n.OnConflictClause.Infer.IndexElems.Items {
252-
elem, ok := item.(*ast.IndexElem)
253-
if !ok || elem.Name == nil {
254-
continue
255-
}
256-
if _, exists := colNames[*elem.Name]; !exists {
257-
e := sqlerr.ColumnNotFound(table.Rel.Name, *elem.Name)
258-
e.Location = n.OnConflictClause.Infer.Location
259-
return e
260-
}
261-
}
262-
}
263-
264-
// Validate DO UPDATE SET col = ... assignment target columns and EXCLUDED.col references.
265-
if n.OnConflictClause.TargetList == nil {
266-
return nil
267-
}
268-
for _, item := range n.OnConflictClause.TargetList.Items {
269-
target, ok := item.(*ast.ResTarget)
270-
if !ok || target.Name == nil {
271-
continue
272-
}
273-
if _, exists := colNames[*target.Name]; !exists {
274-
e := sqlerr.ColumnNotFound(table.Rel.Name, *target.Name)
275-
e.Location = target.Location
276-
return e
277-
}
278-
if ref, ok := target.Val.(*ast.ColumnRef); ok {
279-
if excludedCol, ok := excludedColumn(ref); ok {
280-
if _, exists := colNames[excludedCol]; !exists {
281-
e := sqlerr.ColumnNotFound(table.Rel.Name, excludedCol)
282-
e.Location = ref.Location
283-
return e
284-
}
285-
}
286-
}
287-
}
288-
return nil
289-
}
290-
291-
// excludedColumn returns the column name if the ColumnRef is an EXCLUDED.col
292-
// reference, and ok=true. Returns "", false otherwise.
293-
func excludedColumn(ref *ast.ColumnRef) (string, bool) {
294-
if ref.Fields == nil || len(ref.Fields.Items) != 2 {
295-
return "", false
296-
}
297-
first, ok := ref.Fields.Items[0].(*ast.String)
298-
if !ok || first.Str != "excluded" {
299-
return "", false
300-
}
301-
second, ok := ref.Fields.Items[1].(*ast.String)
302-
if !ok {
303-
return "", false
304-
}
305-
return second.Str, true
306-
}

internal/sql/validate/insert_stmt.go

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
package validate
22

33
import (
4+
"strings"
5+
46
"github.com/sqlc-dev/sqlc/internal/sql/ast"
7+
"github.com/sqlc-dev/sqlc/internal/sql/catalog"
58
"github.com/sqlc-dev/sqlc/internal/sql/sqlerr"
69
)
710

8-
func InsertStmt(stmt *ast.InsertStmt) error {
11+
func InsertStmt(stmt *ast.InsertStmt, fqn *ast.TableName, c *catalog.Catalog) error {
912
sel, ok := stmt.SelectStmt.(*ast.SelectStmt)
1013
if !ok {
1114
return nil
@@ -35,5 +38,85 @@ func InsertStmt(stmt *ast.InsertStmt) error {
3538
Message: "INSERT has more expressions than target columns",
3639
}
3740
}
41+
return onConflictClause(stmt, fqn, c)
42+
}
43+
44+
// onConflictClause validates an ON CONFLICT DO UPDATE clause against the target
45+
// table. It checks:
46+
// - ON CONFLICT (col, ...) conflict target columns exist
47+
// - DO UPDATE SET col = ... assignment target columns exist
48+
// - EXCLUDED.col references exist
49+
func onConflictClause(n *ast.InsertStmt, fqn *ast.TableName, c *catalog.Catalog) error {
50+
if n.OnConflictClause == nil || n.OnConflictClause.Action != ast.OnConflictActionUpdate {
51+
return nil
52+
}
53+
54+
table, err := c.GetTable(fqn)
55+
if err != nil {
56+
return err
57+
}
58+
59+
// Build set of column names for existence checks.
60+
colNames := make(map[string]struct{}, len(table.Columns))
61+
for _, col := range table.Columns {
62+
colNames[col.Name] = struct{}{}
63+
}
64+
65+
// Validate ON CONFLICT (col, ...) conflict target columns.
66+
if n.OnConflictClause.Infer != nil && n.OnConflictClause.Infer.IndexElems != nil {
67+
for _, item := range n.OnConflictClause.Infer.IndexElems.Items {
68+
elem, ok := item.(*ast.IndexElem)
69+
if !ok || elem.Name == nil {
70+
continue
71+
}
72+
if _, exists := colNames[*elem.Name]; !exists {
73+
e := sqlerr.ColumnNotFound(table.Rel.Name, *elem.Name)
74+
e.Location = n.OnConflictClause.Infer.Location
75+
return e
76+
}
77+
}
78+
}
79+
80+
// Validate DO UPDATE SET col = ... assignment target columns and EXCLUDED.col references.
81+
if n.OnConflictClause.TargetList == nil {
82+
return nil
83+
}
84+
for _, item := range n.OnConflictClause.TargetList.Items {
85+
target, ok := item.(*ast.ResTarget)
86+
if !ok || target.Name == nil {
87+
continue
88+
}
89+
if _, exists := colNames[*target.Name]; !exists {
90+
e := sqlerr.ColumnNotFound(table.Rel.Name, *target.Name)
91+
e.Location = target.Location
92+
return e
93+
}
94+
if ref, ok := target.Val.(*ast.ColumnRef); ok {
95+
if excludedCol, ok := excludedColumnRef(ref); ok {
96+
if _, exists := colNames[excludedCol]; !exists {
97+
e := sqlerr.ColumnNotFound(table.Rel.Name, excludedCol)
98+
e.Location = ref.Location
99+
return e
100+
}
101+
}
102+
}
103+
}
38104
return nil
39105
}
106+
107+
// excludedColumnRef returns the column name if the ColumnRef is an EXCLUDED.col
108+
// reference, and ok=true. Returns "", false otherwise.
109+
func excludedColumnRef(ref *ast.ColumnRef) (string, bool) {
110+
if ref.Fields == nil || len(ref.Fields.Items) != 2 {
111+
return "", false
112+
}
113+
first, ok := ref.Fields.Items[0].(*ast.String)
114+
if !ok || !strings.EqualFold(first.Str, "excluded") {
115+
return "", false
116+
}
117+
second, ok := ref.Fields.Items[1].(*ast.String)
118+
if !ok {
119+
return "", false
120+
}
121+
return second.Str, true
122+
}

0 commit comments

Comments
 (0)