Skip to content

Commit 4bbc9da

Browse files
committed
chore: refactored query paths to Querier interface
1 parent ed3423b commit 4bbc9da

6 files changed

Lines changed: 83 additions & 38 deletions

File tree

cmd/dryrun/init.go

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"os"
99
"path/filepath"
1010

11+
"github.com/jackc/pgx/v5"
1112
"github.com/jackc/pgx/v5/pgxpool"
1213
"github.com/spf13/cobra"
1314

@@ -34,22 +35,42 @@ type initWriter interface {
3435
PutActivity(ctx context.Context, key history.SnapshotKey, a *schema.ActivityStatsSnapshot) (history.PutOutcome, error)
3536
}
3637

37-
type pgxCapturer struct{ pool *pgxpool.Pool }
38+
// one REPEATABLE READ, READ ONLY tx (as pg_dump uses) for the whole capture:
39+
// consistent snapshot, no writes, one connection
40+
type pgxCapturer struct{ tx pgx.Tx }
41+
42+
func newPgxCapturer(ctx context.Context, pool *pgxpool.Pool) (pgxCapturer, error) {
43+
tx, err := pool.BeginTx(ctx, pgx.TxOptions{
44+
IsoLevel: pgx.RepeatableRead,
45+
AccessMode: pgx.ReadOnly,
46+
})
47+
if err != nil {
48+
return pgxCapturer{}, fmt.Errorf("begin read-only transaction: %w", err)
49+
}
50+
return pgxCapturer{tx: tx}, nil
51+
}
52+
53+
// read-only, so there is nothing to commit; rollback releases the snapshot
54+
func (c pgxCapturer) Close(ctx context.Context) {
55+
if err := c.tx.Rollback(ctx); err != nil && !errors.Is(err, pgx.ErrTxClosed) {
56+
slog.Warn("rollback capture transaction", "error", err)
57+
}
58+
}
3859

3960
func (c pgxCapturer) IsStandby(ctx context.Context) (bool, error) {
40-
return schema.FetchIsStandby(ctx, c.pool)
61+
return schema.FetchIsStandby(ctx, c.tx)
4162
}
4263

4364
func (c pgxCapturer) Introspect(ctx context.Context) (*schema.SchemaSnapshot, error) {
44-
return schema.IntrospectSchema(ctx, c.pool)
65+
return schema.IntrospectSchema(ctx, c.tx)
4566
}
4667

4768
func (c pgxCapturer) CapturePlanner(ctx context.Context, schemaRefHash string) (*schema.PlannerStatsSnapshot, error) {
48-
return schema.CapturePlannerStats(ctx, c.pool, schemaRefHash)
69+
return schema.CapturePlannerStats(ctx, c.tx, schemaRefHash)
4970
}
5071

5172
func (c pgxCapturer) CaptureActivity(ctx context.Context, schemaRefHash, source string) (*schema.ActivityStatsSnapshot, error) {
52-
return schema.CaptureActivityStats(ctx, c.pool, schemaRefHash, source)
73+
return schema.CaptureActivityStats(ctx, c.tx, schemaRefHash, source)
5374
}
5475

5576
// init owns the masking flag surface; other subcommands don't mask anything.
@@ -98,6 +119,12 @@ func initCmd() *cobra.Command {
98119
}
99120
defer conn.Close()
100121

122+
cap, err := newPgxCapturer(ctx, conn.Pool())
123+
if err != nil {
124+
return err
125+
}
126+
defer cap.Close(ctx)
127+
101128
store, err := openHistoryStore("")
102129
if err != nil {
103130
return fmt.Errorf("open history store: %w", err)
@@ -113,7 +140,7 @@ func initCmd() *cobra.Command {
113140
slog.Warn("masking disabled by --no-masks; raw planner stats will be written to history.db")
114141
}
115142

116-
return runInitCapture(ctx, pgxCapturer{pool: conn.Pool()}, store, key, dataDir, initOptions{
143+
return runInitCapture(ctx, cap, store, key, dataDir, initOptions{
117144
AllowReplica: allowReplica,
118145
Source: source,
119146
Policy: policy,

cmd/dryrun/main.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,12 @@ func snapshotCmd() *cobra.Command {
346346
}
347347
defer conn.Close()
348348

349+
cap, err := newPgxCapturer(cmd.Context(), conn.Pool())
350+
if err != nil {
351+
return err
352+
}
353+
defer cap.Close(cmd.Context())
354+
349355
store, err := openHistoryStore(historyDB)
350356
if err != nil {
351357
return err
@@ -361,7 +367,7 @@ func snapshotCmd() *cobra.Command {
361367
slog.Warn("masking disabled by --no-masks; raw planner stats will be written to history.db")
362368
}
363369

364-
snap, planner, activity, masked, err := runSnapshotTake(cmd.Context(), pgxCapturer{pool: conn.Pool()}, store, key, policy)
370+
snap, planner, activity, masked, err := runSnapshotTake(cmd.Context(), cap, store, key, policy)
365371
if err != nil {
366372
return err
367373
}

cmd/dryrun/snapshot.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,19 @@ func snapshotActivityCmd() *cobra.Command {
4242
}
4343
defer conn.Close()
4444

45+
cap, err := newPgxCapturer(ctx, conn.Pool())
46+
if err != nil {
47+
return err
48+
}
49+
defer cap.Close(ctx)
50+
4551
store, err := openHistoryStore(historyDB)
4652
if err != nil {
4753
return err
4854
}
4955
defer store.Close()
5056

51-
return runSnapshotActivity(ctx, pgxCapturer{pool: conn.Pool()}, store, resolveSnapshotKey(), activityOptions{
57+
return runSnapshotActivity(ctx, cap, store, resolveSnapshotKey(), activityOptions{
5258
Label: label,
5359
AllowOrphan: allowOrphan,
5460
})

internal/schema/connection.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,19 @@ import (
66
"log/slog"
77
"time"
88

9+
"github.com/jackc/pgx/v5"
910
"github.com/jackc/pgx/v5/pgxpool"
1011

1112
"github.com/boringsql/dryrun/internal/dryrun"
1213
)
1314

15+
// satisfied by both *pgxpool.Pool and pgx.Tx, so capture can run straight on
16+
// the pool or inside a read-only tx
17+
type Querier interface {
18+
Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
19+
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
20+
}
21+
1422
type DryRun struct {
1523
pool *pgxpool.Pool
1624
}

internal/schema/introspect.go

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99

1010
"github.com/boringsql/queries"
1111
"github.com/jackc/pgx/v5"
12-
"github.com/jackc/pgx/v5/pgxpool"
1312
)
1413

1514
// scanAll wraps the standard rows.Next loop. The scan callback receives the
@@ -27,7 +26,7 @@ func scanAll[T any](rows pgx.Rows, scan func(pgx.Rows) (T, error)) ([]T, error)
2726
return out, rows.Err()
2827
}
2928

30-
func query(ctx context.Context, pool *pgxpool.Pool, name string) (pgx.Rows, error) {
29+
func query(ctx context.Context, pool Querier, name string) (pgx.Rows, error) {
3130
return pool.Query(ctx, q(name))
3231
}
3332

@@ -48,7 +47,7 @@ func q(name string) string {
4847
}
4948

5049
// DDL-only introspection; planner/activity stats now flow through CapturePlannerStats / CaptureActivityStats
51-
func IntrospectSchema(ctx context.Context, pool *pgxpool.Pool) (*SchemaSnapshot, error) {
50+
func IntrospectSchema(ctx context.Context, pool Querier) (*SchemaSnapshot, error) {
5251
var pgVersion string
5352
if err := pool.QueryRow(ctx, "SELECT version()").Scan(&pgVersion); err != nil {
5453
return nil, fmt.Errorf("query pg version: %w", err)
@@ -253,7 +252,7 @@ type (
253252

254253
// Fetchers - each uses a named query from sql/introspect.sql
255254

256-
func fetchTables(ctx context.Context, pool *pgxpool.Pool) ([]rawTable, error) {
255+
func fetchTables(ctx context.Context, pool Querier) ([]rawTable, error) {
257256
rows, err := query(ctx, pool, "fetch-tables")
258257
if err != nil {
259258
return nil, err
@@ -267,7 +266,7 @@ func fetchTables(ctx context.Context, pool *pgxpool.Pool) ([]rawTable, error) {
267266
})
268267
}
269268

270-
func fetchColumns(ctx context.Context, pool *pgxpool.Pool) ([]rawColumn, error) {
269+
func fetchColumns(ctx context.Context, pool Querier) ([]rawColumn, error) {
271270
rows, err := query(ctx, pool, "fetch-columns")
272271
if err != nil {
273272
return nil, err
@@ -281,7 +280,7 @@ func fetchColumns(ctx context.Context, pool *pgxpool.Pool) ([]rawColumn, error)
281280
})
282281
}
283282

284-
func fetchConstraints(ctx context.Context, pool *pgxpool.Pool) ([]rawConstraint, error) {
283+
func fetchConstraints(ctx context.Context, pool Querier) ([]rawConstraint, error) {
285284
rows, err := query(ctx, pool, "fetch-constraints")
286285
if err != nil {
287286
return nil, err
@@ -295,7 +294,7 @@ func fetchConstraints(ctx context.Context, pool *pgxpool.Pool) ([]rawConstraint,
295294
})
296295
}
297296

298-
func fetchTableComments(ctx context.Context, pool *pgxpool.Pool) ([]rawTableComment, error) {
297+
func fetchTableComments(ctx context.Context, pool Querier) ([]rawTableComment, error) {
299298
rows, err := query(ctx, pool, "fetch-table-comments")
300299
if err != nil {
301300
return nil, err
@@ -309,7 +308,7 @@ func fetchTableComments(ctx context.Context, pool *pgxpool.Pool) ([]rawTableComm
309308
})
310309
}
311310

312-
func fetchColumnComments(ctx context.Context, pool *pgxpool.Pool) ([]rawColumnComment, error) {
311+
func fetchColumnComments(ctx context.Context, pool Querier) ([]rawColumnComment, error) {
313312
rows, err := query(ctx, pool, "fetch-column-comments")
314313
if err != nil {
315314
return nil, err
@@ -323,7 +322,7 @@ func fetchColumnComments(ctx context.Context, pool *pgxpool.Pool) ([]rawColumnCo
323322
})
324323
}
325324

326-
func fetchEnums(ctx context.Context, pool *pgxpool.Pool) ([]EnumType, error) {
325+
func fetchEnums(ctx context.Context, pool Querier) ([]EnumType, error) {
327326
rows, err := query(ctx, pool, "fetch-enums")
328327
if err != nil {
329328
return nil, err
@@ -335,7 +334,7 @@ func fetchEnums(ctx context.Context, pool *pgxpool.Pool) ([]EnumType, error) {
335334
})
336335
}
337336

338-
func fetchDomains(ctx context.Context, pool *pgxpool.Pool) ([]DomainType, error) {
337+
func fetchDomains(ctx context.Context, pool Querier) ([]DomainType, error) {
339338
rows, err := query(ctx, pool, "fetch-domains")
340339
if err != nil {
341340
return nil, err
@@ -349,7 +348,7 @@ func fetchDomains(ctx context.Context, pool *pgxpool.Pool) ([]DomainType, error)
349348
})
350349
}
351350

352-
func fetchComposites(ctx context.Context, pool *pgxpool.Pool) ([]CompositeType, error) {
351+
func fetchComposites(ctx context.Context, pool Querier) ([]CompositeType, error) {
353352
rows, err := pool.Query(ctx, q("fetch-composites"))
354353
if err != nil {
355354
return nil, err
@@ -400,7 +399,7 @@ func fetchComposites(ctx context.Context, pool *pgxpool.Pool) ([]CompositeType,
400399
return out, nil
401400
}
402401

403-
func fetchIndexes(ctx context.Context, pool *pgxpool.Pool) ([]rawIndex, error) {
402+
func fetchIndexes(ctx context.Context, pool Querier) ([]rawIndex, error) {
404403
rows, err := query(ctx, pool, "fetch-indexes")
405404
if err != nil {
406405
return nil, err
@@ -432,7 +431,7 @@ func fetchIndexes(ctx context.Context, pool *pgxpool.Pool) ([]rawIndex, error) {
432431
})
433432
}
434433

435-
func fetchPartitionInfo(ctx context.Context, pool *pgxpool.Pool) ([]rawPartitionInfo, error) {
434+
func fetchPartitionInfo(ctx context.Context, pool Querier) ([]rawPartitionInfo, error) {
436435
rows, err := query(ctx, pool, "fetch-partition-info")
437436
if err != nil {
438437
return nil, err
@@ -446,7 +445,7 @@ func fetchPartitionInfo(ctx context.Context, pool *pgxpool.Pool) ([]rawPartition
446445
})
447446
}
448447

449-
func fetchPartitionChildren(ctx context.Context, pool *pgxpool.Pool) ([]rawPartitionChild, error) {
448+
func fetchPartitionChildren(ctx context.Context, pool Querier) ([]rawPartitionChild, error) {
450449
rows, err := query(ctx, pool, "fetch-partition-children")
451450
if err != nil {
452451
return nil, err
@@ -466,7 +465,7 @@ func fetchPartitionChildren(ctx context.Context, pool *pgxpool.Pool) ([]rawParti
466465
})
467466
}
468467

469-
func fetchPolicies(ctx context.Context, pool *pgxpool.Pool) ([]rawPolicy, error) {
468+
func fetchPolicies(ctx context.Context, pool Querier) ([]rawPolicy, error) {
470469
rows, err := query(ctx, pool, "fetch-policies")
471470
if err != nil {
472471
return nil, err
@@ -480,7 +479,7 @@ func fetchPolicies(ctx context.Context, pool *pgxpool.Pool) ([]rawPolicy, error)
480479
})
481480
}
482481

483-
func fetchTriggers(ctx context.Context, pool *pgxpool.Pool) ([]rawTrigger, error) {
482+
func fetchTriggers(ctx context.Context, pool Querier) ([]rawTrigger, error) {
484483
rows, err := query(ctx, pool, "fetch-triggers")
485484
if err != nil {
486485
return nil, err
@@ -495,7 +494,7 @@ func fetchTriggers(ctx context.Context, pool *pgxpool.Pool) ([]rawTrigger, error
495494
}
496495

497496

498-
func fetchViews(ctx context.Context, pool *pgxpool.Pool) ([]View, error) {
497+
func fetchViews(ctx context.Context, pool Querier) ([]View, error) {
499498
rows, err := query(ctx, pool, "fetch-views")
500499
if err != nil {
501500
return nil, err
@@ -513,7 +512,7 @@ func fetchViews(ctx context.Context, pool *pgxpool.Pool) ([]View, error) {
513512
})
514513
}
515514

516-
func fetchFunctions(ctx context.Context, pool *pgxpool.Pool) ([]Function, error) {
515+
func fetchFunctions(ctx context.Context, pool Querier) ([]Function, error) {
517516
rows, err := query(ctx, pool, "fetch-functions")
518517
if err != nil {
519518
return nil, err
@@ -541,7 +540,7 @@ func fetchFunctions(ctx context.Context, pool *pgxpool.Pool) ([]Function, error)
541540
})
542541
}
543542

544-
func fetchExtensions(ctx context.Context, pool *pgxpool.Pool) ([]Extension, error) {
543+
func fetchExtensions(ctx context.Context, pool Querier) ([]Extension, error) {
545544
rows, err := query(ctx, pool, "fetch-extensions")
546545
if err != nil {
547546
return nil, err
@@ -553,7 +552,7 @@ func fetchExtensions(ctx context.Context, pool *pgxpool.Pool) ([]Extension, erro
553552
})
554553
}
555554

556-
func fetchGUCs(ctx context.Context, pool *pgxpool.Pool) ([]GucSetting, error) {
555+
func fetchGUCs(ctx context.Context, pool Querier) ([]GucSetting, error) {
557556
rows, err := query(ctx, pool, "fetch-gucs")
558557
if err != nil {
559558
return nil, err

0 commit comments

Comments
 (0)