diff --git a/src/OrchardCore/OrchardCore.ContentTypes.Abstractions/Shapes/ContentCardFieldsEditShape.cs b/src/OrchardCore/OrchardCore.ContentTypes.Abstractions/Shapes/ContentCardFieldsEditShape.cs index 3a263cbd431..f99de4f120d 100644 --- a/src/OrchardCore/OrchardCore.ContentTypes.Abstractions/Shapes/ContentCardFieldsEditShape.cs +++ b/src/OrchardCore/OrchardCore.ContentTypes.Abstractions/Shapes/ContentCardFieldsEditShape.cs @@ -2,7 +2,8 @@ namespace OrchardCore.ContentTypes.Shapes; -public class ContentCardFieldsEditShape +[GenerateShape] +public partial class ContentCardFieldsEditShape { public IShape CardShape { get; set; } } diff --git a/src/OrchardCore/OrchardCore.ContentTypes.Abstractions/Shapes/ContentCardFrameShape.cs b/src/OrchardCore/OrchardCore.ContentTypes.Abstractions/Shapes/ContentCardFrameShape.cs index b7b422d2deb..de9937d0713 100644 --- a/src/OrchardCore/OrchardCore.ContentTypes.Abstractions/Shapes/ContentCardFrameShape.cs +++ b/src/OrchardCore/OrchardCore.ContentTypes.Abstractions/Shapes/ContentCardFrameShape.cs @@ -2,7 +2,8 @@ namespace OrchardCore.ContentTypes.Shapes; -public class ContentCardFrameShape +[GenerateShape] +public partial class ContentCardFrameShape { public IShape ChildContent { get; set; } public int? ColumnSize { get; set; } diff --git a/src/OrchardCore/OrchardCore.ContentTypes.Abstractions/Shapes/ContentCardShape.cs b/src/OrchardCore/OrchardCore.ContentTypes.Abstractions/Shapes/ContentCardShape.cs index 41343f60824..a9671232213 100644 --- a/src/OrchardCore/OrchardCore.ContentTypes.Abstractions/Shapes/ContentCardShape.cs +++ b/src/OrchardCore/OrchardCore.ContentTypes.Abstractions/Shapes/ContentCardShape.cs @@ -5,7 +5,8 @@ namespace OrchardCore.ContentTypes.Shapes; -public class ContentCardShape +[GenerateShape] +public partial class ContentCardShape { public IUpdateModel Updater { get; set; } public string CollectionShapeType { get; set; } diff --git a/src/OrchardCore/OrchardCore.DisplayManagement/GenerateShapeAttribute.cs b/src/OrchardCore/OrchardCore.DisplayManagement/GenerateShapeAttribute.cs new file mode 100644 index 00000000000..933c97c0d52 --- /dev/null +++ b/src/OrchardCore/OrchardCore.DisplayManagement/GenerateShapeAttribute.cs @@ -0,0 +1,14 @@ +namespace OrchardCore.DisplayManagement; + +/// +/// Marks a model type to generate a compile-time implementation. +/// +/// +/// Apply this attribute to a partial class with an accessible parameterless constructor. +/// The source generator will implement and directly on the model type, +/// which avoids both interceptors and runtime proxy generation for that model. +/// +[AttributeUsage(AttributeTargets.Class, AllowMultiple = false, Inherited = false)] +public sealed class GenerateShapeAttribute : Attribute +{ +} diff --git a/src/OrchardCore/OrchardCore.DisplayManagement/ShapeFactoryExtensions.cs b/src/OrchardCore/OrchardCore.DisplayManagement/ShapeFactoryExtensions.cs index dd59ff1b78b..438460bcfe4 100644 --- a/src/OrchardCore/OrchardCore.DisplayManagement/ShapeFactoryExtensions.cs +++ b/src/OrchardCore/OrchardCore.DisplayManagement/ShapeFactoryExtensions.cs @@ -7,7 +7,7 @@ namespace OrchardCore.DisplayManagement; public static class ShapeFactoryExtensions { - private static readonly ConcurrentDictionary _proxyTypesCache = []; + private static readonly ConcurrentDictionary _proxyTypeCache = []; private static readonly ProxyGenerator _proxyGenerator = new(); /// @@ -187,15 +187,12 @@ static async ValueTask Awaited(ValueTask task, IShape shape) /// private static IShape CreateStronglyTypedShape(Type baseType) { - var shapeType = baseType; - - // Don't generate a proxy for shape types. - if (typeof(IShape).IsAssignableFrom(shapeType)) + if (typeof(IShape).IsAssignableFrom(baseType)) { return (IShape)Activator.CreateInstance(baseType); } - if (_proxyTypesCache.TryGetValue(baseType, out var proxyType)) + if (_proxyTypeCache.TryGetValue(baseType, out var proxyType)) { var model = new ShapeViewModel(); @@ -206,7 +203,7 @@ private static IShape CreateStronglyTypedShape(Type baseType) options.AddMixinInstance(new ShapeViewModel()); var shape = (IShape)_proxyGenerator.CreateClassProxy(baseType, options); - _proxyTypesCache.TryAdd(baseType, shape.GetType()); + _proxyTypeCache.TryAdd(baseType, shape.GetType()); return shape; } diff --git a/src/OrchardCore/OrchardCore.SourceGenerators/ShapeFactoryGenerator.cs b/src/OrchardCore/OrchardCore.SourceGenerators/ShapeFactoryGenerator.cs new file mode 100644 index 00000000000..c426d41261d --- /dev/null +++ b/src/OrchardCore/OrchardCore.SourceGenerators/ShapeFactoryGenerator.cs @@ -0,0 +1,900 @@ +using System.Collections.Immutable; +using System.Security.Cryptography; +using System.Text; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Text; + +#nullable enable + +namespace OrchardCore.DisplayManagement.SourceGenerators; + +[Generator] +public class ShapeFactoryGenerator : IIncrementalGenerator +{ + private const string GenerateShapeAttributeFullName = "OrchardCore.DisplayManagement.GenerateShapeAttribute"; + private const string ShapeFactoryExtensionsFullName = "OrchardCore.DisplayManagement.ShapeFactoryExtensions"; + private const string IShapeFactoryFullName = "OrchardCore.DisplayManagement.IShapeFactory"; + private const string IShapeFullName = "OrchardCore.DisplayManagement.IShape"; + private const string DisplayDriverBaseFullName = "OrchardCore.DisplayManagement.Handlers.DisplayDriverBase"; + private const string ValueTaskFullName = "System.Threading.Tasks.ValueTask"; + + public void Initialize(IncrementalGeneratorInitializationContext context) + { + var invocations = context.SyntaxProvider + .CreateSyntaxProvider( + predicate: static (node, _) => node is InvocationExpressionSyntax, + transform: static (context, ct) => GetInvocationInfo(context, ct)) + .Where(static info => info is not null); + + var attributedTypes = context.SyntaxProvider + .ForAttributeWithMetadataName( + GenerateShapeAttributeFullName, + predicate: static (node, _) => node is ClassDeclarationSyntax or RecordDeclarationSyntax, + transform: static (context, _) => GetAttributedModelType(context)) + .Where(static type => type is not null); + + context.RegisterSourceOutput( + invocations.Collect(), + static (spc, invocations) => Execute(invocations!, spc)); + + context.RegisterSourceOutput( + attributedTypes.Collect(), + static (spc, types) => ExecuteAttributedTypes(types!, spc)); + } + + private static InvocationInfo? GetInvocationInfo(GeneratorSyntaxContext context, CancellationToken cancellationToken) + { + var invocation = (InvocationExpressionSyntax)context.Node; + var symbolInfo = context.SemanticModel.GetSymbolInfo(invocation, cancellationToken); + + if (symbolInfo.Symbol is not IMethodSymbol methodSymbol) + { + return null; + } + + var targetMethod = methodSymbol.ReducedFrom ?? methodSymbol.OriginalDefinition; + + if (!methodSymbol.IsGenericMethod || + methodSymbol.TypeArguments.Length == 0 || + methodSymbol.TypeArguments[0] is not INamedTypeSymbol modelType || + modelType.TypeKind != TypeKind.Class || + modelType.IsAbstract || + ImplementsIShape(modelType) || + HasGenerateShapeAttribute(modelType) || + !HasAccessibleParameterlessConstructor(modelType)) + { + return null; + } + + var logicalParameters = GetLogicalParameters(methodSymbol); + var invocationKind = GetInvocationKind(targetMethod, logicalParameters, modelType); + + if (invocationKind is null) + { + return null; + } + + var location = context.SemanticModel.GetInterceptableLocation(invocation, cancellationToken); + + if (location is null) + { + return null; + } + + return new InvocationInfo(location, modelType, invocationKind.Value, GetStateType(invocationKind.Value, logicalParameters)); + } + + private static INamedTypeSymbol? GetAttributedModelType(GeneratorAttributeSyntaxContext context) + { + if (context.TargetSymbol is not INamedTypeSymbol modelType || + modelType.TypeKind != TypeKind.Class || + modelType.IsAbstract || + ImplementsIShape(modelType) || + !HasAccessibleParameterlessConstructor(modelType)) + { + return null; + } + + return modelType; + } + + private static ImmutableArray GetLogicalParameters(IMethodSymbol methodSymbol) + { + if (methodSymbol.Parameters.Length > 0 && + methodSymbol.Parameters[0].Type.ToDisplayString() == IShapeFactoryFullName) + { + return [.. methodSymbol.Parameters.Skip(1)]; + } + + return methodSymbol.Parameters; + } + + private static ITypeSymbol? GetStateType(InvocationKind invocationKind, ImmutableArray logicalParameters) + => (invocationKind is InvocationKind.ActionWithState or InvocationKind.FuncWithState) && + logicalParameters.Length >= 3 + ? logicalParameters[2].Type + : null; + + private static InvocationKind? GetInvocationKind(IMethodSymbol targetMethod, ImmutableArray logicalParameters, INamedTypeSymbol modelType) + { + if (targetMethod.Name == "CreateAsync" && + targetMethod.ContainingType?.ToDisplayString() == ShapeFactoryExtensionsFullName) + { + return GetShapeFactoryInvocationKind(logicalParameters, modelType); + } + + if (targetMethod.Name == "Initialize" && + InheritsFrom(targetMethod.ContainingType, DisplayDriverBaseFullName)) + { + return GetDisplayDriverInvocationKind(logicalParameters, modelType); + } + + return null; + } + + private static InvocationKind? GetShapeFactoryInvocationKind(ImmutableArray logicalParameters, INamedTypeSymbol modelType) + { + if (logicalParameters.IsDefaultOrEmpty) + { + return null; + } + + if (logicalParameters.Length == 1) + { + if (IsAction(logicalParameters[0].Type, modelType)) + { + return InvocationKind.ActionWithoutShapeType; + } + + if (IsFunc(logicalParameters[0].Type, modelType, null)) + { + return InvocationKind.FuncWithoutShapeType; + } + + return null; + } + + if (logicalParameters.Length == 2 && logicalParameters[0].Type.SpecialType == SpecialType.System_String) + { + if (IsAction(logicalParameters[1].Type, modelType)) + { + return InvocationKind.ActionWithShapeType; + } + + if (IsFunc(logicalParameters[1].Type, modelType, null)) + { + return InvocationKind.FuncWithShapeType; + } + + return null; + } + + if (logicalParameters.Length == 3 && logicalParameters[0].Type.SpecialType == SpecialType.System_String) + { + if (IsAction(logicalParameters[1].Type, modelType, logicalParameters[2].Type)) + { + return InvocationKind.ActionWithState; + } + + if (IsFunc(logicalParameters[1].Type, modelType, logicalParameters[2].Type)) + { + return InvocationKind.FuncWithState; + } + } + + return null; + } + + private static InvocationKind? GetDisplayDriverInvocationKind(ImmutableArray logicalParameters, INamedTypeSymbol modelType) + { + if (logicalParameters.IsDefault) + { + return null; + } + + if (logicalParameters.Length == 0) + { + return InvocationKind.DisplayDriverWithoutShapeType; + } + + if (logicalParameters.Length == 1) + { + if (logicalParameters[0].Type.SpecialType == SpecialType.System_String) + { + return InvocationKind.DisplayDriverWithoutInitialize; + } + + if (IsAction(logicalParameters[0].Type, modelType)) + { + return InvocationKind.DisplayDriverActionWithoutShapeType; + } + + if (IsFunc(logicalParameters[0].Type, modelType, null)) + { + return InvocationKind.DisplayDriverFuncWithoutShapeType; + } + + return null; + } + + if (logicalParameters.Length == 2 && logicalParameters[0].Type.SpecialType == SpecialType.System_String) + { + if (IsAction(logicalParameters[1].Type, modelType)) + { + return InvocationKind.DisplayDriverActionWithShapeType; + } + + if (IsFunc(logicalParameters[1].Type, modelType, null)) + { + return InvocationKind.DisplayDriverFuncWithShapeType; + } + } + + return null; + } + + private static bool IsAction(ITypeSymbol typeSymbol, INamedTypeSymbol modelType, ITypeSymbol? stateType = null) + => typeSymbol is INamedTypeSymbol namedType && + namedType.ContainingNamespace?.ToDisplayString() == "System" && + namedType.Name == "Action" && + ((stateType is null && + namedType.TypeArguments.Length == 1 && + SymbolEqualityComparer.Default.Equals(namedType.TypeArguments[0], modelType)) || + (stateType is not null && + namedType.TypeArguments.Length == 2 && + SymbolEqualityComparer.Default.Equals(namedType.TypeArguments[0], modelType) && + SymbolEqualityComparer.Default.Equals(namedType.TypeArguments[1], stateType))); + + private static bool IsFunc(ITypeSymbol typeSymbol, INamedTypeSymbol modelType, ITypeSymbol? stateType) + { + if (typeSymbol is not INamedTypeSymbol namedType || + namedType.ContainingNamespace?.ToDisplayString() != "System" || + namedType.Name != "Func") + { + return false; + } + + if (stateType is null) + { + return namedType.TypeArguments.Length == 2 && + SymbolEqualityComparer.Default.Equals(namedType.TypeArguments[0], modelType) && + namedType.TypeArguments[1].ToDisplayString() == ValueTaskFullName; + } + + return namedType.TypeArguments.Length == 3 && + SymbolEqualityComparer.Default.Equals(namedType.TypeArguments[0], modelType) && + SymbolEqualityComparer.Default.Equals(namedType.TypeArguments[1], stateType) && + namedType.TypeArguments[2].ToDisplayString() == ValueTaskFullName; + } + + private static bool ImplementsIShape(INamedTypeSymbol typeSymbol) + => typeSymbol.AllInterfaces.Any(i => i.ToDisplayString() == IShapeFullName); + + private static bool HasGenerateShapeAttribute(INamedTypeSymbol typeSymbol) + => typeSymbol.GetAttributes().Any(attribute => attribute.AttributeClass?.ToDisplayString() == GenerateShapeAttributeFullName); + + private static bool InheritsFrom(INamedTypeSymbol? typeSymbol, string fullName) + { + for (var current = typeSymbol; current is not null; current = current.BaseType) + { + if (current.ToDisplayString() == fullName) + { + return true; + } + } + + return false; + } + + private static bool HasAccessibleParameterlessConstructor(INamedTypeSymbol typeSymbol) + { + if (typeSymbol.InstanceConstructors.Length == 0) + { + return true; + } + + return typeSymbol.InstanceConstructors.Any(ctor => + ctor.Parameters.Length == 0 && + ctor.DeclaredAccessibility is Accessibility.Public or Accessibility.Internal or Accessibility.Protected or Accessibility.ProtectedOrInternal); + } + + private static void Execute(ImmutableArray invocations, SourceProductionContext context) + { + if (invocations.IsDefaultOrEmpty) + { + return; + } + + var validInvocations = invocations + .Where(static invocation => invocation is not null) + .Distinct(InvocationInfoComparer.Instance) + .ToArray(); + + if (validInvocations.Length == 0) + { + return; + } + + var modelTypes = validInvocations + .Select(static invocation => invocation.ModelType) + .Distinct(SymbolEqualityComparer.Default) + .ToArray(); + + var sb = new StringBuilder(); + + var needsAwaitedHelper = validInvocations.Any(static invocation => invocation.Kind is InvocationKind.FuncWithoutShapeType or InvocationKind.FuncWithShapeType or InvocationKind.FuncWithState); + + sb.Append(""" + // + #nullable enable + + namespace System.Runtime.CompilerServices + { + [global::System.Diagnostics.Conditional("DEBUG")] + [global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = true)] + sealed file class InterceptsLocationAttribute : global::System.Attribute + { + public InterceptsLocationAttribute(int version, string data) + { + _ = version; + _ = data; + } + } + + } + + namespace OrchardCore.DisplayManagement.Generated + { + """); + + if (needsAwaitedHelper) + { + sb.Append(""" + file static class ShapeFactoryInterceptorHelpers + { + public static async global::System.Threading.Tasks.ValueTask Awaited(global::System.Threading.Tasks.ValueTask task, global::OrchardCore.DisplayManagement.IShape shape) + { + await task; + return shape; + } + } + + """); + } + + sb.AppendLine(); + + foreach (var modelType in modelTypes) + { + GenerateShapeType(sb, modelType); + } + + foreach (var invocation in validInvocations) + { + GenerateInterceptor(sb, invocation); + } + + sb.AppendLine("}"); + + context.AddSource("ShapeFactoryGenerator.g.cs", SourceText.From(sb.ToString(), Encoding.UTF8)); + } + + private static void ExecuteAttributedTypes(ImmutableArray modelTypes, SourceProductionContext context) + { + if (modelTypes.IsDefaultOrEmpty) + { + return; + } + + foreach (var modelType in modelTypes.Distinct(SymbolEqualityComparer.Default)) + { + context.CancellationToken.ThrowIfCancellationRequested(); + + var source = GenerateAttributedShapeType(modelType); + + if (!string.IsNullOrEmpty(source)) + { + context.AddSource($"{modelType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat).Replace("global::", string.Empty).Replace(".", "_").Replace("<", "_").Replace(">", "_")}.Shape.g.cs", + SourceText.From(source, Encoding.UTF8)); + } + } + } + + private static void GenerateShapeType(StringBuilder sb, INamedTypeSymbol modelType) + { + var baseTypeName = modelType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var generatedTypeName = GetGeneratedTypeName(modelType); + + sb.AppendLine($" public sealed class {generatedTypeName} : {baseTypeName}, global::OrchardCore.DisplayManagement.IShape, global::OrchardCore.DisplayManagement.IPositioned"); + sb.AppendLine(" {"); + sb.AppendLine(" private global::OrchardCore.DisplayManagement.Shapes.ShapeMetadata? _metadata;"); + sb.AppendLine(); + sb.AppendLine(" private global::System.Collections.Generic.List? _classes;"); + sb.AppendLine(); + sb.AppendLine(" private global::System.Collections.Generic.Dictionary? _attributes;"); + sb.AppendLine(); + sb.AppendLine(" private global::System.Collections.Generic.Dictionary? _properties;"); + sb.AppendLine(); + sb.AppendLine(" private bool _sorted;"); + sb.AppendLine(); + sb.AppendLine(" private global::System.Collections.Generic.List? _items;"); + sb.AppendLine(); + sb.AppendLine($" public {GetMemberHidingModifier(modelType, "Metadata")}global::OrchardCore.DisplayManagement.Shapes.ShapeMetadata Metadata => _metadata ??= new global::OrchardCore.DisplayManagement.Shapes.ShapeMetadata();"); + sb.AppendLine(); + sb.AppendLine($" public {GetMemberHidingModifier(modelType, "Position")}string Position"); + sb.AppendLine(" {"); + sb.AppendLine(" get => Metadata.Position;"); + sb.AppendLine(" set => Metadata.Position = value;"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine($" public {GetMemberHidingModifier(modelType, "Id")}string Id {{ get; set; }} = null!;"); + sb.AppendLine(); + sb.AppendLine($" public {GetMemberHidingModifier(modelType, "TagName")}string TagName {{ get; set; }} = null!;"); + sb.AppendLine(); + sb.AppendLine($" public {GetMemberHidingModifier(modelType, "Classes")}global::System.Collections.Generic.IList Classes => _classes ??= [];"); + sb.AppendLine(); + sb.AppendLine($" public {GetMemberHidingModifier(modelType, "Attributes")}global::System.Collections.Generic.IDictionary Attributes => _attributes ??= [];"); + sb.AppendLine(); + sb.AppendLine($" public {GetMemberHidingModifier(modelType, "Properties")}global::System.Collections.Generic.IDictionary Properties => _properties ??= [];"); + sb.AppendLine(); + sb.AppendLine($" public {GetMemberHidingModifier(modelType, "Items")}global::System.Collections.Generic.IReadOnlyList Items"); + sb.AppendLine(" {"); + sb.AppendLine(" get"); + sb.AppendLine(" {"); + sb.AppendLine(" _items ??= [];"); + sb.AppendLine(); + sb.AppendLine(" if (!_sorted && _items.Count > 0)"); + sb.AppendLine(" {"); + sb.AppendLine(" _items = global::System.Linq.Enumerable.ToList(global::System.Linq.Enumerable.OrderBy(_items, static x => x, global::OrchardCore.DisplayManagement.Zones.FlatPositionComparer.Instance));"); + sb.AppendLine(" _sorted = true;"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine(" return _items;"); + sb.AppendLine(" }"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine($" public {GetMemberHidingModifier(modelType, "AddAsync", 2)}global::System.Threading.Tasks.ValueTask AddAsync(object item, string position)"); + sb.AppendLine(" {"); + sb.AppendLine(" if (item == null)"); + sb.AppendLine(" {"); + sb.AppendLine(" return global::System.Threading.Tasks.ValueTask.FromResult(this);"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine(" position ??= string.Empty;"); + sb.AppendLine(" _sorted = false;"); + sb.AppendLine(" _items ??= [];"); + sb.AppendLine(); + sb.AppendLine(" var wrapped = global::OrchardCore.DisplayManagement.PositionWrapper.TryWrap(item, position);"); + sb.AppendLine(" if (wrapped is not null)"); + sb.AppendLine(" {"); + sb.AppendLine(" _items.Add(wrapped);"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine(" return global::System.Threading.Tasks.ValueTask.FromResult(this);"); + sb.AppendLine(" }"); + sb.AppendLine(" }"); + sb.AppendLine(); + } + + private static string GenerateAttributedShapeType(INamedTypeSymbol modelType) + { + var sb = new StringBuilder(); + + sb.AppendLine("// "); + sb.AppendLine("#nullable enable"); + sb.AppendLine(); + + var namespaceName = modelType.ContainingNamespace?.ToDisplayString(); + var hasNamespace = !string.IsNullOrEmpty(namespaceName) && namespaceName != ""; + var currentIndent = string.Empty; + + if (hasNamespace) + { + sb.AppendLine($"namespace {namespaceName}"); + sb.AppendLine("{"); + currentIndent = " "; + } + + AppendContainingTypeDeclarations(sb, modelType, ref currentIndent); + + var keyword = modelType.IsRecord ? "partial record" : "partial class"; + var accessibility = GetAccessibilityText(modelType.DeclaredAccessibility, modelType.ContainingType is not null); + + sb.AppendLine($"{currentIndent}{accessibility} {keyword} {modelType.Name} : global::OrchardCore.DisplayManagement.IShape, global::OrchardCore.DisplayManagement.IPositioned"); + sb.AppendLine($"{currentIndent}{{"); + sb.AppendLine($"{currentIndent} private global::OrchardCore.DisplayManagement.Shapes.ShapeMetadata? _metadata;"); + sb.AppendLine(); + sb.AppendLine($"{currentIndent} private global::System.Collections.Generic.List? _classes;"); + sb.AppendLine(); + sb.AppendLine($"{currentIndent} private global::System.Collections.Generic.Dictionary? _attributes;"); + sb.AppendLine(); + sb.AppendLine($"{currentIndent} private global::System.Collections.Generic.Dictionary? _properties;"); + sb.AppendLine(); + sb.AppendLine($"{currentIndent} private bool _sorted;"); + sb.AppendLine(); + sb.AppendLine($"{currentIndent} private global::System.Collections.Generic.List? _items;"); + sb.AppendLine(); + sb.AppendLine($"{currentIndent} public global::OrchardCore.DisplayManagement.Shapes.ShapeMetadata Metadata => _metadata ??= new global::OrchardCore.DisplayManagement.Shapes.ShapeMetadata();"); + sb.AppendLine(); + sb.AppendLine($"{currentIndent} public string Position"); + sb.AppendLine($"{currentIndent} {{"); + sb.AppendLine($"{currentIndent} get => Metadata.Position;"); + sb.AppendLine($"{currentIndent} set => Metadata.Position = value;"); + sb.AppendLine($"{currentIndent} }}"); + sb.AppendLine(); + sb.AppendLine($"{currentIndent} public string Id {{ get; set; }} = null!;"); + sb.AppendLine(); + sb.AppendLine($"{currentIndent} public string TagName {{ get; set; }} = null!;"); + sb.AppendLine(); + sb.AppendLine($"{currentIndent} public global::System.Collections.Generic.IList Classes => _classes ??= [];"); + sb.AppendLine(); + sb.AppendLine($"{currentIndent} public global::System.Collections.Generic.IDictionary Attributes => _attributes ??= [];"); + sb.AppendLine(); + sb.AppendLine($"{currentIndent} public global::System.Collections.Generic.IDictionary Properties => _properties ??= [];"); + sb.AppendLine(); + sb.AppendLine($"{currentIndent} public global::System.Collections.Generic.IReadOnlyList Items"); + sb.AppendLine($"{currentIndent} {{"); + sb.AppendLine($"{currentIndent} get"); + sb.AppendLine($"{currentIndent} {{"); + sb.AppendLine($"{currentIndent} _items ??= [];"); + sb.AppendLine(); + sb.AppendLine($"{currentIndent} if (!_sorted && _items.Count > 0)"); + sb.AppendLine($"{currentIndent} {{"); + sb.AppendLine($"{currentIndent} _items = global::System.Linq.Enumerable.ToList(global::System.Linq.Enumerable.OrderBy(_items, static x => x, global::OrchardCore.DisplayManagement.Zones.FlatPositionComparer.Instance));"); + sb.AppendLine($"{currentIndent} _sorted = true;"); + sb.AppendLine($"{currentIndent} }}"); + sb.AppendLine(); + sb.AppendLine($"{currentIndent} return _items;"); + sb.AppendLine($"{currentIndent} }}"); + sb.AppendLine($"{currentIndent} }}"); + sb.AppendLine(); + sb.AppendLine($"{currentIndent} public global::System.Threading.Tasks.ValueTask AddAsync(object item, string position)"); + sb.AppendLine($"{currentIndent} {{"); + sb.AppendLine($"{currentIndent} if (item == null)"); + sb.AppendLine($"{currentIndent} {{"); + sb.AppendLine($"{currentIndent} return global::System.Threading.Tasks.ValueTask.FromResult(this);"); + sb.AppendLine($"{currentIndent} }}"); + sb.AppendLine(); + sb.AppendLine($"{currentIndent} position ??= string.Empty;"); + sb.AppendLine($"{currentIndent} _sorted = false;"); + sb.AppendLine($"{currentIndent} _items ??= [];"); + sb.AppendLine(); + sb.AppendLine($"{currentIndent} var wrapped = global::OrchardCore.DisplayManagement.PositionWrapper.TryWrap(item, position);"); + sb.AppendLine($"{currentIndent} if (wrapped is not null)"); + sb.AppendLine($"{currentIndent} {{"); + sb.AppendLine($"{currentIndent} _items.Add(wrapped);"); + sb.AppendLine($"{currentIndent} }}"); + sb.AppendLine(); + sb.AppendLine($"{currentIndent} return global::System.Threading.Tasks.ValueTask.FromResult(this);"); + sb.AppendLine($"{currentIndent} }}"); + sb.AppendLine($"{currentIndent}}}"); + + CloseContainingTypeDeclarations(sb, modelType, ref currentIndent); + + if (hasNamespace) + { + sb.AppendLine("}"); + } + + return sb.ToString(); + } + + private static void AppendContainingTypeDeclarations(StringBuilder sb, INamedTypeSymbol typeSymbol, ref string currentIndent) + { + var containingTypes = new global::System.Collections.Generic.Stack(); + + for (var current = typeSymbol.ContainingType; current is not null; current = current.ContainingType) + { + containingTypes.Push(current); + } + + while (containingTypes.Count > 0) + { + var containingType = containingTypes.Pop(); + var keyword = containingType.IsRecord ? "partial record" : "partial class"; + var accessibility = GetAccessibilityText(containingType.DeclaredAccessibility, containingType.ContainingType is not null); + + sb.AppendLine($"{currentIndent}{accessibility} {keyword} {containingType.Name}"); + sb.AppendLine($"{currentIndent}{{"); + currentIndent += " "; + } + } + + private static void CloseContainingTypeDeclarations(StringBuilder sb, INamedTypeSymbol typeSymbol, ref string currentIndent) + { + for (var current = typeSymbol.ContainingType; current is not null; current = current.ContainingType) + { + currentIndent = currentIndent.Substring(4); + sb.AppendLine($"{currentIndent}}}"); + } + } + + private static string GetAccessibilityText(Accessibility accessibility, bool isNested) + { + if (!isNested) + { + return accessibility switch + { + Accessibility.Public => "public", + Accessibility.Internal => "internal", + _ => "internal", + }; + } + + return accessibility switch + { + Accessibility.Private => "private", + Accessibility.Protected => "protected", + Accessibility.Internal => "internal", + Accessibility.Public => "public", + Accessibility.ProtectedOrInternal => "protected internal", + Accessibility.ProtectedAndInternal => "private protected", + _ => "public", + }; + } + + private static string GetMemberHidingModifier(INamedTypeSymbol modelType, string memberName, int parameterCount = 0) + => HasConflictingMember(modelType, memberName, parameterCount) ? "new " : string.Empty; + + private static bool HasConflictingMember(INamedTypeSymbol modelType, string memberName, int parameterCount) + { + for (var current = modelType; current is not null; current = current.BaseType) + { + if (current.GetMembers(memberName).Any(member => + member switch + { + IPropertySymbol => parameterCount == 0, + IMethodSymbol method => method.Parameters.Length == parameterCount, + _ => false, + })) + { + return true; + } + } + + return false; + } + + private static void GenerateInterceptor(StringBuilder sb, InvocationInfo invocation) + { + var interceptorClassName = $"Interceptor_{GetStableId(invocation.Location.Data)}"; + var generatedTypeName = GetGeneratedTypeName(invocation.ModelType); + var modelTypeName = invocation.ModelType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + sb.AppendLine($" file static class {interceptorClassName}"); + sb.AppendLine(" {"); + sb.AppendLine($" [global::System.Runtime.CompilerServices.InterceptsLocation({invocation.Location.Version}, \"{invocation.Location.Data}\")]"); + + switch (invocation.Kind) + { + case InvocationKind.DisplayDriverWithoutShapeType: + sb.AppendLine(" public static global::OrchardCore.DisplayManagement.Views.ShapeResult InterceptInitialize(this global::OrchardCore.DisplayManagement.Handlers.DisplayDriverBase driver)"); + sb.AppendLine($" => driver.Factory(\"{invocation.ModelType.Name}\","); + sb.AppendLine($" static _ => global::System.Threading.Tasks.ValueTask.FromResult((global::OrchardCore.DisplayManagement.IShape)new {generatedTypeName}()),"); + sb.AppendLine(" null);"); + break; + + case InvocationKind.DisplayDriverWithoutInitialize: + sb.AppendLine(" public static global::OrchardCore.DisplayManagement.Views.ShapeResult InterceptInitialize(this global::OrchardCore.DisplayManagement.Handlers.DisplayDriverBase driver, string shapeType)"); + sb.AppendLine(" => driver.Factory(shapeType,"); + sb.AppendLine($" static _ => global::System.Threading.Tasks.ValueTask.FromResult((global::OrchardCore.DisplayManagement.IShape)new {generatedTypeName}()),"); + sb.AppendLine(" null);"); + break; + + case InvocationKind.ActionWithoutShapeType: + sb.AppendLine($" public static global::System.Threading.Tasks.ValueTask InterceptCreateAsync(this global::OrchardCore.DisplayManagement.IShapeFactory factory, global::System.Action<{modelTypeName}>? initialize = null)"); + sb.AppendLine(" => factory.CreateAsync("); + sb.AppendLine($" \"{invocation.ModelType.Name}\","); + sb.AppendLine(" static initialize => ShapeFactory(initialize),"); + sb.AppendLine(" initialize);"); + sb.AppendLine(); + sb.AppendLine($" private static global::System.Threading.Tasks.ValueTask ShapeFactory(global::System.Action<{modelTypeName}>? initialize)"); + sb.AppendLine(" {"); + sb.AppendLine($" var shape = (global::OrchardCore.DisplayManagement.IShape)new {generatedTypeName}();"); + sb.AppendLine($" initialize?.Invoke(({modelTypeName})shape);"); + sb.AppendLine(" return global::System.Threading.Tasks.ValueTask.FromResult(shape);"); + sb.AppendLine(" }"); + break; + + case InvocationKind.ActionWithShapeType: + sb.AppendLine($" public static global::System.Threading.Tasks.ValueTask InterceptCreateAsync(this global::OrchardCore.DisplayManagement.IShapeFactory factory, string shapeType, global::System.Action<{modelTypeName}>? initialize = null)"); + sb.AppendLine(" => factory.CreateAsync("); + sb.AppendLine(" shapeType,"); + sb.AppendLine(" static initialize => ShapeFactory(initialize),"); + sb.AppendLine(" initialize);"); + sb.AppendLine(); + sb.AppendLine($" private static global::System.Threading.Tasks.ValueTask ShapeFactory(global::System.Action<{modelTypeName}>? initialize)"); + sb.AppendLine(" {"); + sb.AppendLine($" var shape = (global::OrchardCore.DisplayManagement.IShape)new {generatedTypeName}();"); + sb.AppendLine($" initialize?.Invoke(({modelTypeName})shape);"); + sb.AppendLine(" return global::System.Threading.Tasks.ValueTask.FromResult(shape);"); + sb.AppendLine(" }"); + break; + + case InvocationKind.ActionWithState: + var actionStateType = invocation.StateType!.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + sb.AppendLine($" public static global::System.Threading.Tasks.ValueTask InterceptCreateAsync(this global::OrchardCore.DisplayManagement.IShapeFactory factory, string shapeType, global::System.Action<{modelTypeName}, {actionStateType}> initialize, {actionStateType} state)"); + sb.AppendLine(" => factory.CreateAsync("); + sb.AppendLine(" shapeType,"); + sb.AppendLine(" static state => ShapeFactory(state.initialize, state.state),"); + sb.AppendLine(" (initialize, state));"); + sb.AppendLine(); + sb.AppendLine($" private static global::System.Threading.Tasks.ValueTask ShapeFactory(global::System.Action<{modelTypeName}, {actionStateType}> initialize, {actionStateType} state)"); + sb.AppendLine(" {"); + sb.AppendLine($" var shape = (global::OrchardCore.DisplayManagement.IShape)new {generatedTypeName}();"); + sb.AppendLine($" initialize?.Invoke(({modelTypeName})shape, state);"); + sb.AppendLine(" return global::System.Threading.Tasks.ValueTask.FromResult(shape);"); + sb.AppendLine(" }"); + break; + + case InvocationKind.FuncWithoutShapeType: + sb.AppendLine($" public static global::System.Threading.Tasks.ValueTask InterceptCreateAsync(this global::OrchardCore.DisplayManagement.IShapeFactory factory, global::System.Func<{modelTypeName}, global::System.Threading.Tasks.ValueTask> initializeAsync)"); + sb.AppendLine(" => factory.CreateAsync("); + sb.AppendLine($" \"{invocation.ModelType.Name}\","); + sb.AppendLine(" static initializeAsync => ShapeFactory(initializeAsync),"); + sb.AppendLine(" initializeAsync);"); + sb.AppendLine(); + sb.AppendLine($" private static global::System.Threading.Tasks.ValueTask ShapeFactory(global::System.Func<{modelTypeName}, global::System.Threading.Tasks.ValueTask> initializeAsync)"); + sb.AppendLine(" {"); + sb.AppendLine($" var shape = (global::OrchardCore.DisplayManagement.IShape)new {generatedTypeName}();"); + sb.AppendLine($" var task = initializeAsync?.Invoke(({modelTypeName})shape) ?? global::System.Threading.Tasks.ValueTask.CompletedTask;"); + sb.AppendLine(); + sb.AppendLine(" if (!task.IsCompletedSuccessfully)"); + sb.AppendLine(" {"); + sb.AppendLine(" return ShapeFactoryInterceptorHelpers.Awaited(task, shape);"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine(" return global::System.Threading.Tasks.ValueTask.FromResult(shape);"); + sb.AppendLine(" }"); + break; + + case InvocationKind.FuncWithShapeType: + sb.AppendLine($" public static global::System.Threading.Tasks.ValueTask InterceptCreateAsync(this global::OrchardCore.DisplayManagement.IShapeFactory factory, string shapeType, global::System.Func<{modelTypeName}, global::System.Threading.Tasks.ValueTask> initializeAsync)"); + sb.AppendLine(" => factory.CreateAsync("); + sb.AppendLine(" shapeType,"); + sb.AppendLine(" static initializeAsync => ShapeFactory(initializeAsync),"); + sb.AppendLine(" initializeAsync);"); + sb.AppendLine(); + sb.AppendLine($" private static global::System.Threading.Tasks.ValueTask ShapeFactory(global::System.Func<{modelTypeName}, global::System.Threading.Tasks.ValueTask> initializeAsync)"); + sb.AppendLine(" {"); + sb.AppendLine($" var shape = (global::OrchardCore.DisplayManagement.IShape)new {generatedTypeName}();"); + sb.AppendLine($" var task = initializeAsync?.Invoke(({modelTypeName})shape) ?? global::System.Threading.Tasks.ValueTask.CompletedTask;"); + sb.AppendLine(); + sb.AppendLine(" if (!task.IsCompletedSuccessfully)"); + sb.AppendLine(" {"); + sb.AppendLine(" return ShapeFactoryInterceptorHelpers.Awaited(task, shape);"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine(" return global::System.Threading.Tasks.ValueTask.FromResult(shape);"); + sb.AppendLine(" }"); + break; + + case InvocationKind.FuncWithState: + var funcStateType = invocation.StateType!.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + sb.AppendLine($" public static global::System.Threading.Tasks.ValueTask InterceptCreateAsync(this global::OrchardCore.DisplayManagement.IShapeFactory factory, string shapeType, global::System.Func<{modelTypeName}, {funcStateType}, global::System.Threading.Tasks.ValueTask> initializeAsync, {funcStateType} state)"); + sb.AppendLine(" => factory.CreateAsync("); + sb.AppendLine(" shapeType,"); + sb.AppendLine(" static state => ShapeFactory(state.initializeAsync, state.state),"); + sb.AppendLine(" (initializeAsync, state));"); + sb.AppendLine(); + sb.AppendLine($" private static global::System.Threading.Tasks.ValueTask ShapeFactory(global::System.Func<{modelTypeName}, {funcStateType}, global::System.Threading.Tasks.ValueTask> initializeAsync, {funcStateType} state)"); + sb.AppendLine(" {"); + sb.AppendLine($" var shape = (global::OrchardCore.DisplayManagement.IShape)new {generatedTypeName}();"); + sb.AppendLine($" var task = initializeAsync?.Invoke(({modelTypeName})shape, state) ?? global::System.Threading.Tasks.ValueTask.CompletedTask;"); + sb.AppendLine(); + sb.AppendLine(" if (!task.IsCompletedSuccessfully)"); + sb.AppendLine(" {"); + sb.AppendLine(" return ShapeFactoryInterceptorHelpers.Awaited(task, shape);"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine(" return global::System.Threading.Tasks.ValueTask.FromResult(shape);"); + sb.AppendLine(" }"); + break; + + case InvocationKind.DisplayDriverActionWithoutShapeType: + sb.AppendLine($" public static global::OrchardCore.DisplayManagement.Views.ShapeResult InterceptInitialize(this global::OrchardCore.DisplayManagement.Handlers.DisplayDriverBase driver, global::System.Action<{modelTypeName}> initialize)"); + sb.AppendLine($" => driver.Factory(\"{invocation.ModelType.Name}\","); + sb.AppendLine($" static _ => global::System.Threading.Tasks.ValueTask.FromResult((global::OrchardCore.DisplayManagement.IShape)new {generatedTypeName}()),"); + sb.AppendLine($" shape => InterceptInitialize(({modelTypeName})shape, initialize));"); + sb.AppendLine(); + sb.AppendLine($" private static global::System.Threading.Tasks.Task InterceptInitialize({modelTypeName} shape, global::System.Action<{modelTypeName}> initialize)"); + sb.AppendLine(" {"); + sb.AppendLine(" initialize?.Invoke(shape);"); + sb.AppendLine(" return global::System.Threading.Tasks.Task.CompletedTask;"); + sb.AppendLine(" }"); + break; + + case InvocationKind.DisplayDriverActionWithShapeType: + sb.AppendLine($" public static global::OrchardCore.DisplayManagement.Views.ShapeResult InterceptInitialize(this global::OrchardCore.DisplayManagement.Handlers.DisplayDriverBase driver, string shapeType, global::System.Action<{modelTypeName}> initialize)"); + sb.AppendLine(" => driver.Factory(shapeType,"); + sb.AppendLine($" static _ => global::System.Threading.Tasks.ValueTask.FromResult((global::OrchardCore.DisplayManagement.IShape)new {generatedTypeName}()),"); + sb.AppendLine($" shape => InterceptInitialize(({modelTypeName})shape, initialize));"); + sb.AppendLine(); + sb.AppendLine($" private static global::System.Threading.Tasks.Task InterceptInitialize({modelTypeName} shape, global::System.Action<{modelTypeName}> initialize)"); + sb.AppendLine(" {"); + sb.AppendLine(" initialize?.Invoke(shape);"); + sb.AppendLine(" return global::System.Threading.Tasks.Task.CompletedTask;"); + sb.AppendLine(" }"); + break; + + case InvocationKind.DisplayDriverFuncWithoutShapeType: + sb.AppendLine($" public static global::OrchardCore.DisplayManagement.Views.ShapeResult InterceptInitialize(this global::OrchardCore.DisplayManagement.Handlers.DisplayDriverBase driver, global::System.Func<{modelTypeName}, global::System.Threading.Tasks.ValueTask> initializeAsync)"); + sb.AppendLine($" => driver.Factory(\"{invocation.ModelType.Name}\","); + sb.AppendLine($" static _ => global::System.Threading.Tasks.ValueTask.FromResult((global::OrchardCore.DisplayManagement.IShape)new {generatedTypeName}()),"); + sb.AppendLine($" shape => InterceptInitialize(({modelTypeName})shape, initializeAsync));"); + sb.AppendLine(); + sb.AppendLine($" private static global::System.Threading.Tasks.Task InterceptInitialize({modelTypeName} shape, global::System.Func<{modelTypeName}, global::System.Threading.Tasks.ValueTask> initializeAsync)"); + sb.AppendLine(" => initializeAsync?.Invoke(shape).AsTask() ?? global::System.Threading.Tasks.Task.CompletedTask;"); + break; + + case InvocationKind.DisplayDriverFuncWithShapeType: + sb.AppendLine($" public static global::OrchardCore.DisplayManagement.Views.ShapeResult InterceptInitialize(this global::OrchardCore.DisplayManagement.Handlers.DisplayDriverBase driver, string shapeType, global::System.Func<{modelTypeName}, global::System.Threading.Tasks.ValueTask> initializeAsync)"); + sb.AppendLine(" => driver.Factory(shapeType,"); + sb.AppendLine($" static _ => global::System.Threading.Tasks.ValueTask.FromResult((global::OrchardCore.DisplayManagement.IShape)new {generatedTypeName}()),"); + sb.AppendLine($" shape => InterceptInitialize(({modelTypeName})shape, initializeAsync));"); + sb.AppendLine(); + sb.AppendLine($" private static global::System.Threading.Tasks.Task InterceptInitialize({modelTypeName} shape, global::System.Func<{modelTypeName}, global::System.Threading.Tasks.ValueTask> initializeAsync)"); + sb.AppendLine(" => initializeAsync?.Invoke(shape).AsTask() ?? global::System.Threading.Tasks.Task.CompletedTask;"); + break; + } + + sb.AppendLine(" }"); + sb.AppendLine(); + } + + private static string GetGeneratedTypeName(INamedTypeSymbol typeSymbol) + => $"GeneratedShape_{GetStableId(typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))}"; + + private static string GetStableId(string value) + { + using var sha256 = SHA256.Create(); + var hash = sha256.ComputeHash(Encoding.UTF8.GetBytes(value)); + + return BitConverter.ToString(hash).Replace("-", ""); + } + + private sealed class InvocationInfo + { + public InvocationInfo(InterceptableLocation location, INamedTypeSymbol modelType, InvocationKind kind, ITypeSymbol? stateType) + { + Location = location; + ModelType = modelType; + Kind = kind; + StateType = stateType; + } + + public InterceptableLocation Location { get; } + public INamedTypeSymbol ModelType { get; } + public InvocationKind Kind { get; } + public ITypeSymbol? StateType { get; } + } + + private sealed class InvocationInfoComparer : IEqualityComparer + { + public static InvocationInfoComparer Instance { get; } = new(); + + public bool Equals(InvocationInfo? x, InvocationInfo? y) + => x?.Location.Data == y?.Location.Data && + x?.Location.Version == y?.Location.Version; + + public int GetHashCode(InvocationInfo obj) + => ((obj.Location.Data?.GetHashCode() ?? 0) * 397) ^ obj.Location.Version; + } + + private enum InvocationKind + { + DisplayDriverWithoutShapeType, + DisplayDriverWithoutInitialize, + ActionWithoutShapeType, + ActionWithShapeType, + ActionWithState, + FuncWithoutShapeType, + FuncWithShapeType, + FuncWithState, + DisplayDriverActionWithoutShapeType, + DisplayDriverActionWithShapeType, + DisplayDriverFuncWithoutShapeType, + DisplayDriverFuncWithShapeType, + } +} diff --git a/test/OrchardCore.Tests/DisplayManagement/ShapeFactoryTests.cs b/test/OrchardCore.Tests/DisplayManagement/ShapeFactoryTests.cs index 688090de40d..b4be136ded9 100644 --- a/test/OrchardCore.Tests/DisplayManagement/ShapeFactoryTests.cs +++ b/test/OrchardCore.Tests/DisplayManagement/ShapeFactoryTests.cs @@ -1,8 +1,10 @@ using OrchardCore.DisplayManagement; using OrchardCore.DisplayManagement.Descriptors; +using OrchardCore.DisplayManagement.Handlers; using OrchardCore.DisplayManagement.Implementation; using OrchardCore.DisplayManagement.Shapes; using OrchardCore.DisplayManagement.Theming; +using OrchardCore.DisplayManagement.Views; using OrchardCore.Environment.Extensions; using OrchardCore.Tests.Stubs; @@ -126,7 +128,218 @@ public async Task ShapeFactoryWithCustomShapeTypeAppliesArguments() Assert.Equal("Baz", foo.Baz); } + [Fact] + public async Task CreateStronglyTypedShapeUsesGeneratedShapeType() + { + var factory = _serviceProvider.GetRequiredService(); + + var shape = await factory.CreateAsync(model => + { + model.Title = "Generated"; + model.Count = 5; + }); + + var typedShape = Assert.IsAssignableFrom(shape); + var generatedShapeType = typedShape.GetType(); + + Assert.NotEqual(typeof(TestShapeViewModel), generatedShapeType); + Assert.Equal(typeof(ShapeFactoryTests).Assembly, generatedShapeType.Assembly); + Assert.False(generatedShapeType.Assembly.IsDynamic); + Assert.True(generatedShapeType.IsPublic); + Assert.DoesNotContain(generatedShapeType.GetFields(BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public), field => field.FieldType == typeof(ShapeViewModel)); + Assert.Equal("Generated", typedShape.Title); + Assert.Equal(5, typedShape.Count); + Assert.Same(shape.Metadata, ((IShape)shape).Metadata); + } + + [Fact] + public async Task CreateStronglyTypedShapeDelegatesShapeMembers() + { + var factory = _serviceProvider.GetRequiredService(); + + var shape = await factory.CreateAsync(); + + shape.Id = "shape-id"; + shape.TagName = "section"; + shape.Classes.Add("test-class"); + shape.Attributes["data-test"] = "true"; + shape.Properties["answer"] = 42; + + await shape.AddAsync(new Shape(), "1"); + + var positionedShape = Assert.IsAssignableFrom(shape); + + positionedShape.Position = "3"; + + Assert.Equal("shape-id", shape.Id); + Assert.Equal("section", shape.TagName); + Assert.Contains("test-class", shape.Classes); + Assert.Equal("true", shape.Attributes["data-test"]); + Assert.Equal(42, shape.Properties["answer"]); + Assert.Equal("3", positionedShape.Position); + Assert.Single(shape.Items); + } + + [Fact] + public async Task CreateStronglyTypedShapeFallsBackToCastleProxy() + { + var factory = _serviceProvider.GetRequiredService(); + var shapeFactoryExtensionsType = typeof(IShapeFactory).Assembly.GetType("OrchardCore.DisplayManagement.ShapeFactoryExtensions", throwOnError: true); + var createAsyncMethod = GetCreateAsyncActionOverload(shapeFactoryExtensionsType); + var genericCreateAsyncMethod = createAsyncMethod.MakeGenericMethod(typeof(FallbackOnlyShapeViewModel)); + + Assert.NotNull(genericCreateAsyncMethod); + + var shape = await (ValueTask)genericCreateAsyncMethod.Invoke(null, [factory, null]); + var typedShape = Assert.IsAssignableFrom(shape); + + Assert.NotEqual(typeof(FallbackOnlyShapeViewModel), typedShape.GetType()); + Assert.True(typedShape.GetType().Assembly.IsDynamic); + } + + [Fact] + public async Task CreateStronglyTypedShapeUsesAttributedModelWithoutInterception() + { + var factory = _serviceProvider.GetRequiredService(); + var shape = await factory.CreateAsync(model => model.Title = "Attributed"); + var typedShape = Assert.IsAssignableFrom(shape); + + Assert.Same(typeof(AttributedShapeViewModel), typedShape.GetType()); + Assert.Equal(typeof(ShapeFactoryTests).Assembly, typedShape.GetType().Assembly); + Assert.False(typedShape.GetType().Assembly.IsDynamic); + Assert.True(typedShape.GetType().IsPublic); + Assert.Equal("Attributed", typedShape.Title); + } + + [Fact] + public async Task DisplayDriverInitializeUsesGeneratedShapeType() + { + var factory = _serviceProvider.GetRequiredService(); + var shapeResult = new TestDisplayDriver().Build(); + var shape = await BuildShapeAsync(shapeResult, factory); + var typedShape = Assert.IsAssignableFrom(shape); + var generatedShapeType = typedShape.GetType(); + + Assert.IsAssignableFrom(typedShape); + Assert.NotEqual(typeof(TestShapeViewModel), generatedShapeType); + Assert.Equal(typeof(TestShapeViewModel), generatedShapeType.BaseType); + Assert.Equal(typeof(ShapeFactoryTests).Assembly, generatedShapeType.Assembly); + Assert.False(generatedShapeType.Assembly.IsDynamic); + Assert.True(generatedShapeType.IsPublic); + Assert.Equal("Driver", typedShape.Title); + Assert.Equal(10, typedShape.Count); + } + + [Fact] + public async Task DisplayDriverInitializeWithoutInitializerUsesGeneratedShapeType() + { + var factory = _serviceProvider.GetRequiredService(); + var shapeResult = new TestDisplayDriverWithoutInitializer().Build(); + var shape = await BuildShapeAsync(shapeResult, factory); + var typedShape = Assert.IsAssignableFrom(shape); + var generatedShapeType = typedShape.GetType(); + + Assert.IsAssignableFrom(typedShape); + Assert.NotEqual(typeof(TestShapeViewModel), generatedShapeType); + Assert.Equal(typeof(TestShapeViewModel), generatedShapeType.BaseType); + Assert.Equal(typeof(ShapeFactoryTests).Assembly, generatedShapeType.Assembly); + Assert.False(generatedShapeType.Assembly.IsDynamic); + Assert.True(generatedShapeType.IsPublic); + } + + [Fact] + public async Task DisplayDriverInitializeWithoutShapeTypeOrInitializerUsesGeneratedShapeType() + { + var factory = _serviceProvider.GetRequiredService(); + var shapeResult = new TestDisplayDriverWithoutShapeTypeOrInitializer().Build(); + var shape = await BuildShapeAsync(shapeResult, factory); + var typedShape = Assert.IsAssignableFrom(shape); + var generatedShapeType = typedShape.GetType(); + + Assert.IsAssignableFrom(typedShape); + Assert.NotEqual(typeof(TestShapeViewModel), generatedShapeType); + Assert.Equal(typeof(TestShapeViewModel), generatedShapeType.BaseType); + Assert.Equal(typeof(ShapeFactoryTests).Assembly, generatedShapeType.Assembly); + Assert.False(generatedShapeType.Assembly.IsDynamic); + Assert.True(generatedShapeType.IsPublic); + } + + private static MethodInfo GetCreateAsyncActionOverload(Type shapeFactoryExtensionsType) + => shapeFactoryExtensionsType + .GetMethods(BindingFlags.Public | BindingFlags.Static) + .Single(method => method.Name == "CreateAsync" && + method.IsGenericMethodDefinition && + method.GetGenericArguments().Length == 1 && + method.GetParameters().Length == 2 && + method.GetParameters()[0].ParameterType == typeof(IShapeFactory) && + method.GetParameters()[1].ParameterType.IsGenericType && + method.GetParameters()[1].ParameterType.GetGenericTypeDefinition() == typeof(Action<>)); + + private static async Task BuildShapeAsync(ShapeResult shapeResult, IShapeFactory factory) + { + // ShapeResult only populates Shape during the full display pipeline, so this test + // invokes the stored builder and initializer directly to verify generated wrapper usage. + var shapeBuilderField = typeof(ShapeResult).GetField("_shapeBuilder", BindingFlags.Instance | BindingFlags.NonPublic); + var initializingAsyncField = typeof(ShapeResult).GetField("_initializingAsync", BindingFlags.Instance | BindingFlags.NonPublic); + + Assert.NotNull(shapeBuilderField); + Assert.NotNull(initializingAsyncField); + + var shapeBuilder = Assert.IsType>>(shapeBuilderField.GetValue(shapeResult)); + var initializingAsync = (Func)initializingAsyncField.GetValue(shapeResult); + var shape = await shapeBuilder(new BuildDisplayContext(new Shape(), "Detail", string.Empty, factory, null, null)); + + if (initializingAsync is not null) + { + await initializingAsync(shape); + } + + return shape; + } + private sealed class SubShape : Shape { } + + private sealed class TestDisplayDriver : DisplayDriverBase + { + public ShapeResult Build() + => Initialize("TestShapeViewModel_Edit", model => + { + model.Title = "Driver"; + model.Count = 10; + + return ValueTask.CompletedTask; + }); + } + + private sealed class TestDisplayDriverWithoutInitializer : DisplayDriverBase + { + public ShapeResult Build() + => Initialize("TestShapeViewModel_Edit"); + } + + private sealed class TestDisplayDriverWithoutShapeTypeOrInitializer : DisplayDriverBase + { + public ShapeResult Build() + => Initialize(); + } +} + +public class TestShapeViewModel +{ + public string Title { get; set; } + + public int Count { get; set; } +} + +public class FallbackOnlyShapeViewModel +{ + public string Name { get; set; } +} + +[GenerateShape] +public partial class AttributedShapeViewModel +{ + public string Title { get; set; } }