Skip to content

Commit 7df652a

Browse files
committed
sweepbatcher: harden AddSweep against ctx closure
1 parent 787519e commit 7df652a

2 files changed

Lines changed: 95 additions & 6 deletions

File tree

sweepbatcher/sweep_batcher.go

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -780,16 +780,19 @@ func (b *Batcher) PresignSweepsGroup(ctx context.Context, inputs []Input,
780780
// times, but the sweeps (including the order of them) must be the same. If
781781
// notifier is provided, the batcher sends back sweeping results through it.
782782
func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
783-
// If the batcher is shutting down, quit now.
784-
select {
785-
case <-b.quit:
786-
return ErrBatcherShuttingDown
787-
788-
default:
783+
// If the batcher or the caller is shutting down, quit now.
784+
err := b.addSweepExitErr(ctx)
785+
if err != nil {
786+
return err
789787
}
790788

791789
sweeps, err := b.fetchSweeps(ctx, *sweepReq)
792790
if err != nil {
791+
exitErr := b.addSweepExitErr(ctx)
792+
if exitErr != nil {
793+
err = exitErr
794+
}
795+
793796
return fmt.Errorf("fetchSweeps failed: %w", err)
794797
}
795798

@@ -803,6 +806,11 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
803806

804807
completed, err := b.store.GetSweepStatus(ctx, sweep.outpoint)
805808
if err != nil {
809+
exitErr := b.addSweepExitErr(ctx)
810+
if exitErr != nil {
811+
err = exitErr
812+
}
813+
806814
return fmt.Errorf("failed to get the status of sweep %v: %w",
807815
sweep.outpoint, err)
808816
}
@@ -816,6 +824,11 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
816824
// on-chain confirmations to prevent issues caused by reorgs.
817825
parentBatch, err = b.store.GetParentBatch(ctx, sweep.outpoint)
818826
if err != nil {
827+
exitErr := b.addSweepExitErr(ctx)
828+
if exitErr != nil {
829+
err = exitErr
830+
}
831+
819832
return fmt.Errorf("unable to get parent batch for "+
820833
"sweep %x: %w", sweep.swapHash[:6], err)
821834
}
@@ -827,6 +840,11 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
827840

828841
minRelayFeeRate, err := b.wallet.MinRelayFee(ctx)
829842
if err != nil {
843+
exitErr := b.addSweepExitErr(ctx)
844+
if exitErr != nil {
845+
err = exitErr
846+
}
847+
830848
return fmt.Errorf("failed to get min relay fee: %w", err)
831849
}
832850

@@ -839,6 +857,11 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
839857
b.chainParams,
840858
)
841859
if err != nil {
860+
exitErr := b.addSweepExitErr(ctx)
861+
if exitErr != nil {
862+
err = exitErr
863+
}
864+
842865
return fmt.Errorf("inputs with primarySweep %v were "+
843866
"not presigned (call PresignSweepsGroup "+
844867
"first): %w", sweep.outpoint, err)
@@ -861,7 +884,23 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
861884

862885
case <-b.quit:
863886
return ErrBatcherShuttingDown
887+
888+
case <-ctx.Done():
889+
return b.addSweepExitErr(ctx)
890+
}
891+
}
892+
893+
// addSweepExitErr returns the terminal error to use when AddSweep races with
894+
// shutdown or caller cancellation.
895+
func (b *Batcher) addSweepExitErr(ctx context.Context) error {
896+
select {
897+
case <-b.quit:
898+
return ErrBatcherShuttingDown
899+
900+
default:
864901
}
902+
903+
return ctx.Err()
865904
}
866905

867906
// testRunInEventLoop runs a function in the event loop blocking until

sweepbatcher/sweep_batcher_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package sweepbatcher
33
import (
44
"context"
55
"database/sql"
6+
"database/sql/driver"
67
"errors"
78
"fmt"
89
"maps"
@@ -3711,6 +3712,55 @@ func (f *sweepFetcherMock) FetchSweep(ctx context.Context, _ lntypes.Hash,
37113712
return f.store[outpoint], nil
37123713
}
37133714

3715+
type cancelingSweepFetcher struct {
3716+
cancel context.CancelFunc
3717+
}
3718+
3719+
func (f *cancelingSweepFetcher) FetchSweep(context.Context, lntypes.Hash,
3720+
wire.OutPoint) (*SweepInfo, error) {
3721+
3722+
// Simulate the caller canceling while the backend returns a
3723+
// driver-level error.
3724+
f.cancel()
3725+
3726+
return nil, driver.ErrBadConn
3727+
}
3728+
3729+
func testAddSweepReturnsContextErrorOnFetchCancellation(t *testing.T,
3730+
_ testStore, batcherStore testBatcherStore) {
3731+
3732+
defer test.Guard(t)()
3733+
3734+
lnd := test.NewMockLnd()
3735+
ctx, cancel := context.WithCancel(context.Background())
3736+
3737+
batcher := NewBatcher(
3738+
lnd.WalletKit, lnd.ChainNotifier, lnd.Signer,
3739+
testMuSig2SignSweep, testVerifySchnorrSig, lnd.ChainParams,
3740+
batcherStore, &cancelingSweepFetcher{cancel: cancel},
3741+
)
3742+
3743+
err := batcher.AddSweep(ctx, &SweepRequest{
3744+
SwapHash: lntypes.Hash{1, 1, 1},
3745+
Inputs: []Input{{
3746+
Value: 1111,
3747+
Outpoint: wire.OutPoint{
3748+
Hash: chainhash.Hash{1, 1},
3749+
Index: 1,
3750+
},
3751+
}},
3752+
})
3753+
require.ErrorIs(t, err, context.Canceled)
3754+
require.NotErrorIs(t, err, driver.ErrBadConn)
3755+
}
3756+
3757+
// TestAddSweepReturnsContextErrorOnFetchCancellation asserts that AddSweep
3758+
// returns the context cancellation error if sweep fetching fails while the
3759+
// caller context is being canceled.
3760+
func TestAddSweepReturnsContextErrorOnFetchCancellation(t *testing.T) {
3761+
runTests(t, testAddSweepReturnsContextErrorOnFetchCancellation)
3762+
}
3763+
37143764
// testSweepFetcher tests providing custom sweep fetcher to Batcher.
37153765
func testSweepFetcher(t *testing.T, store testStore,
37163766
batcherStore testBatcherStore) {

0 commit comments

Comments
 (0)