Skip to content

Commit 44eca59

Browse files
TheJokrevanrittenhouse
authored andcommitted
task-killswitch: Switch to dashmap instead of worker task
As an alternative to #2194, it is worth considering a different design for task-killswitch. Using a concurrent hashmap like `dashmap`, each `spawn_with_killswitch()` call pays the cost of adding a task itself. Task cleanup meanwhile becomes the responsibility of the task itself, when it returns its result. dashmap internally is set up as a set of `4 * nproc` RwLock-ed regular hashmaps. Keys are sharded into maps by their hash. This means collisions between different tasks being inserted or removed concurrently are quite unlikely, since they use different keys (provided the hash function is good.) To get an idea of potential performance, I ran the benchmarks linked from dashmap's repo locally. With a 50/50 insert/remove workload using random u64 keys, dashmap had an average latency per operation between 100-200 ns and a throughput of dozens of millions of ops/s. I think this is more than fine for the task-killswitch use case.
1 parent 2d9a6ac commit 44eca59

3 files changed

Lines changed: 94 additions & 82 deletions

File tree

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ anyhow = { version = "1" }
3434
boring = { version = "4.3" }
3535
buffer-pool = { version = "0.1.0", path = "./buffer-pool" }
3636
crossbeam = { version = "0.8.1", default-features = false }
37+
dashmap = { version = "6" }
3738
datagram-socket = { version = "0.5.0", path = "./datagram-socket" }
3839
env_logger = "0.10"
3940
foundations = { version = ">=4,<6", default-features = false }

task-killswitch/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ categories = { workspace = true }
99
description = "Abort all tokio tasks at once"
1010

1111
[dependencies]
12+
dashmap = { workspace = true }
1213
parking_lot = { workspace = true }
1314
tokio = { workspace = true, features = ["rt", "sync"] }
1415

task-killswitch/src/lib.rs

Lines changed: 92 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -24,99 +24,114 @@
2424
// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
2525
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

27-
use tokio::sync::mpsc;
27+
use dashmap::DashMap;
28+
use parking_lot::Mutex;
2829
use tokio::sync::watch;
30+
use tokio::task;
2931
use tokio::task::AbortHandle;
3032

31-
use std::collections::HashMap;
3233
use std::future::Future;
33-
use std::sync::atomic::AtomicU64;
34+
use std::sync::atomic::AtomicBool;
3435
use std::sync::atomic::Ordering;
3536
use std::sync::LazyLock;
3637

37-
enum ActiveTaskOp {
38-
Add { id: u64, handle: AbortHandle },
39-
Remove { id: u64 },
40-
}
41-
4238
/// Drop guard for task removal. If a task panics, this makes sure
4339
/// it is removed from [`ActiveTasks`] properly.
4440
struct RemoveOnDrop {
45-
id: u64,
46-
task_tx_weak: mpsc::WeakUnboundedSender<ActiveTaskOp>,
41+
id: task::Id,
42+
storage: &'static ActiveTasks,
4743
}
4844
impl Drop for RemoveOnDrop {
4945
fn drop(&mut self) {
50-
if let Some(tx) = self.task_tx_weak.upgrade() {
51-
let _ = tx.send(ActiveTaskOp::Remove { id: self.id });
52-
}
46+
self.storage.remove_task(self.id);
5347
}
5448
}
5549

5650
/// A task killswitch that allows aborting all the tasks spawned with it at
57-
/// once. The implementation strives to not introduce any in-band locking, so
58-
/// spawning the future doesn't require acquiring a global lock, keeping the
59-
/// regular pace of operation.
51+
/// once. The implementation strives to minimize in-band locking. Spawning a
52+
/// future requires a single sharded lock from an internal [`DashMap`].
53+
/// Conflicts are expected to be very rare (dashmap defaults to `4 * nproc`
54+
/// shards, while each thread can only spawn one task at a time.)
6055
struct TaskKillswitch {
61-
// NOTE: use a lock without poisoning here to not panic all the threads if
62-
// one of the worker threads panic.
63-
task_tx: parking_lot::RwLock<Option<mpsc::UnboundedSender<ActiveTaskOp>>>,
64-
task_counter: AtomicU64,
56+
// Invariant: If `activated` is true, we don't add new tasks anymore.
57+
activated: AtomicBool,
58+
storage: &'static ActiveTasks,
59+
60+
/// Watcher that is triggered after all kill signals have been sent (by
61+
/// dropping `signal_killed`.) Currently-running tasks are killed after
62+
/// their next yield, which may be after this triggers.
6563
all_killed: watch::Receiver<()>,
64+
// NOTE: All we want here is to take ownership of `signal_killed` when
65+
// activating the killswitch. That code path only runs once per instance, but
66+
// requires interior mutability. Using `Mutex` is easier than bothering with
67+
// an `UnsafeCell`. The mutex is guaranteed to be unlocked.
68+
signal_killed: Mutex<Option<watch::Sender<()>>>,
6669
}
6770

6871
impl TaskKillswitch {
69-
fn new() -> Self {
70-
let (task_tx, task_rx) = mpsc::unbounded_channel();
72+
fn new(storage: &'static ActiveTasks) -> Self {
7173
let (signal_killed, all_killed) = watch::channel(());
72-
73-
let active_tasks = ActiveTasks {
74-
task_rx,
75-
tasks: Default::default(),
76-
signal_killed,
77-
};
78-
tokio::spawn(active_tasks.collect());
74+
let signal_killed = Mutex::new(Some(signal_killed));
7975

8076
Self {
81-
task_tx: parking_lot::RwLock::new(Some(task_tx)),
82-
task_counter: Default::default(),
77+
activated: AtomicBool::new(false),
78+
storage,
79+
signal_killed,
8380
all_killed,
8481
}
8582
}
8683

84+
/// Creates a killswitch by allocating and leaking the task storage.
85+
///
86+
/// **NOTE:** This is intended for use in `static`s and tests. It should not
87+
/// be exposed publicly!
88+
fn with_leaked_storage() -> Self {
89+
let storage = Box::leak(Box::new(ActiveTasks::default()));
90+
Self::new(storage)
91+
}
92+
93+
fn was_activated(&self) -> bool {
94+
// All synchronization is done using locks,
95+
// so we can use relaxed for our atomics.
96+
self.activated.load(Ordering::Relaxed)
97+
}
98+
8799
fn spawn_task(&self, fut: impl Future<Output = ()> + Send + 'static) {
88-
// NOTE: acquiring the lock here is very cheap, as unless the killswitch
89-
// is activated, this one is always unlocked and this is just a
90-
// few atomic operations.
91-
let Some(task_tx) = self.task_tx.read().as_ref().cloned() else {
100+
if self.was_activated() {
92101
return;
93-
};
94-
95-
let id = self.task_counter.fetch_add(1, Ordering::SeqCst);
96-
assert!(id < u64::MAX, "task-killswitch ID counter wrapped around!");
97-
let task_tx_weak = task_tx.downgrade();
102+
}
98103

104+
let storage = self.storage;
99105
let handle = tokio::spawn(async move {
100-
// NOTE: we use a weak sender inside the spawned task - dropping
101-
// all strong senders activates the killswitch. In that case,
102-
// we don't need to remove anything from ActiveTasks anymore.
103-
let _guard = RemoveOnDrop { task_tx_weak, id };
106+
let id = task::id();
107+
let _guard = RemoveOnDrop { id, storage };
104108
fut.await;
105109
})
106110
.abort_handle();
107111

108-
let _ = task_tx.send(ActiveTaskOp::Add { id, handle });
112+
let res = self.storage.add_task_if(handle, || !self.was_activated());
113+
if let Err(handle) = res {
114+
// Killswitch was activated by the time we got a lock on the map shard
115+
handle.abort();
116+
}
109117
}
110118

111119
fn activate(&self) {
112-
// take()ing the sender here drops it and thereby triggers the killswitch.
113-
// Concurrent spawn_task calls may still hold strong senders, which
114-
// ensures those tasks are added to ActiveTasks before the killing
115-
// starts.
120+
// We check `activated` after locking the map shard and before inserting
121+
// an element. This ensures in-progress spawns either complete before
122+
// `tasks.kill_all()` obtains the lock for that shard, or they abort
123+
// afterwards.
116124
assert!(
117-
self.task_tx.write().take().is_some(),
125+
!self.activated.swap(true, Ordering::Relaxed),
118126
"killswitch can't be used twice"
119127
);
128+
129+
let tasks = self.storage;
130+
let signal_killed = self.signal_killed.lock().take();
131+
std::thread::spawn(move || {
132+
tasks.kill_all();
133+
drop(signal_killed);
134+
});
120135
}
121136

122137
fn killed(&self) -> impl Future<Output = ()> + Send + 'static {
@@ -130,57 +145,52 @@ impl TaskKillswitch {
130145
enum TaskEntry {
131146
/// Task was added and not yet removed.
132147
Handle(AbortHandle),
133-
/// Task was removed before it was added. This can happen
134-
/// if a spawned future completes before the `Add` message is sent.
148+
/// Task was removed before it was added. This can happen if a spawned
149+
/// future completes before the spawning thread can add it to the map.
135150
Tombstone,
136151
}
137152

153+
#[derive(Default)]
138154
struct ActiveTasks {
139-
task_rx: mpsc::UnboundedReceiver<ActiveTaskOp>,
140-
tasks: HashMap<u64, TaskEntry>,
141-
signal_killed: watch::Sender<()>,
155+
tasks: DashMap<task::Id, TaskEntry>,
142156
}
143157

144158
impl ActiveTasks {
145-
async fn collect(mut self) {
146-
while let Some(op) = self.task_rx.recv().await {
147-
self.handle_task_op(op);
148-
}
149-
150-
for entry in self.tasks.into_values() {
159+
fn kill_all(&self) {
160+
self.tasks.retain(|_, entry| {
151161
if let TaskEntry::Handle(task) = entry {
152162
task.abort();
153163
}
154-
}
155-
drop(self.signal_killed);
164+
false // remove all elements
165+
});
156166
}
157167

158-
fn handle_task_op(&mut self, op: ActiveTaskOp) {
159-
match op {
160-
ActiveTaskOp::Add { id, handle } => self.add_task(id, handle),
161-
ActiveTaskOp::Remove { id } => self.remove_task(id),
162-
}
163-
}
168+
fn add_task_if(
169+
&self, handle: AbortHandle, cond: impl FnOnce() -> bool,
170+
) -> Result<(), AbortHandle> {
171+
use dashmap::Entry::*;
172+
let id = handle.id();
164173

165-
fn add_task(&mut self, id: u64, handle: AbortHandle) {
166-
use std::collections::hash_map::Entry::Occupied;
167174
match self.tasks.entry(id) {
175+
Vacant(e) => {
176+
if !cond() {
177+
return Err(handle);
178+
}
179+
e.insert(TaskEntry::Handle(handle));
180+
},
168181
Occupied(e) if matches!(e.get(), TaskEntry::Tombstone) => {
169182
// Task was removed before it was added. Clear the map entry and
170183
// drop the handle.
171184
e.remove();
172185
},
173-
e => {
174-
// We assert against duplicate IDs in `spawn_task`. Panicing here
175-
// wouldn't do anything besides stopping the killswitch loop, so
176-
// just overwrite the handle.
177-
e.insert_entry(TaskEntry::Handle(handle));
178-
},
186+
Occupied(_) => panic!("tokio task ID already in use: {id}"),
179187
}
188+
189+
Ok(())
180190
}
181191

182-
fn remove_task(&mut self, id: u64) {
183-
use std::collections::hash_map::Entry::*;
192+
fn remove_task(&self, id: task::Id) {
193+
use dashmap::Entry::*;
184194
match self.tasks.entry(id) {
185195
Vacant(e) => {
186196
// Task was not added yet, set a tombstone instead.
@@ -196,7 +206,7 @@ impl ActiveTasks {
196206

197207
/// The global [`TaskKillswitch`] exposed publicly from the crate.
198208
static TASK_KILLSWITCH: LazyLock<TaskKillswitch> =
199-
LazyLock::new(TaskKillswitch::new);
209+
LazyLock::new(TaskKillswitch::with_leaked_storage);
200210

201211
/// Spawns a new asynchronous task and registers it in the crate's global
202212
/// killswitch.
@@ -276,7 +286,7 @@ mod tests {
276286

277287
#[tokio::test]
278288
async fn activate_killswitch_early() {
279-
let killswitch = TaskKillswitch::new();
289+
let killswitch = TaskKillswitch::with_leaked_storage();
280290
let abort_signals = start_test_tasks(&killswitch);
281291

282292
killswitch.activate();
@@ -291,7 +301,7 @@ mod tests {
291301

292302
#[tokio::test]
293303
async fn activate_killswitch_with_delay() {
294-
let killswitch = TaskKillswitch::new();
304+
let killswitch = TaskKillswitch::with_leaked_storage();
295305
let abort_signals = start_test_tasks(&killswitch);
296306
let signal_handle = tokio::spawn(killswitch.killed());
297307

0 commit comments

Comments
 (0)