From 3771223bb6a4b2f0e3aa124e5ee465671f276411 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABlle=20Huisman?= Date: Tue, 19 Aug 2025 22:20:08 +0200 Subject: [PATCH] feat(workos): add OAuth and SSO buttons to index action --- Cargo.lock | 1 + examples/dioxus-axum/.env.example | 2 + examples/dioxus-axum/Cargo.toml | 2 + examples/dioxus-axum/src/main.rs | 53 +++-- packages/core/shield/src/action.rs | 11 +- packages/core/shield/src/actions/sign_out.rs | 6 +- packages/core/shield/src/shield.rs | 2 +- .../shield-credentials/src/actions/sign_in.rs | 4 +- .../src/actions/sign_out.rs | 4 +- .../shield-oauth/src/actions/sign_in.rs | 15 +- .../src/actions/sign_in_callback.rs | 4 +- .../shield-oauth/src/actions/sign_out.rs | 4 +- packages/methods/shield-oauth/src/options.rs | 2 +- .../shield-oidc/src/actions/sign_in.rs | 6 +- .../src/actions/sign_in_callback.rs | 4 +- .../shield-oidc/src/actions/sign_out.rs | 4 +- packages/methods/shield-oidc/src/options.rs | 2 +- .../shield-workos/src/actions/index.rs | 198 ++++++++++++++---- .../shield-workos/src/actions/sign_in.rs | 6 +- .../shield-workos/src/actions/sign_out.rs | 4 +- .../shield-workos/src/actions/sign_up.rs | 6 +- packages/methods/shield-workos/src/lib.rs | 6 + packages/methods/shield-workos/src/method.rs | 11 +- packages/methods/shield-workos/src/options.rs | 6 +- 24 files changed, 269 insertions(+), 94 deletions(-) create mode 100644 examples/dioxus-axum/.env.example diff --git a/Cargo.lock b/Cargo.lock index 62d09fc..1960946 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5612,6 +5612,7 @@ dependencies = [ "shield-dioxus-axum", "shield-memory", "shield-oidc", + "shield-workos", "tokio", "tower-sessions", "tracing", diff --git a/examples/dioxus-axum/.env.example b/examples/dioxus-axum/.env.example new file mode 100644 index 0000000..babe14d --- /dev/null +++ b/examples/dioxus-axum/.env.example @@ -0,0 +1,2 @@ +# WorkOS (optional) +# WORKOS_API_KEY = diff --git a/examples/dioxus-axum/Cargo.toml b/examples/dioxus-axum/Cargo.toml index 6ef10ff..78f7a33 100644 --- a/examples/dioxus-axum/Cargo.toml +++ b/examples/dioxus-axum/Cargo.toml @@ -18,6 +18,7 @@ server = [ "dep:shield-dioxus-axum", "dep:shield-memory", "dep:shield-oidc", + "dep:shield-workos", "dep:tokio", "dep:tower-sessions", "dioxus/server", @@ -35,6 +36,7 @@ shield-dioxus.workspace = true shield-dioxus-axum = { workspace = true, optional = true } shield-memory = { workspace = true, optional = true } shield-oidc = { workspace = true, features = ["native-tls"], optional = true } +shield-workos = { workspace = true, optional = true } tokio = { workspace = true, features = ["rt-multi-thread"], optional = true } tower-sessions = { workspace = true, optional = true } tracing.workspace = true diff --git a/examples/dioxus-axum/src/main.rs b/examples/dioxus-axum/src/main.rs index 36dab48..8a75e20 100644 --- a/examples/dioxus-axum/src/main.rs +++ b/examples/dioxus-axum/src/main.rs @@ -15,18 +15,19 @@ fn main() { #[cfg(feature = "server")] #[tokio::main] async fn main() { - use std::sync::Arc; + use std::{env, sync::Arc}; use axum::Router; use dioxus::{ cli_config::fullstack_address_or_localhost, prelude::{DioxusRouterExt, *}, }; - use shield::{Shield, ShieldOptions}; + use shield::{ErasedMethod, Method, Shield, ShieldOptions}; use shield_bootstrap::BootstrapDioxusStyle; use shield_dioxus_axum::{AxumDioxusIntegration, ShieldLayer}; use shield_memory::{MemoryStorage, User}; use shield_oidc::{Keycloak, OidcMethod}; + use shield_workos::{WorkosMethod, WorkosOauthProvider, WorkosOptions}; use tokio::net::TcpListener; use tower_sessions::{Expiry, MemoryStore, SessionManagerLayer, cookie::time::Duration}; use tracing::{Level, info}; @@ -45,21 +46,39 @@ async fn main() { let storage = MemoryStorage::new(); let shield = Shield::new( storage.clone(), - vec![Arc::new( - OidcMethod::new(storage).with_providers([Keycloak::builder( - "keycloak", - "http://localhost:18080/realms/Shield", - "client1", - ) - .client_secret("xcpQsaGbRILTljPtX4npjmYMBjKrariJ") - .redirect_url(format!( - "http://localhost:{}/api/auth/oidc/sign-in-callback/keycloak", - dioxus::cli_config::devserver_raw_addr() - .map(|addr| addr.port()) - .unwrap_or_else(|| addr.port()) - )) - .build()]), - )], + [ + Some(Arc::new( + OidcMethod::new(storage).with_providers([Keycloak::builder( + "keycloak", + "http://localhost:18080/realms/Shield", + "client1", + ) + .client_secret("xcpQsaGbRILTljPtX4npjmYMBjKrariJ") + .redirect_url(format!( + "http://localhost:{}/api/auth/oidc/sign-in-callback/keycloak", + dioxus::cli_config::devserver_raw_addr() + .map(|addr| addr.port()) + .unwrap_or_else(|| addr.port()) + )) + .build()]), + ) as Arc), + env::var("WORKOS_API_KEY").ok().map(|api_key| { + Arc::new( + WorkosMethod::from_api_key(&api_key).with_options( + WorkosOptions::builder() + .oauth_providers(vec![ + WorkosOauthProvider::AppleOAuth, + WorkosOauthProvider::GoogleOAuth, + WorkosOauthProvider::MicrosoftOAuth, + ]) + .build(), + ), + ) as Arc + }), + ] + .into_iter() + .flatten() + .collect(), ShieldOptions::default(), ); let shield_layer = ShieldLayer::new(shield.clone()); diff --git a/packages/core/shield/src/action.rs b/packages/core/shield/src/action.rs index 8e9f36a..8097daf 100644 --- a/packages/core/shield/src/action.rs +++ b/packages/core/shield/src/action.rs @@ -40,7 +40,7 @@ pub trait Action: ErasedAction + Send + Sync { Ok(true) } - fn forms(&self, provider: P) -> Vec
; + async fn forms(&self, provider: P) -> Result, ShieldError>; async fn call( &self, @@ -62,7 +62,10 @@ pub trait ErasedAction: Send + Sync { session: Session, ) -> Result; - fn erased_forms(&self, provider: Box) -> Vec; + async fn erased_forms( + &self, + provider: Box, + ) -> Result, ShieldError>; async fn erased_call( &self, @@ -89,8 +92,8 @@ macro_rules! erased_action { self.condition(provider.downcast_ref().expect("TODO"), session) } - fn erased_forms(&self, provider: Box) -> Vec<$crate::Form> { - self.forms(*provider.downcast().expect("TODO")) + async fn erased_forms(&self, provider: Box) -> Result, $crate::ShieldError> { + self.forms(*provider.downcast().expect("TODO")).await } async fn erased_call( diff --git a/packages/core/shield/src/actions/sign_out.rs b/packages/core/shield/src/actions/sign_out.rs index b956511..3d90026 100644 --- a/packages/core/shield/src/actions/sign_out.rs +++ b/packages/core/shield/src/actions/sign_out.rs @@ -31,14 +31,14 @@ impl SignOutAction { })) } - pub fn forms(_provider: P) -> Vec { - vec![Form { + pub async fn forms(_provider: P) -> Result, ShieldError> { + Ok(vec![Form { inputs: vec![Input { name: "submit".to_owned(), label: None, r#type: InputType::Submit(InputTypeSubmit {}), value: Some(Self::name()), }], - }] + }]) } } diff --git a/packages/core/shield/src/shield.rs b/packages/core/shield/src/shield.rs index 3ed2436..2b72df7 100644 --- a/packages/core/shield/src/shield.rs +++ b/packages/core/shield/src/shield.rs @@ -107,7 +107,7 @@ impl Shield { continue; } - let forms = action.erased_forms(provider); + let forms = action.erased_forms(provider).await?; for form in forms { provider_forms.push(ActionProviderForm { id: provider_id.clone(), diff --git a/packages/methods/shield-credentials/src/actions/sign_in.rs b/packages/methods/shield-credentials/src/actions/sign_in.rs index 14ef4f3..29da31e 100644 --- a/packages/methods/shield-credentials/src/actions/sign_in.rs +++ b/packages/methods/shield-credentials/src/actions/sign_in.rs @@ -31,8 +31,8 @@ impl Action Vec { - vec![self.credentials.form()] + async fn forms(&self, _provider: CredentialsProvider) -> Result, ShieldError> { + Ok(vec![self.credentials.form()]) } async fn call( diff --git a/packages/methods/shield-credentials/src/actions/sign_out.rs b/packages/methods/shield-credentials/src/actions/sign_out.rs index b7003ff..1cd86f8 100644 --- a/packages/methods/shield-credentials/src/actions/sign_out.rs +++ b/packages/methods/shield-credentials/src/actions/sign_out.rs @@ -23,8 +23,8 @@ impl Action for CredentialsSignOutAction { SignOutAction::condition(provider, session) } - fn forms(&self, provider: CredentialsProvider) -> Vec { - SignOutAction::forms(provider) + async fn forms(&self, provider: CredentialsProvider) -> Result, ShieldError> { + SignOutAction::forms(provider).await } async fn call( diff --git a/packages/methods/shield-oauth/src/actions/sign_in.rs b/packages/methods/shield-oauth/src/actions/sign_in.rs index 816e361..0f3b0a9 100644 --- a/packages/methods/shield-oauth/src/actions/sign_in.rs +++ b/packages/methods/shield-oauth/src/actions/sign_in.rs @@ -1,8 +1,8 @@ use async_trait::async_trait; use oauth2::{CsrfToken, PkceCodeChallenge, Scope, url::form_urlencoded::parse}; use shield::{ - Action, ConfigurationError, Form, Request, Response, Session, SessionError, ShieldError, - SignInAction, erased_action, + Action, ConfigurationError, Form, Input, InputType, InputTypeSubmit, Provider, Request, + Response, Session, SessionError, ShieldError, SignInAction, erased_action, }; use crate::{ @@ -23,8 +23,15 @@ impl Action for OauthSignInAction { SignInAction::name() } - fn forms(&self, _provider: OauthProvider) -> Vec { - vec![Form { inputs: vec![] }] + async fn forms(&self, provider: OauthProvider) -> Result, ShieldError> { + Ok(vec![Form { + inputs: vec![Input { + name: "submit".to_owned(), + label: None, + r#type: InputType::Submit(InputTypeSubmit::default()), + value: Some(format!("Sign in with {}", provider.name())), + }], + }]) } async fn call( diff --git a/packages/methods/shield-oauth/src/actions/sign_in_callback.rs b/packages/methods/shield-oauth/src/actions/sign_in_callback.rs index 5fec982..5714ef1 100644 --- a/packages/methods/shield-oauth/src/actions/sign_in_callback.rs +++ b/packages/methods/shield-oauth/src/actions/sign_in_callback.rs @@ -143,8 +143,8 @@ impl Action for OauthSignInCallbackAction { SignInCallbackAction::condition(provider, session) } - fn forms(&self, _provider: OauthProvider) -> Vec { - vec![Form { inputs: vec![] }] + async fn forms(&self, _provider: OauthProvider) -> Result, ShieldError> { + Ok(vec![]) } async fn call( diff --git a/packages/methods/shield-oauth/src/actions/sign_out.rs b/packages/methods/shield-oauth/src/actions/sign_out.rs index 117916c..53275ed 100644 --- a/packages/methods/shield-oauth/src/actions/sign_out.rs +++ b/packages/methods/shield-oauth/src/actions/sign_out.rs @@ -19,8 +19,8 @@ impl Action for OauthSignOutAction { SignOutAction::condition(provider, session) } - fn forms(&self, provider: OauthProvider) -> Vec { - SignOutAction::forms(provider) + async fn forms(&self, provider: OauthProvider) -> Result, ShieldError> { + SignOutAction::forms(provider).await } async fn call( diff --git a/packages/methods/shield-oauth/src/options.rs b/packages/methods/shield-oauth/src/options.rs index 3bc43e3..5045afc 100644 --- a/packages/methods/shield-oauth/src/options.rs +++ b/packages/methods/shield-oauth/src/options.rs @@ -4,7 +4,7 @@ use bon::Builder; #[builder(on(String, into), state_mod(vis = "pub(crate)"))] pub struct OauthOptions { #[builder(default = "/")] - pub sign_in_redirect: String, + pub(crate) sign_in_redirect: String, } impl Default for OauthOptions { diff --git a/packages/methods/shield-oidc/src/actions/sign_in.rs b/packages/methods/shield-oidc/src/actions/sign_in.rs index f2bfa6c..c5f8314 100644 --- a/packages/methods/shield-oidc/src/actions/sign_in.rs +++ b/packages/methods/shield-oidc/src/actions/sign_in.rs @@ -26,15 +26,15 @@ impl Action for OidcSignInAction { SignInAction::name() } - fn forms(&self, provider: OidcProvider) -> Vec { - vec![Form { + async fn forms(&self, provider: OidcProvider) -> Result, ShieldError> { + Ok(vec![Form { inputs: vec![Input { name: "submit".to_owned(), label: None, r#type: InputType::Submit(InputTypeSubmit::default()), value: Some(format!("Sign in with {}", provider.name())), }], - }] + }]) } async fn call( diff --git a/packages/methods/shield-oidc/src/actions/sign_in_callback.rs b/packages/methods/shield-oidc/src/actions/sign_in_callback.rs index a0dd488..437a80c 100644 --- a/packages/methods/shield-oidc/src/actions/sign_in_callback.rs +++ b/packages/methods/shield-oidc/src/actions/sign_in_callback.rs @@ -154,8 +154,8 @@ impl Action for OidcSignInCallbackAction { SignInCallbackAction::condition(provider, session) } - fn forms(&self, _provider: OidcProvider) -> Vec { - vec![Form { inputs: vec![] }] + async fn forms(&self, _provider: OidcProvider) -> Result, ShieldError> { + Ok(vec![]) } async fn call( diff --git a/packages/methods/shield-oidc/src/actions/sign_out.rs b/packages/methods/shield-oidc/src/actions/sign_out.rs index ff57851..4c804a3 100644 --- a/packages/methods/shield-oidc/src/actions/sign_out.rs +++ b/packages/methods/shield-oidc/src/actions/sign_out.rs @@ -19,8 +19,8 @@ impl Action for OidcSignOutAction { SignOutAction::condition(provider, session) } - fn forms(&self, provider: OidcProvider) -> Vec { - SignOutAction::forms(provider) + async fn forms(&self, provider: OidcProvider) -> Result, ShieldError> { + SignOutAction::forms(provider).await } async fn call( diff --git a/packages/methods/shield-oidc/src/options.rs b/packages/methods/shield-oidc/src/options.rs index 336e268..a97ae45 100644 --- a/packages/methods/shield-oidc/src/options.rs +++ b/packages/methods/shield-oidc/src/options.rs @@ -4,7 +4,7 @@ use bon::Builder; #[builder(on(String, into), state_mod(vis = "pub(crate)"))] pub struct OidcOptions { #[builder(default = "/")] - pub sign_in_redirect: String, + pub(crate) sign_in_redirect: String, } impl Default for OidcOptions { diff --git a/packages/methods/shield-workos/src/actions/index.rs b/packages/methods/shield-workos/src/actions/index.rs index 74b6f7a..54d4966 100644 --- a/packages/methods/shield-workos/src/actions/index.rs +++ b/packages/methods/shield-workos/src/actions/index.rs @@ -3,16 +3,17 @@ use std::sync::Arc; use async_trait::async_trait; use serde::Deserialize; use shield::{ - Action, Form, Input, InputType, InputTypeEmail, Request, Response, Session, ShieldError, - erased_action, + Action, Form, Input, InputType, InputTypeEmail, InputTypeHidden, InputTypeSubmit, Request, + Response, Session, ShieldError, erased_action, }; use tracing::info; use workos_sdk::{ PaginationParams, WorkOs, - user_management::{ListUsers, ListUsersParams}, + sso::{ConnectionId, ListConnections, ListConnectionsParams}, + user_management::{ListUsers, ListUsersParams, OauthProvider}, }; -use crate::provider::WorkosProvider; +use crate::{WorkosOptions, provider::WorkosProvider}; // TODO: Make a special case for an index action reachable at the `/auth` root URL. @@ -20,17 +21,30 @@ const ACTION_ID: &str = "index"; const ACTION_NAME: &str = "Index"; #[derive(Debug, Deserialize)] -pub struct EmailData { - pub email: String, +#[serde(untagged, rename_all = "camelCase", rename_all_fields = "camelCase")] +pub enum IndexData { + Email { + // TODO: Dioxus records multiple values per field, but most of the time only a single value is expected. + email: Vec, + }, + Oauth { + // TODO: See above. + oauth_provider: Vec, + }, + Sso { + // TODO: See above. + connection_id: Vec, + }, } pub struct WorkosIndexAction { + options: WorkosOptions, client: Arc, } impl WorkosIndexAction { - pub fn new(client: Arc) -> Self { - Self { client } + pub fn new(options: WorkosOptions, client: Arc) -> Self { + Self { options, client } } } @@ -44,22 +58,99 @@ impl Action for WorkosIndexAction { ACTION_NAME.to_owned() } - fn forms(&self, _provider: WorkosProvider) -> Vec { - // TODO: SSO buttons. - - vec![Form { - inputs: vec![Input { - name: "email".to_owned(), - label: Some("Email address".to_owned()), - r#type: InputType::Email(InputTypeEmail { - autocomplete: Some("email".to_owned()), - placeholder: Some("Email address".to_owned()), - required: Some(true), + async fn forms(&self, _provider: WorkosProvider) -> Result, ShieldError> { + let connections = self + .client + .sso() + .list_connections(&ListConnectionsParams { + pagination: PaginationParams { + limit: Some(100), ..Default::default() - }), - value: None, - }], + }, + ..Default::default() + }) + .await + .expect("TODO: handle error"); + + info!("{connections:#?}"); + + Ok([Form { + inputs: vec![ + Input { + name: "email".to_owned(), + label: Some("Email address".to_owned()), + r#type: InputType::Email(InputTypeEmail { + autocomplete: Some("email".to_owned()), + placeholder: Some("Email address".to_owned()), + required: Some(true), + ..Default::default() + }), + value: None, + }, + Input { + name: "submit".to_owned(), + label: None, + r#type: InputType::Submit(InputTypeSubmit::default()), + value: Some("Continue".to_owned()), + }, + ], }] + .into_iter() + .chain( + self.options + .oauth_providers + .iter() + .map(|oauth_provider| Form { + inputs: vec![ + Input { + name: "oauthProvider".to_owned(), + label: None, + r#type: InputType::Hidden(InputTypeHidden { + required: Some(true), + ..Default::default() + }), + value: Some(oauth_provider.to_string()), + }, + Input { + name: "submit".to_owned(), + label: None, + r#type: InputType::Submit(InputTypeSubmit::default()), + value: Some( + format!( + "Continue with {}", + match oauth_provider { + OauthProvider::AppleOAuth => "Apple", + OauthProvider::GithubOAuth => "GitHub", + OauthProvider::GoogleOAuth => "Google", + OauthProvider::MicrosoftOAuth => "Microsoft", + } + ) + .to_owned(), + ), + }, + ], + }), + ) + .chain(connections.data.into_iter().map(|connection| Form { + inputs: vec![ + Input { + name: "connectionId".to_owned(), + label: None, + r#type: InputType::Hidden(InputTypeHidden { + required: Some(true), + ..Default::default() + }), + value: Some(connection.id.to_string()), + }, + Input { + name: "submit".to_owned(), + label: None, + r#type: InputType::Submit(InputTypeSubmit::default()), + value: Some(format!("Continue with {}", connection.name).to_owned()), + }, + ], + })) + .collect()) } async fn call( @@ -71,23 +162,56 @@ impl Action for WorkosIndexAction { // TODO: Check email address and redirect to sign-in/sign-up action with prefilled email address. // TODO: Only check if enabled in options. - let data = serde_json::from_value::(request.form_data) + let data = serde_json::from_value::(request.form_data) .map_err(|err| ShieldError::Validation(err.to_string()))?; - let result = self - .client - .user_management() - .list_users(&ListUsersParams { - pagination: PaginationParams { - limit: Some(1), - ..Default::default() - }, - email: Some(&data.email), - ..Default::default() - }) - .await; - - info!("{result:?}"); + match data { + IndexData::Email { email } => { + info!("email: {email:#?}"); + + let users = self + .client + .user_management() + .list_users(&ListUsersParams { + pagination: PaginationParams { + limit: Some(1), + ..Default::default() + }, + // TODO: Remove [0] once email is a single value. + email: Some(&email[0]), + ..Default::default() + }) + .await + .expect("TODO: handle error"); + + info!("{users:#?}"); + + if users.data.is_empty() { + // TODO: Redirect to sign up action. + } else { + // TODO: Redirect to sign in action. + } + } + IndexData::Oauth { oauth_provider } => { + info!("oauth {oauth_provider:#?}"); + + // TODO: Add client ID to method. + // self.client + // .user_management() + // .get_authorization_url(&GetAuthorizationUrlParams { + // client_id: todo!(), + // redirect_uri: todo!(), + // connection_selector: todo!(), + // state: todo!(), + // code_challenge: todo!(), + // login_hint: todo!(), + // domain_hint: todo!(), + // }) + } + IndexData::Sso { connection_id } => { + info!("sso {connection_id:#?}"); + } + } Ok(Response::Default) } diff --git a/packages/methods/shield-workos/src/actions/sign_in.rs b/packages/methods/shield-workos/src/actions/sign_in.rs index e5692a2..df6655b 100644 --- a/packages/methods/shield-workos/src/actions/sign_in.rs +++ b/packages/methods/shield-workos/src/actions/sign_in.rs @@ -31,11 +31,11 @@ impl Action for WorkosSignInAction { SignInAction::name() } - fn forms(&self, _provider: WorkosProvider) -> Vec { + async fn forms(&self, _provider: WorkosProvider) -> Result, ShieldError> { // TODO: Magic auth and SSO buttons. // TODO: Prefill email address. - vec![ + Ok(vec![ Form { inputs: vec![ Input { @@ -87,7 +87,7 @@ impl Action for WorkosSignInAction { }, ], }, - ] + ]) } async fn call( diff --git a/packages/methods/shield-workos/src/actions/sign_out.rs b/packages/methods/shield-workos/src/actions/sign_out.rs index 42abe06..89ee70d 100644 --- a/packages/methods/shield-workos/src/actions/sign_out.rs +++ b/packages/methods/shield-workos/src/actions/sign_out.rs @@ -32,8 +32,8 @@ impl Action for WorkosSignOutAction { SignOutAction::condition(provider, session) } - fn forms(&self, provider: WorkosProvider) -> Vec { - SignOutAction::forms(provider) + async fn forms(&self, provider: WorkosProvider) -> Result, ShieldError> { + SignOutAction::forms(provider).await } async fn call( diff --git a/packages/methods/shield-workos/src/actions/sign_up.rs b/packages/methods/shield-workos/src/actions/sign_up.rs index fe709b7..afbb617 100644 --- a/packages/methods/shield-workos/src/actions/sign_up.rs +++ b/packages/methods/shield-workos/src/actions/sign_up.rs @@ -31,11 +31,11 @@ impl Action for WorkosSignUpAction { SignUpAction::name() } - fn forms(&self, _provider: WorkosProvider) -> Vec { + async fn forms(&self, _provider: WorkosProvider) -> Result, ShieldError> { // TODO: Magic auth and SSO buttons. // TODO: Prefill email address. - vec![ + Ok(vec![ Form { inputs: vec![ Input { @@ -87,7 +87,7 @@ impl Action for WorkosSignUpAction { }, ], }, - ] + ]) } async fn call( diff --git a/packages/methods/shield-workos/src/lib.rs b/packages/methods/shield-workos/src/lib.rs index 011cc55..aa7a206 100644 --- a/packages/methods/shield-workos/src/lib.rs +++ b/packages/methods/shield-workos/src/lib.rs @@ -6,4 +6,10 @@ mod provider; pub use method::*; pub use options::*; +#[doc(no_inline)] +pub use workos_sdk::{ + ApiKey as WorkosApiKey, WorkOs as Workos, WorkOsBuilder as WorkosBuilder, + user_management::OauthProvider as WorkosOauthProvider, +}; + // TODO: Support both AuthKit method and self hosted method. diff --git a/packages/methods/shield-workos/src/method.rs b/packages/methods/shield-workos/src/method.rs index e8c66eb..bc7b4d3 100644 --- a/packages/methods/shield-workos/src/method.rs +++ b/packages/methods/shield-workos/src/method.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use async_trait::async_trait; use shield::{Action, Method, ShieldError, erased_method}; -use workos_sdk::WorkOs; +use workos_sdk::{ApiKey, WorkOs}; use crate::{ actions::{WorkosIndexAction, WorkosSignInAction, WorkosSignOutAction, WorkosSignUpAction}, @@ -25,6 +25,10 @@ impl WorkosMethod { } } + pub fn from_api_key(api_key: &str) -> Self { + Self::new(WorkOs::new(&ApiKey::from(api_key))) + } + pub fn with_options(mut self, options: WorkosOptions) -> Self { self.options = options; self @@ -39,7 +43,10 @@ impl Method for WorkosMethod { fn actions(&self) -> Vec>> { vec![ - Box::new(WorkosIndexAction::new(self.client.clone())), + Box::new(WorkosIndexAction::new( + self.options.clone(), + self.client.clone(), + )), Box::new(WorkosSignInAction::new(self.client.clone())), Box::new(WorkosSignUpAction::new(self.client.clone())), Box::new(WorkosSignOutAction::new(self.client.clone())), diff --git a/packages/methods/shield-workos/src/options.rs b/packages/methods/shield-workos/src/options.rs index 43fd6a2..cb4c647 100644 --- a/packages/methods/shield-workos/src/options.rs +++ b/packages/methods/shield-workos/src/options.rs @@ -1,8 +1,12 @@ use bon::Builder; +use workos_sdk::user_management::OauthProvider; #[derive(Builder, Clone, Debug)] #[builder(on(String, into), state_mod(vis = "pub(crate)"))] -pub struct WorkosOptions {} +pub struct WorkosOptions { + #[builder(default)] + pub(crate) oauth_providers: Vec, +} impl Default for WorkosOptions { fn default() -> Self {