diff --git a/Cargo.toml b/Cargo.toml index f1a1315..ac3aa7f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ authors = ['Narrowlink '] description = 'Asynchronous lightweight userspace implementation of TCP/IP stack for Tun device' name = "ipstack" -version = "0.4.0" +version = "0.5.0" edition = "2024" license = "Apache-2.0" repository = 'https://github.com/narrowlink/ipstack' diff --git a/examples/tun.rs b/examples/tun.rs index d39a1a3..5b2fad3 100644 --- a/examples/tun.rs +++ b/examples/tun.rs @@ -93,7 +93,10 @@ async fn main() -> Result<(), Box> { let mut ipstack_config = ipstack::IpStackConfig::default(); ipstack_config.mtu(MTU); - ipstack_config.tcp_timeout(std::time::Duration::from_secs(args.tcp_timeout)); + let mut tcp_config = ipstack::TcpConfig::default(); + tcp_config.timeout = std::time::Duration::from_secs(args.tcp_timeout); + tcp_config.options = Some(vec![ipstack::TcpOptions::MaximumSegmentSize(1460)]); + ipstack_config.with_tcp_config(tcp_config); ipstack_config.udp_timeout(std::time::Duration::from_secs(args.udp_timeout)); let mut ip_stack = ipstack::IpStack::new(ipstack_config, tun::create_as_async(&tun_config)?); diff --git a/src/lib.rs b/src/lib.rs index ba01b71..7ef9533 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,7 +2,7 @@ use ahash::AHashMap; use packet::{NetworkPacket, NetworkTuple, TransportHeader}; -use std::time::Duration; +use std::{sync::Arc, time::Duration}; use tokio::{ io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, select, @@ -20,7 +20,8 @@ mod stream; pub use self::error::{IpStackError, Result}; pub use self::stream::{IpStackStream, IpStackTcpStream, IpStackUdpStream, IpStackUnknownTransport}; -pub use ::etherparse::IpNumber; +pub use self::stream::{TcpConfig, TcpOptions}; +pub use etherparse::IpNumber; #[cfg(unix)] const TTL: u8 = 64; @@ -41,10 +42,11 @@ const TUN_PROTO_IP6: [u8; 2] = [0x00, 0x0A]; #[cfg(any(target_os = "macos", target_os = "ios"))] const TUN_PROTO_IP4: [u8; 2] = [0x00, 0x02]; +#[non_exhaustive] pub struct IpStackConfig { pub mtu: u16, pub packet_information: bool, - pub tcp_timeout: Duration, + pub tcp_config: Arc, pub udp_timeout: Duration, } @@ -53,15 +55,16 @@ impl Default for IpStackConfig { IpStackConfig { mtu: u16::MAX, packet_information: false, - tcp_timeout: Duration::from_secs(60), + tcp_config: Arc::new(TcpConfig::default()), udp_timeout: Duration::from_secs(30), } } } impl IpStackConfig { - pub fn tcp_timeout(&mut self, timeout: Duration) -> &mut Self { - self.tcp_timeout = timeout; + /// Set custom TCP configuration + pub fn with_tcp_config(&mut self, config: TcpConfig) -> &mut Self { + self.tcp_config = Arc::new(config); self } pub fn udp_timeout(&mut self, timeout: Duration) -> &mut Self { @@ -194,7 +197,7 @@ fn create_stream( let dst_addr = packet.dst_addr(); match packet.transport_header() { TransportHeader::Tcp(h) => { - let stream = IpStackTcpStream::new(src_addr, dst_addr, h.clone(), up_pkt_sender, cfg.mtu, cfg.tcp_timeout, msgr)?; + let stream = IpStackTcpStream::new(src_addr, dst_addr, h.clone(), up_pkt_sender, cfg.mtu, msgr, cfg.tcp_config.clone())?; Ok(IpStackStream::Tcp(stream)) } TransportHeader::Udp(_) => { diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 1f12f6a..467191f 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -1,6 +1,7 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; pub use self::tcp::IpStackTcpStream; +pub use self::tcp::{TcpConfig, TcpOptions}; pub use self::udp::IpStackUdpStream; pub use self::unknown::IpStackUnknownTransport; diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index e182105..4ece70d 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -9,12 +9,13 @@ use crate::{ }, stream::tcb::{PacketType, Tcb, TcpState}, }; -use etherparse::{IpNumber, Ipv4Header, Ipv6FlowLabel, TcpHeader}; +use etherparse::{IpNumber, Ipv4Header, Ipv6FlowLabel, TcpHeader, TcpOptionElement}; use std::{ future::Future, io::ErrorKind::{BrokenPipe, ConnectionRefused, InvalidInput, UnexpectedEof}, net::SocketAddr, pin::Pin, + sync::Arc, task::{Context, Poll, Waker}, time::Duration, }; @@ -26,6 +27,45 @@ const TWO_MSL: Duration = Duration::from_secs(2); const CLOSE_WAIT_TIMEOUT: Duration = Duration::from_secs(5); const LAST_ACK_MAX_RETRIES: usize = 3; const LAST_ACK_TIMEOUT: Duration = Duration::from_millis(500); +const TIMEOUT: Duration = Duration::from_secs(60); + +#[non_exhaustive] +#[derive(Debug, Clone)] +/// TCP configuration +pub struct TcpConfig { + /// Maximum number of retries for sending the last ACK in the LAST_ACK state. Default is 3. + pub last_ack_max_retries: usize, + /// Timeout for the last ACK in the LAST_ACK state. Default is 500ms. + pub last_ack_timeout: Duration, + /// Timeout for the CLOSE_WAIT state. Default is 5 seconds. + pub close_wait_timeout: Duration, + /// Timeout for TCP connections. Default is 60 seconds. + pub timeout: Duration, + /// Timeout for the TIME_WAIT state. Default is 2 seconds. + pub two_msl: Duration, + /// TCP options + pub options: Option>, +} + +#[non_exhaustive] +#[derive(Debug, Clone)] +pub enum TcpOptions { + /// Maximum segment size (MSS) for TCP connections. Default is 1460 bytes. + MaximumSegmentSize(u16), +} + +impl Default for TcpConfig { + fn default() -> Self { + TcpConfig { + last_ack_max_retries: LAST_ACK_MAX_RETRIES, + last_ack_timeout: LAST_ACK_TIMEOUT, + close_wait_timeout: CLOSE_WAIT_TIMEOUT, + timeout: TIMEOUT, + two_msl: TWO_MSL, + options: Default::default(), + } + } +} #[derive(Debug)] enum Shutdown { @@ -81,13 +121,13 @@ pub struct IpStackTcpStream { write_notify: std::sync::Arc>>, destroy_messenger: Option<::tokio::sync::oneshot::Sender<()>>, timeout: Pin>, - timeout_interval: Duration, data_tx: tokio::sync::mpsc::UnboundedSender>, data_rx: tokio::sync::mpsc::UnboundedReceiver>, read_notify: std::sync::Arc>>, task_handle: Option>>, exit_notifier: Option>, temp_read_buffer: Vec, + config: Arc, } impl IpStackTcpStream { @@ -97,14 +137,14 @@ impl IpStackTcpStream { tcp: TcpHeader, up_packet_sender: PacketSender, mtu: u16, - timeout_interval: Duration, destroy_messenger: Option<::tokio::sync::oneshot::Sender<()>>, + config: Arc, ) -> Result { let tcb = Tcb::new(SeqNum(tcp.sequence_number), mtu); let tuple = NetworkTuple::new(src_addr, dst_addr, true); if !tcp.syn { if !tcp.rst - && let Err(err) = write_packet_to_device(&up_packet_sender, tuple, &tcb, ACK | RST, None, None) + && let Err(err) = write_packet_to_device(&up_packet_sender, tuple, &tcb, None, ACK | RST, None, None) { log::warn!("Error sending RST/ACK packet: {err}"); } @@ -114,7 +154,7 @@ impl IpStackTcpStream { let (stream_sender, stream_receiver) = tokio::sync::mpsc::unbounded_channel::(); let (data_tx, data_rx) = tokio::sync::mpsc::unbounded_channel::>(); - let deadline = tokio::time::Instant::now() + timeout_interval; + let deadline = tokio::time::Instant::now() + config.timeout; let mut stream = IpStackTcpStream { src_addr, @@ -127,13 +167,13 @@ impl IpStackTcpStream { write_notify: std::sync::Arc::new(std::sync::Mutex::new(None)), destroy_messenger, timeout: Box::pin(tokio::time::sleep_until(deadline)), - timeout_interval, data_tx, data_rx, read_notify: std::sync::Arc::new(std::sync::Mutex::new(None)), task_handle: None, exit_notifier: None, temp_read_buffer: Vec::new(), + config, }; let sessions = SESSION_COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst).saturating_add(1); @@ -146,7 +186,7 @@ impl IpStackTcpStream { } fn reset_timeout(&mut self) { - let deadline = tokio::time::Instant::now() + self.timeout_interval; + let deadline = tokio::time::Instant::now() + self.config.timeout; self.timeout.as_mut().reset(deadline); } @@ -192,7 +232,7 @@ impl AsyncRead for IpStackTcpStream { let l_info = format!("local {{ seq: {seq}, ack: {ack} }}"); log::warn!("{network_tuple} {state:?}: [poll_read] {l_info}, session timeout reached, closing forcefully..."); let sender = &self.up_packet_sender; - write_packet_to_device(sender, network_tuple, &tcb, ACK | RST, None, None)?; + write_packet_to_device(sender, network_tuple, &tcb, None, ACK | RST, None, None)?; tcb.change_state(TcpState::Closed); let state = tcb.get_state(); log::warn!("{network_tuple} {state:?}: [poll_read] {l_info}, session notified to close"); @@ -251,7 +291,7 @@ impl AsyncWrite for IpStackTcpStream { let mut tcb = self.tcb.lock().unwrap(); let sender = &self.up_packet_sender; - let payload_len = write_packet_to_device(sender, nt, &tcb, ACK | PSH, None, Some(buf.to_vec()))?; + let payload_len = write_packet_to_device(sender, nt, &tcb, None, ACK | PSH, None, Some(buf.to_vec()))?; tcb.add_inflight_packet(buf[..payload_len].to_vec())?; let (state, seq, ack) = (tcb.get_state(), tcb.get_seq(), tcb.get_ack()); @@ -305,7 +345,7 @@ fn send_fin_n_change_state_to_fin_wait1(hint: &str, nt: NetworkTuple, sender: &P } log::debug!("{nt} {state:?}: {hint} actively send a farewell packet to the other side..."); - write_packet_to_device(sender, nt, tcb, ACK | FIN, None, None)?; + write_packet_to_device(sender, nt, tcb, None, ACK | FIN, None, None)?; tcb.increase_seq(); tcb.change_state(TcpState::FinWait1); let state = tcb.get_state(); @@ -350,11 +390,13 @@ impl IpStackTcpStream { let (exit_task_notifier, exit_monitor) = tokio::sync::mpsc::channel::<()>(10); let exit_notifier = exit_task_notifier.clone(); + let config = self.config.clone(); self.exit_notifier = Some(exit_task_notifier); let task_handle = tokio::spawn(async move { let v = tcp_main_logic_loop( tcb, + config, stream_receiver, up_packet_sender, exit_notifier, @@ -382,6 +424,7 @@ impl IpStackTcpStream { #[allow(clippy::too_many_arguments)] async fn tcp_main_logic_loop( tcb: TcbPtr, + config: Arc, mut stream_receiver: PacketReceiver, up_packet_sender: PacketSender, exit_notifier: tokio::sync::mpsc::Sender<()>, @@ -404,7 +447,15 @@ async fn tcp_main_logic_loop( let (seq, ack) = (tcb.get_seq().0, tcb.get_ack().0); let l_info = format!("local {{ seq: {seq}, ack: {ack} }}"); log::trace!("{network_tuple} {state:?}: {l_info} session begins"); - write_packet_to_device(&up_packet_sender, network_tuple, &tcb, ACK | SYN, None, None)?; + write_packet_to_device( + &up_packet_sender, + network_tuple, + &tcb, + config.options.as_ref(), + ACK | SYN, + None, + None, + )?; tcb.increase_seq(); tcb.change_state(TcpState::SynReceived); let state = tcb.get_state(); @@ -413,27 +464,34 @@ async fn tcp_main_logic_loop( let tcb_clone = tcb.clone(); - async fn task_wait_to_close(tcb: TcbPtr, exit_notifier: tokio::sync::mpsc::Sender<()>, nt: NetworkTuple) { - tokio::time::sleep(TWO_MSL).await; + async fn task_wait_to_close(tcb: TcbPtr, exit_notifier: tokio::sync::mpsc::Sender<()>, nt: NetworkTuple, two_msl: Duration) { + tokio::time::sleep(two_msl).await; { let mut tcb = tcb.lock().unwrap(); tcb.change_state(TcpState::Closed); let state = tcb.get_state(); - log::debug!("{nt} {state:?}: [task_wait_to_close] session closed after {TWO_MSL:?}"); + log::debug!("{nt} {state:?}: [task_wait_to_close] session closed after {two_msl:?}"); } exit_notifier.send(()).await.unwrap_or(()); } - async fn task_last_ack(tcb: TcbPtr, exit_notifier: tokio::sync::mpsc::Sender<()>, nt: NetworkTuple, pkt_sdr: PacketSender) { + async fn task_last_ack( + tcb: TcbPtr, + exit_notifier: tokio::sync::mpsc::Sender<()>, + nt: NetworkTuple, + pkt_sdr: PacketSender, + last_ack_timeout: Duration, + last_ack_max_retries: usize, + ) { let hint = "[task_last_ack]"; - for idx in 1..=LAST_ACK_MAX_RETRIES { + for idx in 1..=last_ack_max_retries { let state = { tcb.lock().unwrap().get_state() }; if state == TcpState::Closed { log::debug!("{nt} {state:?}: {hint} session closed, exiting 1..."); return; } - tokio::time::sleep(LAST_ACK_TIMEOUT).await; + tokio::time::sleep(last_ack_timeout).await; { let tcb = tcb.lock().unwrap(); @@ -442,8 +500,8 @@ async fn tcp_main_logic_loop( log::debug!("{nt} {state:?}: {hint} session closed, exiting 2..."); return; } - log::debug!("{nt} {state:?}: {hint} timer expired, resending ACK|FIN (retry {idx}/{LAST_ACK_MAX_RETRIES})"); - _ = write_packet_to_device(&pkt_sdr, nt, &tcb, ACK | FIN, None, None); + log::debug!("{nt} {state:?}: {hint} timer expired, resending ACK|FIN (retry {idx}/{last_ack_max_retries})"); + _ = write_packet_to_device(&pkt_sdr, nt, &tcb, None, ACK | FIN, None, None); } } { @@ -460,8 +518,11 @@ async fn tcp_main_logic_loop( exit_notifier: tokio::sync::mpsc::Sender<()>, nt: NetworkTuple, up_packet_sender: PacketSender, + close_wait_timeout: Duration, + last_ack_timeout: Duration, + last_ack_max_retries: usize, ) -> std::io::Result<()> { - tokio::time::sleep(CLOSE_WAIT_TIMEOUT).await; // Wait CLOSE_WAIT_TIMEOUT for upstream + tokio::time::sleep(close_wait_timeout).await; // Wait CLOSE_WAIT_TIMEOUT for upstream let tcb_clone = tcb.clone(); let mut tcb = tcb.lock().unwrap(); let state = tcb.get_state(); @@ -469,14 +530,21 @@ async fn tcp_main_logic_loop( return Ok(()); } log::warn!("{nt} {state:?}: Upstream timeout, forcing FIN"); - write_packet_to_device(&up_packet_sender, nt, &tcb, ACK | FIN, None, None)?; + write_packet_to_device(&up_packet_sender, nt, &tcb, None, ACK | FIN, None, None)?; tcb.increase_seq(); tcb.change_state(TcpState::LastAck); let new_state = tcb.get_state(); log::debug!("{nt} {state:?}: Forced transition to {new_state:?}"); // Here we set a timer to wait for the last ACK from the other side. - tokio::spawn(task_last_ack(tcb_clone, exit_notifier, nt, up_packet_sender)); + tokio::spawn(task_last_ack( + tcb_clone, + exit_notifier, + nt, + up_packet_sender, + last_ack_timeout, + last_ack_max_retries, + )); Ok::<(), std::io::Error>(()) } @@ -531,7 +599,15 @@ async fn tcp_main_logic_loop( for packet in tcb.collect_timed_out_inflight_packets() { let (seq, count) = (packet.seq, packet.retransmit_count); log::debug!("{network_tuple} inflight packet retransmission timeout: {seq:?}, retransmit_count: {count}",); - write_packet_to_device(&up_packet_sender, network_tuple, &tcb, ACK | PSH, Some(seq), Some(packet.payload))?; + write_packet_to_device( + &up_packet_sender, + network_tuple, + &tcb, + None, + ACK | PSH, + Some(seq), + Some(packet.payload), + )?; } let pkt_type = tcb.check_pkt_type(tcp_header, &payload); @@ -561,7 +637,7 @@ async fn tcp_main_logic_loop( write_notify.lock().unwrap().take().map(|w| w.wake_by_ref()).unwrap_or(()); } PacketType::KeepAlive => { - write_packet_to_device(&up_packet_sender, network_tuple, &tcb, ACK, None, None)?; + write_packet_to_device(&up_packet_sender, network_tuple, &tcb, None, ACK, None, None)?; } PacketType::RetransmissionRequest => { if let Some(packet) = tcb.find_inflight_packet(incoming_ack) { @@ -570,7 +646,7 @@ async fn tcp_main_logic_loop( "{network_tuple} {state:?}: {l_info}, {pkt_type:?}, retransmission request, seq = {s}, len = {}", p.len() ); - write_packet_to_device(&up_packet_sender, network_tuple, &tcb, ACK | PSH, Some(s), Some(p))?; + write_packet_to_device(&up_packet_sender, network_tuple, &tcb, None, ACK | PSH, Some(s), Some(p))?; } } PacketType::NewPacket => { @@ -587,7 +663,7 @@ async fn tcp_main_logic_loop( } else if flags == (ACK | FIN) { // The other side is closing the connection, we need to send an ACK and change state to CloseWait tcb.increase_ack(); - write_packet_to_device(&up_packet_sender, network_tuple, &tcb, ACK, None, None)?; + write_packet_to_device(&up_packet_sender, network_tuple, &tcb, None, ACK, None, None)?; tcb.change_state(TcpState::CloseWait); let s = tcb.get_state(); @@ -597,7 +673,7 @@ async fn tcp_main_logic_loop( log::trace!("{network_tuple} {s:?}: {l_info}, {pkt_type:?}, closed by the other side, no upstream data"); // Here we don't wait, just send FIN to the other side and change state to LastAck directly, - write_packet_to_device(&up_packet_sender, network_tuple, &tcb, ACK | FIN, None, None)?; + write_packet_to_device(&up_packet_sender, network_tuple, &tcb, None, ACK | FIN, None, None)?; tcb.increase_seq(); tcb.change_state(TcpState::LastAck); @@ -608,7 +684,14 @@ async fn tcp_main_logic_loop( // If the timer expires, we send an ACK|FIN packet to the other side again and wait anthoer timeout // till the retries reach the limit, and then close the session forcibly. let up = up_packet_sender.clone(); - tokio::spawn(task_last_ack(tcb_clone.clone(), exit_notifier, network_tuple, up)); + tokio::spawn(task_last_ack( + tcb_clone.clone(), + exit_notifier, + network_tuple, + up, + config.last_ack_timeout, + config.last_ack_max_retries, + )); } else { // Upstream data pending, wake write_notify and wait write_notify.lock().unwrap().take().map(|w| w.wake_by_ref()).unwrap_or(()); @@ -617,7 +700,15 @@ async fn tcp_main_logic_loop( // Spawn a timeout task to force FIN if upstream is unresponsive let tcb = tcb_clone.clone(); let up = up_packet_sender.clone(); - tokio::spawn(task_timed_out_for_close_wait(tcb, exit_notifier, network_tuple, up)); + tokio::spawn(task_timed_out_for_close_wait( + tcb, + exit_notifier, + network_tuple, + up, + config.close_wait_timeout, + config.last_ack_timeout, + config.last_ack_max_retries, + )); } } else if flags == (ACK | PSH) && pkt_type == PacketType::NewPacket { if !payload.is_empty() && tcb.get_ack() == incoming_seq { @@ -631,7 +722,7 @@ async fn tcp_main_logic_loop( } TcpState::CloseWait => { if flags & ACK == ACK && tcb.get_inflight_packets_total_len() == 0 { - write_packet_to_device(&up_packet_sender, network_tuple, &tcb, ACK | FIN, None, None)?; + write_packet_to_device(&up_packet_sender, network_tuple, &tcb, None, ACK | FIN, None, None)?; tcb.increase_seq(); tcb.change_state(TcpState::LastAck); let new_state = tcb.get_state(); @@ -641,7 +732,14 @@ async fn tcp_main_logic_loop( // If the timer expires, we send an ACK|FIN packet to the other side again and wait anthoer timeout // till the retries reach the limit, and then close the session forcibly. let up = up_packet_sender.clone(); - tokio::spawn(task_last_ack(tcb_clone.clone(), exit_notifier, network_tuple, up)); + tokio::spawn(task_last_ack( + tcb_clone.clone(), + exit_notifier, + network_tuple, + up, + config.last_ack_timeout, + config.last_ack_max_retries, + )); } else { write_notify.lock().unwrap().take().map(|w| w.wake_by_ref()).unwrap_or(()); } @@ -662,10 +760,10 @@ async fn tcp_main_logic_loop( if flags & (ACK | FIN) == (ACK | FIN) && len == 0 { // If the received packet is an ACK with FIN, we need to send an ACK and change state to TimeWait directly, not to FinWait2 tcb.increase_ack(); - write_packet_to_device(&up_packet_sender, network_tuple, &tcb, ACK, None, None)?; + write_packet_to_device(&up_packet_sender, network_tuple, &tcb, None, ACK, None, None)?; tcb.change_state(TcpState::TimeWait); - tokio::spawn(task_wait_to_close(tcb_clone.clone(), exit_notifier, network_tuple)); + tokio::spawn(task_wait_to_close(tcb_clone.clone(), exit_notifier, network_tuple, config.two_msl)); let new_state = tcb.get_state(); log::trace!("{network_tuple} {state:?}: Final ACK|FIN received too early, transitioned to {new_state:?} directly"); } else if flags & ACK == ACK { @@ -686,9 +784,9 @@ async fn tcp_main_logic_loop( TcpState::FinWait2 => { if flags & (ACK | FIN) == (ACK | FIN) && len == 0 { tcb.increase_ack(); - write_packet_to_device(&up_packet_sender, network_tuple, &tcb, ACK, None, None)?; + write_packet_to_device(&up_packet_sender, network_tuple, &tcb, None, ACK, None, None)?; tcb.change_state(TcpState::TimeWait); - tokio::spawn(task_wait_to_close(tcb_clone.clone(), exit_notifier, network_tuple)); + tokio::spawn(task_wait_to_close(tcb_clone.clone(), exit_notifier, network_tuple, config.two_msl)); let new_state = tcb.get_state(); log::trace!("{network_tuple} {state:?}: Received final ACK|FIN, transitioned to {new_state:?}"); } else if flags & ACK == ACK && len == 0 { @@ -699,7 +797,7 @@ async fn tcp_main_logic_loop( } } else if flags & ACK == ACK && len > 0 { if pkt_type == PacketType::KeepAlive { - write_packet_to_device(&up_packet_sender, network_tuple, &tcb, ACK, None, None)?; + write_packet_to_device(&up_packet_sender, network_tuple, &tcb, None, ACK, None, None)?; } else { // if the other side is still sending data, we need to deal with it like PacketStatus::NewPacket tcb.add_unordered_packet(incoming_seq, payload); @@ -708,7 +806,7 @@ async fn tcp_main_logic_loop( } if flags & FIN == FIN { tcb.change_state(TcpState::TimeWait); - tokio::spawn(task_wait_to_close(tcb_clone.clone(), exit_notifier, network_tuple)); + tokio::spawn(task_wait_to_close(tcb_clone.clone(), exit_notifier, network_tuple, config.two_msl)); let new_state = tcb.get_state(); log::trace!("{network_tuple} {state:?}: Received final ACK|FIN, transitioned to {new_state:?}"); } @@ -719,7 +817,7 @@ async fn tcp_main_logic_loop( } TcpState::TimeWait => { if flags & (ACK | FIN) == (ACK | FIN) { - write_packet_to_device(&up_packet_sender, network_tuple, &tcb, ACK, None, None)?; + write_packet_to_device(&up_packet_sender, network_tuple, &tcb, None, ACK, None, None)?; // wait to timeout, can't call `tcb.change_state(TcpState::Closed);` to change state here // now we need to wait for the timeout to reach... } @@ -752,7 +850,7 @@ fn extract_data_n_write_upstream( log::trace!("{network_tuple} {state:?}: {l_info} {hint} receiving data, len = {}", data.len()); data_tx.send(data).map_err(|e| std::io::Error::new(BrokenPipe, e))?; read_notify.lock().unwrap().take().map(|w| w.wake_by_ref()).unwrap_or(()); - write_packet_to_device(up_packet_sender, network_tuple, tcb, ACK, None, None)?; + write_packet_to_device(up_packet_sender, network_tuple, tcb, None, ACK, None, None)?; } Ok(()) } @@ -763,6 +861,7 @@ pub(crate) fn write_packet_to_device( up_packet_sender: &PacketSender, tuple: NetworkTuple, tcb: &Tcb, + options: Option<&Vec>, flags: u8, seq: Option, payload: Option>, @@ -772,7 +871,18 @@ pub(crate) fn write_packet_to_device( let (ack, window_size) = (tcb.get_ack().0, tcb.get_recv_window().max(tcb.get_mtu())); let (src, dst) = (tuple.dst, tuple.src); // Note: The address is reversed here let calc = |ip_header_len: usize, tcp_header_len: usize| tcb.calculate_payload_max_len(ip_header_len, tcp_header_len); - let packet = create_raw_packet(src, dst, calc, flags, TTL, seq, ack, window_size, payload.unwrap_or_default())?; + let packet = create_raw_packet( + src, + dst, + calc, + flags, + TTL, + seq, + ack, + window_size, + payload.unwrap_or_default(), + options, + )?; let len = packet.payload.as_ref().map(|p| p.len()).unwrap_or(0); up_packet_sender.send(packet).map_err(|e| Error::new(UnexpectedEof, e))?; Ok(len) @@ -789,6 +899,7 @@ pub(crate) fn create_raw_packet( ack: u32, win: u16, mut payload: Vec, + options: Option<&Vec>, ) -> std::io::Result { let mut tcp_header = etherparse::TcpHeader::new(src_addr.port(), dst_addr.port(), seq, win); tcp_header.acknowledgment_number = ack; @@ -798,6 +909,17 @@ pub(crate) fn create_raw_packet( tcp_header.fin = flags & FIN != 0; tcp_header.psh = flags & PSH != 0; + if let Some(opts) = options { + let mut tcp_options = Vec::new(); + for opt in opts { + match opt { + TcpOptions::MaximumSegmentSize(mss) => tcp_options.push(TcpOptionElement::MaximumSegmentSize(*mss)), + } + } + tcp_header + .set_options(&tcp_options) + .map_err(|e| std::io::Error::new(InvalidInput, e))?; + } let ip_header = match (src_addr.ip(), dst_addr.ip()) { (std::net::IpAddr::V4(src), std::net::IpAddr::V4(dst)) => { let mut ip_h =