2424// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
2525// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626
27- use tokio:: sync:: mpsc;
27+ use dashmap:: DashMap ;
28+ use parking_lot:: Mutex ;
2829use tokio:: sync:: watch;
30+ use tokio:: task;
2931use tokio:: task:: AbortHandle ;
3032
31- use std:: collections:: HashMap ;
3233use std:: future:: Future ;
33- use std:: sync:: atomic:: AtomicU64 ;
34+ use std:: sync:: atomic:: AtomicBool ;
3435use std:: sync:: atomic:: Ordering ;
3536use std:: sync:: LazyLock ;
3637
37- enum ActiveTaskOp {
38- Add { id : u64 , handle : AbortHandle } ,
39- Remove { id : u64 } ,
40- }
41-
4238/// Drop guard for task removal. If a task panics, this makes sure
4339/// it is removed from [`ActiveTasks`] properly.
4440struct RemoveOnDrop {
45- id : u64 ,
46- task_tx_weak : mpsc :: WeakUnboundedSender < ActiveTaskOp > ,
41+ id : task :: Id ,
42+ storage : & ' static ActiveTasks ,
4743}
4844impl Drop for RemoveOnDrop {
4945 fn drop ( & mut self ) {
50- if let Some ( tx) = self . task_tx_weak . upgrade ( ) {
51- let _ = tx. send ( ActiveTaskOp :: Remove { id : self . id } ) ;
52- }
46+ self . storage . remove_task ( self . id ) ;
5347 }
5448}
5549
5650/// A task killswitch that allows aborting all the tasks spawned with it at
57- /// once. The implementation strives to not introduce any in-band locking, so
58- /// spawning the future doesn't require acquiring a global lock, keeping the
59- /// regular pace of operation.
51+ /// once. The implementation strives to minimize in-band locking. Spawning a
52+ /// future requires a single sharded lock from an internal [`DashMap`].
53+ /// Conflicts are expected to be very rare (dashmap defaults to `4 * nproc`
54+ /// shards, while each thread can only spawn one task at a time.)
6055struct TaskKillswitch {
61- // NOTE: use a lock without poisoning here to not panic all the threads if
62- // one of the worker threads panic.
63- task_tx : parking_lot:: RwLock < Option < mpsc:: UnboundedSender < ActiveTaskOp > > > ,
64- task_counter : AtomicU64 ,
56+ // Invariant: If `activated` is true, we don't add new tasks anymore.
57+ activated : AtomicBool ,
58+ storage : & ' static ActiveTasks ,
59+
60+ /// Watcher that is triggered after all kill signals have been sent (by
61+ /// dropping `signal_killed`.) Currently-running tasks are killed after
62+ /// their next yield, which may be after this triggers.
6563 all_killed : watch:: Receiver < ( ) > ,
64+ // NOTE: All we want here is to take ownership of `signal_killed` when
65+ // activating the killswitch. That code path only runs once per instance, but
66+ // requires interior mutability. Using `Mutex` is easier than bothering with
67+ // an `UnsafeCell`. The mutex is guaranteed to be unlocked.
68+ signal_killed : Mutex < Option < watch:: Sender < ( ) > > > ,
6669}
6770
6871impl TaskKillswitch {
69- fn new ( ) -> Self {
70- let ( task_tx, task_rx) = mpsc:: unbounded_channel ( ) ;
72+ fn new ( storage : & ' static ActiveTasks ) -> Self {
7173 let ( signal_killed, all_killed) = watch:: channel ( ( ) ) ;
72-
73- let active_tasks = ActiveTasks {
74- task_rx,
75- tasks : Default :: default ( ) ,
76- signal_killed,
77- } ;
78- tokio:: spawn ( active_tasks. collect ( ) ) ;
74+ let signal_killed = Mutex :: new ( Some ( signal_killed) ) ;
7975
8076 Self {
81- task_tx : parking_lot:: RwLock :: new ( Some ( task_tx) ) ,
82- task_counter : Default :: default ( ) ,
77+ activated : AtomicBool :: new ( false ) ,
78+ storage,
79+ signal_killed,
8380 all_killed,
8481 }
8582 }
8683
84+ /// Creates a killswitch by allocating and leaking the task storage.
85+ ///
86+ /// **NOTE:** This is intended for use in `static`s and tests. It should not
87+ /// be exposed publicly!
88+ fn with_leaked_storage ( ) -> Self {
89+ let storage = Box :: leak ( Box :: new ( ActiveTasks :: default ( ) ) ) ;
90+ Self :: new ( storage)
91+ }
92+
93+ fn was_activated ( & self ) -> bool {
94+ // All synchronization is done using locks,
95+ // so we can use relaxed for our atomics.
96+ self . activated . load ( Ordering :: Relaxed )
97+ }
98+
8799 fn spawn_task ( & self , fut : impl Future < Output = ( ) > + Send + ' static ) {
88- // NOTE: acquiring the lock here is very cheap, as unless the killswitch
89- // is activated, this one is always unlocked and this is just a
90- // few atomic operations.
91- let Some ( task_tx) = self . task_tx . read ( ) . as_ref ( ) . cloned ( ) else {
100+ if self . was_activated ( ) {
92101 return ;
93- } ;
94-
95- let id = self . task_counter . fetch_add ( 1 , Ordering :: SeqCst ) ;
96- assert ! ( id < u64 :: MAX , "task-killswitch ID counter wrapped around!" ) ;
97- let task_tx_weak = task_tx. downgrade ( ) ;
102+ }
98103
104+ let storage = self . storage ;
99105 let handle = tokio:: spawn ( async move {
100- // NOTE: we use a weak sender inside the spawned task - dropping
101- // all strong senders activates the killswitch. In that case,
102- // we don't need to remove anything from ActiveTasks anymore.
103- let _guard = RemoveOnDrop { task_tx_weak, id } ;
106+ let id = task:: id ( ) ;
107+ let _guard = RemoveOnDrop { id, storage } ;
104108 fut. await ;
105109 } )
106110 . abort_handle ( ) ;
107111
108- let _ = task_tx. send ( ActiveTaskOp :: Add { id, handle } ) ;
112+ let res = self . storage . add_task_if ( handle, || !self . was_activated ( ) ) ;
113+ if let Err ( handle) = res {
114+ // Killswitch was activated by the time we got a lock on the map shard
115+ handle. abort ( ) ;
116+ }
109117 }
110118
111119 fn activate ( & self ) {
112- // take()ing the sender here drops it and thereby triggers the killswitch.
113- // Concurrent spawn_task calls may still hold strong senders, which
114- // ensures those tasks are added to ActiveTasks before the killing
115- // starts .
120+ // We check `activated` after locking the map shard and before inserting
121+ // an element. This ensures in-progress spawns either complete before
122+ // `tasks.kill_all()` obtains the lock for that shard, or they abort
123+ // afterwards .
116124 assert ! (
117- self . task_tx . write ( ) . take ( ) . is_some ( ) ,
125+ ! self . activated . swap ( true , Ordering :: Relaxed ) ,
118126 "killswitch can't be used twice"
119127 ) ;
128+
129+ let tasks = self . storage ;
130+ let signal_killed = self . signal_killed . lock ( ) . take ( ) ;
131+ std:: thread:: spawn ( move || {
132+ tasks. kill_all ( ) ;
133+ drop ( signal_killed) ;
134+ } ) ;
120135 }
121136
122137 fn killed ( & self ) -> impl Future < Output = ( ) > + Send + ' static {
@@ -130,57 +145,52 @@ impl TaskKillswitch {
130145enum TaskEntry {
131146 /// Task was added and not yet removed.
132147 Handle ( AbortHandle ) ,
133- /// Task was removed before it was added. This can happen
134- /// if a spawned future completes before the `Add` message is sent .
148+ /// Task was removed before it was added. This can happen if a spawned
149+ /// future completes before the spawning thread can add it to the map .
135150 Tombstone ,
136151}
137152
153+ #[ derive( Default ) ]
138154struct ActiveTasks {
139- task_rx : mpsc:: UnboundedReceiver < ActiveTaskOp > ,
140- tasks : HashMap < u64 , TaskEntry > ,
141- signal_killed : watch:: Sender < ( ) > ,
155+ tasks : DashMap < task:: Id , TaskEntry > ,
142156}
143157
144158impl ActiveTasks {
145- async fn collect ( mut self ) {
146- while let Some ( op) = self . task_rx . recv ( ) . await {
147- self . handle_task_op ( op) ;
148- }
149-
150- for entry in self . tasks . into_values ( ) {
159+ fn kill_all ( & self ) {
160+ self . tasks . retain ( |_, entry| {
151161 if let TaskEntry :: Handle ( task) = entry {
152162 task. abort ( ) ;
153163 }
154- }
155- drop ( self . signal_killed ) ;
164+ false // remove all elements
165+ } ) ;
156166 }
157167
158- fn handle_task_op ( & mut self , op : ActiveTaskOp ) {
159- match op {
160- ActiveTaskOp :: Add { id, handle } => self . add_task ( id, handle) ,
161- ActiveTaskOp :: Remove { id } => self . remove_task ( id) ,
162- }
163- }
168+ fn add_task_if (
169+ & self , handle : AbortHandle , cond : impl FnOnce ( ) -> bool ,
170+ ) -> Result < ( ) , AbortHandle > {
171+ use dashmap:: Entry :: * ;
172+ let id = handle. id ( ) ;
164173
165- fn add_task ( & mut self , id : u64 , handle : AbortHandle ) {
166- use std:: collections:: hash_map:: Entry :: Occupied ;
167174 match self . tasks . entry ( id) {
175+ Vacant ( e) => {
176+ if !cond ( ) {
177+ return Err ( handle) ;
178+ }
179+ e. insert ( TaskEntry :: Handle ( handle) ) ;
180+ } ,
168181 Occupied ( e) if matches ! ( e. get( ) , TaskEntry :: Tombstone ) => {
169182 // Task was removed before it was added. Clear the map entry and
170183 // drop the handle.
171184 e. remove ( ) ;
172185 } ,
173- e => {
174- // We assert against duplicate IDs in `spawn_task`. Panicing here
175- // wouldn't do anything besides stopping the killswitch loop, so
176- // just overwrite the handle.
177- e. insert_entry ( TaskEntry :: Handle ( handle) ) ;
178- } ,
186+ Occupied ( _) => panic ! ( "tokio task ID already in use: {id}" ) ,
179187 }
188+
189+ Ok ( ( ) )
180190 }
181191
182- fn remove_task ( & mut self , id : u64 ) {
183- use std :: collections :: hash_map :: Entry :: * ;
192+ fn remove_task ( & self , id : task :: Id ) {
193+ use dashmap :: Entry :: * ;
184194 match self . tasks . entry ( id) {
185195 Vacant ( e) => {
186196 // Task was not added yet, set a tombstone instead.
@@ -196,7 +206,7 @@ impl ActiveTasks {
196206
197207/// The global [`TaskKillswitch`] exposed publicly from the crate.
198208static TASK_KILLSWITCH : LazyLock < TaskKillswitch > =
199- LazyLock :: new ( TaskKillswitch :: new ) ;
209+ LazyLock :: new ( TaskKillswitch :: with_leaked_storage ) ;
200210
201211/// Spawns a new asynchronous task and registers it in the crate's global
202212/// killswitch.
@@ -276,7 +286,7 @@ mod tests {
276286
277287 #[ tokio:: test]
278288 async fn activate_killswitch_early ( ) {
279- let killswitch = TaskKillswitch :: new ( ) ;
289+ let killswitch = TaskKillswitch :: with_leaked_storage ( ) ;
280290 let abort_signals = start_test_tasks ( & killswitch) ;
281291
282292 killswitch. activate ( ) ;
@@ -291,7 +301,7 @@ mod tests {
291301
292302 #[ tokio:: test]
293303 async fn activate_killswitch_with_delay ( ) {
294- let killswitch = TaskKillswitch :: new ( ) ;
304+ let killswitch = TaskKillswitch :: with_leaked_storage ( ) ;
295305 let abort_signals = start_test_tasks ( & killswitch) ;
296306 let signal_handle = tokio:: spawn ( killswitch. killed ( ) ) ;
297307
0 commit comments