@@ -66,11 +66,25 @@ type (
6666 registerer prometheus.Registerer
6767 }
6868
69- ExecFunc func (Conn ) error
69+ ExecFunc func (context. Context , Conn ) error
7070
7171 AdvisoryLock = uint32
72+
73+ txKey struct {}
7274)
7375
76+ func txFromContext (ctx context.Context ) pgx.Tx {
77+ tx , _ := ctx .Value (txKey {}).(pgx.Tx )
78+ return tx
79+ }
80+
81+ // WithoutTx returns a copy of ctx with the active transaction
82+ // removed. Use this when a nested call should start an independent
83+ // transaction instead of reusing the parent as a savepoint.
84+ func WithoutTx (ctx context.Context ) context.Context {
85+ return context .WithValue (ctx , txKey {}, nil )
86+ }
87+
7488const (
7589 BaseAdvisoryLockId uint32 = 42
7690)
@@ -301,7 +315,7 @@ func (c *Client) WithConn(
301315 }
302316 defer conn .Release ()
303317
304- if err := exec (conn ); err != nil {
318+ if err := exec (ctx , conn ); err != nil {
305319 if rootSpan .IsRecording () {
306320 recordError (span , err )
307321 }
@@ -317,12 +331,27 @@ func (c *Client) WithConn(
317331// returns an error, the transaction is rolled back; otherwise, it
318332// commits.
319333//
334+ // When called within an existing transaction (i.e. nested WithTx),
335+ // a savepoint is created instead of a new transaction. If `exec`
336+ // returns an error, only the savepoint is rolled back; the outer
337+ // transaction remains active and can still be committed.
338+ //
320339// Example:
321340//
322- // err := client.WithTx(ctx, func(tx pg.Conn) error {
323- // if _, err := tx.Exec(ctx, "DELETE FROM users WHERE id = $1 ", id); err != nil {
341+ // err := client.WithTx(ctx, func(ctx context.Context, tx pg.Conn) error {
342+ // if _, err := tx.Exec(ctx, "DELETE FROM users WHERE id = $1", id); err != nil {
324343// return err
325344// }
345+ //
346+ // // Nested call creates a savepoint; failure here does not
347+ // // roll back the DELETE above.
348+ // if err := client.WithTx(ctx, func(ctx context.Context, tx pg.Conn) error {
349+ // _, err := tx.Exec(ctx, "INSERT INTO audit_log (...) VALUES (...)")
350+ // return err
351+ // }); err != nil {
352+ // log.Warn("audit failed, continuing", "err", err)
353+ // }
354+ //
326355// return nil
327356// })
328357//
@@ -346,10 +375,71 @@ func (c *Client) WithTx(
346375 defer span .End ()
347376 }
348377
378+ if parentTx := txFromContext (ctx ); parentTx != nil {
379+ if span != nil {
380+ span .SetAttributes (attribute .Bool ("db.savepoint" , true ))
381+ }
382+
383+ return c .withSavepoint (ctx , span , parentTx , exec )
384+ }
385+
386+ return c .withNewTx (ctx , span , exec )
387+ }
388+
389+ func (c * Client ) withSavepoint (
390+ ctx context.Context ,
391+ span trace.Span ,
392+ parentTx pgx.Tx ,
393+ exec ExecFunc ,
394+ ) error {
395+ tx , err := parentTx .Begin (ctx )
396+ if err != nil {
397+ err := fmt .Errorf ("cannot create savepoint: %w" , err )
398+ if span != nil {
399+ recordError (span , err )
400+ }
401+
402+ return err
403+ }
404+
405+ txCtx := context .WithValue (ctx , txKey {}, tx )
406+
407+ if err := exec (txCtx , tx ); err != nil {
408+ if err2 := tx .Rollback (ctx ); err2 != nil {
409+ err = errors .Join (
410+ err ,
411+ fmt .Errorf ("cannot rollback savepoint: %w" , err2 ),
412+ )
413+ }
414+
415+ if span != nil {
416+ recordError (span , err )
417+ }
418+
419+ return err
420+ }
421+
422+ if err := tx .Commit (ctx ); err != nil {
423+ err := fmt .Errorf ("cannot release savepoint: %w" , err )
424+ if span != nil {
425+ recordError (span , err )
426+ }
427+
428+ return err
429+ }
430+
431+ return nil
432+ }
433+
434+ func (c * Client ) withNewTx (
435+ ctx context.Context ,
436+ span trace.Span ,
437+ exec ExecFunc ,
438+ ) error {
349439 conn , err := c .pool .Acquire (ctx )
350440 if err != nil {
351441 err := fmt .Errorf ("cannot acquire connection: %w" , err )
352- if rootSpan . IsRecording () {
442+ if span != nil {
353443 recordError (span , err )
354444 }
355445
@@ -360,22 +450,24 @@ func (c *Client) WithTx(
360450 tx , err := conn .Begin (ctx )
361451 if err != nil {
362452 err := fmt .Errorf ("cannot begin transaction: %w" , err )
363- if rootSpan . IsRecording () {
453+ if span != nil {
364454 recordError (span , err )
365455 }
366456
367457 return err
368458 }
369459
370- if err := exec (tx ); err != nil {
460+ txCtx := context .WithValue (ctx , txKey {}, tx )
461+
462+ if err := exec (txCtx , tx ); err != nil {
371463 if err2 := tx .Rollback (ctx ); err2 != nil {
372464 err = errors .Join (
373465 err ,
374466 fmt .Errorf ("cannot rollback transaction: %w" , err2 ),
375467 )
376468 }
377469
378- if rootSpan . IsRecording () {
470+ if span != nil {
379471 recordError (span , err )
380472 }
381473
@@ -384,7 +476,7 @@ func (c *Client) WithTx(
384476
385477 if err := tx .Commit (ctx ); err != nil {
386478 err := fmt .Errorf ("cannot commit transaction: %w" , err )
387- if rootSpan . IsRecording () {
479+ if span != nil {
388480 recordError (span , err )
389481 }
390482
@@ -418,9 +510,9 @@ func (c *Client) WithAdvisoryLock(
418510
419511 return c .WithTx (
420512 ctx ,
421- func (conn Conn ) error {
513+ func (txCtx context. Context , conn Conn ) error {
422514 q := "SELECT pg_advisory_xact_lock($1, $2)"
423- _ , err := conn .Exec (ctx , q , BaseAdvisoryLockId , id )
515+ _ , err := conn .Exec (txCtx , q , BaseAdvisoryLockId , id )
424516 if err != nil {
425517 err = fmt .Errorf ("cannot acquire advisory lock: %w" , err )
426518 if rootSpan .IsRecording () {
0 commit comments