Skip to content

Commit d1139ae

Browse files
authored
p3-http: implementation follow-up (#11649)
* p3: refactor future producers/consumers Signed-off-by: Roman Volosatovs <rvolosatovs@riseup.net> * p3-http: tie lifetime of the spawned task to the bodies Signed-off-by: Roman Volosatovs <rvolosatovs@riseup.net> * p3-http: improve docs Signed-off-by: Roman Volosatovs <rvolosatovs@riseup.net> --------- Signed-off-by: Roman Volosatovs <rvolosatovs@riseup.net>
1 parent f3d7256 commit d1139ae

8 files changed

Lines changed: 207 additions & 55 deletions

File tree

crates/wasi-http/src/p3/body.rs

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ pub(crate) enum Body {
4242
}
4343

4444
impl Body {
45+
/// Implementation of `consume-body` shared between requests and responses
4546
pub(crate) fn consume<T>(
4647
self,
4748
mut store: Access<'_, T, WasiHttp>,
@@ -105,6 +106,7 @@ impl Body {
105106
}
106107
}
107108

109+
/// Implementation of `drop` shared between requests and responses
108110
pub(crate) fn drop(self, mut store: impl AsContextMut) {
109111
if let Body::Guest {
110112
contents_rx,
@@ -120,7 +122,8 @@ impl Body {
120122
}
121123
}
122124

123-
pub(crate) enum GuestBodyKind {
125+
/// The kind of body, used for error reporting
126+
pub(crate) enum BodyKind {
124127
Request,
125128
Response,
126129
}
@@ -141,20 +144,22 @@ impl ContentLength {
141144
}
142145
}
143146

147+
/// [StreamConsumer] implementation for bodies originating in the guest.
144148
struct GuestBodyConsumer {
145149
contents_tx: PollSender<Result<Bytes, ErrorCode>>,
146150
result_tx: Option<oneshot::Sender<Result<(), ErrorCode>>>,
147151
content_length: Option<ContentLength>,
148-
kind: GuestBodyKind,
152+
kind: BodyKind,
149153
// `true` when the other side of `contents_tx` was unexpectedly closed
150154
closed: bool,
151155
}
152156

153157
impl GuestBodyConsumer {
158+
/// Constructs the approprite body size error given the [BodyKind]
154159
fn body_size_error(&self, n: Option<u64>) -> ErrorCode {
155160
match self.kind {
156-
GuestBodyKind::Request => ErrorCode::HttpRequestBodySize(n),
157-
GuestBodyKind::Response => ErrorCode::HttpResponseBodySize(n),
161+
BodyKind::Request => ErrorCode::HttpRequestBodySize(n),
162+
BodyKind::Response => ErrorCode::HttpResponseBodySize(n),
158163
}
159164
}
160165

@@ -235,20 +240,22 @@ impl<D> StreamConsumer<D> for GuestBodyConsumer {
235240
}
236241
}
237242

243+
/// [http_body::Body] implementation for bodies originating in the guest.
238244
pub(crate) struct GuestBody {
239245
contents_rx: Option<mpsc::Receiver<Result<Bytes, ErrorCode>>>,
240246
trailers_rx: Option<oneshot::Receiver<Result<Option<Arc<http::HeaderMap>>, ErrorCode>>>,
241247
content_length: Option<u64>,
242248
}
243249

244250
impl GuestBody {
251+
/// Construct a new [GuestBody]
245252
pub(crate) fn new<T: 'static>(
246253
mut store: impl AsContextMut<Data = T>,
247254
contents_rx: Option<StreamReader<u8>>,
248255
trailers_rx: FutureReader<Result<Option<Resource<Trailers>>, ErrorCode>>,
249256
result_tx: oneshot::Sender<Result<(), ErrorCode>>,
250257
content_length: Option<u64>,
251-
kind: GuestBodyKind,
258+
kind: BodyKind,
252259
getter: fn(&mut T) -> WasiHttpCtxView<'_>,
253260
) -> Self {
254261
let (trailers_http_tx, trailers_http_rx) = oneshot::channel();
@@ -290,10 +297,15 @@ impl http_body::Body for GuestBody {
290297
cx: &mut Context<'_>,
291298
) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
292299
if let Some(contents_rx) = self.contents_rx.as_mut() {
300+
// `contents_rx` has not been closed yet, poll it
293301
while let Some(res) = ready!(contents_rx.poll_recv(cx)) {
294302
match res {
295303
Ok(buf) => {
296304
if let Some(n) = self.content_length.as_mut() {
305+
// Substract frame length from `content_length`,
306+
// [GuestBodyConsumer] already performs the validation, so
307+
// just keep count as optimization for
308+
// `is_end_stream` and `size_hint`
297309
*n = n.saturating_sub(buf.len().try_into().unwrap_or(u64::MAX));
298310
}
299311
return Poll::Ready(Some(Ok(http_body::Frame::data(buf))));
@@ -303,14 +315,17 @@ impl http_body::Body for GuestBody {
303315
}
304316
}
305317
}
318+
// Record that `contents_rx` is closed
306319
self.contents_rx = None;
307320
}
308321

309322
let Some(trailers_rx) = self.trailers_rx.as_mut() else {
323+
// `trailers_rx` has already terminated - this is the end of stream
310324
return Poll::Ready(None);
311325
};
312326

313327
let res = ready!(Pin::new(trailers_rx).poll(cx));
328+
// Record that `trailers_rx` has terminated
314329
self.trailers_rx = None;
315330
match res {
316331
Ok(Ok(Some(trailers))) => Poll::Ready(Some(Ok(http_body::Frame::trailers(
@@ -328,14 +343,18 @@ impl http_body::Body for GuestBody {
328343
|| !contents_rx.is_closed()
329344
|| self.content_length.is_some_and(|n| n > 0)
330345
{
346+
// `contents_rx` might still produce data frames
331347
return false;
332348
}
333349
}
334350
if let Some(trailers_rx) = self.trailers_rx.as_ref() {
335351
if !trailers_rx.is_terminated() {
352+
// `trailers_rx` has not terminated yet
336353
return false;
337354
}
338355
}
356+
357+
// no data left
339358
return true;
340359
}
341360

@@ -348,6 +367,7 @@ impl http_body::Body for GuestBody {
348367
}
349368
}
350369

370+
/// [http_body::Body] that has been consumed.
351371
pub(crate) struct ConsumedBody;
352372

353373
impl http_body::Body for ConsumedBody {
@@ -372,9 +392,10 @@ impl http_body::Body for ConsumedBody {
372392
}
373393
}
374394

375-
pub(crate) struct GuestTrailerConsumer<T> {
376-
pub(crate) tx: Option<oneshot::Sender<Result<Option<Arc<HeaderMap>>, ErrorCode>>>,
377-
pub(crate) getter: fn(&mut T) -> WasiHttpCtxView<'_>,
395+
/// [FutureConsumer] implementation for trailers originating in the guest.
396+
struct GuestTrailerConsumer<T> {
397+
tx: Option<oneshot::Sender<Result<Option<Arc<HeaderMap>>, ErrorCode>>>,
398+
getter: fn(&mut T) -> WasiHttpCtxView<'_>,
378399
}
379400

380401
impl<D> FutureConsumer<D> for GuestTrailerConsumer<D>
@@ -387,12 +408,13 @@ where
387408
mut self: Pin<&mut Self>,
388409
_: &mut Context<'_>,
389410
mut store: StoreContextMut<D>,
390-
mut source: Source<'_, Self::Item>,
411+
mut src: Source<'_, Self::Item>,
391412
_: bool,
392413
) -> Poll<wasmtime::Result<()>> {
393-
let value = &mut None;
394-
source.read(store.as_context_mut(), value)?;
395-
let res = match value.take().unwrap() {
414+
let mut result = None;
415+
src.read(store.as_context_mut(), &mut result)
416+
.context("failed to read result")?;
417+
let res = match result.context("result value missing")? {
396418
Ok(Some(trailers)) => {
397419
let WasiHttpCtxView { table, .. } = (self.getter)(store.data_mut());
398420
let trailers = table
@@ -408,6 +430,7 @@ where
408430
}
409431
}
410432

433+
/// [StreamProducer] implementation for bodies originating in the host.
411434
struct HostBodyStreamProducer<T> {
412435
body: BoxBody<Bytes, ErrorCode>,
413436
trailers: Option<oneshot::Sender<Result<Option<Resource<Trailers>>, ErrorCode>>>,
@@ -446,6 +469,8 @@ where
446469
let cap = match dst.remaining(&mut store).map(NonZeroUsize::new) {
447470
Some(Some(cap)) => Some(cap),
448471
Some(None) => {
472+
// On 0-length the best we can do is check that underlying stream has not
473+
// reached the end yet
449474
if self.body.is_end_stream() {
450475
break 'result Ok(None);
451476
} else {
@@ -462,11 +487,13 @@ where
462487
let n = frame.len();
463488
let cap = cap.into();
464489
if n > cap {
490+
// data frame does not fit in destination, fill it and buffer the rest
465491
dst.set_buffer(Cursor::new(frame.split_off(cap)));
466492
let mut dst = dst.as_direct(store, cap);
467493
dst.remaining().copy_from_slice(&frame);
468494
dst.mark_written(cap);
469495
} else {
496+
// copy the whole frame into the destination
470497
let mut dst = dst.as_direct(store, n);
471498
dst.remaining()[..n].copy_from_slice(&frame);
472499
dst.mark_written(n);

crates/wasi-http/src/p3/host/handler.rs

Lines changed: 103 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,74 @@
11
use crate::p3::bindings::http::handler::{Host, HostWithStore};
22
use crate::p3::bindings::http::types::{ErrorCode, Request, Response};
3-
use crate::p3::body::{Body, ConsumedBody, GuestBody, GuestBodyKind};
3+
use crate::p3::body::{Body, BodyKind, ConsumedBody, GuestBody};
44
use crate::p3::{HttpError, HttpResult, WasiHttp, WasiHttpCtxView, get_content_length};
55
use anyhow::Context as _;
66
use core::pin::Pin;
7+
use core::task::{Context, Poll, Waker};
78
use http::header::HOST;
89
use http::{HeaderValue, Uri};
910
use http_body_util::BodyExt as _;
1011
use std::sync::Arc;
1112
use tokio::sync::oneshot;
1213
use tracing::debug;
13-
use wasmtime::component::{Accessor, AccessorTask, Resource};
14+
use wasmtime::component::{Accessor, AccessorTask, JoinHandle, Resource};
15+
16+
/// A wrapper around [`JoinHandle`], which will [`JoinHandle::abort`] the task
17+
/// when dropped
18+
struct AbortOnDropJoinHandle(JoinHandle);
19+
20+
impl Drop for AbortOnDropJoinHandle {
21+
fn drop(&mut self) {
22+
self.0.abort();
23+
}
24+
}
25+
26+
/// A wrapper around [http_body::Body], which allows attaching arbitrary state to it
27+
struct BodyWithState<T, U> {
28+
body: T,
29+
_state: U,
30+
}
31+
32+
impl<T, U> http_body::Body for BodyWithState<T, U>
33+
where
34+
T: http_body::Body + Unpin,
35+
U: Unpin,
36+
{
37+
type Data = T::Data;
38+
type Error = T::Error;
39+
40+
#[inline]
41+
fn poll_frame(
42+
self: Pin<&mut Self>,
43+
cx: &mut Context<'_>,
44+
) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
45+
Pin::new(&mut self.get_mut().body).poll_frame(cx)
46+
}
47+
48+
#[inline]
49+
fn is_end_stream(&self) -> bool {
50+
self.body.is_end_stream()
51+
}
52+
53+
#[inline]
54+
fn size_hint(&self) -> http_body::SizeHint {
55+
self.body.size_hint()
56+
}
57+
}
58+
59+
trait BodyExt {
60+
fn with_state<T>(self, state: T) -> BodyWithState<Self, T>
61+
where
62+
Self: Sized,
63+
{
64+
BodyWithState {
65+
body: self,
66+
_state: state,
67+
}
68+
}
69+
}
70+
71+
impl<T> BodyExt for T {}
1472

1573
struct SendRequestTask {
1674
io: Pin<Box<dyn Future<Output = Result<(), ErrorCode>> + Send>>,
@@ -26,14 +84,35 @@ impl<T> AccessorTask<T, WasiHttp, wasmtime::Result<()>> for SendRequestTask {
2684
}
2785
}
2886

87+
async fn io_task_result(
88+
rx: oneshot::Receiver<(
89+
Arc<AbortOnDropJoinHandle>,
90+
oneshot::Receiver<Result<(), ErrorCode>>,
91+
)>,
92+
) -> Result<(), ErrorCode> {
93+
let Ok((_io, io_result_rx)) = rx.await else {
94+
return Ok(());
95+
};
96+
io_result_rx.await.unwrap_or(Ok(()))
97+
}
98+
2999
impl HostWithStore for WasiHttp {
30100
async fn handle<T>(
31101
store: &Accessor<T, Self>,
32102
req: Resource<Request>,
33103
) -> HttpResult<Resource<Response>> {
34-
let getter = store.getter();
104+
// A handle to the I/O task, if spawned, will be sent on this channel
105+
// and kept as part of request body state
106+
let (io_task_tx, io_task_rx) = oneshot::channel();
107+
108+
// A handle to the I/O task, if spawned, will be sent on this channel
109+
// along with the result receiver
35110
let (io_result_tx, io_result_rx) = oneshot::channel();
111+
112+
// Response processing result will be sent on this channel
36113
let (res_result_tx, res_result_rx) = oneshot::channel();
114+
115+
let getter = store.getter();
37116
let fut = store.with(|mut store| {
38117
let WasiHttpCtxView { table, .. } = store.get();
39118
let Request {
@@ -56,30 +135,30 @@ impl HostWithStore for WasiHttp {
56135
result_tx,
57136
} => {
58137
let (http_result_tx, http_result_rx) = oneshot::channel();
138+
// `Content-Length` header value is validated in `fields` implementation
59139
let content_length = get_content_length(&headers)
60140
.map_err(|err| ErrorCode::InternalError(Some(format!("{err:#}"))))?;
61141
_ = result_tx.send(Box::new(async move {
62142
if let Ok(Err(err)) = http_result_rx.await {
63143
return Err(err);
64144
};
65-
io_result_rx.await.unwrap_or(Ok(()))
145+
io_task_result(io_result_rx).await
66146
}));
67147
GuestBody::new(
68148
&mut store,
69149
contents_rx,
70150
trailers_rx,
71151
http_result_tx,
72152
content_length,
73-
GuestBodyKind::Request,
153+
BodyKind::Request,
74154
getter,
75155
)
156+
.with_state(io_task_rx)
76157
.boxed()
77158
}
78159
Body::Host { body, result_tx } => {
79-
_ = result_tx.send(Box::new(
80-
async move { io_result_rx.await.unwrap_or(Ok(())) },
81-
));
82-
body
160+
_ = result_tx.send(Box::new(io_task_result(io_result_rx)));
161+
body.with_state(io_task_rx).boxed()
83162
}
84163
Body::Consumed => ConsumedBody.boxed(),
85164
};
@@ -121,6 +200,7 @@ impl HostWithStore for WasiHttp {
121200
req,
122201
options.as_deref().copied(),
123202
Box::new(async {
203+
// Forward the response processing result to `WasiHttpCtx` implementation
124204
let Ok(fut) = res_result_rx.await else {
125205
return Ok(());
126206
};
@@ -129,16 +209,26 @@ impl HostWithStore for WasiHttp {
129209
))
130210
})?;
131211
let (res, io) = Box::into_pin(fut).await?;
132-
store.spawn(SendRequestTask {
133-
io: Box::into_pin(io),
134-
result_tx: io_result_tx,
135-
});
136212
let (
137213
http::response::Parts {
138214
status, headers, ..
139215
},
140216
body,
141217
) = res.into_parts();
218+
219+
let mut io = Box::into_pin(io);
220+
let body = match io.as_mut().poll(&mut Context::from_waker(Waker::noop()))? {
221+
Poll::Ready(()) => body,
222+
Poll::Pending => {
223+
// I/O driver still needs to be polled, spawn a task and send handles to it
224+
let (tx, rx) = oneshot::channel();
225+
let io = store.spawn(SendRequestTask { io, result_tx: tx });
226+
let io = Arc::new(AbortOnDropJoinHandle(io));
227+
_ = io_result_tx.send((Arc::clone(&io), rx));
228+
_ = io_task_tx.send(Arc::clone(&io));
229+
body.with_state(io).boxed()
230+
}
231+
};
142232
let res = Response {
143233
status,
144234
headers: Arc::new(headers),

0 commit comments

Comments
 (0)