Skip to content

Commit 82b212c

Browse files
authored
Merge pull request #2536 from Tubelight30/feat/struct-args-named-query
feat: Add StructArgs and StrictStructArgs for named query arguments
2 parents d2ed749 + 5a8eac9 commit 82b212c

2 files changed

Lines changed: 303 additions & 0 deletions

File tree

named_args.go

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package pgx
33
import (
44
"context"
55
"fmt"
6+
"reflect"
67
"strconv"
78
"strings"
89
"unicode/utf8"
@@ -18,6 +19,7 @@ import (
1819
//
1920
// Named placeholders are case sensitive and must start with a letter or underscore. Subsequent characters can be
2021
// letters, numbers, or underscores.
22+
2123
type NamedArgs map[string]any
2224

2325
// RewriteQuery implements the QueryRewriter interface.
@@ -34,6 +36,119 @@ func (sna StrictNamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql str
3436
return rewriteQuery(sna, sql, true)
3537
}
3638

39+
type errorQueryRewriter struct {
40+
err error
41+
}
42+
43+
func (r errorQueryRewriter) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
44+
return "", nil, r.err
45+
}
46+
47+
// StructArgs converts exported fields of a struct into a QueryRewriter so it can
48+
// be used as the first argument to a query method (e.g. "where id=@id").
49+
//
50+
// Field names are taken from the `db` struct tag if present. Tag values may
51+
// include comma-separated options (e.g. `db:"id,omitempty"`). A `db:"-"` field is
52+
// ignored. If no `db` tag is present, the Go field name is used.
53+
//
54+
// sa may be a struct or a pointer to a struct.
55+
func StructArgs(sa any) QueryRewriter {
56+
args, err := structArgs(sa)
57+
if err != nil {
58+
return errorQueryRewriter{err: err}
59+
}
60+
return NamedArgs(args)
61+
}
62+
63+
// StrictStructArgs is like StructArgs but uses StrictNamedArgs rewriting
64+
// semantics (i.e. errors if the SQL query references missing arguments or if
65+
// extra arguments are provided).
66+
func StrictStructArgs(sa any) QueryRewriter {
67+
args, err := structArgs(sa)
68+
if err != nil {
69+
return errorQueryRewriter{err: err}
70+
}
71+
return StrictNamedArgs(args)
72+
}
73+
74+
func structArgs(sa any) (map[string]any, error) {
75+
if sa == nil {
76+
return nil, fmt.Errorf("StructArgs requires a struct or pointer to struct, got nil")
77+
}
78+
79+
v := reflect.ValueOf(sa)
80+
t := v.Type()
81+
82+
if t.Kind() == reflect.Pointer {
83+
if v.IsNil() {
84+
return nil, fmt.Errorf("StructArgs requires a non-nil pointer to struct")
85+
}
86+
v = v.Elem()
87+
t = v.Type()
88+
}
89+
90+
if t.Kind() != reflect.Struct {
91+
return nil, fmt.Errorf("StructArgs requires a struct or pointer to struct, got %s", t)
92+
}
93+
94+
out := make(map[string]any, t.NumField())
95+
for i := 0; i < t.NumField(); i++ {
96+
sf := t.Field(i)
97+
98+
// Ignore unexported fields.
99+
if sf.PkgPath != "" {
100+
continue
101+
}
102+
103+
key, ok, err := dbTagKey(sf)
104+
if err != nil {
105+
return nil, err
106+
}
107+
if !ok {
108+
continue
109+
}
110+
111+
if _, exists := out[key]; exists {
112+
return nil, fmt.Errorf("duplicate StructArgs key %q", key)
113+
}
114+
115+
out[key] = v.Field(i).Interface()
116+
}
117+
118+
return out, nil
119+
}
120+
121+
// dbTagKey derives the named-argument key for a struct field. Tag parsing matches
122+
// RowToStructByName* in rows.go (structTagKey, Lookup, comma options, db:"-").
123+
// Anonymous embedded structs are skipped without flattening (unlike row scanning).
124+
func dbTagKey(sf reflect.StructField) (key string, ok bool, err error) {
125+
if sf.Anonymous {
126+
ft := sf.Type
127+
if ft.Kind() == reflect.Pointer {
128+
ft = ft.Elem()
129+
}
130+
if ft.Kind() == reflect.Struct {
131+
return "", false, nil
132+
}
133+
}
134+
135+
dbTag, dbTagPresent := sf.Tag.Lookup(structTagKey)
136+
if dbTagPresent {
137+
dbTag, _, _ = strings.Cut(dbTag, ",")
138+
}
139+
if dbTag == "-" {
140+
return "", false, nil
141+
}
142+
if dbTagPresent {
143+
if dbTag == "" {
144+
return "", false, fmt.Errorf("field %s has empty `%s` tag", sf.Name, structTagKey)
145+
}
146+
return dbTag, true, nil
147+
}
148+
149+
return sf.Name, true, nil
150+
}
151+
37152
type namedArg string
38153

39154
type sqlLexer struct {

named_args_test.go

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,191 @@ func TestStrictNamedArgsRewriteQuery(t *testing.T) {
160160
}
161161
}
162162
}
163+
164+
func TestStructArgs(t *testing.T) {
165+
t.Parallel()
166+
167+
for _, tt := range []struct {
168+
name string
169+
input any
170+
sql string
171+
expectedSQL string
172+
expectedArgs []any
173+
expectError bool
174+
}{
175+
{
176+
name: "basic",
177+
input: struct {
178+
ID int `db:"id"`
179+
Name string `db:"name,omitempty"`
180+
Skip string `db:"-"`
181+
}{ID: 42, Name: "x", Skip: "ignored"},
182+
sql: "select * from t where id=@id and name=@name",
183+
expectedSQL: "select * from t where id=$1 and name=$2",
184+
expectedArgs: []any{42, "x"},
185+
},
186+
{
187+
name: "pointer",
188+
input: func() any {
189+
type S struct {
190+
ID int `db:"id"`
191+
}
192+
return &S{ID: 7}
193+
}(),
194+
sql: "select * from t where id=@id",
195+
expectedSQL: "select * from t where id=$1",
196+
expectedArgs: []any{7},
197+
},
198+
{
199+
name: "unexported fields omitted (missing placeholders become nil)",
200+
input: struct {
201+
id int `db:"id"`
202+
ID int `db:"ID"`
203+
}{id: 1, ID: 2},
204+
sql: "select * from t where ID=@ID and id=@id",
205+
expectedSQL: "select * from t where ID=$1 and id=$2",
206+
expectedArgs: []any{2, nil},
207+
},
208+
{
209+
name: "missing db tag falls back to field name",
210+
input: struct {
211+
ID int
212+
}{ID: 9},
213+
sql: "select * from t where ID=@ID",
214+
expectedSQL: "select * from t where ID=$1",
215+
expectedArgs: []any{9},
216+
},
217+
{
218+
name: "duplicate keys error",
219+
input: struct {
220+
A int `db:"x"`
221+
B int `db:"x"`
222+
}{A: 1, B: 2},
223+
sql: "select * from t where x=@x",
224+
expectError: true,
225+
},
226+
{
227+
name: "nil pointer returns error",
228+
input: func() any {
229+
type S struct {
230+
ID int `db:"id"`
231+
}
232+
var s *S
233+
return s
234+
}(),
235+
sql: "select * from t where id=@id",
236+
expectError: true,
237+
},
238+
{
239+
name: "non struct returns error",
240+
input: 42,
241+
sql: "select * from t where id=@id",
242+
expectError: true,
243+
},
244+
{
245+
name: "nil input returns error",
246+
input: nil,
247+
sql: "select * from t where id=@id",
248+
expectError: true,
249+
},
250+
} {
251+
t.Run(tt.name, func(t *testing.T) {
252+
t.Parallel()
253+
254+
qr := pgx.StructArgs(tt.input)
255+
sql, args, err := qr.RewriteQuery(context.Background(), nil, tt.sql, nil)
256+
if tt.expectError {
257+
require.Error(t, err)
258+
return
259+
}
260+
261+
require.NoError(t, err)
262+
assert.Equal(t, tt.expectedSQL, sql)
263+
assert.EqualValues(t, tt.expectedArgs, args)
264+
})
265+
}
266+
}
267+
268+
func TestStrictStructArgs(t *testing.T) {
269+
t.Parallel()
270+
271+
type MyInt int
272+
273+
for _, tt := range []struct {
274+
name string
275+
input any
276+
sql string
277+
expectedSQL string
278+
expectedArgs []any
279+
expectError bool
280+
}{
281+
{
282+
name: "fallback to field name without db tag",
283+
input: struct {
284+
ID int
285+
}{ID: 1},
286+
sql: "select * from t where ID=@ID",
287+
expectedSQL: "select * from t where ID=$1",
288+
expectedArgs: []any{1},
289+
},
290+
{
291+
name: "empty db tag errors",
292+
input: struct {
293+
ID int `db:","`
294+
}{ID: 1},
295+
sql: "select * from t where ID=@ID",
296+
expectError: true,
297+
},
298+
{
299+
name: "duplicate keys error",
300+
input: struct {
301+
A int `db:"x"`
302+
B int `db:"x"`
303+
}{A: 1, B: 2},
304+
sql: "select * from t where x=@x",
305+
expectError: true,
306+
},
307+
{
308+
name: "skips anonymous embedded structs without flattening",
309+
input: func() any {
310+
type Embedded struct {
311+
ID int `db:"id"`
312+
}
313+
type S struct {
314+
Embedded
315+
Name string `db:"name"`
316+
}
317+
return S{Embedded: Embedded{ID: 1}, Name: "x"}
318+
}(),
319+
sql: "select * from t where name=@name and id=@id",
320+
expectError: true,
321+
},
322+
{
323+
name: "anonymous embedded non-struct still requires tag in strict mode",
324+
input: func() any {
325+
type S struct {
326+
MyInt
327+
}
328+
return S{MyInt: 1}
329+
}(),
330+
sql: "select * from t where MyInt=@MyInt",
331+
expectedSQL: "select * from t where MyInt=$1",
332+
expectedArgs: []any{MyInt(1)},
333+
},
334+
} {
335+
t.Run(tt.name, func(t *testing.T) {
336+
t.Parallel()
337+
338+
qr := pgx.StrictStructArgs(tt.input)
339+
sql, args, err := qr.RewriteQuery(context.Background(), nil, tt.sql, nil)
340+
if tt.expectError {
341+
require.Error(t, err)
342+
return
343+
}
344+
345+
require.NoError(t, err)
346+
assert.Equal(t, tt.expectedSQL, sql)
347+
assert.EqualValues(t, tt.expectedArgs, args)
348+
})
349+
}
350+
}

0 commit comments

Comments
 (0)