diff --git a/example/src/email_address.rs b/example/src/email_address.rs new file mode 100644 index 0000000..9c2a8fc --- /dev/null +++ b/example/src/email_address.rs @@ -0,0 +1,18 @@ +use fortifier::Validate; + +#[derive(Validate)] +pub enum ChangeEmailAddressRelation { + Create { + #[validate(email)] + email_address: String, + }, + Update { + id: String, + + #[validate(email)] + email_address: String, + }, + Delete { + id: String, + }, +} diff --git a/example/src/main.rs b/example/src/main.rs index 707dddd..c8bcb49 100644 --- a/example/src/main.rs +++ b/example/src/main.rs @@ -1,40 +1,11 @@ -use std::{error::Error, sync::LazyLock}; +mod email_address; +mod user; -use fortifier::Validate; -use regex::Regex; - -static COUNTRY_CODE_REGEX: LazyLock = - LazyLock::new(|| Regex::new(r"[A-Z]{2}").expect("Regex should be valid.")); - -#[derive(Validate)] -struct CreateUser { - #[validate(email)] - email: String, - - #[validate(length(min = 1, max = 256))] - name: String, - - #[validate(url)] - url: String, - - #[validate(regex = &COUNTRY_CODE_REGEX)] - country_code: String, +use std::error::Error; - #[validate(custom(function = validate_one_locale_required, error = OneLocaleRequiredError))] - #[validate(length(min = 1))] - locales: Vec, -} +use fortifier::Validate; -#[derive(Debug)] -struct OneLocaleRequiredError; - -fn validate_one_locale_required(locales: &[String]) -> Result<(), OneLocaleRequiredError> { - if locales.is_empty() { - Err(OneLocaleRequiredError) - } else { - Ok(()) - } -} +use crate::{email_address::ChangeEmailAddressRelation, user::CreateUser}; #[tokio::main] async fn main() -> Result<(), Box> { @@ -48,5 +19,22 @@ async fn main() -> Result<(), Box> { data.validate().await?; + let data = ChangeEmailAddressRelation::Create { + email_address: "john@doe.com".to_owned(), + }; + + data.validate().await?; + + let data = ChangeEmailAddressRelation::Update { + id: "1".to_owned(), + email_address: "john@doe.com".to_owned(), + }; + + data.validate().await?; + + let data = ChangeEmailAddressRelation::Delete { id: "1".to_owned() }; + + data.validate().await?; + Ok(()) } diff --git a/example/src/user.rs b/example/src/user.rs new file mode 100644 index 0000000..40a0f6e --- /dev/null +++ b/example/src/user.rs @@ -0,0 +1,37 @@ +use std::sync::LazyLock; + +use fortifier::Validate; +use regex::Regex; + +static COUNTRY_CODE_REGEX: LazyLock = + LazyLock::new(|| Regex::new(r"[A-Z]{2}").expect("Regex should be valid.")); + +#[derive(Validate)] +pub struct CreateUser { + #[validate(email)] + pub email: String, + + #[validate(length(min = 1, max = 256))] + pub name: String, + + #[validate(url)] + pub url: String, + + #[validate(regex = &COUNTRY_CODE_REGEX)] + pub country_code: String, + + #[validate(custom(function = validate_one_locale_required, error = OneLocaleRequiredError))] + #[validate(length(min = 1))] + pub locales: Vec, +} + +#[derive(Debug)] +pub struct OneLocaleRequiredError; + +fn validate_one_locale_required(locales: &[String]) -> Result<(), OneLocaleRequiredError> { + if locales.is_empty() { + Err(OneLocaleRequiredError) + } else { + Ok(()) + } +} diff --git a/packages/fortifier-macros/src/validate.rs b/packages/fortifier-macros/src/validate.rs index 5272580..3ea26fd 100644 --- a/packages/fortifier-macros/src/validate.rs +++ b/packages/fortifier-macros/src/validate.rs @@ -1,5 +1,6 @@ mod r#enum; mod field; +mod fields; mod r#struct; mod r#union; diff --git a/packages/fortifier-macros/src/validate/enum.rs b/packages/fortifier-macros/src/validate/enum.rs index 3d549f5..aeb6731 100644 --- a/packages/fortifier-macros/src/validate/enum.rs +++ b/packages/fortifier-macros/src/validate/enum.rs @@ -1,17 +1,276 @@ +use std::{collections::HashSet, str::FromStr}; + use proc_macro2::TokenStream; -use quote::ToTokens; -use syn::{DataEnum, DeriveInput, Result}; +use quote::{ToTokens, TokenStreamExt, format_ident, quote}; +use syn::{DataEnum, DeriveInput, Generics, Ident, Result, Variant, Visibility}; + +use crate::validate::{ + field::{LiteralOrIdent, ValidateFieldPrefix}, + fields::ValidateFields, +}; -pub struct ValidateEnum {} +pub struct ValidateEnum { + visibility: Visibility, + ident: Ident, + error_ident: Ident, + generics: Generics, + variants: Vec, +} impl ValidateEnum { - pub fn parse(_input: &DeriveInput, _data: &DataEnum) -> Result { - todo!("enum") + pub fn parse(input: &DeriveInput, data: &DataEnum) -> Result { + let mut result = ValidateEnum { + visibility: input.vis.clone(), + ident: input.ident.clone(), + error_ident: format_ident!("{}ValidationError", input.ident), + generics: input.generics.clone(), + variants: Vec::with_capacity(data.variants.len()), + }; + + for variant in &data.variants { + result.variants.push(ValidateEnumVariant::parse( + &input.vis, + &result.ident, + &result.error_ident, + variant, + )?); + } + + Ok(result) + } + + fn uses(&self) -> HashSet { + self.variants + .iter() + .flat_map(|variant| variant.uses()) + .collect() + } + + fn error_type(&self) -> (Ident, TokenStream) { + let visibility = &self.visibility; + let error_ident = &self.error_ident; + + let error_variant_idents = self + .variants + .iter() + .map(|variant| &variant.ident) + .collect::>(); + let error_variant_types = self + .variants + .iter() + .map(|variant| variant.error_type().0) + .collect::>(); + + ( + error_ident.clone(), + quote! { + #[allow(dead_code)] + #[derive(Debug)] + #visibility enum #error_ident { + #( #error_variant_idents(#error_variant_types) ),* + } + + #[automatically_derived] + impl ::std::fmt::Display for #error_ident { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + write!(f, "{self:#?}") + } + } + + #[automatically_derived] + impl ::std::error::Error for #error_ident {} + }, + ) } } impl ToTokens for ValidateEnum { - fn to_tokens(&self, _tokens: &mut TokenStream) { - // TODO + fn to_tokens(&self, tokens: &mut TokenStream) { + let ident = &self.ident; + let (impl_generics, type_generics, where_clause) = &self.generics.split_for_impl(); + + let uses = self.uses().into_iter().map(|r#use| { + let tokens = TokenStream::from_str(&r#use).expect("Tokens should be valid."); + quote!(use #tokens;) + }); + let (error_ident, error_type) = self.error_type(); + let variant_error_types = self.variants.iter().map(|variant| variant.error_type().1); + let sync_variant_match_arms = self + .variants + .iter() + .map(|variant| variant.sync_match_arm_tokens()); + let async_variant_match_arms = self + .variants + .iter() + .map(|variant| variant.async_match_arm_tokens()); + + tokens.append_all(quote! { + #( #uses )* + + // TODO: Replace with granular uses. + use fortifier::*; + + #error_type + + #( #variant_error_types )* + + #[automatically_derived] + impl #impl_generics Validate for #ident #type_generics #where_clause { + type Error = #error_ident; + + fn validate_sync(&self) -> Result<(), ValidationErrors> { + match &self { + #( #sync_variant_match_arms ),* + } + } + + fn validate_async(&self) -> ::std::pin::Pin>>>> { + Box::pin(async move { + match &self { + #( #async_variant_match_arms ),* + } + }) + } + } + }) + } +} + +pub struct ValidateEnumVariant { + enum_ident: Ident, + enum_error_ident: Ident, + ident: Ident, + fields: ValidateFields, +} + +impl ValidateEnumVariant { + pub fn parse( + visibility: &Visibility, + enum_ident: &Ident, + enum_error_ident: &Ident, + variant: &Variant, + ) -> Result { + let result = ValidateEnumVariant { + enum_ident: enum_ident.clone(), + enum_error_ident: enum_error_ident.clone(), + ident: variant.ident.clone(), + fields: ValidateFields::parse( + visibility, + &format_ident!("{}{}", enum_ident, variant.ident), + &variant.fields, + )?, + }; + + Ok(result) + } + + fn uses(&self) -> HashSet { + self.fields.uses() + } + + fn error_type(&self) -> (Ident, TokenStream) { + self.fields.error_type() + } + + fn sync_match_arm_tokens(&self) -> TokenStream { + let enum_ident = &self.enum_ident; + let enum_error_ident = &self.enum_error_ident; + let ident = &self.ident; + + let error_wrapper = |tokens| quote!(#enum_error_ident::#ident(#tokens)); + + match &self.fields { + ValidateFields::Named(fields) => { + let field_idents = fields.idents(); + let sync_validations = + fields.sync_validations(ValidateFieldPrefix::None, &error_wrapper); + + // TODO: Only destructure fields required for validation. + quote! { + #[allow(unused_variables)] + #enum_ident::#ident { + #( #field_idents ),* + } => { + #sync_validations + } + } + } + ValidateFields::Unnamed(fields) => { + let field_idents = fields.idents().map(|ident| match ident { + LiteralOrIdent::Literal(literal) => format_ident!("f{literal}"), + LiteralOrIdent::Ident(ident) => ident.clone(), + }); + let sync_validations = + fields.sync_validations(ValidateFieldPrefix::F, &error_wrapper); + + quote! { + #enum_ident::#ident( + #( #field_idents ),* + ) => { + #sync_validations + } + } + } + ValidateFields::Unit(fields) => { + let sync_validations = fields.sync_validations(); + + quote! { + #enum_ident::#ident => { + #sync_validations + } + } + } + } + } + + fn async_match_arm_tokens(&self) -> TokenStream { + let enum_ident = &self.enum_ident; + let enum_error_ident = &self.enum_error_ident; + let ident = &self.ident; + + let error_wrapper = |tokens| quote!(#enum_error_ident::#ident(#tokens)); + + match &self.fields { + ValidateFields::Named(fields) => { + let field_idents = fields.idents(); + let async_validations = + fields.async_validations(ValidateFieldPrefix::None, &error_wrapper); + + // TODO: Only destructure fields required for validation. + quote! { + #[allow(unused_variables)] + #enum_ident::#ident { + #( #field_idents ),* + } => { + #async_validations + } + } + } + ValidateFields::Unnamed(fields) => { + let field_idents = fields.idents().map(|ident| match ident { + LiteralOrIdent::Literal(literal) => format_ident!("f{literal}"), + LiteralOrIdent::Ident(ident) => ident.clone(), + }); + let async_validations = + fields.async_validations(ValidateFieldPrefix::F, &error_wrapper); + + quote! { + #enum_ident::#ident( + #( #field_idents ),* + ) => { + #async_validations + } + } + } + ValidateFields::Unit(fields) => { + let async_validations = fields.async_validations(); + + quote! { + #enum_ident::#ident => { + #async_validations + } + } + } + } } } diff --git a/packages/fortifier-macros/src/validate/field.rs b/packages/fortifier-macros/src/validate/field.rs index 7f91731..a4ec772 100644 --- a/packages/fortifier-macros/src/validate/field.rs +++ b/packages/fortifier-macros/src/validate/field.rs @@ -1,32 +1,62 @@ use convert_case::{Case, Casing}; -use proc_macro2::TokenStream; +use proc_macro2::{Literal, TokenStream}; use quote::{ToTokens, format_ident, quote}; -use syn::{Field, Ident, Result}; +use syn::{Field, Ident, Result, Visibility}; use crate::{ validation::Validation, validations::{Custom, Email, Length, Regex, Url}, }; +pub enum LiteralOrIdent { + Literal(Literal), + Ident(Ident), +} + +impl ToTokens for LiteralOrIdent { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self { + LiteralOrIdent::Literal(literal) => literal.to_tokens(tokens), + LiteralOrIdent::Ident(ident) => ident.to_tokens(tokens), + } + } +} + +#[derive(Clone, Copy)] +pub enum ValidateFieldPrefix { + None, + SelfKeyword, + F, +} + pub struct ValidateField { + visibility: Visibility, + ident: LiteralOrIdent, + error_ident: Ident, error_type_ident: Ident, - expr: TokenStream, validations: Vec>, } impl ValidateField { pub fn parse( + visibility: &Visibility, type_prefix: &Ident, - ident: Ident, - expr: TokenStream, + ident: LiteralOrIdent, field: &Field, ) -> Result { - let error_ident = format_ident!("{}", ident.to_string().to_case(Case::UpperCamel)); + let error_ident = match &ident { + LiteralOrIdent::Literal(literal) => format_ident!("F{literal}"), + LiteralOrIdent::Ident(ident) => { + format_ident!("{}", ident.to_string().to_case(Case::UpperCamel)) + } + }; let error_type_ident = format_ident!("{type_prefix}{error_ident}ValidationError"); let mut result = Self { + visibility: visibility.clone(), + ident, + error_ident, error_type_ident, - expr, validations: vec![], }; @@ -63,13 +93,18 @@ impl ValidateField { Ok(result) } - pub fn error_type( - &self, - ident: &Ident, - field_error_ident: &Ident, - ) -> (TokenStream, Option) { + pub fn ident(&self) -> &LiteralOrIdent { + &self.ident + } + + pub fn error_ident(&self) -> &Ident { + &self.error_ident + } + + pub fn error_type(&self, ident: &Ident) -> (TokenStream, Option) { if self.validations.len() > 1 { - let ident = format_ident!("{}{}ValidationError", ident, field_error_ident); + let visibility = &self.visibility; + let ident = format_ident!("{}{}ValidationError", ident, self.error_ident); let variant_ident = self.validations.iter().map(|validation| validation.ident()); let variant_type = self .validations @@ -80,7 +115,7 @@ impl ValidateField { ident.to_token_stream(), Some(quote! { #[derive(Debug)] - enum #ident { + #visibility enum #ident { #( #variant_ident(#variant_type) ),* } }), @@ -92,15 +127,25 @@ impl ValidateField { } } - pub fn sync_validations(&self) -> Vec { + pub fn sync_validations(&self, field_prefix: ValidateFieldPrefix) -> Vec { let error_type_ident = &self.error_type_ident; + let ident = &self.ident; self.validations .iter() .filter(|validation| !validation.is_async()) .map(|validation| { let validation_ident = validation.ident(); - let tokens = validation.tokens(&self.expr); + let tokens = validation.tokens(&match field_prefix { + ValidateFieldPrefix::None => self.ident.to_token_stream(), + ValidateFieldPrefix::SelfKeyword => quote!(self.#ident), + ValidateFieldPrefix::F => match &self.ident { + LiteralOrIdent::Literal(literal) => { + format_ident!("f{literal}").to_token_stream() + } + LiteralOrIdent::Ident(ident) => ident.to_token_stream(), + }, + }); if self.validations.len() > 1 { quote! { @@ -113,7 +158,8 @@ impl ValidateField { .collect() } - pub fn async_validations(&self) -> Vec { + pub fn async_validations(&self, field_prefix: ValidateFieldPrefix) -> Vec { + let ident = &self.ident; let error_type_ident = &self.error_type_ident; self.validations @@ -121,7 +167,16 @@ impl ValidateField { .filter(|validation| validation.is_async()) .map(|validation| { let validation_ident = validation.ident(); - let tokens = validation.tokens(&self.expr); + let tokens = validation.tokens(&match field_prefix { + ValidateFieldPrefix::None => self.ident.to_token_stream(), + ValidateFieldPrefix::SelfKeyword => quote!(self.#ident), + ValidateFieldPrefix::F => match &self.ident { + LiteralOrIdent::Literal(literal) => { + format_ident!("f{literal}").to_token_stream() + } + LiteralOrIdent::Ident(ident) => ident.to_token_stream(), + }, + }); if self.validations.len() > 1 { quote! { diff --git a/packages/fortifier-macros/src/validate/fields.rs b/packages/fortifier-macros/src/validate/fields.rs new file mode 100644 index 0000000..217eb43 --- /dev/null +++ b/packages/fortifier-macros/src/validate/fields.rs @@ -0,0 +1,364 @@ +use std::collections::HashSet; + +use proc_macro2::{Literal, TokenStream}; +use quote::{format_ident, quote}; +use syn::{Fields, FieldsNamed, FieldsUnnamed, Ident, Result, Visibility}; + +use crate::validate::field::{LiteralOrIdent, ValidateField, ValidateFieldPrefix}; + +pub enum ValidateFields { + Named(ValidateNamedFields), + Unnamed(ValidateUnnamedFields), + Unit(ValidateUnitFields), +} + +impl ValidateFields { + pub fn parse(visibility: &Visibility, ident: &Ident, fields: &Fields) -> Result { + Ok(match fields { + Fields::Named(fields) => { + Self::Named(ValidateNamedFields::parse(visibility, ident, fields)?) + } + Fields::Unnamed(fields) => { + Self::Unnamed(ValidateUnnamedFields::parse(visibility, ident, fields)?) + } + Fields::Unit => Self::Unit(ValidateUnitFields::parse()?), + }) + } + + pub fn uses(&self) -> HashSet { + match self { + ValidateFields::Named(named) => named.uses(), + ValidateFields::Unnamed(unnamed) => unnamed.uses(), + ValidateFields::Unit(unit) => unit.uses(), + } + } + + pub fn error_type(&self) -> (Ident, TokenStream) { + match self { + ValidateFields::Named(named) => named.error_type(), + ValidateFields::Unnamed(unnamed) => unnamed.error_type(), + ValidateFields::Unit(unit) => unit.error_type(), + } + } + + pub fn sync_validations( + &self, + field_prefix: ValidateFieldPrefix, + error_wrapper: &impl Fn(TokenStream) -> TokenStream, + ) -> TokenStream { + match self { + ValidateFields::Named(named) => named.sync_validations(field_prefix, error_wrapper), + ValidateFields::Unnamed(unnamed) => { + unnamed.sync_validations(field_prefix, error_wrapper) + } + ValidateFields::Unit(unit) => unit.sync_validations(), + } + } + + pub fn async_validations( + &self, + field_prefix: ValidateFieldPrefix, + error_wrapper: &impl Fn(TokenStream) -> TokenStream, + ) -> TokenStream { + match self { + ValidateFields::Named(named) => named.async_validations(field_prefix, error_wrapper), + ValidateFields::Unnamed(unnamed) => { + unnamed.async_validations(field_prefix, error_wrapper) + } + ValidateFields::Unit(unit) => unit.async_validations(), + } + } +} + +pub struct ValidateNamedFields { + visibility: Visibility, + ident: Ident, + error_ident: Ident, + fields: Vec, +} + +impl ValidateNamedFields { + fn parse(visibility: &Visibility, ident: &Ident, fields: &FieldsNamed) -> Result { + let mut result = Self { + visibility: visibility.clone(), + ident: ident.clone(), + error_ident: format_ident!("{}ValidationError", ident), + fields: Vec::with_capacity(fields.named.len()), + }; + + for field in &fields.named { + let Some(field_ident) = &field.ident else { + continue; + }; + + result.fields.push(ValidateField::parse( + visibility, + ident, + LiteralOrIdent::Ident(field_ident.clone()), + field, + )?); + } + + Ok(result) + } + + pub fn idents(&self) -> impl Iterator { + self.fields.iter().map(|field| field.ident()) + } + + fn uses(&self) -> HashSet { + HashSet::default() + } + + fn error_type(&self) -> (Ident, TokenStream) { + let visibility = &self.visibility; + let ident = &self.ident; + let error_ident = &self.error_ident; + + let mut error_field_idents = vec![]; + let mut error_field_types = vec![]; + let mut error_field_enums = vec![]; + + for field in &self.fields { + let field_error_ident = field.error_ident(); + let (field_error_type, field_error_enum) = field.error_type(ident); + + error_field_idents.push(field_error_ident.clone()); + error_field_types.push(field_error_type); + if let Some(error_enum) = field_error_enum { + error_field_enums.push(error_enum); + } + } + + ( + error_ident.clone(), + quote! { + #[allow(dead_code)] + #[derive(Debug)] + #visibility enum #error_ident { + #( #error_field_idents(#error_field_types) ),* + } + + #[automatically_derived] + impl ::std::fmt::Display for #error_ident { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + write!(f, "{self:#?}") + } + } + + #[automatically_derived] + impl ::std::error::Error for #error_ident {} + + #( #error_field_enums )* + }, + ) + } + + pub fn sync_validations( + &self, + field_prefix: ValidateFieldPrefix, + error_wrapper: &impl Fn(TokenStream) -> TokenStream, + ) -> TokenStream { + validations( + &self.error_ident, + error_wrapper, + self.fields + .iter() + .map(|field| (field, field.sync_validations(field_prefix))), + ) + } + + pub fn async_validations( + &self, + field_prefix: ValidateFieldPrefix, + error_wrapper: &impl Fn(TokenStream) -> TokenStream, + ) -> TokenStream { + validations( + &self.error_ident, + error_wrapper, + self.fields + .iter() + .map(|field| (field, field.async_validations(field_prefix))), + ) + } +} + +pub struct ValidateUnnamedFields { + visibility: Visibility, + ident: Ident, + error_ident: Ident, + fields: Vec, +} + +impl ValidateUnnamedFields { + fn parse(visibility: &Visibility, ident: &Ident, fields: &FieldsUnnamed) -> Result { + let mut result = Self { + visibility: visibility.clone(), + ident: ident.clone(), + error_ident: format_ident!("{}ValidationError", ident), + fields: Vec::with_capacity(fields.unnamed.len()), + }; + + for (index, field) in fields.unnamed.iter().enumerate() { + result.fields.push(ValidateField::parse( + visibility, + ident, + LiteralOrIdent::Literal(Literal::usize_unsuffixed(index)), + field, + )?); + } + + Ok(result) + } + + pub fn idents(&self) -> impl Iterator { + self.fields.iter().map(|field| field.ident()) + } + + fn uses(&self) -> HashSet { + HashSet::default() + } + + fn error_type(&self) -> (Ident, TokenStream) { + let visibility = &self.visibility; + let ident = &self.ident; + let error_ident = &self.error_ident; + + let mut error_field_idents = vec![]; + let mut error_field_types = vec![]; + let mut error_field_enums = vec![]; + + for field in &self.fields { + let field_error_ident = field.error_ident(); + let (field_error_type, field_error_enum) = field.error_type(ident); + + error_field_idents.push(field_error_ident.clone()); + error_field_types.push(field_error_type); + if let Some(error_enum) = field_error_enum { + error_field_enums.push(error_enum); + } + } + + ( + error_ident.clone(), + quote! { + #[allow(dead_code)] + #[derive(Debug)] + #visibility enum #error_ident { + #( #error_field_idents(#error_field_types) ),* + } + + #[automatically_derived] + impl ::std::fmt::Display for #error_ident { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + write!(f, "{self:#?}") + } + } + + #[automatically_derived] + impl ::std::error::Error for #error_ident {} + + #( #error_field_enums )* + }, + ) + } + + pub fn sync_validations( + &self, + field_prefix: ValidateFieldPrefix, + error_wrapper: &impl Fn(TokenStream) -> TokenStream, + ) -> TokenStream { + validations( + &self.error_ident, + error_wrapper, + self.fields + .iter() + .map(|field| (field, field.sync_validations(field_prefix))), + ) + } + + pub fn async_validations( + &self, + field_prefix: ValidateFieldPrefix, + error_wrapper: &impl Fn(TokenStream) -> TokenStream, + ) -> TokenStream { + validations( + &self.error_ident, + error_wrapper, + self.fields + .iter() + .map(|field| (field, field.async_validations(field_prefix))), + ) + } +} + +pub struct ValidateUnitFields {} + +impl ValidateUnitFields { + fn parse() -> Result { + Ok(Self {}) + } + + fn uses(&self) -> HashSet { + HashSet::from(["std::convert::Infallible".to_owned()]) + } + + fn error_type(&self) -> (Ident, TokenStream) { + (format_ident!("Infallible"), TokenStream::new()) + } + + pub fn sync_validations(&self) -> TokenStream { + quote! { + Ok(()) + } + } + + pub fn async_validations(&self) -> TokenStream { + quote! { + Ok(()) + } + } +} + +fn validations<'a>( + error_ident: &Ident, + error_wrapper: &impl Fn(TokenStream) -> TokenStream, + iterator: impl Iterator)>, +) -> TokenStream { + let validations = iterator + .flat_map(|(field, validations)| { + let field_error_ident = field.error_ident(); + + validations + .iter() + .map(|validation| { + let error = error_wrapper(quote!(#error_ident::#field_error_ident(err))); + + quote! { + if let Err(err) = #validation { + errors.push(#error); + } + } + }) + .collect::>() + }) + .collect::>(); + + if validations.is_empty() { + quote! { + Ok(()) + } + } else { + quote! { + let mut errors = vec![]; + + #(#validations)* + + if !errors.is_empty() { + Err(errors.into()) + } else { + Ok(()) + } + } + } +} diff --git a/packages/fortifier-macros/src/validate/struct.rs b/packages/fortifier-macros/src/validate/struct.rs index c828954..5c7c58e 100644 --- a/packages/fortifier-macros/src/validate/struct.rs +++ b/packages/fortifier-macros/src/validate/struct.rs @@ -1,333 +1,68 @@ -use std::collections::HashMap; +use std::str::FromStr; -use convert_case::{Case, Casing}; -use proc_macro2::{Literal, TokenStream}; -use quote::{ToTokens, TokenStreamExt, format_ident, quote}; -use syn::{DataStruct, DeriveInput, Fields, FieldsNamed, FieldsUnnamed, Generics, Ident, Result}; +use proc_macro2::TokenStream; +use quote::{ToTokens, TokenStreamExt, quote}; +use syn::{DataStruct, DeriveInput, Generics, Ident, Result}; -use crate::validate::field::ValidateField; +use crate::validate::{field::ValidateFieldPrefix, fields::ValidateFields}; -pub enum ValidateStruct { - Named(ValidateNamedStruct), - Unnamed(ValidateUnnamedStruct), - Unit(ValidateUnitStruct), -} - -impl ValidateStruct { - pub fn parse(input: &DeriveInput, data: &DataStruct) -> Result { - Ok(match &data.fields { - Fields::Named(fields) => Self::Named(ValidateNamedStruct::parse(input, data, fields)?), - Fields::Unnamed(fields) => { - Self::Unnamed(ValidateUnnamedStruct::parse(input, data, fields)?) - } - Fields::Unit => Self::Unit(ValidateUnitStruct::parse(input)?), - }) - } -} - -impl ToTokens for ValidateStruct { - fn to_tokens(&self, tokens: &mut TokenStream) { - match self { - ValidateStruct::Named(named) => named.to_tokens(tokens), - ValidateStruct::Unnamed(unnamed) => unnamed.to_tokens(tokens), - ValidateStruct::Unit(unit) => unit.to_tokens(tokens), - } - } -} - -pub struct ValidateNamedStruct { +pub struct ValidateStruct { ident: Ident, - error_ident: Ident, generics: Generics, - fields: HashMap, + fields: ValidateFields, } -impl ValidateNamedStruct { - fn parse(input: &DeriveInput, _data: &DataStruct, fields: &FieldsNamed) -> Result { - let mut result = Self { +impl ValidateStruct { + pub fn parse(input: &DeriveInput, data: &DataStruct) -> Result { + Ok(ValidateStruct { ident: input.ident.clone(), - error_ident: format_ident!("{}ValidationError", input.ident), generics: input.generics.clone(), - fields: HashMap::default(), - }; - - for field in &fields.named { - let Some(field_ident) = &field.ident else { - continue; - }; - - let expr = quote!(self.#field_ident); - - result.fields.insert( - field_ident.clone(), - ValidateField::parse(&input.ident, field_ident.clone(), expr, field)?, - ); - } - - Ok(result) - } -} - -impl ToTokens for ValidateNamedStruct { - fn to_tokens(&self, tokens: &mut TokenStream) { - let ident = &self.ident; - let error_ident = &self.error_ident; - let (impl_generics, type_generics, where_clause) = &self.generics.split_for_impl(); - - let mut error_field_idents = vec![]; - let mut error_field_types = vec![]; - let mut error_field_enums = vec![]; - let mut sync_validations = vec![]; - let mut async_validations = vec![]; - - for (field_ident, field) in &self.fields { - let field_error_ident = - format_ident!("{}", &field_ident.to_string().to_case(Case::UpperCamel)); - - let (error_type, error_enum) = field.error_type(ident, &field_error_ident); - - error_field_idents.push(field_error_ident.clone()); - error_field_types.push(error_type); - if let Some(error_enum) = error_enum { - error_field_enums.push(error_enum); - } - - for validation in field.sync_validations() { - sync_validations.push(quote! { - if let Err(err) = #validation { - errors.push(#error_ident::#field_error_ident(err)); - } - }); - } - - for validation in field.async_validations() { - async_validations.push(quote! { - if let Err(err) = #validation { - errors.push(#error_ident::#field_error_ident(err)); - } - }); - } - } - - tokens.append_all(quote! { - use fortifier::*; - - #[allow(dead_code)] - #[derive(Debug)] - enum #error_ident { - #( #error_field_idents(#error_field_types) ),* - } - - #[automatically_derived] - impl ::std::fmt::Display for #error_ident { - fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { - write!(f, "{self:#?}") - } - } - - #[automatically_derived] - impl ::std::error::Error for #error_ident {} - - #(#error_field_enums)* - - #[automatically_derived] - impl #impl_generics Validate for #ident #type_generics #where_clause { - type Error = #error_ident; - - fn validate_sync(&self) -> Result<(), ValidationErrors> { - let mut errors = vec![]; - - #(#sync_validations)* - - if !errors.is_empty() { - Err(errors.into()) - } else { - Ok(()) - } - } - - fn validate_async(&self) -> ::std::pin::Pin>>>> { - Box::pin(async { - let mut errors = vec![]; - - #(#async_validations)* - - if !errors.is_empty() { - Err(errors.into()) - } else { - Ok(()) - } - }) - } - } + fields: ValidateFields::parse(&input.vis, &input.ident, &data.fields)?, }) } } -pub struct ValidateUnnamedStruct { - ident: Ident, - error_ident: Ident, - generics: Generics, - fields: Vec, -} - -impl ValidateUnnamedStruct { - fn parse(input: &DeriveInput, _data: &DataStruct, fields: &FieldsUnnamed) -> Result { - let mut result = Self { - ident: input.ident.clone(), - error_ident: format_ident!("{}ValidationError", input.ident), - generics: input.generics.clone(), - fields: Vec::default(), - }; - - for (index, field) in fields.unnamed.iter().enumerate() { - let index = Literal::usize_unsuffixed(index); - let field_ident = format_ident!("F{index}"); - let expr = quote!(self.#index); - - result.fields.push(ValidateField::parse( - &input.ident, - field_ident, - expr, - field, - )?); - } - - Ok(result) - } -} - -impl ToTokens for ValidateUnnamedStruct { +impl ToTokens for ValidateStruct { fn to_tokens(&self, tokens: &mut TokenStream) { let ident = &self.ident; - let error_ident = &self.error_ident; - let (impl_generics, type_generics, where_clause) = &self.generics.split_for_impl(); - - let mut error_field_idents = vec![]; - let mut error_field_types = vec![]; - let mut error_field_enums = vec![]; - let mut sync_validations = vec![]; - let mut async_validations = vec![]; - - for (index, field) in self.fields.iter().enumerate() { - let field_error_ident = format_ident!("F{index}"); - - let (error_type, error_enum) = field.error_type(ident, &field_error_ident); - - error_field_idents.push(field_error_ident.clone()); - error_field_types.push(error_type); - if let Some(error_enum) = error_enum { - error_field_enums.push(error_enum); - } + let (impl_generics, type_generics, where_clause) = self.generics.split_for_impl(); - for validation in field.sync_validations() { - sync_validations.push(quote! { - if let Err(err) = #validation { - errors.push(#error_ident::#field_error_ident(err)); - } - }); - } + let error_wrapper = |tokens| tokens; - for validation in field.async_validations() { - async_validations.push(quote! { - if let Err(err) = #validation { - errors.push(#error_ident::#field_error_ident(err)); - } - }); - } - } + let uses = self.fields.uses().into_iter().map(|r#use| { + let tokens = TokenStream::from_str(&r#use).expect("Tokens should be valid."); + quote!(use #tokens;) + }); + let (error_ident, error_type) = self.fields.error_type(); + let sync_validations = self + .fields + .sync_validations(ValidateFieldPrefix::SelfKeyword, &error_wrapper); + let async_validations = self + .fields + .async_validations(ValidateFieldPrefix::SelfKeyword, &error_wrapper); tokens.append_all(quote! { - use fortifier::*; - - #[allow(dead_code)] - #[derive(Debug)] - enum #error_ident { - #( #error_field_idents(#error_field_types) ),* - } - - #[automatically_derived] - impl ::std::fmt::Display for #error_ident { - fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { - write!(f, "{self:#?}") - } - } + #( #uses )* - #[automatically_derived] - impl ::std::error::Error for #error_ident {} + // TODO: Replace with granular uses. + use fortifier::*; - #(#error_field_enums)* + #error_type #[automatically_derived] impl #impl_generics Validate for #ident #type_generics #where_clause { type Error = #error_ident; fn validate_sync(&self) -> Result<(), ValidationErrors> { - let mut errors = vec![]; - - #(#sync_validations)* - - if !errors.is_empty() { - Err(errors.into()) - } else { - Ok(()) - } + #sync_validations } fn validate_async(&self) -> ::std::pin::Pin>>>> { Box::pin(async { - let mut errors = vec![]; - - #(#async_validations)* - - if !errors.is_empty() { - Err(errors.into()) - } else { - Ok(()) - } + #async_validations }) } } }) } } - -pub struct ValidateUnitStruct { - ident: Ident, - generics: Generics, -} - -impl ValidateUnitStruct { - fn parse(input: &DeriveInput) -> Result { - Ok(Self { - ident: input.ident.clone(), - generics: input.generics.clone(), - }) - } -} - -impl ToTokens for ValidateUnitStruct { - fn to_tokens(&self, tokens: &mut TokenStream) { - let ident = &self.ident; - let (impl_generics, type_generics, where_clause) = &self.generics.split_for_impl(); - - tokens.append_all(quote! { - use fortifier::ValidationErrors; - - #[automatically_derived] - impl #impl_generics Validate for #ident #type_generics #where_clause { - type Error = ::std::convert::Infallible; - - fn validate_sync(&self) -> Result<(), ValidationErrors> { - Ok(()) - } - - fn validate_async(&self) -> ::std::pin::Pin>>>> { - Box::pin(async { - Ok(()) - }) - } - } - }); - } -} diff --git a/packages/fortifier-macros/tests/derive/enum_named_pass.rs b/packages/fortifier-macros/tests/derive/enum_named_pass.rs new file mode 100644 index 0000000..6946dfb --- /dev/null +++ b/packages/fortifier-macros/tests/derive/enum_named_pass.rs @@ -0,0 +1,41 @@ +use std::error::Error; + +use fortifier::Validate; + +#[derive(Validate)] +enum ChangeEmailAddressRelation { + Create { + #[validate(email)] + email_address: String, + }, + Update { + id: String, + + #[validate(email)] + email_address: String, + }, + Delete { + id: String, + }, +} + +fn main() -> Result<(), Box> { + let data = ChangeEmailAddressRelation::Create { + email_address: "john@doe.com".to_owned(), + }; + + data.validate_sync()?; + + let data = ChangeEmailAddressRelation::Update { + id: "1".to_owned(), + email_address: "john@doe.com".to_owned(), + }; + + data.validate_sync()?; + + let data = ChangeEmailAddressRelation::Delete { id: "1".to_owned() }; + + data.validate_sync()?; + + Ok(()) +} diff --git a/packages/fortifier-macros/tests/derive/enum_unit_pass.rs b/packages/fortifier-macros/tests/derive/enum_unit_pass.rs new file mode 100644 index 0000000..20f19f7 --- /dev/null +++ b/packages/fortifier-macros/tests/derive/enum_unit_pass.rs @@ -0,0 +1,26 @@ +use std::error::Error; + +use fortifier::Validate; + +#[derive(Validate)] +enum ChangeEmailAddressRelation { + Create, + Update, + Delete, +} + +fn main() -> Result<(), Box> { + let data = ChangeEmailAddressRelation::Create; + + data.validate_sync()?; + + let data = ChangeEmailAddressRelation::Update; + + data.validate_sync()?; + + let data = ChangeEmailAddressRelation::Delete; + + data.validate_sync()?; + + Ok(()) +} diff --git a/packages/fortifier-macros/tests/derive/enum_unnamed_pass.rs b/packages/fortifier-macros/tests/derive/enum_unnamed_pass.rs new file mode 100644 index 0000000..c07374d --- /dev/null +++ b/packages/fortifier-macros/tests/derive/enum_unnamed_pass.rs @@ -0,0 +1,26 @@ +use std::error::Error; + +use fortifier::Validate; + +#[derive(Validate)] +enum ChangeEmailAddressRelation { + Create(#[validate(email)] String), + Update(String, #[validate(email)] String), + Delete(String), +} + +fn main() -> Result<(), Box> { + let data = ChangeEmailAddressRelation::Create("john@doe.com".to_owned()); + + data.validate_sync()?; + + let data = ChangeEmailAddressRelation::Update("1".to_owned(), "john@doe.com".to_owned()); + + data.validate_sync()?; + + let data = ChangeEmailAddressRelation::Delete("1".to_owned()); + + data.validate_sync()?; + + Ok(()) +}