Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ An asynchronous Redis client for Rust.
* [LUA Scripts/Functions](hhttps://redis.io/docs/latest/develop/programmability/) support
* [Cluster](https://redis.io/docs/latest/operate/oss_and_stack/management/scaling/) support (minimus supported Redis version is 6)
* [Client-side caching](https://redis.io/docs/latest/develop/reference/client-side-caching/) support
* Support for AWS/GCP/other IAM auth using `with_credentials_provider`

# Protocol Compatibility

Expand Down
302 changes: 302 additions & 0 deletions src/client/auth.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
use super::{Config, IntoConfig};
use crate::{Future, Result};
use std::{fmt, future::Future as StdFuture, sync::Arc};

/// Fresh authentication material for a newly established Redis TCP session.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Credentials {
pub username: Option<String>,
pub password: String,
}

impl Credentials {
#[must_use]
pub fn new(username: impl Into<String>, password: impl Into<String>) -> Self {
Self {
username: Some(username.into()),
password: password.into(),
}
}

#[must_use]
pub fn for_default_user(password: impl Into<String>) -> Self {
Self {
username: None,
password: password.into(),
}
}
}

/// Why a new TCP session is being authenticated.
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CredentialsReason {
InitialConnect,
Reconnect,
TopologyRefresh,
}

/// Which kind of server socket is being authenticated.
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CredentialsTarget {
DataNode,
SentinelNode,
}

/// The higher-level topology that triggered this authentication request.
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ServerKind {
Standalone,
Sentinel,
Cluster,
}

/// Connection metadata passed to a [`CredentialsProvider`].
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CredentialsContext {
pub host: String,
pub port: u16,
pub reason: CredentialsReason,
pub target: CredentialsTarget,
pub server_kind: ServerKind,
pub tls_enabled: bool,
}

/// Async credentials source used to authenticate every new TCP session.
///
/// Providers are invoked immediately before authentication on each newly established TCP session,
/// including the initial connect, reconnects, pub/sub reconnects, and cluster or sentinel node
/// connections. This makes them suitable for short-lived cloud credentials such as IAM-issued
/// Redis tokens.
///
/// Providers own their own caching policy. If token minting is expensive, cache internally and
/// return a fresh-enough [`Credentials`] value when `rustis` asks for one.
pub trait CredentialsProvider: Send + Sync + 'static {
fn resolve(&self, context: CredentialsContext) -> Future<'_, Credentials>;
}

/// Cloneable handle to a shared [`CredentialsProvider`].
#[derive(Clone)]
pub struct SharedCredentialsProvider(Arc<dyn CredentialsProvider>);

impl SharedCredentialsProvider {
#[must_use]
pub fn new<P: CredentialsProvider>(provider: P) -> Self {
Self(Arc::new(provider))
}

pub(crate) fn resolve(&self, context: CredentialsContext) -> Future<'_, Credentials> {
self.0.resolve(context)
}
}

impl fmt::Debug for SharedCredentialsProvider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("SharedCredentialsProvider(..)")
}
}

impl<P: CredentialsProvider> From<P> for SharedCredentialsProvider {
fn from(provider: P) -> Self {
Self::new(provider)
}
}

impl<F, Fut> CredentialsProvider for F
where
F: Fn(CredentialsContext) -> Fut + Send + Sync + 'static,
Fut: StdFuture<Output = Result<Credentials>> + Send + 'static,
{
fn resolve(&self, context: CredentialsContext) -> Future<'_, Credentials> {
Box::pin((self)(context))
}
}

/// Wrap an async closure into a [`SharedCredentialsProvider`].
#[must_use]
pub fn credentials_provider_fn<F, Fut>(f: F) -> SharedCredentialsProvider
where
F: Fn(CredentialsContext) -> Fut + Send + Sync + 'static,
Fut: StdFuture<Output = Result<Credentials>> + Send + 'static,
{
SharedCredentialsProvider::new(f)
}

/// `Internal Use`
///
/// Connection inputs after resolving an [`IntoConfig`] implementation and layering
/// any dynamic credentials providers on top of it.
///
/// This type is public because it appears in the hidden
/// [`IntoConfig::into_connection_setup`](crate::client::IntoConfig::into_connection_setup)
/// plumbing, but it is not intended to be constructed, matched on, or stored directly
/// by end users.
#[doc(hidden)]
#[derive(Debug, Clone)]
pub struct ConnectionSetup {
pub(crate) config: Config,
pub(crate) credentials_provider: Option<SharedCredentialsProvider>,
pub(crate) sentinel_credentials_provider: Option<SharedCredentialsProvider>,
}

impl ConnectionSetup {
#[must_use]
pub fn new(config: Config) -> Self {
Self {
config,
credentials_provider: None,
sentinel_credentials_provider: None,
}
}

#[must_use]
pub(crate) fn with_credentials_provider(mut self, provider: SharedCredentialsProvider) -> Self {
self.credentials_provider = Some(provider);
self
}

#[must_use]
pub(crate) fn with_sentinel_credentials_provider(
mut self,
provider: SharedCredentialsProvider,
) -> Self {
self.sentinel_credentials_provider = Some(provider);
self
}
}

impl IntoConfig for ConnectionSetup {
fn into_config(self) -> Result<Config> {
Ok(self.config)
}

fn into_connection_setup(self) -> Result<ConnectionSetup> {
Ok(self)
}
}

/// Wrapper returned by [`WithCredentialsProvider`] to attach dynamic auth providers
/// to any [`IntoConfig`] input.
///
/// End users should normally obtain this type via the extension-trait methods rather
/// than naming it directly.
#[derive(Debug, Clone)]
pub struct ConfigWithCredentialsProvider<C> {
inner: C,
credentials_provider: Option<SharedCredentialsProvider>,
sentinel_credentials_provider: Option<SharedCredentialsProvider>,
}

impl<C> ConfigWithCredentialsProvider<C> {
fn new(inner: C) -> Self {
Self {
inner,
credentials_provider: None,
sentinel_credentials_provider: None,
}
}

/// Use this provider for every new Redis data-node TCP session.
#[must_use]
pub fn with_credentials_provider(
mut self,
provider: impl Into<SharedCredentialsProvider>,
) -> Self {
self.credentials_provider = Some(provider.into());
self
}

/// Use this provider for every new Sentinel control-plane TCP session.
#[must_use]
pub fn with_sentinel_credentials_provider(
mut self,
provider: impl Into<SharedCredentialsProvider>,
) -> Self {
self.sentinel_credentials_provider = Some(provider.into());
self
}
}

impl<C: IntoConfig> IntoConfig for ConfigWithCredentialsProvider<C> {
fn into_config(self) -> Result<Config> {
Ok(self.into_connection_setup()?.config)
}

fn into_connection_setup(self) -> Result<ConnectionSetup> {
let mut setup = self.inner.into_connection_setup()?;

if let Some(provider) = self.credentials_provider {
setup = setup.with_credentials_provider(provider);
}

if let Some(provider) = self.sentinel_credentials_provider {
setup = setup.with_sentinel_credentials_provider(provider);
}

Ok(setup)
}
}

/// Extension methods for attaching dynamic credentials providers to any
/// [`IntoConfig`] input accepted by `rustis`.
///
/// This is the public entry point for dynamic authentication. The configured provider is called
/// every time `rustis` opens a new authenticated TCP session, so it integrates with the library's
/// built-in reconnect and resubscribe behavior.
///
/// # IAM example
///
/// ```no_run
/// use rustis::{
/// client::{Client, Credentials, CredentialsContext, WithCredentialsProvider},
/// commands::ConnectionCommands,
/// Result,
/// };
///
/// async fn fetch_iam_token(_ctx: &CredentialsContext) -> Result<String> {
/// todo!("Call your cloud SDK or metadata service here")
/// }
///
/// #[tokio::main]
/// async fn main() -> Result<()> {
/// let client = Client::connect(
/// "rediss://cache.example.com:6379".with_credentials_provider(|ctx| async move {
/// let token = fetch_iam_token(&ctx).await?;
/// Ok(Credentials::for_default_user(token))
/// }),
/// )
/// .await?;
///
/// let _: String = client.ping("hello").await?;
/// Ok(())
/// }
/// ```
pub trait WithCredentialsProvider: IntoConfig + Sized {
/// Use this provider for every new Redis data-node TCP session.
///
/// Static [`Config`](crate::client::Config) username and password fields are still supported,
/// but when a dynamic provider is attached, its return value is used for each new session.
#[must_use]
fn with_credentials_provider(
self,
provider: impl Into<SharedCredentialsProvider>,
) -> ConfigWithCredentialsProvider<Self> {
ConfigWithCredentialsProvider::new(self).with_credentials_provider(provider)
}

/// Use this provider for every new Sentinel control-plane TCP session.
///
/// This is useful when sentinel discovery and Redis data nodes use different credentials.
#[must_use]
fn with_sentinel_credentials_provider(
self,
provider: impl Into<SharedCredentialsProvider>,
) -> ConfigWithCredentialsProvider<Self> {
ConfigWithCredentialsProvider::new(self).with_sentinel_credentials_provider(provider)
}
}

impl<T: IntoConfig + Sized> WithCredentialsProvider for T {}
8 changes: 4 additions & 4 deletions src/client/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ impl Client {
/// Any Redis driver [`Error`](crate::Error) that occurs during the connection operation
#[inline]
pub async fn connect(config: impl IntoConfig) -> Result<Self> {
let config = config.into_config()?;
let command_timeout = config.command_timeout;
let retry_on_error = config.retry_on_error;
let setup = config.into_connection_setup()?;
let command_timeout = setup.config.command_timeout;
let retry_on_error = setup.config.retry_on_error;
let (msg_sender, network_task_join_handle, reconnect_sender, connection_tag) =
NetworkHandler::connect(config.into_config()?).await?;
NetworkHandler::connect(setup).await?;

Ok(Self {
msg_sender: Arc::new(Some(msg_sender)),
Expand Down
29 changes: 29 additions & 0 deletions src/client/config.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::{ConnectionSetup, ServerKind};
use crate::{ClientError, Error, Result};
#[cfg(feature = "native-tls")]
use native_tls::{Certificate, Identity, Protocol, TlsConnector, TlsConnectorBuilder};
Expand Down Expand Up @@ -150,6 +151,26 @@ impl Config {
Self::from_str(uri.as_str())
}

pub(crate) fn server_kind(&self) -> ServerKind {
match &self.server {
ServerConfig::Standalone { .. } => ServerKind::Standalone,
ServerConfig::Sentinel(_) => ServerKind::Sentinel,
ServerConfig::Cluster(_) => ServerKind::Cluster,
}
}

pub(crate) fn tls_enabled(&self) -> bool {
#[cfg(any(feature = "native-tls", feature = "rustls"))]
{
self.tls_config.is_some()
}

#[cfg(not(any(feature = "native-tls", feature = "rustls")))]
{
false
}
}

/// Parse address in the standard formart `host`:`port`
fn parse_addr(str: &str) -> Option<(&str, u16)> {
let mut iter = str.split(':');
Expand Down Expand Up @@ -848,6 +869,14 @@ impl TlsConfig {
pub trait IntoConfig {
/// Converts this type into a [`Config`](crate::client::Config).
fn into_config(self) -> Result<Config>;

#[doc(hidden)]
fn into_connection_setup(self) -> Result<ConnectionSetup>
where
Self: Sized,
{
Ok(ConnectionSetup::new(self.into_config()?))
}
}

impl IntoConfig for Config {
Expand Down
Loading