@@ -12,7 +12,7 @@ use tokio::{
1212
1313pub ( crate ) type PacketSender = UnboundedSender < NetworkPacket > ;
1414pub ( crate ) type PacketReceiver = UnboundedReceiver < NetworkPacket > ;
15- pub ( crate ) type SessionCollection = std :: sync :: Arc < tokio :: sync :: Mutex < AHashMap < NetworkTuple , PacketSender > > > ;
15+ pub ( crate ) type SessionCollection = AHashMap < NetworkTuple , PacketSender > ;
1616
1717mod error;
1818mod packet;
@@ -105,7 +105,8 @@ fn run<Device: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
105105 mut device : Device ,
106106 accept_sender : UnboundedSender < IpStackStream > ,
107107) -> JoinHandle < Result < ( ) > > {
108- let sessions: SessionCollection = std:: sync:: Arc :: new ( tokio:: sync:: Mutex :: new ( AHashMap :: new ( ) ) ) ;
108+ let mut sessions: SessionCollection = AHashMap :: new ( ) ;
109+ let ( session_remove_tx, mut session_remove_rx) = mpsc:: unbounded_channel :: < NetworkTuple > ( ) ;
109110 let pi = config. packet_information ;
110111 let offset = if pi && cfg ! ( unix) { 4 } else { 0 } ;
111112 let mut buffer = vec ! [ 0_u8 ; u16 :: MAX as usize + offset] ;
@@ -115,8 +116,7 @@ fn run<Device: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
115116 loop {
116117 select ! {
117118 Ok ( n) = device. read( & mut buffer) => {
118- let u = up_pkt_sender. clone( ) ;
119- if let Err ( e) = process_device_read( & buffer[ offset..n] , sessions. clone( ) , u, & config, & accept_sender) . await {
119+ if let Err ( e) = process_device_read( & buffer[ offset..n] , & mut sessions, & session_remove_tx, & up_pkt_sender, & config, & accept_sender) . await {
120120 let io_err: std:: io:: Error = e. into( ) ;
121121 if io_err. kind( ) == std:: io:: ErrorKind :: ConnectionRefused {
122122 log:: trace!( "Received junk data: {io_err}" ) ;
@@ -125,6 +125,12 @@ fn run<Device: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
125125 }
126126 }
127127 }
128+ network_tuple = session_remove_rx. recv( ) => {
129+ if let Some ( network_tuple) = network_tuple {
130+ sessions. remove( & network_tuple) ;
131+ log:: debug!( "session destroyed: {network_tuple}" ) ;
132+ }
133+ }
128134 Some ( packet) = up_pkt_receiver. recv( ) => {
129135 process_upstream_recv( packet, & mut device, #[ cfg( unix) ] pi) . await ?;
130136 }
@@ -135,8 +141,9 @@ fn run<Device: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
135141
136142async fn process_device_read (
137143 data : & [ u8 ] ,
138- sessions : SessionCollection ,
139- up_pkt_sender : PacketSender ,
144+ sessions : & mut SessionCollection ,
145+ session_remove_tx : & UnboundedSender < NetworkTuple > ,
146+ up_pkt_sender : & PacketSender ,
140147 config : & IpStackConfig ,
141148 accept_sender : & UnboundedSender < IpStackStream > ,
142149) -> Result < ( ) > {
@@ -153,27 +160,26 @@ async fn process_device_read(
153160 packet. payload . unwrap_or_default ( ) ,
154161 & packet. ip ,
155162 config. mtu ,
156- up_pkt_sender,
163+ up_pkt_sender. clone ( ) ,
157164 ) ) ;
158165 accept_sender. send ( stream) ?;
159166 return Ok ( ( ) ) ;
160167 }
161168
162- let sessions_clone = sessions. clone ( ) ;
163169 let network_tuple = packet. network_tuple ( ) ;
164- match sessions. lock ( ) . await . entry ( network_tuple) {
170+ match sessions. entry ( network_tuple) {
165171 std:: collections:: hash_map:: Entry :: Occupied ( entry) => {
166172 let len = packet. payload . as_ref ( ) . map ( |p| p. len ( ) ) . unwrap_or ( 0 ) ;
167173 log:: trace!( "packet sent to stream: {network_tuple} len {len}" ) ;
168174 entry. get ( ) . send ( packet) . map_err ( std:: io:: Error :: other) ?;
169175 }
170176 std:: collections:: hash_map:: Entry :: Vacant ( entry) => {
171177 let ( tx, rx) = tokio:: sync:: oneshot:: channel :: < ( ) > ( ) ;
172- let ip_stack_stream = create_stream ( packet, config, up_pkt_sender, Some ( tx) ) ?;
178+ let ip_stack_stream = create_stream ( packet, config, up_pkt_sender. clone ( ) , Some ( tx) ) ?;
179+ let session_remove_tx = session_remove_tx. clone ( ) ;
173180 tokio:: spawn ( async move {
174181 rx. await . ok ( ) ;
175- sessions_clone. lock ( ) . await . remove ( & network_tuple) ;
176- log:: debug!( "session destroyed: {network_tuple}" ) ;
182+ session_remove_tx. send ( network_tuple) . ok ( ) ;
177183 } ) ;
178184 let packet_sender = ip_stack_stream. stream_sender ( ) ?;
179185 accept_sender. send ( ip_stack_stream) ?;
0 commit comments