diff --git a/src/lib.rs b/src/lib.rs index c1318de..5661942 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,10 +2,17 @@ use std::{ future::Future, - sync::{Arc, Condvar, Mutex}, + pin::Pin, + sync::Arc, task::{Context, Poll, Wake, Waker}, + thread, }; +thread_local! { + // A local reusable signal for each thread. + static LOCAL_THREAD_SIGNAL: Arc = Arc::new(Signal::new()); +} + #[cfg(feature = "macro")] pub use pollster_macro::{main, test}; @@ -22,79 +29,32 @@ pub trait FutureExt: Future { /// /// let result = my_fut.block_on(); /// ``` - fn block_on(self) -> Self::Output where Self: Sized { block_on(self) } + fn block_on(self) -> Self::Output + where + Self: Sized, + { + block_on(self) + } } impl FutureExt for F {} -enum SignalState { - Empty, - Waiting, - Notified, -} - struct Signal { - state: Mutex, - cond: Condvar, + /// The thread that owns the signal. + owning_thread: thread::Thread, } impl Signal { fn new() -> Self { Self { - state: Mutex::new(SignalState::Empty), - cond: Condvar::new(), - } - } - - fn wait(&self) { - let mut state = self.state.lock().unwrap(); - match *state { - SignalState::Notified => { - // Notify() was called before we got here, consume it here without waiting and return immediately. - *state = SignalState::Empty; - return; - } - // This should not be possible because our signal is created within a function and never handed out to any - // other threads. If this is the case, we have a serious problem so we panic immediately to avoid anything - // more problematic happening. - SignalState::Waiting => { - unreachable!("Multiple threads waiting on the same signal: Open a bug report!"); - } - SignalState::Empty => { - // Nothing has happened yet, and we're the only thread waiting (as should be the case!). Set the state - // accordingly and begin polling the condvar in a loop until it's no longer telling us to wait. The - // loop prevents incorrect spurious wakeups. - *state = SignalState::Waiting; - while let SignalState::Waiting = *state { - state = self.cond.wait(state).unwrap(); - } - } - } - } - - fn notify(&self) { - let mut state = self.state.lock().unwrap(); - match *state { - // The signal was already notified, no need to do anything because the thread will be waking up anyway - SignalState::Notified => {} - // The signal wasn't notified but a thread isn't waiting on it, so we can avoid doing unnecessary work by - // skipping the condvar and leaving behind a message telling the thread that a notification has already - // occurred should it come along in the future. - SignalState::Empty => *state = SignalState::Notified, - // The signal wasn't notified and there's a waiting thread. Reset the signal so it can be wait()'ed on again - // and wake up the thread. Because there should only be a single thread waiting, `notify_all` would also be - // valid. - SignalState::Waiting => { - *state = SignalState::Empty; - self.cond.notify_one(); - } + owning_thread: thread::current(), } } } impl Wake for Signal { fn wake(self: Arc) { - self.notify(); + self.owning_thread.unpark(); } } @@ -111,23 +71,20 @@ pub fn block_on(mut fut: F) -> F::Output { // SAFETY: We shadow `fut` so that it cannot be used again. The future is now pinned to the stack and will not be // moved until the end of this scope. This is, incidentally, exactly what the `pin_mut!` macro from `pin_utils` // does. - let mut fut = unsafe { std::pin::Pin::new_unchecked(&mut fut) }; - - // Signal used to wake up the thread for polling as the future moves to completion. We need to use an `Arc` - // because, although the lifetime of `fut` is limited to this function, the underlying IO abstraction might keep - // the signal alive for far longer. `Arc` is a thread-safe way to allow this to happen. - // TODO: Investigate ways to reuse this `Arc`... perhaps via a `static`? - let signal = Arc::new(Signal::new()); + let mut fut = unsafe { Pin::new_unchecked(&mut fut) }; - // Create a context that will be passed to the future. - let waker = Waker::from(Arc::clone(&signal)); - let mut context = Context::from_waker(&waker); + // A signal used to wake up the thread for polling as the future moves to completion. + LOCAL_THREAD_SIGNAL.with(|signal| { + // Create a waker and a context to be passed to the future. + let waker = Waker::from(Arc::clone(signal)); + let mut context = Context::from_waker(&waker); - // Poll the future to completion - loop { - match fut.as_mut().poll(&mut context) { - Poll::Pending => signal.wait(), - Poll::Ready(item) => break item, + // Poll the future to completion. + loop { + match fut.as_mut().poll(&mut context) { + Poll::Pending => thread::park(), + Poll::Ready(item) => break item, + } } - } + }) }