Skip to content

Commit 4840954

Browse files
committed
native-tls https fix
1 parent 774a655 commit 4840954

4 files changed

Lines changed: 114 additions & 35 deletions

File tree

bitreq/src/connection.rs

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@ use crate::{Error, Method, ResponseLazy};
2929

3030
type UnsecuredStream = TcpStream;
3131

32-
#[cfg(feature = "rustls")]
32+
#[cfg(any(feature = "rustls", feature = "native-tls"))]
3333
mod rustls_stream;
34-
#[cfg(feature = "rustls")]
34+
#[cfg(any(feature = "rustls", feature = "native-tls"))]
3535
type SecuredStream = rustls_stream::SecuredStream;
3636

3737
pub(crate) enum HttpStream {
3838
Unsecured(UnsecuredStream, Option<Instant>),
39-
#[cfg(feature = "rustls")]
39+
#[cfg(any(feature = "rustls", feature = "native-tls"))]
4040
Secured(Box<SecuredStream>, Option<Instant>),
4141
#[cfg(feature = "async")]
4242
Buffer(std::io::Cursor<Vec<u8>>),
@@ -81,7 +81,7 @@ impl Read for HttpStream {
8181
timeout(inner, *timeout_at)?;
8282
inner.read(buf)
8383
}
84-
#[cfg(feature = "rustls")]
84+
#[cfg(any(feature = "rustls", feature = "native-tls"))]
8585
HttpStream::Secured(inner, timeout_at) => {
8686
timeout(inner.get_ref(), *timeout_at)?;
8787
inner.read(buf)
@@ -111,7 +111,7 @@ impl Write for HttpStream {
111111
set_socket_write_timeout(inner, *timeout_at)?;
112112
inner.write(buf)
113113
}
114-
#[cfg(feature = "rustls")]
114+
#[cfg(any(feature = "rustls", feature = "native-tls"))]
115115
HttpStream::Secured(inner, timeout_at) => {
116116
set_socket_write_timeout(inner.get_ref(), *timeout_at)?;
117117
inner.write(buf)
@@ -137,7 +137,7 @@ impl Write for HttpStream {
137137
set_socket_write_timeout(inner, *timeout_at)?;
138138
inner.flush()
139139
}
140-
#[cfg(feature = "rustls")]
140+
#[cfg(any(feature = "rustls", feature = "native-tls"))]
141141
HttpStream::Secured(inner, timeout_at) => {
142142
set_socket_write_timeout(inner.get_ref(), *timeout_at)?;
143143
inner.flush()
@@ -158,13 +158,13 @@ impl Write for HttpStream {
158158
}
159159
}
160160

161-
#[cfg(feature = "tokio-rustls")]
161+
#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
162162
type AsyncSecuredStream = rustls_stream::AsyncSecuredStream;
163163

164164
#[cfg(feature = "async")]
165165
pub(crate) enum AsyncHttpStream {
166166
Unsecured(AsyncTcpStream),
167-
#[cfg(feature = "tokio-rustls")]
167+
#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
168168
Secured(Box<AsyncSecuredStream>),
169169
}
170170

@@ -177,7 +177,7 @@ impl AsyncRead for AsyncHttpStream {
177177
) -> Poll<io::Result<()>> {
178178
match &mut *self {
179179
AsyncHttpStream::Unsecured(inner) => Pin::new(inner).poll_read(cx, buf),
180-
#[cfg(feature = "tokio-rustls")]
180+
#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
181181
AsyncHttpStream::Secured(inner) => Pin::new(inner).poll_read(cx, buf),
182182
}
183183
}
@@ -192,23 +192,23 @@ impl AsyncWrite for AsyncHttpStream {
192192
) -> Poll<io::Result<usize>> {
193193
match &mut *self {
194194
AsyncHttpStream::Unsecured(inner) => Pin::new(inner).poll_write(cx, buf),
195-
#[cfg(feature = "tokio-rustls")]
195+
#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
196196
AsyncHttpStream::Secured(inner) => Pin::new(inner).poll_write(cx, buf),
197197
}
198198
}
199199

200200
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
201201
match &mut *self {
202202
AsyncHttpStream::Unsecured(inner) => Pin::new(inner).poll_flush(cx),
203-
#[cfg(feature = "tokio-rustls")]
203+
#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
204204
AsyncHttpStream::Secured(inner) => Pin::new(inner).poll_flush(cx),
205205
}
206206
}
207207

208208
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
209209
match &mut *self {
210210
AsyncHttpStream::Unsecured(inner) => Pin::new(inner).poll_shutdown(cx),
211-
#[cfg(feature = "tokio-rustls")]
211+
#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
212212
AsyncHttpStream::Secured(inner) => Pin::new(inner).poll_shutdown(cx),
213213
}
214214
}
@@ -271,10 +271,8 @@ impl AsyncConnection {
271271
let socket = Self::connect(params).await?;
272272

273273
if params.https {
274-
#[cfg(not(feature = "tokio-rustls"))]
275-
return Err(Error::HttpsFeatureNotEnabled);
276-
#[cfg(feature = "tokio-rustls")]
277-
rustls_stream::wrap_async_stream(socket, params.host).await
274+
// temp call
275+
Self::wrap_async_stream(socket, params.host).await
278276
} else {
279277
Ok(AsyncHttpStream::Unsecured(socket))
280278
}
@@ -298,6 +296,23 @@ impl AsyncConnection {
298296
}))))
299297
}
300298

299+
/// Temp Method implementation
300+
#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
301+
async fn wrap_async_stream(
302+
socket: AsyncTcpStream,
303+
host: &str,
304+
) -> Result<AsyncHttpStream, Error> {
305+
rustls_stream::wrap_async_stream(socket, host).await
306+
}
307+
308+
/// Temp Method implementation
309+
#[cfg(not(any(feature = "tokio-rustls", feature = "tokio-native-tls")))]
310+
async fn wrap_async_stream(
311+
_socket: AsyncTcpStream,
312+
_host: &str,
313+
) -> Result<AsyncHttpStream, Error> {
314+
Err(Error::HttpsFeatureNotEnabled)
315+
}
301316
async fn tcp_connect(host: &str, port: u16) -> Result<AsyncTcpStream, Error> {
302317
#[cfg(feature = "log")]
303318
log::trace!("Looking up host {host}");
@@ -653,13 +668,10 @@ impl Connection {
653668
let socket = Self::connect(params, timeout_at)?;
654669

655670
let stream = if params.https {
656-
#[cfg(not(feature = "rustls"))]
671+
#[cfg(not(any(feature = "rustls", feature = "native-tls")))]
657672
return Err(Error::HttpsFeatureNotEnabled);
658-
#[cfg(feature = "rustls")]
659-
{
660-
let tls = rustls_stream::wrap_stream(socket, params.host)?;
661-
HttpStream::Secured(Box::new(tls), timeout_at)
662-
}
673+
#[cfg(any(feature = "rustls", feature = "native-tls"))]
674+
rustls_stream::wrap_stream(socket, params.host)?
663675
} else {
664676
HttpStream::create_unsecured(socket, timeout_at)
665677
};

bitreq/src/connection/rustls_stream.rs

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
use alloc::sync::Arc;
66
#[cfg(feature = "rustls")]
77
use core::convert::TryFrom;
8+
#[cfg(any(feature = "rustls", feature = "native-tls"))]
89
use std::io;
910
use std::net::TcpStream;
1011
use std::sync::OnceLock;
@@ -20,9 +21,12 @@ use tokio_rustls::{client::TlsStream, TlsConnector};
2021
#[cfg(feature = "rustls-webpki")]
2122
use webpki_roots::TLS_SERVER_ROOTS;
2223

23-
#[cfg(feature = "tokio-rustls")]
24-
use super::{AsyncHttpStream, AsyncTcpStream};
25-
#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))]
24+
#[cfg(any(feature = "rustls", feature = "native-tls"))]
25+
use super::HttpStream;
26+
#[cfg(any(
27+
all(feature = "native-tls", feature = "tokio-native-tls"),
28+
all(feature = "rustls", feature = "tokio-rustls")
29+
))]
2630
use super::{AsyncHttpStream, AsyncTcpStream};
2731
use crate::Error;
2832

@@ -64,7 +68,7 @@ fn build_client_config() -> Arc<ClientConfig> {
6468
}
6569

6670
#[cfg(feature = "rustls")]
67-
pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result<SecuredStream, Error> {
71+
pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result<HttpStream, Error> {
6872
#[cfg(feature = "log")]
6973
log::trace!("Setting up TLS parameters for {host}.");
7074
let dns_name = match ServerName::try_from(host) {
@@ -73,10 +77,12 @@ pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result<SecuredStream, E
7377
};
7478
let sess = ClientConnection::new(CONFIG.get_or_init(build_client_config).clone(), dns_name)
7579
.map_err(Error::RustlsCreateConnection)?;
80+
let tls = StreamOwned::new(sess, tcp);
7681

7782
#[cfg(feature = "log")]
7883
log::trace!("Establishing TLS session to {host}.");
79-
Ok(StreamOwned::new(sess, tcp))
84+
85+
Ok(HttpStream::Secured(Box::new(tls), None))
8086
}
8187

8288
// Async rustls TLS implementation
@@ -115,7 +121,7 @@ static CONNECTOR: OnceLock<Result<TlsConnector, Error>> = OnceLock::new();
115121
#[cfg(all(feature = "native-tls", not(feature = "rustls")))]
116122
fn native_tls_err<S>(e: HandshakeError<S>) -> Error {
117123
match e {
118-
HandshakeError::Failure(e) => Error::NativeTlsError(e),
124+
HandshakeError::Failure(err) => Error::NativeTlsCreateConnection(err),
119125
HandshakeError::WouldBlock(_) => {
120126
debug_assert!(false, "We shouldn't hit a blocking error");
121127
Error::Other("Got a WouldBlock error from native-tls")
@@ -125,22 +131,27 @@ fn native_tls_err<S>(e: HandshakeError<S>) -> Error {
125131

126132
#[cfg(all(feature = "native-tls", not(feature = "rustls")))]
127133
fn build_tls_connector() -> Result<TlsConnector, Error> {
128-
TlsConnector::builder().build().map_err(Error::NativeTlsError)
134+
TlsConnector::builder().build().map_err(Error::from)
129135
}
130136

131137
#[cfg(all(feature = "native-tls", not(feature = "rustls")))]
132-
pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result<SecuredStream, Error> {
138+
pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result<HttpStream, Error> {
133139
#[cfg(feature = "log")]
134140
log::trace!("Setting up TLS parameters for {host}.");
135141

136142
// TODO: Once we can `get_or_try_init`, so that instead
137143
// https://github.com/rust-lang/rust/issues/109737
138-
let connector = CONNECTOR.get_or_init(build_tls_connector)?;
144+
let connector = match CONNECTOR.get_or_init(build_tls_connector) {
145+
Ok(c) => c.clone(),
146+
Err(err) => return Err(Error::IoError(io::Error::new(io::ErrorKind::Other, err))),
147+
};
139148

140149
#[cfg(feature = "log")]
141150
log::trace!("Establishing TLS session to {host}.");
142151

143-
connector.connect(host, tcp).map_err(native_tls_err)
152+
let tls = connector.connect(host, tcp).map_err(native_tls_err)?;
153+
154+
Ok(HttpStream::Secured(Box::new(tls), None))
144155
}
145156

146157
#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))]
@@ -156,12 +167,36 @@ pub(super) async fn wrap_async_stream(
156167

157168
// TODO: Once we can `get_or_try_init`, so that instead
158169
// https://github.com/rust-lang/rust/issues/109737
159-
let connector = AsyncTlsConnector::from(CONNECTOR.get_or_init(build_tls_connector)?.clone());
170+
let sync_connector = match CONNECTOR.get_or_init(build_tls_connector) {
171+
Ok(c) => c.clone(),
172+
Err(err) => return Err(Error::IoError(io::Error::new(io::ErrorKind::Other, err))),
173+
};
174+
175+
let async_connector = AsyncTlsConnector::from(sync_connector);
160176

161177
#[cfg(feature = "log")]
162178
log::trace!("Establishing TLS session to {host}.");
163179

164-
let tls = connector.connect(host, tcp).await.map_err(native_tls_err)?;
180+
let tls = async_connector.connect(host, tcp).await?;
165181

166182
Ok(AsyncHttpStream::Secured(Box::new(tls)))
167183
}
184+
185+
// #[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))]
186+
// pub(super) async fn wrap_async_stream_with_configs(
187+
// tcp: AsyncTcpStream,
188+
// host: &str,
189+
// client_configs: Certificates,
190+
// ) -> Result<AsyncHttpStream, Error> {
191+
// #[cfg(feature = "log")]
192+
// log::trace!("Setting up TLS parameters for {host}.");
193+
194+
// let async_connector = client_configs.0;
195+
196+
// #[cfg(feature = "log")]
197+
// log::trace!("Establishing TLS session to {host}.");
198+
199+
// let tls = async_connector.connect(host, tcp).await?;
200+
201+
// Ok(AsyncHttpStream::Secured(Box::new(tls)))
202+
// }

bitreq/src/error.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ impl fmt::Display for Error {
105105
#[cfg(feature = "rustls")]
106106
RustlsCreateConnection(err) => write!(f, "error creating rustls connection: {}", err),
107107
#[cfg(feature = "native-tls")]
108-
NativeTlsCreateConnection(err) => write!(f, "error creating native-tls connection: {err}"),
108+
NativeTlsCreateConnection(err) => write!(f, "error creating native-tls connection: {}", err),
109109
MalformedChunkLength => write!(f, "non-usize chunk length with transfer-encoding: chunked"),
110110
MalformedChunkEnd => write!(f, "chunk did not end after reading the expected amount of bytes"),
111111
MalformedContentLength => write!(f, "non-usize content length"),
@@ -160,3 +160,8 @@ impl From<io::Error> for Error {
160160
impl From<UrlParseError> for Error {
161161
fn from(other: UrlParseError) -> Error { Error::InvalidUrl(other) }
162162
}
163+
164+
#[cfg(feature = "native-tls")]
165+
impl From<native_tls::Error> for Error {
166+
fn from(err: native_tls::Error) -> Error { Error::NativeTlsCreateConnection(err) }
167+
}

bitreq/tests/main.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,33 @@ async fn test_https() {
1616
assert_eq!(get_status_code(bitreq::get("https://example.com")).await, 200);
1717
}
1818

19+
#[tokio::test]
20+
#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))]
21+
async fn test_https() {
22+
// TODO: Implement this locally.
23+
assert_eq!(get_status_code(bitreq::get("https://example.com")).await, 200);
24+
// Test reusing the existing connection in client:
25+
assert_eq!(get_status_code(bitreq::get("https://example.com")).await, 200);
26+
}
27+
28+
#[tokio::test]
29+
#[cfg(all(feature = "rustls", feature = "tokio-rustls"))]
30+
async fn test_https_with_client() {
31+
setup();
32+
let client = bitreq::Client::new(1);
33+
let response = client.send_async(bitreq::get("https://example.com")).await.unwrap();
34+
assert_eq!(response.status_code, 200);
35+
}
36+
37+
#[tokio::test]
38+
#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))]
39+
async fn test_https_with_client() {
40+
setup();
41+
let client = bitreq::Client::new(1);
42+
let response = client.send_async(bitreq::get("https://example.com")).await.unwrap();
43+
assert_eq!(response.status_code, 200);
44+
}
45+
1946
#[tokio::test]
2047
#[cfg(feature = "json-using-serde")]
2148
async fn test_json_using_serde() {

0 commit comments

Comments
 (0)