@@ -780,16 +780,18 @@ func (b *Batcher) PresignSweepsGroup(ctx context.Context, inputs []Input,
780780// times, but the sweeps (including the order of them) must be the same. If
781781// notifier is provided, the batcher sends back sweeping results through it.
782782func (b * Batcher ) AddSweep (ctx context.Context , sweepReq * SweepRequest ) error {
783- // If the batcher is shutting down, quit now.
784- select {
785- case <- b .quit :
786- return ErrBatcherShuttingDown
787-
788- default :
783+ // If the batcher or the caller is shutting down, quit now.
784+ err := b .addSweepExitErrIfAny (ctx )
785+ if err != nil {
786+ return err
789787 }
790788
791789 sweeps , err := b .fetchSweeps (ctx , * sweepReq )
792790 if err != nil {
791+ if exitErr := b .addSweepExitErrIfAny (ctx ); exitErr != nil {
792+ return exitErr
793+ }
794+
793795 return fmt .Errorf ("fetchSweeps failed: %w" , err )
794796 }
795797
@@ -803,6 +805,10 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
803805
804806 completed , err := b .store .GetSweepStatus (ctx , sweep .outpoint )
805807 if err != nil {
808+ if exitErr := b .addSweepExitErrIfAny (ctx ); exitErr != nil {
809+ return exitErr
810+ }
811+
806812 return fmt .Errorf ("failed to get the status of sweep %v: %w" ,
807813 sweep .outpoint , err )
808814 }
@@ -816,6 +822,10 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
816822 // on-chain confirmations to prevent issues caused by reorgs.
817823 parentBatch , err = b .store .GetParentBatch (ctx , sweep .outpoint )
818824 if err != nil {
825+ if exitErr := b .addSweepExitErrIfAny (ctx ); exitErr != nil {
826+ return exitErr
827+ }
828+
819829 return fmt .Errorf ("unable to get parent batch for " +
820830 "sweep %x: %w" , sweep .swapHash [:6 ], err )
821831 }
@@ -827,6 +837,10 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
827837
828838 minRelayFeeRate , err := b .wallet .MinRelayFee (ctx )
829839 if err != nil {
840+ if exitErr := b .addSweepExitErrIfAny (ctx ); exitErr != nil {
841+ return exitErr
842+ }
843+
830844 return fmt .Errorf ("failed to get min relay fee: %w" , err )
831845 }
832846
@@ -839,6 +853,10 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
839853 b .chainParams ,
840854 )
841855 if err != nil {
856+ if exitErr := b .addSweepExitErrIfAny (ctx ); exitErr != nil {
857+ return exitErr
858+ }
859+
842860 return fmt .Errorf ("inputs with primarySweep %v were " +
843861 "not presigned (call PresignSweepsGroup " +
844862 "first): %w" , sweep .outpoint , err )
@@ -861,7 +879,24 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
861879
862880 case <- b .quit :
863881 return ErrBatcherShuttingDown
882+
883+ case <- ctx .Done ():
884+ return b .addSweepExitErrIfAny (ctx )
885+ }
886+ }
887+
888+ // addSweepExitErrIfAny returns the terminal error to use when AddSweep races
889+ // with shutdown or caller cancellation. It returns nil if AddSweep should
890+ // continue.
891+ func (b * Batcher ) addSweepExitErrIfAny (ctx context.Context ) error {
892+ select {
893+ case <- b .quit :
894+ return ErrBatcherShuttingDown
895+
896+ default :
864897 }
898+
899+ return ctx .Err ()
865900}
866901
867902// testRunInEventLoop runs a function in the event loop blocking until
0 commit comments