Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions Source/FunicularSwitch.Generators.Common/RoslynExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,24 @@ public static bool IsAnyKeyWord(this string identifier) =>
|| SyntaxFacts.GetReservedKeywordKinds().Contains(SyntaxFactory.ParseToken(identifier).Kind());


public static bool InheritsFrom(this INamedTypeSymbol symbol, ITypeSymbol type)
public static bool InheritsFrom(this INamedTypeSymbol symbol, INamedTypeSymbol type)
{
var baseType = symbol.BaseType;
while (baseType != null)
{
if (type.Equals(baseType, SymbolEqualityComparer.Default))
{
return true;
}

// If the derived type is not declared as a nested type and the base type is generic, then the generic <T>s will not be equal to each other.
// That is why we compare to the fully qualified display string without generics and then assert the number of type arguments is the same, which also prevents issues when the type arguments are named differently in the deriving class
if (type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat.WithGenericsOptions(SymbolDisplayGenericsOptions.None)) ==
baseType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat.WithGenericsOptions(SymbolDisplayGenericsOptions.None))
&& type.TypeParameters.Length == baseType.TypeParameters.Length)
{
return true;
}

baseType = baseType.BaseType;
}
Expand Down Expand Up @@ -101,7 +112,7 @@ public static QualifiedTypeName QualifiedName(this BaseTypeDeclarationSyntax dec
return new(dec.Name(), typeNames);
}

public static QualifiedTypeName QualifiedNameWithGenerics(this BaseTypeDeclarationSyntax dec)
public static QualifiedTypeName QualifiedNameWithGenerics(this BaseTypeDeclarationSyntax dec, INamedTypeSymbol typeSymbol, INamedTypeSymbol baseType)
{
var current = dec.Parent as BaseTypeDeclarationSyntax;
var typeNames = new Stack<string>();
Expand All @@ -110,8 +121,25 @@ public static QualifiedTypeName QualifiedNameWithGenerics(this BaseTypeDeclarati
typeNames.Push(current.Name() + FormatTypeParameters(current.GetTypeParameterList()));
current = current.Parent as BaseTypeDeclarationSyntax;
}

// If the base type has type parameters with different names then the generics are not resolved properly e.g. in the match method
// If though they are the same number of parameters and are passed in the same order as they are declared, we can replace the type argument list with the one from the base type and it should still work
var typeParameters = typeSymbol.TypeParameters;
var argumentsOnBaseType = typeSymbol.BaseType?.TypeArguments ?? [];
var baseTypeTypeParameters = baseType.TypeParameters;
EquatableArray<string> typeParametersForFormatting;
if (typeParameters.Length == argumentsOnBaseType.Length
&& typeParameters.Zip(argumentsOnBaseType, ValueTuple.Create).All(x => x.Item1.Equals(x.Item2, SymbolEqualityComparer.Default))
&& typeParameters.Length == baseTypeTypeParameters.Length)
{
typeParametersForFormatting = baseTypeTypeParameters.Select(tp => tp.Name).ToImmutableArray();
}
else
{
typeParametersForFormatting = typeParameters.Select(tp => tp.Name).ToImmutableArray();
}

return new(dec.Name() + FormatTypeParameters(dec.GetTypeParameterList()), typeNames);
return new(dec.Name() + FormatTypeParameters(typeParametersForFormatting), typeNames);
}

public static EquatableArray<string> GetTypeParameterList(this BaseTypeDeclarationSyntax dec)
Expand Down
36 changes: 25 additions & 11 deletions Source/FunicularSwitch.Generators/UnionType/Parser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,27 @@ public static GenerationResult<UnionTypeSchema> GetUnionTypeSchema(
var treeSemanticModel = syntaxTree != unionTypeClass.SyntaxTree ? compilation.GetSemanticModel(syntaxTree) : semanticModel;

return FindConcreteDerivedTypesWalker.Get(root, unionTypeSymbol, treeSemanticModel);
});
})
.ToList();


var isPartial = unionTypeClass.Modifiers.HasModifier(SyntaxKind.PartialKeyword)
&& unionTypeSymbol.ContainingType == null; //for now do not generate factory methods for nested types, we could support that if all containing types are partial
var generateFactoryMethods = isPartial && staticFactoryMethods;
var anyDerivedTypeInheritsFromBaseWithResolvedGenericArgument = derivedTypes.Any(tuple =>
tuple.symbol.BaseType?.TypeArguments.Any(ta => ta is not ITypeParameterSymbol) ?? false);
var anyDerivedTypeHasDifferentParameterNames = derivedTypes.Any(tuple =>
{
if (unionTypeSymbol.Equals(tuple.symbol.ContainingType, SymbolEqualityComparer.Default))
{
return false;
}
return !tuple.symbol.TypeParameters.Select(tp => tp.Name)
.SequenceEqual(unionTypeSymbol.TypeParameters.Select(tp => tp.Name));
});
var generateFactoryMethods = isPartial && staticFactoryMethods && !anyDerivedTypeInheritsFromBaseWithResolvedGenericArgument && !anyDerivedTypeHasDifferentParameterNames;

return
ToOrderedCases(caseOrder, derivedTypes, compilation, generateFactoryMethods, unionTypeSymbol.Name)
ToOrderedCases(caseOrder, derivedTypes, compilation, generateFactoryMethods, unionTypeSymbol)
.Map(cases => new UnionTypeSchema(
Namespace: fullNamespace,
TypeName: unionTypeSymbol.Name,
Expand Down Expand Up @@ -127,12 +139,13 @@ PropertyDeclarationSyntax p when p.Modifiers.HasModifier(SyntaxKind.StaticKeywor

static GenerationResult<ImmutableArray<DerivedType>> ToOrderedCases(CaseOrder caseOrder,
IEnumerable<(INamedTypeSymbol symbol, BaseTypeDeclarationSyntax node, int? caseIndex, int
numberOfConctreteBaseTypes)> derivedTypes, Compilation compilation, bool getConstructors, string baseTypeName)
numberOfConctreteBaseTypes)> derivedTypes, Compilation compilation, bool getConstructors, INamedTypeSymbol baseType)
{
var baseTypeName = baseType.Name;
var ordered = derivedTypes.OrderByDescending(d => d.numberOfConctreteBaseTypes);
ordered = caseOrder switch
{
CaseOrder.Alphabetic => ordered.ThenBy(d => d.node.QualifiedNameWithGenerics().Name),
CaseOrder.Alphabetic => ordered.ThenBy(d => d.node.QualifiedNameWithGenerics(d.symbol, baseType).Name),
CaseOrder.AsDeclared => ordered.ThenBy(d => d.node.SyntaxTree.FilePath)
.ThenBy(d => d.node.Span.Start),
CaseOrder.Explicit => ordered.ThenBy(d => d.caseIndex),
Expand Down Expand Up @@ -175,7 +188,8 @@ static GenerationResult<ImmutableArray<DerivedType>> ToOrderedCases(CaseOrder ca

var derived = result.Select(d =>
{
var qualifiedTypeName = d.node.QualifiedNameWithGenerics();
var qualifiedNameWithGenerics = d.node.QualifiedNameWithGenerics(d.symbol, baseType);
var qualifiedName = d.node.QualifiedName();
var fullNamespace = d.symbol.GetFullNamespace();
var constructors = ImmutableArray<CallableMemberInfo>.Empty;
var requiredMembers = ImmutableArray<PropertyOrFieldInfo>.Empty;
Expand Down Expand Up @@ -205,10 +219,10 @@ static GenerationResult<ImmutableArray<DerivedType>> ToOrderedCases(CaseOrder ca
}

var (parameterName, staticMethodName) =
DeriveParameterAndStaticMethodName(qualifiedTypeName.Name, baseTypeName);
DeriveParameterAndStaticMethodName(qualifiedName.Name, baseTypeName);

return new DerivedType(
fullTypeName: $"{(fullNamespace != null ? $"{fullNamespace}." : "")}{qualifiedTypeName}",
fullTypeName: $"{(fullNamespace != null ? $"{fullNamespace}." : "")}{qualifiedNameWithGenerics}",
constructors: constructors,
requiredMembers: requiredMembers,
parameterName: parameterName,
Expand All @@ -230,15 +244,15 @@ class FindConcreteDerivedTypesWalker : CSharpSyntaxWalker
{
readonly List<(INamedTypeSymbol symbol, BaseTypeDeclarationSyntax node, int? caseIndex)> m_DerivedClasses = new();
readonly SemanticModel m_SemanticModel;
readonly ITypeSymbol m_BaseClass;
readonly INamedTypeSymbol m_BaseClass;

FindConcreteDerivedTypesWalker(SemanticModel semanticModel, ITypeSymbol baseClass)
FindConcreteDerivedTypesWalker(SemanticModel semanticModel, INamedTypeSymbol baseClass)
{
m_SemanticModel = semanticModel;
m_BaseClass = baseClass;
}

public static IEnumerable<(INamedTypeSymbol symbol, BaseTypeDeclarationSyntax node, int? caseIndex, int numberOfConctreteBaseTypes)> Get(SyntaxNode node, ITypeSymbol baseClass, SemanticModel semanticModel)
public static IEnumerable<(INamedTypeSymbol symbol, BaseTypeDeclarationSyntax node, int? caseIndex, int numberOfConctreteBaseTypes)> Get(SyntaxNode node, INamedTypeSymbol baseClass, SemanticModel semanticModel)
{
var me = new FindConcreteDerivedTypesWalker(semanticModel, baseClass);
me.Visit(node);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public static class MyErrorExtension
public static string MergeErrors(this string error, string other) => $""{error}{System.Environment.NewLine}{other}"";
}
";
return Verify(code);
return Verify(code, numberOfGeneratedFiles: 4);
}

[TestMethod]
Expand Down Expand Up @@ -82,7 +82,7 @@ public static class MyErrorExtension
}
}
";
return Verify(code);
return Verify(code, numberOfGeneratedFiles: 2);
}

[TestMethod]
Expand All @@ -103,7 +103,7 @@ public enum MyError
Unauthorized
}
";
return Verify(code);
return Verify(code, numberOfGeneratedFiles: 0);
}


Expand Down Expand Up @@ -231,6 +231,6 @@ public override bool Equals(object obj)
public override int GetHashCode() => (int)UnionCase;
}
";
return Verify(code);
return Verify(code, numberOfGeneratedFiles: 2);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//HintName: Attributes.g.cs
#nullable enable
using System;

// ReSharper disable once CheckNamespace
namespace FunicularSwitch.Generators
{
/// <summary>
/// Mark an abstract partial type with a single generic argument with the ResultType attribute.
/// This type from now on has Ok | Error semantics with map and bind operations.
/// </summary>
[AttributeUsage(AttributeTargets.Class, Inherited = false)]
sealed class ResultTypeAttribute : Attribute
{
public ResultTypeAttribute() => ErrorType = typeof(string);
public ResultTypeAttribute(Type errorType) => ErrorType = errorType;

public Type ErrorType { get; set; }
}

/// <summary>
/// Mark a static method or a member method or you error type with the MergeErrorAttribute attribute.
/// Static signature: TError -> TError -> TError. Member signature: TError -> TError
/// We are now able to collect errors and methods like Validate, Aggregate, FirstOk that are useful to combine results are generated.
/// </summary>
[AttributeUsage(AttributeTargets.Method, Inherited = false)]
sealed class MergeErrorAttribute : Attribute
{
}

/// <summary>
/// Mark a static method with the ExceptionToError attribute.
/// Signature: Exception -> TError
/// This method is always called, when an exception happens in a bind operation.
/// So a call like result.Map(i => i/0) will return an Error produced by the factory method instead of throwing the DivisionByZero exception.
/// </summary>
[AttributeUsage(AttributeTargets.Method, Inherited = false)]
sealed class ExceptionToError : Attribute
{
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
//HintName: Attributes.g.cs
#nullable enable
using System;

// ReSharper disable once CheckNamespace
namespace FunicularSwitch.Generators
{
[AttributeUsage(AttributeTargets.Class | AttributeTargets.Interface, Inherited = false)]
sealed class UnionTypeAttribute : Attribute
{
public CaseOrder CaseOrder { get; set; } = CaseOrder.Alphabetic;
public bool StaticFactoryMethods { get; set; } = true;
}

enum CaseOrder
{
Alphabetic,
AsDeclared,
Explicit
}

[AttributeUsage(AttributeTargets.Class, Inherited = false)]
sealed class UnionCaseAttribute : Attribute
{
public UnionCaseAttribute(int index) => Index = index;

public int Index { get; }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
//HintName: Attributes.g.cs
#nullable enable
using System;

// ReSharper disable once CheckNamespace
namespace FunicularSwitch.Generators
{
[AttributeUsage(AttributeTargets.Enum)]
sealed class ExtendedEnumAttribute : Attribute
{
public EnumCaseOrder CaseOrder { get; set; } = EnumCaseOrder.AsDeclared;
public ExtensionAccessibility Accessibility { get; set; } = ExtensionAccessibility.Public;
}

enum EnumCaseOrder
{
Alphabetic,
AsDeclared
}

/// <summary>
/// Generate match methods for all enums defined in assembly that contains AssemblySpecifier.
/// </summary>
[AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)]
class ExtendEnumsAttribute : Attribute
{
public Type AssemblySpecifier { get; }
public EnumCaseOrder CaseOrder { get; set; } = EnumCaseOrder.AsDeclared;
public ExtensionAccessibility Accessibility { get; set; } = ExtensionAccessibility.Public;

public ExtendEnumsAttribute() => AssemblySpecifier = typeof(ExtendEnumsAttribute);

public ExtendEnumsAttribute(Type assemblySpecifier)
{
AssemblySpecifier = assemblySpecifier;
}
}

/// <summary>
/// Generate match methods for Type. Must be enum.
/// </summary>
[AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)]
class ExtendEnumAttribute : Attribute
{
public Type Type { get; }

public EnumCaseOrder CaseOrder { get; set; } = EnumCaseOrder.AsDeclared;

public ExtensionAccessibility Accessibility { get; set; } = ExtensionAccessibility.Public;

public ExtendEnumAttribute(Type type)
{
Type = type;
}
}

enum ExtensionAccessibility
{
Internal,
Public
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
//HintName: FunicularSwitchTestBaseTypeOfTMatchExtension.g.cs
#pragma warning disable 1591
#nullable enable
namespace FunicularSwitch.Test
{
public static partial class BaseTypeMatchExtension
{
public static TMatchResult Match<T, TMatchResult>(this global::FunicularSwitch.Test.BaseType<T> baseType, global::System.Func<FunicularSwitch.Test.Deriving<T>, TMatchResult> deriving, global::System.Func<FunicularSwitch.Test.Deriving2<T>, TMatchResult> deriving2) =>
baseType switch
{
FunicularSwitch.Test.Deriving<T> deriving1 => deriving(deriving1),
FunicularSwitch.Test.Deriving2<T> deriving22 => deriving2(deriving22),
_ => throw new global::System.ArgumentException($"Unknown type derived from FunicularSwitch.Test.BaseType: {baseType.GetType().Name}")
};

public static global::System.Threading.Tasks.Task<TMatchResult> Match<T, TMatchResult>(this global::FunicularSwitch.Test.BaseType<T> baseType, global::System.Func<FunicularSwitch.Test.Deriving<T>, global::System.Threading.Tasks.Task<TMatchResult>> deriving, global::System.Func<FunicularSwitch.Test.Deriving2<T>, global::System.Threading.Tasks.Task<TMatchResult>> deriving2) =>
baseType switch
{
FunicularSwitch.Test.Deriving<T> deriving1 => deriving(deriving1),
FunicularSwitch.Test.Deriving2<T> deriving22 => deriving2(deriving22),
_ => throw new global::System.ArgumentException($"Unknown type derived from FunicularSwitch.Test.BaseType: {baseType.GetType().Name}")
};

public static async global::System.Threading.Tasks.Task<TMatchResult> Match<T, TMatchResult>(this global::System.Threading.Tasks.Task<global::FunicularSwitch.Test.BaseType<T>> baseType, global::System.Func<FunicularSwitch.Test.Deriving<T>, TMatchResult> deriving, global::System.Func<FunicularSwitch.Test.Deriving2<T>, TMatchResult> deriving2) =>
(await baseType.ConfigureAwait(false)).Match(deriving, deriving2);

public static async global::System.Threading.Tasks.Task<TMatchResult> Match<T, TMatchResult>(this global::System.Threading.Tasks.Task<global::FunicularSwitch.Test.BaseType<T>> baseType, global::System.Func<FunicularSwitch.Test.Deriving<T>, global::System.Threading.Tasks.Task<TMatchResult>> deriving, global::System.Func<FunicularSwitch.Test.Deriving2<T>, global::System.Threading.Tasks.Task<TMatchResult>> deriving2) =>
await (await baseType.ConfigureAwait(false)).Match(deriving, deriving2).ConfigureAwait(false);

public static void Switch<T>(this global::FunicularSwitch.Test.BaseType<T> baseType, global::System.Action<FunicularSwitch.Test.Deriving<T>> deriving, global::System.Action<FunicularSwitch.Test.Deriving2<T>> deriving2)
{
switch (baseType)
{
case FunicularSwitch.Test.Deriving<T> deriving1:
deriving(deriving1);
break;
case FunicularSwitch.Test.Deriving2<T> deriving22:
deriving2(deriving22);
break;
default:
throw new global::System.ArgumentException($"Unknown type derived from FunicularSwitch.Test.BaseType: {baseType.GetType().Name}");
}
}

public static async global::System.Threading.Tasks.Task Switch<T>(this global::FunicularSwitch.Test.BaseType<T> baseType, global::System.Func<FunicularSwitch.Test.Deriving<T>, global::System.Threading.Tasks.Task> deriving, global::System.Func<FunicularSwitch.Test.Deriving2<T>, global::System.Threading.Tasks.Task> deriving2)
{
switch (baseType)
{
case FunicularSwitch.Test.Deriving<T> deriving1:
await deriving(deriving1).ConfigureAwait(false);
break;
case FunicularSwitch.Test.Deriving2<T> deriving22:
await deriving2(deriving22).ConfigureAwait(false);
break;
default:
throw new global::System.ArgumentException($"Unknown type derived from FunicularSwitch.Test.BaseType: {baseType.GetType().Name}");
}
}

public static async global::System.Threading.Tasks.Task Switch<T>(this global::System.Threading.Tasks.Task<global::FunicularSwitch.Test.BaseType<T>> baseType, global::System.Action<FunicularSwitch.Test.Deriving<T>> deriving, global::System.Action<FunicularSwitch.Test.Deriving2<T>> deriving2) =>
(await baseType.ConfigureAwait(false)).Switch(deriving, deriving2);

public static async global::System.Threading.Tasks.Task Switch<T>(this global::System.Threading.Tasks.Task<global::FunicularSwitch.Test.BaseType<T>> baseType, global::System.Func<FunicularSwitch.Test.Deriving<T>, global::System.Threading.Tasks.Task> deriving, global::System.Func<FunicularSwitch.Test.Deriving2<T>, global::System.Threading.Tasks.Task> deriving2) =>
await (await baseType.ConfigureAwait(false)).Switch(deriving, deriving2).ConfigureAwait(false);
}

public abstract partial record BaseType<T>
{
public static FunicularSwitch.Test.BaseType<T> Deriving(string Value, T Other) => new FunicularSwitch.Test.Deriving<T>(Value, Other);
public static FunicularSwitch.Test.BaseType<T> Deriving2(string Value) => new FunicularSwitch.Test.Deriving2<T>(Value);
}
}
#pragma warning restore 1591
Loading
Loading