Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ exclude = ["/.github", "/examples", "/scripts", "/tests/"]

[dependencies]
rustls = { version = "0.23.27", default-features = false, features = ["std"] }
tokio = "1.0"
tokio = { version = "1.0", features = ["time"] }

[features]
default = ["logging", "tls12", "aws_lc_rs"]
Expand Down
134 changes: 106 additions & 28 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,36 @@ use std::pin::Pin;
#[cfg(feature = "early-data")]
use std::task::Waker;
use std::task::{Context, Poll};
use std::time::Duration;
use std::{
io::{self, BufRead as _},
sync::Arc,
};

use rustls::{pki_types::ServerName, ClientConfig, ClientConnection};
use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
use tokio::time::{self, Sleep};

use crate::common::{IoSession, MidHandshake, Stream, TlsState};

/// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method.
#[derive(Clone)]
pub struct TlsConnector {
inner: Arc<ClientConfig>,
handshake_timeout: Option<Duration>,
#[cfg(feature = "early-data")]
early_data: bool,
}

impl TlsConnector {
/// Set the maximum amount of time to allow for TLS handshakes.
///
/// `None` disables the handshake timeout.
pub fn with_handshake_timeout(mut self, timeout: Option<Duration>) -> Self {
self.handshake_timeout = timeout;
self
}

/// Enable 0-RTT.
///
/// If you want to use 0-RTT,
Expand Down Expand Up @@ -68,36 +79,42 @@ impl TlsConnector {
let mut session = match ClientConnection::new_with_alpn(self.inner.clone(), domain, alpn) {
Ok(session) => session,
Err(error) => {
return Connect(MidHandshake::Error {
io: stream,
// TODO(eliza): should this really return an `io::Error`?
// Probably not...
error: io::Error::new(io::ErrorKind::Other, error),
});
return Connect {
inner: MidHandshake::Error {
io: stream,
// TODO(eliza): should this really return an `io::Error`?
// Probably not...
error: io::Error::new(io::ErrorKind::Other, error),
},
timeout: self.handshake_timeout.map(HandshakeTimeout::new),
};
}
};
f(&mut session);

Connect(MidHandshake::Handshaking(TlsStream {
io: stream,
Connect {
inner: MidHandshake::Handshaking(TlsStream {
io: stream,

#[cfg(not(feature = "early-data"))]
state: TlsState::Stream,
#[cfg(not(feature = "early-data"))]
state: TlsState::Stream,

#[cfg(feature = "early-data")]
state: if self.early_data && session.early_data().is_some() {
TlsState::EarlyData(0, Vec::new())
} else {
TlsState::Stream
},
#[cfg(feature = "early-data")]
state: if self.early_data && session.early_data().is_some() {
TlsState::EarlyData(0, Vec::new())
} else {
TlsState::Stream
},

need_flush: false,
need_flush: false,

#[cfg(feature = "early-data")]
early_waker: None,
#[cfg(feature = "early-data")]
early_waker: None,

session,
}))
session,
}),
timeout: self.handshake_timeout.map(HandshakeTimeout::new),
}
}

pub fn with_alpn(&self, alpn_protocols: Vec<Vec<u8>>) -> TlsConnectorWithAlpn<'_> {
Expand All @@ -117,6 +134,7 @@ impl From<Arc<ClientConfig>> for TlsConnector {
fn from(inner: Arc<ClientConfig>) -> Self {
Self {
inner,
handshake_timeout: None,
#[cfg(feature = "early-data")]
early_data: false,
}
Expand Down Expand Up @@ -151,16 +169,43 @@ impl TlsConnectorWithAlpn<'_> {

/// Future returned from `TlsConnector::connect` which will resolve
/// once the connection handshake has finished.
pub struct Connect<IO>(MidHandshake<TlsStream<IO>>);
pub struct Connect<IO> {
inner: MidHandshake<TlsStream<IO>>,
timeout: Option<HandshakeTimeout>,
}

struct HandshakeTimeout {
duration: Duration,
sleep: Option<Pin<Box<Sleep>>>,
}

impl HandshakeTimeout {
fn new(duration: Duration) -> Self {
Self {
duration,
sleep: None,
}
}

fn poll(&mut self, cx: &mut Context<'_>) -> Poll<()> {
let sleep = self
.sleep
.get_or_insert_with(|| Box::pin(time::sleep(self.duration)));
sleep.as_mut().poll(cx)
}
}

impl<IO> Connect<IO> {
#[inline]
pub fn into_fallible(self) -> FallibleConnect<IO> {
FallibleConnect(self.0)
FallibleConnect {
inner: self.inner,
timeout: self.timeout,
}
}

pub fn get_ref(&self) -> Option<&IO> {
match &self.0 {
match &self.inner {
MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
MidHandshake::SendAlert { io, .. } => Some(io),
MidHandshake::Error { io, .. } => Some(io),
Expand All @@ -169,7 +214,7 @@ impl<IO> Connect<IO> {
}

pub fn get_mut(&mut self) -> Option<&mut IO> {
match &mut self.0 {
match &mut self.inner {
MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
MidHandshake::SendAlert { io, .. } => Some(io),
MidHandshake::Error { io, .. } => Some(io),
Expand All @@ -183,7 +228,8 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {

#[inline]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
let this = self.as_mut().get_mut();
poll_fallible_connect(&mut this.inner, &mut this.timeout, cx).map_err(|(err, _)| err)
}
}

Expand All @@ -192,12 +238,44 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleConnect<IO> {

#[inline]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.0).poll(cx)
let this = self.as_mut().get_mut();
poll_fallible_connect(&mut this.inner, &mut this.timeout, cx)
}
}

/// Like [Connect], but returns `IO` on failure.
pub struct FallibleConnect<IO>(MidHandshake<TlsStream<IO>>);
pub struct FallibleConnect<IO> {
inner: MidHandshake<TlsStream<IO>>,
timeout: Option<HandshakeTimeout>,
}

fn handshake_timeout_error() -> io::Error {
io::Error::new(io::ErrorKind::TimedOut, "TLS handshake timed out")
}

fn poll_fallible_connect<IO>(
inner: &mut MidHandshake<TlsStream<IO>>,
timeout: &mut Option<HandshakeTimeout>,
cx: &mut Context<'_>,
) -> Poll<Result<TlsStream<IO>, (io::Error, IO)>>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
match Pin::new(&mut *inner).poll(cx) {
Poll::Ready(result) => Poll::Ready(result),
Poll::Pending => match timeout {
Some(timeout) => {
if timeout.poll(cx).is_pending() {
return Poll::Pending;
}

let io = inner.take_io().expect("handshake missing IO");
Poll::Ready(Err((handshake_timeout_error(), io)))
}
_ => Poll::Pending,
},
}
}

/// A wrapper around an underlying raw stream which implements the TLS or SSL
/// protocol.
Expand Down
10 changes: 10 additions & 0 deletions src/common/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ pub(crate) enum MidHandshake<IS: IoSession> {
},
}

impl<IS: IoSession> MidHandshake<IS> {
pub(crate) fn take_io(&mut self) -> Option<IS::Io> {
match mem::replace(self, Self::End) {
Self::Handshaking(stream) => Some(stream.into_io()),
Self::SendAlert { io, .. } | Self::Error { io, .. } => Some(io),
Self::End => None,
}
}
}

impl<IS, SD> Future for MidHandshake<IS>
where
IS: IoSession + Unpin,
Expand Down
14 changes: 14 additions & 0 deletions tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,20 @@ async fn fail() -> io::Result<()> {
Ok(())
}

#[tokio::test]
async fn handshake_timeout() {
let (_, config) = utils::make_configs();
let config = TlsConnector::from(Arc::new(config))
.with_handshake_timeout(Some(Duration::from_millis(10)));
let domain = ServerName::try_from(utils::TEST_SERVER_DOMAIN)
.unwrap()
.to_owned();
let (client, _server) = tokio::io::duplex(4096);

let err = config.connect(domain, client).await.unwrap_err();
assert_eq!(err.kind(), ErrorKind::TimedOut);
}

#[tokio::test]
async fn test_lazy_config_acceptor() -> io::Result<()> {
let (sconfig, cconfig) = utils::make_configs();
Expand Down
Loading