|
9 | 9 | "github.com/sqlc-dev/sqlc/internal/sql/ast" |
10 | 10 | "github.com/sqlc-dev/sqlc/internal/sql/named" |
11 | 11 | "github.com/sqlc-dev/sqlc/internal/sql/rewrite" |
12 | | - "github.com/sqlc-dev/sqlc/internal/sql/sqlerr" |
13 | 12 | "github.com/sqlc-dev/sqlc/internal/sql/validate" |
14 | 13 | ) |
15 | 14 |
|
@@ -143,11 +142,7 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) |
143 | 142 | raw, namedParams, edits := rewrite.NamedParameters(c.conf.Engine, raw, numbers, dollar) |
144 | 143 |
|
145 | 144 | var table *ast.TableName |
146 | | - switch n := raw.Stmt.(type) { |
147 | | - case *ast.InsertStmt: |
148 | | - if err := check(validate.InsertStmt(n)); err != nil { |
149 | | - return nil, err |
150 | | - } |
| 145 | + if n, ok := raw.Stmt.(*ast.InsertStmt); ok { |
151 | 146 | var err error |
152 | 147 | table, err = ParseTableName(n.Relation) |
153 | 148 | if err := check(err); err != nil { |
@@ -187,7 +182,7 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) |
187 | 182 | return nil, err |
188 | 183 | } |
189 | 184 | if n, ok := raw.Stmt.(*ast.InsertStmt); ok { |
190 | | - if err := check(c.validateOnConflictClause(n)); err != nil { |
| 185 | + if err := check(validate.InsertStmt(n, table, c.catalog)); err != nil { |
191 | 186 | return nil, err |
192 | 187 | } |
193 | 188 | } |
@@ -219,88 +214,3 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) |
219 | 214 | Named: namedParams, |
220 | 215 | }, rerr |
221 | 216 | } |
222 | | - |
223 | | -// validateOnConflictClause validates an ON CONFLICT DO UPDATE clause against |
224 | | -// the target table. It checks: |
225 | | -// - ON CONFLICT (col, ...) conflict target columns exist |
226 | | -// - DO UPDATE SET col = ... assignment target columns exist |
227 | | -// - EXCLUDED.col references exist |
228 | | -func (c *Compiler) validateOnConflictClause(n *ast.InsertStmt) error { |
229 | | - if n.OnConflictClause == nil || n.OnConflictClause.Action != ast.OnConflictActionUpdate { |
230 | | - return nil |
231 | | - } |
232 | | - |
233 | | - fqn, err := ParseTableName(n.Relation) |
234 | | - if err != nil { |
235 | | - return err |
236 | | - } |
237 | | - |
238 | | - table, err := c.catalog.GetTable(fqn) |
239 | | - if err != nil { |
240 | | - return err |
241 | | - } |
242 | | - |
243 | | - // Build set of column names for existence checks. |
244 | | - colNames := make(map[string]struct{}, len(table.Columns)) |
245 | | - for _, col := range table.Columns { |
246 | | - colNames[col.Name] = struct{}{} |
247 | | - } |
248 | | - |
249 | | - // Validate ON CONFLICT (col, ...) conflict target columns. |
250 | | - if n.OnConflictClause.Infer != nil && n.OnConflictClause.Infer.IndexElems != nil { |
251 | | - for _, item := range n.OnConflictClause.Infer.IndexElems.Items { |
252 | | - elem, ok := item.(*ast.IndexElem) |
253 | | - if !ok || elem.Name == nil { |
254 | | - continue |
255 | | - } |
256 | | - if _, exists := colNames[*elem.Name]; !exists { |
257 | | - e := sqlerr.ColumnNotFound(table.Rel.Name, *elem.Name) |
258 | | - e.Location = n.OnConflictClause.Infer.Location |
259 | | - return e |
260 | | - } |
261 | | - } |
262 | | - } |
263 | | - |
264 | | - // Validate DO UPDATE SET col = ... assignment target columns and EXCLUDED.col references. |
265 | | - if n.OnConflictClause.TargetList == nil { |
266 | | - return nil |
267 | | - } |
268 | | - for _, item := range n.OnConflictClause.TargetList.Items { |
269 | | - target, ok := item.(*ast.ResTarget) |
270 | | - if !ok || target.Name == nil { |
271 | | - continue |
272 | | - } |
273 | | - if _, exists := colNames[*target.Name]; !exists { |
274 | | - e := sqlerr.ColumnNotFound(table.Rel.Name, *target.Name) |
275 | | - e.Location = target.Location |
276 | | - return e |
277 | | - } |
278 | | - if ref, ok := target.Val.(*ast.ColumnRef); ok { |
279 | | - if excludedCol, ok := excludedColumn(ref); ok { |
280 | | - if _, exists := colNames[excludedCol]; !exists { |
281 | | - e := sqlerr.ColumnNotFound(table.Rel.Name, excludedCol) |
282 | | - e.Location = ref.Location |
283 | | - return e |
284 | | - } |
285 | | - } |
286 | | - } |
287 | | - } |
288 | | - return nil |
289 | | -} |
290 | | - |
291 | | -// excludedColumn returns the column name if the ColumnRef is an EXCLUDED.col |
292 | | -// reference, and ok=true. Returns "", false otherwise. |
293 | | -func excludedColumn(ref *ast.ColumnRef) (string, bool) { |
294 | | - if ref.Fields == nil || len(ref.Fields.Items) != 2 { |
295 | | - return "", false |
296 | | - } |
297 | | - first, ok := ref.Fields.Items[0].(*ast.String) |
298 | | - if !ok || first.Str != "excluded" { |
299 | | - return "", false |
300 | | - } |
301 | | - second, ok := ref.Fields.Items[1].(*ast.String) |
302 | | - if !ok { |
303 | | - return "", false |
304 | | - } |
305 | | - return second.Str, true |
306 | | -} |
0 commit comments