Skip to content

Commit 516c58c

Browse files
committed
Add savepoint support
Signed-off-by: Bryan Frimin <bryan@getprobo.com>
1 parent 9380419 commit 516c58c

File tree

2 files changed

+111
-17
lines changed

2 files changed

+111
-17
lines changed

migrator/migrator.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,16 @@ func (m *Migrator) Run(ctx context.Context, dirname string) error {
7171
return nil
7272
}
7373

74+
independentCtx := pg.WithoutTx(ctx)
75+
7476
err := m.pg.WithAdvisoryLock(
7577
ctx,
7678
MigrationAdvisoryLock,
7779
func(conn pg.Conn) error {
7880
err := m.pg.WithConn(
79-
ctx,
80-
func(conn pg.Conn) error {
81-
return createIfNotExistVersionsTable(ctx, conn)
81+
independentCtx,
82+
func(connCtx context.Context, conn pg.Conn) error {
83+
return createIfNotExistVersionsTable(connCtx, conn)
8284
},
8385
)
8486
if err != nil {
@@ -98,9 +100,9 @@ func (m *Migrator) Run(ctx context.Context, dirname string) error {
98100
m.logger.Info("applying migration", log.String("version", migration.Version))
99101

100102
err := m.pg.WithTx(
101-
ctx,
102-
func(conn pg.Conn) error {
103-
return migration.Apply(ctx, conn)
103+
independentCtx,
104+
func(txCtx context.Context, conn pg.Conn) error {
105+
return migration.Apply(txCtx, conn)
104106
},
105107
)
106108
if err != nil {

pg/client.go

Lines changed: 103 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,25 @@ type (
6666
registerer prometheus.Registerer
6767
}
6868

69-
ExecFunc func(Conn) error
69+
ExecFunc func(context.Context, Conn) error
7070

7171
AdvisoryLock = uint32
72+
73+
txKey struct{}
7274
)
7375

76+
func txFromContext(ctx context.Context) pgx.Tx {
77+
tx, _ := ctx.Value(txKey{}).(pgx.Tx)
78+
return tx
79+
}
80+
81+
// WithoutTx returns a copy of ctx with the active transaction
82+
// removed. Use this when a nested call should start an independent
83+
// transaction instead of reusing the parent as a savepoint.
84+
func WithoutTx(ctx context.Context) context.Context {
85+
return context.WithValue(ctx, txKey{}, nil)
86+
}
87+
7488
const (
7589
BaseAdvisoryLockId uint32 = 42
7690
)
@@ -301,7 +315,7 @@ func (c *Client) WithConn(
301315
}
302316
defer conn.Release()
303317

304-
if err := exec(conn); err != nil {
318+
if err := exec(ctx, conn); err != nil {
305319
if rootSpan.IsRecording() {
306320
recordError(span, err)
307321
}
@@ -317,12 +331,27 @@ func (c *Client) WithConn(
317331
// returns an error, the transaction is rolled back; otherwise, it
318332
// commits.
319333
//
334+
// When called within an existing transaction (i.e. nested WithTx),
335+
// a savepoint is created instead of a new transaction. If `exec`
336+
// returns an error, only the savepoint is rolled back; the outer
337+
// transaction remains active and can still be committed.
338+
//
320339
// Example:
321340
//
322-
// err := client.WithTx(ctx, func(tx pg.Conn) error {
323-
// if _, err := tx.Exec(ctx, "DELETE FROM users WHERE id = $1 ", id); err != nil {
341+
// err := client.WithTx(ctx, func(ctx context.Context, tx pg.Conn) error {
342+
// if _, err := tx.Exec(ctx, "DELETE FROM users WHERE id = $1", id); err != nil {
324343
// return err
325344
// }
345+
//
346+
// // Nested call creates a savepoint; failure here does not
347+
// // roll back the DELETE above.
348+
// if err := client.WithTx(ctx, func(ctx context.Context, tx pg.Conn) error {
349+
// _, err := tx.Exec(ctx, "INSERT INTO audit_log (...) VALUES (...)")
350+
// return err
351+
// }); err != nil {
352+
// log.Warn("audit failed, continuing", "err", err)
353+
// }
354+
//
326355
// return nil
327356
// })
328357
//
@@ -346,10 +375,71 @@ func (c *Client) WithTx(
346375
defer span.End()
347376
}
348377

378+
if parentTx := txFromContext(ctx); parentTx != nil {
379+
if span != nil {
380+
span.SetAttributes(attribute.Bool("db.savepoint", true))
381+
}
382+
383+
return c.withSavepoint(ctx, span, parentTx, exec)
384+
}
385+
386+
return c.withNewTx(ctx, span, exec)
387+
}
388+
389+
func (c *Client) withSavepoint(
390+
ctx context.Context,
391+
span trace.Span,
392+
parentTx pgx.Tx,
393+
exec ExecFunc,
394+
) error {
395+
tx, err := parentTx.Begin(ctx)
396+
if err != nil {
397+
err := fmt.Errorf("cannot create savepoint: %w", err)
398+
if span != nil {
399+
recordError(span, err)
400+
}
401+
402+
return err
403+
}
404+
405+
txCtx := context.WithValue(ctx, txKey{}, tx)
406+
407+
if err := exec(txCtx, tx); err != nil {
408+
if err2 := tx.Rollback(ctx); err2 != nil {
409+
err = errors.Join(
410+
err,
411+
fmt.Errorf("cannot rollback savepoint: %w", err2),
412+
)
413+
}
414+
415+
if span != nil {
416+
recordError(span, err)
417+
}
418+
419+
return err
420+
}
421+
422+
if err := tx.Commit(ctx); err != nil {
423+
err := fmt.Errorf("cannot release savepoint: %w", err)
424+
if span != nil {
425+
recordError(span, err)
426+
}
427+
428+
return err
429+
}
430+
431+
return nil
432+
}
433+
434+
func (c *Client) withNewTx(
435+
ctx context.Context,
436+
span trace.Span,
437+
exec ExecFunc,
438+
) error {
349439
conn, err := c.pool.Acquire(ctx)
350440
if err != nil {
351441
err := fmt.Errorf("cannot acquire connection: %w", err)
352-
if rootSpan.IsRecording() {
442+
if span != nil {
353443
recordError(span, err)
354444
}
355445

@@ -360,22 +450,24 @@ func (c *Client) WithTx(
360450
tx, err := conn.Begin(ctx)
361451
if err != nil {
362452
err := fmt.Errorf("cannot begin transaction: %w", err)
363-
if rootSpan.IsRecording() {
453+
if span != nil {
364454
recordError(span, err)
365455
}
366456

367457
return err
368458
}
369459

370-
if err := exec(tx); err != nil {
460+
txCtx := context.WithValue(ctx, txKey{}, tx)
461+
462+
if err := exec(txCtx, tx); err != nil {
371463
if err2 := tx.Rollback(ctx); err2 != nil {
372464
err = errors.Join(
373465
err,
374466
fmt.Errorf("cannot rollback transaction: %w", err2),
375467
)
376468
}
377469

378-
if rootSpan.IsRecording() {
470+
if span != nil {
379471
recordError(span, err)
380472
}
381473

@@ -384,7 +476,7 @@ func (c *Client) WithTx(
384476

385477
if err := tx.Commit(ctx); err != nil {
386478
err := fmt.Errorf("cannot commit transaction: %w", err)
387-
if rootSpan.IsRecording() {
479+
if span != nil {
388480
recordError(span, err)
389481
}
390482

@@ -418,9 +510,9 @@ func (c *Client) WithAdvisoryLock(
418510

419511
return c.WithTx(
420512
ctx,
421-
func(conn Conn) error {
513+
func(txCtx context.Context, conn Conn) error {
422514
q := "SELECT pg_advisory_xact_lock($1, $2)"
423-
_, err := conn.Exec(ctx, q, BaseAdvisoryLockId, id)
515+
_, err := conn.Exec(txCtx, q, BaseAdvisoryLockId, id)
424516
if err != nil {
425517
err = fmt.Errorf("cannot acquire advisory lock: %w", err)
426518
if rootSpan.IsRecording() {

0 commit comments

Comments
 (0)