diff --git a/README.md b/README.md index b87d91c..3d06165 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ An asynchronous Redis client for Rust. * [LUA Scripts/Functions](hhttps://redis.io/docs/latest/develop/programmability/) support * [Cluster](https://redis.io/docs/latest/operate/oss_and_stack/management/scaling/) support (minimus supported Redis version is 6) * [Client-side caching](https://redis.io/docs/latest/develop/reference/client-side-caching/) support +* Support for AWS/GCP/other IAM auth using `with_credentials_provider` # Protocol Compatibility diff --git a/src/client/auth.rs b/src/client/auth.rs new file mode 100644 index 0000000..a89348f --- /dev/null +++ b/src/client/auth.rs @@ -0,0 +1,302 @@ +use super::{Config, IntoConfig}; +use crate::{Future, Result}; +use std::{fmt, future::Future as StdFuture, sync::Arc}; + +/// Fresh authentication material for a newly established Redis TCP session. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Credentials { + pub username: Option, + pub password: String, +} + +impl Credentials { + #[must_use] + pub fn new(username: impl Into, password: impl Into) -> Self { + Self { + username: Some(username.into()), + password: password.into(), + } + } + + #[must_use] + pub fn for_default_user(password: impl Into) -> Self { + Self { + username: None, + password: password.into(), + } + } +} + +/// Why a new TCP session is being authenticated. +#[non_exhaustive] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CredentialsReason { + InitialConnect, + Reconnect, + TopologyRefresh, +} + +/// Which kind of server socket is being authenticated. +#[non_exhaustive] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CredentialsTarget { + DataNode, + SentinelNode, +} + +/// The higher-level topology that triggered this authentication request. +#[non_exhaustive] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ServerKind { + Standalone, + Sentinel, + Cluster, +} + +/// Connection metadata passed to a [`CredentialsProvider`]. +#[non_exhaustive] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CredentialsContext { + pub host: String, + pub port: u16, + pub reason: CredentialsReason, + pub target: CredentialsTarget, + pub server_kind: ServerKind, + pub tls_enabled: bool, +} + +/// Async credentials source used to authenticate every new TCP session. +/// +/// Providers are invoked immediately before authentication on each newly established TCP session, +/// including the initial connect, reconnects, pub/sub reconnects, and cluster or sentinel node +/// connections. This makes them suitable for short-lived cloud credentials such as IAM-issued +/// Redis tokens. +/// +/// Providers own their own caching policy. If token minting is expensive, cache internally and +/// return a fresh-enough [`Credentials`] value when `rustis` asks for one. +pub trait CredentialsProvider: Send + Sync + 'static { + fn resolve(&self, context: CredentialsContext) -> Future<'_, Credentials>; +} + +/// Cloneable handle to a shared [`CredentialsProvider`]. +#[derive(Clone)] +pub struct SharedCredentialsProvider(Arc); + +impl SharedCredentialsProvider { + #[must_use] + pub fn new(provider: P) -> Self { + Self(Arc::new(provider)) + } + + pub(crate) fn resolve(&self, context: CredentialsContext) -> Future<'_, Credentials> { + self.0.resolve(context) + } +} + +impl fmt::Debug for SharedCredentialsProvider { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("SharedCredentialsProvider(..)") + } +} + +impl From

for SharedCredentialsProvider { + fn from(provider: P) -> Self { + Self::new(provider) + } +} + +impl CredentialsProvider for F +where + F: Fn(CredentialsContext) -> Fut + Send + Sync + 'static, + Fut: StdFuture> + Send + 'static, +{ + fn resolve(&self, context: CredentialsContext) -> Future<'_, Credentials> { + Box::pin((self)(context)) + } +} + +/// Wrap an async closure into a [`SharedCredentialsProvider`]. +#[must_use] +pub fn credentials_provider_fn(f: F) -> SharedCredentialsProvider +where + F: Fn(CredentialsContext) -> Fut + Send + Sync + 'static, + Fut: StdFuture> + Send + 'static, +{ + SharedCredentialsProvider::new(f) +} + +/// `Internal Use` +/// +/// Connection inputs after resolving an [`IntoConfig`] implementation and layering +/// any dynamic credentials providers on top of it. +/// +/// This type is public because it appears in the hidden +/// [`IntoConfig::into_connection_setup`](crate::client::IntoConfig::into_connection_setup) +/// plumbing, but it is not intended to be constructed, matched on, or stored directly +/// by end users. +#[doc(hidden)] +#[derive(Debug, Clone)] +pub struct ConnectionSetup { + pub(crate) config: Config, + pub(crate) credentials_provider: Option, + pub(crate) sentinel_credentials_provider: Option, +} + +impl ConnectionSetup { + #[must_use] + pub fn new(config: Config) -> Self { + Self { + config, + credentials_provider: None, + sentinel_credentials_provider: None, + } + } + + #[must_use] + pub(crate) fn with_credentials_provider(mut self, provider: SharedCredentialsProvider) -> Self { + self.credentials_provider = Some(provider); + self + } + + #[must_use] + pub(crate) fn with_sentinel_credentials_provider( + mut self, + provider: SharedCredentialsProvider, + ) -> Self { + self.sentinel_credentials_provider = Some(provider); + self + } +} + +impl IntoConfig for ConnectionSetup { + fn into_config(self) -> Result { + Ok(self.config) + } + + fn into_connection_setup(self) -> Result { + Ok(self) + } +} + +/// Wrapper returned by [`WithCredentialsProvider`] to attach dynamic auth providers +/// to any [`IntoConfig`] input. +/// +/// End users should normally obtain this type via the extension-trait methods rather +/// than naming it directly. +#[derive(Debug, Clone)] +pub struct ConfigWithCredentialsProvider { + inner: C, + credentials_provider: Option, + sentinel_credentials_provider: Option, +} + +impl ConfigWithCredentialsProvider { + fn new(inner: C) -> Self { + Self { + inner, + credentials_provider: None, + sentinel_credentials_provider: None, + } + } + + /// Use this provider for every new Redis data-node TCP session. + #[must_use] + pub fn with_credentials_provider( + mut self, + provider: impl Into, + ) -> Self { + self.credentials_provider = Some(provider.into()); + self + } + + /// Use this provider for every new Sentinel control-plane TCP session. + #[must_use] + pub fn with_sentinel_credentials_provider( + mut self, + provider: impl Into, + ) -> Self { + self.sentinel_credentials_provider = Some(provider.into()); + self + } +} + +impl IntoConfig for ConfigWithCredentialsProvider { + fn into_config(self) -> Result { + Ok(self.into_connection_setup()?.config) + } + + fn into_connection_setup(self) -> Result { + let mut setup = self.inner.into_connection_setup()?; + + if let Some(provider) = self.credentials_provider { + setup = setup.with_credentials_provider(provider); + } + + if let Some(provider) = self.sentinel_credentials_provider { + setup = setup.with_sentinel_credentials_provider(provider); + } + + Ok(setup) + } +} + +/// Extension methods for attaching dynamic credentials providers to any +/// [`IntoConfig`] input accepted by `rustis`. +/// +/// This is the public entry point for dynamic authentication. The configured provider is called +/// every time `rustis` opens a new authenticated TCP session, so it integrates with the library's +/// built-in reconnect and resubscribe behavior. +/// +/// # IAM example +/// +/// ```no_run +/// use rustis::{ +/// client::{Client, Credentials, CredentialsContext, WithCredentialsProvider}, +/// commands::ConnectionCommands, +/// Result, +/// }; +/// +/// async fn fetch_iam_token(_ctx: &CredentialsContext) -> Result { +/// todo!("Call your cloud SDK or metadata service here") +/// } +/// +/// #[tokio::main] +/// async fn main() -> Result<()> { +/// let client = Client::connect( +/// "rediss://cache.example.com:6379".with_credentials_provider(|ctx| async move { +/// let token = fetch_iam_token(&ctx).await?; +/// Ok(Credentials::for_default_user(token)) +/// }), +/// ) +/// .await?; +/// +/// let _: String = client.ping("hello").await?; +/// Ok(()) +/// } +/// ``` +pub trait WithCredentialsProvider: IntoConfig + Sized { + /// Use this provider for every new Redis data-node TCP session. + /// + /// Static [`Config`](crate::client::Config) username and password fields are still supported, + /// but when a dynamic provider is attached, its return value is used for each new session. + #[must_use] + fn with_credentials_provider( + self, + provider: impl Into, + ) -> ConfigWithCredentialsProvider { + ConfigWithCredentialsProvider::new(self).with_credentials_provider(provider) + } + + /// Use this provider for every new Sentinel control-plane TCP session. + /// + /// This is useful when sentinel discovery and Redis data nodes use different credentials. + #[must_use] + fn with_sentinel_credentials_provider( + self, + provider: impl Into, + ) -> ConfigWithCredentialsProvider { + ConfigWithCredentialsProvider::new(self).with_sentinel_credentials_provider(provider) + } +} + +impl WithCredentialsProvider for T {} diff --git a/src/client/client.rs b/src/client/client.rs index dcf3070..233fb3b 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -68,11 +68,11 @@ impl Client { /// Any Redis driver [`Error`](crate::Error) that occurs during the connection operation #[inline] pub async fn connect(config: impl IntoConfig) -> Result { - let config = config.into_config()?; - let command_timeout = config.command_timeout; - let retry_on_error = config.retry_on_error; + let setup = config.into_connection_setup()?; + let command_timeout = setup.config.command_timeout; + let retry_on_error = setup.config.retry_on_error; let (msg_sender, network_task_join_handle, reconnect_sender, connection_tag) = - NetworkHandler::connect(config.into_config()?).await?; + NetworkHandler::connect(setup).await?; Ok(Self { msg_sender: Arc::new(Some(msg_sender)), diff --git a/src/client/config.rs b/src/client/config.rs index c6fdf8c..95b8655 100644 --- a/src/client/config.rs +++ b/src/client/config.rs @@ -1,3 +1,4 @@ +use super::{ConnectionSetup, ServerKind}; use crate::{ClientError, Error, Result}; #[cfg(feature = "native-tls")] use native_tls::{Certificate, Identity, Protocol, TlsConnector, TlsConnectorBuilder}; @@ -150,6 +151,26 @@ impl Config { Self::from_str(uri.as_str()) } + pub(crate) fn server_kind(&self) -> ServerKind { + match &self.server { + ServerConfig::Standalone { .. } => ServerKind::Standalone, + ServerConfig::Sentinel(_) => ServerKind::Sentinel, + ServerConfig::Cluster(_) => ServerKind::Cluster, + } + } + + pub(crate) fn tls_enabled(&self) -> bool { + #[cfg(any(feature = "native-tls", feature = "rustls"))] + { + self.tls_config.is_some() + } + + #[cfg(not(any(feature = "native-tls", feature = "rustls")))] + { + false + } + } + /// Parse address in the standard formart `host`:`port` fn parse_addr(str: &str) -> Option<(&str, u16)> { let mut iter = str.split(':'); @@ -848,6 +869,14 @@ impl TlsConfig { pub trait IntoConfig { /// Converts this type into a [`Config`](crate::client::Config). fn into_config(self) -> Result; + + #[doc(hidden)] + fn into_connection_setup(self) -> Result + where + Self: Sized, + { + Ok(ConnectionSetup::new(self.into_config()?)) + } } impl IntoConfig for Config { diff --git a/src/client/mod.rs b/src/client/mod.rs index 1669dc2..0ccd559 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -120,6 +120,52 @@ A [`Client`] instance can be configured with the [`Config`] struct: * [`String`](https://doc.rust-lang.org/alloc/string/struct.String.html): host and port separated by a colon * [`Url`](https://docs.rs/url/latest/url/struct.Url.html): see Url syntax below. +## Dynamic Authentication + +For deployments that use short-lived credentials, such as cloud IAM auth, attach an async +credentials provider with [`WithCredentialsProvider::with_credentials_provider`]. + +The provider is called each time `rustis` establishes a new authenticated TCP session: +* initial connect +* reconnect +* pub/sub reconnect +* cluster node connect +* sentinel node connect + +The provider returns fresh [`Credentials`](crate::client::Credentials) immediately before +authentication. Any token caching policy stays inside the provider implementation. + +### IAM Example + +```no_run +use rustis::{ + client::{Client, Credentials, CredentialsContext, WithCredentialsProvider}, + commands::ConnectionCommands, + Result, +}; + +async fn fetch_iam_token(_ctx: &CredentialsContext) -> Result { + todo!("Call your cloud SDK or metadata service here") +} + +#[tokio::main] +async fn main() -> Result<()> { + let client = Client::connect( + "rediss://cache.example.com:6379".with_credentials_provider(|ctx| async move { + let token = fetch_iam_token(&ctx).await?; + Ok(Credentials::for_default_user(token)) + }), + ) + .await?; + + let _: String = client.ping("hello").await?; + Ok(()) +} +``` + +If Sentinel uses different credentials from the Redis data node, chain +[`WithCredentialsProvider::with_sentinel_credentials_provider`] as well. + ## Url Syntax The **rustis** [`Config`] can also be built from an URL @@ -414,6 +460,7 @@ async fn main() -> Result<()> { */ #[allow(clippy::module_inception)] +mod auth; mod client; mod client_tracking_invalidation_stream; mod config; @@ -427,6 +474,7 @@ mod prepared_command; mod pub_sub_stream; mod transaction; +pub use auth::*; pub use client::*; pub(crate) use client_tracking_invalidation_stream::*; pub use config::*; diff --git a/src/client/pooled_client_manager.rs b/src/client/pooled_client_manager.rs index 357424d..4f410df 100644 --- a/src/client/pooled_client_manager.rs +++ b/src/client/pooled_client_manager.rs @@ -1,19 +1,19 @@ use crate::{ Error, Result, - client::{Client, Config, IntoConfig}, + client::{Client, ConnectionSetup, IntoConfig}, commands::ConnectionCommands, }; use bb8::ManageConnection; /// An object which manages a pool of clients, based on [bb8](https://docs.rs/bb8/latest/bb8/) pub struct PooledClientManager { - config: Config, + setup: ConnectionSetup, } impl PooledClientManager { pub fn new(config: impl IntoConfig) -> Result { Ok(Self { - config: config.into_config()?, + setup: config.into_connection_setup()?, }) } } @@ -23,8 +23,7 @@ impl ManageConnection for PooledClientManager { type Error = Error; async fn connect(&self) -> Result { - let config = self.config.clone(); - Client::connect(config).await + Client::connect(self.setup.clone()).await } async fn is_valid(&self, client: &mut Client) -> Result<()> { diff --git a/src/network/cluster_connection.rs b/src/network/cluster_connection.rs index e1342bb..a30f6b1 100644 --- a/src/network/cluster_connection.rs +++ b/src/network/cluster_connection.rs @@ -1,6 +1,8 @@ use crate::{ ClientError, Error, RedisError, RedisErrorKind, Result, RetryReason, StandaloneConnection, - client::{ClusterConfig, Config}, + client::{ + ClusterConfig, Config, CredentialsReason, CredentialsTarget, SharedCredentialsProvider, + }, commands::{ ClusterCommands, ClusterHealthStatus, ClusterNodeResult, ClusterShardResult, LegacyClusterShardResult, RequestPolicy, ResponsePolicy, @@ -118,6 +120,7 @@ impl ClusterNodeResult { pub struct ClusterConnection { cluster_config: ClusterConfig, config: Config, + credentials_provider: Option, nodes: Vec, slot_ranges: Vec, pending_requests: VecDeque, @@ -130,8 +133,15 @@ impl ClusterConnection { pub async fn connect( cluster_config: &ClusterConfig, config: &Config, + credentials_provider: Option, ) -> Result { - let (mut nodes, slot_ranges) = Self::connect_to_cluster(cluster_config, config).await?; + let (mut nodes, slot_ranges) = Self::connect_to_cluster( + cluster_config, + config, + credentials_provider.clone(), + CredentialsReason::InitialConnect, + ) + .await?; let first_node = nodes .get_mut(0) .ok_or_else(|| Error::Client(ClientError::ClusterConfig))?; @@ -141,6 +151,7 @@ impl ClusterConnection { Ok(ClusterConnection { cluster_config: cluster_config.clone(), config: config.clone(), + credentials_provider, nodes, slot_ranges, pending_requests: VecDeque::new(), @@ -810,8 +821,13 @@ impl ClusterConnection { pub async fn reconnect(&mut self) -> Result<()> { info!("[{}] Reconnecting to cluster...", self.tag); - let (nodes, slot_ranges) = - Self::connect_to_cluster(&self.cluster_config, &self.config).await?; + let (nodes, slot_ranges) = Self::connect_to_cluster( + &self.cluster_config, + &self.config, + self.credentials_provider.clone(), + CredentialsReason::Reconnect, + ) + .await?; info!("[{}] Reconnected to cluster!", self.tag); self.nodes = nodes; @@ -825,13 +841,24 @@ impl ClusterConnection { async fn connect_to_cluster( cluster_config: &ClusterConfig, config: &Config, + credentials_provider: Option, + credentials_reason: CredentialsReason, ) -> Result<(Vec, Vec)> { debug!("Discovering cluster shard and slots..."); let mut shard_info_list: Option> = None; for node_config in &cluster_config.nodes { - match StandaloneConnection::connect(&node_config.0, node_config.1, config).await { + match StandaloneConnection::connect_with_context( + &node_config.0, + node_config.1, + config, + credentials_provider.clone(), + credentials_reason, + CredentialsTarget::DataNode, + ) + .await + { Ok(mut connection) => { let version: Result = connection.get_version().try_into(); let Ok(version) = version else { @@ -895,7 +922,15 @@ impl ClusterConnection { let port = master_info.get_port()?; - let connection = StandaloneConnection::connect(&master_info.ip, port, config).await?; + let connection = StandaloneConnection::connect_with_context( + &master_info.ip, + port, + config, + credentials_provider.clone(), + credentials_reason, + CredentialsTarget::DataNode, + ) + .await?; slot_ranges.extend(shard_info.slots.iter().map(|s| SlotRange { slot_range: *s, @@ -937,8 +972,15 @@ impl ClusterConnection { let port = node_info.get_port()?; let node_id: NodeId = node_info.id.as_str().into(); - let connection = - StandaloneConnection::connect(&node_info.ip, port, &self.config).await?; + let connection = StandaloneConnection::connect_with_context( + &node_info.ip, + port, + &self.config, + self.credentials_provider.clone(), + CredentialsReason::TopologyRefresh, + CredentialsTarget::DataNode, + ) + .await?; for slot_range_info in &shard_info.slots { if let Some(slot_range) = self.get_slot_range_by_slot_mut(slot_range_info.0) @@ -1032,8 +1074,15 @@ impl ClusterConnection { // add missing node let port = node_info.get_port()?; - let connection = - StandaloneConnection::connect(&node_info.ip, port, &self.config).await?; + let connection = StandaloneConnection::connect_with_context( + &node_info.ip, + port, + &self.config, + self.credentials_provider.clone(), + CredentialsReason::TopologyRefresh, + CredentialsTarget::DataNode, + ) + .await?; self.nodes.push(Node { id: node_id, diff --git a/src/network/connection.rs b/src/network/connection.rs index ebd0829..d1b2b59 100644 --- a/src/network/connection.rs +++ b/src/network/connection.rs @@ -1,7 +1,7 @@ use crate::{ ClusterConnection, Error, Future, Result, RetryReason, SentinelConnection, StandaloneConnection, - client::{Config, PreparedCommand, ServerConfig}, + client::{ConnectionSetup, PreparedCommand, ServerConfig}, commands::InternalPubSubCommands, resp::{Command, RespResponse}, }; @@ -17,16 +17,33 @@ pub enum Connection { impl Connection { #[inline] - pub async fn connect(config: Config) -> Result { - match &config.server { + pub async fn connect(setup: ConnectionSetup) -> Result { + match setup.config.server.clone() { ServerConfig::Standalone { host, port } => Ok(Connection::Standalone( - StandaloneConnection::connect(host, *port, &config).await?, + StandaloneConnection::connect( + &host, + port, + &setup.config, + setup.credentials_provider.clone(), + ) + .await?, )), ServerConfig::Sentinel(sentinel_config) => Ok(Connection::Sentinel( - SentinelConnection::connect(sentinel_config, &config).await?, + SentinelConnection::connect( + &sentinel_config, + &setup.config, + setup.credentials_provider.clone(), + setup.sentinel_credentials_provider.clone(), + ) + .await?, )), ServerConfig::Cluster(cluster_config) => Ok(Connection::Cluster( - ClusterConnection::connect(cluster_config, &config).await?, + ClusterConnection::connect( + &cluster_config, + &setup.config, + setup.credentials_provider.clone(), + ) + .await?, )), } } diff --git a/src/network/network_handler.rs b/src/network/network_handler.rs index 6e3a37c..4999046 100644 --- a/src/network/network_handler.rs +++ b/src/network/network_handler.rs @@ -1,7 +1,7 @@ use super::pub_sub_message::PubSubMessage; use crate::{ ClientError, Connection, Error, JoinHandle, ReconnectionState, Result, RetryReason, - client::{Config, Message, MessageKind}, + client::{ConnectionSetup, Message, MessageKind}, commands::InternalPubSubCommands, resp::{ClientReplyMode, CommandKind, RespResponse, SubscriptionType, cmd}, spawn, timeout, @@ -107,14 +107,14 @@ pub(crate) struct NetworkHandler { impl NetworkHandler { pub async fn connect( - config: Config, + setup: ConnectionSetup, ) -> Result<(MsgSender, JoinHandle<()>, ReconnectSender, Arc)> { // options - let auto_resubscribe = config.auto_resubscribe; - let auto_remonitor = config.auto_remonitor; - let reconnection_config = config.reconnection.clone(); + let auto_resubscribe = setup.config.auto_resubscribe; + let auto_remonitor = setup.config.auto_remonitor; + let reconnection_config = setup.config.reconnection.clone(); - let connection = Connection::connect(config).await?; + let connection = Connection::connect(setup).await?; let (msg_sender, msg_receiver): (MsgSender, MsgReceiver) = mpsc::unbounded(); let (reconnect_sender, _): (ReconnectSender, ReconnectReceiver) = broadcast::channel(32); let tag = connection.tag().to_owned(); diff --git a/src/network/sentinel_connection.rs b/src/network/sentinel_connection.rs index 4ecaca8..df839e4 100644 --- a/src/network/sentinel_connection.rs +++ b/src/network/sentinel_connection.rs @@ -1,6 +1,8 @@ use crate::{ Error, Result, RetryReason, StandaloneConnection, - client::{Config, SentinelConfig}, + client::{ + Config, CredentialsReason, CredentialsTarget, SentinelConfig, SharedCredentialsProvider, + }, commands::{RoleResult, SentinelCommands, ServerCommands}, resp::{Command, RespResponse}, sleep, @@ -11,6 +13,8 @@ use std::{sync::Arc, task::Poll}; pub struct SentinelConnection { sentinel_config: SentinelConfig, config: Config, + credentials_provider: Option, + sentinel_credentials_provider: Option, pub inner_connection: StandaloneConnection, } @@ -37,8 +41,14 @@ impl SentinelConnection { #[inline] pub async fn reconnect(&mut self) -> Result<()> { - self.inner_connection = - Self::connect_to_sentinel(&self.sentinel_config, &self.config).await?; + self.inner_connection = Self::connect_to_sentinel( + &self.sentinel_config, + &self.config, + self.credentials_provider.clone(), + self.sentinel_credentials_provider.clone(), + CredentialsReason::Reconnect, + ) + .await?; Ok(()) } @@ -52,12 +62,23 @@ impl SentinelConnection { pub async fn connect( sentinel_config: &SentinelConfig, config: &Config, + credentials_provider: Option, + sentinel_credentials_provider: Option, ) -> Result { - let inner_connection = Self::connect_to_sentinel(sentinel_config, config).await?; + let inner_connection = Self::connect_to_sentinel( + sentinel_config, + config, + credentials_provider.clone(), + sentinel_credentials_provider.clone(), + CredentialsReason::InitialConnect, + ) + .await?; Ok(SentinelConnection { sentinel_config: sentinel_config.clone(), config: config.clone(), + credentials_provider, + sentinel_credentials_provider, inner_connection, }) } @@ -65,6 +86,9 @@ impl SentinelConnection { async fn connect_to_sentinel( sentinel_config: &SentinelConfig, config: &Config, + credentials_provider: Option, + sentinel_credentials_provider: Option, + credentials_reason: CredentialsReason, ) -> Result { let mut restart = false; let mut unreachable_sentinel = true; @@ -82,14 +106,22 @@ impl SentinelConnection { // Step 1: connecting to Sentinel let (host, port) = sentinel_instance; - let mut sentinel_connection = - match StandaloneConnection::connect(host, *port, &sentinel_node_config).await { - Ok(sentinel_connection) => sentinel_connection, - Err(e) => { - debug!("Cannot connect to Sentinel {}:{} : {}", *host, *port, e); - continue; - } - }; + let mut sentinel_connection = match StandaloneConnection::connect_with_context( + host, + *port, + &sentinel_node_config, + sentinel_credentials_provider.clone(), + credentials_reason, + CredentialsTarget::SentinelNode, + ) + .await + { + Ok(sentinel_connection) => sentinel_connection, + Err(e) => { + debug!("Cannot connect to Sentinel {}:{} : {}", *host, *port, e); + continue; + } + }; // Step 2: ask for master address let (master_host, master_port) = match sentinel_connection @@ -115,8 +147,15 @@ impl SentinelConnection { }; // Step 3: call the ROLE command in the target instance - let mut master_connection = - StandaloneConnection::connect(&master_host, master_port, config).await?; + let mut master_connection = StandaloneConnection::connect_with_context( + &master_host, + master_port, + config, + credentials_provider.clone(), + credentials_reason, + CredentialsTarget::DataNode, + ) + .await?; let role: RoleResult = master_connection.role().await?; diff --git a/src/network/standalone_connection.rs b/src/network/standalone_connection.rs index fe2367e..9841d9a 100644 --- a/src/network/standalone_connection.rs +++ b/src/network/standalone_connection.rs @@ -1,6 +1,9 @@ use crate::{ Error, Future, Result, RetryReason, TcpStreamReader, TcpStreamWriter, - client::{Config, PreparedCommand}, + client::{ + Config, Credentials, CredentialsContext, CredentialsReason, CredentialsTarget, + PreparedCommand, SharedCredentialsProvider, + }, commands::{ ClusterCommands, ConnectionCommands, HelloOptions, SentinelCommands, ServerCommands, }, @@ -61,19 +64,47 @@ pub struct StandaloneConnection { host: String, port: u16, config: Config, + credentials_provider: Option, + credentials_target: CredentialsTarget, streams: Streams, version: String, tag: Arc, } impl StandaloneConnection { - pub async fn connect(host: &str, port: u16, config: &Config) -> Result { + pub async fn connect( + host: &str, + port: u16, + config: &Config, + credentials_provider: Option, + ) -> Result { + Self::connect_with_context( + host, + port, + config, + credentials_provider, + CredentialsReason::InitialConnect, + CredentialsTarget::DataNode, + ) + .await + } + + pub async fn connect_with_context( + host: &str, + port: u16, + config: &Config, + credentials_provider: Option, + credentials_reason: CredentialsReason, + credentials_target: CredentialsTarget, + ) -> Result { let streams = Streams::connect(host, port, config).await?; let mut connection = Self { host: host.to_owned(), port, config: config.clone(), + credentials_provider, + credentials_target, streams, version: String::new(), tag: if config.connection_name.is_empty() { @@ -83,7 +114,7 @@ impl StandaloneConnection { }, }; - connection.post_connect().await?; + connection.post_connect(credentials_reason).await?; Ok(connection) } @@ -105,8 +136,15 @@ impl StandaloneConnection { let client_id = self.client_id().await?; let mut config = self.config.clone(); "killer".clone_into(&mut config.connection_name); - let mut connection = - StandaloneConnection::connect(&self.host, self.port, &config).await?; + let mut connection = StandaloneConnection::connect_with_context( + &self.host, + self.port, + &config, + self.credentials_provider.clone(), + CredentialsReason::InitialConnect, + self.credentials_target, + ) + .await?; connection .client_kill(crate::commands::ClientKillOptions::default().id(client_id)) .await?; @@ -179,27 +217,23 @@ impl StandaloneConnection { pub async fn reconnect(&mut self) -> Result<()> { self.streams = Streams::connect(&self.host, self.port, &self.config).await?; - self.post_connect().await?; + self.post_connect(CredentialsReason::Reconnect).await?; Ok(()) } - async fn post_connect(&mut self) -> Result<()> { + async fn post_connect(&mut self, credentials_reason: CredentialsReason) -> Result<()> { // RESP3 let mut hello_options = HelloOptions::new(3); - let config_username = self.config.username.clone(); - let config_password = self.config.password.clone(); let config_connection_name = self.config.connection_name.clone(); + let resolved_credentials = self.resolve_credentials(credentials_reason).await?; // authentication - if let Some(password) = &config_password { + if let Some(credentials) = &resolved_credentials { hello_options = hello_options.auth( - match &config_username { - Some(username) => username, - None => "default", - }, - password, + credentials.username.as_deref().unwrap_or("default"), + &credentials.password, ); } @@ -219,6 +253,26 @@ impl StandaloneConnection { Ok(()) } + async fn resolve_credentials(&self, reason: CredentialsReason) -> Result> { + if let Some(provider) = &self.credentials_provider { + let context = CredentialsContext { + host: self.host.clone(), + port: self.port, + reason, + target: self.credentials_target, + server_kind: self.config.server_kind(), + tls_enabled: self.config.tls_enabled(), + }; + + return Ok(Some(provider.resolve(context).await?)); + } + + Ok(self.config.password.as_ref().map(|password| Credentials { + username: self.config.username.clone(), + password: password.clone(), + })) + } + pub fn get_version(&self) -> &str { &self.version } diff --git a/src/tests/client.rs b/src/tests/client.rs index 9ed3360..c6e541f 100644 --- a/src/tests/client.rs +++ b/src/tests/client.rs @@ -1,8 +1,15 @@ +use std::sync::{ + Arc, Mutex, + atomic::{AtomicUsize, Ordering}, +}; use std::time::Duration; use crate::{ Error, Result, - client::{Client, IntoConfig}, + client::{ + Client, Credentials, CredentialsReason, CredentialsTarget, IntoConfig, ServerKind, + WithCredentialsProvider, + }, commands::{ BlockingCommands, ClientKillOptions, ConnectionCommands, FlushingMode, LMoveWhere, ListCommands, ServerCommands, StringCommands, @@ -68,6 +75,59 @@ async fn on_reconnect() -> Result<()> { Ok(()) } +#[cfg_attr(feature = "tokio-runtime", tokio::test)] +#[cfg_attr(feature = "async-std-runtime", async_std::test)] +#[serial] +async fn dynamic_credentials_provider_is_used_for_connect_and_reconnect() -> Result<()> { + let admin = get_test_client().await?; + admin.config_set(("requirepass", "pwd")).await?; + + let addr = get_default_addr(); + let calls = Arc::new(AtomicUsize::new(0)); + let contexts = Arc::new(Mutex::new(Vec::new())); + let calls_clone = calls.clone(); + let contexts_clone = contexts.clone(); + + let config = addr.clone().with_credentials_provider(move |ctx| { + let calls_clone = calls_clone.clone(); + let contexts_clone = contexts_clone.clone(); + + async move { + calls_clone.fetch_add(1, Ordering::SeqCst); + contexts_clone.lock().unwrap().push(ctx); + Ok(Credentials::for_default_user("pwd")) + } + }); + + let client = Client::connect(config).await?; + assert_eq!(1, calls.load(Ordering::SeqCst)); + + let killer = Client::connect(format!("redis://:pwd@{}", addr)).await?; + let client_id = client.client_id().await?; + killer + .client_kill(ClientKillOptions::default().id(client_id)) + .await?; + + client + .set("provider:key", "value") + .retry_on_error(true) + .await?; + + let recorded = contexts.lock().unwrap().clone(); + assert_eq!(2, recorded.len()); + assert_eq!(CredentialsReason::InitialConnect, recorded[0].reason); + assert_eq!(CredentialsReason::Reconnect, recorded[1].reason); + assert_eq!(CredentialsTarget::DataNode, recorded[0].target); + assert_eq!(ServerKind::Standalone, recorded[0].server_kind); + + client.config_set(("requirepass", "")).await?; + killer.close().await?; + client.close().await?; + admin.close().await?; + + Ok(()) +} + #[cfg_attr(feature = "tokio-runtime", tokio::test)] #[cfg_attr(feature = "async-std-runtime", async_std::test)] #[serial] diff --git a/src/tests/config.rs b/src/tests/config.rs index 4d2ba5c..a8a06ed 100644 --- a/src/tests/config.rs +++ b/src/tests/config.rs @@ -1,6 +1,6 @@ use crate::{ Result, - client::{Client, IntoConfig}, + client::{Client, Credentials, IntoConfig, WithCredentialsProvider}, commands::{ClientKillOptions, ConnectionCommands, FlushingMode, ServerCommands}, tests::{get_default_host, get_default_port, get_test_client, log_try_init}, }; @@ -240,6 +240,25 @@ fn into_config() -> Result<()> { Ok(()) } +#[test] +fn credentials_provider_wrappers_layer_without_changing_config() -> Result<()> { + let setup = "redis+sentinel://127.0.0.1:6379/myservice" + .with_credentials_provider(|_| async { Ok(Credentials::for_default_user("data-pwd")) }) + .with_sentinel_credentials_provider(|_| async { + Ok(Credentials::new("sentinel-user", "sentinel-pwd")) + }) + .into_connection_setup()?; + + assert_eq!( + "redis+sentinel://127.0.0.1:6379/myservice", + setup.config.to_string() + ); + assert!(setup.credentials_provider.is_some()); + assert!(setup.sentinel_credentials_provider.is_some()); + + Ok(()) +} + #[cfg_attr(feature = "tokio-runtime", tokio::test)] #[cfg_attr(feature = "async-std-runtime", async_std::test)] #[serial]