Skip to content

Commit 20e9f20

Browse files
authored
Add RollbackWithoutCancel helper to have rollbacks supersede context cancellation (#1062)
Here, add a `dbutil.RollbackWithoutCancel` helper. The purpose of this is in situations where an operation is in a transaction but ends up cancelling due to a context timeout. From there, there tends to be a rollback in a `defer`, but the rollback ends up not running because its context is already cancelled. The new helper removes a cancelled context, then adds a new 5 second context timeout to run the rollback command. This is mainly used for `JobSchedule` in the hopes of coming up with a fix for #1059, but we use the new function everywhere that we previously had a manual rollback, and it also gets add to the `dbutil.WithTx*` helpers so that it gets used anywhere those do (and these should probably be preferred most of the time because it makes rollbacks and commits harder to forget). Fixes #1059.
1 parent c40e99f commit 20e9f20

7 files changed

Lines changed: 126 additions & 23 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010
### Fixed
1111

1212
- Fix snoozed events emitted from `rivertest.Worker` when snooze duration is zero seconds. [PR #1057](https://github.com/riverqueue/river/pull/1057).
13+
- Rollbacks now use an uncancelled context so as to not leave transactions in an ambiguous state if a transaction in them fails due to context cancellation. [PR #1062](https://github.com/riverqueue/river/pull/1062).
1314

1415
## [0.26.0] - 2025-10-07
1516

client.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2403,7 +2403,7 @@ func (c *Client[TTx]) QueuePause(ctx context.Context, name string, opts *QueuePa
24032403
if err != nil {
24042404
return err
24052405
}
2406-
defer tx.Rollback(ctx)
2406+
defer dbutil.RollbackWithoutCancel(ctx, tx)
24072407

24082408
if err := tx.QueuePause(ctx, &riverdriver.QueuePauseParams{
24092409
Name: name,
@@ -2473,7 +2473,7 @@ func (c *Client[TTx]) QueueResume(ctx context.Context, name string, opts *QueueP
24732473
if err != nil {
24742474
return err
24752475
}
2476-
defer tx.Rollback(ctx)
2476+
defer dbutil.RollbackWithoutCancel(ctx, tx)
24772477

24782478
if err := tx.QueueResume(ctx, &riverdriver.QueueResumeParams{
24792479
Name: name,
@@ -2541,7 +2541,7 @@ func (c *Client[TTx]) QueueUpdate(ctx context.Context, name string, params *Queu
25412541
if err != nil {
25422542
return nil, err
25432543
}
2544-
defer tx.Rollback(ctx)
2544+
defer dbutil.RollbackWithoutCancel(ctx, tx)
25452545

25462546
queue, controlEvent, err := c.queueUpdate(ctx, tx, name, params)
25472547
if err != nil {

internal/maintenance/job_scheduler.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"github.com/riverqueue/river/rivershared/riversharedmaintenance"
1515
"github.com/riverqueue/river/rivershared/startstop"
1616
"github.com/riverqueue/river/rivershared/testsignal"
17+
"github.com/riverqueue/river/rivershared/util/dbutil"
1718
"github.com/riverqueue/river/rivershared/util/randutil"
1819
"github.com/riverqueue/river/rivershared/util/serviceutil"
1920
"github.com/riverqueue/river/rivershared/util/testutil"
@@ -37,7 +38,7 @@ func (ts *JobSchedulerTestSignals) Init(tb testutil.TestingTB) {
3738

3839
// NotifyInsert is a function to call to emit notifications for queues where
3940
// jobs were scheduled.
40-
type NotifyInsertFunc func(ctx context.Context, tx riverdriver.ExecutorTx, queues []string) error
41+
type NotifyInsertFunc func(ctx context.Context, execTx riverdriver.ExecutorTx, queues []string) error
4142

4243
type JobSchedulerConfig struct {
4344
riversharedmaintenance.BatchSizes
@@ -167,16 +168,16 @@ func (s *JobScheduler) runOnce(ctx context.Context) (*schedulerRunOnceResult, er
167168
ctx, cancelFunc := context.WithTimeout(ctx, riversharedmaintenance.TimeoutDefault)
168169
defer cancelFunc()
169170

170-
tx, err := s.exec.Begin(ctx)
171+
execTx, err := s.exec.Begin(ctx)
171172
if err != nil {
172173
return 0, fmt.Errorf("error starting transaction: %w", err)
173174
}
174-
defer tx.Rollback(ctx)
175+
defer dbutil.RollbackWithoutCancel(ctx, execTx)
175176

176177
now := s.Time.NowUTC()
177178
nowWithLookAhead := now.Add(s.config.Interval)
178179

179-
scheduledJobResults, err := tx.JobSchedule(ctx, &riverdriver.JobScheduleParams{
180+
scheduledJobResults, err := execTx.JobSchedule(ctx, &riverdriver.JobScheduleParams{
180181
Max: s.batchSize(),
181182
Now: &nowWithLookAhead,
182183
Schema: s.config.Schema,
@@ -204,13 +205,13 @@ func (s *JobScheduler) runOnce(ctx context.Context) (*schedulerRunOnceResult, er
204205
}
205206

206207
if len(queues) > 0 {
207-
if err := s.config.NotifyInsert(ctx, tx, queues); err != nil {
208+
if err := s.config.NotifyInsert(ctx, execTx, queues); err != nil {
208209
return 0, fmt.Errorf("error notifying insert: %w", err)
209210
}
210211
s.TestSignals.NotifiedQueues.Signal(queues)
211212
}
212213

213-
return len(scheduledJobResults), tx.Commit(ctx)
214+
return len(scheduledJobResults), execTx.Commit(ctx)
214215
}()
215216
if err != nil {
216217
if errors.Is(err, context.DeadlineExceeded) {

riverdriver/riversqlite/river_sqlite_driver.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ func (e *Executor) JobCancel(ctx context.Context, params *riverdriver.JobCancelP
243243
// exists and is not running, only one database operation is needed, but if
244244
// the initial update comes back empty, it does one more fetch to return the
245245
// most appropriate error.
246-
return dbutil.WithTxV(ctx, e, func(ctx context.Context, execTx riverdriver.ExecutorTx) (*rivertype.JobRow, error) {
246+
return dbutil.WithTxV(ctx, e, func(ctx context.Context, execTx riverdriver.ExecutorTx) (*rivertype.JobRow, error) { // TODO
247247
dbtx := templateReplaceWrapper{dbtx: e.driver.UnwrapTx(execTx), replacer: &e.driver.replacer}
248248

249249
cancelledAt, err := params.CancelAttemptedAt.UTC().MarshalJSON()
@@ -320,7 +320,7 @@ func (e *Executor) JobDelete(ctx context.Context, params *riverdriver.JobDeleteP
320320
// exists and is not running, only one database operation is needed, but if
321321
// the initial delete comes back empty, it does one more fetch to return the
322322
// most appropriate error.
323-
return dbutil.WithTxV(ctx, e, func(ctx context.Context, execTx riverdriver.ExecutorTx) (*rivertype.JobRow, error) {
323+
return dbutil.WithTxV(ctx, e, func(ctx context.Context, execTx riverdriver.ExecutorTx) (*rivertype.JobRow, error) { // TODO
324324
dbtx := templateReplaceWrapper{dbtx: e.driver.UnwrapTx(execTx), replacer: &e.driver.replacer}
325325

326326
job, err := dbsqlc.New().JobDelete(schemaTemplateParam(ctx, params.Schema), dbtx, params.ID)
@@ -495,7 +495,7 @@ func (e *Executor) JobInsertFastMany(ctx context.Context, params *riverdriver.Jo
495495
uniqueNonce = randutil.Hex(8)
496496
)
497497

498-
if err := dbutil.WithTx(ctx, e, func(ctx context.Context, execTx riverdriver.ExecutorTx) error {
498+
if err := dbutil.WithTx(ctx, e, func(ctx context.Context, execTx riverdriver.ExecutorTx) error { // TODO
499499
ctx = schemaTemplateParam(ctx, params.Schema)
500500
dbtx := templateReplaceWrapper{dbtx: e.driver.UnwrapTx(execTx), replacer: &e.driver.replacer}
501501

rivermigrate/river_migrate_test.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,9 @@ import (
2727

2828
const (
2929
// The name of an actual migration line embedded in our test data below.
30-
migrationLineAlternate = "alternate"
31-
migrationLineAlternateMaxVersion = 6
32-
migrationLineCommitRequired = "commit_required"
33-
migrationLineCommitRequiredMaxVersion = 3
30+
migrationLineAlternate = "alternate"
31+
migrationLineAlternateMaxVersion = 6
32+
migrationLineCommitRequired = "commit_required"
3433
)
3534

3635
//go:embed migration/*/*.sql

rivershared/util/dbutil/db_util.go

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,44 @@ package dbutil
33
import (
44
"context"
55
"fmt"
6+
"time"
67

78
"github.com/riverqueue/river/riverdriver"
89
)
910

11+
// RollbackWithoutCancel initiates a rollback, but one in which context is
12+
// overridden with context.WithoutCancel so that the rollback can proceed even
13+
// if a previous operation was cancelled. This decreases the chance that a
14+
// transaction is accidentally left open and in an ambiguous state.
15+
//
16+
// A rollback error is returned the same way most driver rollbacks return an
17+
// error, but given this is normally expected to be used in a defer statement,
18+
// it's unusual for the error to be handled.
19+
func RollbackWithoutCancel[TExec riverdriver.ExecutorTx](ctx context.Context, execTx TExec) error {
20+
ctxWithoutCancel := context.WithoutCancel(ctx)
21+
22+
ctx, cancel := context.WithTimeout(ctxWithoutCancel, 5*time.Second)
23+
defer cancel()
24+
25+
// It might not be the worst idea to log an unexpected error on rollback
26+
// here instead of returning it. I had this in place initially, but there's
27+
// a number of common errors that need to be ignored like "conn closed",
28+
// `pgx.ErrTxClosed`, or `sql.ErrTxDone`. These all turn out to be
29+
// driver-specific when this function is meant to be driver agnostic.
30+
//
31+
// It'd still be possible to make it happen, but we'd have to have a driver
32+
// function like `ShouldIgnoreRollbackError` that'd need a lot of plumbing
33+
// and it becomes questionable as to whether it's all worth it as Rollback
34+
// producing a non-standard error would be quite unusual.
35+
return execTx.Rollback(ctx)
36+
}
37+
1038
// WithTx starts and commits a transaction on a driver executor around
1139
// the given function, allowing the return of a generic value.
40+
//
41+
// Rollbacks use RollbackWithoutCancel to maximize the chance of a successful
42+
// rollback even where an operation within the transaction timed out due to
43+
// context timeout.
1244
func WithTx[TExec riverdriver.Executor](ctx context.Context, exec TExec, innerFunc func(ctx context.Context, execTx riverdriver.ExecutorTx) error) error {
1345
_, err := WithTxV(ctx, exec, func(ctx context.Context, tx riverdriver.ExecutorTx) (struct{}, error) {
1446
return struct{}{}, innerFunc(ctx, tx)
@@ -18,14 +50,18 @@ func WithTx[TExec riverdriver.Executor](ctx context.Context, exec TExec, innerFu
1850

1951
// WithTxV starts and commits a transaction on a driver executor around
2052
// the given function, allowing the return of a generic value.
53+
//
54+
// Rollbacks use RollbackWithoutCancel to maximize the chance of a successful
55+
// rollback even where an operation within the transaction timed out due to
56+
// context timeout.
2157
func WithTxV[TExec riverdriver.Executor, T any](ctx context.Context, exec TExec, innerFunc func(ctx context.Context, execTx riverdriver.ExecutorTx) (T, error)) (T, error) {
2258
var defaultRes T
2359

2460
tx, err := exec.Begin(ctx)
2561
if err != nil {
2662
return defaultRes, fmt.Errorf("error beginning transaction: %w", err)
2763
}
28-
defer tx.Rollback(ctx)
64+
defer RollbackWithoutCancel(ctx, tx) //nolint:errcheck
2965

3066
res, err := innerFunc(ctx, tx)
3167
if err != nil {

rivershared/util/dbutil/db_util_test.go

Lines changed: 72 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@ package dbutil_test
22

33
import (
44
"context"
5+
"errors"
56
"testing"
67

8+
"github.com/jackc/pgx/v5"
79
"github.com/stretchr/testify/require"
810

911
"github.com/riverqueue/river/riverdbtest"
@@ -12,12 +14,74 @@ import (
1214
"github.com/riverqueue/river/rivershared/util/dbutil"
1315
)
1416

15-
func TestWithTx(t *testing.T) {
17+
func TestRollbackCancelOverride(t *testing.T) {
1618
t.Parallel()
1719

1820
ctx := context.Background()
19-
tx := riverdbtest.TestTxPgx(ctx, t)
20-
driver := riverpgxv5.New(nil)
21+
22+
type testBundle struct {
23+
driver *riverpgxv5.Driver
24+
tx pgx.Tx
25+
}
26+
27+
setup := func(t *testing.T) *testBundle {
28+
t.Helper()
29+
30+
return &testBundle{
31+
driver: riverpgxv5.New(nil),
32+
tx: riverdbtest.TestTxPgx(ctx, t),
33+
}
34+
}
35+
36+
t.Run("Success", func(t *testing.T) {
37+
t.Parallel()
38+
39+
bundle := setup(t)
40+
41+
dbutil.RollbackWithoutCancel(ctx, bundle.driver.UnwrapExecutor(bundle.tx))
42+
})
43+
44+
t.Run("WithCancelledContext", func(t *testing.T) {
45+
t.Parallel()
46+
47+
bundle := setup(t)
48+
49+
ctx, cancel := context.WithCancel(ctx)
50+
cancel()
51+
52+
dbutil.RollbackWithoutCancel(ctx, bundle.driver.UnwrapExecutor(bundle.tx))
53+
})
54+
55+
t.Run("RollbackError", func(t *testing.T) {
56+
t.Parallel()
57+
58+
bundle := setup(t)
59+
60+
execTx := &executorTxWithRollbackError{
61+
ExecutorTx: bundle.driver.UnwrapExecutor(bundle.tx),
62+
}
63+
64+
err := dbutil.RollbackWithoutCancel(ctx, execTx)
65+
require.EqualError(t, err, "rollback error")
66+
})
67+
}
68+
69+
type executorTxWithRollbackError struct {
70+
riverdriver.ExecutorTx
71+
}
72+
73+
func (e *executorTxWithRollbackError) Rollback(ctx context.Context) error {
74+
return errors.New("rollback error")
75+
}
76+
77+
func TestWithTx(t *testing.T) {
78+
t.Parallel()
79+
80+
var (
81+
ctx = context.Background()
82+
tx = riverdbtest.TestTxPgx(ctx, t)
83+
driver = riverpgxv5.New(nil)
84+
)
2185

2286
err := dbutil.WithTx(ctx, driver.UnwrapExecutor(tx), func(ctx context.Context, execTx riverdriver.ExecutorTx) error {
2387
require.NoError(t, execTx.Exec(ctx, "SELECT 1"))
@@ -29,9 +93,11 @@ func TestWithTx(t *testing.T) {
2993
func TestWithTxV(t *testing.T) {
3094
t.Parallel()
3195

32-
ctx := context.Background()
33-
tx := riverdbtest.TestTxPgx(ctx, t)
34-
driver := riverpgxv5.New(nil)
96+
var (
97+
ctx = context.Background()
98+
tx = riverdbtest.TestTxPgx(ctx, t)
99+
driver = riverpgxv5.New(nil)
100+
)
35101

36102
ret, err := dbutil.WithTxV(ctx, driver.UnwrapExecutor(tx), func(ctx context.Context, execTx riverdriver.ExecutorTx) (int, error) {
37103
require.NoError(t, execTx.Exec(ctx, "SELECT 1"))

0 commit comments

Comments
 (0)