@@ -66,25 +66,11 @@ type (
6666 registerer prometheus.Registerer
6767 }
6868
69- ExecFunc func (context.Context , Conn ) error
69+ ExecFunc [ Q Querier ] func (context.Context , Q ) error
7070
7171 AdvisoryLock = uint32
72-
73- txKey struct {}
7472)
7573
76- func txFromContext (ctx context.Context ) pgx.Tx {
77- tx , _ := ctx .Value (txKey {}).(pgx.Tx )
78- return tx
79- }
80-
81- // WithoutTx returns a copy of ctx with the active transaction
82- // removed. Use this when a nested call should start an independent
83- // transaction instead of reusing the parent as a savepoint.
84- func WithoutTx (ctx context.Context ) context.Context {
85- return context .WithValue (ctx , txKey {}, nil )
86- }
87-
8874const (
8975 BaseAdvisoryLockId uint32 = 42
9076)
@@ -279,7 +265,7 @@ func (c *Client) Close() {
279265//
280266// Example:
281267//
282- // err := client.WithConn(ctx, func(conn pg.Conn ) error {
268+ // err := client.WithConn(ctx, func(ctx context.Context, conn pg.Querier ) error {
283269// _, err := conn.Exec(ctx, "SELECT * FROM users")
284270// return err
285271// })
@@ -288,7 +274,7 @@ func (c *Client) Close() {
288274// and logs any errors.
289275func (c * Client ) WithConn (
290276 ctx context.Context ,
291- exec ExecFunc ,
277+ exec ExecFunc [ Querier ] ,
292278) error {
293279 var (
294280 rootSpan = trace .SpanFromContext (ctx )
@@ -326,27 +312,24 @@ func (c *Client) WithConn(
326312 return nil
327313}
328314
329- // WithTx executes the given ExecFunc within a transaction. This
330- // method begins a transaction, executing `exec` within it. If `exec`
331- // returns an error, the transaction is rolled back; otherwise, it
332- // commits.
315+ // WithTx executes the given ExecFunc within a new transaction. This
316+ // method acquires a connection, begins a transaction, and executes
317+ // exec within it. If exec returns an error, the transaction is rolled
318+ // back; otherwise, it commits.
333319//
334- // When called within an existing transaction (i.e. nested WithTx),
335- // a savepoint is created instead of a new transaction. If `exec`
336- // returns an error, only the savepoint is rolled back; the outer
337- // transaction remains active and can still be committed.
320+ // Use Tx.Savepoint inside the callback to create savepoints within
321+ // the transaction.
338322//
339323// Example:
340324//
341- // err := client.WithTx(ctx, func(ctx context.Context, tx pg.Conn ) error {
325+ // err := client.WithTx(ctx, func(ctx context.Context, tx pg.Tx ) error {
342326// if _, err := tx.Exec(ctx, "DELETE FROM users WHERE id = $1", id); err != nil {
343327// return err
344328// }
345329//
346- // // Nested call creates a savepoint; failure here does not
347- // // roll back the DELETE above.
348- // if err := client.WithTx(ctx, func(ctx context.Context, tx pg.Conn) error {
349- // _, err := tx.Exec(ctx, "INSERT INTO audit_log (...) VALUES (...)")
330+ // // Savepoint failure does not roll back the DELETE above.
331+ // if err := tx.Savepoint(ctx, func(ctx context.Context, q pg.Querier) error {
332+ // _, err := q.Exec(ctx, "INSERT INTO audit_log (...) VALUES (...)")
350333// return err
351334// }); err != nil {
352335// log.Warn("audit failed, continuing", "err", err)
@@ -359,7 +342,7 @@ func (c *Client) WithConn(
359342// and logs any errors.
360343func (c * Client ) WithTx (
361344 ctx context.Context ,
362- exec ExecFunc ,
345+ exec ExecFunc [ Tx ] ,
363346) error {
364347 var (
365348 rootSpan = trace .SpanFromContext (ctx )
@@ -375,67 +358,6 @@ func (c *Client) WithTx(
375358 defer span .End ()
376359 }
377360
378- if parentTx := txFromContext (ctx ); parentTx != nil {
379- if span != nil {
380- span .SetAttributes (attribute .Bool ("db.savepoint" , true ))
381- }
382-
383- return c .withSavepoint (ctx , span , parentTx , exec )
384- }
385-
386- return c .withNewTx (ctx , span , exec )
387- }
388-
389- func (c * Client ) withSavepoint (
390- ctx context.Context ,
391- span trace.Span ,
392- parentTx pgx.Tx ,
393- exec ExecFunc ,
394- ) error {
395- tx , err := parentTx .Begin (ctx )
396- if err != nil {
397- err := fmt .Errorf ("cannot create savepoint: %w" , err )
398- if span != nil {
399- recordError (span , err )
400- }
401-
402- return err
403- }
404-
405- txCtx := context .WithValue (ctx , txKey {}, tx )
406-
407- if err := exec (txCtx , tx ); err != nil {
408- if err2 := tx .Rollback (ctx ); err2 != nil {
409- err = errors .Join (
410- err ,
411- fmt .Errorf ("cannot rollback savepoint: %w" , err2 ),
412- )
413- }
414-
415- if span != nil {
416- recordError (span , err )
417- }
418-
419- return err
420- }
421-
422- if err := tx .Commit (ctx ); err != nil {
423- err := fmt .Errorf ("cannot release savepoint: %w" , err )
424- if span != nil {
425- recordError (span , err )
426- }
427-
428- return err
429- }
430-
431- return nil
432- }
433-
434- func (c * Client ) withNewTx (
435- ctx context.Context ,
436- span trace.Span ,
437- exec ExecFunc ,
438- ) error {
439361 conn , err := c .pool .Acquire (ctx )
440362 if err != nil {
441363 err := fmt .Errorf ("cannot acquire connection: %w" , err )
@@ -447,7 +369,7 @@ func (c *Client) withNewTx(
447369 }
448370 defer conn .Release ()
449371
450- tx , err := conn .Begin (ctx )
372+ innerTx , err := conn .Begin (ctx )
451373 if err != nil {
452374 err := fmt .Errorf ("cannot begin transaction: %w" , err )
453375 if span != nil {
@@ -457,10 +379,10 @@ func (c *Client) withNewTx(
457379 return err
458380 }
459381
460- txCtx := context . WithValue ( ctx , txKey {}, tx )
382+ tx := & pgxTx { inner : innerTx , tracer : c . tracer }
461383
462- if err := exec (txCtx , tx ); err != nil {
463- if err2 := tx .Rollback (ctx ); err2 != nil {
384+ if err := exec (ctx , tx ); err != nil {
385+ if err2 := innerTx .Rollback (ctx ); err2 != nil {
464386 err = errors .Join (
465387 err ,
466388 fmt .Errorf ("cannot rollback transaction: %w" , err2 ),
@@ -474,7 +396,7 @@ func (c *Client) withNewTx(
474396 return err
475397 }
476398
477- if err := tx .Commit (ctx ); err != nil {
399+ if err := innerTx .Commit (ctx ); err != nil {
478400 err := fmt .Errorf ("cannot commit transaction: %w" , err )
479401 if span != nil {
480402 recordError (span , err )
@@ -489,7 +411,7 @@ func (c *Client) withNewTx(
489411func (c * Client ) WithAdvisoryLock (
490412 ctx context.Context ,
491413 id AdvisoryLock ,
492- f func (Conn ) error ,
414+ f func (Querier ) error ,
493415) error {
494416 var (
495417 rootSpan = trace .SpanFromContext (ctx )
@@ -510,9 +432,9 @@ func (c *Client) WithAdvisoryLock(
510432
511433 return c .WithTx (
512434 ctx ,
513- func (txCtx context.Context , conn Conn ) error {
435+ func (txCtx context.Context , tx Tx ) error {
514436 q := "SELECT pg_advisory_xact_lock($1, $2)"
515- _ , err := conn .Exec (txCtx , q , BaseAdvisoryLockId , id )
437+ _ , err := tx .Exec (txCtx , q , BaseAdvisoryLockId , id )
516438 if err != nil {
517439 err = fmt .Errorf ("cannot acquire advisory lock: %w" , err )
518440 if rootSpan .IsRecording () {
@@ -523,7 +445,7 @@ func (c *Client) WithAdvisoryLock(
523445 return err
524446 }
525447
526- err = f (conn )
448+ err = f (tx )
527449 if err != nil {
528450 if rootSpan .IsRecording () {
529451 span .SetStatus (codes .Error , err .Error ())
0 commit comments