@@ -780,16 +780,19 @@ func (b *Batcher) PresignSweepsGroup(ctx context.Context, inputs []Input,
780780// times, but the sweeps (including the order of them) must be the same. If
781781// notifier is provided, the batcher sends back sweeping results through it.
782782func (b * Batcher ) AddSweep (ctx context.Context , sweepReq * SweepRequest ) error {
783- // If the batcher is shutting down, quit now.
784- select {
785- case <- b .quit :
786- return ErrBatcherShuttingDown
787-
788- default :
783+ // If the batcher or the caller is shutting down, quit now.
784+ err := b .addSweepExitErr (ctx )
785+ if err != nil {
786+ return err
789787 }
790788
791789 sweeps , err := b .fetchSweeps (ctx , * sweepReq )
792790 if err != nil {
791+ exitErr := b .addSweepExitErr (ctx )
792+ if exitErr != nil {
793+ return fmt .Errorf ("fetchSweeps failed: %w" , exitErr )
794+ }
795+
793796 return fmt .Errorf ("fetchSweeps failed: %w" , err )
794797 }
795798
@@ -803,6 +806,12 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
803806
804807 completed , err := b .store .GetSweepStatus (ctx , sweep .outpoint )
805808 if err != nil {
809+ exitErr := b .addSweepExitErr (ctx )
810+ if exitErr != nil {
811+ return fmt .Errorf ("failed to get the status of " +
812+ "sweep %v: %w" , sweep .outpoint , exitErr )
813+ }
814+
806815 return fmt .Errorf ("failed to get the status of sweep %v: %w" ,
807816 sweep .outpoint , err )
808817 }
@@ -816,6 +825,13 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
816825 // on-chain confirmations to prevent issues caused by reorgs.
817826 parentBatch , err = b .store .GetParentBatch (ctx , sweep .outpoint )
818827 if err != nil {
828+ exitErr := b .addSweepExitErr (ctx )
829+ if exitErr != nil {
830+ return fmt .Errorf ("unable to get parent " +
831+ "batch for sweep %x: %w" ,
832+ sweep .swapHash [:6 ], exitErr )
833+ }
834+
819835 return fmt .Errorf ("unable to get parent batch for " +
820836 "sweep %x: %w" , sweep .swapHash [:6 ], err )
821837 }
@@ -827,6 +843,12 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
827843
828844 minRelayFeeRate , err := b .wallet .MinRelayFee (ctx )
829845 if err != nil {
846+ exitErr := b .addSweepExitErr (ctx )
847+ if exitErr != nil {
848+ return fmt .Errorf ("failed to get min relay fee: %w" ,
849+ exitErr )
850+ }
851+
830852 return fmt .Errorf ("failed to get min relay fee: %w" , err )
831853 }
832854
@@ -861,9 +883,30 @@ func (b *Batcher) AddSweep(ctx context.Context, sweepReq *SweepRequest) error {
861883
862884 case <- b .quit :
863885 return ErrBatcherShuttingDown
886+
887+ case <- ctx .Done ():
888+ err := b .addSweepExitErr (ctx )
889+ if err != nil {
890+ return err
891+ }
892+
893+ return ctx .Err ()
864894 }
865895}
866896
897+ // addSweepExitErr returns the terminal error to use when AddSweep races with
898+ // shutdown or caller cancellation.
899+ func (b * Batcher ) addSweepExitErr (ctx context.Context ) error {
900+ select {
901+ case <- b .quit :
902+ return ErrBatcherShuttingDown
903+
904+ default :
905+ }
906+
907+ return ctx .Err ()
908+ }
909+
867910// testRunInEventLoop runs a function in the event loop blocking until
868911// the function returns. For unit tests only!
869912func (b * Batcher ) testRunInEventLoop (ctx context.Context , handler func ()) {
0 commit comments