Skip to content

Commit 20f4564

Browse files
committed
refactor!: expect auth_provider as parameter
- update the `RawClient` methods to expect `auth_provider` as parameter. - update both `.with_auth()` and `.negotiate_protocol_version()` to private. - always call `.with_auth()` internally when creating the `RawClient`. - always call `.negotiate_protocol_version()` internally when creating the `RawClient`. It fixes the misleading API from the previous commit, the protocol version negotiation is a mandatory step in the protocol, and shouldn't be relied on the user reading the documentation that it MUST call the `.negotiate_protocol_version()` as previously stated, now it's always called internally.
1 parent eacf92d commit 20f4564

File tree

2 files changed

+77
-62
lines changed

2 files changed

+77
-62
lines changed

src/client.rs

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -124,20 +124,22 @@ impl ClientType {
124124
config.validate_domain(),
125125
socks5,
126126
config.timeout(),
127-
)?
128-
.with_auth(auth_provider)
129-
.negotiate_protocol_version()?,
130-
None => {
131-
RawClient::new_ssl(url.as_str(), config.validate_domain(), config.timeout())?
132-
.with_auth(auth_provider)
133-
.negotiate_protocol_version()?
134-
}
127+
auth_provider,
128+
)?,
129+
None => RawClient::new_ssl(
130+
url.as_str(),
131+
config.validate_domain(),
132+
config.timeout(),
133+
auth_provider,
134+
)?,
135135
};
136136
#[cfg(not(feature = "proxy"))]
137-
let raw_client =
138-
RawClient::new_ssl(url.as_str(), config.validate_domain(), config.timeout())?
139-
.with_auth(auth_provider)
140-
.negotiate_protocol_version()?;
137+
let raw_client = RawClient::new_ssl(
138+
url.as_str(),
139+
config.validate_domain(),
140+
config.timeout(),
141+
auth_provider,
142+
)?;
141143

142144
return Ok(ClientType::SSL(raw_client));
143145
}
@@ -154,24 +156,25 @@ impl ClientType {
154156

155157
#[cfg(feature = "proxy")]
156158
let client = match config.socks5() {
157-
Some(socks5) => ClientType::Socks5(
158-
RawClient::new_proxy(url.as_str(), socks5, config.timeout())?
159-
.with_auth(auth_provider)
160-
.negotiate_protocol_version()?,
161-
),
162-
None => ClientType::TCP(
163-
RawClient::new(url.as_str(), config.timeout())?
164-
.with_auth(auth_provider)
165-
.negotiate_protocol_version()?,
166-
),
159+
Some(socks5) => ClientType::Socks5(RawClient::new_proxy(
160+
url.as_str(),
161+
socks5,
162+
config.timeout(),
163+
auth_provider,
164+
)?),
165+
None => ClientType::TCP(RawClient::new(
166+
url.as_str(),
167+
config.timeout(),
168+
auth_provider,
169+
)?),
167170
};
168171

169172
#[cfg(not(feature = "proxy"))]
170-
let client = ClientType::TCP(
171-
RawClient::new(url.as_str(), config.timeout())?
172-
.with_auth(auth_provider)
173-
.negotiate_protocol_version()?,
174-
);
173+
let client = ClientType::TCP(RawClient::new(
174+
url.as_str(),
175+
config.timeout(),
176+
auth_provider,
177+
)?);
175178

176179
Ok(client)
177180
}

src/raw_client.rs

Lines changed: 47 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ impl RawClient<ElectrumPlaintextStream> {
222222
pub fn new<A: ToSocketAddrs>(
223223
socket_addrs: A,
224224
timeout: Option<Duration>,
225+
auth_provider: Option<AuthProvider>,
225226
) -> Result<Self, Error> {
226227
let stream = match timeout {
227228
Some(timeout) => {
@@ -233,7 +234,11 @@ impl RawClient<ElectrumPlaintextStream> {
233234
None => TcpStream::connect(socket_addrs)?,
234235
};
235236

236-
Ok(stream.into())
237+
let client = Self::from(stream)
238+
.with_auth(auth_provider)
239+
.negotiate_protocol_version()?;
240+
241+
Ok(client)
237242
}
238243
}
239244

@@ -285,6 +290,7 @@ impl RawClient<ElectrumSslStream> {
285290
socket_addrs: A,
286291
validate_domain: bool,
287292
timeout: Option<Duration>,
293+
auth_provider: Option<AuthProvider>,
288294
) -> Result<Self, Error> {
289295
debug!(
290296
"new_ssl socket_addrs.domain():{:?} validate_domain:{} timeout:{:?}",
@@ -300,11 +306,11 @@ impl RawClient<ElectrumSslStream> {
300306
let stream = connect_with_total_timeout(socket_addrs.clone(), timeout)?;
301307
stream.set_read_timeout(Some(timeout))?;
302308
stream.set_write_timeout(Some(timeout))?;
303-
Self::new_ssl_from_stream(socket_addrs, validate_domain, stream)
309+
Self::new_ssl_from_stream(socket_addrs, validate_domain, stream, auth_provider)
304310
}
305311
None => {
306312
let stream = TcpStream::connect(socket_addrs.clone())?;
307-
Self::new_ssl_from_stream(socket_addrs, validate_domain, stream)
313+
Self::new_ssl_from_stream(socket_addrs, validate_domain, stream, auth_provider)
308314
}
309315
}
310316
}
@@ -314,6 +320,7 @@ impl RawClient<ElectrumSslStream> {
314320
socket_addrs: A,
315321
validate_domain: bool,
316322
stream: TcpStream,
323+
auth_provider: Option<AuthProvider>,
317324
) -> Result<Self, Error> {
318325
let mut builder =
319326
SslConnector::builder(SslMethod::tls()).map_err(Error::InvalidSslMethod)?;
@@ -332,7 +339,11 @@ impl RawClient<ElectrumSslStream> {
332339
.connect(&domain, stream)
333340
.map_err(Error::SslHandshakeError)?;
334341

335-
Ok(stream.into())
342+
let client = Self::from(stream)
343+
.with_auth(auth_provider)
344+
.negotiate_protocol_version()?;
345+
346+
Ok(client)
336347
}
337348
}
338349

@@ -407,6 +418,7 @@ impl RawClient<ElectrumSslStream> {
407418
socket_addrs: A,
408419
validate_domain: bool,
409420
timeout: Option<Duration>,
421+
auth_provider: Option<AuthProvider>,
410422
) -> Result<Self, Error> {
411423
debug!(
412424
"new_ssl socket_addrs.domain():{:?} validate_domain:{} timeout:{:?}",
@@ -424,11 +436,11 @@ impl RawClient<ElectrumSslStream> {
424436
let stream = connect_with_total_timeout(socket_addrs.clone(), timeout)?;
425437
stream.set_read_timeout(Some(timeout))?;
426438
stream.set_write_timeout(Some(timeout))?;
427-
Self::new_ssl_from_stream(socket_addrs, validate_domain, stream)
439+
Self::new_ssl_from_stream(socket_addrs, validate_domain, stream, auth_provider)
428440
}
429441
None => {
430442
let stream = TcpStream::connect(socket_addrs.clone())?;
431-
Self::new_ssl_from_stream(socket_addrs, validate_domain, stream)
443+
Self::new_ssl_from_stream(socket_addrs, validate_domain, stream, auth_provider)
432444
}
433445
}
434446
}
@@ -438,6 +450,7 @@ impl RawClient<ElectrumSslStream> {
438450
socket_addr: A,
439451
validate_domain: bool,
440452
tcp_stream: TcpStream,
453+
auth_provider: Option<AuthProvider>,
441454
) -> Result<Self, Error> {
442455
use std::convert::TryFrom;
443456

@@ -501,7 +514,11 @@ impl RawClient<ElectrumSslStream> {
501514
.map_err(Error::CouldNotCreateConnection)?;
502515
let stream = StreamOwned::new(session, tcp_stream);
503516

504-
Ok(stream.into())
517+
let client = Self::from(stream)
518+
.with_auth(auth_provider)
519+
.negotiate_protocol_version()?;
520+
521+
Ok(client)
505522
}
506523
}
507524

@@ -517,6 +534,7 @@ impl RawClient<ElectrumProxyStream> {
517534
target_addr: T,
518535
proxy: &crate::Socks5Config,
519536
timeout: Option<Duration>,
537+
auth_provider: Option<AuthProvider>,
520538
) -> Result<Self, Error> {
521539
let mut stream = match proxy.credentials.as_ref() {
522540
Some(cred) => Socks5Stream::connect_with_password(
@@ -531,7 +549,11 @@ impl RawClient<ElectrumProxyStream> {
531549
stream.get_mut().set_read_timeout(timeout)?;
532550
stream.get_mut().set_write_timeout(timeout)?;
533551

534-
Ok(stream.into())
552+
let client = Self::from(stream)
553+
.with_auth(auth_provider)
554+
.negotiate_protocol_version()?;
555+
556+
Ok(client)
535557
}
536558

537559
#[cfg(all(
@@ -546,6 +568,7 @@ impl RawClient<ElectrumProxyStream> {
546568
validate_domain: bool,
547569
proxy: &crate::Socks5Config,
548570
timeout: Option<Duration>,
571+
auth_provider: Option<AuthProvider>,
549572
) -> Result<RawClient<ElectrumSslStream>, Error> {
550573
let target = target_addr.to_target_addr()?;
551574

@@ -563,7 +586,7 @@ impl RawClient<ElectrumProxyStream> {
563586
stream.get_mut().set_read_timeout(timeout)?;
564587
stream.get_mut().set_write_timeout(timeout)?;
565588

566-
RawClient::new_ssl_from_stream(target, validate_domain, stream.into_inner())
589+
RawClient::new_ssl_from_stream(target, validate_domain, stream.into_inner(), auth_provider)
567590
}
568591
}
569592

@@ -600,7 +623,7 @@ impl<S: Read + Write> RawClient<S> {
600623
/// This method should be called **before** [`RawClient::negotiate_protocol_version`],
601624
/// as the initial `server.version` handshake also requires authentication
602625
/// on protected servers.
603-
pub fn with_auth(mut self, auth_provider: Option<AuthProvider>) -> Self {
626+
fn with_auth(mut self, auth_provider: Option<AuthProvider>) -> Self {
604627
self.auth_provider = auth_provider;
605628
self
606629
}
@@ -613,11 +636,8 @@ impl<S: Read + Write> RawClient<S> {
613636
/// As of Electrum Protocol v1.6 it's a mandatory step, see:
614637
/// <https://electrum-protocol.readthedocs.io/en/latest/protocol-changes.html#version-1-6>
615638
///
616-
/// **NOTE:** It's only called automatically when building the client through [`ClientType`] constructors.
617-
/// If you are building the client by [`RawClient`] constructors you MUST call this before any other request.
618-
///
619639
/// [`ClientType`]: crate::ClientType
620-
pub fn negotiate_protocol_version(self) -> Result<Self, Error> {
640+
fn negotiate_protocol_version(self) -> Result<Self, Error> {
621641
let version_range = vec![
622642
PROTOCOL_VERSION_MIN.to_string(),
623643
PROTOCOL_VERSION_MAX.to_string(),
@@ -1389,18 +1409,22 @@ mod test {
13891409
// here's an useful list of live servers: https://1209k.com/bitcoin-eye/ele.php.
13901410
const DEFAULT_TEST_ELECTRUM_SERVER: &str = "fortress.qtornado.com:443";
13911411

1392-
fn get_test_raw_client() -> RawClient<ElectrumSslStream> {
1412+
fn get_test_auth_client(
1413+
authorization_provider: Option<AuthProvider>,
1414+
) -> RawClient<ElectrumSslStream> {
13931415
let server = std::env::var("TEST_ELECTRUM_SERVER")
13941416
.unwrap_or(DEFAULT_TEST_ELECTRUM_SERVER.to_owned());
13951417

1396-
RawClient::new_ssl(&*server, false, None)
1418+
RawClient::new_ssl(&*server, false, None, authorization_provider)
13971419
.expect("should build the `RawClient` successfully!")
13981420
}
13991421

14001422
fn get_test_client() -> RawClient<ElectrumSslStream> {
1401-
get_test_raw_client()
1402-
.negotiate_protocol_version()
1403-
.expect("should negotiate `server.version` successfully!")
1423+
let server = std::env::var("TEST_ELECTRUM_SERVER")
1424+
.unwrap_or(DEFAULT_TEST_ELECTRUM_SERVER.to_owned());
1425+
1426+
RawClient::new_ssl(&*server, false, None, None)
1427+
.expect("should build the `RawClient` successfully!")
14041428
}
14051429

14061430
#[test]
@@ -1895,10 +1919,7 @@ mod test {
18951919
Some("Bearer test-token-123".to_string())
18961920
});
18971921

1898-
let client = get_test_raw_client()
1899-
.with_auth(Some(auth_provider))
1900-
.negotiate_protocol_version()
1901-
.expect("should negotiate `server.version` successfully!");
1922+
let client = get_test_auth_client(Some(auth_provider));
19021923

19031924
// Make a request - provider should be called
19041925
let _ = client.server_features();
@@ -1918,10 +1939,7 @@ mod test {
19181939
let auth_provider: AuthProvider =
19191940
Arc::new(move || Some(token_clone.read().unwrap().clone()));
19201941

1921-
let client = get_test_raw_client()
1922-
.with_auth(Some(auth_provider.clone()))
1923-
.negotiate_protocol_version()
1924-
.expect("should negotiate `server.version` successfully!");
1942+
let client = get_test_auth_client(Some(auth_provider.clone()));
19251943

19261944
// Make first request with initial token
19271945
let _ = client.server_features();
@@ -1942,10 +1960,7 @@ mod test {
19421960

19431961
let auth_provider: AuthProvider = Arc::new(|| None);
19441962

1945-
let client = get_test_raw_client()
1946-
.with_auth(Some(auth_provider))
1947-
.negotiate_protocol_version()
1948-
.expect("should negotiate `server.version` successfully!");
1963+
let client = get_test_auth_client(Some(auth_provider));
19491964

19501965
// Should still work when provider returns None
19511966
let result = client.server_features();
@@ -1958,10 +1973,7 @@ mod test {
19581973

19591974
let auth_provider: AuthProvider = Arc::new(|| Some("Bearer test".to_string()));
19601975

1961-
let client = get_test_raw_client()
1962-
.with_auth(Some(auth_provider))
1963-
.negotiate_protocol_version()
1964-
.expect("should negotiate `server.version` successfully!");
1976+
let client = get_test_auth_client(Some(auth_provider));
19651977

19661978
// Verify the provider was set
19671979
let result = client.server_features();

0 commit comments

Comments
 (0)