Skip to content

Commit 3f9c11c

Browse files
committed
implement client builder for rustls
1 parent 4840954 commit 3f9c11c

9 files changed

Lines changed: 366 additions & 34 deletions

File tree

bitreq/src/client.rs

Lines changed: 141 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,138 @@
99
use std::collections::{hash_map, HashMap, VecDeque};
1010
use std::sync::{Arc, Mutex};
1111

12+
#[cfg(feature = "tokio-rustls")]
13+
use crate::connection::tls_config::{TlsConfig, TlsConfigBuilder};
1214
use crate::connection::AsyncConnection;
1315
use crate::request::{OwnedConnectionParams as ConnectionKey, ParsedRequest};
1416
use crate::{Error, Request, Response};
1517

18+
#[derive(Clone)]
19+
pub(crate) struct ClientConfig {
20+
#[cfg(feature = "tokio-rustls")]
21+
pub(crate) tls: Option<TlsConfig>,
22+
}
23+
24+
pub struct ClientBuilder {
25+
capacity: usize,
26+
#[cfg(feature = "tokio-rustls")]
27+
tls_config: Option<TlsConfigBuilder>,
28+
}
29+
30+
/// Builder for configuring a `Client` with custom settings.
31+
///
32+
/// # Example
33+
///
34+
/// ```no_run
35+
/// # async fn example() -> Result<(), bitreq::Error> {
36+
/// use bitreq::{Client, RequestExt};
37+
///
38+
/// let client = Client::builder().with_capacity(20).build()?;
39+
///
40+
/// let response = bitreq::get("https://example.com")
41+
/// .send_async_with_client(&client)
42+
/// .await?;
43+
/// # Ok(())
44+
/// # }
45+
/// ```
46+
impl ClientBuilder {
47+
/// Creates a new `ClientBuilder` with a default pool capacity of 10.
48+
pub fn new() -> Self {
49+
Self {
50+
capacity: 10,
51+
#[cfg(feature = "tokio-rustls")]
52+
tls_config: None,
53+
}
54+
}
55+
56+
/// Sets the maximum number of connections to keep in the pool.
57+
pub fn with_capacity(mut self, capacity: usize) -> Self {
58+
self.capacity = capacity;
59+
self
60+
}
61+
62+
#[cfg(feature = "tokio-rustls")]
63+
/// Builds the `Client` with the configured settings.
64+
pub fn build(self) -> Result<Client, Error> {
65+
let build_config = if let Some(builder) = self.tls_config {
66+
let tls_config = builder.build()?;
67+
Some(ClientConfig { tls: Some(tls_config) })
68+
} else {
69+
None
70+
};
71+
let client_config = build_config.map(Arc::new);
72+
73+
Ok(Client {
74+
r#async: Arc::new(Mutex::new(ClientImpl {
75+
connections: HashMap::new(),
76+
lru_order: VecDeque::new(),
77+
capacity: self.capacity,
78+
client_config,
79+
})),
80+
})
81+
}
82+
83+
/// Builds the `Client` with the configured settings.
84+
#[cfg(not(feature = "tokio-rustls"))]
85+
pub fn build(self) -> Result<Client, Error> {
86+
Ok(Client {
87+
r#async: Arc::new(Mutex::new(ClientImpl {
88+
connections: HashMap::new(),
89+
lru_order: VecDeque::new(),
90+
capacity: self.capacity,
91+
client_config: None,
92+
})),
93+
})
94+
}
95+
96+
/// Adds a custom DER-encoded root certificate for TLS verification.
97+
/// The certificate must be provided in DER format. This method accepts any type
98+
/// that can be converted into a `Vec<u8>`.
99+
/// The certificate is appended to the default trust store rather than replacing it.
100+
/// The trust store used depends on the TLS backend: system certificates for native-tls,
101+
/// Mozilla's root certificates(rustls-webpki) and/or system certificates(rustls-native-certs) for rustls.
102+
///
103+
/// # Example
104+
///
105+
/// ```no_run
106+
/// # use bitreq::Client;
107+
/// # async fn example() -> Result<(), bitreq::Error> {
108+
/// let client = Client::builder()
109+
/// .with_root_certificate(include_bytes!("../tests/test_cert.der"))?
110+
/// .build()?;
111+
/// # Ok(())
112+
/// # }
113+
/// ```
114+
#[cfg(feature = "tokio-rustls")]
115+
pub fn with_root_certificate<T: Into<Vec<u8>>>(mut self, cert_der: T) -> Result<Self, Error> {
116+
let cert_der = cert_der.into();
117+
if let Some(ref mut tls_config) = self.tls_config {
118+
tls_config.append_certificate(cert_der)?;
119+
120+
return Ok(self);
121+
}
122+
123+
self.tls_config = Some(TlsConfigBuilder::new(Some(cert_der))?);
124+
Ok(self)
125+
}
126+
127+
/// Disables default root certificates for TLS connections.
128+
/// Returns [`Error::InvalidTlsConfig`] if TLS has not been configured.
129+
#[cfg(feature = "tokio-rustls")]
130+
pub fn disable_default_certificates(mut self) -> Result<Self, Error> {
131+
match self.tls_config {
132+
Some(ref mut tls_config) => tls_config.disable_default_certificates()?,
133+
None => return Err(Error::InvalidTlsConfig),
134+
};
135+
136+
Ok(self)
137+
}
138+
}
139+
140+
impl Default for ClientBuilder {
141+
fn default() -> Self { Self::new() }
142+
}
143+
16144
/// A client that caches connections for reuse.
17145
///
18146
/// The client maintains a pool of up to `capacity` connections, evicting
@@ -39,10 +167,11 @@ struct ClientImpl<T> {
39167
connections: HashMap<ConnectionKey, Arc<T>>,
40168
lru_order: VecDeque<ConnectionKey>,
41169
capacity: usize,
170+
client_config: Option<Arc<ClientConfig>>,
42171
}
43172

44173
impl Client {
45-
/// Creates a new `Client` with the specified connection cache capacity.
174+
/// Creates a new `Client` with the specified connection pool capacity.
46175
///
47176
/// # Arguments
48177
///
@@ -54,10 +183,14 @@ impl Client {
54183
connections: HashMap::new(),
55184
lru_order: VecDeque::new(),
56185
capacity,
186+
client_config: None,
57187
})),
58188
}
59189
}
60190

191+
/// Create a builder for a client
192+
pub fn builder() -> ClientBuilder { ClientBuilder::new() }
193+
61194
/// Sends a request asynchronously using a cached connection if available.
62195
pub async fn send_async(&self, request: Request) -> Result<Response, Error> {
63196
let parsed_request = ParsedRequest::new(request)?;
@@ -77,7 +210,13 @@ impl Client {
77210
let conn = if let Some(conn) = conn_opt {
78211
conn
79212
} else {
80-
let connection = AsyncConnection::new(key, parsed_request.timeout_at).await?;
213+
let client_config = {
214+
let state = self.r#async.lock().unwrap();
215+
state.client_config.as_ref().map(Arc::clone)
216+
};
217+
218+
let connection =
219+
AsyncConnection::new(key, parsed_request.timeout_at, client_config).await?;
81220
let connection = Arc::new(connection);
82221

83222
let mut state = self.r#async.lock().unwrap();

bitreq/src/connection.rs

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ use tokio::net::TcpStream as AsyncTcpStream;
2222
#[cfg(feature = "async")]
2323
use tokio::sync::Mutex as AsyncMutex;
2424

25+
#[cfg(feature = "async")]
26+
use crate::client::ClientConfig;
2527
use crate::request::{ConnectionParams, OwnedConnectionParams, ParsedRequest};
2628
#[cfg(feature = "async")]
2729
use crate::Response;
@@ -31,6 +33,8 @@ type UnsecuredStream = TcpStream;
3133

3234
#[cfg(any(feature = "rustls", feature = "native-tls"))]
3335
mod rustls_stream;
36+
#[cfg(feature = "tokio-rustls")]
37+
pub(crate) mod tls_config;
3438
#[cfg(any(feature = "rustls", feature = "native-tls"))]
3539
type SecuredStream = rustls_stream::SecuredStream;
3640

@@ -238,6 +242,7 @@ struct AsyncConnectionState {
238242
/// Defaults to 60 seconds after open to align with nginx's default timeout of 75 seconds, but
239243
/// can be overridden by the `Keep-Alive` header.
240244
socket_new_requests_timeout: Mutex<Instant>,
245+
client_config: Option<Arc<ClientConfig>>,
241246
}
242247

243248
#[cfg(feature = "async")]
@@ -266,13 +271,15 @@ impl AsyncConnection {
266271
pub(crate) async fn new(
267272
params: ConnectionParams<'_>,
268273
timeout_at: Option<Instant>,
274+
client_config: Option<Arc<ClientConfig>>,
269275
) -> Result<AsyncConnection, Error> {
276+
let client_config_ref = &client_config;
277+
270278
let future = async move {
271279
let socket = Self::connect(params).await?;
272280

273281
if params.https {
274-
// temp call
275-
Self::wrap_async_stream(socket, params.host).await
282+
Self::wrap_async_stream(socket, params.host, client_config_ref).await
276283
} else {
277284
Ok(AsyncHttpStream::Unsecured(socket))
278285
}
@@ -293,26 +300,47 @@ impl AsyncConnection {
293300
readable_request_id: AtomicUsize::new(0),
294301
min_dropped_reader_id: AtomicUsize::new(usize::MAX),
295302
socket_new_requests_timeout: Mutex::new(Instant::now() + Duration::from_secs(60)),
303+
client_config,
296304
}))))
297305
}
298306

299-
/// Temp Method implementation
300-
#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
307+
// =======
308+
/// Temp method. Required to compile
309+
#[cfg(all(feature = "tokio-native-tls", not(feature = "tokio-rustls")))]
301310
async fn wrap_async_stream(
302311
socket: AsyncTcpStream,
303312
host: &str,
313+
_client_config: &Option<Arc<ClientConfig>>,
304314
) -> Result<AsyncHttpStream, Error> {
305315
rustls_stream::wrap_async_stream(socket, host).await
306316
}
317+
// =======
318+
319+
/// Call the correct wrapper function depending on whether client_configs are present
320+
#[cfg(feature = "tokio-rustls")]
321+
async fn wrap_async_stream(
322+
socket: AsyncTcpStream,
323+
host: &str,
324+
client_config: &Option<Arc<ClientConfig>>,
325+
) -> Result<AsyncHttpStream, Error> {
326+
if let Some(client_config) = client_config {
327+
let tls_config = client_config.tls.as_ref().unwrap().clone();
328+
rustls_stream::wrap_async_stream_with_configs(socket, host, tls_config).await
329+
} else {
330+
rustls_stream::wrap_async_stream(socket, host).await
331+
}
332+
}
307333

308-
/// Temp Method implementation
334+
/// Error treatment function, should not be called under normal circustances
309335
#[cfg(not(any(feature = "tokio-rustls", feature = "tokio-native-tls")))]
310336
async fn wrap_async_stream(
311337
_socket: AsyncTcpStream,
312338
_host: &str,
339+
_client_config: &Option<Arc<ClientConfig>>,
313340
) -> Result<AsyncHttpStream, Error> {
314341
Err(Error::HttpsFeatureNotEnabled)
315342
}
343+
316344
async fn tcp_connect(host: &str, port: u16) -> Result<AsyncTcpStream, Error> {
317345
#[cfg(feature = "log")]
318346
log::trace!("Looking up host {host}");
@@ -461,9 +489,13 @@ impl AsyncConnection {
461489
retry_new_connection!(_internal);
462490
};
463491
(_internal) => {
464-
let new_connection =
465-
AsyncConnection::new(request.connection_params(), request.timeout_at)
466-
.await?;
492+
let config = conn.client_config.as_ref().map(Arc::clone);
493+
let new_connection = AsyncConnection::new(
494+
request.connection_params(),
495+
request.timeout_at,
496+
config,
497+
)
498+
.await?;
467499
*self.0.lock().unwrap() = Arc::clone(&*new_connection.0.lock().unwrap());
468500
core::mem::drop(read);
469501
// Note that this cannot recurse infinitely as we'll always be able to send at
@@ -818,7 +850,8 @@ async fn async_handle_redirects(
818850
let new_connection;
819851
if needs_new_connection {
820852
new_connection =
821-
AsyncConnection::new(request.connection_params(), request.timeout_at).await?;
853+
AsyncConnection::new(request.connection_params(), request.timeout_at, None)
854+
.await?;
822855
connection = &new_connection;
823856
}
824857
connection.send(request).await

0 commit comments

Comments
 (0)