Skip to content

Commit 1f30f7e

Browse files
committed
p3-http: tie lifetime of the spawned task to the bodies
Signed-off-by: Roman Volosatovs <rvolosatovs@riseup.net>
1 parent ed62e2e commit 1f30f7e

File tree

1 file changed

+99
-11
lines changed

1 file changed

+99
-11
lines changed

crates/wasi-http/src/p3/host/handler.rs

Lines changed: 99 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,71 @@ use crate::p3::body::{Body, ConsumedBody, GuestBody, GuestBodyKind};
44
use crate::p3::{HttpError, HttpResult, WasiHttp, WasiHttpCtxView, get_content_length};
55
use anyhow::Context as _;
66
use core::pin::Pin;
7+
use core::task::{Context, Poll, Waker};
78
use http::header::HOST;
89
use http::{HeaderValue, Uri};
910
use http_body_util::BodyExt as _;
1011
use std::sync::Arc;
1112
use tokio::sync::oneshot;
1213
use tracing::debug;
13-
use wasmtime::component::{Accessor, AccessorTask, Resource};
14+
use wasmtime::component::{Accessor, AccessorTask, JoinHandle, Resource};
15+
16+
/// A wrapper around [`JoinHandle`], which will [`JoinHandle::abort`] the task
17+
/// when dropped
18+
struct AbortOnDropJoinHandle(JoinHandle);
19+
20+
impl Drop for AbortOnDropJoinHandle {
21+
fn drop(&mut self) {
22+
self.0.abort();
23+
}
24+
}
25+
26+
/// A wrapper around [http_body::Body], which allows attaching arbitrary state to it
27+
struct BodyWithState<T, U> {
28+
body: T,
29+
_state: U,
30+
}
31+
32+
impl<T, U> http_body::Body for BodyWithState<T, U>
33+
where
34+
T: http_body::Body + Unpin,
35+
U: Unpin,
36+
{
37+
type Data = T::Data;
38+
type Error = T::Error;
39+
40+
#[inline]
41+
fn poll_frame(
42+
self: Pin<&mut Self>,
43+
cx: &mut Context<'_>,
44+
) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
45+
Pin::new(&mut self.get_mut().body).poll_frame(cx)
46+
}
47+
48+
#[inline]
49+
fn is_end_stream(&self) -> bool {
50+
self.body.is_end_stream()
51+
}
52+
53+
#[inline]
54+
fn size_hint(&self) -> http_body::SizeHint {
55+
self.body.size_hint()
56+
}
57+
}
58+
59+
trait BodyExt {
60+
fn with_state<T>(self, state: T) -> BodyWithState<Self, T>
61+
where
62+
Self: Sized,
63+
{
64+
BodyWithState {
65+
body: self,
66+
_state: state,
67+
}
68+
}
69+
}
70+
71+
impl<T> BodyExt for T {}
1472

1573
struct SendRequestTask {
1674
io: Pin<Box<dyn Future<Output = Result<(), ErrorCode>> + Send>>,
@@ -26,14 +84,35 @@ impl<T> AccessorTask<T, WasiHttp, wasmtime::Result<()>> for SendRequestTask {
2684
}
2785
}
2886

87+
async fn io_task_result(
88+
rx: oneshot::Receiver<(
89+
Arc<AbortOnDropJoinHandle>,
90+
oneshot::Receiver<Result<(), ErrorCode>>,
91+
)>,
92+
) -> Result<(), ErrorCode> {
93+
let Ok((_io, io_result_rx)) = rx.await else {
94+
return Ok(());
95+
};
96+
io_result_rx.await.unwrap_or(Ok(()))
97+
}
98+
2999
impl HostWithStore for WasiHttp {
30100
async fn handle<T>(
31101
store: &Accessor<T, Self>,
32102
req: Resource<Request>,
33103
) -> HttpResult<Resource<Response>> {
34-
let getter = store.getter();
104+
// A handle to the I/O task, if spawned, will be sent on this channel
105+
// and kept as part of request body state
106+
let (io_task_tx, io_task_rx) = oneshot::channel();
107+
108+
// A handle to the I/O task and, if spawned, will be sent on this channel
109+
// along with the result receiver
35110
let (io_result_tx, io_result_rx) = oneshot::channel();
111+
112+
// Response processing result will be sent on this channel
36113
let (res_result_tx, res_result_rx) = oneshot::channel();
114+
115+
let getter = store.getter();
37116
let fut = store.with(|mut store| {
38117
let WasiHttpCtxView { table, .. } = store.get();
39118
let Request {
@@ -62,7 +141,7 @@ impl HostWithStore for WasiHttp {
62141
if let Ok(Err(err)) = http_result_rx.await {
63142
return Err(err);
64143
};
65-
io_result_rx.await.unwrap_or(Ok(()))
144+
io_task_result(io_result_rx).await
66145
}));
67146
GuestBody::new(
68147
&mut store,
@@ -73,13 +152,12 @@ impl HostWithStore for WasiHttp {
73152
GuestBodyKind::Request,
74153
getter,
75154
)
155+
.with_state(io_task_rx)
76156
.boxed()
77157
}
78158
Body::Host { body, result_tx } => {
79-
_ = result_tx.send(Box::new(
80-
async move { io_result_rx.await.unwrap_or(Ok(())) },
81-
));
82-
body
159+
_ = result_tx.send(Box::new(io_task_result(io_result_rx)));
160+
body.with_state(io_task_rx).boxed()
83161
}
84162
Body::Consumed => ConsumedBody.boxed(),
85163
};
@@ -129,16 +207,26 @@ impl HostWithStore for WasiHttp {
129207
))
130208
})?;
131209
let (res, io) = Box::into_pin(fut).await?;
132-
store.spawn(SendRequestTask {
133-
io: Box::into_pin(io),
134-
result_tx: io_result_tx,
135-
});
136210
let (
137211
http::response::Parts {
138212
status, headers, ..
139213
},
140214
body,
141215
) = res.into_parts();
216+
217+
let mut io = Box::into_pin(io);
218+
let body = match io.as_mut().poll(&mut Context::from_waker(Waker::noop()))? {
219+
Poll::Ready(()) => body,
220+
Poll::Pending => {
221+
// I/O driver still needs to be polled, spawn a task and send handles to it
222+
let (tx, rx) = oneshot::channel();
223+
let io = store.spawn(SendRequestTask { io, result_tx: tx });
224+
let io = Arc::new(AbortOnDropJoinHandle(io));
225+
_ = io_result_tx.send((Arc::clone(&io), rx));
226+
_ = io_task_tx.send(Arc::clone(&io));
227+
body.with_state(io).boxed()
228+
}
229+
};
142230
let res = Response {
143231
status,
144232
headers: Arc::new(headers),

0 commit comments

Comments
 (0)