@@ -160,3 +160,191 @@ func TestStrictNamedArgsRewriteQuery(t *testing.T) {
160160 }
161161 }
162162}
163+
164+ func TestStructArgs (t * testing.T ) {
165+ t .Parallel ()
166+
167+ for _ , tt := range []struct {
168+ name string
169+ input any
170+ sql string
171+ expectedSQL string
172+ expectedArgs []any
173+ expectError bool
174+ }{
175+ {
176+ name : "basic" ,
177+ input : struct {
178+ ID int `db:"id"`
179+ Name string `db:"name,omitempty"`
180+ Skip string `db:"-"`
181+ }{ID : 42 , Name : "x" , Skip : "ignored" },
182+ sql : "select * from t where id=@id and name=@name" ,
183+ expectedSQL : "select * from t where id=$1 and name=$2" ,
184+ expectedArgs : []any {42 , "x" },
185+ },
186+ {
187+ name : "pointer" ,
188+ input : func () any {
189+ type S struct {
190+ ID int `db:"id"`
191+ }
192+ return & S {ID : 7 }
193+ }(),
194+ sql : "select * from t where id=@id" ,
195+ expectedSQL : "select * from t where id=$1" ,
196+ expectedArgs : []any {7 },
197+ },
198+ {
199+ name : "unexported fields omitted (missing placeholders become nil)" ,
200+ input : struct {
201+ id int `db:"id"`
202+ ID int `db:"ID"`
203+ }{id : 1 , ID : 2 },
204+ sql : "select * from t where ID=@ID and id=@id" ,
205+ expectedSQL : "select * from t where ID=$1 and id=$2" ,
206+ expectedArgs : []any {2 , nil },
207+ },
208+ {
209+ name : "missing db tag falls back to field name" ,
210+ input : struct {
211+ ID int
212+ }{ID : 9 },
213+ sql : "select * from t where ID=@ID" ,
214+ expectedSQL : "select * from t where ID=$1" ,
215+ expectedArgs : []any {9 },
216+ },
217+ {
218+ name : "duplicate keys error" ,
219+ input : struct {
220+ A int `db:"x"`
221+ B int `db:"x"`
222+ }{A : 1 , B : 2 },
223+ sql : "select * from t where x=@x" ,
224+ expectError : true ,
225+ },
226+ {
227+ name : "nil pointer returns error" ,
228+ input : func () any {
229+ type S struct {
230+ ID int `db:"id"`
231+ }
232+ var s * S
233+ return s
234+ }(),
235+ sql : "select * from t where id=@id" ,
236+ expectError : true ,
237+ },
238+ {
239+ name : "non struct returns error" ,
240+ input : 42 ,
241+ sql : "select * from t where id=@id" ,
242+ expectError : true ,
243+ },
244+ {
245+ name : "nil input returns error" ,
246+ input : nil ,
247+ sql : "select * from t where id=@id" ,
248+ expectError : true ,
249+ },
250+ } {
251+ t .Run (tt .name , func (t * testing.T ) {
252+ t .Parallel ()
253+
254+ qr := pgx .StructArgs (tt .input )
255+ sql , args , err := qr .RewriteQuery (context .Background (), nil , tt .sql , nil )
256+ if tt .expectError {
257+ require .Error (t , err )
258+ return
259+ }
260+
261+ require .NoError (t , err )
262+ assert .Equal (t , tt .expectedSQL , sql )
263+ assert .EqualValues (t , tt .expectedArgs , args )
264+ })
265+ }
266+ }
267+
268+ func TestStrictStructArgs (t * testing.T ) {
269+ t .Parallel ()
270+
271+ type MyInt int
272+
273+ for _ , tt := range []struct {
274+ name string
275+ input any
276+ sql string
277+ expectedSQL string
278+ expectedArgs []any
279+ expectError bool
280+ }{
281+ {
282+ name : "fallback to field name without db tag" ,
283+ input : struct {
284+ ID int
285+ }{ID : 1 },
286+ sql : "select * from t where ID=@ID" ,
287+ expectedSQL : "select * from t where ID=$1" ,
288+ expectedArgs : []any {1 },
289+ },
290+ {
291+ name : "empty db tag errors" ,
292+ input : struct {
293+ ID int `db:","`
294+ }{ID : 1 },
295+ sql : "select * from t where ID=@ID" ,
296+ expectError : true ,
297+ },
298+ {
299+ name : "duplicate keys error" ,
300+ input : struct {
301+ A int `db:"x"`
302+ B int `db:"x"`
303+ }{A : 1 , B : 2 },
304+ sql : "select * from t where x=@x" ,
305+ expectError : true ,
306+ },
307+ {
308+ name : "skips anonymous embedded structs without flattening" ,
309+ input : func () any {
310+ type Embedded struct {
311+ ID int `db:"id"`
312+ }
313+ type S struct {
314+ Embedded
315+ Name string `db:"name"`
316+ }
317+ return S {Embedded : Embedded {ID : 1 }, Name : "x" }
318+ }(),
319+ sql : "select * from t where name=@name and id=@id" ,
320+ expectError : true ,
321+ },
322+ {
323+ name : "anonymous embedded non-struct still requires tag in strict mode" ,
324+ input : func () any {
325+ type S struct {
326+ MyInt
327+ }
328+ return S {MyInt : 1 }
329+ }(),
330+ sql : "select * from t where MyInt=@MyInt" ,
331+ expectedSQL : "select * from t where MyInt=$1" ,
332+ expectedArgs : []any {MyInt (1 )},
333+ },
334+ } {
335+ t .Run (tt .name , func (t * testing.T ) {
336+ t .Parallel ()
337+
338+ qr := pgx .StrictStructArgs (tt .input )
339+ sql , args , err := qr .RewriteQuery (context .Background (), nil , tt .sql , nil )
340+ if tt .expectError {
341+ require .Error (t , err )
342+ return
343+ }
344+
345+ require .NoError (t , err )
346+ assert .Equal (t , tt .expectedSQL , sql )
347+ assert .EqualValues (t , tt .expectedArgs , args )
348+ })
349+ }
350+ }
0 commit comments