Skip to content

Commit c9d1370

Browse files
committed
[Analyzer] Multiple restrictions for regular unions
* Union records must be sealed (because of copy constructor) * Union class must be sealed or have private constructors * Derived type of a union must not be less accessible than the base union * Derived type of a union must not have unbound generic parameters
1 parent 6c5568d commit c9d1370

18 files changed

Lines changed: 553 additions & 101 deletions

File tree

Directory.Build.props

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
<PropertyGroup>
44
<Copyright>(c) $([System.DateTime]::Now.Year), Pawel Gerr. All rights reserved.</Copyright>
5-
<VersionPrefix>8.6.0</VersionPrefix>
5+
<VersionPrefix>8.6.1</VersionPrefix>
66
<Authors>Pawel Gerr</Authors>
77
<GenerateDocumentationFile>true</GenerateDocumentationFile>
88
<PackageProjectUrl>https://github.com/PawelGerr/Thinktecture.Runtime.Extensions</PackageProjectUrl>

docs

Submodule docs updated from e92bf0f to 365caf2

samples/Thinktecture.Runtime.Extensions.EntityFrameworkCore.Samples/MessageState.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ public abstract partial record MessageState
77
{
88
public int Order { get; }
99

10-
public record Initial : MessageState;
10+
public sealed record Initial : MessageState;
1111

12-
public record Parsed(DateTime CreatedAt) : MessageState;
12+
public sealed record Parsed(DateTime CreatedAt) : MessageState;
1313

14-
public record Processed(DateTime CreatedAt) : MessageState;
14+
public sealed record Processed(DateTime CreatedAt) : MessageState;
1515

16-
public record Error(string Message) : MessageState;
16+
public sealed record Error(string Message) : MessageState;
1717
}

samples/Thinktecture.Runtime.Extensions.Samples/Unions/Result.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ namespace Thinktecture.Unions;
44
MapMethods = SwitchMapMethodsGeneration.DefaultWithPartialOverloads)]
55
public partial record Result<T>
66
{
7-
public record Success(T Value) : Result<T>;
7+
public sealed record Success(T Value) : Result<T>;
88

9-
public record Failure(string Error) : Result<T>;
9+
public sealed record Failure(string Error) : Result<T>;
1010
}

src/Thinktecture.Runtime.Extensions.SourceGenerator/CodeAnalysis/CodeFixes/ThinktectureRuntimeExtensionsCodeFixProvider.cs

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ public sealed class ThinktectureRuntimeExtensionsCodeFixProvider : CodeFixProvid
3838
DiagnosticsDescriptors.ExplicitComparerWithoutEqualityComparer.Id,
3939
DiagnosticsDescriptors.ExplicitEqualityComparerWithoutComparer.Id,
4040
DiagnosticsDescriptors.MethodWithUseDelegateFromConstructorMustBePartial.Id,
41+
DiagnosticsDescriptors.UnionRecordMustBeSealed.Id,
4142
];
4243

4344
/// <inheritdoc />
@@ -93,7 +94,7 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context)
9394
}
9495
else if (diagnostic.Id == DiagnosticsDescriptors.EnumWithoutDerivedTypesMustBeSealed.Id)
9596
{
96-
context.RegisterCodeFix(CodeAction.Create(_SEAL_CLASS, _ => AddTypeModifierAsync(context.Document, root, GetCodeFixesContext().TypeDeclaration, SyntaxKind.SealedKeyword), _SEAL_CLASS), diagnostic);
97+
context.RegisterCodeFix(CodeAction.Create(_SEAL_CLASS, _ => ReplaceOrAddTypeModifierAsync(context.Document, root, GetCodeFixesContext().TypeDeclaration, SyntaxKind.AbstractKeyword, SyntaxKind.SealedKeyword), _SEAL_CLASS), diagnostic);
9798
}
9899
else if (diagnostic.Id == DiagnosticsDescriptors.StringBasedValueObjectNeedsEqualityComparer.Id)
99100
{
@@ -115,6 +116,10 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context)
115116
{
116117
context.RegisterCodeFix(CodeAction.Create(_MAKE_METHOD_PARTIAL, _ => AddTypeModifierAsync(context.Document, root, GetCodeFixesContext().MethodDeclaration, SyntaxKind.PartialKeyword), _MAKE_METHOD_PARTIAL), diagnostic);
117118
}
119+
else if (diagnostic.Id == DiagnosticsDescriptors.UnionRecordMustBeSealed.Id)
120+
{
121+
context.RegisterCodeFix(CodeAction.Create(_SEAL_CLASS, _ => ReplaceOrAddTypeModifierAsync(context.Document, root, GetCodeFixesContext().TypeDeclaration, SyntaxKind.AbstractKeyword, SyntaxKind.SealedKeyword), _SEAL_CLASS), diagnostic);
122+
}
118123
}
119124
}
120125

@@ -136,6 +141,38 @@ private static Task<Document> AddTypeModifierAsync(
136141
return Task.FromResult(newDoc);
137142
}
138143

144+
private static Task<Document> ReplaceOrAddTypeModifierAsync(
145+
Document document,
146+
SyntaxNode root,
147+
MemberDeclarationSyntax? declaration,
148+
SyntaxKind oldModifier,
149+
SyntaxKind newModifier)
150+
{
151+
if (declaration is null)
152+
return Task.FromResult(document);
153+
154+
var newModifierToken = SyntaxFactory.Token(newModifier);
155+
var indexOfOldModifier = declaration.Modifiers.IndexOf(oldModifier);
156+
157+
MemberDeclarationSyntax newDeclaration;
158+
159+
if (indexOfOldModifier >= 0)
160+
{
161+
var oldToken = declaration.Modifiers[indexOfOldModifier];
162+
newDeclaration = declaration.ReplaceToken(oldToken, newModifierToken);
163+
}
164+
else
165+
{
166+
var indexOfPartialKeyword = declaration.Modifiers.IndexOf(SyntaxKind.PartialKeyword);
167+
newDeclaration = indexOfPartialKeyword < 0 ? declaration.AddModifiers(newModifierToken) : declaration.WithModifiers(declaration.Modifiers.Insert(indexOfPartialKeyword, newModifierToken));
168+
}
169+
170+
var newRoot = root.ReplaceNode(declaration, newDeclaration);
171+
var newDoc = document.WithSyntaxRoot(newRoot);
172+
173+
return Task.FromResult(newDoc);
174+
}
175+
139176
private static Task<Document> ChangeAccessibilityAsync(
140177
Document document,
141178
SyntaxNode root,

src/Thinktecture.Runtime.Extensions.SourceGenerator/CodeAnalysis/Diagnostics/ThinktectureRuntimeExtensionsAnalyzer.cs

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ public sealed class ThinktectureRuntimeExtensionsAnalyzer : DiagnosticAnalyzer
5151
DiagnosticsDescriptors.MethodWithUseDelegateFromConstructorMustBePartial,
5252
DiagnosticsDescriptors.MethodWithUseDelegateFromConstructorMustNotHaveGenerics,
5353
DiagnosticsDescriptors.TypeMustNotBeInsideGenericType,
54-
DiagnosticsDescriptors.NonAbstractUnionDerivedTypesMustNotBeGeneric
54+
DiagnosticsDescriptors.UnionDerivedTypesMustNotBeGeneric,
55+
DiagnosticsDescriptors.UnionMustBeSealedOrHavePrivateConstructorsOnly,
56+
DiagnosticsDescriptors.UnionRecordMustBeSealed,
57+
DiagnosticsDescriptors.NonAbstractDerivedUnionIsLessAccessibleThanBaseUnion,
5558
];
5659

5760
/// <inheritdoc />
@@ -429,20 +432,7 @@ private static void ValidateUnion(
429432
{
430433
CheckConstructors(context, type, mustBePrivate: true, canHavePrimaryConstructor: false);
431434
TypeMustBePartial(context, type);
432-
NonAbstractDerivedTypesMustNotBeGeneric(context, type);
433-
}
434-
435-
private static void NonAbstractDerivedTypesMustNotBeGeneric(OperationAnalysisContext context, INamedTypeSymbol unionType)
436-
{
437-
var derivedTypes = unionType.FindDerivedInnerTypes();
438-
439-
for (var i = 0; i < derivedTypes.Count; i++)
440-
{
441-
var (type, _, _) = derivedTypes[i];
442-
443-
if (!type.IsAbstract && type.Arity != 0)
444-
ReportDiagnostic(context, DiagnosticsDescriptors.NonAbstractUnionDerivedTypesMustNotBeGeneric, GetDerivedTypeLocation(context, type), type);
445-
}
435+
ValidateUnionDerivedTypes(context, type);
446436
}
447437

448438
private static void ValidateKeyedValueObject(
@@ -717,7 +707,7 @@ private static void ValidateEnum(
717707
baseClass = baseClass.BaseType;
718708
}
719709

720-
ValidateDerivedTypes(context, enumType);
710+
ValidateEnumDerivedTypes(context, enumType);
721711

722712
EnumKeyMemberNameMustNotBeItem(context, attribute, locationOfFirstDeclaration);
723713

@@ -786,27 +776,27 @@ private static void Check_ItemLike_StaticProperties(
786776
}
787777
}
788778

789-
private static void ValidateDerivedTypes(OperationAnalysisContext context, INamedTypeSymbol enumType)
779+
private static void ValidateEnumDerivedTypes(OperationAnalysisContext context, INamedTypeSymbol type)
790780
{
791-
var derivedTypes = enumType.FindDerivedInnerTypes();
781+
var derivedTypes = type.FindDerivedInnerTypes();
792782
var typesToLeaveOpen = ImmutableArray.Create<INamedTypeSymbol>();
793783

794784
for (var i = 0; i < derivedTypes.Count; i++)
795785
{
796-
var (type, _, level) = derivedTypes[i];
786+
var (derivedType, _, level) = derivedTypes[i];
797787

798788
if (level == 1)
799789
{
800-
if (type.DeclaredAccessibility != Accessibility.Private)
801-
ReportDiagnostic(context, DiagnosticsDescriptors.InnerEnumOnFirstLevelMustBePrivate, GetDerivedTypeLocation(context, type), type);
790+
if (derivedType.DeclaredAccessibility != Accessibility.Private)
791+
ReportDiagnostic(context, DiagnosticsDescriptors.InnerEnumOnFirstLevelMustBePrivate, GetDerivedTypeLocation(context, derivedType), derivedType);
802792
}
803-
else if (type.DeclaredAccessibility != Accessibility.Public)
793+
else if (derivedType.DeclaredAccessibility != Accessibility.Public)
804794
{
805-
ReportDiagnostic(context, DiagnosticsDescriptors.InnerEnumOnNonFirstLevelMustBePublic, GetDerivedTypeLocation(context, type), type);
795+
ReportDiagnostic(context, DiagnosticsDescriptors.InnerEnumOnNonFirstLevelMustBePublic, GetDerivedTypeLocation(context, derivedType), derivedType);
806796
}
807797

808-
if (!type.BaseType.IsNullOrObject())
809-
typesToLeaveOpen = typesToLeaveOpen.Add(type.BaseType);
798+
if (!derivedType.BaseType.IsNullOrObject())
799+
typesToLeaveOpen = typesToLeaveOpen.Add(derivedType.BaseType);
810800
}
811801

812802
for (var i = 0; i < derivedTypes.Count; i++)
@@ -818,6 +808,34 @@ private static void ValidateDerivedTypes(OperationAnalysisContext context, IName
818808
}
819809
}
820810

811+
private static void ValidateUnionDerivedTypes(OperationAnalysisContext context, INamedTypeSymbol type)
812+
{
813+
var derivedTypes = type.FindDerivedInnerTypes();
814+
815+
for (var i = 0; i < derivedTypes.Count; i++)
816+
{
817+
var (derivedType, _, _) = derivedTypes[i];
818+
819+
if (derivedType.Arity != 0)
820+
ReportDiagnostic(context, DiagnosticsDescriptors.UnionDerivedTypesMustNotBeGeneric, GetDerivedTypeLocation(context, derivedType), derivedType);
821+
822+
if (!derivedType.IsAbstract && derivedType.HasLowerAccessibility(type.DeclaredAccessibility, type))
823+
ReportDiagnostic(context, DiagnosticsDescriptors.NonAbstractDerivedUnionIsLessAccessibleThanBaseUnion, GetDerivedTypeLocation(context, derivedType), derivedType, type);
824+
825+
if (!derivedType.IsSealed)
826+
{
827+
if (derivedType.IsRecord)
828+
{
829+
ReportDiagnostic(context, DiagnosticsDescriptors.UnionRecordMustBeSealed, GetDerivedTypeLocation(context, derivedType), derivedType);
830+
}
831+
else if (derivedType.Constructors.Any(ctor => ctor.DeclaredAccessibility != Accessibility.Private))
832+
{
833+
ReportDiagnostic(context, DiagnosticsDescriptors.UnionMustBeSealedOrHavePrivateConstructorsOnly, GetDerivedTypeLocation(context, derivedType), derivedType);
834+
}
835+
}
836+
}
837+
}
838+
821839
private static Location GetDerivedTypeLocation(OperationAnalysisContext context, INamedTypeSymbol derivedType)
822840
{
823841
return ((TypeDeclarationSyntax)derivedType.DeclaringSyntaxReferences.First().GetSyntax(context.CancellationToken)).Identifier.GetLocation();

src/Thinktecture.Runtime.Extensions.SourceGenerator/CodeAnalysis/DiagnosticsDescriptors.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@ internal static class DiagnosticsDescriptors
3434
public static readonly DiagnosticDescriptor MethodWithUseDelegateFromConstructorMustBePartial = new("TTRESG050", $"Method with '{Constants.Attributes.UseDelegateFromConstructorAttribute.NAME}' must be partial", $"The method '{{0}}' with '{Constants.Attributes.UseDelegateFromConstructorAttribute.NAME}' must be marked as partial", nameof(ThinktectureRuntimeExtensionsAnalyzer), DiagnosticSeverity.Error, true);
3535
public static readonly DiagnosticDescriptor MethodWithUseDelegateFromConstructorMustNotHaveGenerics = new("TTRESG051", $"Method with '{Constants.Attributes.UseDelegateFromConstructorAttribute.NAME}' must not have generics", $"The method '{{0}}' with '{Constants.Attributes.UseDelegateFromConstructorAttribute.NAME}' must not have generic type parameters. Use inheritance approach instead.", nameof(ThinktectureRuntimeExtensionsAnalyzer), DiagnosticSeverity.Error, true);
3636
public static readonly DiagnosticDescriptor TypeMustNotBeInsideGenericType = new("TTRESG052", "The type must not be inside generic type", "Type '{0}' must not be inside a generic type", nameof(ThinktectureRuntimeExtensionsAnalyzer), DiagnosticSeverity.Error, true);
37-
public static readonly DiagnosticDescriptor NonAbstractUnionDerivedTypesMustNotBeGeneric = new("TTRESG053", "Non-abstract derived type of a union must not be generic", "Non-abstract derived type '{0}' of a union must not be generic", nameof(ThinktectureRuntimeExtensionsAnalyzer), DiagnosticSeverity.Error, true);
37+
public static readonly DiagnosticDescriptor UnionDerivedTypesMustNotBeGeneric = new("TTRESG053", "Derived type of a union must not be generic", "Derived type '{0}' of a union must not be generic", nameof(ThinktectureRuntimeExtensionsAnalyzer), DiagnosticSeverity.Error, true);
38+
public static readonly DiagnosticDescriptor UnionMustBeSealedOrHavePrivateConstructorsOnly = new("TTRESG054", "Discriminated union must be sealed or have private constructors only", "Discriminated union '{0}' must be sealed or have private constructors only", nameof(ThinktectureRuntimeExtensionsAnalyzer), DiagnosticSeverity.Error, true);
39+
public static readonly DiagnosticDescriptor UnionRecordMustBeSealed = new("TTRESG055", "Discriminated union implemented using a record must be sealed", "Discriminated union '{0}' using a record must be sealed", nameof(ThinktectureRuntimeExtensionsAnalyzer), DiagnosticSeverity.Error, true);
40+
public static readonly DiagnosticDescriptor NonAbstractDerivedUnionIsLessAccessibleThanBaseUnion = new("TTRESG056", "Non-abstract derived union is less accessible than base union", "Non-abstract derived union '{0}' is less accessible than base union '{1}'", nameof(ThinktectureRuntimeExtensionsAnalyzer), DiagnosticSeverity.Error, true);
3841

3942
public static readonly DiagnosticDescriptor ErrorDuringCodeAnalysis = new("TTRESG098", "Error during code analysis", "Error during code analysis of '{0}': '{1}'", nameof(ThinktectureRuntimeExtensionsAnalyzer), DiagnosticSeverity.Warning, true);
4043
public static readonly DiagnosticDescriptor ErrorDuringGeneration = new("TTRESG099", "Error during code generation", "Error during code generation for '{0}': '{1}'", nameof(ThinktectureRuntimeExtensionsAnalyzer), DiagnosticSeverity.Error, true);

src/Thinktecture.Runtime.Extensions.SourceGenerator/CodeAnalysis/Unions/UnionSourceGenerator.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,9 @@ private bool IsCandidate(SyntaxNode syntaxNode, CancellationToken cancellationTo
115115
{
116116
var derivedTypeInfo = derivedTypeInfos[i];
117117

118-
if (!derivedTypeInfo.Type.IsAbstract && derivedTypeInfo.Type.Arity != 0)
118+
if (derivedTypeInfo.Type.Arity != 0)
119119
{
120-
Logger.LogDebug("Derived type of a union must not have generic parameters, unless it is abstract", tds);
120+
Logger.LogDebug("Derived type of a union must not have generic parameters", tds);
121121
return null;
122122
}
123123

src/Thinktecture.Runtime.Extensions.SourceGenerator/Extensions/NamedTypeSymbolExtensions.cs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ public static bool IsNestedInGenericClass(this INamedTypeSymbol type)
8787
{
8888
var containingType = type.ContainingType;
8989

90-
while (containingType != null)
90+
while (containingType is not null)
9191
{
9292
if (!containingType.TypeParameters.IsDefaultOrEmpty)
9393
return true;
@@ -97,4 +97,23 @@ public static bool IsNestedInGenericClass(this INamedTypeSymbol type)
9797

9898
return false;
9999
}
100+
101+
public static bool HasLowerAccessibility(
102+
this INamedTypeSymbol type,
103+
Accessibility accessibility,
104+
INamedTypeSymbol stopType)
105+
{
106+
var containingType = type;
107+
108+
while (containingType is not null
109+
&& !SymbolEqualityComparer.Default.Equals(containingType, stopType))
110+
{
111+
if (containingType.DeclaredAccessibility < accessibility)
112+
return true;
113+
114+
containingType = containingType.ContainingType;
115+
}
116+
117+
return false;
118+
}
100119
}

test/Thinktecture.Runtime.Extensions.SourceGenerator.Tests/AnalyzerAndCodeFixTests/TTRESG053_NonAbstractUnionDerivedTypesMustNotBeGeneric.cs

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,22 @@ public class TTRESG053_NonAbstractUnionDerivedTypesMustNotBeGeneric
1010

1111
public class Non_abstract_unions_must_not_be_generic
1212
{
13-
[Fact]
14-
public async Task Should_trigger_on_generic_class()
13+
[Theory]
14+
[InlineData("class")]
15+
[InlineData("record")]
16+
public async Task Should_trigger_on_generic_class(string type)
1517
{
16-
var code = """
18+
var code = $$"""
1719
1820
using System;
1921
using Thinktecture;
2022
2123
namespace TestNamespace
2224
{
2325
[Union]
24-
public partial class TestUnion<T>
26+
public partial {{type}} TestUnion<T>
2527
{
26-
public class {|#0:First|}<T>(T Value) : TestUnion<T>;
28+
public sealed {{type}} {|#0:First|}<T>(T Value) : TestUnion<T>;
2729
}
2830
}
2931
""";
@@ -33,7 +35,7 @@ public class {|#0:First|}<T>(T Value) : TestUnion<T>;
3335
}
3436

3537
[Fact]
36-
public async Task Should_not_trigger_on_non_generic_class()
38+
public async Task Should_trigger_on_generic_abstract_class()
3739
{
3840
var code = """
3941
@@ -45,7 +47,36 @@ namespace TestNamespace
4547
[Union]
4648
public partial class TestUnion<T>
4749
{
48-
public class {|#0:First|}(T Value) : TestUnion<T>;
50+
public abstract class {|#0:First|}<T> : TestUnion<T>
51+
{
52+
private First(T Value)
53+
{
54+
}
55+
}
56+
}
57+
}
58+
""";
59+
60+
var expected = Verifier.Diagnostic(_DIAGNOSTIC_ID).WithLocation(0).WithArguments("First<T>");
61+
await Verifier.VerifyAnalyzerAsync(code, [typeof(UnionAttribute).Assembly], expected);
62+
}
63+
64+
[Theory]
65+
[InlineData("class")]
66+
[InlineData("record")]
67+
public async Task Should_not_trigger_on_non_generic_class(string type)
68+
{
69+
var code = $$"""
70+
71+
using System;
72+
using Thinktecture;
73+
74+
namespace TestNamespace
75+
{
76+
[Union]
77+
public partial {{type}} TestUnion<T>
78+
{
79+
public sealed {{type}} {|#0:First|}(T Value) : TestUnion<T>;
4980
}
5081
}
5182
""";

0 commit comments

Comments
 (0)