Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions packages/core/shield/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ use async_trait::async_trait;
use serde::{Serialize, de::DeserializeOwned};

use crate::{
ErasedMethodAction,
action::MethodAction,
action::{ErasedMethodAction, MethodAction},
error::{SessionError, ShieldError},
provider::Provider,
};

#[async_trait]
pub trait Method: Send + Sync {
type Provider: Provider;
type Connection;
type Session: DeserializeOwned + Serialize;

fn id(&self) -> String;
Expand Down Expand Up @@ -40,6 +40,12 @@ pub trait Method: Send + Sync {
.into_iter()
.find(|provider| provider.id().as_deref() == provider_id))
}

async fn user_connections(
&self,
user: &str,
provider_id: Option<&str>,
) -> Result<Vec<Self::Connection>, ShieldError>;
}

#[async_trait]
Expand All @@ -59,6 +65,12 @@ pub trait ErasedMethod: Send + Sync {
provider_id: Option<&str>,
) -> Result<Option<Box<dyn Any + Send + Sync>>, ShieldError>;

async fn erased_user_connections(
&self,
user_id: &str,
provider_id: Option<&str>,
) -> Result<Vec<Box<dyn Any + Send + Sync>>, ShieldError>;

fn erased_deserialize_session(
&self,
value: Option<&str>,
Expand Down Expand Up @@ -110,6 +122,18 @@ macro_rules! erased_method {
})
}

async fn erased_user_connections(
&self,
user_id: &str,
provider_id: Option<&str>,
) -> Result<Vec<Box<dyn std::any::Any + Send + Sync>>, $crate::ShieldError> {
Ok(self.user_connections(user_id, provider_id)
.await?
.into_iter()
.map(|connection| Box::new(connection) as Box<dyn std::any::Any + Send + Sync>)
.collect())
}

fn erased_deserialize_session(
&self,
value: Option<&str>
Expand Down
28 changes: 27 additions & 1 deletion packages/core/shield/src/shield.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ use utoipa::{
#[cfg(feature = "utoipa")]
use crate::path::{ActionPathParams, MethodActionPathParams};
use crate::{
SignOutAction,
action::{Action, ActionForms, ActionMethodForm, ActionProviderForm},
actions::SignOutAction,
error::{ActionError, MethodError, ProviderError, SessionError, ShieldError},
method::ErasedMethod,
options::ShieldOptions,
Expand Down Expand Up @@ -277,6 +277,32 @@ impl<U: User> Shield<U> {
}
}

pub async fn user_connections<C: 'static>(
&self,
user: &U,
method_id: &str,
provider_id: Option<&str>,
) -> Result<Vec<C>, ShieldError> {
let method =
self.method_by_id(method_id)
.ok_or(ShieldError::Method(MethodError::NotFound(
method_id.to_owned(),
)))?;

let connections = method
.erased_user_connections(&user.id(), provider_id)
.await?;

Ok(connections
.into_iter()
.map(|connection| {
*connection
.downcast::<C>()
.expect("Connection should be downcast")
})
.collect())
}

#[cfg(feature = "utoipa")]
pub fn openapi(&self) -> OpenApi {
use utoipa::openapi::Response;
Expand Down
9 changes: 9 additions & 0 deletions packages/methods/shield-credentials/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ impl<U: User, D: DeserializeOwned> CredentialsMethod<U, D> {
#[async_trait]
impl<U: User + 'static, D: DeserializeOwned + 'static> Method for CredentialsMethod<U, D> {
type Provider = CredentialsProvider;
type Connection = ();
type Session = ();

fn id(&self) -> String {
Expand All @@ -40,6 +41,14 @@ impl<U: User + 'static, D: DeserializeOwned + 'static> Method for CredentialsMet
async fn providers(&self) -> Result<Vec<Self::Provider>, ShieldError> {
Ok(vec![CredentialsProvider])
}

async fn user_connections(
&self,
_user_id: &str,
_provider_id: Option<&str>,
) -> Result<Vec<Self::Connection>, ShieldError> {
Ok(vec![])
}
}

erased_method!(CredentialsMethod, <U: User, D: DeserializeOwned>);
9 changes: 9 additions & 0 deletions packages/methods/shield-dummy/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ impl<U: User> DummyMethod<U> {
#[async_trait]
impl<U: User + 'static> Method for DummyMethod<U> {
type Provider = DummyProvider;
type Connection = ();
type Session = ();

fn id(&self) -> String {
Expand All @@ -35,6 +36,14 @@ impl<U: User + 'static> Method for DummyMethod<U> {
async fn providers(&self) -> Result<Vec<Self::Provider>, ShieldError> {
Ok(vec![DummyProvider])
}

async fn user_connections(
&self,
_user_id: &str,
_provider_id: Option<&str>,
) -> Result<Vec<Self::Connection>, ShieldError> {
Ok(vec![])
}
}

erased_method!(DummyMethod, <U: User>);
9 changes: 9 additions & 0 deletions packages/methods/shield-email/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ impl<U: User> EmailMethod<U> {
#[async_trait]
impl<U: User + 'static> Method for EmailMethod<U> {
type Provider = EmailProvider;
type Connection = ();
type Session = ();

fn id(&self) -> String {
Expand All @@ -51,6 +52,14 @@ impl<U: User + 'static> Method for EmailMethod<U> {
async fn providers(&self) -> Result<Vec<Self::Provider>, ShieldError> {
Ok(vec![EmailProvider])
}

async fn user_connections(
&self,
_user_id: &str,
_provider_id: Option<&str>,
) -> Result<Vec<Self::Connection>, ShieldError> {
Ok(vec![])
}
}

erased_method!(EmailMethod, <U: User>);
13 changes: 13 additions & 0 deletions packages/methods/shield-oauth/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use async_trait::async_trait;
use shield::{Method, MethodAction, ShieldError, User, erased_method};

use crate::{
OauthConnection,
actions::{OauthSignInAction, OauthSignInCallbackAction},
options::OauthOptions,
provider::OauthProvider,
Expand Down Expand Up @@ -67,6 +68,7 @@ impl<U: User> OauthMethod<U> {
#[async_trait]
impl<U: User + 'static> Method for OauthMethod<U> {
type Provider = OauthProvider;
type Connection = OauthConnection;
type Session = OauthSession;

fn id(&self) -> String {
Expand Down Expand Up @@ -102,6 +104,17 @@ impl<U: User + 'static> Method for OauthMethod<U> {
Ok(None)
}
}

async fn user_connections(
&self,
user_id: &str,
provider_id: Option<&str>,
) -> Result<Vec<Self::Connection>, ShieldError> {
Ok(self
.storage
.user_oauth_connections(user_id, provider_id)
.await?)
}
}

erased_method!(OauthMethod, <U: User>);
6 changes: 6 additions & 0 deletions packages/methods/shield-oauth/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,10 @@ pub trait OauthStorage<U: User>: Storage<U> + Sync {
) -> Result<OauthConnection, StorageError>;

async fn delete_oauth_connection(&self, connection_id: &str) -> Result<(), StorageError>;

async fn user_oauth_connections(
&self,
user_id: &str,
provider_id: Option<&str>,
) -> Result<Vec<OauthConnection>, StorageError>;
}
13 changes: 13 additions & 0 deletions packages/methods/shield-oidc/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use async_trait::async_trait;
use shield::{Method, MethodAction, ShieldError, User, erased_method};

use crate::{
OidcConnection,
actions::{OidcSignInAction, OidcSignInCallbackAction},
options::OidcOptions,
provider::OidcProvider,
Expand Down Expand Up @@ -65,6 +66,7 @@ impl<U: User> OidcMethod<U> {
#[async_trait]
impl<U: User + 'static> Method for OidcMethod<U> {
type Provider = OidcProvider;
type Connection = OidcConnection;
type Session = OidcSession;

fn id(&self) -> String {
Expand Down Expand Up @@ -100,6 +102,17 @@ impl<U: User + 'static> Method for OidcMethod<U> {
Ok(None)
}
}

async fn user_connections(
&self,
user_id: &str,
provider_id: Option<&str>,
) -> Result<Vec<Self::Connection>, ShieldError> {
Ok(self
.storage
.user_oidc_connections(user_id, provider_id)
.await?)
}
}

erased_method!(OidcMethod, <U: User>);
6 changes: 6 additions & 0 deletions packages/methods/shield-oidc/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,10 @@ pub trait OidcStorage<U: User>: Storage<U> + Sync {
) -> Result<OidcConnection, StorageError>;

async fn delete_oidc_connection(&self, connection_id: &str) -> Result<(), StorageError>;

async fn user_oidc_connections(
&self,
user_id: &str,
provider_id: Option<&str>,
) -> Result<Vec<OidcConnection>, StorageError>;
}
10 changes: 10 additions & 0 deletions packages/methods/shield-workos/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ impl WorkosMethod {
#[async_trait]
impl Method for WorkosMethod {
type Provider = WorkosProvider;
type Connection = ();
type Session = ();

fn id(&self) -> String {
Expand All @@ -61,6 +62,15 @@ impl Method for WorkosMethod {
async fn providers(&self) -> Result<Vec<Self::Provider>, ShieldError> {
Ok(vec![WorkosProvider])
}

async fn user_connections(
&self,
_user_id: &str,
_provider_id: Option<&str>,
) -> Result<Vec<Self::Connection>, ShieldError> {
// TODO
Ok(vec![])
}
}

erased_method!(WorkosMethod);
19 changes: 19 additions & 0 deletions packages/storage/shield-memory/src/methods/oauth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,23 @@ impl OauthStorage<User> for MemoryStorage {

Ok(())
}

async fn user_oauth_connections(
&self,
user_id: &str,
provider_id: Option<&str>,
) -> Result<Vec<OauthConnection>, StorageError> {
Ok(self
.oauth
.connections
.lock()
.map_err(|err| StorageError::Engine(err.to_string()))?
.iter()
.filter(|connection| {
connection.user_id == user_id
&& provider_id.is_none_or(|provider_id| connection.provider_id == provider_id)
})
.cloned()
.collect())
}
}
19 changes: 19 additions & 0 deletions packages/storage/shield-memory/src/methods/oidc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,23 @@ impl OidcStorage<User> for MemoryStorage {

Ok(())
}

async fn user_oidc_connections(
&self,
user_id: &str,
provider_id: Option<&str>,
) -> Result<Vec<OidcConnection>, StorageError> {
Ok(self
.oidc
.connections
.lock()
.map_err(|err| StorageError::Engine(err.to_string()))?
.iter()
.filter(|connection| {
connection.user_id == user_id
&& provider_id.is_none_or(|provider_id| connection.provider_id == provider_id)
})
.cloned()
.collect())
}
}
19 changes: 19 additions & 0 deletions packages/storage/shield-sea-orm/src/methods/oauth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,25 @@ impl OauthStorage<User> for SeaOrmStorage {
.map_err(|err| StorageError::Engine(err.to_string()))
.map(|_| ())
}

async fn user_oauth_connections(
&self,
user_id: &str,
provider_id: Option<&str>,
) -> Result<Vec<OauthConnection>, StorageError> {
let mut query = oauth_provider_connection::Entity::find()
.filter(oauth_provider_connection::Column::UserId.eq(Self::parse_uuid(user_id)?));

if let Some(provider_id) = provider_id {
query = query.filter(oauth_provider_connection::Column::ProviderId.eq(provider_id));
}

query
.all(&self.database)
.await
.map_err(|err| StorageError::Engine(err.to_string()))
.map(|connections| connections.into_iter().map(OauthConnection::from).collect())
}
}

impl From<oauth_provider::OauthProviderVisibility> for OauthProviderVisibility {
Expand Down
19 changes: 19 additions & 0 deletions packages/storage/shield-sea-orm/src/methods/oidc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,25 @@ impl OidcStorage<User> for SeaOrmStorage {
.map_err(|err| StorageError::Engine(err.to_string()))
.map(|_| ())
}

async fn user_oidc_connections(
&self,
user_id: &str,
provider_id: Option<&str>,
) -> Result<Vec<OidcConnection>, StorageError> {
let mut query = oidc_provider_connection::Entity::find()
.filter(oidc_provider_connection::Column::UserId.eq(Self::parse_uuid(user_id)?));

if let Some(provider_id) = provider_id {
query = query.filter(oidc_provider_connection::Column::ProviderId.eq(provider_id));
}

query
.all(&self.database)
.await
.map_err(|err| StorageError::Engine(err.to_string()))
.map(|connections| connections.into_iter().map(OidcConnection::from).collect())
}
}

impl From<oidc_provider::OidcProviderVisibility> for OidcProviderVisibility {
Expand Down