Skip to content

Commit 9bb4748

Browse files
committed
Merge branch 'main' of github.com:dex3r/MattSourceGenHelpers
# Conflicts: # MattSourceGenHelpers.Generators/GeneratesMethodGenerator.cs
2 parents 988709b + 1aa7794 commit 9bb4748

6 files changed

Lines changed: 842 additions & 740 deletions
Lines changed: 331 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
using Microsoft.CodeAnalysis;
2+
using Microsoft.CodeAnalysis.CSharp;
3+
using Microsoft.CodeAnalysis.CSharp.Syntax;
4+
using Microsoft.CodeAnalysis.Emit;
5+
using System.Collections;
6+
using System.Reflection;
7+
using System.Runtime.Loader;
8+
using System.Text;
9+
10+
namespace MattSourceGenHelpers.Generators;
11+
12+
internal sealed record SwitchBodyData(
13+
IReadOnlyList<(object key, string value)> CasePairs,
14+
bool HasDefaultCase);
15+
16+
internal static class GeneratesMethodExecutionRuntime
17+
{
18+
internal static (string? value, string? error) ExecuteSimpleGeneratorMethod(
19+
IMethodSymbol generatorMethod,
20+
IMethodSymbol partialMethod,
21+
Compilation compilation)
22+
{
23+
IReadOnlyList<IMethodSymbol> allPartials = GetAllUnimplementedPartialMethods(compilation);
24+
return ExecuteGeneratorMethodWithArgs(generatorMethod, allPartials, compilation, null);
25+
}
26+
27+
internal static (SwitchBodyData? record, string? error) ExecuteFluentGeneratorMethod(
28+
IMethodSymbol generatorMethod,
29+
IMethodSymbol partialMethod,
30+
Compilation compilation)
31+
{
32+
IReadOnlyList<IMethodSymbol> allPartials = GetAllUnimplementedPartialMethods(compilation);
33+
CSharpCompilation executableCompilation = BuildExecutionCompilation(allPartials, compilation);
34+
35+
using MemoryStream stream = new();
36+
EmitResult emitResult = executableCompilation.Emit(stream);
37+
if (!emitResult.Success)
38+
{
39+
string errors = string.Join("; ", emitResult.Diagnostics
40+
.Where(diagnostic => diagnostic.Severity == DiagnosticSeverity.Error)
41+
.Select(diagnostic => diagnostic.GetMessage()));
42+
return (null, $"Compilation failed: {errors}");
43+
}
44+
45+
stream.Position = 0;
46+
AssemblyLoadContext? loadContext = null;
47+
try
48+
{
49+
loadContext = new AssemblyLoadContext("__GeneratorExec", isCollectible: true);
50+
loadContext.Resolving += (context, assemblyName) =>
51+
{
52+
PortableExecutableReference? match = compilation.References
53+
.OfType<PortableExecutableReference>()
54+
.FirstOrDefault(reference => reference.FilePath is not null && string.Equals(
55+
Path.GetFileNameWithoutExtension(reference.FilePath),
56+
assemblyName.Name,
57+
StringComparison.OrdinalIgnoreCase));
58+
return match?.FilePath != null ? context.LoadFromAssemblyPath(match.FilePath) : null;
59+
};
60+
61+
Assembly assembly = loadContext.LoadFromStream(stream);
62+
63+
PortableExecutableReference? abstractionsReference = compilation.References
64+
.OfType<PortableExecutableReference>()
65+
.FirstOrDefault(reference => reference.FilePath is not null && string.Equals(
66+
Path.GetFileNameWithoutExtension(reference.FilePath),
67+
"MattSourceGenHelpers.Abstractions",
68+
StringComparison.OrdinalIgnoreCase));
69+
70+
if (abstractionsReference?.FilePath == null)
71+
{
72+
return (null, "Could not find MattSourceGenHelpers.Abstractions reference in compilation");
73+
}
74+
75+
string abstractionsAssemblyPath = ResolveImplementationAssemblyPath(abstractionsReference.FilePath);
76+
Assembly abstractionsAssembly = loadContext.LoadFromAssemblyPath(abstractionsAssemblyPath);
77+
78+
Type? generatorStaticType = abstractionsAssembly.GetType("MattSourceGenHelpers.Abstractions.Generator");
79+
Type? recordingFactoryType = abstractionsAssembly.GetType("MattSourceGenHelpers.Abstractions.RecordingGeneratorsFactory");
80+
if (generatorStaticType == null || recordingFactoryType == null)
81+
{
82+
return (null, "Could not find Generator or RecordingGeneratorsFactory types in Abstractions assembly");
83+
}
84+
85+
object? recordingFactory = Activator.CreateInstance(recordingFactoryType);
86+
PropertyInfo? currentGeneratorProperty = generatorStaticType.GetProperty("CurrentGenerator", BindingFlags.Public | BindingFlags.Static);
87+
currentGeneratorProperty?.SetValue(null, recordingFactory);
88+
89+
string typeName = generatorMethod.ContainingType.ToDisplayString();
90+
Type? loadedType = assembly.GetType(typeName);
91+
if (loadedType == null)
92+
{
93+
return (null, $"Could not find type '{typeName}' in compiled assembly");
94+
}
95+
96+
MethodInfo? generatorMethodInfo = loadedType.GetMethod(generatorMethod.Name, BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.Public);
97+
if (generatorMethodInfo == null)
98+
{
99+
return (null, $"Could not find method '{generatorMethod.Name}' in type '{typeName}'");
100+
}
101+
102+
generatorMethodInfo.Invoke(null, null);
103+
104+
PropertyInfo? lastRecordProperty = recordingFactoryType.GetProperty("LastRecord");
105+
object? lastRecord = lastRecordProperty?.GetValue(recordingFactory);
106+
if (lastRecord == null)
107+
{
108+
return (null, "RecordingGeneratorsFactory did not produce a record");
109+
}
110+
111+
return (ExtractSwitchBodyData(lastRecord, partialMethod.ReturnType), null);
112+
}
113+
catch (Exception ex)
114+
{
115+
return (null, $"Error executing generator method '{generatorMethod.Name}': {ex.GetBaseException()}");
116+
}
117+
finally
118+
{
119+
loadContext?.Unload();
120+
}
121+
}
122+
123+
internal static (string? value, string? error) ExecuteGeneratorMethodWithArgs(
124+
IMethodSymbol generatorMethod,
125+
IReadOnlyList<IMethodSymbol> allPartialMethods,
126+
Compilation compilation,
127+
object?[]? args)
128+
{
129+
CSharpCompilation executableCompilation = BuildExecutionCompilation(allPartialMethods, compilation);
130+
131+
using MemoryStream stream = new();
132+
EmitResult emitResult = executableCompilation.Emit(stream);
133+
if (!emitResult.Success)
134+
{
135+
string errors = string.Join("; ", emitResult.Diagnostics
136+
.Where(diagnostic => diagnostic.Severity == DiagnosticSeverity.Error)
137+
.Select(diagnostic => diagnostic.GetMessage()));
138+
return (null, $"Compilation failed: {errors}");
139+
}
140+
141+
stream.Position = 0;
142+
AssemblyLoadContext? loadContext = null;
143+
try
144+
{
145+
loadContext = new AssemblyLoadContext("__GeneratorExec", isCollectible: true);
146+
loadContext.Resolving += (context, assemblyName) =>
147+
{
148+
PortableExecutableReference? match = compilation.References
149+
.OfType<PortableExecutableReference>()
150+
.FirstOrDefault(reference => reference.FilePath is not null && string.Equals(
151+
Path.GetFileNameWithoutExtension(reference.FilePath),
152+
assemblyName.Name,
153+
StringComparison.OrdinalIgnoreCase));
154+
return match?.FilePath != null ? context.LoadFromAssemblyPath(match.FilePath) : null;
155+
};
156+
157+
Assembly assembly = loadContext.LoadFromStream(stream);
158+
string typeName = generatorMethod.ContainingType.ToDisplayString();
159+
Type? loadedType = assembly.GetType(typeName);
160+
if (loadedType == null)
161+
{
162+
return (null, $"Could not find type '{typeName}' in compiled assembly");
163+
}
164+
165+
MethodInfo? generatorMethodInfo = loadedType.GetMethod(generatorMethod.Name, BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.Public);
166+
if (generatorMethodInfo == null)
167+
{
168+
return (null, $"Could not find method '{generatorMethod.Name}' in type '{typeName}'");
169+
}
170+
171+
object?[]? convertedArgs = ConvertArguments(args, generatorMethodInfo);
172+
object? result = generatorMethodInfo.Invoke(null, convertedArgs);
173+
return (result?.ToString(), null);
174+
}
175+
catch (Exception ex)
176+
{
177+
return (null, $"Error executing generator method '{generatorMethod.Name}': {ex.GetBaseException()}");
178+
}
179+
finally
180+
{
181+
loadContext?.Unload();
182+
}
183+
}
184+
185+
internal static IReadOnlyList<IMethodSymbol> GetAllUnimplementedPartialMethods(Compilation compilation)
186+
{
187+
List<IMethodSymbol> methods = new();
188+
foreach (SyntaxTree syntaxTree in compilation.SyntaxTrees)
189+
{
190+
SemanticModel semanticModel = compilation.GetSemanticModel(syntaxTree);
191+
IEnumerable<MethodDeclarationSyntax> partialMethodDeclarations = syntaxTree.GetRoot().DescendantNodes()
192+
.OfType<MethodDeclarationSyntax>()
193+
.Where(method => method.Modifiers.Any(modifier => modifier.IsKind(SyntaxKind.PartialKeyword)));
194+
195+
foreach (MethodDeclarationSyntax declaration in partialMethodDeclarations)
196+
{
197+
if (semanticModel.GetDeclaredSymbol(declaration) is IMethodSymbol symbol && symbol.IsPartialDefinition)
198+
{
199+
methods.Add(symbol);
200+
}
201+
}
202+
}
203+
204+
return methods;
205+
}
206+
207+
private static object?[]? ConvertArguments(object?[]? args, MethodInfo methodInfo)
208+
{
209+
if (args == null || methodInfo.GetParameters().Length == 0)
210+
{
211+
return null;
212+
}
213+
214+
Type parameterType = methodInfo.GetParameters()[0].ParameterType;
215+
return new[] { Convert.ChangeType(args[0], parameterType) };
216+
}
217+
218+
private static SwitchBodyData ExtractSwitchBodyData(object lastRecord, ITypeSymbol returnType)
219+
{
220+
Type recordType = lastRecord.GetType();
221+
PropertyInfo? caseKeysProperty = recordType.GetProperty("CaseKeys");
222+
PropertyInfo? caseValuesProperty = recordType.GetProperty("CaseValues");
223+
PropertyInfo? hasDefaultProperty = recordType.GetProperty("HasDefaultCase");
224+
225+
IList caseKeys = (caseKeysProperty?.GetValue(lastRecord) as IList) ?? new List<object>();
226+
IList caseValues = (caseValuesProperty?.GetValue(lastRecord) as IList) ?? new List<object?>();
227+
bool hasDefaultCase = (bool)(hasDefaultProperty?.GetValue(lastRecord) ?? false);
228+
229+
List<(object key, string value)> pairs = new();
230+
for (int index = 0; index < caseKeys.Count; index++)
231+
{
232+
object key = caseKeys[index]!;
233+
string? value = index < caseValues.Count ? caseValues[index]?.ToString() : null;
234+
pairs.Add((key, GeneratesMethodPatternSourceBuilder.FormatValueAsCSharpLiteral(value, returnType)));
235+
}
236+
237+
return new SwitchBodyData(pairs, hasDefaultCase);
238+
}
239+
240+
private static string ResolveImplementationAssemblyPath(string path)
241+
{
242+
string? directory = Path.GetDirectoryName(path);
243+
string? parentDirectory = directory != null ? Path.GetDirectoryName(directory) : null;
244+
if (directory != null &&
245+
parentDirectory != null &&
246+
string.Equals(Path.GetFileName(directory), "ref", StringComparison.OrdinalIgnoreCase))
247+
{
248+
return Path.Combine(parentDirectory, Path.GetFileName(path));
249+
}
250+
251+
return path;
252+
}
253+
254+
private static CSharpCompilation BuildExecutionCompilation(
255+
IReadOnlyList<IMethodSymbol> allPartialMethods,
256+
Compilation compilation)
257+
{
258+
string dummySource = BuildDummyImplementation(allPartialMethods);
259+
CSharpParseOptions parseOptions = compilation.SyntaxTrees.FirstOrDefault()?.Options as CSharpParseOptions
260+
?? CSharpParseOptions.Default;
261+
262+
return (CSharpCompilation)compilation
263+
.WithOptions(new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary))
264+
.AddSyntaxTrees(CSharpSyntaxTree.ParseText(dummySource, parseOptions));
265+
}
266+
267+
private static string BuildDummyImplementation(IEnumerable<IMethodSymbol> partialMethods)
268+
{
269+
StringBuilder builder = new();
270+
271+
IEnumerable<IGrouping<(string? Namespace, string TypeName, bool IsStatic, TypeKind TypeKind), IMethodSymbol>> groupedMethods = partialMethods.GroupBy(
272+
method => (Namespace: method.ContainingType.ContainingNamespace?.IsGlobalNamespace == false
273+
? method.ContainingType.ContainingNamespace.ToDisplayString()
274+
: null,
275+
TypeName: method.ContainingType.Name,
276+
IsStatic: method.ContainingType.IsStatic,
277+
TypeKind: method.ContainingType.TypeKind));
278+
279+
foreach (IGrouping<(string? Namespace, string TypeName, bool IsStatic, TypeKind TypeKind), IMethodSymbol> typeGroup in groupedMethods)
280+
{
281+
string? namespaceName = typeGroup.Key.Namespace;
282+
if (namespaceName != null)
283+
{
284+
builder.AppendLine($"namespace {namespaceName} {{");
285+
}
286+
287+
string typeKeyword = typeGroup.Key.TypeKind switch
288+
{
289+
TypeKind.Struct => "struct",
290+
_ => "class"
291+
};
292+
293+
string typeModifiers = typeGroup.Key.IsStatic ? "static partial" : "partial";
294+
builder.AppendLine($"{typeModifiers} {typeKeyword} {typeGroup.Key.TypeName} {{");
295+
296+
foreach (IMethodSymbol partialMethod in typeGroup)
297+
{
298+
string accessibility = partialMethod.DeclaredAccessibility switch
299+
{
300+
Accessibility.Public => "public",
301+
Accessibility.Protected => "protected",
302+
Accessibility.Internal => "internal",
303+
Accessibility.ProtectedOrInternal => "protected internal",
304+
Accessibility.ProtectedAndInternal => "private protected",
305+
_ => ""
306+
};
307+
308+
string staticModifier = partialMethod.IsStatic ? "static " : "";
309+
string returnType = partialMethod.ReturnType.ToDisplayString();
310+
string parameters = string.Join(", ", partialMethod.Parameters.Select(parameter => $"{parameter.Type.ToDisplayString()} {parameter.Name}"));
311+
312+
builder.AppendLine($"{accessibility} {staticModifier}partial {returnType} {partialMethod.Name}({parameters}) {{");
313+
if (!partialMethod.ReturnsVoid)
314+
{
315+
builder.AppendLine("return default!;");
316+
}
317+
318+
builder.AppendLine("}");
319+
}
320+
321+
builder.AppendLine("}");
322+
323+
if (namespaceName != null)
324+
{
325+
builder.AppendLine("}");
326+
}
327+
}
328+
329+
return builder.ToString();
330+
}
331+
}

0 commit comments

Comments
 (0)