Skip to content

Commit 56d742c

Browse files
committed
Rewrite savepoint implementation
Signed-off-by: Bryan Frimin <bryan@getprobo.com>
1 parent f6a7338 commit 56d742c

5 files changed

Lines changed: 485 additions & 254 deletions

File tree

migrator/migrator.go

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,13 @@ func (m *Migrator) Run(ctx context.Context, dirname string) error {
7171
return nil
7272
}
7373

74-
independentCtx := pg.WithoutTx(ctx)
75-
7674
err := m.pg.WithAdvisoryLock(
7775
ctx,
7876
MigrationAdvisoryLock,
79-
func(conn pg.Conn) error {
77+
func(conn pg.Querier) error {
8078
err := m.pg.WithConn(
81-
independentCtx,
82-
func(connCtx context.Context, conn pg.Conn) error {
79+
ctx,
80+
func(connCtx context.Context, conn pg.Querier) error {
8381
return createIfNotExistVersionsTable(connCtx, conn)
8482
},
8583
)
@@ -100,9 +98,9 @@ func (m *Migrator) Run(ctx context.Context, dirname string) error {
10098
m.logger.Info("applying migration", log.String("version", migration.Version))
10199

102100
err := m.pg.WithTx(
103-
independentCtx,
104-
func(txCtx context.Context, conn pg.Conn) error {
105-
return migration.Apply(txCtx, conn)
101+
ctx,
102+
func(txCtx context.Context, tx pg.Tx) error {
103+
return migration.Apply(txCtx, tx)
106104
},
107105
)
108106
if err != nil {
@@ -166,14 +164,14 @@ func (pms *Migrations) LoadFromDir(disk FS, dirname string) error {
166164
return nil
167165
}
168166

169-
func (m *Migration) Apply(ctx context.Context, conn pg.Conn) error {
170-
_, err := conn.Exec(ctx, m.SQL)
167+
func (m *Migration) Apply(ctx context.Context, tx pg.Tx) error {
168+
_, err := tx.Exec(ctx, m.SQL)
171169
if err != nil {
172170
return fmt.Errorf("cannot execute migration: %w", err)
173171
}
174172

175173
q := "INSERT INTO schema_versions (version) VALUES ($1)"
176-
_, err = conn.Exec(ctx, q, m.Version)
174+
_, err = tx.Exec(ctx, q, m.Version)
177175
if err != nil {
178176
return fmt.Errorf("cannot insert schema version: %w", err)
179177
}
@@ -197,7 +195,7 @@ func (m *Migration) LoadFromFile(disk fs.ReadFileFS, filename string) error {
197195
return nil
198196
}
199197

200-
func createIfNotExistVersionsTable(ctx context.Context, conn pg.Conn) error {
198+
func createIfNotExistVersionsTable(ctx context.Context, conn pg.Querier) error {
201199
q := `
202200
CREATE TABLE IF NOT EXISTS schema_versions (
203201
version VARCHAR PRIMARY KEY,
@@ -209,7 +207,7 @@ CREATE TABLE IF NOT EXISTS schema_versions (
209207
return err
210208
}
211209

212-
func loadSchemaVersions(ctx context.Context, conn pg.Conn) (map[string]struct{}, error) {
210+
func loadSchemaVersions(ctx context.Context, conn pg.Querier) (map[string]struct{}, error) {
213211
q := "SELECT version FROM schema_versions"
214212
r, err := conn.Query(ctx, q)
215213
if err != nil {

migrator/migrator_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ func newTestPGClient(t *testing.T) *pg.Client {
3333
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
3434
defer cancel()
3535

36-
err = client.WithConn(ctx, func(ctx context.Context, conn pg.Conn) error {
36+
err = client.WithConn(ctx, func(ctx context.Context, conn pg.Querier) error {
3737
_, err := conn.Exec(ctx, "SELECT 1")
3838
return err
3939
})
@@ -50,7 +50,7 @@ func newTestPGClient(t *testing.T) *pg.Client {
5050
func dropTables(t *testing.T, client *pg.Client, tables ...string) {
5151
t.Helper()
5252
ctx := context.Background()
53-
_ = client.WithConn(ctx, func(ctx context.Context, conn pg.Conn) error {
53+
_ = client.WithConn(ctx, func(ctx context.Context, conn pg.Querier) error {
5454
for _, tbl := range tables {
5555
_, _ = conn.Exec(ctx, "DROP TABLE IF EXISTS "+tbl+" CASCADE")
5656
}
@@ -139,7 +139,7 @@ func TestMigrator_Run(t *testing.T) {
139139
m := migrator.NewMigrator(client, disk, logger)
140140
require.NoError(t, m.Run(ctx, "migrations"))
141141

142-
err := client.WithConn(ctx, func(ctx context.Context, conn pg.Conn) error {
142+
err := client.WithConn(ctx, func(ctx context.Context, conn pg.Querier) error {
143143
var exists bool
144144
err := conn.QueryRow(ctx,
145145
"SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name = 'test_mig_users')").
@@ -171,7 +171,7 @@ func TestMigrator_Run(t *testing.T) {
171171
require.NoError(t, m.Run(ctx, "migrations"))
172172
require.NoError(t, m.Run(ctx, "migrations"))
173173

174-
err := client.WithConn(ctx, func(ctx context.Context, conn pg.Conn) error {
174+
err := client.WithConn(ctx, func(ctx context.Context, conn pg.Querier) error {
175175
var count int
176176
err := conn.QueryRow(ctx, "SELECT count(*) FROM schema_versions").Scan(&count)
177177
require.NoError(t, err)
@@ -197,7 +197,7 @@ func TestMigrator_Run(t *testing.T) {
197197
m := migrator.NewMigrator(client, disk, logger)
198198
require.NoError(t, m.Run(ctx, "migrations"))
199199

200-
err := client.WithConn(ctx, func(ctx context.Context, conn pg.Conn) error {
200+
err := client.WithConn(ctx, func(ctx context.Context, conn pg.Querier) error {
201201
var count int
202202
err := conn.QueryRow(ctx, "SELECT count(*) FROM schema_versions").Scan(&count)
203203
require.NoError(t, err)
@@ -254,7 +254,7 @@ func TestMigrator_Run(t *testing.T) {
254254
m2 := migrator.NewMigrator(client, disk2, logger)
255255
require.NoError(t, m2.Run(ctx, "migrations"))
256256

257-
err := client.WithConn(ctx, func(ctx context.Context, conn pg.Conn) error {
257+
err := client.WithConn(ctx, func(ctx context.Context, conn pg.Querier) error {
258258
var count int
259259
err := conn.QueryRow(ctx, "SELECT count(*) FROM schema_versions").Scan(&count)
260260
require.NoError(t, err)

pg/client.go

Lines changed: 23 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -66,25 +66,11 @@ type (
6666
registerer prometheus.Registerer
6767
}
6868

69-
ExecFunc func(context.Context, Conn) error
69+
ExecFunc[Q Querier] func(context.Context, Q) error
7070

7171
AdvisoryLock = uint32
72-
73-
txKey struct{}
7472
)
7573

76-
func txFromContext(ctx context.Context) pgx.Tx {
77-
tx, _ := ctx.Value(txKey{}).(pgx.Tx)
78-
return tx
79-
}
80-
81-
// WithoutTx returns a copy of ctx with the active transaction
82-
// removed. Use this when a nested call should start an independent
83-
// transaction instead of reusing the parent as a savepoint.
84-
func WithoutTx(ctx context.Context) context.Context {
85-
return context.WithValue(ctx, txKey{}, nil)
86-
}
87-
8874
const (
8975
BaseAdvisoryLockId uint32 = 42
9076
)
@@ -279,7 +265,7 @@ func (c *Client) Close() {
279265
//
280266
// Example:
281267
//
282-
// err := client.WithConn(ctx, func(conn pg.Conn) error {
268+
// err := client.WithConn(ctx, func(ctx context.Context, conn pg.Querier) error {
283269
// _, err := conn.Exec(ctx, "SELECT * FROM users")
284270
// return err
285271
// })
@@ -288,7 +274,7 @@ func (c *Client) Close() {
288274
// and logs any errors.
289275
func (c *Client) WithConn(
290276
ctx context.Context,
291-
exec ExecFunc,
277+
exec ExecFunc[Querier],
292278
) error {
293279
var (
294280
rootSpan = trace.SpanFromContext(ctx)
@@ -326,27 +312,24 @@ func (c *Client) WithConn(
326312
return nil
327313
}
328314

329-
// WithTx executes the given ExecFunc within a transaction. This
330-
// method begins a transaction, executing `exec` within it. If `exec`
331-
// returns an error, the transaction is rolled back; otherwise, it
332-
// commits.
315+
// WithTx executes the given ExecFunc within a new transaction. This
316+
// method acquires a connection, begins a transaction, and executes
317+
// exec within it. If exec returns an error, the transaction is rolled
318+
// back; otherwise, it commits.
333319
//
334-
// When called within an existing transaction (i.e. nested WithTx),
335-
// a savepoint is created instead of a new transaction. If `exec`
336-
// returns an error, only the savepoint is rolled back; the outer
337-
// transaction remains active and can still be committed.
320+
// Use Tx.Savepoint inside the callback to create savepoints within
321+
// the transaction.
338322
//
339323
// Example:
340324
//
341-
// err := client.WithTx(ctx, func(ctx context.Context, tx pg.Conn) error {
325+
// err := client.WithTx(ctx, func(ctx context.Context, tx pg.Tx) error {
342326
// if _, err := tx.Exec(ctx, "DELETE FROM users WHERE id = $1", id); err != nil {
343327
// return err
344328
// }
345329
//
346-
// // Nested call creates a savepoint; failure here does not
347-
// // roll back the DELETE above.
348-
// if err := client.WithTx(ctx, func(ctx context.Context, tx pg.Conn) error {
349-
// _, err := tx.Exec(ctx, "INSERT INTO audit_log (...) VALUES (...)")
330+
// // Savepoint failure does not roll back the DELETE above.
331+
// if err := tx.Savepoint(ctx, func(ctx context.Context, q pg.Querier) error {
332+
// _, err := q.Exec(ctx, "INSERT INTO audit_log (...) VALUES (...)")
350333
// return err
351334
// }); err != nil {
352335
// log.Warn("audit failed, continuing", "err", err)
@@ -359,7 +342,7 @@ func (c *Client) WithConn(
359342
// and logs any errors.
360343
func (c *Client) WithTx(
361344
ctx context.Context,
362-
exec ExecFunc,
345+
exec ExecFunc[Tx],
363346
) error {
364347
var (
365348
rootSpan = trace.SpanFromContext(ctx)
@@ -375,67 +358,6 @@ func (c *Client) WithTx(
375358
defer span.End()
376359
}
377360

378-
if parentTx := txFromContext(ctx); parentTx != nil {
379-
if span != nil {
380-
span.SetAttributes(attribute.Bool("db.savepoint", true))
381-
}
382-
383-
return c.withSavepoint(ctx, span, parentTx, exec)
384-
}
385-
386-
return c.withNewTx(ctx, span, exec)
387-
}
388-
389-
func (c *Client) withSavepoint(
390-
ctx context.Context,
391-
span trace.Span,
392-
parentTx pgx.Tx,
393-
exec ExecFunc,
394-
) error {
395-
tx, err := parentTx.Begin(ctx)
396-
if err != nil {
397-
err := fmt.Errorf("cannot create savepoint: %w", err)
398-
if span != nil {
399-
recordError(span, err)
400-
}
401-
402-
return err
403-
}
404-
405-
txCtx := context.WithValue(ctx, txKey{}, tx)
406-
407-
if err := exec(txCtx, tx); err != nil {
408-
if err2 := tx.Rollback(ctx); err2 != nil {
409-
err = errors.Join(
410-
err,
411-
fmt.Errorf("cannot rollback savepoint: %w", err2),
412-
)
413-
}
414-
415-
if span != nil {
416-
recordError(span, err)
417-
}
418-
419-
return err
420-
}
421-
422-
if err := tx.Commit(ctx); err != nil {
423-
err := fmt.Errorf("cannot release savepoint: %w", err)
424-
if span != nil {
425-
recordError(span, err)
426-
}
427-
428-
return err
429-
}
430-
431-
return nil
432-
}
433-
434-
func (c *Client) withNewTx(
435-
ctx context.Context,
436-
span trace.Span,
437-
exec ExecFunc,
438-
) error {
439361
conn, err := c.pool.Acquire(ctx)
440362
if err != nil {
441363
err := fmt.Errorf("cannot acquire connection: %w", err)
@@ -447,7 +369,7 @@ func (c *Client) withNewTx(
447369
}
448370
defer conn.Release()
449371

450-
tx, err := conn.Begin(ctx)
372+
innerTx, err := conn.Begin(ctx)
451373
if err != nil {
452374
err := fmt.Errorf("cannot begin transaction: %w", err)
453375
if span != nil {
@@ -457,10 +379,10 @@ func (c *Client) withNewTx(
457379
return err
458380
}
459381

460-
txCtx := context.WithValue(ctx, txKey{}, tx)
382+
tx := &pgxTx{inner: innerTx, tracer: c.tracer}
461383

462-
if err := exec(txCtx, tx); err != nil {
463-
if err2 := tx.Rollback(ctx); err2 != nil {
384+
if err := exec(ctx, tx); err != nil {
385+
if err2 := innerTx.Rollback(ctx); err2 != nil {
464386
err = errors.Join(
465387
err,
466388
fmt.Errorf("cannot rollback transaction: %w", err2),
@@ -474,7 +396,7 @@ func (c *Client) withNewTx(
474396
return err
475397
}
476398

477-
if err := tx.Commit(ctx); err != nil {
399+
if err := innerTx.Commit(ctx); err != nil {
478400
err := fmt.Errorf("cannot commit transaction: %w", err)
479401
if span != nil {
480402
recordError(span, err)
@@ -489,7 +411,7 @@ func (c *Client) withNewTx(
489411
func (c *Client) WithAdvisoryLock(
490412
ctx context.Context,
491413
id AdvisoryLock,
492-
f func(Conn) error,
414+
f func(Querier) error,
493415
) error {
494416
var (
495417
rootSpan = trace.SpanFromContext(ctx)
@@ -510,9 +432,9 @@ func (c *Client) WithAdvisoryLock(
510432

511433
return c.WithTx(
512434
ctx,
513-
func(txCtx context.Context, conn Conn) error {
435+
func(txCtx context.Context, tx Tx) error {
514436
q := "SELECT pg_advisory_xact_lock($1, $2)"
515-
_, err := conn.Exec(txCtx, q, BaseAdvisoryLockId, id)
437+
_, err := tx.Exec(txCtx, q, BaseAdvisoryLockId, id)
516438
if err != nil {
517439
err = fmt.Errorf("cannot acquire advisory lock: %w", err)
518440
if rootSpan.IsRecording() {
@@ -523,7 +445,7 @@ func (c *Client) WithAdvisoryLock(
523445
return err
524446
}
525447

526-
err = f(conn)
448+
err = f(tx)
527449
if err != nil {
528450
if rootSpan.IsRecording() {
529451
span.SetStatus(codes.Error, err.Error())

0 commit comments

Comments
 (0)