@@ -4,13 +4,71 @@ use crate::p3::body::{Body, ConsumedBody, GuestBody, GuestBodyKind};
44use crate :: p3:: { HttpError , HttpResult , WasiHttp , WasiHttpCtxView , get_content_length} ;
55use anyhow:: Context as _;
66use core:: pin:: Pin ;
7+ use core:: task:: { Context , Poll , Waker } ;
78use http:: header:: HOST ;
89use http:: { HeaderValue , Uri } ;
910use http_body_util:: BodyExt as _;
1011use std:: sync:: Arc ;
1112use tokio:: sync:: oneshot;
1213use tracing:: debug;
13- use wasmtime:: component:: { Accessor , AccessorTask , Resource } ;
14+ use wasmtime:: component:: { Accessor , AccessorTask , JoinHandle , Resource } ;
15+
16+ /// A wrapper around [`JoinHandle`], which will [`JoinHandle::abort`] the task
17+ /// when dropped
18+ struct AbortOnDropJoinHandle ( JoinHandle ) ;
19+
20+ impl Drop for AbortOnDropJoinHandle {
21+ fn drop ( & mut self ) {
22+ self . 0 . abort ( ) ;
23+ }
24+ }
25+
26+ /// A wrapper around [http_body::Body], which allows attaching arbitrary state to it
27+ struct BodyWithState < T , U > {
28+ body : T ,
29+ _state : U ,
30+ }
31+
32+ impl < T , U > http_body:: Body for BodyWithState < T , U >
33+ where
34+ T : http_body:: Body + Unpin ,
35+ U : Unpin ,
36+ {
37+ type Data = T :: Data ;
38+ type Error = T :: Error ;
39+
40+ #[ inline]
41+ fn poll_frame (
42+ self : Pin < & mut Self > ,
43+ cx : & mut Context < ' _ > ,
44+ ) -> Poll < Option < Result < http_body:: Frame < Self :: Data > , Self :: Error > > > {
45+ Pin :: new ( & mut self . get_mut ( ) . body ) . poll_frame ( cx)
46+ }
47+
48+ #[ inline]
49+ fn is_end_stream ( & self ) -> bool {
50+ self . body . is_end_stream ( )
51+ }
52+
53+ #[ inline]
54+ fn size_hint ( & self ) -> http_body:: SizeHint {
55+ self . body . size_hint ( )
56+ }
57+ }
58+
59+ trait BodyExt {
60+ fn with_state < T > ( self , state : T ) -> BodyWithState < Self , T >
61+ where
62+ Self : Sized ,
63+ {
64+ BodyWithState {
65+ body : self ,
66+ _state : state,
67+ }
68+ }
69+ }
70+
71+ impl < T > BodyExt for T { }
1472
1573struct SendRequestTask {
1674 io : Pin < Box < dyn Future < Output = Result < ( ) , ErrorCode > > + Send > > ,
@@ -26,14 +84,35 @@ impl<T> AccessorTask<T, WasiHttp, wasmtime::Result<()>> for SendRequestTask {
2684 }
2785}
2886
87+ async fn io_task_result (
88+ rx : oneshot:: Receiver < (
89+ Arc < AbortOnDropJoinHandle > ,
90+ oneshot:: Receiver < Result < ( ) , ErrorCode > > ,
91+ ) > ,
92+ ) -> Result < ( ) , ErrorCode > {
93+ let Ok ( ( _io, io_result_rx) ) = rx. await else {
94+ return Ok ( ( ) ) ;
95+ } ;
96+ io_result_rx. await . unwrap_or ( Ok ( ( ) ) )
97+ }
98+
2999impl HostWithStore for WasiHttp {
30100 async fn handle < T > (
31101 store : & Accessor < T , Self > ,
32102 req : Resource < Request > ,
33103 ) -> HttpResult < Resource < Response > > {
34- let getter = store. getter ( ) ;
104+ // A handle to the I/O task, if spawned, will be sent on this channel
105+ // and kept as part of request body state
106+ let ( io_task_tx, io_task_rx) = oneshot:: channel ( ) ;
107+
108+ // A handle to the I/O task and, if spawned, will be sent on this channel
109+ // along with the result receiver
35110 let ( io_result_tx, io_result_rx) = oneshot:: channel ( ) ;
111+
112+ // Response processing result will be sent on this channel
36113 let ( res_result_tx, res_result_rx) = oneshot:: channel ( ) ;
114+
115+ let getter = store. getter ( ) ;
37116 let fut = store. with ( |mut store| {
38117 let WasiHttpCtxView { table, .. } = store. get ( ) ;
39118 let Request {
@@ -62,7 +141,7 @@ impl HostWithStore for WasiHttp {
62141 if let Ok ( Err ( err) ) = http_result_rx. await {
63142 return Err ( err) ;
64143 } ;
65- io_result_rx. await . unwrap_or ( Ok ( ( ) ) )
144+ io_task_result ( io_result_rx) . await
66145 } ) ) ;
67146 GuestBody :: new (
68147 & mut store,
@@ -73,13 +152,12 @@ impl HostWithStore for WasiHttp {
73152 GuestBodyKind :: Request ,
74153 getter,
75154 )
155+ . with_state ( io_task_rx)
76156 . boxed ( )
77157 }
78158 Body :: Host { body, result_tx } => {
79- _ = result_tx. send ( Box :: new (
80- async move { io_result_rx. await . unwrap_or ( Ok ( ( ) ) ) } ,
81- ) ) ;
82- body
159+ _ = result_tx. send ( Box :: new ( io_task_result ( io_result_rx) ) ) ;
160+ body. with_state ( io_task_rx) . boxed ( )
83161 }
84162 Body :: Consumed => ConsumedBody . boxed ( ) ,
85163 } ;
@@ -129,16 +207,26 @@ impl HostWithStore for WasiHttp {
129207 ) )
130208 } ) ?;
131209 let ( res, io) = Box :: into_pin ( fut) . await ?;
132- store. spawn ( SendRequestTask {
133- io : Box :: into_pin ( io) ,
134- result_tx : io_result_tx,
135- } ) ;
136210 let (
137211 http:: response:: Parts {
138212 status, headers, ..
139213 } ,
140214 body,
141215 ) = res. into_parts ( ) ;
216+
217+ let mut io = Box :: into_pin ( io) ;
218+ let body = match io. as_mut ( ) . poll ( & mut Context :: from_waker ( Waker :: noop ( ) ) ) ? {
219+ Poll :: Ready ( ( ) ) => body,
220+ Poll :: Pending => {
221+ // I/O driver still needs to be polled, spawn a task and send handles to it
222+ let ( tx, rx) = oneshot:: channel ( ) ;
223+ let io = store. spawn ( SendRequestTask { io, result_tx : tx } ) ;
224+ let io = Arc :: new ( AbortOnDropJoinHandle ( io) ) ;
225+ _ = io_result_tx. send ( ( Arc :: clone ( & io) , rx) ) ;
226+ _ = io_task_tx. send ( Arc :: clone ( & io) ) ;
227+ body. with_state ( io) . boxed ( )
228+ }
229+ } ;
142230 let res = Response {
143231 status,
144232 headers : Arc :: new ( headers) ,
0 commit comments