11use {
22 super :: {
33 Auth ,
4+ RestError ,
45 WrappedRouter ,
56 } ,
67 crate :: {
3233 State ,
3334 WebSocketUpgrade ,
3435 } ,
36+ http:: HeaderMap ,
3537 response:: IntoResponse ,
3638 Router ,
3739 } ,
6870 StreamExt ,
6971 } ,
7072 std:: {
71- collections:: HashSet ,
73+ collections:: {
74+ HashMap ,
75+ HashSet ,
76+ } ,
7277 future:: Future ,
78+ net:: IpAddr ,
7379 sync:: {
7480 atomic:: {
7581 AtomicUsize ,
8288 time:: OffsetDateTime ,
8389 tokio:: sync:: {
8490 broadcast,
91+ RwLock ,
8592 Semaphore ,
8693 } ,
8794 tracing:: {
@@ -91,26 +98,102 @@ use {
9198} ;
9299
93100pub struct WsState {
94- pub subscriber_counter : AtomicUsize ,
95- pub broadcast_sender : broadcast:: Sender < UpdateEvent > ,
96- pub broadcast_receiver : broadcast:: Receiver < UpdateEvent > ,
101+ pub requester_ip_header_name : String ,
102+ subscriber_counter : AtomicUsize ,
103+ subscriber_per_ip : RwLock < HashMap < IpAddr , HashSet < SubscriberId > > > ,
104+ pub broadcast_sender : broadcast:: Sender < UpdateEvent > ,
105+ pub broadcast_receiver : broadcast:: Receiver < UpdateEvent > ,
106+ }
107+
108+ const MAXIMUM_SUBSCRIBERS_PER_IP : usize = 10 ;
109+
110+ impl WsState {
111+ pub fn new ( requester_ip_header_name : String , broadcast_channel_size : usize ) -> Self {
112+ let ( broadcast_sender, broadcast_receiver) = broadcast:: channel ( broadcast_channel_size) ;
113+ Self {
114+ requester_ip_header_name,
115+ subscriber_counter : AtomicUsize :: new ( 0 ) ,
116+ subscriber_per_ip : RwLock :: new ( HashMap :: new ( ) ) ,
117+ broadcast_sender,
118+ broadcast_receiver,
119+ }
120+ }
121+
122+ /// If the specified IP address has too many open websocket connections, this function will
123+ /// return none. Otherwise, it will return the new subscriber id.
124+ pub async fn get_new_subscriber_id ( & self , ip : Option < IpAddr > ) -> Option < SubscriberId > {
125+ let id = self . subscriber_counter . fetch_add ( 1 , Ordering :: SeqCst ) ;
126+ if let Some ( ip) = ip {
127+ let mut write_gaurd = self . subscriber_per_ip . write ( ) . await ;
128+ let ids = write_gaurd. entry ( ip) . or_insert_with ( HashSet :: new) ;
129+ if ids. len ( ) >= MAXIMUM_SUBSCRIBERS_PER_IP {
130+ return None ;
131+ }
132+ ids. insert ( id) ;
133+ }
134+ Some ( id)
135+ }
136+
137+ pub async fn remove_subscriber ( & self , id : SubscriberId , ip : Option < IpAddr > ) {
138+ if let Some ( ip) = ip {
139+ let mut write_gaurd = self . subscriber_per_ip . write ( ) . await ;
140+ if let Some ( ids) = write_gaurd. get_mut ( & ip) {
141+ ids. remove ( & id) ;
142+ if ids. is_empty ( ) {
143+ write_gaurd. remove ( & ip) ;
144+ }
145+ }
146+ }
147+ }
97148}
98149
99150pub async fn ws_route_handler (
100151 auth : Auth ,
101152 ws : WebSocketUpgrade ,
102153 State ( store) : State < Arc < StoreNew > > ,
154+ headers : HeaderMap ,
103155) -> impl IntoResponse {
104- ws. on_upgrade ( move |socket| websocket_handler ( socket, store, auth) )
156+ let ws_state = & store. store . ws ;
157+ let requester_ip = headers
158+ . get ( ws_state. requester_ip_header_name . as_str ( ) )
159+ . and_then ( |value| value. to_str ( ) . ok ( ) )
160+ . and_then ( |value| value. split ( ',' ) . next ( ) ) // Only take the first ip if there are multiple
161+ . and_then ( |value| value. parse ( ) . ok ( ) ) ;
162+
163+ if requester_ip. is_none ( ) {
164+ tracing:: warn!( "Failed to get requester IP address" ) ;
165+ }
166+
167+ match ws_state. get_new_subscriber_id ( requester_ip) . await {
168+ Some ( subscriber_id) => ws. on_upgrade ( move |socket| {
169+ websocket_handler ( socket, store, subscriber_id, auth, requester_ip)
170+ } ) ,
171+ None => RestError :: TooManyOpenWebsocketConnections . into_response ( ) ,
172+ }
105173}
106174
107- async fn websocket_handler ( stream : WebSocket , state : Arc < StoreNew > , auth : Auth ) {
175+ async fn websocket_handler (
176+ stream : WebSocket ,
177+ state : Arc < StoreNew > ,
178+ subscriber_id : SubscriberId ,
179+ auth : Auth ,
180+ requester_ip : Option < IpAddr > ,
181+ ) {
108182 let ws_state = & state. store . ws ;
109- let id = ws_state. subscriber_counter . fetch_add ( 1 , Ordering :: SeqCst ) ;
110183 let ( sender, receiver) = stream. split ( ) ;
111184 let new_receiver = ws_state. broadcast_receiver . resubscribe ( ) ;
112- let mut subscriber = Subscriber :: new ( id, state, new_receiver, receiver, sender, auth) ;
185+ let mut subscriber = Subscriber :: new (
186+ subscriber_id,
187+ state. clone ( ) ,
188+ new_receiver,
189+ receiver,
190+ sender,
191+ auth,
192+ ) ;
113193 subscriber. run ( ) . await ;
194+ ws_state
195+ . remove_subscriber ( subscriber_id, requester_ip)
196+ . await ;
114197}
115198
116199#[ derive( Clone , PartialEq , Debug ) ]
0 commit comments