From 40f48646b35af95d81578ef128422b7c5b53eb8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABlle=20Huisman?= Date: Sun, 14 Jun 2026 15:23:36 +0200 Subject: [PATCH] feat: add user connections --- packages/core/shield/src/method.rs | 28 +++++++++++++++++-- packages/core/shield/src/shield.rs | 28 ++++++++++++++++++- .../methods/shield-credentials/src/method.rs | 9 ++++++ packages/methods/shield-dummy/src/method.rs | 9 ++++++ packages/methods/shield-email/src/method.rs | 9 ++++++ packages/methods/shield-oauth/src/method.rs | 13 +++++++++ packages/methods/shield-oauth/src/storage.rs | 6 ++++ packages/methods/shield-oidc/src/method.rs | 13 +++++++++ packages/methods/shield-oidc/src/storage.rs | 6 ++++ packages/methods/shield-workos/src/method.rs | 10 +++++++ .../shield-memory/src/methods/oauth.rs | 19 +++++++++++++ .../storage/shield-memory/src/methods/oidc.rs | 19 +++++++++++++ .../shield-sea-orm/src/methods/oauth.rs | 19 +++++++++++++ .../shield-sea-orm/src/methods/oidc.rs | 19 +++++++++++++ 14 files changed, 204 insertions(+), 3 deletions(-) diff --git a/packages/core/shield/src/method.rs b/packages/core/shield/src/method.rs index 42950cc..b8a1133 100644 --- a/packages/core/shield/src/method.rs +++ b/packages/core/shield/src/method.rs @@ -4,8 +4,7 @@ use async_trait::async_trait; use serde::{Serialize, de::DeserializeOwned}; use crate::{ - ErasedMethodAction, - action::MethodAction, + action::{ErasedMethodAction, MethodAction}, error::{SessionError, ShieldError}, provider::Provider, }; @@ -13,6 +12,7 @@ use crate::{ #[async_trait] pub trait Method: Send + Sync { type Provider: Provider; + type Connection; type Session: DeserializeOwned + Serialize; fn id(&self) -> String; @@ -40,6 +40,12 @@ pub trait Method: Send + Sync { .into_iter() .find(|provider| provider.id().as_deref() == provider_id)) } + + async fn user_connections( + &self, + user: &str, + provider_id: Option<&str>, + ) -> Result, ShieldError>; } #[async_trait] @@ -59,6 +65,12 @@ pub trait ErasedMethod: Send + Sync { provider_id: Option<&str>, ) -> Result>, ShieldError>; + async fn erased_user_connections( + &self, + user_id: &str, + provider_id: Option<&str>, + ) -> Result>, ShieldError>; + fn erased_deserialize_session( &self, value: Option<&str>, @@ -110,6 +122,18 @@ macro_rules! erased_method { }) } + async fn erased_user_connections( + &self, + user_id: &str, + provider_id: Option<&str>, + ) -> Result>, $crate::ShieldError> { + Ok(self.user_connections(user_id, provider_id) + .await? + .into_iter() + .map(|connection| Box::new(connection) as Box) + .collect()) + } + fn erased_deserialize_session( &self, value: Option<&str> diff --git a/packages/core/shield/src/shield.rs b/packages/core/shield/src/shield.rs index 503e6e3..91e2735 100644 --- a/packages/core/shield/src/shield.rs +++ b/packages/core/shield/src/shield.rs @@ -16,8 +16,8 @@ use utoipa::{ #[cfg(feature = "utoipa")] use crate::path::{ActionPathParams, MethodActionPathParams}; use crate::{ - SignOutAction, action::{Action, ActionForms, ActionMethodForm, ActionProviderForm}, + actions::SignOutAction, error::{ActionError, MethodError, ProviderError, SessionError, ShieldError}, method::ErasedMethod, options::ShieldOptions, @@ -277,6 +277,32 @@ impl Shield { } } + pub async fn user_connections( + &self, + user: &U, + method_id: &str, + provider_id: Option<&str>, + ) -> Result, ShieldError> { + let method = + self.method_by_id(method_id) + .ok_or(ShieldError::Method(MethodError::NotFound( + method_id.to_owned(), + )))?; + + let connections = method + .erased_user_connections(&user.id(), provider_id) + .await?; + + Ok(connections + .into_iter() + .map(|connection| { + *connection + .downcast::() + .expect("Connection should be downcast") + }) + .collect()) + } + #[cfg(feature = "utoipa")] pub fn openapi(&self) -> OpenApi { use utoipa::openapi::Response; diff --git a/packages/methods/shield-credentials/src/method.rs b/packages/methods/shield-credentials/src/method.rs index ebb0a66..cea746a 100644 --- a/packages/methods/shield-credentials/src/method.rs +++ b/packages/methods/shield-credentials/src/method.rs @@ -25,6 +25,7 @@ impl CredentialsMethod { #[async_trait] impl Method for CredentialsMethod { type Provider = CredentialsProvider; + type Connection = (); type Session = (); fn id(&self) -> String { @@ -40,6 +41,14 @@ impl Method for CredentialsMet async fn providers(&self) -> Result, ShieldError> { Ok(vec![CredentialsProvider]) } + + async fn user_connections( + &self, + _user_id: &str, + _provider_id: Option<&str>, + ) -> Result, ShieldError> { + Ok(vec![]) + } } erased_method!(CredentialsMethod, ); diff --git a/packages/methods/shield-dummy/src/method.rs b/packages/methods/shield-dummy/src/method.rs index 9c5b068..3f2ede1 100644 --- a/packages/methods/shield-dummy/src/method.rs +++ b/packages/methods/shield-dummy/src/method.rs @@ -22,6 +22,7 @@ impl DummyMethod { #[async_trait] impl Method for DummyMethod { type Provider = DummyProvider; + type Connection = (); type Session = (); fn id(&self) -> String { @@ -35,6 +36,14 @@ impl Method for DummyMethod { async fn providers(&self) -> Result, ShieldError> { Ok(vec![DummyProvider]) } + + async fn user_connections( + &self, + _user_id: &str, + _provider_id: Option<&str>, + ) -> Result, ShieldError> { + Ok(vec![]) + } } erased_method!(DummyMethod, ); diff --git a/packages/methods/shield-email/src/method.rs b/packages/methods/shield-email/src/method.rs index ab50a21..da8ee77 100644 --- a/packages/methods/shield-email/src/method.rs +++ b/packages/methods/shield-email/src/method.rs @@ -29,6 +29,7 @@ impl EmailMethod { #[async_trait] impl Method for EmailMethod { type Provider = EmailProvider; + type Connection = (); type Session = (); fn id(&self) -> String { @@ -51,6 +52,14 @@ impl Method for EmailMethod { async fn providers(&self) -> Result, ShieldError> { Ok(vec![EmailProvider]) } + + async fn user_connections( + &self, + _user_id: &str, + _provider_id: Option<&str>, + ) -> Result, ShieldError> { + Ok(vec![]) + } } erased_method!(EmailMethod, ); diff --git a/packages/methods/shield-oauth/src/method.rs b/packages/methods/shield-oauth/src/method.rs index 8e371a0..c9a1951 100644 --- a/packages/methods/shield-oauth/src/method.rs +++ b/packages/methods/shield-oauth/src/method.rs @@ -4,6 +4,7 @@ use async_trait::async_trait; use shield::{Method, MethodAction, ShieldError, User, erased_method}; use crate::{ + OauthConnection, actions::{OauthSignInAction, OauthSignInCallbackAction}, options::OauthOptions, provider::OauthProvider, @@ -67,6 +68,7 @@ impl OauthMethod { #[async_trait] impl Method for OauthMethod { type Provider = OauthProvider; + type Connection = OauthConnection; type Session = OauthSession; fn id(&self) -> String { @@ -102,6 +104,17 @@ impl Method for OauthMethod { Ok(None) } } + + async fn user_connections( + &self, + user_id: &str, + provider_id: Option<&str>, + ) -> Result, ShieldError> { + Ok(self + .storage + .user_oauth_connections(user_id, provider_id) + .await?) + } } erased_method!(OauthMethod, ); diff --git a/packages/methods/shield-oauth/src/storage.rs b/packages/methods/shield-oauth/src/storage.rs index 7a97ec0..0382dac 100644 --- a/packages/methods/shield-oauth/src/storage.rs +++ b/packages/methods/shield-oauth/src/storage.rs @@ -38,4 +38,10 @@ pub trait OauthStorage: Storage + Sync { ) -> Result; async fn delete_oauth_connection(&self, connection_id: &str) -> Result<(), StorageError>; + + async fn user_oauth_connections( + &self, + user_id: &str, + provider_id: Option<&str>, + ) -> Result, StorageError>; } diff --git a/packages/methods/shield-oidc/src/method.rs b/packages/methods/shield-oidc/src/method.rs index 6c115d9..2daf50d 100644 --- a/packages/methods/shield-oidc/src/method.rs +++ b/packages/methods/shield-oidc/src/method.rs @@ -4,6 +4,7 @@ use async_trait::async_trait; use shield::{Method, MethodAction, ShieldError, User, erased_method}; use crate::{ + OidcConnection, actions::{OidcSignInAction, OidcSignInCallbackAction}, options::OidcOptions, provider::OidcProvider, @@ -65,6 +66,7 @@ impl OidcMethod { #[async_trait] impl Method for OidcMethod { type Provider = OidcProvider; + type Connection = OidcConnection; type Session = OidcSession; fn id(&self) -> String { @@ -100,6 +102,17 @@ impl Method for OidcMethod { Ok(None) } } + + async fn user_connections( + &self, + user_id: &str, + provider_id: Option<&str>, + ) -> Result, ShieldError> { + Ok(self + .storage + .user_oidc_connections(user_id, provider_id) + .await?) + } } erased_method!(OidcMethod, ); diff --git a/packages/methods/shield-oidc/src/storage.rs b/packages/methods/shield-oidc/src/storage.rs index 526095e..3866a3d 100644 --- a/packages/methods/shield-oidc/src/storage.rs +++ b/packages/methods/shield-oidc/src/storage.rs @@ -38,4 +38,10 @@ pub trait OidcStorage: Storage + Sync { ) -> Result; async fn delete_oidc_connection(&self, connection_id: &str) -> Result<(), StorageError>; + + async fn user_oidc_connections( + &self, + user_id: &str, + provider_id: Option<&str>, + ) -> Result, StorageError>; } diff --git a/packages/methods/shield-workos/src/method.rs b/packages/methods/shield-workos/src/method.rs index d533adc..3be10fa 100644 --- a/packages/methods/shield-workos/src/method.rs +++ b/packages/methods/shield-workos/src/method.rs @@ -41,6 +41,7 @@ impl WorkosMethod { #[async_trait] impl Method for WorkosMethod { type Provider = WorkosProvider; + type Connection = (); type Session = (); fn id(&self) -> String { @@ -61,6 +62,15 @@ impl Method for WorkosMethod { async fn providers(&self) -> Result, ShieldError> { Ok(vec![WorkosProvider]) } + + async fn user_connections( + &self, + _user_id: &str, + _provider_id: Option<&str>, + ) -> Result, ShieldError> { + // TODO + Ok(vec![]) + } } erased_method!(WorkosMethod); diff --git a/packages/storage/shield-memory/src/methods/oauth.rs b/packages/storage/shield-memory/src/methods/oauth.rs index 9bde888..f299abe 100644 --- a/packages/storage/shield-memory/src/methods/oauth.rs +++ b/packages/storage/shield-memory/src/methods/oauth.rs @@ -126,4 +126,23 @@ impl OauthStorage for MemoryStorage { Ok(()) } + + async fn user_oauth_connections( + &self, + user_id: &str, + provider_id: Option<&str>, + ) -> Result, StorageError> { + Ok(self + .oauth + .connections + .lock() + .map_err(|err| StorageError::Engine(err.to_string()))? + .iter() + .filter(|connection| { + connection.user_id == user_id + && provider_id.is_none_or(|provider_id| connection.provider_id == provider_id) + }) + .cloned() + .collect()) + } } diff --git a/packages/storage/shield-memory/src/methods/oidc.rs b/packages/storage/shield-memory/src/methods/oidc.rs index 23de44a..3a23904 100644 --- a/packages/storage/shield-memory/src/methods/oidc.rs +++ b/packages/storage/shield-memory/src/methods/oidc.rs @@ -130,4 +130,23 @@ impl OidcStorage for MemoryStorage { Ok(()) } + + async fn user_oidc_connections( + &self, + user_id: &str, + provider_id: Option<&str>, + ) -> Result, StorageError> { + Ok(self + .oidc + .connections + .lock() + .map_err(|err| StorageError::Engine(err.to_string()))? + .iter() + .filter(|connection| { + connection.user_id == user_id + && provider_id.is_none_or(|provider_id| connection.provider_id == provider_id) + }) + .cloned() + .collect()) + } } diff --git a/packages/storage/shield-sea-orm/src/methods/oauth.rs b/packages/storage/shield-sea-orm/src/methods/oauth.rs index 576e163..720a4c3 100644 --- a/packages/storage/shield-sea-orm/src/methods/oauth.rs +++ b/packages/storage/shield-sea-orm/src/methods/oauth.rs @@ -142,6 +142,25 @@ impl OauthStorage for SeaOrmStorage { .map_err(|err| StorageError::Engine(err.to_string())) .map(|_| ()) } + + async fn user_oauth_connections( + &self, + user_id: &str, + provider_id: Option<&str>, + ) -> Result, StorageError> { + let mut query = oauth_provider_connection::Entity::find() + .filter(oauth_provider_connection::Column::UserId.eq(Self::parse_uuid(user_id)?)); + + if let Some(provider_id) = provider_id { + query = query.filter(oauth_provider_connection::Column::ProviderId.eq(provider_id)); + } + + query + .all(&self.database) + .await + .map_err(|err| StorageError::Engine(err.to_string())) + .map(|connections| connections.into_iter().map(OauthConnection::from).collect()) + } } impl From for OauthProviderVisibility { diff --git a/packages/storage/shield-sea-orm/src/methods/oidc.rs b/packages/storage/shield-sea-orm/src/methods/oidc.rs index 8b323cd..e5ed5b0 100644 --- a/packages/storage/shield-sea-orm/src/methods/oidc.rs +++ b/packages/storage/shield-sea-orm/src/methods/oidc.rs @@ -149,6 +149,25 @@ impl OidcStorage for SeaOrmStorage { .map_err(|err| StorageError::Engine(err.to_string())) .map(|_| ()) } + + async fn user_oidc_connections( + &self, + user_id: &str, + provider_id: Option<&str>, + ) -> Result, StorageError> { + let mut query = oidc_provider_connection::Entity::find() + .filter(oidc_provider_connection::Column::UserId.eq(Self::parse_uuid(user_id)?)); + + if let Some(provider_id) = provider_id { + query = query.filter(oidc_provider_connection::Column::ProviderId.eq(provider_id)); + } + + query + .all(&self.database) + .await + .map_err(|err| StorageError::Engine(err.to_string())) + .map(|connections| connections.into_iter().map(OidcConnection::from).collect()) + } } impl From for OidcProviderVisibility {