Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/client/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,9 @@ where
}
};
trace!("send_when canceled");
// Tell pipe_task to reset the h2 stream so that
// RST_STREAM is sent and flow-control capacity freed.
this.when.as_mut().cancel();
Poll::Ready(())
}
Poll::Ready(Err((error, message))) => {
Expand Down
37 changes: 37 additions & 0 deletions src/proto/h2/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ pin_project! {
conn_drop_ref: Option<Sender<Infallible>>,
#[pin]
ping: Option<Recorder>,
cancel_rx: Option<oneshot::Receiver<()>>,
}
}

Expand All @@ -474,6 +475,26 @@ where
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> std::task::Poll<Self::Output> {
let mut this = self.project();

// Check if the client cancelled the request (e.g. dropped the
// response future due to a timeout). If so, reset the h2 stream
// so that a RST_STREAM is sent and flow-control capacity is freed.
let cancel_result = this.cancel_rx.as_mut().map(|rx| Pin::new(rx).poll(cx));
match cancel_result {
Some(Poll::Ready(Ok(()))) => {
debug!("client request body send cancelled, resetting stream");
this.pipe.as_mut().send_reset(h2::Reason::CANCEL);
drop(this.conn_drop_ref.take().expect("Future polled twice"));
drop(this.ping.take().expect("Future polled twice"));
return Poll::Ready(());
}
Some(Poll::Ready(Err(_))) => {
// Sender dropped without cancelling (normal response or error).
// Stop polling the receiver.
*this.cancel_rx = None;
}
Some(Poll::Pending) | None => {}
}

match Pin::new(&mut this.pipe).poll(cx) {
Poll::Ready(result) => {
if let Err(_e) = result {
Expand All @@ -500,6 +521,10 @@ where
fn poll_pipe(&mut self, f: FutCtx<B>, cx: &mut Context<'_>) {
let ping = self.ping.clone();

// A one-shot channel so that send_task can tell pipe_task to
// reset the stream when the client cancels the request.
let (cancel_tx, cancel_rx) = oneshot::channel::<()>();

let send_stream = if !f.is_connect {
if !f.eos {
let mut pipe = PipeToSendStream::new(f.body, f.body_tx);
Expand All @@ -519,6 +544,7 @@ where
pipe,
conn_drop_ref: Some(conn_drop_ref),
ping: Some(ping),
cancel_rx: Some(cancel_rx),
};
// Clear send task
self.executor
Expand All @@ -539,6 +565,7 @@ where
ping: Some(ping),
send_stream: Some(send_stream),
exec: self.executor.clone(),
cancel_tx: Some(cancel_tx),
},
call_back: Some(f.cb),
},
Expand All @@ -558,6 +585,16 @@ pin_project! {
#[pin]
send_stream: Option<Option<SendStream<SendBuf<<B as Body>::Data>>>>,
exec: E,
cancel_tx: Option<oneshot::Sender<()>>,
}
}

impl<B: Body + 'static, E> ResponseFutMap<B, E> {
/// Signal the pipe_task to reset the stream (e.g. on client cancellation).
pub(crate) fn cancel(self: Pin<&mut Self>) {
if let Some(cancel_tx) = self.project().cancel_tx.take() {
let _ = cancel_tx.send(());
}
}
}

Expand Down
4 changes: 4 additions & 0 deletions src/proto/h2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ where
stream,
}
}

fn send_reset(self: Pin<&mut Self>, reason: h2::Reason) {
Comment thread
mmishra100 marked this conversation as resolved.
self.project().body_tx.send_reset(reason);
}
}

impl<S> Future for PipeToSendStream<S>
Expand Down
41 changes: 41 additions & 0 deletions tests/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2868,6 +2868,47 @@ mod conn {
Pin::new(&mut self.tcp).poll_read(cx, buf)
}
}

// https://github.com/hyperium/hyper/issues/4040
#[tokio::test]
async fn h2_pipe_task_cancelled_on_response_future_drop() {
let (client_io, server_io, _) = setup_duplex_test_server();
let (rst_tx, rst_rx) = oneshot::channel::<bool>();

tokio::spawn(async move {
let mut builder = h2::server::Builder::new();
builder.initial_window_size(0);
let mut h2 = builder.handshake::<_, Bytes>(server_io).await.unwrap();
let (req, _respond) = h2.accept().await.unwrap().unwrap();
tokio::spawn(async move {
let _ = poll_fn(|cx| h2.poll_closed(cx)).await;
});

let mut body = req.into_body();
let got_rst = tokio::time::timeout(Duration::from_secs(2), body.data())
.await
.map_or(false, |frame| matches!(frame, Some(Err(_)) | None));
let _ = rst_tx.send(got_rst);
});

let io = TokioIo::new(client_io);
let (mut client, conn) = conn::http2::Builder::new(TokioExecutor)
.handshake(io)
.await
.expect("http handshake");
tokio::spawn(async move {
let _ = conn.await;
});

let req = Request::post("http://localhost/")
.body(Full::new(Bytes::from(vec![b'x'; 50])))
.unwrap();
let res = tokio::time::timeout(Duration::from_millis(5), client.send_request(req)).await;
assert!(res.is_err(), "should timeout waiting for response");

let got_rst = rst_rx.await.expect("server task should complete");
assert!(got_rst, "server should receive RST_STREAM");
}
}

trait FutureHyperExt: TryFuture {
Expand Down
Loading