@@ -122,75 +122,44 @@ impl Body {
122122 }
123123}
124124
125- /// The kind of body, used for error reporting
126- pub ( crate ) enum BodyKind {
127- Request ,
128- Response ,
129- }
130-
131- /// Represents `Content-Length` limit and state
132- #[ derive( Copy , Clone , Debug , Eq , PartialEq , Hash ) ]
133- struct ContentLength {
125+ /// [StreamConsumer] implementation for bodies originating in the guest with `Content-Length`
126+ /// header set.
127+ struct LimitedGuestBodyConsumer {
128+ contents_tx : PollSender < Result < Bytes , ErrorCode > > ,
129+ error_tx : Option < oneshot:: Sender < ErrorCode > > ,
130+ make_error : fn ( Option < u64 > ) -> ErrorCode ,
134131 /// Limit of bytes to be sent
135132 limit : u64 ,
136133 /// Number of bytes sent
137134 sent : u64 ,
138- }
139-
140- impl ContentLength {
141- /// Constructs new [ContentLength]
142- fn new ( limit : u64 ) -> Self {
143- Self { limit, sent : 0 }
144- }
145- }
146-
147- /// [StreamConsumer] implementation for bodies originating in the guest.
148- struct GuestBodyConsumer {
149- contents_tx : PollSender < Result < Bytes , ErrorCode > > ,
150- result_tx : Option < oneshot:: Sender < Result < ( ) , ErrorCode > > > ,
151- content_length : Option < ContentLength > ,
152- kind : BodyKind ,
153135 // `true` when the other side of `contents_tx` was unexpectedly closed
154136 closed : bool ,
155137}
156138
157- impl GuestBodyConsumer {
158- /// Constructs the approprite body size error given the [BodyKind]
159- fn body_size_error ( & self , n : Option < u64 > ) -> ErrorCode {
160- match self . kind {
161- BodyKind :: Request => ErrorCode :: HttpRequestBodySize ( n) ,
162- BodyKind :: Response => ErrorCode :: HttpResponseBodySize ( n) ,
163- }
164- }
165-
166- // Sends the corresponding error constructed by [Self::body_size_error] on both
167- // error channels.
168- // [`PollSender::poll_reserve`] on `contents_tx` must have succeeed prior to this being called.
169- fn send_body_size_error ( & mut self , n : Option < u64 > ) {
170- if let Some ( result_tx) = self . result_tx . take ( ) {
171- _ = result_tx. send ( Err ( self . body_size_error ( n) ) ) ;
172- _ = self . contents_tx . send_item ( Err ( self . body_size_error ( n) ) ) ;
139+ impl LimitedGuestBodyConsumer {
140+ /// Sends the error constructed by [Self::make_error] on both error channels.
141+ /// Does nothing if an error has already been sent on [Self::error_tx].
142+ fn send_error ( & mut self , sent : Option < u64 > ) {
143+ if let Some ( error_tx) = self . error_tx . take ( ) {
144+ _ = error_tx. send ( ( self . make_error ) ( sent) ) ;
145+ self . contents_tx . abort_send ( ) ;
146+ if let Some ( tx) = self . contents_tx . get_ref ( ) {
147+ _ = tx. try_send ( Err ( ( self . make_error ) ( sent) ) )
148+ }
149+ self . contents_tx . close ( ) ;
173150 }
174151 }
175152}
176153
177- impl Drop for GuestBodyConsumer {
154+ impl Drop for LimitedGuestBodyConsumer {
178155 fn drop ( & mut self ) {
179- if let Some ( result_tx) = self . result_tx . take ( ) {
180- if let Some ( ContentLength { limit, sent } ) = self . content_length {
181- if !self . closed && limit != sent {
182- _ = result_tx. send ( Err ( self . body_size_error ( Some ( sent) ) ) ) ;
183- self . contents_tx . abort_send ( ) ;
184- if let Some ( tx) = self . contents_tx . get_ref ( ) {
185- _ = tx. try_send ( Err ( self . body_size_error ( Some ( sent) ) ) )
186- }
187- }
188- }
156+ if !self . closed && self . limit != self . sent {
157+ self . send_error ( Some ( self . sent ) )
189158 }
190159 }
191160}
192161
193- impl < D > StreamConsumer < D > for GuestBodyConsumer {
162+ impl < D > StreamConsumer < D > for LimitedGuestBodyConsumer {
194163 type Item = u8 ;
195164
196165 fn poll_consume (
@@ -201,27 +170,31 @@ impl<D> StreamConsumer<D> for GuestBodyConsumer {
201170 finish : bool ,
202171 ) -> Poll < wasmtime:: Result < StreamResult > > {
203172 debug_assert ! ( !self . closed) ;
173+ let mut src = src. as_direct ( store) ;
174+ let buf = src. remaining ( ) ;
175+ let n = buf. len ( ) ;
176+
177+ // Perform `content-length` check early and precompute the next value
178+ let Ok ( sent) = n. try_into ( ) else {
179+ self . send_error ( None ) ;
180+ return Poll :: Ready ( Ok ( StreamResult :: Dropped ) ) ;
181+ } ;
182+ let Some ( sent) = self . sent . checked_add ( sent) else {
183+ self . send_error ( None ) ;
184+ return Poll :: Ready ( Ok ( StreamResult :: Dropped ) ) ;
185+ } ;
186+ if sent > self . limit {
187+ self . send_error ( Some ( sent) ) ;
188+ return Poll :: Ready ( Ok ( StreamResult :: Dropped ) ) ;
189+ }
204190 match self . contents_tx . poll_reserve ( cx) {
205191 Poll :: Ready ( Ok ( ( ) ) ) => {
206- let mut src = src. as_direct ( store) ;
207- let buf = src. remaining ( ) ;
208- if let Some ( ContentLength { limit, sent } ) = self . content_length . as_mut ( ) {
209- let Some ( n) = buf. len ( ) . try_into ( ) . ok ( ) . and_then ( |n| sent. checked_add ( n) )
210- else {
211- self . send_body_size_error ( None ) ;
212- return Poll :: Ready ( Ok ( StreamResult :: Dropped ) ) ;
213- } ;
214- if n > * limit {
215- self . send_body_size_error ( Some ( n) ) ;
216- return Poll :: Ready ( Ok ( StreamResult :: Dropped ) ) ;
217- }
218- * sent = n;
219- }
220192 let buf = Bytes :: copy_from_slice ( buf) ;
221- let n = buf. len ( ) ;
222193 match self . contents_tx . send_item ( Ok ( buf) ) {
223194 Ok ( ( ) ) => {
224195 src. mark_read ( n) ;
196+ // Record new `content-length` only on successful send
197+ self . sent = sent;
225198 Poll :: Ready ( Ok ( StreamResult :: Completed ) )
226199 }
227200 Err ( ..) => {
@@ -240,6 +213,41 @@ impl<D> StreamConsumer<D> for GuestBodyConsumer {
240213 }
241214}
242215
216+ /// [StreamConsumer] implementation for bodies originating in the guest without `Content-Length`
217+ /// header set.
218+ struct UnlimitedGuestBodyConsumer ( PollSender < Result < Bytes , ErrorCode > > ) ;
219+
220+ impl < D > StreamConsumer < D > for UnlimitedGuestBodyConsumer {
221+ type Item = u8 ;
222+
223+ fn poll_consume (
224+ mut self : Pin < & mut Self > ,
225+ cx : & mut Context < ' _ > ,
226+ store : StoreContextMut < D > ,
227+ src : Source < Self :: Item > ,
228+ finish : bool ,
229+ ) -> Poll < wasmtime:: Result < StreamResult > > {
230+ match self . 0 . poll_reserve ( cx) {
231+ Poll :: Ready ( Ok ( ( ) ) ) => {
232+ let mut src = src. as_direct ( store) ;
233+ let buf = src. remaining ( ) ;
234+ let n = buf. len ( ) ;
235+ let buf = Bytes :: copy_from_slice ( buf) ;
236+ match self . 0 . send_item ( Ok ( buf) ) {
237+ Ok ( ( ) ) => {
238+ src. mark_read ( n) ;
239+ Poll :: Ready ( Ok ( StreamResult :: Completed ) )
240+ }
241+ Err ( ..) => Poll :: Ready ( Ok ( StreamResult :: Dropped ) ) ,
242+ }
243+ }
244+ Poll :: Ready ( Err ( ..) ) => Poll :: Ready ( Ok ( StreamResult :: Dropped ) ) ,
245+ Poll :: Pending if finish => Poll :: Ready ( Ok ( StreamResult :: Cancelled ) ) ,
246+ Poll :: Pending => Poll :: Pending ,
247+ }
248+ }
249+ }
250+
243251/// [http_body::Body] implementation for bodies originating in the guest.
244252pub ( crate ) struct GuestBody {
245253 contents_rx : Option < mpsc:: Receiver < Result < Bytes , ErrorCode > > > ,
@@ -253,9 +261,10 @@ impl GuestBody {
253261 mut store : impl AsContextMut < Data = T > ,
254262 contents_rx : Option < StreamReader < u8 > > ,
255263 trailers_rx : FutureReader < Result < Option < Resource < Trailers > > , ErrorCode > > ,
256- result_tx : oneshot:: Sender < Result < ( ) , ErrorCode > > ,
264+ result_tx : oneshot:: Sender < Box < dyn Future < Output = Result < ( ) , ErrorCode > > + Send > > ,
265+ result_fut : impl Future < Output = Result < ( ) , ErrorCode > > + Send + ' static ,
257266 content_length : Option < u64 > ,
258- kind : BodyKind ,
267+ make_error : fn ( Option < u64 > ) -> ErrorCode ,
259268 getter : fn ( & mut T ) -> WasiHttpCtxView < ' _ > ,
260269 ) -> Self {
261270 let ( trailers_http_tx, trailers_http_rx) = oneshot:: channel ( ) ;
@@ -266,20 +275,38 @@ impl GuestBody {
266275 getter,
267276 } ,
268277 ) ;
269- let contents_rx = contents_rx. map ( |rx| {
278+
279+ let contents_rx = if let Some ( rx) = contents_rx {
270280 let ( http_tx, http_rx) = mpsc:: channel ( 1 ) ;
271- rx. pipe (
272- store,
273- GuestBodyConsumer {
274- contents_tx : PollSender :: new ( http_tx) ,
275- result_tx : Some ( result_tx) ,
276- content_length : content_length. map ( ContentLength :: new) ,
277- kind,
278- closed : false ,
279- } ,
280- ) ;
281- http_rx
282- } ) ;
281+ let contents_tx = PollSender :: new ( http_tx) ;
282+ if let Some ( limit) = content_length {
283+ let ( error_tx, error_rx) = oneshot:: channel ( ) ;
284+ _ = result_tx. send ( Box :: new ( async move {
285+ if let Ok ( err) = error_rx. await {
286+ return Err ( err) ;
287+ } ;
288+ result_fut. await
289+ } ) ) ;
290+ rx. pipe (
291+ store,
292+ LimitedGuestBodyConsumer {
293+ contents_tx,
294+ error_tx : Some ( error_tx) ,
295+ make_error,
296+ limit,
297+ sent : 0 ,
298+ closed : false ,
299+ } ,
300+ ) ;
301+ } else {
302+ _ = result_tx. send ( Box :: new ( result_fut) ) ;
303+ rx. pipe ( store, UnlimitedGuestBodyConsumer ( contents_tx) ) ;
304+ } ;
305+ Some ( http_rx)
306+ } else {
307+ _ = result_tx. send ( Box :: new ( result_fut) ) ;
308+ None
309+ } ;
283310 Self {
284311 trailers_rx : Some ( trailers_http_rx) ,
285312 contents_rx,
@@ -303,7 +330,7 @@ impl http_body::Body for GuestBody {
303330 Ok ( buf) => {
304331 if let Some ( n) = self . content_length . as_mut ( ) {
305332 // Substract frame length from `content_length`,
306- // [GuestBodyConsumer ] already performs the validation, so
333+ // [LimitedGuestBodyConsumer ] already performs the validation, so
307334 // just keep count as optimization for
308335 // `is_end_stream` and `size_hint`
309336 * n = n. saturating_sub ( buf. len ( ) . try_into ( ) . unwrap_or ( u64:: MAX ) ) ;
0 commit comments