Skip to content

Commit e20d8af

Browse files
committed
feat(client): allow to set a specific sni hostname per request
1 parent 7645226 commit e20d8af

6 files changed

Lines changed: 64 additions & 16 deletions

File tree

client/src/client.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ impl Client {
269269

270270
let (conn, version) = self
271271
.connector
272-
.call((connect.hostname(), conn))
272+
.call((connect.sni_hostname(), conn))
273273
.timeout(timer.as_mut())
274274
.await
275275
.map_err(|_| TimeoutError::TlsHandshake)??;

client/src/connect.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,17 +80,19 @@ pub struct Connect<'a> {
8080
pub(crate) uri: Uri<'a>,
8181
pub(crate) port: u16,
8282
pub(crate) addr: Addrs,
83+
pub(crate) sni_hostname: Option<&'a str>,
8384
}
8485

8586
impl<'a> Connect<'a> {
8687
/// Create `Connect` instance by splitting the string by ':' and convert the second part to u16
87-
pub fn new(uri: Uri<'a>) -> Self {
88+
pub fn new(uri: Uri<'a>, sni_hostname: Option<&'a str>) -> Self {
8889
let (_, port) = parse_host(uri.hostname());
8990

9091
Self {
9192
uri,
9293
port: port.unwrap_or(0),
9394
addr: Addrs::None,
95+
sni_hostname,
9496
}
9597
}
9698

@@ -112,6 +114,11 @@ impl<'a> Connect<'a> {
112114
self.uri.hostname()
113115
}
114116

117+
/// Get sni hostname.
118+
pub fn sni_hostname(&self) -> &str {
119+
self.sni_hostname.unwrap_or_else(|| self.hostname())
120+
}
121+
115122
/// Get request port.
116123
pub fn port(&self) -> u16 {
117124
Address::port(&self.uri).unwrap_or(self.port)

client/src/connection.rs

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use core::hash::{Hash, Hasher};
22

33
use xitca_http::http::uri::{Authority, PathAndQuery};
44

5-
use super::{tls::TlsStream, uri::Uri};
5+
use super::{connect::Connect, tls::TlsStream, uri::Uri};
66

77
/// exclusive connection for http1 and in certain case they can be upgraded to [ConnectionShared]
88
pub type ConnectionExclusive = TlsStream;
@@ -34,10 +34,17 @@ impl From<crate::h3::Connection> for ConnectionShared {
3434
#[doc(hidden)]
3535
#[derive(PartialEq, Eq, Debug, Clone, Hash)]
3636
pub enum ConnectionKey {
37-
Regular(Authority),
37+
Regular(AuthorityWithSni),
3838
Unix(AuthorityWithPath),
3939
}
4040

41+
#[doc(hidden)]
42+
#[derive(PartialEq, Eq, Debug, Clone, Hash)]
43+
pub struct AuthorityWithSni {
44+
authority: Authority,
45+
sni: Option<String>,
46+
}
47+
4148
#[doc(hidden)]
4249
#[derive(Eq, Debug, Clone)]
4350
pub struct AuthorityWithPath {
@@ -58,10 +65,13 @@ impl Hash for AuthorityWithPath {
5865
}
5966
}
6067

61-
impl From<&Uri<'_>> for ConnectionKey {
62-
fn from(uri: &Uri<'_>) -> Self {
63-
match *uri {
64-
Uri::Tcp(uri) | Uri::Tls(uri) => ConnectionKey::Regular(uri.authority().unwrap().clone()),
68+
impl From<&Connect<'_>> for ConnectionKey {
69+
fn from(connect: &Connect<'_>) -> Self {
70+
match connect.uri {
71+
Uri::Tcp(uri) | Uri::Tls(uri) => ConnectionKey::Regular(AuthorityWithSni {
72+
authority: uri.authority().unwrap().clone(),
73+
sni: connect.sni_hostname.map(|s| s.to_string()),
74+
}),
6575
Uri::Unix(uri) => ConnectionKey::Unix(AuthorityWithPath {
6676
authority: uri.authority().unwrap().clone(),
6777
path_and_query: uri.path_and_query().unwrap().clone(),

client/src/middleware/redirect.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,26 @@ where
2828
type Error = Error;
2929

3030
async fn call(&self, req: ServiceRequest<'r, 'c>) -> Result<Self::Response, Self::Error> {
31-
let ServiceRequest { req, client, timeout } = req;
31+
let ServiceRequest {
32+
req,
33+
client,
34+
timeout,
35+
sni_hostname,
36+
} = req;
3237
let mut headers = req.headers().clone();
3338
let mut method = req.method().clone();
3439
let mut uri = req.uri().clone();
3540
let ext = req.extensions().clone();
3641
loop {
37-
let mut res = self.service.call(ServiceRequest { req, client, timeout }).await?;
42+
let mut res = self
43+
.service
44+
.call(ServiceRequest {
45+
req,
46+
client,
47+
timeout,
48+
sni_hostname,
49+
})
50+
.await?;
3851
match res.status() {
3952
StatusCode::MOVED_PERMANENTLY | StatusCode::FOUND | StatusCode::SEE_OTHER => {
4053
if method != Method::HEAD {

client/src/request.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ pub struct RequestBuilder<'a, M = marker::Http> {
2222
err: Vec<Error>,
2323
client: &'a Client,
2424
timeout: Duration,
25+
sni_hostname: Option<String>,
2526
_marker: PhantomData<M>,
2627
}
2728

@@ -104,6 +105,7 @@ impl<'a, M> RequestBuilder<'a, M> {
104105
err: Vec::new(),
105106
client,
106107
timeout: client.timeout_config.request_timeout,
108+
sni_hostname: None,
107109
_marker: PhantomData,
108110
}
109111
}
@@ -114,6 +116,7 @@ impl<'a, M> RequestBuilder<'a, M> {
114116
err: self.err,
115117
client: self.client,
116118
timeout: self.timeout,
119+
sni_hostname: self.sni_hostname,
117120
_marker: PhantomData,
118121
}
119122
}
@@ -138,6 +141,7 @@ impl<'a, M> RequestBuilder<'a, M> {
138141
req: &mut req,
139142
client,
140143
timeout,
144+
sni_hostname: self.sni_hostname.as_deref(),
141145
})
142146
.await
143147
}
@@ -210,6 +214,13 @@ impl<'a, M> RequestBuilder<'a, M> {
210214
self
211215
}
212216

217+
/// Set SNI hostname of this request.
218+
#[inline]
219+
pub fn sni_hostname(mut self, sni_hostname: String) -> Self {
220+
self.sni_hostname = Some(sni_hostname);
221+
self
222+
}
223+
213224
fn map_body<B, E>(mut self, b: B) -> RequestBuilder<'a, M>
214225
where
215226
B: Stream<Item = Result<Bytes, E>> + Send + 'static,

client/src/service.rs

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ pub struct ServiceRequest<'r, 'c> {
6868
pub req: &'r mut Request<BoxBody>,
6969
pub client: &'c Client,
7070
pub timeout: Duration,
71+
pub sni_hostname: Option<&'r str>,
7172
}
7273

7374
/// type alias for object safe wrapper of type implement [Service] trait.
@@ -85,7 +86,12 @@ pub(crate) fn base_service() -> HttpService {
8586
#[cfg(any(feature = "http1", feature = "http2", feature = "http3"))]
8687
use crate::{error::TimeoutError, timeout::Timeout};
8788

88-
let ServiceRequest { req, client, timeout } = req;
89+
let ServiceRequest {
90+
req,
91+
client,
92+
timeout,
93+
sni_hostname,
94+
} = req;
8995

9096
let uri = Uri::try_parse(req.uri())?;
9197

@@ -94,13 +100,13 @@ pub(crate) fn base_service() -> HttpService {
94100
#[allow(unused_mut)]
95101
let mut version = req.version();
96102

97-
let mut connect = Connect::new(uri);
103+
let mut connect = Connect::new(uri, sni_hostname);
98104

99105
let _date = client.date_service.handle();
100106

101107
loop {
102108
match version {
103-
Version::HTTP_2 | Version::HTTP_3 => match client.shared_pool.acquire(&connect.uri).await {
109+
Version::HTTP_2 | Version::HTTP_3 => match client.shared_pool.acquire(&connect).await {
104110
shared::AcquireOutput::Conn(mut _conn) => {
105111
let mut _timer = Box::pin(tokio::time::sleep(timeout));
106112
*req.version_mut() = version;
@@ -155,7 +161,7 @@ pub(crate) fn base_service() -> HttpService {
155161
if let Ok(Ok(conn)) = crate::h3::proto::connect(
156162
&client.h3_client,
157163
connect.addrs(),
158-
connect.hostname(),
164+
connect.sni_hostname(),
159165
)
160166
.timeout(timer.as_mut())
161167
.await
@@ -197,7 +203,7 @@ pub(crate) fn base_service() -> HttpService {
197203

198204
#[cfg(feature = "http1")]
199205
{
200-
client.exclusive_pool.try_add(&connect.uri, conn);
206+
client.exclusive_pool.try_add(&connect, conn);
201207
// downgrade request version to what alpn protocol suggested from make_exclusive.
202208
version = alpn_version;
203209
}
@@ -212,7 +218,7 @@ pub(crate) fn base_service() -> HttpService {
212218
_ => unreachable!("outer match didn't handle version correctly."),
213219
},
214220
},
215-
version => match client.exclusive_pool.acquire(&connect.uri).await {
221+
version => match client.exclusive_pool.acquire(&connect).await {
216222
exclusive::AcquireOutput::Conn(mut _conn) => {
217223
*req.version_mut() = version;
218224

@@ -307,6 +313,7 @@ mod test {
307313
req,
308314
client: &self.0,
309315
timeout: self.0.timeout_config.request_timeout,
316+
sni_hostname: None,
310317
}
311318
}
312319
}

0 commit comments

Comments
 (0)