diff --git a/protocols/relay/src/behaviour.rs b/protocols/relay/src/behaviour.rs index 2e395c3dac6..3c2d67dd9d3 100644 --- a/protocols/relay/src/behaviour.rs +++ b/protocols/relay/src/behaviour.rs @@ -410,7 +410,7 @@ impl NetworkBehaviour for Behaviour { .get(&event_source) .map(|cs| cs.len()) .unwrap_or(0) - > self.config.max_reservations_per_peer) + >= self.config.max_reservations_per_peer) // Deny if it exceeds `max_reservations`. || self .reservations diff --git a/protocols/relay/tests/lib.rs b/protocols/relay/tests/lib.rs index de3087b2903..6117b061553 100644 --- a/protocols/relay/tests/lib.rs +++ b/protocols/relay/tests/lib.rs @@ -345,6 +345,78 @@ async fn propagate_reservation_error_to_listener() { )); } +#[tokio::test] +async fn enforce_reservation_limit_per_peer_across_connections() { + let _ = tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .try_init(); + + let relay_addr = Multiaddr::empty().with(Protocol::Memory(rand::random::())); + let mut relay = build_relay_with_config(relay::Config { + max_reservations: 10, + max_reservations_per_peer: 1, + reservation_duration: Duration::from_secs(60), + ..relay::Config::default() + }); + let relay_peer_id = *relay.local_peer_id(); + + relay.listen_on(relay_addr.clone()).unwrap(); + relay.add_external_address(relay_addr.clone()); + tokio::spawn(async move { + relay.collect::>().await; + }); + + let attacker_key = identity::Keypair::generate_ed25519(); + let attacker_peer_id = attacker_key.public().to_peer_id(); + let relayed_addr = relay_addr + .clone() + .with(Protocol::P2p(relay_peer_id)) + .with(Protocol::P2pCircuit); + let relayed_addr_with_peer = relayed_addr.clone().with(Protocol::P2p(attacker_peer_id)); + + let mut attacker_conn_1 = + build_client_with_key(attacker_key.clone(), Config::with_tokio_executor()); + attacker_conn_1.listen_on(relayed_addr.clone()).unwrap(); + assert!(wait_for_dial(&mut attacker_conn_1, relay_peer_id).await); + wait_for_reservation( + &mut attacker_conn_1, + relayed_addr_with_peer.clone(), + relay_peer_id, + false, + ) + .await; + + tokio::spawn(async move { + attacker_conn_1.collect::>().await; + }); + + let mut attacker_conn_2 = build_client_with_key(attacker_key, Config::with_tokio_executor()); + let reservation_listener = attacker_conn_2.listen_on(relayed_addr).unwrap(); + assert!(wait_for_dial(&mut attacker_conn_2, relay_peer_id).await); + + let error = attacker_conn_2 + .wait(|e| match e { + SwarmEvent::ListenerClosed { + listener_id, + reason: Err(e), + .. + } if listener_id == reservation_listener => Some(e), + _ => None, + }) + .await; + + let error = error + .source() + .unwrap() + .downcast_ref::() + .unwrap(); + + assert!(matches!( + error, + relay::outbound::hop::ReserveError::ResourceLimitExceeded + )); +} + #[tokio::test] async fn propagate_connect_error_to_unknown_peer_to_dialer() { let _ = tracing_subscriber::fmt() @@ -471,7 +543,10 @@ fn build_client() -> Swarm { } fn build_client_with_config(config: Config) -> Swarm { - let local_key = identity::Keypair::generate_ed25519(); + build_client_with_key(identity::Keypair::generate_ed25519(), config) +} + +fn build_client_with_key(local_key: identity::Keypair, config: Config) -> Swarm { let local_peer_id = local_key.public().to_peer_id(); let (relay_transport, behaviour) = relay::client::new(local_peer_id);