Skip to content

Commit 84ad2cb

Browse files
committed
fix: inferred column names
1 parent f758c7e commit 84ad2cb

File tree

3 files changed

+289
-54
lines changed

3 files changed

+289
-54
lines changed

internal/compiler/find_params.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,11 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor {
140140
p.parent = node
141141

142142
case *ast.SelectStmt:
143+
if n.FromClause != nil && len(n.FromClause.Items) == 1 {
144+
if rv, ok := n.FromClause.Items[0].(*ast.RangeVar); ok {
145+
p.rangeVar = rv
146+
}
147+
}
143148
if n.LimitCount != nil {
144149
p.limitCount = n.LimitCount
145150
}

internal/compiler/parse_test.go

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
package compiler
2+
3+
import (
4+
"context"
5+
"strings"
6+
"testing"
7+
8+
analysispb "github.com/sqlc-dev/sqlc/internal/analysis"
9+
"github.com/sqlc-dev/sqlc/internal/config"
10+
"github.com/sqlc-dev/sqlc/internal/engine/postgresql"
11+
"github.com/sqlc-dev/sqlc/internal/opts"
12+
"github.com/sqlc-dev/sqlc/internal/sql/ast"
13+
"github.com/sqlc-dev/sqlc/internal/sql/named"
14+
)
15+
16+
const batchParameterTypeSchema = `
17+
CREATE TABLE public.solar_commcard_mapping (
18+
id INT8 NOT NULL,
19+
"deviceId" INT8 NOT NULL,
20+
version VARCHAR(32) DEFAULT ''::VARCHAR NOT NULL,
21+
sn VARCHAR(32) DEFAULT ''::VARCHAR NOT NULL,
22+
"createdAt" TIMESTAMPTZ DEFAULT now(),
23+
"updatedAt" TIMESTAMPTZ DEFAULT now()
24+
);
25+
`
26+
27+
const batchParameterTypeQuery = `-- name: InsertMappping :batchexec
28+
WITH
29+
table1 AS (
30+
SELECT version
31+
FROM solar_commcard_mapping
32+
WHERE "deviceId" = $1
33+
ORDER BY "updatedAt" DESC
34+
LIMIT 1
35+
)
36+
INSERT INTO solar_commcard_mapping ("deviceId", version, sn, "updatedAt")
37+
SELECT $1, @version::text, $3, $4
38+
WHERE NOT EXISTS (
39+
SELECT *
40+
FROM table1
41+
WHERE table1.version = @version::text
42+
) OR NOT EXISTS (SELECT * FROM table1);
43+
`
44+
45+
type stubAnalyzer struct {
46+
analyze func(context.Context, ast.Node, string, []string, *named.ParamSet) (*analysispb.Analysis, error)
47+
}
48+
49+
func (s stubAnalyzer) Analyze(ctx context.Context, n ast.Node, q string, schema []string, np *named.ParamSet) (*analysispb.Analysis, error) {
50+
return s.analyze(ctx, n, q, schema, np)
51+
}
52+
53+
func (stubAnalyzer) Close(context.Context) error { return nil }
54+
func (stubAnalyzer) EnsureConn(context.Context, []string) error { return nil }
55+
func (stubAnalyzer) GetColumnNames(context.Context, string) ([]string, error) { return nil, nil }
56+
57+
func newBatchParameterTypeCompiler(t *testing.T) (*Compiler, *ast.RawStmt) {
58+
t.Helper()
59+
60+
parser := postgresql.NewParser()
61+
catalog := postgresql.NewCatalog()
62+
63+
schema, err := parser.Parse(strings.NewReader(batchParameterTypeSchema))
64+
if err != nil {
65+
t.Fatal(err)
66+
}
67+
if err := catalog.Build(schema); err != nil {
68+
t.Fatal(err)
69+
}
70+
71+
stmts, err := parser.Parse(strings.NewReader(batchParameterTypeQuery))
72+
if err != nil {
73+
t.Fatal(err)
74+
}
75+
if len(stmts) != 1 {
76+
t.Fatalf("expected 1 statement, got %d", len(stmts))
77+
}
78+
79+
return &Compiler{
80+
conf: config.SQL{Engine: config.EnginePostgreSQL},
81+
parser: parser,
82+
catalog: catalog,
83+
selector: newDefaultSelector(),
84+
}, stmts[0].Raw
85+
}
86+
87+
func assertBatchParameterNames(t *testing.T, params []Parameter) {
88+
t.Helper()
89+
90+
checks := []struct {
91+
idx int
92+
number int
93+
name string
94+
original string
95+
named bool
96+
}{
97+
{idx: 0, number: 1, name: "deviceId", original: "deviceId"},
98+
{idx: 1, number: 2, name: "version", original: "version", named: true},
99+
{idx: 2, number: 3, name: "sn", original: "sn"},
100+
{idx: 3, number: 4, name: "updatedAt", original: "updatedAt"},
101+
}
102+
if len(params) != len(checks) {
103+
t.Fatalf("expected %d params, got %d", len(checks), len(params))
104+
}
105+
106+
for _, check := range checks {
107+
param := params[check.idx]
108+
if param.Number != check.number {
109+
t.Fatalf("param %d number mismatch: got %d want %d", check.idx, param.Number, check.number)
110+
}
111+
if param.Column == nil {
112+
t.Fatalf("param %d column is nil", check.idx)
113+
}
114+
if param.Column.Name != check.name {
115+
t.Fatalf("param %d name mismatch: got %q want %q", check.idx, param.Column.Name, check.name)
116+
}
117+
if param.Column.OriginalName != check.original {
118+
t.Fatalf("param %d original name mismatch: got %q want %q", check.idx, param.Column.OriginalName, check.original)
119+
}
120+
if param.Column.IsNamedParam != check.named {
121+
t.Fatalf("param %d named mismatch: got %v want %v", check.idx, param.Column.IsNamedParam, check.named)
122+
}
123+
if param.Column.DataType == "" || param.Column.DataType == "any" {
124+
t.Fatalf("param %d type was not inferred: %+v", check.idx, param.Column)
125+
}
126+
}
127+
}
128+
129+
func TestInferQueryPreservesInsertSelectParamNamesWithCTEAndMixedParams(t *testing.T) {
130+
t.Parallel()
131+
132+
comp, raw := newBatchParameterTypeCompiler(t)
133+
anlys, err := comp.inferQuery(raw, batchParameterTypeQuery)
134+
if err != nil && !strings.Contains(err.Error(), "parameter $2") {
135+
t.Fatalf("unexpected infer error: %v", err)
136+
}
137+
if anlys == nil {
138+
t.Fatal("expected non-nil analysis")
139+
}
140+
if !strings.Contains(anlys.Query, "$2::text") {
141+
t.Fatalf("expected rewritten query to contain $2::text, got %q", anlys.Query)
142+
}
143+
144+
assertBatchParameterNames(t, anlys.Parameters)
145+
}
146+
147+
func TestParseQueryManagedDBPreservesInferredParamNames(t *testing.T) {
148+
t.Parallel()
149+
150+
comp, raw := newBatchParameterTypeCompiler(t)
151+
comp.analyzer = stubAnalyzer{analyze: func(_ context.Context, _ ast.Node, query string, _ []string, np *named.ParamSet) (*analysispb.Analysis, error) {
152+
if np == nil {
153+
t.Fatal("expected named param set")
154+
}
155+
if got, ok := np.NameFor(2); !ok || got != "version" {
156+
t.Fatalf("expected param 2 to be named version, got %q %v", got, ok)
157+
}
158+
if !strings.Contains(query, "$2::text") {
159+
t.Fatalf("expected analyzer query to contain rewritten named param, got %q", query)
160+
}
161+
return &analysispb.Analysis{Params: []*analysispb.Parameter{
162+
{Number: 1, Column: &analysispb.Column{DataType: "pg_catalog.int8"}},
163+
{Number: 2, Column: &analysispb.Column{Name: "version", DataType: "text", IsNamedParam: true}},
164+
{Number: 3, Column: &analysispb.Column{DataType: "text"}},
165+
{Number: 4, Column: &analysispb.Column{DataType: "pg_catalog.timestamptz"}},
166+
}}, nil
167+
}}
168+
169+
query, err := comp.parseQuery(raw, batchParameterTypeQuery, opts.Parser{})
170+
if err != nil {
171+
t.Fatal(err)
172+
}
173+
174+
assertBatchParameterNames(t, query.Params)
175+
}

0 commit comments

Comments
 (0)