Skip to content

Commit bc2ff0d

Browse files
committed
sweepbatcher: harden AddSweep against ctx closure
1 parent 787519e commit bc2ff0d

2 files changed

Lines changed: 91 additions & 6 deletions

File tree

sweepbatcher/sweep_batcher.go

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -780,16 +780,18 @@ func (b *Batcher) PresignSweepsGroup(ctx context.Context, inputs []Input,
780780
// times, but the sweeps (including the order of them) must be the same. If
781781
// notifier is provided, the batcher sends back sweeping results through it.
782782
func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
783-
// If the batcher is shutting down, quit now.
784-
select {
785-
case <-b.quit:
786-
return ErrBatcherShuttingDown
787-
788-
default:
783+
// If the batcher or the caller is shutting down, quit now.
784+
err := b.addSweepExitErrIfAny(ctx)
785+
if err != nil {
786+
return err
789787
}
790788

791789
sweeps, err := b.fetchSweeps(ctx, *sweepReq)
792790
if err != nil {
791+
if exitErr := b.addSweepExitErrIfAny(ctx); exitErr != nil {
792+
return exitErr
793+
}
794+
793795
return fmt.Errorf("fetchSweeps failed: %w", err)
794796
}
795797

@@ -803,6 +805,10 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
803805

804806
completed, err := b.store.GetSweepStatus(ctx, sweep.outpoint)
805807
if err != nil {
808+
if exitErr := b.addSweepExitErrIfAny(ctx); exitErr != nil {
809+
return exitErr
810+
}
811+
806812
return fmt.Errorf("failed to get the status of sweep %v: %w",
807813
sweep.outpoint, err)
808814
}
@@ -816,6 +822,10 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
816822
// on-chain confirmations to prevent issues caused by reorgs.
817823
parentBatch, err = b.store.GetParentBatch(ctx, sweep.outpoint)
818824
if err != nil {
825+
if exitErr := b.addSweepExitErrIfAny(ctx); exitErr != nil {
826+
return exitErr
827+
}
828+
819829
return fmt.Errorf("unable to get parent batch for "+
820830
"sweep %x: %w", sweep.swapHash[:6], err)
821831
}
@@ -827,6 +837,10 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
827837

828838
minRelayFeeRate, err := b.wallet.MinRelayFee(ctx)
829839
if err != nil {
840+
if exitErr := b.addSweepExitErrIfAny(ctx); exitErr != nil {
841+
return exitErr
842+
}
843+
830844
return fmt.Errorf("failed to get min relay fee: %w", err)
831845
}
832846

@@ -839,6 +853,10 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
839853
b.chainParams,
840854
)
841855
if err != nil {
856+
if exitErr := b.addSweepExitErrIfAny(ctx); exitErr != nil {
857+
return exitErr
858+
}
859+
842860
return fmt.Errorf("inputs with primarySweep %v were "+
843861
"not presigned (call PresignSweepsGroup "+
844862
"first): %w", sweep.outpoint, err)
@@ -861,7 +879,24 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
861879

862880
case <-b.quit:
863881
return ErrBatcherShuttingDown
882+
883+
case <-ctx.Done():
884+
return b.addSweepExitErrIfAny(ctx)
885+
}
886+
}
887+
888+
// addSweepExitErrIfAny returns the terminal error to use when AddSweep races
889+
// with shutdown or caller cancellation. It returns nil if AddSweep should
890+
// continue.
891+
func (b *Batcher) addSweepExitErrIfAny(ctx context.Context) error {
892+
select {
893+
case <-b.quit:
894+
return ErrBatcherShuttingDown
895+
896+
default:
864897
}
898+
899+
return ctx.Err()
865900
}
866901

867902
// testRunInEventLoop runs a function in the event loop blocking until

sweepbatcher/sweep_batcher_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package sweepbatcher
33
import (
44
"context"
55
"database/sql"
6+
"database/sql/driver"
67
"errors"
78
"fmt"
89
"maps"
@@ -3711,6 +3712,55 @@ func (f *sweepFetcherMock) FetchSweep(ctx context.Context, _ lntypes.Hash,
37113712
return f.store[outpoint], nil
37123713
}
37133714

3715+
type cancelingSweepFetcher struct {
3716+
cancel context.CancelFunc
3717+
}
3718+
3719+
func (f *cancelingSweepFetcher) FetchSweep(context.Context, lntypes.Hash,
3720+
wire.OutPoint) (*SweepInfo, error) {
3721+
3722+
// Simulate the caller canceling while the backend returns a
3723+
// driver-level error.
3724+
f.cancel()
3725+
3726+
return nil, driver.ErrBadConn
3727+
}
3728+
3729+
func testAddSweepReturnsContextErrorOnFetchCancellation(t *testing.T,
3730+
_ testStore, batcherStore testBatcherStore) {
3731+
3732+
defer test.Guard(t)()
3733+
3734+
lnd := test.NewMockLnd()
3735+
ctx, cancel := context.WithCancel(context.Background())
3736+
3737+
batcher := NewBatcher(
3738+
lnd.WalletKit, lnd.ChainNotifier, lnd.Signer,
3739+
testMuSig2SignSweep, testVerifySchnorrSig, lnd.ChainParams,
3740+
batcherStore, &cancelingSweepFetcher{cancel: cancel},
3741+
)
3742+
3743+
err := batcher.AddSweep(ctx, &SweepRequest{
3744+
SwapHash: lntypes.Hash{1, 1, 1},
3745+
Inputs: []Input{{
3746+
Value: 1111,
3747+
Outpoint: wire.OutPoint{
3748+
Hash: chainhash.Hash{1, 1},
3749+
Index: 1,
3750+
},
3751+
}},
3752+
})
3753+
require.ErrorIs(t, err, context.Canceled)
3754+
require.NotErrorIs(t, err, driver.ErrBadConn)
3755+
}
3756+
3757+
// TestAddSweepReturnsContextErrorOnFetchCancellation asserts that AddSweep
3758+
// returns the context cancellation error if sweep fetching fails while the
3759+
// caller context is being canceled.
3760+
func TestAddSweepReturnsContextErrorOnFetchCancellation(t *testing.T) {
3761+
runTests(t, testAddSweepReturnsContextErrorOnFetchCancellation)
3762+
}
3763+
37143764
// testSweepFetcher tests providing custom sweep fetcher to Batcher.
37153765
func testSweepFetcher(t *testing.T, store testStore,
37163766
batcherStore testBatcherStore) {

0 commit comments

Comments
 (0)