Skip to content

Commit d7bd4d0

Browse files
committed
fix: prevent reuse of the stream after an error
When a stream timeouts, `tokio_io_timeout::TimeoutStream` returns an error once, but then allows to keep using the stream, e.g. calling `poll_read()` again. This can be dangerous if the error is ignored. For example in case of IMAP stream, if IMAP command is sent, but then reading the response times out and the error is ignored, it is possible to send another IMAP command. In this case leftover response from a previous command may be read and interpreted as the response to the new IMAP command. ErrorCapturingStream wraps the stream to prevent its reuse after an error.
1 parent e3973f6 commit d7bd4d0

4 files changed

Lines changed: 144 additions & 8 deletions

File tree

src/net.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@ use crate::sql::Sql;
1616
use crate::tools::time;
1717

1818
pub(crate) mod dns;
19+
pub(crate) mod error_capturing_stream;
1920
pub(crate) mod http;
2021
pub(crate) mod proxy;
2122
pub(crate) mod session;
2223
pub(crate) mod tls;
2324

2425
use dns::lookup_host_with_cache;
26+
pub(crate) use error_capturing_stream::ErrorCapturingStream;
2527
pub use http::{Response as HttpResponse, read_url, read_url_blob};
2628
use tls::wrap_tls;
2729

@@ -105,7 +107,7 @@ pub(crate) async fn load_connection_timestamp(
105107
/// to the network, which is important to reduce the latency of interactive protocols such as IMAP.
106108
pub(crate) async fn connect_tcp_inner(
107109
addr: SocketAddr,
108-
) -> Result<Pin<Box<TimeoutStream<TcpStream>>>> {
110+
) -> Result<Pin<Box<ErrorCapturingStream<TimeoutStream<TcpStream>>>>> {
109111
let tcp_stream = timeout(TIMEOUT, TcpStream::connect(addr))
110112
.await
111113
.context("connection timeout")?
@@ -118,7 +120,9 @@ pub(crate) async fn connect_tcp_inner(
118120
timeout_stream.set_write_timeout(Some(TIMEOUT));
119121
timeout_stream.set_read_timeout(Some(TIMEOUT));
120122

121-
Ok(Box::pin(timeout_stream))
123+
let error_capturing_stream = ErrorCapturingStream::new(timeout_stream);
124+
125+
Ok(Box::pin(error_capturing_stream))
122126
}
123127

124128
/// Attempts to establish TLS connection
@@ -235,7 +239,7 @@ pub(crate) async fn connect_tcp(
235239
host: &str,
236240
port: u16,
237241
load_cache: bool,
238-
) -> Result<Pin<Box<TimeoutStream<TcpStream>>>> {
242+
) -> Result<Pin<Box<ErrorCapturingStream<TimeoutStream<TcpStream>>>>> {
239243
let connection_futures = lookup_host_with_cache(context, host, port, "", load_cache)
240244
.await?
241245
.into_iter()

src/net/error_capturing_stream.rs

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
use std::io::IoSlice;
2+
use std::net::SocketAddr;
3+
use std::pin::Pin;
4+
use std::task::{Context, Poll};
5+
use std::time::Duration;
6+
use tokio::io::{self, AsyncRead, AsyncWrite, ReadBuf};
7+
8+
use pin_project::pin_project;
9+
10+
use crate::net::SessionStream;
11+
12+
/// Stream that remembers the first error
13+
/// and keeps returning it afterwards.
14+
///
15+
/// It is needed to avoid accidentally using
16+
/// the stream after read timeout.
17+
#[derive(Debug)]
18+
#[pin_project]
19+
pub(crate) struct ErrorCapturingStream<T: AsyncRead + AsyncWrite + std::fmt::Debug> {
20+
#[pin]
21+
inner: T,
22+
23+
/// If true, the stream has already returned an error once.
24+
///
25+
/// All read and write operations return error in this case.
26+
is_broken: bool,
27+
}
28+
29+
impl<T: AsyncRead + AsyncWrite + std::fmt::Debug> ErrorCapturingStream<T> {
30+
pub fn new(inner: T) -> Self {
31+
Self {
32+
inner,
33+
is_broken: false,
34+
}
35+
}
36+
37+
/// Gets a reference to the underlying stream.
38+
pub fn get_ref(&self) -> &T {
39+
&self.inner
40+
}
41+
42+
/// Gets a pinned mutable reference to the underlying stream.
43+
pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> {
44+
self.project().inner
45+
}
46+
}
47+
48+
impl<T: AsyncRead + AsyncWrite + std::fmt::Debug> AsyncRead for ErrorCapturingStream<T> {
49+
fn poll_read(
50+
self: Pin<&mut Self>,
51+
cx: &mut Context<'_>,
52+
buf: &mut ReadBuf,
53+
) -> Poll<io::Result<()>> {
54+
let this = self.project();
55+
if *this.is_broken {
56+
return Poll::Ready(Err(io::Error::other("Broken stream")));
57+
}
58+
match this.inner.poll_read(cx, buf) {
59+
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
60+
Poll::Ready(Err(err)) => {
61+
*this.is_broken = true;
62+
Poll::Ready(Err(err))
63+
}
64+
Poll::Pending => Poll::Pending,
65+
}
66+
}
67+
}
68+
69+
impl<T: AsyncRead + AsyncWrite + std::fmt::Debug> AsyncWrite for ErrorCapturingStream<T> {
70+
fn poll_write(
71+
self: Pin<&mut Self>,
72+
cx: &mut Context<'_>,
73+
buf: &[u8],
74+
) -> Poll<io::Result<usize>> {
75+
let this = self.project();
76+
if *this.is_broken {
77+
return Poll::Ready(Err(io::Error::other("Broken stream")));
78+
}
79+
match this.inner.poll_write(cx, buf) {
80+
Poll::Ready(Ok(size)) => Poll::Ready(Ok(size)),
81+
Poll::Ready(Err(err)) => {
82+
*this.is_broken = true;
83+
Poll::Ready(Err(err))
84+
}
85+
Poll::Pending => Poll::Pending,
86+
}
87+
}
88+
89+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
90+
let this = self.project();
91+
if *this.is_broken {
92+
return Poll::Ready(Err(io::Error::other("Broken stream")));
93+
}
94+
this.inner.poll_flush(cx)
95+
}
96+
97+
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
98+
let this = self.project();
99+
if *this.is_broken {
100+
return Poll::Ready(Err(io::Error::other("Broken stream")));
101+
}
102+
this.inner.poll_shutdown(cx)
103+
}
104+
105+
fn poll_write_vectored(
106+
self: Pin<&mut Self>,
107+
cx: &mut Context<'_>,
108+
bufs: &[IoSlice<'_>],
109+
) -> Poll<io::Result<usize>> {
110+
let this = self.project();
111+
if *this.is_broken {
112+
return Poll::Ready(Err(io::Error::other("Broken stream")));
113+
}
114+
this.inner.poll_write_vectored(cx, bufs)
115+
}
116+
117+
fn is_write_vectored(&self) -> bool {
118+
self.inner.is_write_vectored()
119+
}
120+
}
121+
122+
impl<T: SessionStream> SessionStream for ErrorCapturingStream<T> {
123+
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
124+
self.inner.set_read_timeout(timeout)
125+
}
126+
127+
fn peer_addr(&self) -> anyhow::Result<SocketAddr> {
128+
self.inner.peer_addr()
129+
}
130+
}

src/net/proxy.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ use url::Url;
2121
use crate::config::Config;
2222
use crate::constants::NON_ALPHANUMERIC_WITHOUT_DOT;
2323
use crate::context::Context;
24-
use crate::net::connect_tcp;
2524
use crate::net::session::SessionStream;
2625
use crate::net::tls::wrap_rustls;
26+
use crate::net::{ErrorCapturingStream, connect_tcp};
2727
use crate::sql::Sql;
2828

2929
/// Default SOCKS5 port according to [RFC 1928](https://tools.ietf.org/html/rfc1928).
@@ -118,7 +118,7 @@ impl Socks5Config {
118118
target_host: &str,
119119
target_port: u16,
120120
load_dns_cache: bool,
121-
) -> Result<Socks5Stream<Pin<Box<TimeoutStream<TcpStream>>>>> {
121+
) -> Result<Socks5Stream<Pin<Box<ErrorCapturingStream<TimeoutStream<TcpStream>>>>>> {
122122
let tcp_stream = connect_tcp(context, &self.host, self.port, load_dns_cache)
123123
.await
124124
.context("Failed to connect to SOCKS5 proxy")?;

src/net/session.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, BufStream, BufWriter};
77
use tokio::net::TcpStream;
88
use tokio_io_timeout::TimeoutStream;
99

10+
use crate::net::ErrorCapturingStream;
11+
1012
pub(crate) trait SessionStream:
1113
AsyncRead + AsyncWrite + Unpin + Send + Sync + std::fmt::Debug
1214
{
@@ -61,13 +63,13 @@ impl<T: SessionStream> SessionStream for BufWriter<T> {
6163
self.get_ref().peer_addr()
6264
}
6365
}
64-
impl SessionStream for Pin<Box<TimeoutStream<TcpStream>>> {
66+
impl SessionStream for Pin<Box<ErrorCapturingStream<TimeoutStream<TcpStream>>>> {
6567
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
66-
self.as_mut().set_read_timeout_pinned(timeout);
68+
self.as_mut().get_pin_mut().set_read_timeout_pinned(timeout);
6769
}
6870

6971
fn peer_addr(&self) -> Result<SocketAddr> {
70-
Ok(self.get_ref().peer_addr()?)
72+
Ok(self.get_ref().get_ref().peer_addr()?)
7173
}
7274
}
7375
impl<T: SessionStream> SessionStream for Socks5Stream<T> {

0 commit comments

Comments
 (0)