Skip to content

Commit ee83cc2

Browse files
committed
code review adjustment: streamline code and reduce complexity to avoid arcs/mutexes/clones when building ClientConfigs
1 parent 503702b commit ee83cc2

4 files changed

Lines changed: 113 additions & 147 deletions

File tree

bitreq/src/client.rs

Lines changed: 56 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use std::sync::{Arc, Mutex};
1313
all(feature = "native-tls", feature = "tokio-native-tls"),
1414
all(feature = "rustls", feature = "tokio-rustls")
1515
))]
16-
use crate::connection::certificates::Certificates;
16+
use crate::connection::certificates::{Certificates, CertificatesBuilder};
1717
use crate::connection::AsyncConnection;
1818
use crate::request::{OwnedConnectionParams as ConnectionKey, ParsedRequest};
1919
use crate::{Error, Request, Response};
@@ -27,23 +27,6 @@ pub(crate) struct ClientConfig {
2727
pub(crate) tls: Option<TlsConfig>,
2828
}
2929

30-
impl ClientConfig {
31-
#[cfg(any(
32-
all(feature = "native-tls", feature = "tokio-native-tls"),
33-
all(feature = "rustls", feature = "tokio-rustls")
34-
))]
35-
pub fn build(self) -> Result<Self, Error> {
36-
let tls = self.tls.map(|tls| tls.build()).transpose()?;
37-
Ok(Self { tls })
38-
}
39-
40-
#[cfg(not(any(
41-
all(feature = "native-tls", feature = "tokio-native-tls"),
42-
all(feature = "rustls", feature = "tokio-rustls")
43-
)))]
44-
pub fn build(self) -> Result<Self, Error> { Ok(Self {}) }
45-
}
46-
4730
#[cfg(any(
4831
all(feature = "native-tls", feature = "tokio-native-tls"),
4932
all(feature = "rustls", feature = "tokio-rustls")
@@ -58,34 +41,16 @@ pub(crate) struct TlsConfig {
5841
all(feature = "rustls", feature = "tokio-rustls")
5942
))]
6043
impl TlsConfig {
61-
fn new(cert_der: Vec<u8>) -> Result<Self, Error> {
62-
let certificates = Certificates::new(Some(cert_der))?;
63-
64-
Ok(Self { certificates })
65-
}
66-
67-
#[cfg(all(feature = "rustls", feature = "tokio-rustls"))]
68-
fn build(mut self) -> Result<Self, Error> {
69-
if self.certificates.disable_default {
70-
return Ok(self);
71-
}
72-
73-
self.certificates = self.certificates.with_root_certificates();
74-
Ok(self)
75-
}
76-
77-
#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))]
78-
fn build(mut self) -> Result<Self, Error> {
79-
let certificates = self.certificates.build()?;
80-
81-
self.certificates = certificates;
82-
Ok(self)
83-
}
44+
fn new(certificates: Certificates) -> Self { Self { certificates } }
8445
}
8546

8647
pub struct ClientBuilder {
8748
capacity: usize,
88-
client_config: Option<ClientConfig>,
49+
#[cfg(any(
50+
all(feature = "native-tls", feature = "tokio-native-tls"),
51+
all(feature = "rustls", feature = "tokio-rustls")
52+
))]
53+
certificates: Option<CertificatesBuilder>,
8954
}
9055

9156
/// Builder for configuring a `Client` with custom settings.
@@ -106,22 +71,39 @@ pub struct ClientBuilder {
10671
/// ```
10772
impl ClientBuilder {
10873
/// Creates a new `ClientBuilder` with a default pool capacity of 10.
109-
pub fn new() -> Self { Self { capacity: 10, client_config: None } }
74+
#[cfg(any(
75+
all(feature = "native-tls", feature = "tokio-native-tls"),
76+
all(feature = "rustls", feature = "tokio-rustls")
77+
))]
78+
pub fn new() -> Self { Self { capacity: 10, certificates: None } }
79+
80+
/// Creates a new `ClientBuilder` with a default pool capacity of 10.
81+
#[cfg(not(any(
82+
all(feature = "native-tls", feature = "tokio-native-tls"),
83+
all(feature = "rustls", feature = "tokio-rustls")
84+
)))]
85+
pub fn new() -> Self { Self { capacity: 10 } }
11086

11187
/// Sets the maximum number of connections to keep in the pool.
11288
pub fn with_capacity(mut self, capacity: usize) -> Self {
11389
self.capacity = capacity;
11490
self
11591
}
11692

93+
#[cfg(any(
94+
all(feature = "native-tls", feature = "tokio-native-tls"),
95+
all(feature = "rustls", feature = "tokio-rustls")
96+
))]
11797
/// Builds the `Client` with the configured settings.
11898
pub fn build(self) -> Result<Client, Error> {
119-
let build = self.client_config.map(|c| c.build());
120-
let client_config = match build {
121-
Some(Ok(config)) => Some(Arc::new(config)),
122-
Some(Err(e)) => return Err(e),
123-
None => None,
99+
let build_config = if let Some(builder) = self.certificates {
100+
let certificates = builder.build()?;
101+
let tls_config = TlsConfig::new(certificates);
102+
Some(ClientConfig { tls: Some(tls_config) })
103+
} else {
104+
None
124105
};
106+
let client_config = build_config.map(Arc::new);
125107

126108
Ok(Client {
127109
r#async: Arc::new(Mutex::new(ClientImpl {
@@ -132,6 +114,23 @@ impl ClientBuilder {
132114
})),
133115
})
134116
}
117+
118+
/// Builds the `Client` with the configured settings.
119+
#[cfg(not(any(
120+
all(feature = "native-tls", feature = "tokio-native-tls"),
121+
all(feature = "rustls", feature = "tokio-rustls")
122+
)))]
123+
pub fn build(self) -> Result<Client, Error> {
124+
Ok(Client {
125+
r#async: Arc::new(Mutex::new(ClientImpl {
126+
connections: HashMap::new(),
127+
lru_order: VecDeque::new(),
128+
capacity: self.capacity,
129+
client_config: None,
130+
})),
131+
})
132+
}
133+
135134
/// Adds a custom DER-encoded root certificate for TLS verification.
136135
/// The certificate must be provided in DER format. This method accepts any type
137136
/// that can be converted into a `Vec<u8>`.
@@ -156,18 +155,13 @@ impl ClientBuilder {
156155
))]
157156
pub fn with_root_certificate<T: Into<Vec<u8>>>(mut self, cert_der: T) -> Result<Self, Error> {
158157
let cert_der = cert_der.into();
158+
if let Some(ref mut certificates) = self.certificates {
159+
certificates.append_certificate(cert_der)?;
159160

160-
if let Some(ref mut client_config) = self.client_config {
161-
if let Some(ref mut tls_config) = client_config.tls {
162-
let certificates = tls_config.certificates.clone().append_certificate(cert_der)?;
163-
tls_config.certificates = certificates;
164-
165-
return Ok(self);
166-
}
161+
return Ok(self);
167162
}
168163

169-
let tls_config = TlsConfig::new(cert_der)?;
170-
self.client_config = Some(ClientConfig { tls: Some(tls_config) });
164+
self.certificates = Some(CertificatesBuilder::new(Some(cert_der))?);
171165
Ok(self)
172166
}
173167

@@ -178,9 +172,11 @@ impl ClientBuilder {
178172
all(feature = "rustls", feature = "tokio-rustls")
179173
))]
180174
pub fn disable_default_certificates(mut self) -> Result<Self, Error> {
181-
let client_config = self.client_config.as_mut().ok_or(Error::InvalidTlsConfig)?;
182-
let tls_config = client_config.tls.as_mut().ok_or(Error::InvalidTlsConfig)?;
183-
tls_config.certificates = tls_config.certificates.clone().disable_default()?;
175+
match self.certificates {
176+
Some(ref mut certificates) => certificates.disable_default()?,
177+
None => return Err(Error::InvalidTlsConfig),
178+
};
179+
184180
Ok(self)
185181
}
186182
}

bitreq/src/connection.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,9 @@ impl AsyncConnection {
315315
client_config: Option<Arc<ClientConfig>>,
316316
) -> Result<AsyncHttpStream, Error> {
317317
if let Some(client_config) = client_config {
318-
rustls_stream::wrap_async_stream_with_configs(socket, host, client_config).await
318+
let tls_config = client_config.tls.as_ref().unwrap();
319+
let certificates = tls_config.certificates.clone();
320+
rustls_stream::wrap_async_stream_with_configs(socket, host, certificates).await
319321
} else {
320322
rustls_stream::wrap_async_stream(socket, host).await
321323
}
Lines changed: 48 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,125 +1,101 @@
1-
#[cfg(feature = "rustls")]
1+
#[cfg(any(feature = "rustls", feature = "native-tls"))]
22
use std::sync::Arc;
3-
#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))]
4-
use std::sync::{Arc, Mutex};
53

64
#[cfg(all(feature = "native-tls", not(feature = "rustls")))]
75
use native_tls::{Certificate, TlsConnector, TlsConnectorBuilder};
86
#[cfg(feature = "rustls")]
97
use rustls::RootCertStore;
8+
#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))]
9+
use tokio_native_tls::TlsConnector as AsyncTlsConnector;
1010
#[cfg(feature = "rustls-webpki")]
1111
use webpki_roots::TLS_SERVER_ROOTS;
1212

1313
use crate::Error;
1414

15-
#[derive(Clone)]
1615
#[cfg(feature = "rustls")]
17-
pub(crate) struct Certificates {
18-
pub(crate) inner: Arc<RootCertStore>,
16+
pub(crate) struct CertificatesBuilder {
17+
pub(crate) inner: RootCertStore,
1918
pub(crate) disable_default: bool,
2019
}
2120

22-
#[derive(Clone)]
2321
#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))]
24-
pub(crate) struct Certificates {
25-
pub(crate) inner: CertificatesInner,
22+
pub(crate) struct CertificatesBuilder {
23+
pub(crate) inner: TlsConnectorBuilder,
2624
}
2725

28-
#[derive(Clone)]
29-
#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))]
30-
pub(crate) enum CertificatesInner {
31-
Builder(Arc<Mutex<TlsConnectorBuilder>>),
32-
Built(TlsConnector),
33-
}
34-
35-
impl Certificates {
26+
impl CertificatesBuilder {
3627
#[cfg(feature = "rustls")]
3728
pub(crate) fn new(cert_der: Option<Vec<u8>>) -> Result<Self, Error> {
38-
let certificates = Self { inner: Arc::new(RootCertStore::empty()), disable_default: false };
29+
let mut certificates = Self { inner: RootCertStore::empty(), disable_default: false };
3930

4031
if let Some(cert_der) = cert_der {
41-
certificates.append_certificate(cert_der)
42-
} else {
43-
Ok(certificates)
32+
certificates.append_certificate(cert_der)?;
4433
}
34+
35+
Ok(certificates)
4536
}
4637

4738
#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))]
4839
pub(crate) fn new(cert_der: Option<Vec<u8>>) -> Result<Self, Error> {
4940
let builder = TlsConnector::builder();
50-
let inner = CertificatesInner::Builder(Arc::new(Mutex::new(builder)));
51-
let certificates = Self { inner: inner };
41+
let mut certificates = Self { inner: builder };
5242

5343
if let Some(cert_der) = cert_der {
54-
certificates.append_certificate(cert_der)
55-
} else {
56-
Ok(certificates)
44+
certificates.append_certificate(cert_der)?;
5745
}
46+
47+
Ok(certificates)
5848
}
5949

6050
#[cfg(feature = "rustls")]
61-
pub(crate) fn append_certificate(mut self, cert_der: Vec<u8>) -> Result<Self, Error> {
62-
let certificates = Arc::make_mut(&mut self.inner);
63-
certificates.add(&rustls::Certificate(cert_der)).map_err(Error::RustlsAppendCert)?;
51+
pub(crate) fn append_certificate(&mut self, cert_der: Vec<u8>) -> Result<&mut Self, Error> {
52+
self.inner.add(&rustls::Certificate(cert_der)).map_err(Error::RustlsAppendCert)?;
6453

6554
Ok(self)
6655
}
6756

6857
#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))]
69-
pub(crate) fn append_certificate(mut self, cert_der: Vec<u8>) -> Result<Self, Error> {
70-
let new_inner = match self.inner {
71-
CertificatesInner::Builder(builder_mutex) => {
72-
let certificate = Certificate::from_der(&cert_der)?;
73-
74-
{
75-
let mut builder_guard = builder_mutex.lock().unwrap();
76-
builder_guard.add_root_certificate(certificate);
77-
}
58+
pub(crate) fn append_certificate(&mut self, cert_der: Vec<u8>) -> Result<&mut Self, Error> {
59+
let certificate = Certificate::from_der(&cert_der)?;
60+
self.inner.add_root_certificate(certificate);
7861

79-
CertificatesInner::Builder(builder_mutex)
80-
}
81-
CertificatesInner::Built(_) => return Err(Error::NativeTlsAppendCert),
82-
};
83-
84-
self.inner = new_inner;
8562
Ok(self)
8663
}
8764

8865
#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))]
89-
pub(crate) fn build(mut self) -> Result<Self, Error> {
90-
let new_inner = match self.inner {
91-
CertificatesInner::Builder(builder_mutex) => {
92-
let mut builder_guard = builder_mutex.lock().unwrap();
93-
let connector = builder_guard.build()?;
94-
95-
CertificatesInner::Built(connector)
96-
}
97-
CertificatesInner::Built(_) => return Ok(self),
98-
};
66+
pub(crate) fn build(self) -> Result<Certificates, Error> {
67+
let connector = self.inner.build()?;
68+
let async_connector = AsyncTlsConnector::from(connector);
9969

100-
self.inner = new_inner;
101-
Ok(self)
70+
Ok(Certificates(Arc::new(async_connector)))
10271
}
10372

10473
#[cfg(feature = "rustls")]
105-
pub(crate) fn with_root_certificates(mut self) -> Self {
106-
let root_certificates = Arc::make_mut(&mut self.inner);
74+
pub(crate) fn build(mut self) -> Result<Certificates, Error> {
75+
if !self.disable_default {
76+
self.with_root_certificates();
77+
}
78+
79+
Ok(Certificates(Arc::new(self.inner)))
80+
}
10781

82+
#[cfg(feature = "rustls")]
83+
fn with_root_certificates(&mut self) -> &mut Self {
10884
// Try to load native certs
10985
#[cfg(feature = "https-rustls-probe")]
11086
if let Ok(os_roots) = rustls_native_certs::load_native_certs() {
11187
for root_cert in os_roots {
11288
// Ignore erroneous OS certificates, there's nothing
11389
// to do differently in that situation anyways.
114-
let _ = root_certificates.add(&rustls::Certificate(root_cert.0));
90+
let _ = self.inner.add(&rustls::Certificate(root_cert.0));
11591
}
11692
}
11793

11894
#[cfg(feature = "rustls-webpki")]
11995
{
12096
#[allow(deprecated)]
12197
// Need to use add_server_trust_anchors to compile with rustls 0.21.1
122-
root_certificates.add_server_trust_anchors(TLS_SERVER_ROOTS.iter().map(|ta| {
98+
self.inner.add_server_trust_anchors(TLS_SERVER_ROOTS.iter().map(|ta| {
12399
rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
124100
ta.subject,
125101
ta.spki,
@@ -131,19 +107,22 @@ impl Certificates {
131107
}
132108

133109
#[cfg(feature = "rustls")]
134-
pub(crate) fn disable_default(mut self) -> Result<Self, Error> {
110+
pub(crate) fn disable_default(&mut self) -> Result<&mut Self, Error> {
135111
self.disable_default = true;
136112
Ok(self)
137113
}
138114

139115
#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))]
140-
pub(crate) fn disable_default(self) -> Result<Self, Error> {
141-
match self.inner {
142-
CertificatesInner::Builder(ref builder_mutex) => {
143-
builder_mutex.lock().unwrap().disable_built_in_roots(true);
144-
Ok(self)
145-
}
146-
CertificatesInner::Built(_) => return Err(Error::InvalidTlsConfig),
147-
}
116+
pub(crate) fn disable_default(&mut self) -> Result<&mut Self, Error> {
117+
self.inner.disable_built_in_roots(true);
118+
Ok(self)
148119
}
149120
}
121+
122+
#[derive(Clone)]
123+
#[cfg(feature = "rustls")]
124+
pub(crate) struct Certificates(pub(crate) Arc<RootCertStore>);
125+
126+
#[derive(Clone)]
127+
#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))]
128+
pub(crate) struct Certificates(pub(crate) Arc<AsyncTlsConnector>);

0 commit comments

Comments
 (0)