Skip to content

Commit 8db097b

Browse files
committed
sweepbatcher: harden AddSweep against ctx closure
1 parent 787519e commit 8db097b

2 files changed

Lines changed: 91 additions & 6 deletions

File tree

sweepbatcher/sweep_batcher.go

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -780,16 +780,19 @@ func (b *Batcher) PresignSweepsGroup(ctx context.Context, inputs []Input,
780780
// times, but the sweeps (including the order of them) must be the same. If
781781
// notifier is provided, the batcher sends back sweeping results through it.
782782
func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
783-
// If the batcher is shutting down, quit now.
784-
select {
785-
case <-b.quit:
786-
return ErrBatcherShuttingDown
787-
788-
default:
783+
// If the batcher or the caller is shutting down, quit now.
784+
err := b.addSweepExitErr(ctx)
785+
if err != nil {
786+
return err
789787
}
790788

791789
sweeps, err := b.fetchSweeps(ctx, *sweepReq)
792790
if err != nil {
791+
exitErr := b.addSweepExitErr(ctx)
792+
if exitErr != nil {
793+
return fmt.Errorf("fetchSweeps failed: %w", exitErr)
794+
}
795+
793796
return fmt.Errorf("fetchSweeps failed: %w", err)
794797
}
795798

@@ -803,6 +806,12 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
803806

804807
completed, err := b.store.GetSweepStatus(ctx, sweep.outpoint)
805808
if err != nil {
809+
exitErr := b.addSweepExitErr(ctx)
810+
if exitErr != nil {
811+
return fmt.Errorf("failed to get the status of "+
812+
"sweep %v: %w", sweep.outpoint, exitErr)
813+
}
814+
806815
return fmt.Errorf("failed to get the status of sweep %v: %w",
807816
sweep.outpoint, err)
808817
}
@@ -816,6 +825,13 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
816825
// on-chain confirmations to prevent issues caused by reorgs.
817826
parentBatch, err = b.store.GetParentBatch(ctx, sweep.outpoint)
818827
if err != nil {
828+
exitErr := b.addSweepExitErr(ctx)
829+
if exitErr != nil {
830+
return fmt.Errorf("unable to get parent "+
831+
"batch for sweep %x: %w",
832+
sweep.swapHash[:6], exitErr)
833+
}
834+
819835
return fmt.Errorf("unable to get parent batch for "+
820836
"sweep %x: %w", sweep.swapHash[:6], err)
821837
}
@@ -827,6 +843,12 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
827843

828844
minRelayFeeRate, err := b.wallet.MinRelayFee(ctx)
829845
if err != nil {
846+
exitErr := b.addSweepExitErr(ctx)
847+
if exitErr != nil {
848+
return fmt.Errorf("failed to get min relay fee: %w",
849+
exitErr)
850+
}
851+
830852
return fmt.Errorf("failed to get min relay fee: %w", err)
831853
}
832854

@@ -861,9 +883,30 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
861883

862884
case <-b.quit:
863885
return ErrBatcherShuttingDown
886+
887+
case <-ctx.Done():
888+
err := b.addSweepExitErr(ctx)
889+
if err != nil {
890+
return err
891+
}
892+
893+
return ctx.Err()
864894
}
865895
}
866896

897+
// addSweepExitErr returns the terminal error to use when AddSweep races with
898+
// shutdown or caller cancellation.
899+
func (b *Batcher) addSweepExitErr(ctx context.Context) error {
900+
select {
901+
case <-b.quit:
902+
return ErrBatcherShuttingDown
903+
904+
default:
905+
}
906+
907+
return ctx.Err()
908+
}
909+
867910
// testRunInEventLoop runs a function in the event loop blocking until
868911
// the function returns. For unit tests only!
869912
func (b *Batcher) testRunInEventLoop(ctx context.Context, handler func()) {

sweepbatcher/sweep_batcher_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package sweepbatcher
33
import (
44
"context"
55
"database/sql"
6+
"database/sql/driver"
67
"errors"
78
"fmt"
89
"maps"
@@ -3711,6 +3712,47 @@ func (f *sweepFetcherMock) FetchSweep(ctx context.Context, _ lntypes.Hash,
37113712
return f.store[outpoint], nil
37123713
}
37133714

3715+
type cancelingSweepFetcher struct {
3716+
cancel context.CancelFunc
3717+
}
3718+
3719+
func (f *cancelingSweepFetcher) FetchSweep(context.Context, lntypes.Hash,
3720+
wire.OutPoint) (*SweepInfo, error) {
3721+
3722+
f.cancel()
3723+
3724+
return nil, driver.ErrBadConn
3725+
}
3726+
3727+
// TestAddSweepReturnsContextErrorOnFetchCancellation asserts that AddSweep
3728+
// returns the context cancellation error if sweep fetching fails while the
3729+
// caller context is being canceled.
3730+
func TestAddSweepReturnsContextErrorOnFetchCancellation(t *testing.T) {
3731+
defer test.Guard(t)()
3732+
3733+
lnd := test.NewMockLnd()
3734+
ctx, cancel := context.WithCancel(context.Background())
3735+
3736+
batcher := NewBatcher(
3737+
lnd.WalletKit, lnd.ChainNotifier, lnd.Signer,
3738+
testMuSig2SignSweep, testVerifySchnorrSig, lnd.ChainParams,
3739+
NewStoreMock(), &cancelingSweepFetcher{cancel: cancel},
3740+
)
3741+
3742+
err := batcher.AddSweep(ctx, &SweepRequest{
3743+
SwapHash: lntypes.Hash{1, 1, 1},
3744+
Inputs: []Input{{
3745+
Value: 1111,
3746+
Outpoint: wire.OutPoint{
3747+
Hash: chainhash.Hash{1, 1},
3748+
Index: 1,
3749+
},
3750+
}},
3751+
})
3752+
require.ErrorIs(t, err, context.Canceled)
3753+
require.NotErrorIs(t, err, driver.ErrBadConn)
3754+
}
3755+
37143756
// testSweepFetcher tests providing custom sweep fetcher to Batcher.
37153757
func testSweepFetcher(t *testing.T, store testStore,
37163758
batcherStore testBatcherStore) {

0 commit comments

Comments
 (0)