diff --git a/protocols/kad/Cargo.toml b/protocols/kad/Cargo.toml index a5eb4ac2f1a..b6c06a4ce2b 100644 --- a/protocols/kad/Cargo.toml +++ b/protocols/kad/Cargo.toml @@ -33,7 +33,7 @@ thiserror = { workspace = true } tracing = { workspace = true } [dev-dependencies] -tokio = { workspace = true, features = ["macros", "rt-multi-thread", "time"] } +tokio = { workspace = true, features = ["macros", "rt-multi-thread", "time", "test-util"] } futures-timer = "3.0" libp2p-identify = { path = "../identify" } libp2p-noise = { workspace = true } diff --git a/protocols/kad/src/bootstrap.rs b/protocols/kad/src/bootstrap.rs index 1588838916e..478acc4c278 100644 --- a/protocols/kad/src/bootstrap.rs +++ b/protocols/kad/src/bootstrap.rs @@ -4,8 +4,54 @@ use std::{ }; use futures::FutureExt; + +#[cfg(not(test))] use futures_timer::Delay; +#[cfg(test)] +use mock_delay::Delay; + +#[cfg(test)] +mod mock_delay { + use std::future::Future; + use std::pin::Pin; + use std::task::{Context, Poll}; + use std::time::Duration; + + #[derive(Debug)] + pub(super) enum Delay { + Tokio(Pin>), + Futures(futures_timer::Delay), + } + + impl Delay { + pub(super) fn new(dur: Duration) -> Self { + if tokio::runtime::Handle::try_current().is_ok() { + Self::Tokio(Box::pin(tokio::time::sleep(dur))) + } else { + Self::Futures(futures_timer::Delay::new(dur)) + } + } + + pub(super) fn reset(&mut self, dur: Duration) { + match self { + Self::Tokio(sleep) => sleep.as_mut().reset(tokio::time::Instant::now() + dur), + Self::Futures(delay) => delay.reset(dur), + } + } + } + + impl Future for Delay { + type Output = (); + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match &mut *self { + Self::Tokio(sleep) => sleep.as_mut().poll(cx), + Self::Futures(delay) => Pin::new(delay).poll(cx), + } + } + } +} + /// Default value chosen at ``. pub(crate) const DEFAULT_AUTOMATIC_THROTTLE: Duration = Duration::from_millis(500); @@ -181,7 +227,7 @@ impl futures::Future for ThrottleTimer { #[cfg(test)] mod tests { - use web_time::Instant; + use tokio::time::Instant; use super::*; @@ -198,7 +244,7 @@ mod tests { do_bootstrap(status); } - #[tokio::test] + #[tokio::test(start_paused = true)] async fn immediate_automatic_bootstrap_is_triggered_immediately() { let mut status = Status::new(Some(Duration::from_secs(1)), Some(Duration::ZERO)); @@ -223,7 +269,7 @@ mod tests { ); } - #[tokio::test] + #[tokio::test(start_paused = true)] async fn delayed_automatic_bootstrap_is_triggered_before_periodic_bootstrap() { let mut status = Status::new(Some(Duration::from_secs(1)), Some(MS_5)); @@ -261,7 +307,7 @@ mod tests { ) } - #[tokio::test] + #[tokio::test(start_paused = true)] async fn given_periodic_bootstrap_when_routing_table_updated_then_wont_bootstrap_until_next_interval() { let mut status = Status::new(Some(MS_100), Some(MS_5)); @@ -278,10 +324,10 @@ mod tests { await_and_do_bootstrap(&mut status).await; let elapsed = Instant::now().duration_since(start); - assert!(elapsed > MS_100); + assert!(elapsed >= MS_100); } - #[tokio::test] + #[tokio::test(start_paused = true)] async fn given_no_periodic_bootstrap_and_automatic_bootstrap_when_new_entry_then_will_bootstrap() { let mut status = Status::new(None, Some(Duration::ZERO)); @@ -291,7 +337,7 @@ mod tests { status.next().await; } - #[tokio::test] + #[tokio::test(start_paused = true)] async fn given_periodic_bootstrap_and_no_automatic_bootstrap_triggers_periodically() { let mut status = Status::new(Some(MS_100), None); @@ -306,7 +352,7 @@ mod tests { } } - #[tokio::test] + #[tokio::test(start_paused = true)] async fn given_no_periodic_bootstrap_and_automatic_bootstrap_reset_throttle_when_multiple_peers() { let mut status = Status::new(None, Some(MS_100)); @@ -330,7 +376,7 @@ mod tests { ); } - #[tokio::test] + #[tokio::test(start_paused = true)] async fn given_periodic_bootstrap_and_no_automatic_bootstrap_manually_triggering_prevent_periodic() { let mut status = Status::new(Some(MS_100), None);