Skip to content

Commit 2f66e96

Browse files
committed
fix queries
1 parent 873c31c commit 2f66e96

3 files changed

Lines changed: 29 additions & 32 deletions

File tree

examples/authors/ydb-database-sql/queries.sql.go

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

examples/authors/ydb-go-sdk/queries.sql.go

Lines changed: 4 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/codegen/dbsql/gen.go

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ package dbsql
33
import (
44
"bytes"
55
"context"
6-
"regexp"
7-
"strconv"
86
"strings"
97
"text/template"
108

@@ -155,21 +153,13 @@ type Queries struct {
155153
}
156154
`
157155

158-
// sqlToPositional converts $name placeholders to $1, $2 in param order.
159-
func sqlToPositional(sql string, paramNames []string) string {
160-
for i, name := range paramNames {
161-
re := regexp.MustCompile(`\$` + regexp.QuoteMeta(name) + `\b`)
162-
sql = re.ReplaceAllString(sql, "$"+strconv.Itoa(i+1))
163-
}
164-
return sql
165-
}
166-
167156
func genQueryFile(cat *pb.Catalog, pkg, sourceName string, queries []*pb.Query) ([]byte, error) {
168157
var buf bytes.Buffer
169158
buf.WriteString("// Code generated by sqlc. DO NOT EDIT.\n// source: " + sourceName + "\n\n")
170159
buf.WriteString("package " + pkg + "\n\n")
171160
buf.WriteString("import (\n")
172-
buf.WriteString("\t\"context\"\n\n")
161+
buf.WriteString("\t\"context\"\n")
162+
buf.WriteString("\t\"database/sql\"\n\n")
173163
buf.WriteString("\t_ \"github.com/ydb-platform/ydb-go-sdk/v3\"\n")
174164
buf.WriteString("\t\"github.com/ydb-platform/ydb-go-sdk/v3/pkg/xerrors\"\n")
175165
buf.WriteString("\t\"github.com/ydb-platform/ydb-go-sdk/v3/retry\"\n")
@@ -194,16 +184,8 @@ func genQuery(cat *pb.Catalog, pkg string, q *pb.Query) ([]byte, error) {
194184
cmd := strings.ToLower(q.GetCmd())
195185
text := q.GetText()
196186

197-
paramNames := make([]string, 0, len(q.GetParams()))
198-
for _, p := range q.GetParams() {
199-
if col := p.GetColumn(); col != nil && col.GetName() != "" {
200-
paramNames = append(paramNames, col.GetName())
201-
}
202-
}
203-
positionalSQL := sqlToPositional(text, paramNames)
204-
205187
constName := toGoConst(name)
206-
buf.WriteString("const " + constName + " = `" + positionalSQL + "`\n\n")
188+
buf.WriteString("const " + constName + " = `" + text + "`\n\n")
207189

208190
if len(q.GetParams()) > 0 {
209191
buf.WriteString("type " + name + "Params struct {\n")
@@ -263,14 +245,14 @@ func genOne(cat *pb.Catalog, name, constName string, q *pb.Query) string {
263245
}
264246

265247
sig := "ctx context.Context, id int64"
266-
args := "id"
248+
args := ""
267249
if len(params) > 0 {
268250
sig = "ctx context.Context, arg " + name + "Params"
269-
args = argList(params)
251+
args = namedArgList(params)
270252
}
271253
b.WriteString("func (q *Queries) " + name + "(" + sig + ") (*" + retType + ", error) {\n")
272254
b.WriteString("\ti, err := retry.RetryWithResult(ctx, func(ctx context.Context) (*" + retType + ", error) {\n")
273-
b.WriteString("\t\trow := q.db.QueryRowContext(ctx, " + constName + ", " + args + ")\n")
255+
b.WriteString("\t\trow := q.db.QueryRowContext(ctx, " + constName + args + ")\n")
274256
b.WriteString("\t\tvar i " + retType + "\n")
275257
b.WriteString("\t\terr := row.Scan(" + scanTargets(cols) + ")\n")
276258
b.WriteString("\t\tif err != nil {\n")
@@ -306,12 +288,12 @@ func genMany(cat *pb.Catalog, name, constName string, q *pb.Query) string {
306288
args := ""
307289
if len(params) > 0 {
308290
sig = "ctx context.Context, arg " + name + "Params"
309-
args = argList(params)
291+
args = namedArgList(params)
310292
}
311293
b.WriteString("func (q *Queries) " + name + "(" + sig + ") ([]" + rowType + ", error) {\n")
312294
b.WriteString("\titems, err := retry.RetryWithResult(ctx, func(ctx context.Context) ([]" + rowType + ", error) {\n")
313295
if len(params) > 0 {
314-
b.WriteString("\t\trows, err := q.db.QueryContext(ctx, " + constName + ", " + args + ")\n")
296+
b.WriteString("\t\trows, err := q.db.QueryContext(ctx, " + constName + args + ")\n")
315297
} else {
316298
b.WriteString("\t\trows, err := q.db.QueryContext(ctx, " + constName + ")\n")
317299
}
@@ -347,12 +329,12 @@ func genExec(name, constName string, q *pb.Query) string {
347329
args := ""
348330
if len(params) > 0 {
349331
sig = "ctx context.Context, arg " + name + "Params"
350-
args = argList(params)
332+
args = namedArgList(params)
351333
}
352334
b.WriteString("func (q *Queries) " + name + "(" + sig + ") error {\n")
353335
b.WriteString("\terr := retry.Retry(ctx, func(ctx context.Context) error {\n")
354336
if len(params) > 0 {
355-
b.WriteString("\t\t_, err := q.db.ExecContext(ctx, " + constName + ", " + args + ")\n")
337+
b.WriteString("\t\t_, err := q.db.ExecContext(ctx, " + constName + args + ")\n")
356338
} else {
357339
b.WriteString("\t\t_, err := q.db.ExecContext(ctx, " + constName + ")\n")
358340
}
@@ -369,16 +351,27 @@ func genExec(name, constName string, q *pb.Query) string {
369351
return b.String()
370352
}
371353

372-
func argList(params []*pb.Parameter) string {
354+
// namedArgList returns the args for QueryRowContext/QueryContext/ExecContext as sql.Named(..., arg.Field).
355+
// Format: ",\n\t\t\tsql.Named(\"id\", arg.ID),\n\t\t\tsql.Named(\"name\", arg.Name),\n\t\t"
356+
func namedArgList(params []*pb.Parameter) string {
357+
if len(params) == 0 {
358+
return ""
359+
}
373360
var parts []string
374361
for _, p := range params {
375362
col := p.GetColumn()
376363
if col == nil {
377364
continue
378365
}
379-
parts = append(parts, "arg."+toGoField(col.GetName()))
366+
paramName := col.GetName()
367+
field := "arg." + toGoField(paramName)
368+
parts = append(parts, "sql.Named(\""+paramName+"\", "+field+")")
380369
}
381-
return strings.Join(parts, ", ")
370+
if len(parts) == 0 {
371+
return ""
372+
}
373+
// Indent: row/rows/exec is 2 tabs; sql.Named continuation 3 tabs; closing paren 2 tabs.
374+
return ",\n\t\t\t" + strings.Join(parts, ",\n\t\t\t") + ",\n\t\t"
382375
}
383376

384377
func toGoStruct(s string) string {

0 commit comments

Comments
 (0)