Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/tun.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

let mut ipstack_config = ipstack::IpStackConfig::default();
ipstack_config.mtu(MTU);
ipstack_config.tcp_timeout(std::time::Duration::from_secs(args.tcp_timeout));
let mut tcp_config = ipstack::TcpConfig::default();
tcp_config.timeout = std::time::Duration::from_secs(args.tcp_timeout);
ipstack_config.with_tcp_config(tcp_config);
ipstack_config.udp_timeout(std::time::Duration::from_secs(args.udp_timeout));

let mut ip_stack = ipstack::IpStack::new(ipstack_config, tun::create_as_async(&tun_config)?);
Expand Down
18 changes: 10 additions & 8 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use ahash::AHashMap;
use packet::{NetworkPacket, NetworkTuple, TransportHeader};
use std::time::Duration;
use std::{sync::Arc, time::Duration};
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
select,
Expand All @@ -19,8 +19,9 @@ mod packet;
mod stream;

pub use self::error::{IpStackError, Result};
pub use self::stream::TcpConfig;
pub use self::stream::{IpStackStream, IpStackTcpStream, IpStackUdpStream, IpStackUnknownTransport};
pub use ::etherparse::IpNumber;
pub use etherparse::IpNumber;

#[cfg(unix)]
const TTL: u8 = 64;
Expand All @@ -41,28 +42,29 @@ const TUN_PROTO_IP6: [u8; 2] = [0x00, 0x0A];
#[cfg(any(target_os = "macos", target_os = "ios"))]
const TUN_PROTO_IP4: [u8; 2] = [0x00, 0x02];

#[non_exhaustive]
Comment thread
SajjadPourali marked this conversation as resolved.
pub struct IpStackConfig {
pub mtu: u16,
pub packet_information: bool,
pub tcp_timeout: Duration,
pub udp_timeout: Duration,
pub tcp_config: Arc<TcpConfig>,
}

impl Default for IpStackConfig {
fn default() -> Self {
IpStackConfig {
mtu: u16::MAX,
packet_information: false,
tcp_timeout: Duration::from_secs(60),
tcp_config: Arc::new(TcpConfig::default()),
udp_timeout: Duration::from_secs(30),
}
}
}

impl IpStackConfig {
pub fn tcp_timeout(&mut self, timeout: Duration) -> &mut Self {
self.tcp_timeout = timeout;
self
/// Set custom TCP configuration
pub fn with_tcp_config(&mut self, config: TcpConfig) {
self.tcp_config = Arc::new(config);
Comment thread
SajjadPourali marked this conversation as resolved.
Outdated
}
pub fn udp_timeout(&mut self, timeout: Duration) -> &mut Self {
self.udp_timeout = timeout;
Expand Down Expand Up @@ -194,7 +196,7 @@ fn create_stream(
let dst_addr = packet.dst_addr();
match packet.transport_header() {
TransportHeader::Tcp(h) => {
let stream = IpStackTcpStream::new(src_addr, dst_addr, h.clone(), up_pkt_sender, cfg.mtu, cfg.tcp_timeout, msgr)?;
let stream = IpStackTcpStream::new(src_addr, dst_addr, h.clone(), up_pkt_sender, cfg.mtu, msgr, cfg.tcp_config.clone())?;
Ok(IpStackStream::Tcp(stream))
}
TransportHeader::Udp(_) => {
Expand Down
1 change: 1 addition & 0 deletions src/stream/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4, SocketAddrV6};

pub use self::tcp::IpStackTcpStream;
pub use self::tcp::TcpConfig;
pub use self::udp::IpStackUdpStream;
pub use self::unknown::IpStackUnknownTransport;

Expand Down
112 changes: 92 additions & 20 deletions src/stream/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use std::{
io::ErrorKind::{BrokenPipe, ConnectionRefused, InvalidInput, UnexpectedEof},
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{Context, Poll, Waker},
time::Duration,
};
Expand All @@ -26,6 +27,35 @@ const TWO_MSL: Duration = Duration::from_secs(2);
const CLOSE_WAIT_TIMEOUT: Duration = Duration::from_secs(5);
const LAST_ACK_MAX_RETRIES: usize = 3;
const LAST_ACK_TIMEOUT: Duration = Duration::from_millis(500);
const TIMEOUT: Duration = Duration::from_secs(60);

#[non_exhaustive]
#[derive(Debug, Clone)]
/// TCP configuration
pub struct TcpConfig {
/// Maximum number of retries for sending the last ACK in the LAST_ACK state. Default is 3.
pub last_ack_max_retries: usize,
/// Timeout for the last ACK in the LAST_ACK state. Default is 500ms.
pub last_ack_timeout: Duration,
/// Timeout for the CLOSE_WAIT state. Default is 5 seconds.
pub close_wait_timeout: Duration,
/// Timeout for TCP connections. Default is 60 seconds.
pub timeout: Duration,
/// Timeout for the TIME_WAIT state. Default is 2 seconds.
pub two_msl: Duration,
}

impl Default for TcpConfig {
fn default() -> Self {
TcpConfig {
last_ack_max_retries: LAST_ACK_MAX_RETRIES,
last_ack_timeout: LAST_ACK_TIMEOUT,
close_wait_timeout: CLOSE_WAIT_TIMEOUT,
timeout: TIMEOUT,
two_msl: TWO_MSL,
}
}
}

#[derive(Debug)]
enum Shutdown {
Expand Down Expand Up @@ -81,13 +111,13 @@ pub struct IpStackTcpStream {
write_notify: std::sync::Arc<std::sync::Mutex<Option<Waker>>>,
destroy_messenger: Option<::tokio::sync::oneshot::Sender<()>>,
timeout: Pin<Box<tokio::time::Sleep>>,
timeout_interval: Duration,
data_tx: tokio::sync::mpsc::UnboundedSender<Vec<u8>>,
data_rx: tokio::sync::mpsc::UnboundedReceiver<Vec<u8>>,
read_notify: std::sync::Arc<std::sync::Mutex<Option<Waker>>>,
task_handle: Option<tokio::task::JoinHandle<std::io::Result<()>>>,
exit_notifier: Option<tokio::sync::mpsc::Sender<()>>,
temp_read_buffer: Vec<u8>,
config: Arc<TcpConfig>,
}

impl IpStackTcpStream {
Expand All @@ -97,8 +127,8 @@ impl IpStackTcpStream {
tcp: TcpHeader,
up_packet_sender: PacketSender,
mtu: u16,
timeout_interval: Duration,
destroy_messenger: Option<::tokio::sync::oneshot::Sender<()>>,
config: Arc<TcpConfig>,
) -> Result<IpStackTcpStream, IpStackError> {
let tcb = Tcb::new(SeqNum(tcp.sequence_number), mtu);
let tuple = NetworkTuple::new(src_addr, dst_addr, true);
Expand All @@ -114,7 +144,7 @@ impl IpStackTcpStream {

let (stream_sender, stream_receiver) = tokio::sync::mpsc::unbounded_channel::<NetworkPacket>();
let (data_tx, data_rx) = tokio::sync::mpsc::unbounded_channel::<Vec<u8>>();
let deadline = tokio::time::Instant::now() + timeout_interval;
let deadline = tokio::time::Instant::now() + config.timeout;

let mut stream = IpStackTcpStream {
src_addr,
Expand All @@ -127,13 +157,13 @@ impl IpStackTcpStream {
write_notify: std::sync::Arc::new(std::sync::Mutex::new(None)),
destroy_messenger,
timeout: Box::pin(tokio::time::sleep_until(deadline)),
timeout_interval,
data_tx,
data_rx,
read_notify: std::sync::Arc::new(std::sync::Mutex::new(None)),
task_handle: None,
exit_notifier: None,
temp_read_buffer: Vec::new(),
config,
};

let sessions = SESSION_COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst).saturating_add(1);
Expand All @@ -146,7 +176,7 @@ impl IpStackTcpStream {
}

fn reset_timeout(&mut self) {
let deadline = tokio::time::Instant::now() + self.timeout_interval;
let deadline = tokio::time::Instant::now() + self.config.timeout;
self.timeout.as_mut().reset(deadline);
}

Expand Down Expand Up @@ -350,11 +380,13 @@ impl IpStackTcpStream {

let (exit_task_notifier, exit_monitor) = tokio::sync::mpsc::channel::<()>(10);
let exit_notifier = exit_task_notifier.clone();
let config = self.config.clone();
self.exit_notifier = Some(exit_task_notifier);

let task_handle = tokio::spawn(async move {
let v = tcp_main_logic_loop(
tcb,
config,
stream_receiver,
up_packet_sender,
exit_notifier,
Expand Down Expand Up @@ -382,6 +414,7 @@ impl IpStackTcpStream {
#[allow(clippy::too_many_arguments)]
async fn tcp_main_logic_loop(
tcb: TcbPtr,
config: Arc<TcpConfig>,
mut stream_receiver: PacketReceiver,
up_packet_sender: PacketSender,
exit_notifier: tokio::sync::mpsc::Sender<()>,
Expand Down Expand Up @@ -413,27 +446,34 @@ async fn tcp_main_logic_loop(

let tcb_clone = tcb.clone();

async fn task_wait_to_close(tcb: TcbPtr, exit_notifier: tokio::sync::mpsc::Sender<()>, nt: NetworkTuple) {
tokio::time::sleep(TWO_MSL).await;
async fn task_wait_to_close(tcb: TcbPtr, exit_notifier: tokio::sync::mpsc::Sender<()>, nt: NetworkTuple, two_msl: Duration) {
tokio::time::sleep(two_msl).await;
{
let mut tcb = tcb.lock().unwrap();
tcb.change_state(TcpState::Closed);
let state = tcb.get_state();
log::debug!("{nt} {state:?}: [task_wait_to_close] session closed after {TWO_MSL:?}");
log::debug!("{nt} {state:?}: [task_wait_to_close] session closed after {two_msl:?}");
}
exit_notifier.send(()).await.unwrap_or(());
}

async fn task_last_ack(tcb: TcbPtr, exit_notifier: tokio::sync::mpsc::Sender<()>, nt: NetworkTuple, pkt_sdr: PacketSender) {
async fn task_last_ack(
tcb: TcbPtr,
exit_notifier: tokio::sync::mpsc::Sender<()>,
nt: NetworkTuple,
pkt_sdr: PacketSender,
last_ack_timeout: Duration,
last_ack_max_retries: usize,
) {
let hint = "[task_last_ack]";
for idx in 1..=LAST_ACK_MAX_RETRIES {
for idx in 1..=last_ack_max_retries {
let state = { tcb.lock().unwrap().get_state() };
if state == TcpState::Closed {
log::debug!("{nt} {state:?}: {hint} session closed, exiting 1...");
return;
}

tokio::time::sleep(LAST_ACK_TIMEOUT).await;
tokio::time::sleep(last_ack_timeout).await;

{
let tcb = tcb.lock().unwrap();
Expand All @@ -442,7 +482,7 @@ async fn tcp_main_logic_loop(
log::debug!("{nt} {state:?}: {hint} session closed, exiting 2...");
return;
}
log::debug!("{nt} {state:?}: {hint} timer expired, resending ACK|FIN (retry {idx}/{LAST_ACK_MAX_RETRIES})");
log::debug!("{nt} {state:?}: {hint} timer expired, resending ACK|FIN (retry {idx}/{last_ack_max_retries})");
_ = write_packet_to_device(&pkt_sdr, nt, &tcb, ACK | FIN, None, None);
}
}
Expand All @@ -460,8 +500,11 @@ async fn tcp_main_logic_loop(
exit_notifier: tokio::sync::mpsc::Sender<()>,
nt: NetworkTuple,
up_packet_sender: PacketSender,
close_wait_timeout: Duration,
last_ack_timeout: Duration,
last_ack_max_retries: usize,
) -> std::io::Result<()> {
tokio::time::sleep(CLOSE_WAIT_TIMEOUT).await; // Wait CLOSE_WAIT_TIMEOUT for upstream
tokio::time::sleep(close_wait_timeout).await; // Wait CLOSE_WAIT_TIMEOUT for upstream
let tcb_clone = tcb.clone();
let mut tcb = tcb.lock().unwrap();
let state = tcb.get_state();
Expand All @@ -476,7 +519,14 @@ async fn tcp_main_logic_loop(
log::debug!("{nt} {state:?}: Forced transition to {new_state:?}");

// Here we set a timer to wait for the last ACK from the other side.
tokio::spawn(task_last_ack(tcb_clone, exit_notifier, nt, up_packet_sender));
tokio::spawn(task_last_ack(
tcb_clone,
exit_notifier,
nt,
up_packet_sender,
last_ack_timeout,
last_ack_max_retries,
));

Ok::<(), std::io::Error>(())
}
Expand Down Expand Up @@ -608,7 +658,14 @@ async fn tcp_main_logic_loop(
// If the timer expires, we send an ACK|FIN packet to the other side again and wait anthoer timeout
// till the retries reach the limit, and then close the session forcibly.
let up = up_packet_sender.clone();
tokio::spawn(task_last_ack(tcb_clone.clone(), exit_notifier, network_tuple, up));
tokio::spawn(task_last_ack(
tcb_clone.clone(),
exit_notifier,
network_tuple,
up,
config.last_ack_timeout,
config.last_ack_max_retries,
));
} else {
// Upstream data pending, wake write_notify and wait
write_notify.lock().unwrap().take().map(|w| w.wake_by_ref()).unwrap_or(());
Expand All @@ -617,7 +674,15 @@ async fn tcp_main_logic_loop(
// Spawn a timeout task to force FIN if upstream is unresponsive
let tcb = tcb_clone.clone();
let up = up_packet_sender.clone();
tokio::spawn(task_timed_out_for_close_wait(tcb, exit_notifier, network_tuple, up));
tokio::spawn(task_timed_out_for_close_wait(
tcb,
exit_notifier,
network_tuple,
up,
config.close_wait_timeout,
config.last_ack_timeout,
config.last_ack_max_retries,
));
}
} else if flags == (ACK | PSH) && pkt_type == PacketType::NewPacket {
if !payload.is_empty() && tcb.get_ack() == incoming_seq {
Expand All @@ -641,7 +706,14 @@ async fn tcp_main_logic_loop(
// If the timer expires, we send an ACK|FIN packet to the other side again and wait anthoer timeout
// till the retries reach the limit, and then close the session forcibly.
let up = up_packet_sender.clone();
tokio::spawn(task_last_ack(tcb_clone.clone(), exit_notifier, network_tuple, up));
tokio::spawn(task_last_ack(
tcb_clone.clone(),
exit_notifier,
network_tuple,
up,
config.last_ack_timeout,
config.last_ack_max_retries,
));
} else {
write_notify.lock().unwrap().take().map(|w| w.wake_by_ref()).unwrap_or(());
}
Expand All @@ -665,7 +737,7 @@ async fn tcp_main_logic_loop(
write_packet_to_device(&up_packet_sender, network_tuple, &tcb, ACK, None, None)?;
tcb.change_state(TcpState::TimeWait);

tokio::spawn(task_wait_to_close(tcb_clone.clone(), exit_notifier, network_tuple));
tokio::spawn(task_wait_to_close(tcb_clone.clone(), exit_notifier, network_tuple, config.two_msl));
let new_state = tcb.get_state();
log::trace!("{network_tuple} {state:?}: Final ACK|FIN received too early, transitioned to {new_state:?} directly");
} else if flags & ACK == ACK {
Expand All @@ -688,7 +760,7 @@ async fn tcp_main_logic_loop(
tcb.increase_ack();
write_packet_to_device(&up_packet_sender, network_tuple, &tcb, ACK, None, None)?;
tcb.change_state(TcpState::TimeWait);
tokio::spawn(task_wait_to_close(tcb_clone.clone(), exit_notifier, network_tuple));
tokio::spawn(task_wait_to_close(tcb_clone.clone(), exit_notifier, network_tuple, config.two_msl));
let new_state = tcb.get_state();
log::trace!("{network_tuple} {state:?}: Received final ACK|FIN, transitioned to {new_state:?}");
} else if flags & ACK == ACK && len == 0 {
Expand All @@ -708,7 +780,7 @@ async fn tcp_main_logic_loop(
}
if flags & FIN == FIN {
tcb.change_state(TcpState::TimeWait);
tokio::spawn(task_wait_to_close(tcb_clone.clone(), exit_notifier, network_tuple));
tokio::spawn(task_wait_to_close(tcb_clone.clone(), exit_notifier, network_tuple, config.two_msl));
let new_state = tcb.get_state();
log::trace!("{network_tuple} {state:?}: Received final ACK|FIN, transitioned to {new_state:?}");
}
Expand Down
Loading