@@ -14,8 +14,8 @@ use io_lifetimes::AsSocketlike as _;
1414use rustix:: io:: Errno ;
1515use tokio:: net:: { TcpListener , TcpStream } ;
1616use wasmtime:: component:: {
17- Accessor , AccessorTask , FutureWriter , HostFuture , HostStream , Resource , ResourceTable ,
18- StreamWriter ,
17+ Accessor , AccessorTask , FutureReader , FutureWriter , GuardedFutureWriter , GuardedStreamWriter ,
18+ Resource , ResourceTable , StreamReader , StreamWriter ,
1919} ;
2020
2121use crate :: p3:: DEFAULT_BUFFER_CAPACITY ;
@@ -57,16 +57,17 @@ fn get_socket_mut<'a>(
5757struct ListenTask {
5858 listener : Arc < TcpListener > ,
5959 family : SocketAddressFamily ,
60- tx : StreamWriter < Option < Resource < TcpSocket > > > ,
60+ tx : StreamWriter < Resource < TcpSocket > > ,
6161 options : NonInheritedOptions ,
6262}
6363
6464impl < T > AccessorTask < T , WasiSockets , wasmtime:: Result < ( ) > > for ListenTask {
65- async fn run ( mut self , store : & Accessor < T , WasiSockets > ) -> wasmtime:: Result < ( ) > {
66- while !self . tx . is_closed ( ) {
65+ async fn run ( self , store : & Accessor < T , WasiSockets > ) -> wasmtime:: Result < ( ) > {
66+ let mut tx = GuardedStreamWriter :: new ( store, self . tx ) ;
67+ while !tx. is_closed ( ) {
6768 let Some ( res) = ( {
6869 let mut accept = pin ! ( self . listener. accept( ) ) ;
69- let mut tx = pin ! ( self . tx. watch_reader( store ) ) ;
70+ let mut tx = pin ! ( tx. watch_reader( ) ) ;
7071 poll_fn ( |cx| match tx. as_mut ( ) . poll ( cx) {
7172 Poll :: Ready ( ( ) ) => return Poll :: Ready ( None ) ,
7273 Poll :: Pending => accept. as_mut ( ) . poll ( cx) . map ( Some ) ,
@@ -121,8 +122,8 @@ impl<T> AccessorTask<T, WasiSockets, wasmtime::Result<()>> for ListenTask {
121122 . push ( TcpSocket :: from_state ( state, self . family ) )
122123 . context ( "failed to push socket resource to table" )
123124 } ) ?;
124- if let Some ( socket) = self . tx . write ( store , Some ( socket) ) . await {
125- debug_assert ! ( self . tx. is_closed( ) ) ;
125+ if let Some ( socket) = tx. write ( Some ( socket) ) . await {
126+ debug_assert ! ( tx. is_closed( ) ) ;
126127 store. with ( |mut view| {
127128 view. get ( )
128129 . table
@@ -143,40 +144,40 @@ struct ResultWriteTask {
143144
144145impl < T > AccessorTask < T , WasiSockets , wasmtime:: Result < ( ) > > for ResultWriteTask {
145146 async fn run ( self , store : & Accessor < T , WasiSockets > ) -> wasmtime:: Result < ( ) > {
146- self . result_tx . write ( store, self . result ) . await ;
147+ GuardedFutureWriter :: new ( store, self . result_tx )
148+ . write ( self . result )
149+ . await ;
147150 Ok ( ( ) )
148151 }
149152}
150153
151154struct ReceiveTask {
152155 stream : Arc < TcpStream > ,
153- data_tx : StreamWriter < Cursor < BytesMut > > ,
156+ data_tx : StreamWriter < u8 > ,
154157 result_tx : FutureWriter < Result < ( ) , ErrorCode > > ,
155158}
156159
157160impl < T > AccessorTask < T , WasiSockets , wasmtime:: Result < ( ) > > for ReceiveTask {
158- async fn run ( mut self , store : & Accessor < T , WasiSockets > ) -> wasmtime:: Result < ( ) > {
161+ async fn run ( self , store : & Accessor < T , WasiSockets > ) -> wasmtime:: Result < ( ) > {
159162 let mut buf = BytesMut :: with_capacity ( DEFAULT_BUFFER_CAPACITY ) ;
163+ let mut data_tx = GuardedStreamWriter :: new ( store, self . data_tx ) ;
164+ let result_tx = GuardedFutureWriter :: new ( store, self . result_tx ) ;
160165 let res = loop {
161166 match self . stream . try_read_buf ( & mut buf) {
162167 Ok ( 0 ) => {
163168 break Ok ( ( ) ) ;
164169 }
165170 Ok ( ..) => {
166- buf = self
167- . data_tx
168- . write_all ( store, Cursor :: new ( buf) )
169- . await
170- . into_inner ( ) ;
171- if self . data_tx . is_closed ( ) {
171+ buf = data_tx. write_all ( Cursor :: new ( buf) ) . await . into_inner ( ) ;
172+ if data_tx. is_closed ( ) {
172173 break Ok ( ( ) ) ;
173174 }
174175 buf. clear ( ) ;
175176 }
176177 Err ( err) if err. kind ( ) == std:: io:: ErrorKind :: WouldBlock => {
177178 let Some ( res) = ( {
178179 let mut readable = pin ! ( self . stream. readable( ) ) ;
179- let mut tx = pin ! ( self . data_tx. watch_reader( store ) ) ;
180+ let mut tx = pin ! ( data_tx. watch_reader( ) ) ;
180181 poll_fn ( |cx| match tx. as_mut ( ) . poll ( cx) {
181182 Poll :: Ready ( ( ) ) => return Poll :: Ready ( None ) ,
182183 Poll :: Pending => readable. as_mut ( ) . poll ( cx) . map ( Some ) ,
@@ -203,7 +204,7 @@ impl<T> AccessorTask<T, WasiSockets, wasmtime::Result<()>> for ReceiveTask {
203204 // task are freed
204205 store. spawn ( ResultWriteTask {
205206 result : res,
206- result_tx : self . result_tx ,
207+ result_tx : result_tx. into ( ) ,
207208 } ) ;
208209 Ok ( ( ) )
209210 }
@@ -284,14 +285,10 @@ impl HostTcpSocketWithStore for WasiSockets {
284285 async fn listen < T : ' static > (
285286 store : & Accessor < T , Self > ,
286287 socket : Resource < TcpSocket > ,
287- ) -> wasmtime:: Result < Result < HostStream < Resource < TcpSocket > > , ErrorCode > > {
288+ ) -> wasmtime:: Result < Result < StreamReader < Resource < TcpSocket > > , ErrorCode > > {
288289 store. with ( |mut view| {
289- let ( tx, rx) = view
290- . instance ( )
291- . stream :: < _ , _ , Option < _ > > ( & mut view)
292- . context ( "failed to create stream" ) ?;
293290 if !view. get ( ) . ctx . allowed_network_uses . tcp {
294- return Ok ( Err ( ErrorCode :: AccessDenied ) ) ;
291+ return anyhow :: Ok ( Err ( ErrorCode :: AccessDenied ) ) ;
295292 }
296293 let TcpSocket {
297294 tcp_state,
@@ -328,24 +325,29 @@ impl HostTcpSocketWithStore for WasiSockets {
328325 } ;
329326 let listener = Arc :: new ( listener) ;
330327 * tcp_state = TcpState :: Listening ( Arc :: clone ( & listener) ) ;
328+ let family = * family;
329+ let options = options. clone ( ) ;
330+ let ( tx, rx) = view
331+ . instance ( )
332+ . stream ( & mut view)
333+ . context ( "failed to create stream" ) ?;
331334 let task = ListenTask {
332335 listener,
333- family : * family ,
336+ family,
334337 tx,
335- options : options . clone ( ) ,
338+ options,
336339 } ;
337340 view. spawn ( task) ;
338- Ok ( Ok ( rx. into ( ) ) )
341+ Ok ( Ok ( rx) )
339342 } )
340343 }
341344
342345 async fn send < T : ' static > (
343346 store : & Accessor < T , Self > ,
344347 socket : Resource < TcpSocket > ,
345- data : HostStream < u8 > ,
348+ data : StreamReader < u8 > ,
346349 ) -> wasmtime:: Result < Result < ( ) , ErrorCode > > {
347350 let ( stream, mut data) = match store. with ( |mut view| -> wasmtime:: Result < _ > {
348- let data = data. into_reader :: < Vec < _ > > ( & mut view) ;
349351 let sock = get_socket ( view. get ( ) . table , & socket) ?;
350352 if let TcpState :: Connected ( stream) | TcpState :: Receiving ( stream) = & sock. tcp_state {
351353 Ok ( Ok ( ( Arc :: clone ( & stream) , data) ) )
@@ -387,32 +389,34 @@ impl HostTcpSocketWithStore for WasiSockets {
387389 async fn receive < T : ' static > (
388390 store : & Accessor < T , Self > ,
389391 socket : Resource < TcpSocket > ,
390- ) -> wasmtime:: Result < ( HostStream < u8 > , HostFuture < Result < ( ) , ErrorCode > > ) > {
392+ ) -> wasmtime:: Result < ( StreamReader < u8 > , FutureReader < Result < ( ) , ErrorCode > > ) > {
391393 store. with ( |mut view| {
392394 let instance = view. instance ( ) ;
393395 let ( data_tx, data_rx) = instance
394- . stream :: < _ , _ , BytesMut > ( & mut view)
396+ . stream ( & mut view)
395397 . context ( "failed to create stream" ) ?;
396398 let TcpSocket { tcp_state, .. } = get_socket_mut ( view. get ( ) . table , & socket) ?;
397399 match mem:: replace ( tcp_state, TcpState :: Closed ) {
398400 TcpState :: Connected ( stream) => {
399401 * tcp_state = TcpState :: Receiving ( Arc :: clone ( & stream) ) ;
400402 let ( result_tx, result_rx) = instance
401- . future ( || unreachable ! ( ) , & mut view )
403+ . future ( & mut view , || unreachable ! ( ) )
402404 . context ( "failed to create future" ) ?;
403405 view. spawn ( ReceiveTask {
404406 stream,
405407 data_tx,
406408 result_tx,
407409 } ) ;
408- Ok ( ( data_rx. into ( ) , result_rx. into ( ) ) )
410+ Ok ( ( data_rx, result_rx) )
409411 }
410412 prev => {
411413 * tcp_state = prev;
412- let ( _ , result_rx) = instance
413- . future ( || Err ( ErrorCode :: InvalidState ) , & mut view )
414+ let ( result_tx , result_rx) = instance
415+ . future ( & mut view , || Err ( ErrorCode :: InvalidState ) )
414416 . context ( "failed to create future" ) ?;
415- Ok ( ( data_rx. into ( ) , result_rx. into ( ) ) )
417+ result_tx. close ( & mut view) ?;
418+ data_tx. close ( & mut view) ?;
419+ Ok ( ( data_rx, result_rx) )
416420 }
417421 }
418422 } )
0 commit comments