Skip to content

Commit a7d1fe4

Browse files
committed
bloated ai stuff but will fix
1 parent 48dcbdf commit a7d1fe4

File tree

2 files changed

+236
-1
lines changed

2 files changed

+236
-1
lines changed

internal/compiler/parse_test.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,18 @@ const lowerSwitchedOrderQuery = `-- name: LowerSwitchedOrder :many
8484
SELECT bar FROM foo WHERE bar = $1 AND bat = LOWER($2);
8585
`
8686

87+
const orderByBindsSchema = `
88+
CREATE TABLE authors (
89+
id BIGSERIAL PRIMARY KEY,
90+
name text NOT NULL,
91+
bio text
92+
);
93+
`
94+
95+
const orderByBindsQuery = `-- name: ListAuthorsColumnSortFnWtihArg :many
96+
SELECT * FROM authors ORDER BY MOD(id, $1);
97+
`
98+
8799
type stubAnalyzer struct {
88100
analyze func(context.Context, ast.Node, string, []string, *named.ParamSet) (*analysispb.Analysis, error)
89101
}
@@ -216,6 +228,36 @@ func newLowerSwitchedOrderCompiler(t *testing.T) (*Compiler, *ast.RawStmt) {
216228
}, stmts[0].Raw
217229
}
218230

231+
func newOrderByBindsCompiler(t *testing.T) (*Compiler, *ast.RawStmt) {
232+
t.Helper()
233+
234+
parser := postgresql.NewParser()
235+
catalog := postgresql.NewCatalog()
236+
237+
schema, err := parser.Parse(strings.NewReader(orderByBindsSchema))
238+
if err != nil {
239+
t.Fatal(err)
240+
}
241+
if err := catalog.Build(schema); err != nil {
242+
t.Fatal(err)
243+
}
244+
245+
stmts, err := parser.Parse(strings.NewReader(orderByBindsQuery))
246+
if err != nil {
247+
t.Fatal(err)
248+
}
249+
if len(stmts) != 1 {
250+
t.Fatalf("expected 1 statement, got %d", len(stmts))
251+
}
252+
253+
return &Compiler{
254+
conf: config.SQL{Engine: config.EnginePostgreSQL},
255+
parser: parser,
256+
catalog: catalog,
257+
selector: newDefaultSelector(),
258+
}, stmts[0].Raw
259+
}
260+
219261
func assertBatchParameterNames(t *testing.T, params []Parameter) {
220262
t.Helper()
221263

@@ -325,6 +367,28 @@ func assertLowerSwitchedOrderParams(t *testing.T, params []Parameter) {
325367
}
326368
}
327369

370+
func assertOrderByBindsParams(t *testing.T, params []Parameter) {
371+
t.Helper()
372+
373+
if len(params) != 1 {
374+
t.Fatalf("expected 1 param, got %d", len(params))
375+
}
376+
377+
param := params[0]
378+
if param.Number != 1 {
379+
t.Fatalf("param number mismatch: got %d want %d", param.Number, 1)
380+
}
381+
if param.Column == nil {
382+
t.Fatal("param column is nil")
383+
}
384+
if param.Column.Name != "mod" {
385+
t.Fatalf("param name mismatch: got %q want %q", param.Column.Name, "mod")
386+
}
387+
if param.Column.DataType != "bigint" && param.Column.DataType != "pg_catalog.int8" {
388+
t.Fatalf("param type mismatch: got %q want %q or %q", param.Column.DataType, "bigint", "pg_catalog.int8")
389+
}
390+
}
391+
328392
func TestInferQueryPreservesInsertSelectParamNamesWithCTEAndMixedParams(t *testing.T) {
329393
t.Parallel()
330394

@@ -451,3 +515,15 @@ func TestParseQueryPreservesLowerSwitchedOrderParamTypes(t *testing.T) {
451515

452516
assertLowerSwitchedOrderParams(t, query.Params)
453517
}
518+
519+
func TestParseQueryPreservesOrderByBindsModParamType(t *testing.T) {
520+
t.Parallel()
521+
522+
comp, raw := newOrderByBindsCompiler(t)
523+
query, err := comp.parseQuery(raw, orderByBindsQuery, opts.Parser{})
524+
if err != nil {
525+
t.Fatal(err)
526+
}
527+
528+
assertOrderByBindsParams(t, query.Params)
529+
}

internal/compiler/resolve.go

Lines changed: 160 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,164 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
401401
})
402402
}
403403

404+
resolveColumnType := func(node *ast.ColumnRef) *ast.TypeName {
405+
if node == nil {
406+
return nil
407+
}
408+
409+
items := stringSlice(node.Fields)
410+
var schema, key, alias string
411+
switch len(items) {
412+
case 1:
413+
key = items[0]
414+
case 2:
415+
alias = items[0]
416+
key = items[1]
417+
case 3:
418+
schema = items[0]
419+
alias = items[1]
420+
key = items[2]
421+
default:
422+
return nil
423+
}
424+
425+
search := tables
426+
if alias != "" {
427+
if original, ok := aliasMap[alias]; ok {
428+
search = []*ast.TableName{original}
429+
} else {
430+
var matches []*ast.TableName
431+
for _, fqn := range tables {
432+
if fqn.Name == alias {
433+
matches = append(matches, fqn)
434+
}
435+
}
436+
if len(matches) == 0 {
437+
return nil
438+
}
439+
search = matches
440+
}
441+
}
442+
443+
var found *catalog.Column
444+
var foundCount int
445+
for _, table := range search {
446+
tableSchema := table.Schema
447+
if tableSchema == "" {
448+
tableSchema = c.DefaultSchema
449+
}
450+
if schema != "" && tableSchema != schema {
451+
continue
452+
}
453+
tablesBySchema, ok := typeMap[tableSchema]
454+
if !ok {
455+
continue
456+
}
457+
colsByTable, ok := tablesBySchema[table.Name]
458+
if !ok {
459+
continue
460+
}
461+
col, ok := colsByTable[key]
462+
if !ok {
463+
continue
464+
}
465+
found = col
466+
foundCount++
467+
}
468+
if foundCount != 1 || found == nil {
469+
return nil
470+
}
471+
472+
typ := found.Type
473+
return &typ
474+
}
475+
476+
normalizedFuncArgTypeKey := func(t *ast.TypeName) string {
477+
if t == nil {
478+
return ""
479+
}
480+
481+
name := strings.ToLower(comp.parser.TypeName(t.Schema, t.Name))
482+
switch name {
483+
case "bigserial", "serial8":
484+
name = "bigint"
485+
case "serial", "serial4":
486+
name = "integer"
487+
case "smallserial", "serial2":
488+
name = "smallint"
489+
}
490+
491+
schema := t.Schema
492+
if schema == "pg_catalog" {
493+
schema = ""
494+
}
495+
return schema + "|" + name + "|" + strconv.Itoa(arrayDims(t))
496+
}
497+
498+
sameFuncArgType := func(a, b *ast.TypeName) bool {
499+
if a == nil || b == nil {
500+
return a == nil && b == nil
501+
}
502+
return normalizedFuncArgTypeKey(a) == normalizedFuncArgTypeKey(b)
503+
}
504+
505+
var resolveFuncExprType func(node ast.Node, targetParamNumber int) *ast.TypeName
506+
resolveFuncExprType = func(node ast.Node, targetParamNumber int) *ast.TypeName {
507+
switch n := node.(type) {
508+
case *ast.NamedArgExpr:
509+
return resolveFuncExprType(n.Arg, targetParamNumber)
510+
case *ast.TypeCast:
511+
if pr, ok := n.Arg.(*ast.ParamRef); ok && pr.Number == targetParamNumber {
512+
return nil
513+
}
514+
return n.TypeName
515+
case *ast.ColumnRef:
516+
return resolveColumnType(n)
517+
default:
518+
return nil
519+
}
520+
}
521+
522+
filterFuncCallCandidates := func(call *ast.FuncCall, funcs []catalog.Function, targetIdx int, targetNamedArg string, targetParamNumber int) []catalog.Function {
523+
if call == nil || call.Args == nil || len(funcs) <= 1 {
524+
return funcs
525+
}
526+
527+
filtered := funcs
528+
for i, item := range call.Args.Items {
529+
if i == targetIdx {
530+
continue
531+
}
532+
533+
namedArg := ""
534+
if named, ok := item.(*ast.NamedArgExpr); ok && named.Name != nil {
535+
namedArg = *named.Name
536+
}
537+
if targetNamedArg != "" && namedArg == targetNamedArg {
538+
continue
539+
}
540+
541+
typ := resolveFuncExprType(item, targetParamNumber)
542+
if typ == nil {
543+
continue
544+
}
545+
546+
narrowed := make([]catalog.Function, 0, len(filtered))
547+
for _, fn := range filtered {
548+
arg := funcCallArg(fn, i, namedArg)
549+
if arg == nil || !sameFuncArgType(arg.Type, typ) {
550+
continue
551+
}
552+
narrowed = append(narrowed, fn)
553+
}
554+
if len(narrowed) > 0 {
555+
filtered = narrowed
556+
}
557+
}
558+
559+
return filtered
560+
}
561+
404562
for _, ref := range args {
405563
switch n := ref.parent.(type) {
406564

@@ -697,7 +855,8 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
697855
continue
698856
}
699857

700-
paramName, paramType := funcCallArgMetadata(funcs, i, argName)
858+
candidateFuncs := filterFuncCallCandidates(n, funcs, i, argName, ref.ref.Number)
859+
paramName, paramType := funcCallArgMetadata(candidateFuncs, i, argName)
701860
if argName != "" {
702861
paramName = argName
703862
}

0 commit comments

Comments
 (0)