Skip to content

Commit 0757da0

Browse files
committed
fix(http1): flush buffered data before shutdown
Ensures poll_shutdown() flushes buffered write data before shutting down the socket to prevent data loss with slower clients.
1 parent 5778745 commit 0757da0

File tree

4 files changed

+217
-1
lines changed

4 files changed

+217
-1
lines changed

Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,3 +243,8 @@ required-features = ["full"]
243243
name = "server"
244244
path = "tests/server.rs"
245245
required-features = ["full"]
246+
247+
[[test]]
248+
name = "h1_shutdown_while_buffered"
249+
path = "tests/h1_shutdown_while_buffered.rs"
250+
required-features = ["full"]

src/proto/h1/conn.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -832,7 +832,7 @@ where
832832
}
833833

834834
pub(crate) fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
835-
match ready!(Pin::new(self.io.io_mut()).poll_shutdown(cx)) {
835+
match ready!(self.io.poll_shutdown(cx)) {
836836
Ok(()) => {
837837
trace!("shut down IO complete");
838838
Poll::Ready(Ok(()))

src/proto/h1/io.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,11 @@ where
323323
Pin::new(&mut self.io).poll_flush(cx)
324324
}
325325

326+
pub(crate) fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
327+
ready!(self.poll_flush(cx))?;
328+
Pin::new(&mut self.io).poll_shutdown(cx)
329+
}
330+
326331
#[cfg(test)]
327332
fn flush(&mut self) -> impl std::future::Future<Output = io::Result<()>> + '_ {
328333
futures_util::future::poll_fn(move |cx| self.poll_flush(cx))
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
// Test: Ensures poll_shutdown() is never called with buffered data
2+
//
3+
// Reproduces rare timing bug where HTTP/1.1 server calls shutdown() on a socket while response
4+
// data is still buffered (not flushed), leading to data loss.
5+
//
6+
// Scenario:
7+
// 1. Request fully received and read.
8+
// 2. Server computes a "large" response with Full::new()
9+
// 3. Socket accepts only a chunk of response and then pends
10+
// 3. Flush returns Pending (remaining data still buffered), result ignored
11+
// 4. self.conn.wants_read_again() is false and poll_loop returns Ready
12+
// 5. BUG: poll_shutdown called prematurely and buffered body is lost
13+
// 6. FIX: poll_loop checks flush result and returns Pending, giving the chance for poll_loop to
14+
// run again
15+
16+
use std::{
17+
pin::Pin,
18+
sync::{Arc, Mutex},
19+
task::Poll,
20+
time::Duration,
21+
};
22+
23+
use bytes::Bytes;
24+
use http::{Request, Response};
25+
use http_body_util::Full;
26+
use hyper::{body::Incoming, service::service_fn};
27+
use support::TokioIo;
28+
use tokio::{
29+
io::{AsyncRead, AsyncWrite},
30+
net::{TcpListener, TcpStream},
31+
time::{sleep, timeout},
32+
};
33+
mod support;
34+
35+
#[derive(Debug, Default)]
36+
struct PendingStreamStatistics {
37+
bytes_written: usize,
38+
total_attempted: usize,
39+
shutdown_called_with_buffered: bool,
40+
buffered_at_shutdown: usize,
41+
}
42+
43+
// Simple struct that simply does one write and then pends perpetually
44+
struct PendingStream {
45+
inner: TcpStream,
46+
// Keep track of how many times we entered poll_write so as to be able to write only the first
47+
// time out
48+
write_count: usize,
49+
// Only write this chunk size out of full buffer
50+
write_chunk_size: usize,
51+
stats: Arc<Mutex<PendingStreamStatistics>>,
52+
}
53+
54+
impl PendingStream {
55+
fn new(
56+
inner: TcpStream,
57+
write_chunk_size: usize,
58+
stats: Arc<Mutex<PendingStreamStatistics>>,
59+
) -> Self {
60+
Self {
61+
inner,
62+
stats,
63+
write_chunk_size,
64+
write_count: 0,
65+
}
66+
}
67+
}
68+
69+
impl AsyncRead for PendingStream {
70+
fn poll_read(
71+
mut self: Pin<&mut Self>,
72+
cx: &mut std::task::Context<'_>,
73+
buf: &mut tokio::io::ReadBuf<'_>,
74+
) -> Poll<std::io::Result<()>> {
75+
Pin::new(&mut self.inner).poll_read(cx, buf)
76+
}
77+
}
78+
79+
impl AsyncWrite for PendingStream {
80+
fn poll_write(
81+
mut self: Pin<&mut Self>,
82+
cx: &mut std::task::Context<'_>,
83+
buf: &[u8],
84+
) -> Poll<std::io::Result<usize>> {
85+
self.write_count += 1;
86+
87+
let mut stats = self.stats.lock().unwrap();
88+
stats.total_attempted += buf.len();
89+
90+
if self.write_count == 1 {
91+
// First write: partial only
92+
let partial = std::cmp::min(buf.len(), self.write_chunk_size);
93+
drop(stats);
94+
95+
let result = Pin::new(&mut self.inner).poll_write(cx, &buf[..partial]);
96+
if let Poll::Ready(Ok(n)) = result {
97+
self.stats.lock().unwrap().bytes_written += n;
98+
}
99+
return result;
100+
}
101+
102+
// Block all further writes to simulate pending buffer
103+
Poll::Pending
104+
}
105+
106+
fn poll_shutdown(
107+
mut self: Pin<&mut Self>,
108+
cx: &mut std::task::Context<'_>,
109+
) -> Poll<std::io::Result<()>> {
110+
let mut stats = self.stats.lock().unwrap();
111+
let buffered = stats.total_attempted - stats.bytes_written;
112+
113+
if buffered > 0 {
114+
eprintln!(
115+
"\n❌BUG: shutdown() called with {} bytes buffered",
116+
buffered
117+
);
118+
stats.shutdown_called_with_buffered = true;
119+
stats.buffered_at_shutdown = buffered;
120+
}
121+
drop(stats);
122+
Pin::new(&mut self.inner).poll_shutdown(cx)
123+
}
124+
125+
fn poll_flush(
126+
mut self: Pin<&mut Self>,
127+
cx: &mut std::task::Context<'_>,
128+
) -> Poll<std::io::Result<()>> {
129+
let stats = self.stats.lock().unwrap();
130+
let buffered = stats.total_attempted - stats.bytes_written;
131+
132+
if buffered > 0 {
133+
return Poll::Pending;
134+
}
135+
136+
drop(stats);
137+
Pin::new(&mut self.inner).poll_flush(cx)
138+
}
139+
}
140+
141+
// Test doesn't necessarily check that the connections ended successfully but mainly that shutdown
142+
// wasn't called with data still remaining within hyper's internal buffer
143+
#[tokio::test]
144+
async fn test_no_premature_shutdown_while_buffered() {
145+
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
146+
let addr = listener.local_addr().unwrap();
147+
let stats = Arc::new(Mutex::new(PendingStreamStatistics::default()));
148+
149+
let stats_clone = stats.clone();
150+
let server = tokio::spawn(async move {
151+
let (stream, _) = listener.accept().await.unwrap();
152+
let pending_stream = PendingStream::new(stream, 212_992, stats_clone);
153+
let io = TokioIo::new(pending_stream);
154+
155+
let service = service_fn(|_req: Request<Incoming>| async move {
156+
// Larger Full response than write_chunk_size
157+
let body = Full::new(Bytes::from(vec![b'X'; 500_000]));
158+
Ok::<_, hyper::Error>(Response::new(body))
159+
});
160+
161+
hyper::server::conn::http1::Builder::new()
162+
.serve_connection(io, service)
163+
.await
164+
});
165+
166+
// Wait for server to be ready
167+
sleep(Duration::from_millis(50)).await;
168+
169+
// Client sends request
170+
tokio::spawn(async move {
171+
let mut stream = TcpStream::connect(addr).await.unwrap();
172+
173+
use tokio::io::AsyncWriteExt;
174+
175+
stream
176+
.write_all(
177+
b"POST / HTTP/1.1\r\n\
178+
Host: localhost\r\n\
179+
Transfer-Encoding: chunked\r\n\
180+
\r\n",
181+
)
182+
.await
183+
.unwrap();
184+
185+
stream.write_all(b"A\r\nHello World\r\n").await.unwrap();
186+
stream.write_all(b"0\r\n\r\n").await.unwrap();
187+
stream.flush().await.unwrap();
188+
189+
// keep connection open
190+
sleep(Duration::from_secs(2)).await;
191+
});
192+
193+
// Wait for completion
194+
let result = timeout(Duration::from_millis(900), server).await;
195+
196+
let stats = stats.lock().unwrap();
197+
198+
assert!(
199+
!stats.shutdown_called_with_buffered,
200+
"shutdown() called with {} bytes still buffered (wrote {} of {} bytes)",
201+
stats.buffered_at_shutdown, stats.bytes_written, stats.total_attempted
202+
);
203+
if let Ok(Ok(conn_result)) = result {
204+
conn_result.ok();
205+
}
206+
}

0 commit comments

Comments
 (0)