@@ -3,15 +3,16 @@ use core::mem;
33use core:: pin:: Pin ;
44use core:: task:: { Context , Poll , ready} ;
55
6- use bytes:: { Bytes , BytesMut } ;
6+ use bytes:: { Buf , Bytes , BytesMut } ;
77use http:: HeaderMap ;
8+ use http_body:: Frame ;
89use http_body_util:: BodyExt as _;
910use http_body_util:: combinators:: BoxBody ;
10- use tokio:: sync:: { mpsc, oneshot} ;
11+ use pin_project_lite:: pin_project;
12+ use tokio:: sync:: mpsc;
1113use wasmtime:: component:: { FutureWriter , Resource , StreamReader } ;
12- use wasmtime_wasi:: p3:: { AbortOnDropHandle , WithChildren } ;
14+ use wasmtime_wasi:: p3:: WithChildren ;
1315
14- use crate :: p3:: DEFAULT_BUFFER_CAPACITY ;
1516use crate :: p3:: bindings:: http:: types:: ErrorCode ;
1617
1718pub ( crate ) type OutgoingContentsStreamFuture =
@@ -104,33 +105,6 @@ impl Body {
104105 }
105106}
106107
107- pub ( crate ) struct OutgoingRequestTrailers {
108- pub trailers : Option < oneshot:: Receiver < Result < Option < HeaderMap > , ErrorCode > > > ,
109- #[ expect( dead_code, reason = "here for the dtor" ) ]
110- pub trailer_task : AbortOnDropHandle ,
111- }
112-
113- impl Future for OutgoingRequestTrailers {
114- type Output = Option < Result < HeaderMap , Option < ErrorCode > > > ;
115-
116- fn poll (
117- mut self : Pin < & mut Self > ,
118- cx : & mut Context < ' _ > ,
119- ) -> Poll < Option < Result < HeaderMap , Option < ErrorCode > > > > {
120- let Some ( trailers) = & mut self . trailers else {
121- return Poll :: Ready ( None ) ;
122- } ;
123- let trailers = ready ! ( Pin :: new( trailers) . poll( cx) ) ;
124- self . trailers = None ;
125- match trailers {
126- Ok ( Ok ( Some ( trailers) ) ) => Poll :: Ready ( Some ( Ok ( trailers) ) ) ,
127- Ok ( Ok ( None ) ) => Poll :: Ready ( None ) ,
128- Ok ( Err ( err) ) => Poll :: Ready ( Some ( Err ( Some ( err) ) ) ) ,
129- Err ( ..) => Poll :: Ready ( Some ( Err ( None ) ) ) , // future was dropped without writing a result
130- }
131- }
132- }
133-
134108/// Represents `Content-Length` limit and state
135109#[ derive( Copy , Clone , Debug , Eq , PartialEq , Hash ) ]
136110pub struct ContentLength {
@@ -240,101 +214,150 @@ impl http_body::Body for OutgoingResponseBody {
240214 }
241215}
242216
243- /// Request body constructed by the guest
244- pub ( crate ) struct OutgoingRequestBody {
245- pub contents : Option < OutgoingContentsStreamFuture > ,
246- pub buffer : Bytes ,
247- pub content_length : Option < ContentLength > ,
217+ /// Helper structure to validate that the body `B` provided matches the
218+ /// content length specified in its header.
219+ ///
220+ /// This will behave as if it were `B` except that an error will be
221+ /// generated if too much data is generated or if too little data is
222+ /// generated. This body will only succeed if the `body` contained produces
223+ /// exactly `remaining` bytes.
224+ pub ( crate ) struct BodyChannel < D , E > {
225+ rx : mpsc:: Receiver < Result < D , E > > ,
248226}
249227
250- impl OutgoingRequestBody {
251- pub fn new (
252- contents : OutgoingContentsStreamFuture ,
253- buffer : Bytes ,
228+ impl < D , E > BodyChannel < D , E > {
229+ pub ( crate ) fn new ( rx : mpsc:: Receiver < Result < D , E > > ) -> Self {
230+ BodyChannel { rx }
231+ }
232+ }
233+
234+ impl < D : Buf , E > http_body:: Body for BodyChannel < D , E > {
235+ type Data = D ;
236+ type Error = E ;
237+
238+ fn poll_frame (
239+ mut self : Pin < & mut Self > ,
240+ cx : & mut Context < ' _ > ,
241+ ) -> Poll < Option < Result < http_body:: Frame < Self :: Data > , Self :: Error > > > {
242+ match self . rx . poll_recv ( cx) {
243+ Poll :: Ready ( Some ( Ok ( frame) ) ) => Poll :: Ready ( Some ( Ok ( Frame :: data ( frame) ) ) ) ,
244+ Poll :: Ready ( Some ( Err ( err) ) ) => Poll :: Ready ( Some ( Err ( err) ) ) ,
245+ Poll :: Ready ( None ) => Poll :: Ready ( None ) ,
246+ Poll :: Pending => Poll :: Pending ,
247+ }
248+ }
249+ }
250+
251+ pin_project ! {
252+ /// Helper structure to validate that the body `B` provided matches the
253+ /// content length specified in its header.
254+ ///
255+ /// This will behave as if it were `B` except that an error will be
256+ /// generated if too much data is generated or if too little data is
257+ /// generated. This body will only succeed if the `body` contained produces
258+ /// exactly `remaining` bytes.
259+ pub ( crate ) struct BodyWithContentLength <B > {
260+ #[ pin]
261+ body: B ,
254262 content_length: Option <ContentLength >,
255- ) -> Self {
256- Self {
257- contents : Some ( contents) ,
258- buffer,
263+ body_length_mismatch: bool ,
264+ }
265+ }
266+
267+ impl < B > BodyWithContentLength < B > {
268+ pub ( crate ) fn new ( body : B , content_length : Option < ContentLength > ) -> BodyWithContentLength < B > {
269+ BodyWithContentLength {
270+ body,
259271 content_length,
272+ body_length_mismatch : false ,
260273 }
261274 }
262275}
263276
264- impl http_body:: Body for OutgoingRequestBody {
265- type Data = Bytes ;
266- type Error = Option < ErrorCode > ;
277+ pub ( crate ) trait ContentLengthError : Sized {
278+ fn body_too_long ( amt : Option < u64 > ) -> Self ;
279+ fn body_too_short ( amt : Option < u64 > ) -> Self ;
280+ }
281+
282+ impl < B > http_body:: Body for BodyWithContentLength < B >
283+ where
284+ B : http_body:: Body ,
285+ B :: Error : ContentLengthError ,
286+ {
287+ type Data = B :: Data ;
288+ type Error = B :: Error ;
267289
268290 fn poll_frame (
269- mut self : Pin < & mut Self > ,
291+ self : Pin < & mut Self > ,
270292 cx : & mut Context < ' _ > ,
271293 ) -> Poll < Option < Result < http_body:: Frame < Self :: Data > , Self :: Error > > > {
272- if !self . buffer . is_empty ( ) {
273- let buffer = mem:: take ( & mut self . buffer ) ;
274- if let Some ( ContentLength { limit, sent } ) = & mut self . content_length {
275- let Ok ( n) = buffer. len ( ) . try_into ( ) else {
276- return Poll :: Ready ( Some ( Err ( Some ( ErrorCode :: HttpRequestBodySize ( None ) ) ) ) ) ;
277- } ;
278- let Some ( n) = sent. checked_add ( n) else {
279- return Poll :: Ready ( Some ( Err ( Some ( ErrorCode :: HttpRequestBodySize ( None ) ) ) ) ) ;
280- } ;
281- if n > * limit {
282- return Poll :: Ready ( Some ( Err ( Some ( ErrorCode :: HttpRequestBodySize ( Some ( n) ) ) ) ) ) ;
283- }
284- * sent = n;
285- }
286- return Poll :: Ready ( Some ( Ok ( http_body:: Frame :: data ( buffer) ) ) ) ;
287- }
288- let Some ( stream) = & mut self . contents else {
294+ let mut this = self . project ( ) ;
295+ if * this. body_length_mismatch {
289296 return Poll :: Ready ( None ) ;
297+ }
298+ let frame = match Pin :: new ( & mut this. body ) . poll_frame ( cx) {
299+ Poll :: Ready ( frame) => frame,
300+ Poll :: Pending => return Poll :: Pending ,
290301 } ;
291- let ( tail, mut rx_buffer) = ready ! ( Pin :: new( stream) . poll( cx) ) ;
292- match tail {
293- Some ( tail) => {
294- let buffer = rx_buffer. split ( ) ;
295- rx_buffer. reserve ( DEFAULT_BUFFER_CAPACITY ) ;
296- self . contents = Some ( Box :: pin ( tail. read ( rx_buffer) ) ) ;
297- if let Some ( ContentLength { limit, sent } ) = & mut self . content_length {
298- let Ok ( n) = buffer. len ( ) . try_into ( ) else {
299- return Poll :: Ready ( Some ( Err ( Some ( ErrorCode :: HttpRequestBodySize ( None ) ) ) ) ) ;
300- } ;
301- let Some ( n) = sent. checked_add ( n) else {
302- return Poll :: Ready ( Some ( Err ( Some ( ErrorCode :: HttpRequestBodySize ( None ) ) ) ) ) ;
303- } ;
304- if n > * limit {
305- return Poll :: Ready ( Some ( Err ( Some ( ErrorCode :: HttpRequestBodySize ( Some (
306- n,
307- ) ) ) ) ) ) ;
302+ let content_length = match & mut this. content_length {
303+ Some ( content_length) => content_length,
304+ None => return Poll :: Ready ( frame) ,
305+ } ;
306+ let res = match frame {
307+ Some ( Ok ( frame) ) => {
308+ if let Some ( data) = frame. data_ref ( ) {
309+ let data_len = u64:: try_from ( data. remaining ( ) ) . unwrap ( ) ;
310+ content_length. sent = content_length. sent . saturating_add ( data_len) ;
311+ if content_length. sent > content_length. limit {
312+ * this. body_length_mismatch = true ;
313+ Some ( Err ( B :: Error :: body_too_long ( Some ( content_length. sent ) ) ) )
314+ } else {
315+ Some ( Ok ( frame) )
308316 }
309- * sent = n;
317+ } else {
318+ Some ( Ok ( frame) )
310319 }
311- Poll :: Ready ( Some ( Ok ( http_body:: Frame :: data ( buffer. freeze ( ) ) ) ) )
312320 }
321+ Some ( Err ( err) ) => Some ( Err ( err) ) ,
313322 None => {
314- debug_assert ! ( rx_buffer. is_empty( ) ) ;
315- self . contents = None ;
316- if let Some ( ContentLength { limit, sent } ) = self . content_length {
317- if limit != sent {
318- return Poll :: Ready ( Some ( Err ( Some ( ErrorCode :: HttpRequestBodySize ( Some (
319- sent,
320- ) ) ) ) ) ) ;
321- }
323+ if content_length. sent != content_length. limit {
324+ * this. body_length_mismatch = true ;
325+ Some ( Err ( B :: Error :: body_too_short ( Some ( content_length. sent ) ) ) )
326+ } else {
327+ None
322328 }
323- Poll :: Ready ( None )
324329 }
325- }
330+ } ;
331+
332+ Poll :: Ready ( res)
326333 }
327334
328335 fn is_end_stream ( & self ) -> bool {
329- self . contents . is_none ( )
336+ self . body . is_end_stream ( )
330337 }
331338
332339 fn size_hint ( & self ) -> http_body:: SizeHint {
333- if let Some ( ContentLength { limit, sent } ) = self . content_length {
334- http_body:: SizeHint :: with_exact ( limit. saturating_sub ( sent) )
335- } else {
336- http_body:: SizeHint :: default ( )
340+ let mut hint = self . body . size_hint ( ) ;
341+ if let Some ( content_length) = self . content_length {
342+ let remaining = content_length. limit . saturating_sub ( content_length. sent ) ;
343+ if hint. lower ( ) >= remaining {
344+ hint. set_exact ( remaining)
345+ } else if let Some ( max) = hint. upper ( ) {
346+ hint. set_upper ( remaining. min ( max) )
347+ } else {
348+ hint. set_upper ( remaining)
349+ }
337350 }
351+ hint
352+ }
353+ }
354+
355+ impl ContentLengthError for Option < ErrorCode > {
356+ fn body_too_long ( amt : Option < u64 > ) -> Self {
357+ Some ( ErrorCode :: HttpRequestBodySize ( amt) )
358+ }
359+ fn body_too_short ( amt : Option < u64 > ) -> Self {
360+ Some ( ErrorCode :: HttpRequestBodySize ( amt) )
338361 }
339362}
340363
0 commit comments