Skip to content

Commit 5a2860c

Browse files
authored
p3-http: rework content-length handling (#11658)
* p3-http: correctly handle `result` future cancellation Signed-off-by: Roman Volosatovs <rvolosatovs@riseup.net> * p3-http: restructure the `content-length` test a bit Signed-off-by: Roman Volosatovs <rvolosatovs@riseup.net> * test(http): keep accepting connections after errors prtest:full Signed-off-by: Roman Volosatovs <rvolosatovs@riseup.net> * test(p3-http): assert `handle` error on exceeding `content-length` Signed-off-by: Roman Volosatovs <rvolosatovs@riseup.net> * p3-http: perform `content-length` check early Signed-off-by: Roman Volosatovs <rvolosatovs@riseup.net> * test(p3-http): account for `handle` race condition Signed-off-by: Roman Volosatovs <rvolosatovs@riseup.net> * refactor(http): reuse `get_content_length` Signed-off-by: Roman Volosatovs <rvolosatovs@riseup.net> * p3-http: check `content-length` for host bodies Signed-off-by: Roman Volosatovs <rvolosatovs@riseup.net> * doc(p3-http): call out that host bodies are not validated Signed-off-by: Roman Volosatovs <rvolosatovs@riseup.net> * p3-http: refactor body size error send Signed-off-by: Roman Volosatovs <rvolosatovs@riseup.net> * fix(p3-http): do not rely on `Drop` for host body check Signed-off-by: Roman Volosatovs <rvolosatovs@riseup.net> * doc(p3-http): ensure non-default send request is documented Signed-off-by: Roman Volosatovs <rvolosatovs@riseup.net> * doc(p3-http): correct `send_request` doc Signed-off-by: Roman Volosatovs <rvolosatovs@riseup.net> --------- Signed-off-by: Roman Volosatovs <rvolosatovs@riseup.net>
1 parent 447efbd commit 5a2860c

File tree

11 files changed

+373
-229
lines changed

11 files changed

+373
-229
lines changed

crates/test-programs/src/bin/p3_http_outbound_request_content_length.rs

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,7 @@ impl test_programs::p3::exports::wasi::cli::run::Guest for Component {
5151
async { transmit.await },
5252
async {
5353
let remaining = contents_tx.write_all(b"long enough".to_vec()).await;
54-
assert!(
55-
remaining.is_empty(),
56-
"{}",
57-
String::from_utf8_lossy(&remaining)
58-
);
54+
assert_eq!(String::from_utf8_lossy(&remaining), "");
5955
trailers_tx.write(Ok(None)).await.unwrap();
6056
drop(contents_tx);
6157
},
@@ -72,19 +68,19 @@ impl test_programs::p3::exports::wasi::cli::run::Guest for Component {
7268
async { transmit.await },
7369
async {
7470
let remaining = contents_tx.write_all(b"msg".to_vec()).await;
75-
assert!(
76-
remaining.is_empty(),
77-
"{}",
78-
String::from_utf8_lossy(&remaining)
79-
);
80-
drop(contents_tx);
71+
assert_eq!(String::from_utf8_lossy(&remaining), "");
8172
trailers_tx.write(Ok(None)).await.unwrap();
73+
drop(contents_tx);
8274
},
8375
);
84-
let err = handle.expect_err("should have failed to send request");
76+
// The request body will be polled before `handle` returns.
77+
// Due to the way implementation is structured, by the time it happens
78+
// the error will be already available in most cases and `handle` will fail,
79+
// but it is a race condition, since `handle` may also succeed if
80+
// polling body returns `Poll::Pending`
8581
assert!(
86-
matches!(err, ErrorCode::HttpProtocolError),
87-
"unexpected error: {err:#?}"
82+
matches!(handle, Ok(..) | Err(ErrorCode::HttpProtocolError)),
83+
"unexpected handle result: {handle:#?}"
8884
);
8985
let err = transmit.expect_err("request transmission should have failed");
9086
assert!(
@@ -101,15 +97,22 @@ impl test_programs::p3::exports::wasi::cli::run::Guest for Component {
10197
async { transmit.await },
10298
async {
10399
let remaining = contents_tx.write_all(b"more than 11 bytes".to_vec()).await;
104-
assert_eq!(String::from_utf8_lossy(&remaining), "more than 11 bytes",);
105-
drop(contents_tx);
100+
assert_eq!(String::from_utf8_lossy(&remaining), "more than 11 bytes");
106101
_ = trailers_tx.write(Ok(None)).await;
107102
},
108103
);
109-
110-
// The the error returned by `handle` in this case is non-deterministic,
111-
// so just assert that it fails
112-
let _err = handle.expect_err("should have failed to send request");
104+
// The request body will be polled before `handle` returns.
105+
// Due to the way implementation is structured, by the time it happens
106+
// the error will be already available in most cases and `handle` will fail,
107+
// but it is a race condition, since `handle` may also succeed if
108+
// polling body returns `Poll::Pending`
109+
assert!(
110+
matches!(
111+
handle,
112+
Ok(..) | Err(ErrorCode::HttpRequestBodySize(Some(18)))
113+
),
114+
"unexpected handle result: {handle:#?}"
115+
);
113116
let err = transmit.expect_err("request transmission should have failed");
114117
assert!(
115118
matches!(err, ErrorCode::HttpRequestBodySize(Some(18))),

crates/wasi-http/src/lib.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ pub use crate::types::{
240240
DEFAULT_OUTGOING_BODY_BUFFER_CHUNKS, DEFAULT_OUTGOING_BODY_CHUNK_SIZE, WasiHttpCtx,
241241
WasiHttpImpl, WasiHttpView,
242242
};
243+
use http::header::CONTENT_LENGTH;
243244
use wasmtime::component::{HasData, Linker};
244245

245246
/// Add all of the `wasi:http/proxy` world's interfaces to a [`wasmtime::component::Linker`].
@@ -391,3 +392,15 @@ where
391392

392393
Ok(())
393394
}
395+
396+
/// Extract the `Content-Length` header value from a [`http::HeaderMap`], returning `None` if it's not
397+
/// present. This function will return `Err` if it's not possible to parse the `Content-Length`
398+
/// header.
399+
fn get_content_length(headers: &http::HeaderMap) -> wasmtime::Result<Option<u64>> {
400+
let Some(v) = headers.get(CONTENT_LENGTH) else {
401+
return Ok(None);
402+
};
403+
let v = v.to_str()?;
404+
let v = v.parse()?;
405+
Ok(Some(v))
406+
}

crates/wasi-http/src/p3/body.rs

Lines changed: 110 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -122,75 +122,44 @@ impl Body {
122122
}
123123
}
124124

125-
/// The kind of body, used for error reporting
126-
pub(crate) enum BodyKind {
127-
Request,
128-
Response,
129-
}
130-
131-
/// Represents `Content-Length` limit and state
132-
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
133-
struct ContentLength {
125+
/// [StreamConsumer] implementation for bodies originating in the guest with `Content-Length`
126+
/// header set.
127+
struct LimitedGuestBodyConsumer {
128+
contents_tx: PollSender<Result<Bytes, ErrorCode>>,
129+
error_tx: Option<oneshot::Sender<ErrorCode>>,
130+
make_error: fn(Option<u64>) -> ErrorCode,
134131
/// Limit of bytes to be sent
135132
limit: u64,
136133
/// Number of bytes sent
137134
sent: u64,
138-
}
139-
140-
impl ContentLength {
141-
/// Constructs new [ContentLength]
142-
fn new(limit: u64) -> Self {
143-
Self { limit, sent: 0 }
144-
}
145-
}
146-
147-
/// [StreamConsumer] implementation for bodies originating in the guest.
148-
struct GuestBodyConsumer {
149-
contents_tx: PollSender<Result<Bytes, ErrorCode>>,
150-
result_tx: Option<oneshot::Sender<Result<(), ErrorCode>>>,
151-
content_length: Option<ContentLength>,
152-
kind: BodyKind,
153135
// `true` when the other side of `contents_tx` was unexpectedly closed
154136
closed: bool,
155137
}
156138

157-
impl GuestBodyConsumer {
158-
/// Constructs the approprite body size error given the [BodyKind]
159-
fn body_size_error(&self, n: Option<u64>) -> ErrorCode {
160-
match self.kind {
161-
BodyKind::Request => ErrorCode::HttpRequestBodySize(n),
162-
BodyKind::Response => ErrorCode::HttpResponseBodySize(n),
163-
}
164-
}
165-
166-
// Sends the corresponding error constructed by [Self::body_size_error] on both
167-
// error channels.
168-
// [`PollSender::poll_reserve`] on `contents_tx` must have succeeed prior to this being called.
169-
fn send_body_size_error(&mut self, n: Option<u64>) {
170-
if let Some(result_tx) = self.result_tx.take() {
171-
_ = result_tx.send(Err(self.body_size_error(n)));
172-
_ = self.contents_tx.send_item(Err(self.body_size_error(n)));
139+
impl LimitedGuestBodyConsumer {
140+
/// Sends the error constructed by [Self::make_error] on both error channels.
141+
/// Does nothing if an error has already been sent on [Self::error_tx].
142+
fn send_error(&mut self, sent: Option<u64>) {
143+
if let Some(error_tx) = self.error_tx.take() {
144+
_ = error_tx.send((self.make_error)(sent));
145+
self.contents_tx.abort_send();
146+
if let Some(tx) = self.contents_tx.get_ref() {
147+
_ = tx.try_send(Err((self.make_error)(sent)))
148+
}
149+
self.contents_tx.close();
173150
}
174151
}
175152
}
176153

177-
impl Drop for GuestBodyConsumer {
154+
impl Drop for LimitedGuestBodyConsumer {
178155
fn drop(&mut self) {
179-
if let Some(result_tx) = self.result_tx.take() {
180-
if let Some(ContentLength { limit, sent }) = self.content_length {
181-
if !self.closed && limit != sent {
182-
_ = result_tx.send(Err(self.body_size_error(Some(sent))));
183-
self.contents_tx.abort_send();
184-
if let Some(tx) = self.contents_tx.get_ref() {
185-
_ = tx.try_send(Err(self.body_size_error(Some(sent))))
186-
}
187-
}
188-
}
156+
if !self.closed && self.limit != self.sent {
157+
self.send_error(Some(self.sent))
189158
}
190159
}
191160
}
192161

193-
impl<D> StreamConsumer<D> for GuestBodyConsumer {
162+
impl<D> StreamConsumer<D> for LimitedGuestBodyConsumer {
194163
type Item = u8;
195164

196165
fn poll_consume(
@@ -201,27 +170,31 @@ impl<D> StreamConsumer<D> for GuestBodyConsumer {
201170
finish: bool,
202171
) -> Poll<wasmtime::Result<StreamResult>> {
203172
debug_assert!(!self.closed);
173+
let mut src = src.as_direct(store);
174+
let buf = src.remaining();
175+
let n = buf.len();
176+
177+
// Perform `content-length` check early and precompute the next value
178+
let Ok(sent) = n.try_into() else {
179+
self.send_error(None);
180+
return Poll::Ready(Ok(StreamResult::Dropped));
181+
};
182+
let Some(sent) = self.sent.checked_add(sent) else {
183+
self.send_error(None);
184+
return Poll::Ready(Ok(StreamResult::Dropped));
185+
};
186+
if sent > self.limit {
187+
self.send_error(Some(sent));
188+
return Poll::Ready(Ok(StreamResult::Dropped));
189+
}
204190
match self.contents_tx.poll_reserve(cx) {
205191
Poll::Ready(Ok(())) => {
206-
let mut src = src.as_direct(store);
207-
let buf = src.remaining();
208-
if let Some(ContentLength { limit, sent }) = self.content_length.as_mut() {
209-
let Some(n) = buf.len().try_into().ok().and_then(|n| sent.checked_add(n))
210-
else {
211-
self.send_body_size_error(None);
212-
return Poll::Ready(Ok(StreamResult::Dropped));
213-
};
214-
if n > *limit {
215-
self.send_body_size_error(Some(n));
216-
return Poll::Ready(Ok(StreamResult::Dropped));
217-
}
218-
*sent = n;
219-
}
220192
let buf = Bytes::copy_from_slice(buf);
221-
let n = buf.len();
222193
match self.contents_tx.send_item(Ok(buf)) {
223194
Ok(()) => {
224195
src.mark_read(n);
196+
// Record new `content-length` only on successful send
197+
self.sent = sent;
225198
Poll::Ready(Ok(StreamResult::Completed))
226199
}
227200
Err(..) => {
@@ -240,6 +213,41 @@ impl<D> StreamConsumer<D> for GuestBodyConsumer {
240213
}
241214
}
242215

216+
/// [StreamConsumer] implementation for bodies originating in the guest without `Content-Length`
217+
/// header set.
218+
struct UnlimitedGuestBodyConsumer(PollSender<Result<Bytes, ErrorCode>>);
219+
220+
impl<D> StreamConsumer<D> for UnlimitedGuestBodyConsumer {
221+
type Item = u8;
222+
223+
fn poll_consume(
224+
mut self: Pin<&mut Self>,
225+
cx: &mut Context<'_>,
226+
store: StoreContextMut<D>,
227+
src: Source<Self::Item>,
228+
finish: bool,
229+
) -> Poll<wasmtime::Result<StreamResult>> {
230+
match self.0.poll_reserve(cx) {
231+
Poll::Ready(Ok(())) => {
232+
let mut src = src.as_direct(store);
233+
let buf = src.remaining();
234+
let n = buf.len();
235+
let buf = Bytes::copy_from_slice(buf);
236+
match self.0.send_item(Ok(buf)) {
237+
Ok(()) => {
238+
src.mark_read(n);
239+
Poll::Ready(Ok(StreamResult::Completed))
240+
}
241+
Err(..) => Poll::Ready(Ok(StreamResult::Dropped)),
242+
}
243+
}
244+
Poll::Ready(Err(..)) => Poll::Ready(Ok(StreamResult::Dropped)),
245+
Poll::Pending if finish => Poll::Ready(Ok(StreamResult::Cancelled)),
246+
Poll::Pending => Poll::Pending,
247+
}
248+
}
249+
}
250+
243251
/// [http_body::Body] implementation for bodies originating in the guest.
244252
pub(crate) struct GuestBody {
245253
contents_rx: Option<mpsc::Receiver<Result<Bytes, ErrorCode>>>,
@@ -253,9 +261,10 @@ impl GuestBody {
253261
mut store: impl AsContextMut<Data = T>,
254262
contents_rx: Option<StreamReader<u8>>,
255263
trailers_rx: FutureReader<Result<Option<Resource<Trailers>>, ErrorCode>>,
256-
result_tx: oneshot::Sender<Result<(), ErrorCode>>,
264+
result_tx: oneshot::Sender<Box<dyn Future<Output = Result<(), ErrorCode>> + Send>>,
265+
result_fut: impl Future<Output = Result<(), ErrorCode>> + Send + 'static,
257266
content_length: Option<u64>,
258-
kind: BodyKind,
267+
make_error: fn(Option<u64>) -> ErrorCode,
259268
getter: fn(&mut T) -> WasiHttpCtxView<'_>,
260269
) -> Self {
261270
let (trailers_http_tx, trailers_http_rx) = oneshot::channel();
@@ -266,20 +275,38 @@ impl GuestBody {
266275
getter,
267276
},
268277
);
269-
let contents_rx = contents_rx.map(|rx| {
278+
279+
let contents_rx = if let Some(rx) = contents_rx {
270280
let (http_tx, http_rx) = mpsc::channel(1);
271-
rx.pipe(
272-
store,
273-
GuestBodyConsumer {
274-
contents_tx: PollSender::new(http_tx),
275-
result_tx: Some(result_tx),
276-
content_length: content_length.map(ContentLength::new),
277-
kind,
278-
closed: false,
279-
},
280-
);
281-
http_rx
282-
});
281+
let contents_tx = PollSender::new(http_tx);
282+
if let Some(limit) = content_length {
283+
let (error_tx, error_rx) = oneshot::channel();
284+
_ = result_tx.send(Box::new(async move {
285+
if let Ok(err) = error_rx.await {
286+
return Err(err);
287+
};
288+
result_fut.await
289+
}));
290+
rx.pipe(
291+
store,
292+
LimitedGuestBodyConsumer {
293+
contents_tx,
294+
error_tx: Some(error_tx),
295+
make_error,
296+
limit,
297+
sent: 0,
298+
closed: false,
299+
},
300+
);
301+
} else {
302+
_ = result_tx.send(Box::new(result_fut));
303+
rx.pipe(store, UnlimitedGuestBodyConsumer(contents_tx));
304+
};
305+
Some(http_rx)
306+
} else {
307+
_ = result_tx.send(Box::new(result_fut));
308+
None
309+
};
283310
Self {
284311
trailers_rx: Some(trailers_http_rx),
285312
contents_rx,
@@ -303,7 +330,7 @@ impl http_body::Body for GuestBody {
303330
Ok(buf) => {
304331
if let Some(n) = self.content_length.as_mut() {
305332
// Substract frame length from `content_length`,
306-
// [GuestBodyConsumer] already performs the validation, so
333+
// [LimitedGuestBodyConsumer] already performs the validation, so
307334
// just keep count as optimization for
308335
// `is_end_stream` and `size_hint`
309336
*n = n.saturating_sub(buf.len().try_into().unwrap_or(u64::MAX));

0 commit comments

Comments
 (0)