|
| 1 | +//! OS-backed [`NetProvider`] — tokio TCP sockets for `runtime:net` (SPEC §12). |
| 2 | +//! |
| 3 | +//! Each socket's I/O runs in **spawned runtime tasks** (a reader and a writer) |
| 4 | +//! that move bytes over channels; the ops just send/recv on those channels. |
| 5 | +//! This is the same shape the HTTP client uses: the actual I/O is driven by the |
| 6 | +//! runtime's reactor (via spawned tasks), so reads that must wait for bytes make |
| 7 | +//! progress — polling the raw socket future inline from the op loop would not. |
| 8 | +//! TLS is a follow-up; `connect(tls = true)` errors rather than downgrading. |
| 9 | +
|
| 10 | +use std::collections::HashMap; |
| 11 | +use std::net::SocketAddr; |
| 12 | +use std::sync::atomic::{AtomicU64, Ordering}; |
| 13 | +use std::sync::{Arc, Mutex}; |
| 14 | + |
| 15 | +use es_runtime_providers::{BoxFuture, NetProvider, ProviderError, SocketInfo}; |
| 16 | +use tokio::io::{AsyncReadExt, AsyncWriteExt}; |
| 17 | +use tokio::net::{TcpListener, TcpStream}; |
| 18 | +use tokio::sync::mpsc; |
| 19 | + |
| 20 | +type ReadRx = mpsc::Receiver<Result<Vec<u8>, String>>; |
| 21 | +type WriteTx = mpsc::Sender<Vec<u8>>; |
| 22 | +type AcceptRx = mpsc::Receiver<(TcpStream, SocketAddr)>; |
| 23 | + |
| 24 | +/// A connection's channel ends. `read_rx` is taken out during a read; `write_tx` |
| 25 | +/// is cloned to send and dropped (set to `None`) to half-close. |
| 26 | +struct Slot { |
| 27 | + read_rx: Option<ReadRx>, |
| 28 | + write_tx: Option<WriteTx>, |
| 29 | +} |
| 30 | + |
| 31 | +/// A [`NetProvider`] over real tokio TCP sockets. The `Arc`s are cloned into each |
| 32 | +/// returned future so the futures stay `'static`. |
| 33 | +#[derive(Clone, Default)] |
| 34 | +pub struct SystemNet { |
| 35 | + sockets: Arc<Mutex<HashMap<u64, Slot>>>, |
| 36 | + listeners: Arc<Mutex<HashMap<u64, AcceptRx>>>, |
| 37 | + next_id: Arc<AtomicU64>, |
| 38 | +} |
| 39 | + |
| 40 | +impl SystemNet { |
| 41 | + /// Builds an empty socket registry. |
| 42 | + pub fn new() -> Self { |
| 43 | + Self::default() |
| 44 | + } |
| 45 | + |
| 46 | + fn id(&self) -> u64 { |
| 47 | + self.next_id.fetch_add(1, Ordering::Relaxed) + 1 |
| 48 | + } |
| 49 | + |
| 50 | + /// Splits `stream` and spawns its reader + writer tasks, returning the |
| 51 | + /// channel ends to register. |
| 52 | + fn spawn_socket(stream: TcpStream) -> Slot { |
| 53 | + let (mut r, mut w) = stream.into_split(); |
| 54 | + let (read_tx, read_rx) = mpsc::channel::<Result<Vec<u8>, String>>(8); |
| 55 | + let (write_tx, mut write_rx) = mpsc::channel::<Vec<u8>>(8); |
| 56 | + |
| 57 | + tokio::spawn(async move { |
| 58 | + let mut buf = vec![0u8; 64 * 1024]; |
| 59 | + loop { |
| 60 | + match r.read(&mut buf).await { |
| 61 | + Ok(0) => break, // EOF — dropping read_tx signals it |
| 62 | + Ok(n) => { |
| 63 | + if read_tx.send(Ok(buf[..n].to_vec())).await.is_err() { |
| 64 | + break; // consumer gone |
| 65 | + } |
| 66 | + } |
| 67 | + Err(e) => { |
| 68 | + let _ = read_tx.send(Err(e.to_string())).await; |
| 69 | + break; |
| 70 | + } |
| 71 | + } |
| 72 | + } |
| 73 | + }); |
| 74 | + |
| 75 | + tokio::spawn(async move { |
| 76 | + while let Some(data) = write_rx.recv().await { |
| 77 | + if w.write_all(&data).await.is_err() { |
| 78 | + break; |
| 79 | + } |
| 80 | + } |
| 81 | + let _ = w.shutdown().await; // write_tx dropped (half-close / close) |
| 82 | + }); |
| 83 | + |
| 84 | + Slot { |
| 85 | + read_rx: Some(read_rx), |
| 86 | + write_tx: Some(write_tx), |
| 87 | + } |
| 88 | + } |
| 89 | +} |
| 90 | + |
| 91 | +fn err(e: impl ToString) -> ProviderError { |
| 92 | + ProviderError::Other(e.to_string()) |
| 93 | +} |
| 94 | + |
| 95 | +fn info_of(local: Option<SocketAddr>, remote: Option<SocketAddr>) -> SocketInfo { |
| 96 | + SocketInfo { |
| 97 | + remote_address: remote.map(|a| a.ip().to_string()).unwrap_or_default(), |
| 98 | + remote_port: remote.map(|a| a.port()).unwrap_or(0), |
| 99 | + local_address: local.map(|a| a.ip().to_string()).unwrap_or_default(), |
| 100 | + local_port: local.map(|a| a.port()).unwrap_or(0), |
| 101 | + } |
| 102 | +} |
| 103 | + |
| 104 | +impl NetProvider for SystemNet { |
| 105 | + fn connect( |
| 106 | + &self, |
| 107 | + host: String, |
| 108 | + port: u16, |
| 109 | + tls: bool, |
| 110 | + ) -> BoxFuture<Result<(u64, SocketInfo), ProviderError>> { |
| 111 | + let this = self.clone(); |
| 112 | + Box::pin(async move { |
| 113 | + if tls { |
| 114 | + return Err(err( |
| 115 | + "runtime:net TLS is not supported yet (plaintext TCP only)", |
| 116 | + )); |
| 117 | + } |
| 118 | + let stream = TcpStream::connect((host.as_str(), port)) |
| 119 | + .await |
| 120 | + .map_err(err)?; |
| 121 | + let _ = stream.set_nodelay(true); |
| 122 | + let info = info_of(stream.local_addr().ok(), stream.peer_addr().ok()); |
| 123 | + let id = this.id(); |
| 124 | + this.sockets |
| 125 | + .lock() |
| 126 | + .unwrap() |
| 127 | + .insert(id, SystemNet::spawn_socket(stream)); |
| 128 | + Ok((id, info)) |
| 129 | + }) |
| 130 | + } |
| 131 | + |
| 132 | + fn read(&self, id: u64) -> BoxFuture<Result<Option<Vec<u8>>, ProviderError>> { |
| 133 | + let sockets = self.sockets.clone(); |
| 134 | + Box::pin(async move { |
| 135 | + let mut rx = match sockets |
| 136 | + .lock() |
| 137 | + .unwrap() |
| 138 | + .get_mut(&id) |
| 139 | + .and_then(|s| s.read_rx.take()) |
| 140 | + { |
| 141 | + Some(rx) => rx, |
| 142 | + None => return Ok(None), // closed or already at EOF |
| 143 | + }; |
| 144 | + match rx.recv().await { |
| 145 | + Some(Ok(buf)) => { |
| 146 | + if let Some(slot) = sockets.lock().unwrap().get_mut(&id) { |
| 147 | + slot.read_rx = Some(rx); |
| 148 | + } |
| 149 | + Ok(Some(buf)) |
| 150 | + } |
| 151 | + Some(Err(e)) => Err(err(e)), |
| 152 | + None => Ok(None), // reader task ended (EOF) — leave it taken |
| 153 | + } |
| 154 | + }) |
| 155 | + } |
| 156 | + |
| 157 | + fn write(&self, id: u64, data: Vec<u8>) -> BoxFuture<Result<(), ProviderError>> { |
| 158 | + let sockets = self.sockets.clone(); |
| 159 | + Box::pin(async move { |
| 160 | + let tx = sockets |
| 161 | + .lock() |
| 162 | + .unwrap() |
| 163 | + .get(&id) |
| 164 | + .and_then(|s| s.write_tx.clone()); |
| 165 | + match tx { |
| 166 | + Some(tx) => tx.send(data).await.map_err(|_| err("socket is closed")), |
| 167 | + None => Err(err("socket is closed")), |
| 168 | + } |
| 169 | + }) |
| 170 | + } |
| 171 | + |
| 172 | + fn shutdown(&self, id: u64) -> BoxFuture<Result<(), ProviderError>> { |
| 173 | + let sockets = self.sockets.clone(); |
| 174 | + Box::pin(async move { |
| 175 | + // Drop the sender: the writer task's recv() ends and it shuts down |
| 176 | + // the write half (FIN). The read half keeps working. |
| 177 | + if let Some(slot) = sockets.lock().unwrap().get_mut(&id) { |
| 178 | + slot.write_tx = None; |
| 179 | + } |
| 180 | + Ok(()) |
| 181 | + }) |
| 182 | + } |
| 183 | + |
| 184 | + fn close(&self, id: u64) -> BoxFuture<Result<(), ProviderError>> { |
| 185 | + let sockets = self.sockets.clone(); |
| 186 | + Box::pin(async move { |
| 187 | + // Dropping the slot drops both channel ends, ending both tasks. |
| 188 | + sockets.lock().unwrap().remove(&id); |
| 189 | + Ok(()) |
| 190 | + }) |
| 191 | + } |
| 192 | + |
| 193 | + fn listen( |
| 194 | + &self, |
| 195 | + host: String, |
| 196 | + port: u16, |
| 197 | + ) -> BoxFuture<Result<(u64, SocketInfo), ProviderError>> { |
| 198 | + let this = self.clone(); |
| 199 | + Box::pin(async move { |
| 200 | + let listener = TcpListener::bind((host.as_str(), port)) |
| 201 | + .await |
| 202 | + .map_err(err)?; |
| 203 | + let local = listener.local_addr().ok(); |
| 204 | + let (tx, rx) = mpsc::channel::<(TcpStream, SocketAddr)>(8); |
| 205 | + tokio::spawn(async move { |
| 206 | + while let Ok(conn) = listener.accept().await { |
| 207 | + if tx.send(conn).await.is_err() { |
| 208 | + break; // listener closed (rx dropped) |
| 209 | + } |
| 210 | + } |
| 211 | + }); |
| 212 | + let id = this.id(); |
| 213 | + this.listeners.lock().unwrap().insert(id, rx); |
| 214 | + Ok((id, info_of(local, None))) |
| 215 | + }) |
| 216 | + } |
| 217 | + |
| 218 | + fn accept(&self, id: u64) -> BoxFuture<Result<Option<(u64, SocketInfo)>, ProviderError>> { |
| 219 | + let this = self.clone(); |
| 220 | + Box::pin(async move { |
| 221 | + let mut rx = match this.listeners.lock().unwrap().remove(&id) { |
| 222 | + Some(rx) => rx, |
| 223 | + None => return Ok(None), // listener closed |
| 224 | + }; |
| 225 | + let conn = rx.recv().await; |
| 226 | + this.listeners.lock().unwrap().insert(id, rx); // keep accepting |
| 227 | + match conn { |
| 228 | + Some((stream, remote)) => { |
| 229 | + let _ = stream.set_nodelay(true); |
| 230 | + let info = info_of(stream.local_addr().ok(), Some(remote)); |
| 231 | + let sid = this.id(); |
| 232 | + this.sockets |
| 233 | + .lock() |
| 234 | + .unwrap() |
| 235 | + .insert(sid, SystemNet::spawn_socket(stream)); |
| 236 | + Ok(Some((sid, info))) |
| 237 | + } |
| 238 | + None => Ok(None), |
| 239 | + } |
| 240 | + }) |
| 241 | + } |
| 242 | + |
| 243 | + fn close_listener(&self, id: u64) -> BoxFuture<Result<(), ProviderError>> { |
| 244 | + let listeners = self.listeners.clone(); |
| 245 | + Box::pin(async move { |
| 246 | + listeners.lock().unwrap().remove(&id); |
| 247 | + Ok(()) |
| 248 | + }) |
| 249 | + } |
| 250 | +} |
0 commit comments