Skip to content

Commit b20e4b4

Browse files
committed
implement client builder for native-tls
1 parent 3f9c11c commit b20e4b4

File tree

6 files changed

+130
-38
lines changed

6 files changed

+130
-38
lines changed

bitreq/src/client.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,21 @@
99
use std::collections::{hash_map, HashMap, VecDeque};
1010
use std::sync::{Arc, Mutex};
1111

12-
#[cfg(feature = "tokio-rustls")]
12+
#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
1313
use crate::connection::tls_config::{TlsConfig, TlsConfigBuilder};
1414
use crate::connection::AsyncConnection;
1515
use crate::request::{OwnedConnectionParams as ConnectionKey, ParsedRequest};
1616
use crate::{Error, Request, Response};
1717

1818
#[derive(Clone)]
1919
pub(crate) struct ClientConfig {
20-
#[cfg(feature = "tokio-rustls")]
20+
#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
2121
pub(crate) tls: Option<TlsConfig>,
2222
}
2323

2424
pub struct ClientBuilder {
2525
capacity: usize,
26-
#[cfg(feature = "tokio-rustls")]
26+
#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
2727
tls_config: Option<TlsConfigBuilder>,
2828
}
2929

@@ -48,7 +48,7 @@ impl ClientBuilder {
4848
pub fn new() -> Self {
4949
Self {
5050
capacity: 10,
51-
#[cfg(feature = "tokio-rustls")]
51+
#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
5252
tls_config: None,
5353
}
5454
}
@@ -59,8 +59,8 @@ impl ClientBuilder {
5959
self
6060
}
6161

62-
#[cfg(feature = "tokio-rustls")]
6362
/// Builds the `Client` with the configured settings.
63+
#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
6464
pub fn build(self) -> Result<Client, Error> {
6565
let build_config = if let Some(builder) = self.tls_config {
6666
let tls_config = builder.build()?;
@@ -81,7 +81,7 @@ impl ClientBuilder {
8181
}
8282

8383
/// Builds the `Client` with the configured settings.
84-
#[cfg(not(feature = "tokio-rustls"))]
84+
#[cfg(not(any(feature = "tokio-rustls", feature = "tokio-native-tls")))]
8585
pub fn build(self) -> Result<Client, Error> {
8686
Ok(Client {
8787
r#async: Arc::new(Mutex::new(ClientImpl {
@@ -111,7 +111,7 @@ impl ClientBuilder {
111111
/// # Ok(())
112112
/// # }
113113
/// ```
114-
#[cfg(feature = "tokio-rustls")]
114+
#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
115115
pub fn with_root_certificate<T: Into<Vec<u8>>>(mut self, cert_der: T) -> Result<Self, Error> {
116116
let cert_der = cert_der.into();
117117
if let Some(ref mut tls_config) = self.tls_config {
@@ -126,7 +126,7 @@ impl ClientBuilder {
126126

127127
/// Disables default root certificates for TLS connections.
128128
/// Returns [`Error::InvalidTlsConfig`] if TLS has not been configured.
129-
#[cfg(feature = "tokio-rustls")]
129+
#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
130130
pub fn disable_default_certificates(mut self) -> Result<Self, Error> {
131131
match self.tls_config {
132132
Some(ref mut tls_config) => tls_config.disable_default_certificates()?,

bitreq/src/connection.rs

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ type UnsecuredStream = TcpStream;
3333

3434
#[cfg(any(feature = "rustls", feature = "native-tls"))]
3535
mod rustls_stream;
36-
#[cfg(feature = "tokio-rustls")]
36+
#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
3737
pub(crate) mod tls_config;
3838
#[cfg(any(feature = "rustls", feature = "native-tls"))]
3939
type SecuredStream = rustls_stream::SecuredStream;
@@ -304,20 +304,8 @@ impl AsyncConnection {
304304
}))))
305305
}
306306

307-
// =======
308-
/// Temp method. Required to compile
309-
#[cfg(all(feature = "tokio-native-tls", not(feature = "tokio-rustls")))]
310-
async fn wrap_async_stream(
311-
socket: AsyncTcpStream,
312-
host: &str,
313-
_client_config: &Option<Arc<ClientConfig>>,
314-
) -> Result<AsyncHttpStream, Error> {
315-
rustls_stream::wrap_async_stream(socket, host).await
316-
}
317-
// =======
318-
319307
/// Call the correct wrapper function depending on whether client_configs are present
320-
#[cfg(feature = "tokio-rustls")]
308+
#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
321309
async fn wrap_async_stream(
322310
socket: AsyncTcpStream,
323311
host: &str,

bitreq/src/connection/rustls_stream.rs

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use std::sync::OnceLock;
1414
use native_tls::{HandshakeError, TlsConnector, TlsStream};
1515
#[cfg(feature = "rustls")]
1616
use rustls::{self, ClientConfig, ClientConnection, RootCertStore, ServerName, StreamOwned};
17-
#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))]
17+
#[cfg(all(feature = "tokio-native-tls", not(feature = "rustls")))]
1818
use tokio_native_tls::TlsConnector as AsyncTlsConnector;
1919
#[cfg(feature = "tokio-rustls")]
2020
use tokio_rustls::{client::TlsStream, TlsConnector};
@@ -28,7 +28,7 @@ use super::HttpStream;
2828
all(feature = "rustls", feature = "tokio-rustls")
2929
))]
3030
use super::{AsyncHttpStream, AsyncTcpStream};
31-
#[cfg(all(feature = "rustls", feature = "tokio-rustls"))]
31+
#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
3232
use crate::connection::tls_config::TlsConfig;
3333
use crate::Error;
3434

@@ -189,22 +189,10 @@ pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result<HttpStream, Erro
189189
Ok(HttpStream::Secured(Box::new(tls), None))
190190
}
191191

192-
#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))]
192+
#[cfg(all(feature = "tokio-native-tls", not(feature = "rustls")))]
193193
pub type AsyncSecuredStream = tokio_native_tls::TlsStream<tokio::net::TcpStream>;
194194

195-
// =======
196-
// Temp method, required for compilation
197195
#[cfg(all(feature = "tokio-native-tls", not(feature = "rustls")))]
198-
pub(super) async fn wrap_async_stream_with_configs(
199-
tcp: AsyncTcpStream,
200-
host: &str,
201-
_client_configs: Option<()>,
202-
) -> Result<AsyncHttpStream, Error> {
203-
wrap_async_stream(tcp, host).await
204-
}
205-
// =======
206-
207-
#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))]
208196
pub(super) async fn wrap_async_stream(
209197
tcp: AsyncTcpStream,
210198
host: &str,
@@ -228,3 +216,22 @@ pub(super) async fn wrap_async_stream(
228216

229217
Ok(AsyncHttpStream::Secured(Box::new(tls)))
230218
}
219+
220+
#[cfg(all(feature = "tokio-native-tls", not(feature = "rustls")))]
221+
pub(super) async fn wrap_async_stream_with_configs(
222+
tcp: AsyncTcpStream,
223+
host: &str,
224+
tls_config: TlsConfig,
225+
) -> Result<AsyncHttpStream, Error> {
226+
#[cfg(feature = "log")]
227+
log::trace!("Setting up TLS parameters for {host}.");
228+
229+
let async_connector = tls_config.connector;
230+
231+
#[cfg(feature = "log")]
232+
log::trace!("Establishing TLS session to {host}.");
233+
234+
let tls = async_connector.connect(host, tcp).await?;
235+
236+
Ok(AsyncHttpStream::Secured(Box::new(tls)))
237+
}

bitreq/src/connection/tls_config.rs

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
use std::sync::Arc;
22

3+
#[cfg(not(feature = "rustls"))]
4+
use native_tls::{Certificate, TlsConnector, TlsConnectorBuilder};
35
#[cfg(feature = "rustls")]
46
use rustls::RootCertStore;
7+
#[cfg(not(feature = "rustls"))]
8+
use tokio_native_tls::TlsConnector as AsyncTlsConnector;
59
#[cfg(feature = "rustls-webpki")]
610
use webpki_roots::TLS_SERVER_ROOTS;
711

@@ -13,7 +17,12 @@ pub(crate) struct TlsConfigBuilder {
1317
pub(crate) disable_default: bool,
1418
}
1519

16-
#[cfg(feature = "tokio-rustls")]
20+
#[cfg(not(feature = "rustls"))]
21+
pub(crate) struct TlsConfigBuilder {
22+
pub(crate) inner: TlsConnectorBuilder,
23+
}
24+
25+
#[cfg(feature = "rustls")]
1726
impl TlsConfigBuilder {
1827
pub(crate) fn new(cert_der: Option<Vec<u8>>) -> Result<Self, Error> {
1928
let mut tls_config = Self { inner: RootCertStore::empty(), disable_default: false };
@@ -71,8 +80,47 @@ impl TlsConfigBuilder {
7180
}
7281
}
7382

83+
#[cfg(not(feature = "rustls"))]
84+
impl TlsConfigBuilder {
85+
pub(crate) fn new(cert_der: Option<Vec<u8>>) -> Result<Self, Error> {
86+
let builder = TlsConnector::builder();
87+
let mut tls_config = Self { inner: builder };
88+
89+
if let Some(cert_der) = cert_der {
90+
tls_config.append_certificate(cert_der)?;
91+
}
92+
93+
Ok(tls_config)
94+
}
95+
96+
pub(crate) fn append_certificate(&mut self, cert_der: Vec<u8>) -> Result<&mut Self, Error> {
97+
let certificate = Certificate::from_der(&cert_der)?;
98+
self.inner.add_root_certificate(certificate);
99+
100+
Ok(self)
101+
}
102+
103+
pub(crate) fn disable_default_certificates(&mut self) -> Result<&mut Self, Error> {
104+
self.inner.disable_built_in_roots(true);
105+
Ok(self)
106+
}
107+
108+
pub(crate) fn build(self) -> Result<TlsConfig, Error> {
109+
let connector = self.inner.build()?;
110+
let async_connector = AsyncTlsConnector::from(connector);
111+
112+
Ok(TlsConfig { connector: Arc::new(async_connector) })
113+
}
114+
}
115+
74116
#[derive(Clone)]
75117
#[cfg(feature = "rustls")]
76118
pub(crate) struct TlsConfig {
77119
pub(crate) certificates: Arc<RootCertStore>,
78120
}
121+
122+
#[derive(Clone)]
123+
#[cfg(not(feature = "rustls"))]
124+
pub(crate) struct TlsConfig {
125+
pub(crate) connector: Arc<AsyncTlsConnector>,
126+
}

bitreq/src/error.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ pub enum Error {
2828
#[cfg(feature = "native-tls")]
2929
/// Ran into a native-tls error while creating the connection.
3030
NativeTlsCreateConnection(native_tls::Error),
31+
#[cfg(feature = "native-tls")]
32+
/// Ran into a native-tls error while appending a certificate.
33+
NativeTlsAppendCert,
3134
#[cfg(any(feature = "rustls", feature = "native-tls"))]
3235
/// The current TLS configuration is invalid.
3336
InvalidTlsConfig,
@@ -114,6 +117,8 @@ impl fmt::Display for Error {
114117
RustlsAppendCert(err) => write!(f, "error appending certificate: {}", err),
115118
#[cfg(feature = "native-tls")]
116119
NativeTlsCreateConnection(err) => write!(f, "error creating native-tls connection: {}", err),
120+
#[cfg(feature = "native-tls")]
121+
NativeTlsAppendCert => write!(f, "error appending certificate"),
117122
#[cfg(any(feature = "rustls", feature = "native-tls"))]
118123
InvalidTlsConfig => write!(f, "error disabling default certificates. Must have custom cert."),
119124
MalformedChunkLength => write!(f, "non-usize chunk length with transfer-encoding: chunked"),
@@ -159,6 +164,8 @@ impl error::Error for Error {
159164
RustlsCreateConnection(err) => Some(err),
160165
#[cfg(feature = "rustls")]
161166
RustlsAppendCert(err) => Some(err),
167+
#[cfg(feature = "native-tls")]
168+
NativeTlsCreateConnection(err) => Some(err),
162169
_ => None,
163170
}
164171
}

bitreq/tests/main.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,15 @@ async fn test_https_with_client_builder() {
5252
assert_eq!(response.status_code, 200);
5353
}
5454

55+
#[tokio::test]
56+
#[cfg(all(feature = "tokio-native-tls", not(feature = "rustls")))]
57+
async fn test_https_with_client_builder() {
58+
setup();
59+
let client = bitreq::Client::builder().build().unwrap();
60+
let response = client.send_async(bitreq::get("https://example.com")).await.unwrap();
61+
assert_eq!(response.status_code, 200);
62+
}
63+
5564
#[tokio::test]
5665
#[cfg(feature = "tokio-rustls")]
5766
async fn test_https_with_client_builder_and_cert() {
@@ -66,6 +75,39 @@ async fn test_https_with_client_builder_and_cert() {
6675
assert_eq!(response.status_code, 200);
6776
}
6877

78+
#[tokio::test]
79+
#[cfg(all(feature = "tokio-native-tls", not(feature = "rustls")))]
80+
async fn test_https_with_client_builder_and_cert() {
81+
setup();
82+
let cert_der = include_bytes!("test_cert.der");
83+
let client = bitreq::Client::builder()
84+
.with_root_certificate(cert_der.as_slice())
85+
.unwrap()
86+
.build()
87+
.unwrap();
88+
let response = client.send_async(bitreq::get("https://example.com")).await.unwrap();
89+
assert_eq!(response.status_code, 200);
90+
}
91+
92+
#[tokio::test]
93+
#[cfg(all(feature = "tokio-native-tls", not(feature = "rustls")))]
94+
async fn test_https_with_multiple_certs() {
95+
setup();
96+
let cert_der = include_bytes!("test_cert.der");
97+
let ca_der = include_bytes!("ca_cert.der");
98+
99+
let client = bitreq::Client::builder()
100+
.with_root_certificate(cert_der.as_slice())
101+
.unwrap()
102+
.with_root_certificate(ca_der.as_slice())
103+
.unwrap()
104+
.build()
105+
.unwrap();
106+
107+
let response = client.send_async(bitreq::get("https://example.com")).await.unwrap();
108+
assert_eq!(response.status_code, 200);
109+
}
110+
69111
#[tokio::test]
70112
#[cfg(feature = "tokio-rustls")]
71113
async fn test_https_with_multiple_certs() {

0 commit comments

Comments
 (0)