Skip to content

Commit ea03440

Browse files
authored
Merge pull request #1130 from hieblmi/fix-racae
sweepbatcher: harden AddSweep against ctx closure
2 parents 0102120 + 1b25d8a commit ea03440

2 files changed

Lines changed: 469 additions & 23 deletions

File tree

sweepbatcher/sweep_batcher.go

Lines changed: 124 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -675,12 +675,26 @@ func (b *Batcher) Run(ctx context.Context) error {
675675
// the database. We will then resume the execution of these batches.
676676
batches, err := b.FetchUnconfirmedBatches(runCtx)
677677
if err != nil {
678+
if ctxErr := runCtx.Err(); ctxErr != nil {
679+
infof("FetchUnconfirmedBatches failed during shutdown, "+
680+
"returning %v instead of %v.", ctxErr, err)
681+
682+
return ctxErr
683+
}
684+
678685
return err
679686
}
680687

681688
for _, batch := range batches {
682689
err := b.spinUpBatchFromDB(runCtx, batch)
683690
if err != nil {
691+
if ctxErr := runCtx.Err(); ctxErr != nil {
692+
infof("spinUpBatchFromDB failed during shutdown, "+
693+
"returning %v instead of %v.", ctxErr, err)
694+
695+
return ctxErr
696+
}
697+
684698
return err
685699
}
686700
}
@@ -695,6 +709,14 @@ func (b *Batcher) Run(ctx context.Context) error {
695709
runCtx, req.sweeps, req.notifier, req.fast,
696710
)
697711
if err != nil {
712+
if ctxErr := runCtx.Err(); ctxErr != nil {
713+
infof("handleSweeps failed during shutdown, "+
714+
"returning %v instead of %v.",
715+
ctxErr, err)
716+
717+
return ctxErr
718+
}
719+
698720
warnf("handleSweeps failed: %v.", err)
699721

700722
return err
@@ -705,6 +727,14 @@ func (b *Batcher) Run(ctx context.Context) error {
705727
close(testReq.quit)
706728

707729
case err := <-b.errChan:
730+
if ctxErr := runCtx.Err(); ctxErr != nil {
731+
infof("Batcher received an error during shutdown, "+
732+
"returning %v instead of %v.",
733+
ctxErr, err)
734+
735+
return ctxErr
736+
}
737+
708738
warnf("Batcher received an error: %v.", err)
709739

710740
return err
@@ -734,13 +764,33 @@ func (b *Batcher) PresignSweepsGroup(ctx context.Context, inputs []Input,
734764
return fmt.Errorf("presignedHelper is not installed")
735765
}
736766

767+
if err := b.shutdownOrCancelErrIfAny(ctx); err != nil {
768+
return err
769+
}
770+
737771
// Find the feerate needed to get into next block. Use conf_target=2,
738772
nextBlockFeeRate, err := b.wallet.EstimateFeeRate(ctx, 2)
739773
if err != nil {
774+
if exitErr := b.shutdownOrCancelErrIfAny(ctx); exitErr != nil {
775+
infof("PresignSweepsGroup EstimateFeeRate failed "+
776+
"during shutdown, returning %v instead of %v.",
777+
exitErr, err)
778+
779+
return exitErr
780+
}
781+
740782
return fmt.Errorf("failed to get nextBlockFeeRate: %w", err)
741783
}
742784
minRelayFeeRate, err := b.wallet.MinRelayFee(ctx)
743785
if err != nil {
786+
if exitErr := b.shutdownOrCancelErrIfAny(ctx); exitErr != nil {
787+
infof("PresignSweepsGroup MinRelayFee failed during "+
788+
"shutdown, returning %v instead of %v.",
789+
exitErr, err)
790+
791+
return exitErr
792+
}
793+
744794
return fmt.Errorf("failed to get minRelayFeeRate: %w", err)
745795
}
746796
destPkscript, err := txscript.PayToAddrScript(destAddress)
@@ -768,10 +818,23 @@ func (b *Batcher) PresignSweepsGroup(ctx context.Context, inputs []Input,
768818
// outpoint in the batch.
769819
primarySweepID := sweeps[0].outpoint
770820

771-
return presign(
821+
err = presign(
772822
ctx, b.presignedHelper, destAddress, primarySweepID, sweeps,
773823
nextBlockFeeRate, minRelayFeeRate,
774824
)
825+
if err != nil {
826+
if exitErr := b.shutdownOrCancelErrIfAny(ctx); exitErr != nil {
827+
infof("PresignSweepsGroup presign failed during "+
828+
"shutdown, returning %v instead of %v.",
829+
exitErr, err)
830+
831+
return exitErr
832+
}
833+
834+
return err
835+
}
836+
837+
return nil
775838
}
776839

777840
// AddSweep loads information about sweeps from the store and fee rate source,
@@ -780,16 +843,21 @@ func (b *Batcher) PresignSweepsGroup(ctx context.Context, inputs []Input,
780843
// times, but the sweeps (including the order of them) must be the same. If
781844
// notifier is provided, the batcher sends back sweeping results through it.
782845
func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
783-
// If the batcher is shutting down, quit now.
784-
select {
785-
case <-b.quit:
786-
return ErrBatcherShuttingDown
787-
788-
default:
846+
// If the batcher or the caller is shutting down, quit now.
847+
err := b.shutdownOrCancelErrIfAny(ctx)
848+
if err != nil {
849+
return err
789850
}
790851

791852
sweeps, err := b.fetchSweeps(ctx, *sweepReq)
792853
if err != nil {
854+
if exitErr := b.shutdownOrCancelErrIfAny(ctx); exitErr != nil {
855+
infof("fetchSweeps failed during shutdown, returning "+
856+
"%v instead of %v.", exitErr, err)
857+
858+
return exitErr
859+
}
860+
793861
return fmt.Errorf("fetchSweeps failed: %w", err)
794862
}
795863

@@ -803,6 +871,14 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
803871

804872
completed, err := b.store.GetSweepStatus(ctx, sweep.outpoint)
805873
if err != nil {
874+
if exitErr := b.shutdownOrCancelErrIfAny(ctx); exitErr != nil {
875+
infof("GetSweepStatus failed for sweep %v during "+
876+
"shutdown, returning %v instead of %v.",
877+
sweep.outpoint, exitErr, err)
878+
879+
return exitErr
880+
}
881+
806882
return fmt.Errorf("failed to get the status of sweep %v: %w",
807883
sweep.outpoint, err)
808884
}
@@ -816,6 +892,14 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
816892
// on-chain confirmations to prevent issues caused by reorgs.
817893
parentBatch, err = b.store.GetParentBatch(ctx, sweep.outpoint)
818894
if err != nil {
895+
if exitErr := b.shutdownOrCancelErrIfAny(ctx); exitErr != nil {
896+
infof("GetParentBatch failed for sweep %v "+
897+
"during shutdown, returning %v instead "+
898+
"of %v.", sweep.outpoint, exitErr, err)
899+
900+
return exitErr
901+
}
902+
819903
return fmt.Errorf("unable to get parent batch for "+
820904
"sweep %x: %w", sweep.swapHash[:6], err)
821905
}
@@ -827,6 +911,13 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
827911

828912
minRelayFeeRate, err := b.wallet.MinRelayFee(ctx)
829913
if err != nil {
914+
if exitErr := b.shutdownOrCancelErrIfAny(ctx); exitErr != nil {
915+
infof("MinRelayFee failed during shutdown, returning "+
916+
"%v instead of %v.", exitErr, err)
917+
918+
return exitErr
919+
}
920+
830921
return fmt.Errorf("failed to get min relay fee: %w", err)
831922
}
832923

@@ -839,6 +930,15 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
839930
b.chainParams,
840931
)
841932
if err != nil {
933+
if exitErr := b.shutdownOrCancelErrIfAny(ctx); exitErr != nil {
934+
infof("ensurePresigned failed for primary "+
935+
"sweep %v during shutdown, returning %v "+
936+
"instead of %v.", sweep.outpoint,
937+
exitErr, err)
938+
939+
return exitErr
940+
}
941+
842942
return fmt.Errorf("inputs with primarySweep %v were "+
843943
"not presigned (call PresignSweepsGroup "+
844944
"first): %w", sweep.outpoint, err)
@@ -861,9 +961,26 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
861961

862962
case <-b.quit:
863963
return ErrBatcherShuttingDown
964+
965+
case <-ctx.Done():
966+
return b.shutdownOrCancelErrIfAny(ctx)
864967
}
865968
}
866969

970+
// shutdownOrCancelErrIfAny returns the terminal error to use when caller-facing
971+
// batcher methods race with shutdown or caller cancellation. It returns nil if
972+
// the operation should continue.
973+
func (b *Batcher) shutdownOrCancelErrIfAny(ctx context.Context) error {
974+
select {
975+
case <-b.quit:
976+
return ErrBatcherShuttingDown
977+
978+
default:
979+
}
980+
981+
return ctx.Err()
982+
}
983+
867984
// testRunInEventLoop runs a function in the event loop blocking until
868985
// the function returns. For unit tests only!
869986
func (b *Batcher) testRunInEventLoop(ctx context.Context, handler func()) {

0 commit comments

Comments
 (0)