Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 21 additions & 91 deletions packages/fortifier-macros/src/validate/enum.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use std::{collections::HashSet, str::FromStr};

use proc_macro2::TokenStream;
use quote::{ToTokens, TokenStreamExt, format_ident, quote};
use syn::{DataEnum, DeriveInput, Generics, Ident, Result, Variant, Visibility};

use crate::validate::{
field::{LiteralOrIdent, ValidateFieldPrefix},
fields::ValidateFields,
use crate::{
validate::{
field::{LiteralOrIdent, ValidateFieldPrefix},
fields::ValidateFields,
},
validation::Execution,
};

pub struct ValidateEnum {
Expand Down Expand Up @@ -39,13 +40,6 @@ impl ValidateEnum {
Ok(result)
}

fn uses(&self) -> HashSet<String> {
self.variants
.iter()
.flat_map(|variant| variant.uses())
.collect()
}

fn error_type(&self) -> (Ident, TokenStream) {
let visibility = &self.visibility;
let error_ident = &self.error_ident;
Expand Down Expand Up @@ -89,42 +83,33 @@ impl ToTokens for ValidateEnum {
let ident = &self.ident;
let (impl_generics, type_generics, where_clause) = &self.generics.split_for_impl();

let uses = self.uses().into_iter().map(|r#use| {
let tokens = TokenStream::from_str(&r#use).expect("Tokens should be valid.");
quote!(use #tokens;)
});
let (error_ident, error_type) = self.error_type();
let variant_error_types = self.variants.iter().map(|variant| variant.error_type().1);
let sync_variant_match_arms = self
.variants
.iter()
.map(|variant| variant.sync_match_arm_tokens());
.map(|variant| variant.match_arm(Execution::Sync));
let async_variant_match_arms = self
.variants
.iter()
.map(|variant| variant.async_match_arm_tokens());
.map(|variant| variant.match_arm(Execution::Async));

tokens.append_all(quote! {
#( #uses )*

// TODO: Replace with granular uses.
use fortifier::*;

#error_type

#( #variant_error_types )*

#[automatically_derived]
impl #impl_generics Validate for #ident #type_generics #where_clause {
impl #impl_generics ::fortifier::Validate for #ident #type_generics #where_clause {
type Error = #error_ident;

fn validate_sync(&self) -> Result<(), ValidationErrors<Self::Error>> {
fn validate_sync(&self) -> Result<(), ::fortifier::ValidationErrors<Self::Error>> {
match &self {
#( #sync_variant_match_arms ),*
}
}

fn validate_async(&self) -> ::std::pin::Pin<Box<impl Future<Output = Result<(), ValidationErrors<Self::Error>>>>> {
fn validate_async(&self) -> ::std::pin::Pin<Box<impl Future<Output = Result<(), ::fortifier::ValidationErrors<Self::Error>>>>> {
Box::pin(async move {
match &self {
#( #async_variant_match_arms ),*
Expand Down Expand Up @@ -164,66 +149,11 @@ impl ValidateEnumVariant {
Ok(result)
}

fn uses(&self) -> HashSet<String> {
self.fields.uses()
}

fn error_type(&self) -> (Ident, TokenStream) {
fn error_type(&self) -> (TokenStream, TokenStream) {
self.fields.error_type()
}

fn sync_match_arm_tokens(&self) -> TokenStream {
let enum_ident = &self.enum_ident;
let enum_error_ident = &self.enum_error_ident;
let ident = &self.ident;

let error_wrapper = |tokens| quote!(#enum_error_ident::#ident(#tokens));

match &self.fields {
ValidateFields::Named(fields) => {
let field_idents = fields.idents();
let sync_validations =
fields.sync_validations(ValidateFieldPrefix::None, &error_wrapper);

// TODO: Only destructure fields required for validation.
quote! {
#[allow(unused_variables)]
#enum_ident::#ident {
#( #field_idents ),*
} => {
#sync_validations
}
}
}
ValidateFields::Unnamed(fields) => {
let field_idents = fields.idents().map(|ident| match ident {
LiteralOrIdent::Literal(literal) => format_ident!("f{literal}"),
LiteralOrIdent::Ident(ident) => ident.clone(),
});
let sync_validations =
fields.sync_validations(ValidateFieldPrefix::F, &error_wrapper);

quote! {
#enum_ident::#ident(
#( #field_idents ),*
) => {
#sync_validations
}
}
}
ValidateFields::Unit(fields) => {
let sync_validations = fields.sync_validations();

quote! {
#enum_ident::#ident => {
#sync_validations
}
}
}
}
}

fn async_match_arm_tokens(&self) -> TokenStream {
fn match_arm(&self, exeuction: Execution) -> TokenStream {
let enum_ident = &self.enum_ident;
let enum_error_ident = &self.enum_error_ident;
let ident = &self.ident;
Expand All @@ -233,16 +163,16 @@ impl ValidateEnumVariant {
match &self.fields {
ValidateFields::Named(fields) => {
let field_idents = fields.idents();
let async_validations =
fields.async_validations(ValidateFieldPrefix::None, &error_wrapper);
let validations =
fields.validations(exeuction, ValidateFieldPrefix::None, &error_wrapper);

// TODO: Only destructure fields required for validation.
quote! {
#[allow(unused_variables)]
#enum_ident::#ident {
#( #field_idents ),*
} => {
#async_validations
#validations
}
}
}
Expand All @@ -251,23 +181,23 @@ impl ValidateEnumVariant {
LiteralOrIdent::Literal(literal) => format_ident!("f{literal}"),
LiteralOrIdent::Ident(ident) => ident.clone(),
});
let async_validations =
fields.async_validations(ValidateFieldPrefix::F, &error_wrapper);
let validations =
fields.validations(exeuction, ValidateFieldPrefix::F, &error_wrapper);

quote! {
#enum_ident::#ident(
#( #field_idents ),*
) => {
#async_validations
#validations
}
}
}
ValidateFields::Unit(fields) => {
let async_validations = fields.async_validations();
let validations = fields.validations();

quote! {
#enum_ident::#ident => {
#async_validations
#validations
}
}
}
Expand Down
74 changes: 24 additions & 50 deletions packages/fortifier-macros/src/validate/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use quote::{ToTokens, format_ident, quote};
use syn::{Field, Ident, Result, Visibility};

use crate::{
validation::Validation,
validation::{Execution, Validation},
validations::{Custom, Email, Length, Regex, Url},
};

Expand Down Expand Up @@ -127,64 +127,38 @@ impl ValidateField {
}
}

pub fn sync_validations(&self, field_prefix: ValidateFieldPrefix) -> Vec<TokenStream> {
pub fn validations(
&self,
execution: Execution,
field_prefix: ValidateFieldPrefix,
) -> Vec<TokenStream> {
let error_type_ident = &self.error_type_ident;
let ident = &self.ident;

self.validations
.iter()
.filter(|validation| !validation.is_async())
.map(|validation| {
let validation_ident = validation.ident();
let tokens = validation.tokens(&match field_prefix {
ValidateFieldPrefix::None => self.ident.to_token_stream(),
ValidateFieldPrefix::SelfKeyword => quote!(self.#ident),
ValidateFieldPrefix::F => match &self.ident {
LiteralOrIdent::Literal(literal) => {
format_ident!("f{literal}").to_token_stream()
}
LiteralOrIdent::Ident(ident) => ident.to_token_stream(),
},
});

if self.validations.len() > 1 {
quote! {
#tokens.map_err(#error_type_ident::#validation_ident)
}
} else {
tokens
}
})
.collect()
}

pub fn async_validations(&self, field_prefix: ValidateFieldPrefix) -> Vec<TokenStream> {
let ident = &self.ident;
let error_type_ident = &self.error_type_ident;
let field_expr = match field_prefix {
ValidateFieldPrefix::None => self.ident.to_token_stream(),
ValidateFieldPrefix::SelfKeyword => quote!(self.#ident),
ValidateFieldPrefix::F => match &self.ident {
LiteralOrIdent::Literal(literal) => format_ident!("f{literal}").to_token_stream(),
LiteralOrIdent::Ident(ident) => ident.to_token_stream(),
},
};

self.validations
.iter()
.filter(|validation| validation.is_async())
.map(|validation| {
.flat_map(|validation| {
let validation_ident = validation.ident();
let tokens = validation.tokens(&match field_prefix {
ValidateFieldPrefix::None => self.ident.to_token_stream(),
ValidateFieldPrefix::SelfKeyword => quote!(self.#ident),
ValidateFieldPrefix::F => match &self.ident {
LiteralOrIdent::Literal(literal) => {
format_ident!("f{literal}").to_token_stream()
}
LiteralOrIdent::Ident(ident) => ident.to_token_stream(),
},
});
let expr = validation.expr(execution, &field_expr);

if self.validations.len() > 1 {
quote! {
#tokens.map_err(#error_type_ident::#validation_ident)
expr.map(|expr| {
if self.validations.len() > 1 {
quote! {
#expr.map_err(#error_type_ident::#validation_ident)
}
} else {
expr
}
} else {
tokens
}
})
})
.collect()
}
Expand Down
Loading