Skip to content

Commit 6287222

Browse files
committed
fix(sqlite): allow to use overrides for columns which were created using "alter table .. add column"
This fixes the failing test which was added in the previous commit. Normalize sqlite `ALTER TABLE` column identifiers and preserve analyzer origin metadata when merging query analysis. This allows exact column overrides such as `chat_messages.updated_at` to match columns added via `ALTER TABLE .. ADD COLUMN`. Using the sqlc config: overrides: - column: "*.*_at" go_type: import: "time" type: "Time" and the SQL schema: CREATE TABLE chat_messages ( id TEXT PRIMARY KEY NOT NULL, body TEXT NOT NULL, created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP ) STRICT; ALTER TABLE chat_messages ADD COLUMN "updated_at" TEXT NOT NULL DEFAULT ''; sqlc would previously generate the following code: type ChatMessage struct { ID string Body string CreatedAt time.Time UpdatedAt string // Bug: This should be a time.Time! } _With_ this fix, the following code is generated: type ChatMessage struct { ID string Body string CreatedAt time.Time UpdatedAt time.Time // Fixed, now a time.Time, even for retroactively added columns }
1 parent 1851c0d commit 6287222

2 files changed

Lines changed: 45 additions & 5 deletions

File tree

internal/compiler/analyze.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,41 @@ func convertColumn(c *analyzer.Column) *Column {
6565
}
6666
}
6767

68+
func mergeColumnOrigin(dst, src *Column) {
69+
if dst == nil || src == nil {
70+
return
71+
}
72+
73+
// Column overrides in the Go generator depend on the column's original
74+
// table identity. The analyzer can fill in missing origin metadata, but it
75+
// must not overwrite catalog-inferred origin metadata.
76+
//
77+
// In particular, CTE output columns are deliberately re-scoped to the CTE
78+
// relation by buildQueryCatalog. If we overwrite that with the analyzer's
79+
// underlying base table, queries like:
80+
//
81+
// WITH expensive AS (SELECT * FROM products)
82+
// SELECT * FROM expensive
83+
//
84+
// start looking like they return products directly, causing the Go generator
85+
// to reuse Product instead of emitting ListExpensiveProductsRow.
86+
if dst.OriginalName == "" && src.OriginalName != "" {
87+
dst.OriginalName = src.OriginalName
88+
}
89+
if dst.Table == nil && src.Table != nil {
90+
dst.Table = src.Table
91+
}
92+
if dst.TableAlias == "" && src.TableAlias != "" {
93+
dst.TableAlias = src.TableAlias
94+
}
95+
if dst.Scope == "" && src.Scope != "" {
96+
dst.Scope = src.Scope
97+
}
98+
if dst.EmbedTable == nil && src.EmbedTable != nil {
99+
dst.EmbedTable = src.EmbedTable
100+
}
101+
}
102+
68103
func combineAnalysis(prev *analysis, a *analyzer.Analysis) *analysis {
69104
var cols []*Column
70105
for _, c := range a.Columns {
@@ -79,6 +114,7 @@ func combineAnalysis(prev *analysis, a *analyzer.Analysis) *analysis {
79114
}
80115
if len(prev.Columns) == len(cols) {
81116
for i := range prev.Columns {
117+
mergeColumnOrigin(prev.Columns[i], cols[i])
82118
// Only override column types if the analyzer provides a specific type
83119
// (not "any"), since the catalog-based inference may have better info
84120
if cols[i].DataType != "any" {

internal/engine/sqlite/convert.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,18 @@ func (c *cc) convertAlter_table_stmtContext(n *parser.Alter_table_stmtContext) a
6767
Table: parseTableName(n),
6868
Cmds: &ast.List{},
6969
}
70-
name := def.Column_name().GetText()
70+
name := identifier(def.Column_name().GetText())
71+
typeName := "any"
72+
if def.Type_name() != nil {
73+
typeName = def.Type_name().GetText()
74+
}
7175
stmt.Cmds.Items = append(stmt.Cmds.Items, &ast.AlterTableCmd{
7276
Name: &name,
7377
Subtype: ast.AT_AddColumn,
7478
Def: &ast.ColumnDef{
7579
Colname: name,
7680
TypeName: &ast.TypeName{
77-
Name: def.Type_name().GetText(),
81+
Name: typeName,
7882
},
7983
IsNotNull: hasNotNullConstraint(def.AllColumn_constraint()),
8084
},
@@ -88,7 +92,7 @@ func (c *cc) convertAlter_table_stmtContext(n *parser.Alter_table_stmtContext) a
8892
Table: parseTableName(n),
8993
Cmds: &ast.List{},
9094
}
91-
name := n.Column_name(0).GetText()
95+
name := identifier(n.Column_name(0).GetText())
9296
stmt.Cmds.Items = append(stmt.Cmds.Items, &ast.AlterTableCmd{
9397
Name: &name,
9498
Subtype: ast.AT_DropColumn,
@@ -826,7 +830,7 @@ func (c *cc) convertUnaryExpr(n *parser.Expr_unaryContext) ast.Node {
826830
if opCtx.MINUS() != nil {
827831
// Negative number: -expr
828832
return &ast.A_Expr{
829-
Name: &ast.List{Items: []ast.Node{&ast.String{Str: "-"}}},
833+
Name: &ast.List{Items: []ast.Node{&ast.String{Str: "-"}}},
830834
Rexpr: expr,
831835
}
832836
}
@@ -837,7 +841,7 @@ func (c *cc) convertUnaryExpr(n *parser.Expr_unaryContext) ast.Node {
837841
if opCtx.TILDE() != nil {
838842
// Bitwise NOT: ~expr
839843
return &ast.A_Expr{
840-
Name: &ast.List{Items: []ast.Node{&ast.String{Str: "~"}}},
844+
Name: &ast.List{Items: []ast.Node{&ast.String{Str: "~"}}},
841845
Rexpr: expr,
842846
}
843847
}

0 commit comments

Comments
 (0)