Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

- Change SQLite driver operations over to use bulk inserts where possible now that sqlc has better support for `json_each`. [PR #1276](https://github.com/riverqueue/river/pull/1276)
- Detect duplicate step names across `river.ResumableStep` and return a validation error. [PR #1281](https://github.com/riverqueue/river/pull/1281)

### Fixed

Expand Down
2 changes: 2 additions & 0 deletions internal/rivermiddleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ func (*ResumableMiddleware) Work(ctx context.Context, job *rivertype.JobRow, doI
}

state := &ResumableState{
AllStepNames: make(map[string]struct{}),
Cursors: make(map[string]json.RawMessage),
ResumeMatched: true,
ResumeStep: gjson.GetBytes(job.Metadata, rivercommon.MetadataKeyResumableStep).Str,
Expand Down Expand Up @@ -80,6 +81,7 @@ func (*ResumableMiddleware) Work(ctx context.Context, job *rivertype.JobRow, doI
// ResumableState holds the state for a resumable job execution. It is stored in
// the context and accessed by ResumableStep and ResumableStepCursor.
type ResumableState struct {
AllStepNames map[string]struct{}
CompletedStep string
Cursors map[string]json.RawMessage
Err error
Expand Down
20 changes: 20 additions & 0 deletions resumable.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ type StepOpts struct{}

// ResumableStep runs a resumable step, skipping the step on a later retry if
// an earlier attempt already completed it successfully.
// Step names must be unique across all ResumableStep and ResumableStepCursor
// calls in the same Worker execution.
//
// After a step returns an error, no subsequent steps will be run and the
// overall job will be marked as failed with that error. Be careful to put all
Expand All @@ -57,6 +59,9 @@ func ResumableStep(ctx context.Context, name string, opts *StepOpts, stepFunc fu
if state.Err != nil {
return
}
if !registerResumableStepName(state, name) {
return
}

if !state.ResumeMatched {
if name == state.ResumeStep {
Expand All @@ -81,6 +86,8 @@ func ResumableStep(ctx context.Context, name string, opts *StepOpts, stepFunc fu
// ResumableStepCursor runs a resumable step that also receives a persisted
// cursor value from an earlier failed attempt, if one was recorded with
// ResumableSetCursor.
// Step names must be unique across all ResumableStep and ResumableStepCursor
// calls in the same Worker execution.
//
// The cursor type T is user-specified. It may be a primitive value like an
// integer ID, or a more complex type like a struct with multiple fields. It's
Expand All @@ -102,6 +109,9 @@ func ResumableStepCursor[TCursor any](ctx context.Context, name string, opts *St
if state.Err != nil {
return
}
if !registerResumableStepName(state, name) {
return
}

if !state.ResumeMatched {
if name == state.ResumeStep {
Expand Down Expand Up @@ -149,6 +159,16 @@ func mustResumableState(ctx context.Context) *rivermiddleware.ResumableState {
return state
}

func registerResumableStepName(state *rivermiddleware.ResumableState, name string) bool {
if _, ok := state.AllStepNames[name]; ok {
state.Err = fmt.Errorf("river: duplicate resumable step name %q", name)
return false
}

state.AllStepNames[name] = struct{}{}
return true
}

func resumableStateFromContext(ctx context.Context) (*rivermiddleware.ResumableState, bool) {
state := ctx.Value(rivermiddleware.ResumableContextKey{})
if state == nil {
Expand Down
70 changes: 70 additions & 0 deletions resumable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,54 @@ func TestResumableStep(t *testing.T) {
return ctx, metadataUpdates, &rivertype.JobRow{Metadata: []byte(metadata)}
}

t.Run("DuplicateStepName", func(t *testing.T) {
t.Parallel()

ctx, _, job := setup(t, `{}`)

var ran []string
err := (&rivermiddleware.ResumableMiddleware{}).Work(ctx, job, func(ctx context.Context) error {
ResumableStep(ctx, "step1", nil, func(ctx context.Context) error {
ran = append(ran, "first")
return nil
})
ResumableStep(ctx, "step1", nil, func(ctx context.Context) error {
ran = append(ran, "second")
return nil
})

return nil
})
require.EqualError(t, err, `river: duplicate resumable step name "step1"`)
require.Equal(t, []string{"first"}, ran)
})

t.Run("DuplicateStepNameWhenSkippingCompletedSteps", func(t *testing.T) {
t.Parallel()

ctx, _, job := setup(t, `{"river:resumable_step":"step2"}`)

var ran []string
err := (&rivermiddleware.ResumableMiddleware{}).Work(ctx, job, func(ctx context.Context) error {
ResumableStep(ctx, "step1", nil, func(ctx context.Context) error {
ran = append(ran, "first")
return nil
})
ResumableStep(ctx, "step1", nil, func(ctx context.Context) error {
ran = append(ran, "second")
return nil
})
ResumableStep(ctx, "step2", nil, func(ctx context.Context) error {
ran = append(ran, "third")
return nil
})

return nil
})
require.EqualError(t, err, `river: duplicate resumable step name "step1"`)
require.Empty(t, ran)
})

t.Run("PanicsOutsideWorker", func(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -131,6 +179,28 @@ func TestResumableStepCursor(t *testing.T) {
return ctx, metadataUpdates, &rivertype.JobRow{Metadata: []byte(metadata)}
}

t.Run("DuplicateStepNameSharedWithCursorStep", func(t *testing.T) {
t.Parallel()

ctx, _, job := setup(t, `{}`)

var ran []string
err := (&rivermiddleware.ResumableMiddleware{}).Work(ctx, job, func(ctx context.Context) error {
ResumableStep(ctx, "step1", nil, func(ctx context.Context) error {
ran = append(ran, "step")
return nil
})
ResumableStepCursor(ctx, "step1", nil, func(ctx context.Context, cursor resumableCursor) error {
ran = append(ran, "cursor")
return nil
})

return nil
})
require.EqualError(t, err, `river: duplicate resumable step name "step1"`)
require.Equal(t, []string{"step"}, ran)
})

t.Run("ResumesCursor", func(t *testing.T) {
t.Parallel()

Expand Down
Loading