Skip to content

Commit 0be61f3

Browse files
authored
Add rate limiter for total number of subscriber per ip (#435)
1 parent f2d38f5 commit 0be61f3

5 files changed

Lines changed: 108 additions & 25 deletions

File tree

auction-server/src/api.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,8 @@ pub enum RestError {
369369
InvalidFirstSigner(String),
370370
/// Duplicate bid
371371
DuplicateBid,
372+
/// Too many open websocket connections
373+
TooManyOpenWebsocketConnections,
372374
}
373375

374376

@@ -482,6 +484,10 @@ impl RestError {
482484
StatusCode::BAD_REQUEST,
483485
"Duplicate bid".to_string(),
484486
),
487+
RestError::TooManyOpenWebsocketConnections => (
488+
StatusCode::TOO_MANY_REQUESTS,
489+
"Too many open websocket connections".to_string(),
490+
),
485491
}
486492
}
487493
}

auction-server/src/api/ws.rs

Lines changed: 91 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use {
22
super::{
33
Auth,
4+
RestError,
45
WrappedRouter,
56
},
67
crate::{
@@ -32,6 +33,7 @@ use {
3233
State,
3334
WebSocketUpgrade,
3435
},
36+
http::HeaderMap,
3537
response::IntoResponse,
3638
Router,
3739
},
@@ -68,8 +70,12 @@ use {
6870
StreamExt,
6971
},
7072
std::{
71-
collections::HashSet,
73+
collections::{
74+
HashMap,
75+
HashSet,
76+
},
7277
future::Future,
78+
net::IpAddr,
7379
sync::{
7480
atomic::{
7581
AtomicUsize,
@@ -82,6 +88,7 @@ use {
8288
time::OffsetDateTime,
8389
tokio::sync::{
8490
broadcast,
91+
RwLock,
8592
Semaphore,
8693
},
8794
tracing::{
@@ -91,26 +98,102 @@ use {
9198
};
9299

93100
pub struct WsState {
94-
pub subscriber_counter: AtomicUsize,
95-
pub broadcast_sender: broadcast::Sender<UpdateEvent>,
96-
pub broadcast_receiver: broadcast::Receiver<UpdateEvent>,
101+
pub requester_ip_header_name: String,
102+
subscriber_counter: AtomicUsize,
103+
subscriber_per_ip: RwLock<HashMap<IpAddr, HashSet<SubscriberId>>>,
104+
pub broadcast_sender: broadcast::Sender<UpdateEvent>,
105+
pub broadcast_receiver: broadcast::Receiver<UpdateEvent>,
106+
}
107+
108+
const MAXIMUM_SUBSCRIBERS_PER_IP: usize = 10;
109+
110+
impl WsState {
111+
pub fn new(requester_ip_header_name: String, broadcast_channel_size: usize) -> Self {
112+
let (broadcast_sender, broadcast_receiver) = broadcast::channel(broadcast_channel_size);
113+
Self {
114+
requester_ip_header_name,
115+
subscriber_counter: AtomicUsize::new(0),
116+
subscriber_per_ip: RwLock::new(HashMap::new()),
117+
broadcast_sender,
118+
broadcast_receiver,
119+
}
120+
}
121+
122+
/// If the specified IP address has too many open websocket connections, this function will
123+
/// return none. Otherwise, it will return the new subscriber id.
124+
pub async fn get_new_subscriber_id(&self, ip: Option<IpAddr>) -> Option<SubscriberId> {
125+
let id = self.subscriber_counter.fetch_add(1, Ordering::SeqCst);
126+
if let Some(ip) = ip {
127+
let mut write_gaurd = self.subscriber_per_ip.write().await;
128+
let ids = write_gaurd.entry(ip).or_insert_with(HashSet::new);
129+
if ids.len() >= MAXIMUM_SUBSCRIBERS_PER_IP {
130+
return None;
131+
}
132+
ids.insert(id);
133+
}
134+
Some(id)
135+
}
136+
137+
pub async fn remove_subscriber(&self, id: SubscriberId, ip: Option<IpAddr>) {
138+
if let Some(ip) = ip {
139+
let mut write_gaurd = self.subscriber_per_ip.write().await;
140+
if let Some(ids) = write_gaurd.get_mut(&ip) {
141+
ids.remove(&id);
142+
if ids.is_empty() {
143+
write_gaurd.remove(&ip);
144+
}
145+
}
146+
}
147+
}
97148
}
98149

99150
pub async fn ws_route_handler(
100151
auth: Auth,
101152
ws: WebSocketUpgrade,
102153
State(store): State<Arc<StoreNew>>,
154+
headers: HeaderMap,
103155
) -> impl IntoResponse {
104-
ws.on_upgrade(move |socket| websocket_handler(socket, store, auth))
156+
let ws_state = &store.store.ws;
157+
let requester_ip = headers
158+
.get(ws_state.requester_ip_header_name.as_str())
159+
.and_then(|value| value.to_str().ok())
160+
.and_then(|value| value.split(',').next()) // Only take the first ip if there are multiple
161+
.and_then(|value| value.parse().ok());
162+
163+
if requester_ip.is_none() {
164+
tracing::warn!("Failed to get requester IP address");
165+
}
166+
167+
match ws_state.get_new_subscriber_id(requester_ip).await {
168+
Some(subscriber_id) => ws.on_upgrade(move |socket| {
169+
websocket_handler(socket, store, subscriber_id, auth, requester_ip)
170+
}),
171+
None => RestError::TooManyOpenWebsocketConnections.into_response(),
172+
}
105173
}
106174

107-
async fn websocket_handler(stream: WebSocket, state: Arc<StoreNew>, auth: Auth) {
175+
async fn websocket_handler(
176+
stream: WebSocket,
177+
state: Arc<StoreNew>,
178+
subscriber_id: SubscriberId,
179+
auth: Auth,
180+
requester_ip: Option<IpAddr>,
181+
) {
108182
let ws_state = &state.store.ws;
109-
let id = ws_state.subscriber_counter.fetch_add(1, Ordering::SeqCst);
110183
let (sender, receiver) = stream.split();
111184
let new_receiver = ws_state.broadcast_receiver.resubscribe();
112-
let mut subscriber = Subscriber::new(id, state, new_receiver, receiver, sender, auth);
185+
let mut subscriber = Subscriber::new(
186+
subscriber_id,
187+
state.clone(),
188+
new_receiver,
189+
receiver,
190+
sender,
191+
auth,
192+
);
113193
subscriber.run().await;
194+
ws_state
195+
.remove_subscriber(subscriber_id, requester_ip)
196+
.await;
114197
}
115198

116199
#[derive(Clone, PartialEq, Debug)]

auction-server/src/config/server.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use {
66
const DEFAULT_LISTEN_ADDR: &str = "127.0.0.1:9000";
77
const DEFAULT_METRICS_ADDR: &str = "127.0.0.1:9001";
88
const DEFAULT_DATABASE_MAX_CONNECTIONS: &str = "10";
9+
const DEAFULT_REQUESTER_IP_HEADER_NAME: &str = "X-Forwarded-For";
910

1011
#[derive(Args, Clone, Debug)]
1112
#[command(next_help_heading = "Server Options")]
@@ -30,4 +31,9 @@ pub struct Options {
3031
#[arg(default_value = DEFAULT_METRICS_ADDR)]
3132
#[arg(env = "METRICS_ADDR")]
3233
pub metrics_addr: SocketAddr,
34+
/// The header name to use for the requester IP address.
35+
#[arg(long = "requester-ip-header-name")]
36+
#[arg(default_value = DEAFULT_REQUESTER_IP_HEADER_NAME)]
37+
#[arg(env = "REQUESTER_IP_HEADER_NAME")]
38+
pub requester_ip_header_name: String,
3339
}

auction-server/src/opportunity/service/mod.rs

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,6 @@ pub mod tests {
305305
opportunity::repository::MockDatabase,
306306
server::setup_metrics_recorder,
307307
},
308-
std::sync::atomic::AtomicUsize,
309308
tokio::sync::broadcast::Receiver,
310309
};
311310

@@ -325,20 +324,14 @@ pub mod tests {
325324
ordered_fee_tokens: vec![],
326325
};
327326

328-
let (broadcast_sender, broadcast_receiver) = tokio::sync::broadcast::channel(100);
329-
330327
let mut chains_svm = HashMap::new();
331328
chains_svm.insert(chain_id.clone(), config_svm);
332329

333330
let store = Arc::new(Store {
334331
db: DB::connect_lazy("https://test").unwrap(),
335332
chains_evm: HashMap::new(),
336333
chains_svm: HashMap::new(),
337-
ws: ws::WsState {
338-
subscriber_counter: AtomicUsize::new(0),
339-
broadcast_sender,
340-
broadcast_receiver,
341-
},
334+
ws: ws::WsState::new("X-Forwarded-For".to_string(), 100),
342335
secret_key: "test".to_string(),
343336
access_tokens: RwLock::new(HashMap::new()),
344337
metrics_recorder: setup_metrics_recorder().unwrap(),

auction-server/src/server.rs

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ use {
7171
sync::{
7272
atomic::{
7373
AtomicBool,
74-
AtomicUsize,
7574
Ordering,
7675
},
7776
Arc,
@@ -292,9 +291,6 @@ pub async fn start_server(run_options: RunOptions) -> Result<()> {
292291

293292
let chains_svm = setup_chain_store_svm(config_map)?;
294293

295-
let (broadcast_sender, broadcast_receiver) =
296-
tokio::sync::broadcast::channel(NOTIFICATIONS_CHAN_LEN);
297-
298294
let pool = create_pg_pool(
299295
&run_options.server.database_url,
300296
run_options.server.database_max_connections,
@@ -321,11 +317,10 @@ pub async fn start_server(run_options: RunOptions) -> Result<()> {
321317
db: pool.clone(),
322318
chains_evm: chains_evm.clone(),
323319
chains_svm: chains_svm.clone(),
324-
ws: ws::WsState {
325-
subscriber_counter: AtomicUsize::new(0),
326-
broadcast_sender,
327-
broadcast_receiver,
328-
},
320+
ws: ws::WsState::new(
321+
run_options.server.requester_ip_header_name.clone(),
322+
NOTIFICATIONS_CHAN_LEN,
323+
),
329324
secret_key: run_options.secret_key.clone(),
330325
access_tokens: RwLock::new(access_tokens),
331326
metrics_recorder: setup_metrics_recorder()?,

0 commit comments

Comments
 (0)