Skip to content

Commit f6dd88b

Browse files
committed
refactor: Avoid using tokio locks to reduce performance loss
1 parent 74fdcad commit f6dd88b

1 file changed

Lines changed: 18 additions & 12 deletions

File tree

src/lib.rs

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use tokio::{
1212

1313
pub(crate) type PacketSender = UnboundedSender<NetworkPacket>;
1414
pub(crate) type PacketReceiver = UnboundedReceiver<NetworkPacket>;
15-
pub(crate) type SessionCollection = std::sync::Arc<tokio::sync::Mutex<AHashMap<NetworkTuple, PacketSender>>>;
15+
pub(crate) type SessionCollection = AHashMap<NetworkTuple, PacketSender>;
1616

1717
mod error;
1818
mod packet;
@@ -105,7 +105,8 @@ fn run<Device: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
105105
mut device: Device,
106106
accept_sender: UnboundedSender<IpStackStream>,
107107
) -> JoinHandle<Result<()>> {
108-
let sessions: SessionCollection = std::sync::Arc::new(tokio::sync::Mutex::new(AHashMap::new()));
108+
let mut sessions: SessionCollection = AHashMap::new();
109+
let (session_remove_tx, mut session_remove_rx) = mpsc::unbounded_channel::<NetworkTuple>();
109110
let pi = config.packet_information;
110111
let offset = if pi && cfg!(unix) { 4 } else { 0 };
111112
let mut buffer = vec![0_u8; u16::MAX as usize + offset];
@@ -115,8 +116,7 @@ fn run<Device: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
115116
loop {
116117
select! {
117118
Ok(n) = device.read(&mut buffer) => {
118-
let u = up_pkt_sender.clone();
119-
if let Err(e) = process_device_read(&buffer[offset..n], sessions.clone(), u, &config, &accept_sender).await {
119+
if let Err(e) = process_device_read(&buffer[offset..n], &mut sessions,&session_remove_tx, &up_pkt_sender, &config, &accept_sender).await {
120120
let io_err: std::io::Error = e.into();
121121
if io_err.kind() == std::io::ErrorKind::ConnectionRefused {
122122
log::trace!("Received junk data: {io_err}");
@@ -125,6 +125,12 @@ fn run<Device: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
125125
}
126126
}
127127
}
128+
network_tuple = session_remove_rx.recv() => {
129+
if let Some(network_tuple) = network_tuple {
130+
sessions.remove(&network_tuple);
131+
log::debug!("session destroyed: {network_tuple}");
132+
}
133+
}
128134
Some(packet) = up_pkt_receiver.recv() => {
129135
process_upstream_recv(packet, &mut device, #[cfg(unix)]pi).await?;
130136
}
@@ -135,8 +141,9 @@ fn run<Device: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
135141

136142
async fn process_device_read(
137143
data: &[u8],
138-
sessions: SessionCollection,
139-
up_pkt_sender: PacketSender,
144+
sessions: &mut SessionCollection,
145+
session_remove_tx: &UnboundedSender<NetworkTuple>,
146+
up_pkt_sender: &PacketSender,
140147
config: &IpStackConfig,
141148
accept_sender: &UnboundedSender<IpStackStream>,
142149
) -> Result<()> {
@@ -153,27 +160,26 @@ async fn process_device_read(
153160
packet.payload.unwrap_or_default(),
154161
&packet.ip,
155162
config.mtu,
156-
up_pkt_sender,
163+
up_pkt_sender.clone(),
157164
));
158165
accept_sender.send(stream)?;
159166
return Ok(());
160167
}
161168

162-
let sessions_clone = sessions.clone();
163169
let network_tuple = packet.network_tuple();
164-
match sessions.lock().await.entry(network_tuple) {
170+
match sessions.entry(network_tuple) {
165171
std::collections::hash_map::Entry::Occupied(entry) => {
166172
let len = packet.payload.as_ref().map(|p| p.len()).unwrap_or(0);
167173
log::trace!("packet sent to stream: {network_tuple} len {len}");
168174
entry.get().send(packet).map_err(std::io::Error::other)?;
169175
}
170176
std::collections::hash_map::Entry::Vacant(entry) => {
171177
let (tx, rx) = tokio::sync::oneshot::channel::<()>();
172-
let ip_stack_stream = create_stream(packet, config, up_pkt_sender, Some(tx))?;
178+
let ip_stack_stream = create_stream(packet, config, up_pkt_sender.clone(), Some(tx))?;
179+
let session_remove_tx = session_remove_tx.clone();
173180
tokio::spawn(async move {
174181
rx.await.ok();
175-
sessions_clone.lock().await.remove(&network_tuple);
176-
log::debug!("session destroyed: {network_tuple}");
182+
session_remove_tx.send(network_tuple).ok();
177183
});
178184
let packet_sender = ip_stack_stream.stream_sender()?;
179185
accept_sender.send(ip_stack_stream)?;

0 commit comments

Comments
 (0)