Skip to content

Commit 74ee080

Browse files
committed
fix tcp.rs regressions
Signed-off-by: Joel Dice <joel.dice@fermyon.com>
1 parent f1450b8 commit 74ee080

File tree

1 file changed

+41
-37
lines changed
  • crates/wasi/src/p3/sockets/host/types

1 file changed

+41
-37
lines changed

crates/wasi/src/p3/sockets/host/types/tcp.rs

Lines changed: 41 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ use io_lifetimes::AsSocketlike as _;
1414
use rustix::io::Errno;
1515
use tokio::net::{TcpListener, TcpStream};
1616
use wasmtime::component::{
17-
Accessor, AccessorTask, FutureWriter, HostFuture, HostStream, Resource, ResourceTable,
18-
StreamWriter,
17+
Accessor, AccessorTask, FutureReader, FutureWriter, GuardedFutureWriter, GuardedStreamWriter,
18+
Resource, ResourceTable, StreamReader, StreamWriter,
1919
};
2020

2121
use crate::p3::DEFAULT_BUFFER_CAPACITY;
@@ -57,16 +57,17 @@ fn get_socket_mut<'a>(
5757
struct ListenTask {
5858
listener: Arc<TcpListener>,
5959
family: SocketAddressFamily,
60-
tx: StreamWriter<Option<Resource<TcpSocket>>>,
60+
tx: StreamWriter<Resource<TcpSocket>>,
6161
options: NonInheritedOptions,
6262
}
6363

6464
impl<T> AccessorTask<T, WasiSockets, wasmtime::Result<()>> for ListenTask {
65-
async fn run(mut self, store: &Accessor<T, WasiSockets>) -> wasmtime::Result<()> {
66-
while !self.tx.is_closed() {
65+
async fn run(self, store: &Accessor<T, WasiSockets>) -> wasmtime::Result<()> {
66+
let mut tx = GuardedStreamWriter::new(store, self.tx);
67+
while !tx.is_closed() {
6768
let Some(res) = ({
6869
let mut accept = pin!(self.listener.accept());
69-
let mut tx = pin!(self.tx.watch_reader(store));
70+
let mut tx = pin!(tx.watch_reader());
7071
poll_fn(|cx| match tx.as_mut().poll(cx) {
7172
Poll::Ready(()) => return Poll::Ready(None),
7273
Poll::Pending => accept.as_mut().poll(cx).map(Some),
@@ -121,8 +122,8 @@ impl<T> AccessorTask<T, WasiSockets, wasmtime::Result<()>> for ListenTask {
121122
.push(TcpSocket::from_state(state, self.family))
122123
.context("failed to push socket resource to table")
123124
})?;
124-
if let Some(socket) = self.tx.write(store, Some(socket)).await {
125-
debug_assert!(self.tx.is_closed());
125+
if let Some(socket) = tx.write(Some(socket)).await {
126+
debug_assert!(tx.is_closed());
126127
store.with(|mut view| {
127128
view.get()
128129
.table
@@ -143,40 +144,40 @@ struct ResultWriteTask {
143144

144145
impl<T> AccessorTask<T, WasiSockets, wasmtime::Result<()>> for ResultWriteTask {
145146
async fn run(self, store: &Accessor<T, WasiSockets>) -> wasmtime::Result<()> {
146-
self.result_tx.write(store, self.result).await;
147+
GuardedFutureWriter::new(store, self.result_tx)
148+
.write(self.result)
149+
.await;
147150
Ok(())
148151
}
149152
}
150153

151154
struct ReceiveTask {
152155
stream: Arc<TcpStream>,
153-
data_tx: StreamWriter<Cursor<BytesMut>>,
156+
data_tx: StreamWriter<u8>,
154157
result_tx: FutureWriter<Result<(), ErrorCode>>,
155158
}
156159

157160
impl<T> AccessorTask<T, WasiSockets, wasmtime::Result<()>> for ReceiveTask {
158-
async fn run(mut self, store: &Accessor<T, WasiSockets>) -> wasmtime::Result<()> {
161+
async fn run(self, store: &Accessor<T, WasiSockets>) -> wasmtime::Result<()> {
159162
let mut buf = BytesMut::with_capacity(DEFAULT_BUFFER_CAPACITY);
163+
let mut data_tx = GuardedStreamWriter::new(store, self.data_tx);
164+
let result_tx = GuardedFutureWriter::new(store, self.result_tx);
160165
let res = loop {
161166
match self.stream.try_read_buf(&mut buf) {
162167
Ok(0) => {
163168
break Ok(());
164169
}
165170
Ok(..) => {
166-
buf = self
167-
.data_tx
168-
.write_all(store, Cursor::new(buf))
169-
.await
170-
.into_inner();
171-
if self.data_tx.is_closed() {
171+
buf = data_tx.write_all(Cursor::new(buf)).await.into_inner();
172+
if data_tx.is_closed() {
172173
break Ok(());
173174
}
174175
buf.clear();
175176
}
176177
Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => {
177178
let Some(res) = ({
178179
let mut readable = pin!(self.stream.readable());
179-
let mut tx = pin!(self.data_tx.watch_reader(store));
180+
let mut tx = pin!(data_tx.watch_reader());
180181
poll_fn(|cx| match tx.as_mut().poll(cx) {
181182
Poll::Ready(()) => return Poll::Ready(None),
182183
Poll::Pending => readable.as_mut().poll(cx).map(Some),
@@ -203,7 +204,7 @@ impl<T> AccessorTask<T, WasiSockets, wasmtime::Result<()>> for ReceiveTask {
203204
// task are freed
204205
store.spawn(ResultWriteTask {
205206
result: res,
206-
result_tx: self.result_tx,
207+
result_tx: result_tx.into(),
207208
});
208209
Ok(())
209210
}
@@ -284,14 +285,10 @@ impl HostTcpSocketWithStore for WasiSockets {
284285
async fn listen<T: 'static>(
285286
store: &Accessor<T, Self>,
286287
socket: Resource<TcpSocket>,
287-
) -> wasmtime::Result<Result<HostStream<Resource<TcpSocket>>, ErrorCode>> {
288+
) -> wasmtime::Result<Result<StreamReader<Resource<TcpSocket>>, ErrorCode>> {
288289
store.with(|mut view| {
289-
let (tx, rx) = view
290-
.instance()
291-
.stream::<_, _, Option<_>>(&mut view)
292-
.context("failed to create stream")?;
293290
if !view.get().ctx.allowed_network_uses.tcp {
294-
return Ok(Err(ErrorCode::AccessDenied));
291+
return anyhow::Ok(Err(ErrorCode::AccessDenied));
295292
}
296293
let TcpSocket {
297294
tcp_state,
@@ -328,24 +325,29 @@ impl HostTcpSocketWithStore for WasiSockets {
328325
};
329326
let listener = Arc::new(listener);
330327
*tcp_state = TcpState::Listening(Arc::clone(&listener));
328+
let family = *family;
329+
let options = options.clone();
330+
let (tx, rx) = view
331+
.instance()
332+
.stream(&mut view)
333+
.context("failed to create stream")?;
331334
let task = ListenTask {
332335
listener,
333-
family: *family,
336+
family,
334337
tx,
335-
options: options.clone(),
338+
options,
336339
};
337340
view.spawn(task);
338-
Ok(Ok(rx.into()))
341+
Ok(Ok(rx))
339342
})
340343
}
341344

342345
async fn send<T: 'static>(
343346
store: &Accessor<T, Self>,
344347
socket: Resource<TcpSocket>,
345-
data: HostStream<u8>,
348+
data: StreamReader<u8>,
346349
) -> wasmtime::Result<Result<(), ErrorCode>> {
347350
let (stream, mut data) = match store.with(|mut view| -> wasmtime::Result<_> {
348-
let data = data.into_reader::<Vec<_>>(&mut view);
349351
let sock = get_socket(view.get().table, &socket)?;
350352
if let TcpState::Connected(stream) | TcpState::Receiving(stream) = &sock.tcp_state {
351353
Ok(Ok((Arc::clone(&stream), data)))
@@ -387,32 +389,34 @@ impl HostTcpSocketWithStore for WasiSockets {
387389
async fn receive<T: 'static>(
388390
store: &Accessor<T, Self>,
389391
socket: Resource<TcpSocket>,
390-
) -> wasmtime::Result<(HostStream<u8>, HostFuture<Result<(), ErrorCode>>)> {
392+
) -> wasmtime::Result<(StreamReader<u8>, FutureReader<Result<(), ErrorCode>>)> {
391393
store.with(|mut view| {
392394
let instance = view.instance();
393395
let (data_tx, data_rx) = instance
394-
.stream::<_, _, BytesMut>(&mut view)
396+
.stream(&mut view)
395397
.context("failed to create stream")?;
396398
let TcpSocket { tcp_state, .. } = get_socket_mut(view.get().table, &socket)?;
397399
match mem::replace(tcp_state, TcpState::Closed) {
398400
TcpState::Connected(stream) => {
399401
*tcp_state = TcpState::Receiving(Arc::clone(&stream));
400402
let (result_tx, result_rx) = instance
401-
.future(|| unreachable!(), &mut view)
403+
.future(&mut view, || unreachable!())
402404
.context("failed to create future")?;
403405
view.spawn(ReceiveTask {
404406
stream,
405407
data_tx,
406408
result_tx,
407409
});
408-
Ok((data_rx.into(), result_rx.into()))
410+
Ok((data_rx, result_rx))
409411
}
410412
prev => {
411413
*tcp_state = prev;
412-
let (_, result_rx) = instance
413-
.future(|| Err(ErrorCode::InvalidState), &mut view)
414+
let (result_tx, result_rx) = instance
415+
.future(&mut view, || Err(ErrorCode::InvalidState))
414416
.context("failed to create future")?;
415-
Ok((data_rx.into(), result_rx.into()))
417+
result_tx.close(&mut view)?;
418+
data_tx.close(&mut view)?;
419+
Ok((data_rx, result_rx))
416420
}
417421
}
418422
})

0 commit comments

Comments
 (0)