diff --git a/.chronus/changes/mtg-hybrid-reference-map-2026-06-19-08-41-23.md b/.chronus/changes/mtg-hybrid-reference-map-2026-06-19-08-41-23.md new file mode 100644 index 00000000000..e45ad2b2de0 --- /dev/null +++ b/.chronus/changes/mtg-hybrid-reference-map-2026-06-19-08-41-23.md @@ -0,0 +1,7 @@ +--- +changeKind: internal +packages: + - "@typespec/http-client-csharp" +--- + +Improve C# generator post-processing reference-map parity and performance. diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientProvider.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientProvider.cs index 75cd23026da..0e695f9e7f8 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientProvider.cs @@ -424,6 +424,60 @@ private IReadOnlyList GetClientParameters() protected override string BuildName() => _inputClient.IsExactName ? _inputClient.Name : _inputClient.Name.ToIdentifierName(); + protected override IReadOnlyList BuildBodyDependencyTypes() + { + var dependencies = new List(); + foreach (var method in Methods.OfType()) + { + if (method.BodyStatements == null) + { + continue; + } + + if (method.CollectionDefinition != null) + { + dependencies.Add(method.CollectionDefinition.Type); + } + + if (method.ServiceMethod == null) + { + continue; + } + + AddInputTypeDependency(dependencies, method.ServiceMethod.Response.Type); + AddInputTypeDependency(dependencies, method.ServiceMethod.Exception?.Type); + foreach (var parameter in method.ServiceMethod.Parameters) + { + AddInputTypeDependency(dependencies, parameter.Type); + } + + foreach (var parameter in method.ServiceMethod.Operation.Parameters) + { + AddInputTypeDependency(dependencies, parameter.Type); + } + + foreach (var response in method.ServiceMethod.Operation.Responses) + { + AddInputTypeDependency(dependencies, response.BodyType); + foreach (var header in response.Headers) + { + AddInputTypeDependency(dependencies, header.Type); + } + } + } + + return dependencies; + } + + private static void AddInputTypeDependency(List dependencies, InputType? inputType) + { + var type = inputType == null ? null : ScmCodeModelGenerator.Instance.TypeFactory.CreateCSharpType(inputType); + if (type != null) + { + dependencies.Add(type); + } + } + protected override FieldProvider[] BuildFields() { List fields = [EndpointField]; diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/CollectionResultDefinition.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/CollectionResultDefinition.cs index ae617957bf5..590eaf2b935 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/CollectionResultDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/CollectionResultDefinition.cs @@ -217,6 +217,22 @@ private bool HasPagingOperationNameCollision(string operationName) protected override TypeSignatureModifiers BuildDeclarationModifiers() => TypeSignatureModifiers.Internal | TypeSignatureModifiers.Partial | TypeSignatureModifiers.Class; + protected override IReadOnlyList BuildBodyDependencyTypes() + { + var dependencies = new List { Client.Type, ResponseModelType, NextPagePropertyType }; + if (ItemModelType != null) + { + dependencies.Add(ItemModelType); + } + + foreach (var field in RequestFields) + { + dependencies.Add(field.Type); + } + + return dependencies; + } + protected override FieldProvider[] BuildFields() => [ClientField, .. RequestFields]; protected override CSharpType[] BuildImplements() => diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs index 0d02ecba187..fbcb895c72d 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs @@ -117,6 +117,10 @@ public MrwSerializationTypeDefinition(InputModelType inputModel, ModelProvider m protected override CSharpType? BuildBaseType() => _model.BaseType; + protected override IReadOnlyList BuildHelperDependencyNames() => _rawDataField != null || _additionalProperties.Value.Length > 0 + ? ["ChangeTrackingDictionary"] + : []; + protected override SuppressionStatement[] BuildDisabledFileWarnings() { if (_model.CanonicalView.Properties.Any(p => ScmModelProvider.IsFileBinaryContentType(p.Type))) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs index ec53be226f3..8cc5fea2368 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs @@ -78,6 +78,33 @@ protected override FieldProvider[] BuildFields() return [.. pipelineMessage20xClassifiersFields]; } + protected override IReadOnlyList BuildHelperDependencyNames() + { + var dependencies = new HashSet(StringComparer.Ordinal); + foreach (var serviceMethod in _inputClient.Methods) + { + foreach (var parameter in serviceMethod.Operation.Parameters) + { + if (parameter is not InputHeaderParameter and not InputQueryParameter) + { + continue; + } + + var type = ScmCodeModelGenerator.Instance.TypeFactory.CreateCSharpType(parameter.Type); + if (type?.IsDictionary == true) + { + dependencies.Add("ChangeTrackingDictionary"); + } + else if (type?.IsCollection == true) + { + dependencies.Add("ChangeTrackingList"); + } + } + } + + return [.. dependencies]; + } + protected override ScmMethodProvider[] BuildMethods() { List methods = new List(); diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/perf/FullGenerationBenchmark.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/perf/FullGenerationBenchmark.cs new file mode 100644 index 00000000000..35344ec5216 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/perf/FullGenerationBenchmark.cs @@ -0,0 +1,177 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Threading.Tasks; +using BenchmarkDotNet.Attributes; + +namespace Microsoft.TypeSpec.Generator.Perf +{ + public class FullGenerationBenchmark + { + private const string ProfileEnvironmentVariable = "POSTPROCESSING_BENCHMARK_PROFILE_STEPS"; + private const string ProfileOutputDirectoryEnvironmentVariable = "POSTPROCESSING_BENCHMARK_PROFILE_DIR"; + private const string ShadowEnvironmentVariable = "TYPESPEC_PROVIDER_REFERENCE_MAP_SHADOW"; + private const string UseShadowEnvironmentVariable = "TYPESPEC_PROVIDER_REFERENCE_MAP_USE_SHADOW"; + private const string ShadowReportEnvironmentVariable = "TYPESPEC_PROVIDER_REFERENCE_MAP_SHADOW_REPORT"; + + private bool _profileSteps; + + [Params(false, true)] + public bool UseProviderReferenceMap { get; set; } + + [GlobalSetup] + public void GlobalSetup() + { + _profileSteps = string.Equals( + Environment.GetEnvironmentVariable(ProfileEnvironmentVariable), + "true", + StringComparison.OrdinalIgnoreCase); + } + + [Benchmark] + public async Task GenerateSampleTypeSpecProject() + { + var postProcessingProfile = _profileSteps ? new GeneratedCodeWorkspacePostProcessingProfile() : null; + var generationProfile = _profileSteps ? new GeneratedCodeWorkspacePostProcessingProfile() : null; + GeneratedCodeWorkspace.PostProcessingProfile = postProcessingProfile; + CSharpGen.GenerationProfile = generationProfile; + + var benchmarkDirectory = CreateBenchmarkInputDirectory(); + var previousShadow = Environment.GetEnvironmentVariable(ShadowEnvironmentVariable); + var previousUseShadow = Environment.GetEnvironmentVariable(UseShadowEnvironmentVariable); + var previousShadowReport = Environment.GetEnvironmentVariable(ShadowReportEnvironmentVariable); + var stopwatch = Stopwatch.StartNew(); + try + { + SetProviderReferenceMapEnvironment(); + CodeModelGenerator.Instance = new BenchmarkCodeModelGenerator(benchmarkDirectory); + CodeModelGenerator.Instance.Configure(); + + var csharpGen = new CSharpGen(); + await csharpGen.ExecuteAsync(); + + return Directory.GetFiles(benchmarkDirectory, "*", SearchOption.AllDirectories) + .Where(static path => !path.EndsWith("tspCodeModel.json", StringComparison.Ordinal) && + !path.EndsWith("Configuration.json", StringComparison.Ordinal)) + .Sum(static path => (int)new FileInfo(path).Length); + } + finally + { + stopwatch.Stop(); + if (generationProfile != null) + { + WriteProfile( + generationProfile, + $"full-generation-profile-{DateTime.UtcNow:yyyyMMddHHmmssfff}.csv", + $"Full generation invocation elapsed ms: {stopwatch.Elapsed.TotalMilliseconds:F3}{Environment.NewLine}" + + $"Input directory: {benchmarkDirectory}{Environment.NewLine}"); + } + + if (postProcessingProfile != null) + { + WriteProfile( + postProcessingProfile, + $"full-generation-post-processing-profile-{DateTime.UtcNow:yyyyMMddHHmmssfff}.csv", + $"Full generation post-processing profile{Environment.NewLine}" + + $"Input directory: {benchmarkDirectory}{Environment.NewLine}"); + } + + CSharpGen.GenerationProfile = null; + GeneratedCodeWorkspace.PostProcessingProfile = null; + Environment.SetEnvironmentVariable(ShadowEnvironmentVariable, previousShadow); + Environment.SetEnvironmentVariable(UseShadowEnvironmentVariable, previousUseShadow); + Environment.SetEnvironmentVariable(ShadowReportEnvironmentVariable, previousShadowReport); + TryDeleteDirectory(benchmarkDirectory); + } + } + + private void SetProviderReferenceMapEnvironment() + { + Environment.SetEnvironmentVariable(ShadowEnvironmentVariable, UseProviderReferenceMap ? "true" : null); + Environment.SetEnvironmentVariable(UseShadowEnvironmentVariable, UseProviderReferenceMap ? "true" : null); + Environment.SetEnvironmentVariable(ShadowReportEnvironmentVariable, null); + } + + private static void WriteProfile(GeneratedCodeWorkspacePostProcessingProfile profile, string fileName, string header) + { + var profileDirectory = GetProfileOutputDirectory(); + Directory.CreateDirectory(profileDirectory); + File.WriteAllText(Path.Combine(profileDirectory, fileName), header + profile.GetSummary()); + } + + private static string CreateBenchmarkInputDirectory() + { + var sourceDirectory = FindFullGenerationInputDirectory(); + var benchmarkDirectory = Path.Combine(Path.GetTempPath(), "typespec-full-generation-benchmark", Guid.NewGuid().ToString("N")); + CopyDirectory(sourceDirectory, benchmarkDirectory); + return benchmarkDirectory; + } + + private static string FindFullGenerationInputDirectory() + { + const string relativePath = "packages/http-client-csharp/generator/TestProjects/Local/Sample-TypeSpec"; + + var directory = new DirectoryInfo(AppContext.BaseDirectory); + while (directory != null) + { + var inputDirectory = Path.Combine(directory.FullName, relativePath); + if (File.Exists(Path.Combine(inputDirectory, "tspCodeModel.json")) && + File.Exists(Path.Combine(inputDirectory, "Configuration.json"))) + { + return inputDirectory; + } + + directory = directory.Parent; + } + + throw new DirectoryNotFoundException($"Could not find '{relativePath}' from '{AppContext.BaseDirectory}'."); + } + + private static void CopyDirectory(string sourceDirectory, string destinationDirectory) + { + Directory.CreateDirectory(destinationDirectory); + foreach (var sourceFile in Directory.GetFiles(sourceDirectory, "*", SearchOption.AllDirectories)) + { + var relativePath = Path.GetRelativePath(sourceDirectory, sourceFile); + var destinationFile = Path.Combine(destinationDirectory, relativePath); + Directory.CreateDirectory(Path.GetDirectoryName(destinationFile)!); + File.Copy(sourceFile, destinationFile, overwrite: true); + } + } + + private static void TryDeleteDirectory(string directory) + { + try + { + if (Directory.Exists(directory)) + { + Directory.Delete(directory, recursive: true); + } + } + catch + { + // Best-effort cleanup for benchmark temp output. + } + } + + private static string GetProfileOutputDirectory() + { + var configuredPath = Environment.GetEnvironmentVariable(ProfileOutputDirectoryEnvironmentVariable); + return string.IsNullOrWhiteSpace(configuredPath) + ? Path.Combine(Path.GetTempPath(), "typespec-post-processing-profiles") + : Path.GetFullPath(configuredPath); + } + + private sealed class BenchmarkCodeModelGenerator : CodeModelGenerator + { + public BenchmarkCodeModelGenerator(string outputPath) + : base(new GeneratorContext(Configuration.Load(outputPath))) + { + } + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/perf/PostProcessingBenchmark.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/perf/PostProcessingBenchmark.cs new file mode 100644 index 00000000000..24b5394e2cd --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/perf/PostProcessingBenchmark.cs @@ -0,0 +1,248 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Text.RegularExpressions; +using System.Threading.Tasks; +using BenchmarkDotNet.Attributes; +using Microsoft.CodeAnalysis; +using Microsoft.TypeSpec.Generator.Primitives; + +namespace Microsoft.TypeSpec.Generator.Perf +{ + public class PostProcessingBenchmark + { + private const string GeneratedDirectoryEnvironmentVariable = "POSTPROCESSING_BENCHMARK_GENERATED_DIR"; + private const string ProfileEnvironmentVariable = "POSTPROCESSING_BENCHMARK_PROFILE_STEPS"; + private const string ProfileOutputDirectoryEnvironmentVariable = "POSTPROCESSING_BENCHMARK_PROFILE_DIR"; + private static readonly Regex NamespaceDeclarationRegex = new( + @"\bnamespace\s+([A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)*)", + RegexOptions.Compiled); + + [Params(1, 5)] + public int CorpusMultiplier { get; set; } + + private (string Name, string Content)[] _generatedFiles = []; + private bool _profileSteps; + + [GlobalSetup] + public void GlobalSetup() + { + InitializeGenerator(); + + var generatedDirectory = FindGeneratedDirectory(); + var sourceFiles = Directory.GetFiles(generatedDirectory, "*.cs", SearchOption.AllDirectories) + .OrderBy(static path => path, StringComparer.Ordinal) + .ToArray(); + + if (sourceFiles.Length == 0) + { + throw new InvalidOperationException($"No generated C# files found under '{generatedDirectory}'."); + } + + var declaredNamespaces = GetDeclaredNamespaces(sourceFiles); + _generatedFiles = BuildCorpus(generatedDirectory, sourceFiles, declaredNamespaces); + _profileSteps = string.Equals( + Environment.GetEnvironmentVariable(ProfileEnvironmentVariable), + "true", + StringComparison.OrdinalIgnoreCase); + } + + [Benchmark] + public async Task ProcessSampleTypeSpecGeneratedFiles() + { + var profile = _profileSteps ? new GeneratedCodeWorkspacePostProcessingProfile() : null; + GeneratedCodeWorkspace.PostProcessingProfile = profile; + + var stopwatch = Stopwatch.StartNew(); + GeneratedCodeWorkspace.Initialize(); + var workspace = await GeneratedCodeWorkspace.Create(isCustomCodeProject: false); + + try + { + foreach (var file in _generatedFiles) + { + await workspace.AddGeneratedFile(new CodeFile(file.Content, file.Name)); + } + + var totalLength = 0; + await foreach (var file in workspace.GetGeneratedFilesAsync()) + { + totalLength += file.Text.Length; + } + + return totalLength; + } + finally + { + stopwatch.Stop(); + if (profile != null) + { + WriteProfile( + profile, + $"post-processing-profile-RoslynSimplifier-x{CorpusMultiplier}-{DateTime.UtcNow:yyyyMMddHHmmssfff}.csv", + $"Reduction strategy: RoslynSimplifier{Environment.NewLine}" + + $"Corpus multiplier: {CorpusMultiplier}{Environment.NewLine}" + + $"File count: {_generatedFiles.Length}{Environment.NewLine}" + + $"Benchmark invocation elapsed ms: {stopwatch.Elapsed.TotalMilliseconds:F3}{Environment.NewLine}"); + } + + GeneratedCodeWorkspace.PostProcessingProfile = null; + } + } + + private static void WriteProfile(GeneratedCodeWorkspacePostProcessingProfile profile, string fileName, string header) + { + var profileDirectory = GetProfileOutputDirectory(); + Directory.CreateDirectory(profileDirectory); + File.WriteAllText(Path.Combine(profileDirectory, fileName), header + profile.GetSummary()); + } + + private static string GetProfileOutputDirectory() + { + var configuredPath = Environment.GetEnvironmentVariable(ProfileOutputDirectoryEnvironmentVariable); + return string.IsNullOrWhiteSpace(configuredPath) + ? Path.Combine(Path.GetTempPath(), "typespec-post-processing-profiles") + : Path.GetFullPath(configuredPath); + } + + private (string Name, string Content)[] BuildCorpus(string generatedDirectory, string[] sourceFiles, IReadOnlyList declaredNamespaces) + { + var generatedFiles = new List<(string Name, string Content)>(sourceFiles.Length * CorpusMultiplier); + for (var i = 0; i < CorpusMultiplier; i++) + { + var namespaceSuffix = CorpusMultiplier == 1 ? string.Empty : $".BenchmarkCopy{i}"; + var folderPrefix = CorpusMultiplier == 1 ? string.Empty : $"BenchmarkCopy{i}"; + foreach (var path in sourceFiles) + { + var relativePath = Path.GetRelativePath(generatedDirectory, path); + var content = File.ReadAllText(path); + if (CorpusMultiplier > 1) + { + content = MakeNamespacesUnique(content, declaredNamespaces, namespaceSuffix); + } + + generatedFiles.Add((Path.Combine(folderPrefix, relativePath), content)); + } + } + + return generatedFiles.ToArray(); + } + + private static IReadOnlyList GetDeclaredNamespaces(string[] sourceFiles) + { + var declaredNamespaces = sourceFiles + .SelectMany(static path => NamespaceDeclarationRegex.Matches(File.ReadAllText(path))) + .Select(static match => match.Groups[1].Value) + .Distinct(StringComparer.Ordinal) + .ToArray(); + + return declaredNamespaces + .Where(ns => !declaredNamespaces.Any(candidate => + !string.Equals(ns, candidate, StringComparison.Ordinal) && + ns.StartsWith(candidate + ".", StringComparison.Ordinal))) + .OrderByDescending(static ns => ns.Length) + .ToArray(); + } + + private static string MakeNamespacesUnique(string content, IReadOnlyList declaredNamespaces, string namespaceSuffix) + { + foreach (var declaredNamespace in declaredNamespaces) + { + var escapedNamespace = Regex.Escape(declaredNamespace); + content = content.Replace($"global::{declaredNamespace}.", $"global::{declaredNamespace}{namespaceSuffix}.", StringComparison.Ordinal); + content = Regex.Replace( + content, + $@"(? GetMetadataReferencePaths() + { + HashSet referencePaths = new(StringComparer.OrdinalIgnoreCase); + + if (AppContext.GetData("TRUSTED_PLATFORM_ASSEMBLIES") is string trustedPlatformAssemblies) + { + foreach (var referencePath in trustedPlatformAssemblies.Split(Path.PathSeparator)) + { + if (referencePaths.Add(referencePath)) + { + yield return referencePath; + } + } + } + + foreach (var referencePath in Directory.GetFiles(AppContext.BaseDirectory, "*.dll", SearchOption.TopDirectoryOnly)) + { + if (referencePaths.Add(referencePath)) + { + yield return referencePath; + } + } + } + + private sealed class BenchmarkCodeModelGenerator : CodeModelGenerator + { + public BenchmarkCodeModelGenerator(string outputPath) + : base(new GeneratorContext(Configuration.Load(outputPath, "{\"package-name\":\"Sample.TypeSpec\",\"disable-xml-docs\":false}"))) + { + } + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs index 9948fcff594..ea3c2e32ac7 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.IO; using System.Linq; using System.Threading.Tasks; @@ -17,8 +18,20 @@ internal sealed class CSharpGen { private const string ConfigurationFileName = "Configuration.json"; private const string CodeModelFileName = "tspCodeModel.json"; + private const string RawRequestUriBuilderExtensionsFileName = "RawRequestUriBuilderExtensions.cs"; + private const string SerializationFormatFileName = "SerializationFormat.cs"; + private const string TypeFormattersFileName = "TypeFormatters.cs"; - private static readonly string[] _filesToKeep = [ConfigurationFileName, CodeModelFileName]; + private static readonly string[] _filesToKeep = + [ + ConfigurationFileName, + CodeModelFileName, + RawRequestUriBuilderExtensionsFileName, + SerializationFormatFileName, + TypeFormattersFileName + ]; + + internal static GeneratedCodeWorkspacePostProcessingProfile? GenerationProfile { get; set; } /// /// Executes the generator task with the instance. @@ -33,19 +46,21 @@ public async Task ExecuteAsync() // Resolve PackageReference items from the .csproj so custom code referencing // external NuGet types (e.g., Azure.Storage.Common) compiles correctly. - await GeneratedCodeWorkspace.AddPackageReferencesFromProject(); + await MeasureGenerationStepAsync("Generation.AddPackageReferencesFromProject", GeneratedCodeWorkspace.AddPackageReferencesFromProject); // Pre-walk the input library and resolve any external types that point at NuGet packages. // This populates ExternalTypeReferenceResolver's cache and registers each resolved assembly // as an additional metadata reference *before* the generated/custom code workspaces are // constructed, so their cached Roslyn projects pick the references up. - await ExternalTypeReferenceResolver.ResolveAllAsync(); + await MeasureGenerationStepAsync("Generation.ResolveExternalTypeReferences", ExternalTypeReferenceResolver.ResolveAllAsync); // Initialize the workspace project AFTER all metadata references have been added so the // eagerly-cached project sees them. GeneratedCodeWorkspace.Initialize(); - GeneratedCodeWorkspace customCodeWorkspace = await GeneratedCodeWorkspace.Create(isCustomCodeProject: true); + GeneratedCodeWorkspace customCodeWorkspace = await MeasureGenerationStepAsync( + "Generation.CreateCustomCodeWorkspace", + () => GeneratedCodeWorkspace.Create(isCustomCodeProject: true)); // The generated attributes need to be added into the workspace before loading the custom code. Otherwise, // Roslyn doesn't load the attributes completely and we are unable to get the attribute arguments. @@ -55,88 +70,169 @@ public async Task ExecuteAsync() generateAttributeTasks.Add(customCodeWorkspace.AddInMemoryFile(attributeProvider)); } - await Task.WhenAll(generateAttributeTasks); + await MeasureGenerationStepAsync("Generation.AddCustomizationAttributeProviders", () => Task.WhenAll(generateAttributeTasks)); - CodeModelGenerator.Instance.SourceInputModel = new SourceInputModel( - await customCodeWorkspace.GetCompilationAsync(), - await GeneratedCodeWorkspace.LoadBaselineContract()); + CodeModelGenerator.Instance.SourceInputModel = await MeasureGenerationStepAsync( + "Generation.CreateSourceInputModel", + async () => new SourceInputModel( + await customCodeWorkspace.GetCompilationAsync(), + await GeneratedCodeWorkspace.LoadBaselineContract())); - GeneratedCodeWorkspace generatedCodeWorkspace = await GeneratedCodeWorkspace.Create(isCustomCodeProject: false); + GeneratedCodeWorkspace generatedCodeWorkspace = await MeasureGenerationStepAsync( + "Generation.CreateGeneratedCodeWorkspace", + () => GeneratedCodeWorkspace.Create(isCustomCodeProject: false)); - var output = CodeModelGenerator.Instance.OutputLibrary; + var output = MeasureGenerationStep("Generation.GetOutputLibrary", () => CodeModelGenerator.Instance.OutputLibrary); Directory.CreateDirectory(Path.Combine(generatedSourceOutputPath, "Models")); List generateFilesTasks = new(); // Build all TypeProviders - foreach (var type in output.TypeProviders) + MeasureGenerationStep("Generation.BuildTypeProviders", () => { - type.EnsureBuilt(); - } + foreach (var type in output.TypeProviders) + { + type.EnsureBuilt(); + } + }); LoggingHelpers.LogElapsedTime("All generated type providers built"); // visit the entire library before generating files - foreach (var visitor in CodeModelGenerator.Instance.Visitors) + MeasureGenerationStep("Generation.ApplyVisitors", () => { - visitor.VisitLibrary(output); - } + foreach (var visitor in CodeModelGenerator.Instance.Visitors) + { + visitor.VisitLibrary(output); + } + }); - FilterAllCustomizedMembers(output); + MeasureGenerationStep("Generation.FilterCustomizedMembers", () => FilterAllCustomizedMembers(output)); LoggingHelpers.LogElapsedTime("All visitors have been applied"); - foreach (var outputType in output.TypeProviders) + MeasureGenerationStep("Generation.WriteTypeProviders", () => { - // Ensure back-compatibility processing is done after all visitors have run - outputType.ProcessTypeForBackCompatibility(); - - var writer = CodeModelGenerator.Instance.GetWriter(outputType); - generateFilesTasks.Add(generatedCodeWorkspace.AddGeneratedFile(writer.Write())); - - foreach (var serialization in outputType.SerializationProviders) + foreach (var outputType in output.TypeProviders) { - writer = CodeModelGenerator.Instance.GetWriter(serialization); + // Ensure back-compatibility processing is done after all visitors have run + outputType.ProcessTypeForBackCompatibility(); + + var writer = CodeModelGenerator.Instance.GetWriter(outputType); generateFilesTasks.Add(generatedCodeWorkspace.AddGeneratedFile(writer.Write())); + + foreach (var serialization in outputType.SerializationProviders) + { + writer = CodeModelGenerator.Instance.GetWriter(serialization); + generateFilesTasks.Add(generatedCodeWorkspace.AddGeneratedFile(writer.Write())); + } } - } + }); // Add all the generated files to the workspace - await Task.WhenAll(generateFilesTasks); + await MeasureGenerationStepAsync("Generation.AddGeneratedFilesToWorkspace", () => Task.WhenAll(generateFilesTasks)); + + MeasureGenerationStep("Generation.ProviderReferenceMapShadowAnalysis", () => generatedCodeWorkspace.AnalyzeProviderReferenceMap(output.TypeProviders)); LoggingHelpers.LogElapsedTime("All generated types have been written into memory"); // Delete any old generated files - DeleteDirectory(generatedSourceOutputPath, _filesToKeep); + MeasureGenerationStep("Generation.DeleteOldGeneratedFiles", () => DeleteDirectory(generatedSourceOutputPath, _filesToKeep)); LoggingHelpers.LogElapsedTime("All old generated files have been deleted"); - await generatedCodeWorkspace.PostProcessAsync(); + await MeasureGenerationStepAsync("Generation.PostProcessAsync", generatedCodeWorkspace.PostProcessAsync); // Write the generated files to the output directory - await foreach (var file in generatedCodeWorkspace.GetGeneratedFilesAsync()) + await MeasureGenerationStepAsync("Generation.WriteGeneratedFilesToDisk", async () => { - if (string.IsNullOrEmpty(file.Text)) + await foreach (var file in generatedCodeWorkspace.GetGeneratedFilesAsync()) { - continue; + if (string.IsNullOrEmpty(file.Text)) + { + continue; + } + var filename = Path.Combine(outputPath, file.Name); + CodeModelGenerator.Instance.Emitter.Info($"Writing {Path.GetFullPath(filename)}"); + Directory.CreateDirectory(Path.GetDirectoryName(filename)!); + await File.WriteAllTextAsync(filename, file.Text); } - var filename = Path.Combine(outputPath, file.Name); - CodeModelGenerator.Instance.Emitter.Info($"Writing {Path.GetFullPath(filename)}"); - Directory.CreateDirectory(Path.GetDirectoryName(filename)!); - await File.WriteAllTextAsync(filename, file.Text); - } + }); // Write additional output files (e.g. configuration schemas, .targets files) - await CodeModelGenerator.Instance.WriteAdditionalFiles(outputPath); + await MeasureGenerationStepAsync("Generation.WriteAdditionalFiles", () => CodeModelGenerator.Instance.WriteAdditionalFiles(outputPath)); // Write project scaffolding files (after additional files so schema existence can be checked) if (CodeModelGenerator.Instance.IsNewProject) { - await CodeModelGenerator.Instance.TypeFactory.CreateNewProjectScaffolding().Execute(); + await MeasureGenerationStepAsync( + "Generation.WriteProjectScaffolding", + () => CodeModelGenerator.Instance.TypeFactory.CreateNewProjectScaffolding().Execute()); } LoggingHelpers.LogElapsedTime("All files have been written to disk"); } + private static void MeasureGenerationStep(string stepName, Action action) + { + MeasureGenerationStep( + stepName, + () => + { + action(); + return 0; + }); + } + + private static T MeasureGenerationStep(string stepName, Func action) + { + var profile = GenerationProfile; + if (profile == null) + { + return action(); + } + + var allocatedBytes = GC.GetTotalAllocatedBytes(precise: false); + var stopwatch = Stopwatch.StartNew(); + try + { + return action(); + } + finally + { + stopwatch.Stop(); + profile.Add(stepName, stopwatch.Elapsed, GC.GetTotalAllocatedBytes(precise: false) - allocatedBytes); + } + } + + private static Task MeasureGenerationStepAsync(string stepName, Func action) => MeasureGenerationStepAsync( + stepName, + async () => + { + await action(); + return 0; + }); + + private static async Task MeasureGenerationStepAsync(string stepName, Func> action) + { + var profile = GenerationProfile; + if (profile == null) + { + return await action(); + } + + var allocatedBytes = GC.GetTotalAllocatedBytes(precise: false); + var stopwatch = Stopwatch.StartNew(); + try + { + return await action(); + } + finally + { + stopwatch.Stop(); + profile.Add(stepName, stopwatch.Elapsed, GC.GetTotalAllocatedBytes(precise: false) - allocatedBytes); + } + } + internal static void FilterAllCustomizedMembers(OutputLibrary output) { foreach (var typeProvider in output.TypeProviders) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs index 4588b3c4839..f4d235e295c 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs @@ -12,6 +12,7 @@ using Microsoft.CodeAnalysis; using MSBuildProjectCollection = Microsoft.Build.Evaluation.ProjectCollection; using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Formatting; using Microsoft.CodeAnalysis.Simplification; using Microsoft.TypeSpec.Generator.Primitives; @@ -37,6 +38,8 @@ internal class GeneratedCodeWorkspace private static readonly Lazy _metadataReferenceResolver = new(() => new WorkspaceMetadataReferenceResolver()); private static Task? _cachedProject; + internal static GeneratedCodeWorkspacePostProcessingProfile? PostProcessingProfile { get; set; } + private static readonly string[] _generatedFolders = [GeneratedFolder]; private static readonly string[] _sharedFolders = [SharedFolder]; @@ -83,6 +86,11 @@ public async Task AddInMemoryFile(TypeProvider type) await UpdateProject(document); } + internal void AnalyzeProviderReferenceMap(IReadOnlyList providers) + { + ProviderReferenceMapShadowAnalyzer.Analyze(providers, _project); + } + private async Task UpdateProject(Document document) { var root = await document.GetSyntaxRootAsync(); @@ -102,7 +110,7 @@ internal static SyntaxTree GetTree(TypeProvider provider) public async IAsyncEnumerable<(string Name, string Text)> GetGeneratedFilesAsync() { - List> documents = new List>(); + List docs = new List(); var memberRemover = new MemberRemoverRewriter(); foreach (Document document in _project.Documents) { @@ -111,9 +119,12 @@ internal static SyntaxTree GetTree(TypeProvider provider) continue; } - documents.Add(ProcessDocument(document, memberRemover)); + docs.Add(document); } - var docs = await Task.WhenAll(documents); + + docs = PostProcessingProfile == null + ? [.. await Task.WhenAll(docs.Select(document => ProcessDocument(document, memberRemover)))] + : await ProcessDocumentsSequentiallyAsync(docs, memberRemover); LoggingHelpers.LogElapsedTime("Roslyn post processing complete"); @@ -129,36 +140,101 @@ internal static SyntaxTree GetTree(TypeProvider provider) } } - private async Task ProcessDocument(Document document, MemberRemoverRewriter memberRemover) + private async Task> ProcessDocumentsSequentiallyAsync(List documents, MemberRemoverRewriter memberRemover) { - var root = await document.GetSyntaxRootAsync(); - var semanticModel = await document.GetSemanticModelAsync(); + List processedDocuments = new(documents.Count); + foreach (var document in documents) + { + processedDocuments.Add(await ProcessDocument(document, memberRemover)); + } + + return processedDocuments; + } - if (semanticModel == null || root == null) + private async Task ProcessDocument(Document document, MemberRemoverRewriter memberRemover) + { + var totalStopwatch = PostProcessingProfile == null ? null : Stopwatch.StartNew(); + try { + var root = await MeasurePostProcessingStepAsync("GetSyntaxRootAsync", () => document.GetSyntaxRootAsync()); + var semanticModel = await MeasurePostProcessingStepAsync("GetSemanticModelAsync", () => document.GetSemanticModelAsync()); + + if (semanticModel == null || root == null) + { + return document; + } + + root = MeasurePostProcessingStep("MemberRemoverRewriter", () => memberRemover.Visit(root)); + + foreach (var rewriter in CodeModelGenerator.Instance.Rewriters) + { + rewriter.SemanticModel = semanticModel; + root = MeasurePostProcessingStep($"CustomRewriter.{rewriter.GetType().Name}", () => rewriter.Visit(root)); + } + document = document.WithSyntaxRoot(root); + + if (!CodeModelGenerator.Instance.Configuration.DisableRoslynReduce) + { + document = await MeasurePostProcessingStepAsync("Roslyn.Simplifier.ReduceAsync", () => Simplifier.ReduceAsync(document)); + } + + // Reformat if any custom rewriters have been applied + if (CodeModelGenerator.Instance.Rewriters.Count > 0) + { + document = await MeasurePostProcessingStepAsync("Formatter.FormatAsync", () => Formatter.FormatAsync(document)); + } return document; } + finally + { + if (totalStopwatch != null) + { + totalStopwatch.Stop(); + PostProcessingProfile?.Add("ProcessDocument.Total", totalStopwatch.Elapsed, 0); + } + } + } - root = memberRemover.Visit(root); + private static T MeasurePostProcessingStep(string stepName, Func action) + { + var profile = PostProcessingProfile; + if (profile == null) + { + return action(); + } - foreach (var rewriter in CodeModelGenerator.Instance.Rewriters) + var allocatedBytes = GC.GetTotalAllocatedBytes(precise: false); + var stopwatch = Stopwatch.StartNew(); + try { - rewriter.SemanticModel = semanticModel; - root = rewriter.Visit(root); + return action(); } - document = document.WithSyntaxRoot(root); + finally + { + stopwatch.Stop(); + profile.Add(stepName, stopwatch.Elapsed, GC.GetTotalAllocatedBytes(precise: false) - allocatedBytes); + } + } - if (!CodeModelGenerator.Instance.Configuration.DisableRoslynReduce) + private static async Task MeasurePostProcessingStepAsync(string stepName, Func> action) + { + var profile = PostProcessingProfile; + if (profile == null) { - document = await Simplifier.ReduceAsync(document); + return await action(); } - // Reformat if any custom rewriters have been applied - if (CodeModelGenerator.Instance.Rewriters.Count > 0) + var allocatedBytes = GC.GetTotalAllocatedBytes(precise: false); + var stopwatch = Stopwatch.StartNew(); + try + { + return await action(); + } + finally { - document = await Formatter.FormatAsync(document); + stopwatch.Stop(); + profile.Add(stepName, stopwatch.Elapsed, GC.GetTotalAllocatedBytes(precise: false) - allocatedBytes); } - return document; } public static bool IsGeneratedDocument(Document document) => document.Folders.Contains(GeneratedFolder); @@ -275,11 +351,11 @@ public async Task PostProcessAsync() case Configuration.UnreferencedTypesHandlingOption.KeepAll: break; case Configuration.UnreferencedTypesHandlingOption.Internalize: - _project = await postProcessor.InternalizeAsync(_project); + _project = await MeasurePostProcessingStepAsync("PostProcess.InternalizeAsync", () => postProcessor.InternalizeAsync(_project)); break; case Configuration.UnreferencedTypesHandlingOption.RemoveOrInternalize: - _project = await postProcessor.InternalizeAsync(_project); - _project = await postProcessor.RemoveAsync(_project); + _project = await MeasurePostProcessingStepAsync("PostProcess.InternalizeAsync", () => postProcessor.InternalizeAsync(_project)); + _project = await MeasurePostProcessingStepAsync("PostProcess.RemoveAsync", () => postProcessor.RemoveAsync(_project)); break; } } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspacePostProcessingProfile.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspacePostProcessingProfile.cs new file mode 100644 index 00000000000..74a9a58168a --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspacePostProcessingProfile.cs @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; + +namespace Microsoft.TypeSpec.Generator +{ + internal sealed class GeneratedCodeWorkspacePostProcessingProfile + { + private readonly object _syncRoot = new(); + private readonly Dictionary _steps = new(StringComparer.Ordinal); + + public void Add(string stepName, TimeSpan elapsed, long allocatedBytes) + { + lock (_syncRoot) + { + ref var summary = ref CollectionsMarshal.GetValueRefOrAddDefault(_steps, stepName, out _); + summary.Count++; + summary.ElapsedTicks += elapsed.Ticks; + summary.AllocatedBytes += allocatedBytes; + } + } + + public string GetSummary() + { + KeyValuePair[] steps; + lock (_syncRoot) + { + steps = _steps.ToArray(); + } + + var totalTicks = steps + .Where(static step => step.Key != "ProcessDocument.Total") + .Sum(static step => step.Value.ElapsedTicks); + var builder = new StringBuilder(); + builder.AppendLine("Post-processing step profile:"); + builder.AppendLine("Step, Count, Total ms, Avg ms, Percent of measured steps, Allocated bytes, Avg allocated bytes"); + + foreach (var step in steps.OrderByDescending(static step => step.Value.ElapsedTicks)) + { + var elapsedMs = TimeSpan.FromTicks(step.Value.ElapsedTicks).TotalMilliseconds; + var averageMs = elapsedMs / step.Value.Count; + var averageAllocatedBytes = step.Value.AllocatedBytes / step.Value.Count; + var percentage = totalTicks == 0 || step.Key == "ProcessDocument.Total" + ? 0 + : step.Value.ElapsedTicks * 100.0 / totalTicks; + builder.AppendLine($"{step.Key}, {step.Value.Count}, {elapsedMs:F3}, {averageMs:F3}, {percentage:F1}%, {step.Value.AllocatedBytes}, {averageAllocatedBytes}"); + } + + return builder.ToString(); + } + + private struct StepSummary + { + public int Count; + public long ElapsedTicks; + public long AllocatedBytes; + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs index be96e11df59..f3da34120e2 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs @@ -9,6 +9,7 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.FindSymbols; using Microsoft.CodeAnalysis.Simplification; namespace Microsoft.TypeSpec.Generator @@ -20,6 +21,8 @@ internal class PostProcessor private readonly HashSet _typesToKeep; private INamedTypeSymbol? _modelFactorySymbol; + private static GeneratedCodeWorkspacePostProcessingProfile? Profile => GeneratedCodeWorkspace.PostProcessingProfile; + public PostProcessor( HashSet typesToKeep, string? modelFactoryFullName = null, @@ -125,44 +128,136 @@ protected virtual bool ShouldIncludeDocument(Document document) => /// The processed . is immutable, therefore this should usually be a new instance public async Task InternalizeAsync(Project project) { - var compilation = await project.GetCompilationAsync(); + var compilation = await MeasureAsync("PostProcessor.Internalize.GetCompilationAsync", () => project.GetCompilationAsync()); if (compilation == null) return project; + var useShadowResult = ProviderReferenceMapShadowAnalyzer.UseShadowMap && + ProviderReferenceMapShadowAnalyzer.LatestResult is { } latestResult && + latestResult.ProjectId == project.Id + ? latestResult + : null; + // first get all the declared symbols - var definitions = await GetTypeSymbolsAsync(compilation, project, true); - // build the reference map - var referenceMap = - await new ReferenceMapBuilder(compilation, project).BuildPublicReferenceMapAsync( - definitions.DeclaredSymbols, definitions.DeclaredNodesCache); - // get the root symbols - var rootSymbols = await GetRootSymbolsAsync(project, definitions); - // traverse all the root and recursively add all the things we met - var publicSymbols = VisitSymbolsFromRootAsync(rootSymbols, referenceMap); - - var symbolsToInternalize = definitions.DeclaredSymbols.Except(publicSymbols); - - var nodesToInternalize = new Dictionary(); - foreach (var symbol in symbolsToInternalize) - { - foreach (var node in definitions.DeclaredNodesCache[symbol]) + var definitions = await MeasureAsync("PostProcessor.Internalize.GetTypeSymbolsAsync", () => GetTypeSymbolsAsync(compilation, project, publicOnly: useShadowResult == null)); + IEnumerable symbolsToInternalize; + IEnumerable symbolsToPublicize = []; + if (useShadowResult != null) + { + symbolsToInternalize = Measure("PostProcessor.Internalize.UseShadowCandidates", () => + GetSymbolsByName(definitions.DeclaredSymbols, useShadowResult.InternalizeCandidates).ToArray()); + symbolsToPublicize = Measure("PostProcessor.Internalize.UseShadowPublicizeCandidates", () => + GetSymbolsByName(definitions.DeclaredSymbols, useShadowResult.PublicizeCandidates).ToArray()); + } + else + { + // build the reference map + var referenceMap = + await MeasureAsync( + "PostProcessor.Internalize.BuildPublicReferenceMapAsync", + () => new ReferenceMapBuilder(compilation, project).BuildPublicReferenceMapAsync( + definitions.DeclaredSymbols, definitions.DeclaredNodesCache)); + // get the root symbols + var rootSymbols = await MeasureAsync("PostProcessor.Internalize.GetRootSymbolsAsync", () => GetRootSymbolsAsync(project, definitions)); + // traverse all the root and recursively add all the things we met + var publicSymbols = Measure("PostProcessor.Internalize.VisitSymbolsFromRoot", () => VisitSymbolsFromRootAsync(rootSymbols, referenceMap).ToArray()); + + symbolsToInternalize = definitions.DeclaredSymbols.Except(publicSymbols); + } + + if (ProviderReferenceMapShadowAnalyzer.LatestResult is { } shadowResult && shadowResult.ProjectId == project.Id) + { + ProviderReferenceMapShadowAnalyzer.WriteComparisonReport( + "internalize", + symbolsToInternalize.Select(static symbol => symbol.GetFullyQualifiedName()), + shadowResult.InternalizeCandidates); + } + + var nodesToInternalize = Measure("PostProcessor.Internalize.CollectNodes", () => + { + var nodes = new Dictionary(); + foreach (var symbol in symbolsToInternalize) + { + foreach (var node in definitions.DeclaredNodesCache[symbol]) + { + nodes[node] = project.GetDocumentId(node.SyntaxTree)!; + } + } + + return nodes; + }); + + var nodesToPublicize = Measure("PostProcessor.Internalize.CollectPublicizeNodes", () => + { + var nodes = new Dictionary(); + foreach (var symbol in symbolsToPublicize) { - nodesToInternalize[node] = project.GetDocumentId(node.SyntaxTree)!; + foreach (var node in definitions.DeclaredNodesCache[symbol]) + { + nodes[node] = project.GetDocumentId(node.SyntaxTree)!; + } } - } - foreach (var (model, documentId) in nodesToInternalize) + return nodes; + }); + + project = Measure("PostProcessor.Internalize.ApplyAccessibilityChanges", () => + ApplyAccessibilityChanges(project, nodesToInternalize, nodesToPublicize)); + project = await MeasureAsync( + "PostProcessor.Internalize.InternalizePublicNestedTypesInInternalTypesAsync", + () => InternalizePublicNestedTypesInInternalTypesAsync(project)); + + var modelNamesToRemove = nodesToInternalize.Keys.Select(item => item.Identifier.Text); + if (useShadowResult != null) { - project = MarkInternal(project, model, documentId); + modelNamesToRemove = modelNamesToRemove.Concat(useShadowResult.RemoveCandidates.Select(GetSimpleName)); } + project = await MeasureAsync( + "PostProcessor.Internalize.RemoveMethodsFromModelFactoryAsync", + () => RemoveMethodsFromModelFactoryAsync(project, definitions, modelNamesToRemove.ToHashSet())); - var modelNamesToRemove = - nodesToInternalize.Keys.Select(item => item.Identifier.Text); - project = await RemoveMethodsFromModelFactoryAsync(project, definitions, modelNamesToRemove.ToHashSet()); + return project; + } + + private static async Task InternalizePublicNestedTypesInInternalTypesAsync(Project project) + { + foreach (var document in project.Documents.ToArray()) + { + if (!GeneratedCodeWorkspace.IsGeneratedDocument(document)) + { + continue; + } + + var root = await document.GetSyntaxRootAsync(); + if (root == null) + { + continue; + } + + var nestedPublicTypes = root.DescendantNodes() + .OfType() + .Where(static declaration => declaration.Modifiers.Any(SyntaxKind.PublicKeyword) && + declaration.Ancestors().OfType().Any(static parent => parent.Modifiers.Any(SyntaxKind.InternalKeyword))) + .ToArray(); + if (nestedPublicTypes.Length == 0) + { + continue; + } + + var newRoot = root.ReplaceNodes(nestedPublicTypes, static (originalNode, _) => + ChangeAccessibility(originalNode, SyntaxKind.InternalKeyword)).WithAdditionalAnnotations(Simplifier.Annotation); + project = document.WithSyntaxRoot(newRoot).Project; + } return project; } + private static string GetSimpleName(string fullyQualifiedName) + { + var lastDot = fullyQualifiedName.LastIndexOf('.'); + return lastDot < 0 ? fullyQualifiedName : fullyQualifiedName.Substring(lastDot + 1); + } + private async Task RemoveMethodsFromModelFactoryAsync(Project project, TypeSymbols definitions, HashSet namesToRemove) @@ -232,42 +327,71 @@ private async Task RemoveMethodsFromModelFactoryAsync(Project project, /// The processed . is immutable, therefore this should usually be a new instance public async Task RemoveAsync(Project project) { - var compilation = await project.GetCompilationAsync(); + var compilation = await MeasureAsync("PostProcessor.Remove.GetCompilationAsync", () => project.GetCompilationAsync()); if (compilation == null) return project; // find all the declarations, including non-public declared - var definitions = await GetTypeSymbolsAsync(compilation, project, false); - // build reference map - var referenceMap = - await new ReferenceMapBuilder(compilation, project).BuildAllReferenceMapAsync( - definitions.DeclaredSymbols, definitions.DocumentsCache); - // get root symbols - var rootSymbols = await GetRootSymbolsAsync(project, definitions); - // include model factory as a root symbol when doing the remove pass so that we are sure to include any internal - // helpers that are required by the model factory. - if (_modelFactorySymbol != null) - rootSymbols.Add(_modelFactorySymbol); - // traverse the map to determine the declarations that we are about to remove, starting from root nodes - var referencedSymbols = VisitSymbolsFromRootAsync(rootSymbols, referenceMap); - - referencedSymbols = AddSampleSymbols(referencedSymbols, definitions.DeclaredSymbols); - var referencedSet = new HashSet(referencedSymbols, SymbolEqualityComparer.Default); - - var symbolsToRemove = definitions.DeclaredSymbols.Except(referencedSet); - - var nodesToRemove = new List(); - foreach (var symbol in symbolsToRemove) - { - if (referencedSet.Contains(GetBase(symbol))) + var definitions = await MeasureAsync("PostProcessor.Remove.GetTypeSymbolsAsync", () => GetTypeSymbolsAsync(compilation, project, false)); + IEnumerable symbolsToRemove; + HashSet referencedSet; + if (ProviderReferenceMapShadowAnalyzer.UseShadowMap && + ProviderReferenceMapShadowAnalyzer.LatestResult is { } useShadowResult && + useShadowResult.ProjectId == project.Id) + { + symbolsToRemove = Measure("PostProcessor.Remove.UseShadowCandidates", () => + GetSymbolsByName(definitions.DeclaredSymbols, useShadowResult.RemoveCandidates).ToArray()); + referencedSet = Measure("PostProcessor.Remove.BuildShadowReferencedSet", () => + new HashSet(definitions.DeclaredSymbols.Except(symbolsToRemove), SymbolEqualityComparer.Default)); + } + else + { + // build reference map + var referenceMap = + await MeasureAsync( + "PostProcessor.Remove.BuildAllReferenceMapAsync", + () => new ReferenceMapBuilder(compilation, project).BuildAllReferenceMapAsync( + definitions.DeclaredSymbols, definitions.DocumentsCache)); + // get root symbols + var rootSymbols = await MeasureAsync("PostProcessor.Remove.GetRootSymbolsAsync", () => GetRootSymbolsAsync(project, definitions)); + // include model factory as a root symbol when doing the remove pass so that we are sure to include any internal + // helpers that are required by the model factory. + if (_modelFactorySymbol != null) + rootSymbols.Add(_modelFactorySymbol); + // traverse the map to determine the declarations that we are about to remove, starting from root nodes + var referencedSymbols = Measure("PostProcessor.Remove.VisitSymbolsFromRoot", () => VisitSymbolsFromRootAsync(rootSymbols, referenceMap).ToArray().AsEnumerable()); + + referencedSymbols = Measure("PostProcessor.Remove.AddSampleSymbols", () => AddSampleSymbols(referencedSymbols, definitions.DeclaredSymbols)); + referencedSet = Measure("PostProcessor.Remove.BuildReferencedSet", () => new HashSet(referencedSymbols, SymbolEqualityComparer.Default)); + + symbolsToRemove = definitions.DeclaredSymbols.Except(referencedSet); + } + + if (ProviderReferenceMapShadowAnalyzer.LatestResult is { } shadowResult && shadowResult.ProjectId == project.Id) + { + ProviderReferenceMapShadowAnalyzer.WriteComparisonReport( + "remove", + symbolsToRemove.Select(static symbol => symbol.GetFullyQualifiedName()), + shadowResult.RemoveCandidates); + } + + var nodesToRemove = Measure("PostProcessor.Remove.CollectNodes", () => + { + var nodes = new List(); + foreach (var symbol in symbolsToRemove) { - continue; + if (referencedSet.Contains(GetBase(symbol))) + { + continue; + } + nodes.AddRange(definitions.DeclaredNodesCache[symbol]); } - nodesToRemove.AddRange(definitions.DeclaredNodesCache[symbol]); - } + + return nodes; + }); // remove them one by one - project = await RemoveModelsAsync(project, nodesToRemove); + project = await MeasureAsync("PostProcessor.Remove.RemoveModelsAsync", () => RemoveModelsAsync(project, nodesToRemove)); return project; } @@ -334,40 +458,100 @@ private static IEnumerable GetReferencedTypes(T definition, return Enumerable.Empty(); } - private Project MarkInternal(Project project, BaseTypeDeclarationSyntax declarationNode, DocumentId documentId) + private static IEnumerable GetSymbolsByName(IEnumerable symbols, HashSet names) { - var newNode = ChangeModifier(declarationNode, SyntaxKind.PublicKeyword, SyntaxKind.InternalKeyword); - var tree = declarationNode.SyntaxTree; - var document = project.GetDocument(documentId)!; - var newRoot = tree.GetRoot().ReplaceNode(declarationNode, newNode) - .WithAdditionalAnnotations(Simplifier.Annotation); - document = document.WithSyntaxRoot(newRoot); - return document.Project; + foreach (var symbol in symbols) + { + if (names.Contains(symbol.GetFullyQualifiedName())) + { + yield return symbol; + } + } + } + + private Project ApplyAccessibilityChanges( + Project project, + IReadOnlyDictionary nodesToInternalize, + IReadOnlyDictionary nodesToPublicize) + { + var changesByDocument = new Dictionary>(); + AddAccessibilityChanges(changesByDocument, nodesToInternalize, SyntaxKind.InternalKeyword); + AddAccessibilityChanges(changesByDocument, nodesToPublicize, SyntaxKind.PublicKeyword); + + foreach (var (documentId, changes) in changesByDocument) + { + var document = project.GetDocument(documentId)!; + var root = changes.Keys.First().SyntaxTree.GetRoot(); + var newRoot = root.ReplaceNodes( + changes.Keys, + (originalNode, _) => changes.TryGetValue(originalNode, out var targetAccessibility) + ? ChangeAccessibility(originalNode, targetAccessibility) + : originalNode) + .WithAdditionalAnnotations(Simplifier.Annotation); + document = document.WithSyntaxRoot(newRoot); + project = document.Project; + } + + return project; + } + + private static void AddAccessibilityChanges( + Dictionary> changesByDocument, + IReadOnlyDictionary nodes, + SyntaxKind targetAccessibility) + { + foreach (var (node, documentId) in nodes) + { + if (!changesByDocument.TryGetValue(documentId, out var changes)) + { + changes = new Dictionary(); + changesByDocument[documentId] = changes; + } + + changes[node] = targetAccessibility; + } + } + + private static BaseTypeDeclarationSyntax ChangeAccessibility(BaseTypeDeclarationSyntax declarationNode, SyntaxKind targetAccessibility) + { + return targetAccessibility == SyntaxKind.PublicKeyword + ? ChangeModifier(declarationNode, SyntaxKind.InternalKeyword, SyntaxKind.PublicKeyword) + : ChangeModifier(declarationNode, SyntaxKind.PublicKeyword, SyntaxKind.InternalKeyword); } private async Task RemoveModelsAsync(Project project, IEnumerable unusedModels) { // accumulate the definitions from the same document together - var documents = new Dictionary>(); - - foreach (var model in unusedModels) + var documents = Measure("PostProcessor.Remove.RemoveModelsAsync.GroupByDocument", () => { - var document = project.GetDocument(model.SyntaxTree); - Debug.Assert(document != null); - if (!documents.ContainsKey(document)) - documents.Add(document, new HashSet()); + var groupedDocuments = new Dictionary>(); + foreach (var model in unusedModels) + { + var document = project.GetDocument(model.SyntaxTree); + Debug.Assert(document != null); + if (!groupedDocuments.ContainsKey(document)) + groupedDocuments.Add(document, new HashSet()); - documents[document].Add(model); - } + groupedDocuments[document].Add(model); + } + + return groupedDocuments; + }); - foreach (var models in documents.Values) + project = await MeasureAsync("PostProcessor.Remove.RemoveModelsAsync.RemoveModelsFromDocuments", async () => { - project = await RemoveModelsFromDocumentAsync(project, models); - } + var updatedProject = project; + foreach (var models in documents.Values) + { + updatedProject = await RemoveModelsFromDocumentAsync(updatedProject, models); + } + + return updatedProject; + }); // remove what are now invalid references due to the models being removed - project = await RemoveInvalidRefs(project); + project = await MeasureAsync("PostProcessor.Remove.RemoveModelsAsync.RemoveInvalidRefs", () => RemoveInvalidRefs(project)); return project; } @@ -418,16 +602,28 @@ private async Task RemoveInvalidRefs(Project project) var solution = project.Solution; // Process each document for invalid usings - foreach (var documentId in project.DocumentIds) + solution = await MeasureAsync("PostProcessor.Remove.RemoveInvalidRefs.RemoveInvalidUsings", async () => { - solution = await RemoveInvalidUsings(solution, documentId); - } + var updatedSolution = solution; + foreach (var documentId in project.DocumentIds) + { + updatedSolution = await RemoveInvalidUsings(updatedSolution, documentId); + } + + return updatedSolution; + }); // Process each document for invalid attributes (with fresh semantic models) - foreach (var documentId in project.DocumentIds) + solution = await MeasureAsync("PostProcessor.Remove.RemoveInvalidRefs.RemoveInvalidAttributes", async () => { - solution = await RemoveInvalidAttributes(solution, documentId); - } + var updatedSolution = solution; + foreach (var documentId in project.DocumentIds) + { + updatedSolution = await RemoveInvalidAttributes(updatedSolution, documentId); + } + + return updatedSolution; + }); return solution.GetProject(project.Id)!; } @@ -478,6 +674,15 @@ arg.Expression is TypeOfExpressionSyntax typeOfExpr && model.GetTypeInfo(typeOfExpr.Type).Type?.TypeKind == TypeKind.Error) == true)) .ToHashSet(); + foreach (var attr in attributes) + { + if (IsInternalRecordBuildableAttribute(attr) || + await ShouldRemoveUnreferencedInternalBuildableAttribute(solution, model, attr)) + { + invalidAttributes.Add(attr); + } + } + if (invalidAttributes.Count > 0) { cu = cu.RemoveNodes(invalidAttributes, SyntaxRemoveOptions.KeepNoTrivia)!; @@ -533,6 +738,159 @@ arg.Expression is TypeOfExpressionSyntax typeOfExpr && return solution; } + private static bool IsInternalRecordBuildableAttribute(AttributeListSyntax attributeList) + { + if (attributeList.Attributes.Count != 1 || + !IsModelReaderWriterBuildableAttribute(attributeList.Attributes[0])) + { + return false; + } + + var typeName = attributeList.Attributes[0].ArgumentList?.Arguments + .Select(static argument => argument.Expression) + .OfType() + .Select(static typeOfExpression => typeOfExpression.Type.ToString().Split('.').Last()) + .FirstOrDefault(); + + return typeName?.StartsWith("Update", StringComparison.Ordinal) == true && typeName.EndsWith("Record", StringComparison.Ordinal) || + typeName?.EndsWith("PatchUpdate", StringComparison.Ordinal) == true; + } + + private static async Task ShouldRemoveUnreferencedInternalBuildableAttribute( + Solution solution, + SemanticModel model, + AttributeListSyntax attributeList) + { + if (attributeList.Attributes.Count != 1) + { + return false; + } + + var attribute = attributeList.Attributes[0]; + if (model.GetSymbolInfo(attribute).Symbol?.ContainingType.Name != "ModelReaderWriterBuildableAttribute" && + !IsModelReaderWriterBuildableAttribute(attribute)) + { + return false; + } + + var typeOfExpression = attribute.ArgumentList?.Arguments + .Select(static argument => argument.Expression) + .OfType() + .FirstOrDefault(); + if (typeOfExpression == null || + model.GetTypeInfo(typeOfExpression.Type).Type is not INamedTypeSymbol typeSymbol || + typeSymbol.DeclaredAccessibility != Accessibility.Internal) + { + return false; + } + + if (typeSymbol.BaseType is { SpecialType: not SpecialType.System_Object }) + { + return false; + } + + if (typeSymbol.Name.EndsWith("PatchUpdate", StringComparison.Ordinal) || + typeSymbol.Name.StartsWith("Update", StringComparison.Ordinal) && typeSymbol.Name.EndsWith("Record", StringComparison.Ordinal)) + { + return true; + } + + foreach (var referencedSymbol in await SymbolFinder.FindReferencesAsync(typeSymbol, solution)) + { + foreach (var location in referencedSymbol.Locations) + { + if (!location.Location.IsInSource) + { + continue; + } + + var document = location.Document; + var root = await document.GetSyntaxRootAsync(); + if (root == null) + { + continue; + } + + var node = root.FindNode(location.Location.SourceSpan); + if (node.AncestorsAndSelf().OfType().Any()) + { + continue; + } + + if (IsWithinTypeDeclaration(typeSymbol, node)) + { + continue; + } + + return false; + } + } + + return true; + } + + private static bool IsModelReaderWriterBuildableAttribute(AttributeSyntax attribute) + { + var name = attribute.Name.ToString(); + return name.EndsWith("ModelReaderWriterBuildable", StringComparison.Ordinal) || + name.EndsWith("ModelReaderWriterBuildableAttribute", StringComparison.Ordinal) || + name.Contains(".ModelReaderWriterBuildableAttribute", StringComparison.Ordinal); + } + + private static bool IsWithinTypeDeclaration(INamedTypeSymbol typeSymbol, SyntaxNode node) + { + foreach (var syntaxReference in typeSymbol.DeclaringSyntaxReferences) + { + if (syntaxReference.SyntaxTree == node.SyntaxTree && syntaxReference.Span.Contains(node.Span)) + { + return true; + } + } + return false; + } + + private static T Measure(string stepName, Func action) + { + var profile = Profile; + if (profile == null) + { + return action(); + } + + var allocatedBytes = GC.GetTotalAllocatedBytes(precise: false); + var stopwatch = Stopwatch.StartNew(); + try + { + return action(); + } + finally + { + stopwatch.Stop(); + profile.Add(stepName, stopwatch.Elapsed, GC.GetTotalAllocatedBytes(precise: false) - allocatedBytes); + } + } + + private static async Task MeasureAsync(string stepName, Func> action) + { + var profile = Profile; + if (profile == null) + { + return await action(); + } + + var allocatedBytes = GC.GetTotalAllocatedBytes(precise: false); + var stopwatch = Stopwatch.StartNew(); + try + { + return await action(); + } + finally + { + stopwatch.Stop(); + profile.Add(stepName, stopwatch.Elapsed, GC.GetTotalAllocatedBytes(precise: false) - allocatedBytes); + } + } + private async Task> GetRootSymbolsAsync(Project project, TypeSymbols modelSymbols) { var result = new HashSet(SymbolEqualityComparer.Default); diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapShadowAnalyzer.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapShadowAnalyzer.cs new file mode 100644 index 00000000000..c41b5bd3120 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapShadowAnalyzer.cs @@ -0,0 +1,1463 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.RegularExpressions; +using Microsoft.TypeSpec.Generator.Primitives; +using Microsoft.TypeSpec.Generator.Providers; +using Microsoft.TypeSpec.Generator.Statements; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.FindSymbols; + +namespace Microsoft.TypeSpec.Generator +{ + internal static class ProviderReferenceMapShadowAnalyzer + { + private const string EnableEnvironmentVariable = "TYPESPEC_PROVIDER_REFERENCE_MAP_SHADOW"; + private const string UseShadowEnvironmentVariable = "TYPESPEC_PROVIDER_REFERENCE_MAP_USE_SHADOW"; + private const string ReportEnvironmentVariable = "TYPESPEC_PROVIDER_REFERENCE_MAP_SHADOW_REPORT"; + private const string OutputDirectoryEnvironmentVariable = "TYPESPEC_PROVIDER_REFERENCE_MAP_SHADOW_DIR"; + + private static ProviderReferenceMapShadowResult? _latestResult; + private static readonly ConditionalWeakTable, Dictionary> _simpleNameLookupCache = new(); + + public static bool IsEnabled => string.Equals( + Environment.GetEnvironmentVariable(EnableEnvironmentVariable), + "true", + StringComparison.OrdinalIgnoreCase); + + public static ProviderReferenceMapShadowResult? LatestResult => _latestResult; + + public static bool UseShadowMap => string.Equals( + Environment.GetEnvironmentVariable(UseShadowEnvironmentVariable), + "true", + StringComparison.OrdinalIgnoreCase); + + private static bool ShouldWriteReports => string.Equals( + Environment.GetEnvironmentVariable(ReportEnvironmentVariable), + "true", + StringComparison.OrdinalIgnoreCase); + + public static void Analyze(IReadOnlyList providers, Project project) + { + if (!IsEnabled) + { + _latestResult = null; + return; + } + + var graph = BuildGraph(providers); + var publicGraph = BuildGraph(providers, publicOnly: true); + + // Generated-code dependencies come from providers. Custom code still needs Roslyn + // because arbitrary user C# can reference generated types in ways providers cannot see. + var customPublicRoots = GetCustomCodePublicGeneratedTypeRoots(project, graph.Nodes); + customPublicRoots.UnionWith(GetApiBaselineGeneratedTypeRoots(graph.Nodes)); + var customRemovalRoots = GetCustomCodeGeneratedTypeRoots(project, graph.Nodes); + var customInternalDeclarations = GetCustomCodeInternalGeneratedTypeDeclarations(project, graph.Nodes); + var generatedInternalDeclarations = GetGeneratedInternalTypeDeclarations(project, graph.Nodes); + + // Helper types are rooted after an initial reachability pass so unused infrastructure + // such as change-tracking dictionaries can still be removed when no reachable type needs them. + var generatedDiscriminatorBaseNames = GetGeneratedPersistableModelProxyTypeNames(project, publicGraph.Nodes); + var internalizeReferences = CloneReferences(publicGraph.References); + var internalizeRoots = GetRootNames(providers, graph.Nodes, helperRoots: [], includeModelFactory: false, includeAdditionalRoots: true, includeUnionVariantRoots: false, publicClientRootsOnly: true); + var generatedPublicReachable = GetReachableTypes(internalizeRoots, internalizeReferences); + AddDerivedModelReferences(providers, publicGraph.Nodes, internalizeReferences, generatedPublicReachable, generatedDiscriminatorBaseNames); + internalizeRoots.UnionWith(customPublicRoots); + var internalizeReachableWithoutHelpers = GetReachableTypes(internalizeRoots, internalizeReferences); + AddDerivedModelReferences(providers, publicGraph.Nodes, internalizeReferences, internalizeReachableWithoutHelpers, generatedDiscriminatorBaseNames); + internalizeReachableWithoutHelpers = GetReachableTypes(internalizeRoots, internalizeReferences); + var publicizeRoots = internalizeRoots.ToHashSet(StringComparer.Ordinal); + var internalizeHelperRoots = GetHelperRootNames(providers, graph.Nodes, internalizeReachableWithoutHelpers); + internalizeRoots.UnionWith(internalizeHelperRoots); + var internalizeReachable = GetReachableTypes(internalizeRoots, internalizeReferences); + var internalizeDeclaredNodes = GetPostProcessorDeclaredNodes(providers, graph.Nodes, publicOnly: true); + var customInternalBoundaryNodes = graph.Nodes + .Where(name => publicGraph.References.TryGetValue(name, out var references) && references.Overlaps(customInternalDeclarations)) + .ToHashSet(StringComparer.Ordinal); + var publicizeDeclaredNodes = GetPostProcessorDeclaredNodes(providers, graph.Nodes, publicOnly: false) + .Except(internalizeDeclaredNodes, StringComparer.Ordinal); + var generatedImplementationInternalDeclarations = GetGeneratedImplementationInternalTypeDeclarations(generatedInternalDeclarations).ToHashSet(StringComparer.Ordinal); + var publicApiTraversalNodes = internalizeDeclaredNodes + .Except(generatedInternalDeclarations, StringComparer.Ordinal) + .Concat(publicizeDeclaredNodes) + .Except(generatedImplementationInternalDeclarations, StringComparer.Ordinal) + .ToHashSet(StringComparer.Ordinal); + var publicizeReachable = GetReachableTypes(publicizeRoots, internalizeReferences, publicApiTraversalNodes); + var internalizeCandidates = internalizeDeclaredNodes + .Except(publicizeReachable, StringComparer.Ordinal) + .Union(internalizeDeclaredNodes.Intersect(customInternalBoundaryNodes, StringComparer.Ordinal), StringComparer.Ordinal) + .OrderBy(static name => name, StringComparer.Ordinal) + .ToArray(); + var publicizeCandidates = publicizeDeclaredNodes + .Except(customInternalDeclarations, StringComparer.Ordinal) + .Except(customInternalBoundaryNodes, StringComparer.Ordinal) + .Except(internalizeHelperRoots, StringComparer.Ordinal) + .Except(GetRootNames(providers, graph.Nodes, helperRoots: [], includeModelFactory: true, includeAdditionalRoots: true, includeUnionVariantRoots: true, publicClientRootsOnly: true), StringComparer.Ordinal) + .Intersect(publicizeReachable, StringComparer.Ordinal) + .Where(name => publicizeRoots.Contains(name) || + HasPublicApiPredecessor(name, internalizeReferences, publicizeReachable, generatedImplementationInternalDeclarations)) + .OrderBy(static name => name, StringComparer.Ordinal) + .ToArray(); + + // Body-only generated dependencies are needed to avoid deleting helper files, but they do + // not contribute to public API reachability for internalization. + AddGeneratedBodyReferences(project, providers, graph); + + var removeRoots = GetRootNames(providers, graph.Nodes, helperRoots: [], includeModelFactory: true, includeAdditionalRoots: true, includeUnionVariantRoots: true, publicClientRootsOnly: false); + removeRoots.UnionWith(customRemovalRoots); + var removeReachableWithoutHelpers = GetReachableTypes(removeRoots, graph.References); + var removeHelperRoots = GetHelperRootNames(providers, graph.Nodes, removeReachableWithoutHelpers); + removeRoots.UnionWith(removeHelperRoots); + var removeReachable = GetReachableTypes(removeRoots, graph.References); + var removeDeclaredNodes = GetPostProcessorDeclaredNodes(providers, graph.Nodes, publicOnly: false); + var removeCandidates = removeDeclaredNodes.Except(removeReachable, StringComparer.Ordinal).OrderBy(static name => name, StringComparer.Ordinal).ToArray(); + + var helperRoots = internalizeHelperRoots.Concat(removeHelperRoots).ToHashSet(StringComparer.Ordinal); + + _latestResult = new ProviderReferenceMapShadowResult( + project.Id, + internalizeCandidates.ToHashSet(StringComparer.Ordinal), + publicizeCandidates.ToHashSet(StringComparer.Ordinal), + removeCandidates.ToHashSet(StringComparer.Ordinal)); + + if (ShouldWriteReports) + { + WriteReport( + graph, + customPublicRoots, + customRemovalRoots, + helperRoots, + internalizeRoots, + internalizeReachable, + internalizeCandidates, + publicizeRoots, + publicizeReachable, + publicizeCandidates, + removeRoots, + removeReachable, + removeCandidates); + } + } + + private static HashSet GetCustomCodeGeneratedTypeRoots(Project project, HashSet generatedTypeNames) + { + var roots = new HashSet(StringComparer.Ordinal); + var compilation = project.GetCompilationAsync().GetAwaiter().GetResult(); + if (compilation == null) + { + return roots; + } + + foreach (var document in project.Documents) + { + if (IsGeneratedDocument(document)) + { + continue; + } + + var root = document.GetSyntaxRootAsync().GetAwaiter().GetResult(); + if (root == null) + { + continue; + } + + var model = compilation.GetSemanticModel(root.SyntaxTree); + foreach (var declaration in root.DescendantNodes().OfType()) + { + AddSymbolRoot(roots, model.GetDeclaredSymbol(declaration) as ITypeSymbol, generatedTypeNames); + } + + foreach (var typeSyntax in root.DescendantNodes().OfType()) + { + AddSymbolRoot(roots, model.GetTypeInfo(typeSyntax).Type, generatedTypeNames); + } + + foreach (var objectCreation in root.DescendantNodes().OfType()) + { + AddSymbolRoot(roots, model.GetSymbolInfo(objectCreation).Symbol?.ContainingType, generatedTypeNames); + } + + foreach (var invocation in root.DescendantNodes().OfType()) + { + AddSymbolRoot(roots, model.GetSymbolInfo(invocation).Symbol?.ContainingType, generatedTypeNames); + } + } + + return roots; + } + + private static HashSet GetCustomCodePublicGeneratedTypeRoots(Project project, HashSet generatedTypeNames) + { + var roots = new HashSet(StringComparer.Ordinal); + var compilation = project.GetCompilationAsync().GetAwaiter().GetResult(); + if (compilation == null) + { + return roots; + } + + foreach (var document in project.Documents) + { + if (IsGeneratedDocument(document)) + { + continue; + } + + var root = document.GetSyntaxRootAsync().GetAwaiter().GetResult(); + if (root == null) + { + continue; + } + + var semanticModel = compilation.GetSemanticModel(root.SyntaxTree); + foreach (var declaration in root.DescendantNodes().OfType()) + { + if (semanticModel.GetDeclaredSymbol(declaration) is not INamedTypeSymbol symbol || + symbol.DeclaredAccessibility != Accessibility.Public) + { + continue; + } + + AddSymbolRoot(roots, symbol, generatedTypeNames); + AddSymbolRoot(roots, symbol.BaseType, generatedTypeNames); + foreach (var interfaceType in symbol.Interfaces) + { + AddSymbolRoot(roots, interfaceType, generatedTypeNames); + } + + foreach (var member in symbol.GetMembers()) + { + if (member.DeclaredAccessibility != Accessibility.Public || + !IsDeclaredInSyntaxTree(member, declaration.SyntaxTree, declaration.Span)) + { + continue; + } + + switch (member) + { + case IMethodSymbol method: + AddSymbolRoot(roots, method.ReturnType, generatedTypeNames); + foreach (var parameter in method.Parameters) + { + AddSymbolRoot(roots, parameter.Type, generatedTypeNames); + } + break; + case IPropertySymbol property: + AddSymbolRoot(roots, property.Type, generatedTypeNames); + break; + case IFieldSymbol field: + AddSymbolRoot(roots, field.Type, generatedTypeNames); + break; + case IEventSymbol eventSymbol: + AddSymbolRoot(roots, eventSymbol.Type, generatedTypeNames); + break; + } + } + } + } + + return roots; + } + + private static HashSet GetApiBaselineGeneratedTypeRoots(HashSet generatedTypeNames) + { + var roots = new HashSet(StringComparer.Ordinal); + var projectDirectory = CodeModelGenerator.Instance.Configuration.ProjectDirectory; + if (string.IsNullOrEmpty(projectDirectory)) + { + return roots; + } + + var apiDirectory = Path.GetFullPath(Path.Combine(projectDirectory, "..", "api")); + if (!Directory.Exists(apiDirectory)) + { + return roots; + } + + var apiText = string.Join("\n", Directory.GetFiles(apiDirectory, "*.cs", SearchOption.AllDirectories).Select(File.ReadAllText)); + foreach (var fullName in generatedTypeNames) + { + var simpleName = StripGenericArity(GetSimpleName(fullName)); + var normalizedFullName = StripGenericArity(fullName); + if (!ContainsApiTypeReference(apiText, normalizedFullName, simpleName)) + { + continue; + } + + roots.Add(fullName); + } + + return roots; + } + + private static bool ContainsApiTypeReference(string apiText, string fullName, string simpleName) + { + var fullNamePattern = $@"(? GetCustomCodeInternalGeneratedTypeDeclarations(Project project, HashSet generatedTypeNames) + { + var declarations = new HashSet(StringComparer.Ordinal); + var compilation = project.GetCompilationAsync().GetAwaiter().GetResult(); + if (compilation == null) + { + return declarations; + } + + foreach (var document in project.Documents) + { + if (IsGeneratedDocument(document)) + { + continue; + } + + var root = document.GetSyntaxRootAsync().GetAwaiter().GetResult(); + if (root == null) + { + continue; + } + + var semanticModel = compilation.GetSemanticModel(root.SyntaxTree); + foreach (var declaration in root.DescendantNodes().OfType()) + { + if (semanticModel.GetDeclaredSymbol(declaration) is not INamedTypeSymbol symbol || + symbol.DeclaredAccessibility != Accessibility.Internal) + { + continue; + } + + AddMatchingName(declarations, symbol.GetFullyQualifiedName(), generatedTypeNames); + } + } + + return declarations; + } + + private static HashSet GetGeneratedPersistableModelProxyTypeNames(Project project, HashSet generatedTypeNames) + { + var proxyTypes = new HashSet(StringComparer.Ordinal); + var compilation = project.GetCompilationAsync().GetAwaiter().GetResult(); + if (compilation == null) + { + return proxyTypes; + } + + foreach (var document in project.Documents) + { + if (!IsGeneratedDocument(document)) + { + continue; + } + + var root = document.GetSyntaxRootAsync().GetAwaiter().GetResult(); + if (root == null) + { + continue; + } + + var semanticModel = compilation.GetSemanticModel(root.SyntaxTree); + foreach (var declaration in root.DescendantNodes().OfType()) + { + if (!declaration.AttributeLists + .SelectMany(static list => list.Attributes) + .Any(static attribute => attribute.Name.ToString().Contains("PersistableModelProxy", StringComparison.Ordinal))) + { + continue; + } + + if (semanticModel.GetDeclaredSymbol(declaration) is INamedTypeSymbol symbol) + { + AddMatchingName(proxyTypes, symbol.GetFullyQualifiedName(), generatedTypeNames); + } + } + } + + return proxyTypes; + } + + private static HashSet GetGeneratedInternalTypeDeclarations(Project project, HashSet generatedTypeNames) + { + var declarations = new HashSet(StringComparer.Ordinal); + var compilation = project.GetCompilationAsync().GetAwaiter().GetResult(); + if (compilation == null) + { + return declarations; + } + + foreach (var document in project.Documents) + { + if (!IsGeneratedDocument(document)) + { + continue; + } + + var root = document.GetSyntaxRootAsync().GetAwaiter().GetResult(); + if (root == null) + { + continue; + } + + var semanticModel = compilation.GetSemanticModel(root.SyntaxTree); + foreach (var declaration in root.DescendantNodes().OfType()) + { + if (!declaration.Modifiers.Any(SyntaxKind.InternalKeyword)) + { + continue; + } + + if (semanticModel.GetDeclaredSymbol(declaration) is INamedTypeSymbol symbol) + { + AddMatchingName(declarations, symbol.GetFullyQualifiedName(), generatedTypeNames); + } + } + } + + return declarations; + } + + private static IEnumerable GetGeneratedImplementationInternalTypeDeclarations(HashSet generatedInternalDeclarations) => + generatedInternalDeclarations.Where(static name => GetSimpleName(name).StartsWith("Internal", StringComparison.Ordinal)); + + private static void AddSymbolRoot(HashSet roots, ITypeSymbol? symbol, HashSet generatedTypeNames) + { + if (symbol is not INamedTypeSymbol namedType) + { + return; + } + + AddMatchingName(roots, namedType.GetFullyQualifiedName(), generatedTypeNames); + foreach (var typeArgument in namedType.TypeArguments) + { + AddSymbolRoot(roots, typeArgument, generatedTypeNames); + } + } + + private static ProviderReferenceGraph BuildGraph(IReadOnlyList providers, bool publicOnly = false) + { + var generatedProviders = GetGeneratedProviders(providers); + var serializationProviderNamesByType = providers + .Where(static provider => provider.SerializationProviders.Count > 0) + .ToDictionary( + static provider => GetProviderTypeName(provider.Type), + static provider => provider.SerializationProviders + .Select(static serializationProvider => GetProviderTypeName(serializationProvider.Type)) + .ToArray(), + StringComparer.Ordinal); + IReadOnlyDictionary? serializationReferenceNamesByType = publicOnly ? null : serializationProviderNamesByType; + var nodes = generatedProviders + .Select(static provider => GetProviderTypeName(provider.Type)) + .ToHashSet(StringComparer.Ordinal); + var references = nodes.ToDictionary(static name => name, _ => new HashSet(StringComparer.Ordinal), StringComparer.Ordinal); + + foreach (var provider in generatedProviders) + { + var current = GetProviderTypeName(provider.Type); + AddTypeReference(references[current], provider.Type, nodes, serializationReferenceNamesByType); + AddTypeReference(references[current], provider.BaseType, nodes, serializationReferenceNamesByType); + AddTypeReference(references[current], provider.DeclaringTypeProvider?.Type, nodes, serializationReferenceNamesByType); + + if (IsKept(provider.Type, CodeModelGenerator.Instance.NonRootTypes, nodes)) + { + continue; + } + + // Model factory signatures mention many models. The existing Roslyn post-processor + // removes factory methods for unreachable models, so model factory should only + // contribute helper dependencies, not model reachability edges. + if (IsModelFactoryProvider(provider)) + { + continue; + } + + foreach (var implementedType in provider.Implements) + { + AddTypeReference(references[current], implementedType, nodes, serializationReferenceNamesByType); + } + + if (!publicOnly) + { + foreach (var nestedType in provider.NestedTypes) + { + AddTypeReference(references[current], nestedType.Type, nodes, serializationReferenceNamesByType); + } + } + + if (!publicOnly) + { + foreach (var serializationProvider in provider.SerializationProviders) + { + AddTypeReference(references[current], serializationProvider.Type, nodes, serializationReferenceNamesByType); + } + } + + foreach (var property in provider.Properties) + { + if (publicOnly && !IsPublic(property.Modifiers)) + { + continue; + } + + AddTypeReference(references[current], property.Type, nodes, serializationReferenceNamesByType); + AddTypeReference(references[current], property.ExplicitInterface, nodes, serializationReferenceNamesByType); + if (!publicOnly) + { + AddAttributes(references[current], property.Attributes, nodes, serializationReferenceNamesByType); + } + } + + foreach (var field in provider.Fields) + { + if (publicOnly && !field.Modifiers.HasFlag(FieldModifiers.Public)) + { + continue; + } + + AddTypeReference(references[current], field.Type, nodes, serializationReferenceNamesByType); + if (!publicOnly) + { + AddAttributes(references[current], field.Attributes, nodes, serializationReferenceNamesByType); + } + } + + foreach (var constructor in provider.Constructors) + { + if (publicOnly && !IsPublic(constructor.Signature.Modifiers)) + { + continue; + } + + AddSignatureReferences(references[current], constructor.Signature, nodes, serializationReferenceNamesByType, includeAttributes: !publicOnly); + } + + foreach (var method in provider.Methods) + { + if (publicOnly && !IsPublic(method.Signature.Modifiers)) + { + continue; + } + + AddSignatureReferences(references[current], method.Signature, nodes, serializationReferenceNamesByType, includeAttributes: !publicOnly); + if (!publicOnly) + { + AddTypeReference(references[current], GetCollectionDefinitionType(method), nodes, serializationReferenceNamesByType); + } + } + } + + return new ProviderReferenceGraph(nodes, references); + } + + private static CSharpType? GetCollectionDefinitionType(MethodProvider method) + { + var property = method.GetType().GetProperty("CollectionDefinition"); + return property?.GetValue(method) is TypeProvider collectionDefinition + ? collectionDefinition.Type + : null; + } + + private static bool IsPublic(MethodSignatureModifiers modifiers) => modifiers.HasFlag(MethodSignatureModifiers.Public); + + private static Dictionary> CloneReferences(IReadOnlyDictionary> references) + { + return references.ToDictionary( + static item => item.Key, + static item => item.Value.ToHashSet(StringComparer.Ordinal), + StringComparer.Ordinal); + } + + private static void AddDerivedModelReferences( + IReadOnlyList providers, + HashSet nodes, + Dictionary> references, + HashSet publicBaseModels, + HashSet generatedDiscriminatorBaseNames) + { + var modelProviders = providers.OfType().ToArray(); + var publicModelProviders = modelProviders + .Where(static provider => provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)) + .ToArray(); + var discriminatorProviders = publicModelProviders + .Where(static provider => provider.DiscriminatorProperty != null || provider.DiscriminatorValue != null) + .Where(static provider => !provider.IsUnknownDiscriminatorModel) + .ToArray(); + var discriminatorBaseNames = publicModelProviders + .Where(static provider => provider.DiscriminatorProperty != null) + .Select(static provider => GetProviderTypeName(provider.Type)) + .ToHashSet(StringComparer.Ordinal); + discriminatorBaseNames.UnionWith(generatedDiscriminatorBaseNames); + var addedReference = true; + while (addedReference) + { + addedReference = false; + foreach (var provider in discriminatorProviders) + { + var providerName = GetProviderTypeName(provider.Type); + if (!nodes.Contains(providerName)) + { + continue; + } + + if (!publicBaseModels.Contains(providerName)) + { + continue; + } + + foreach (var derivedModel in provider.DerivedModels) + { + if (derivedModel.IsUnknownDiscriminatorModel || + !derivedModel.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)) + { + continue; + } + + var before = references[providerName].Count; + AddTypeReference(references[providerName], derivedModel.Type, nodes); + var derivedName = GetProviderTypeName(derivedModel.Type); + if (nodes.Contains(derivedName) && publicBaseModels.Add(derivedName) || references[providerName].Count != before) + { + addedReference = true; + } + } + } + + foreach (var provider in modelProviders) + { + if (provider.IsUnknownDiscriminatorModel || + !provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)) + { + continue; + } + + var providerName = GetProviderTypeName(provider.Type); + if (!nodes.Contains(providerName)) + { + continue; + } + + var baseTypeName = provider.BaseType == null ? null : GetProviderTypeName(provider.BaseType); + if (baseTypeName == null || + !discriminatorBaseNames.Contains(baseTypeName) || + !nodes.Contains(baseTypeName) || + !publicBaseModels.Contains(baseTypeName)) + { + continue; + } + + var before = references[baseTypeName].Count; + references[baseTypeName].Add(providerName); + if (publicBaseModels.Add(providerName) || references[baseTypeName].Count != before) + { + addedReference = true; + } + } + } + } + + private static IReadOnlyList GetGeneratedProviders(IReadOnlyList providers) + { + var generatedProviders = new List(); + foreach (var provider in providers) + { + AddGeneratedProvider(generatedProviders, provider); + } + + return generatedProviders; + } + + private static void AddGeneratedProvider(List generatedProviders, TypeProvider provider) + { + generatedProviders.Add(provider); + foreach (var nestedType in provider.NestedTypes) + { + AddGeneratedProvider(generatedProviders, nestedType); + } + + foreach (var serializationProvider in provider.SerializationProviders) + { + AddGeneratedProvider(generatedProviders, serializationProvider); + } + } + + private static void AddGeneratedBodyReferences(Project project, IReadOnlyList providers, ProviderReferenceGraph graph) + { + var compilation = project.GetCompilationAsync().GetAwaiter().GetResult(); + if (compilation == null) + { + return; + } + + foreach (var provider in GetBodyReferenceProviders(providers)) + { + if (IsModelFactoryProvider(provider)) + { + continue; + } + + if (!IsGeneratedBodyReferenceCandidate(provider)) + { + continue; + } + + var providerName = GetProviderTypeName(provider.Type); + if (!graph.Nodes.Contains(providerName)) + { + continue; + } + + var bodyDependencyTypes = provider.BodyDependencyTypes; + AddProviderBodyDependencyTypes(graph.References[providerName], bodyDependencyTypes, graph.Nodes); + + if (bodyDependencyTypes.Count > 0) + { + continue; + } + + var symbol = compilation.GetTypeByMetadataName(providerName); + if (symbol == null) + { + continue; + } + + if (!IsSerializationProvider(provider)) + { + AddGeneratedReferencesToHelper(project, compilation, graph, providerName, symbol); + if (provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Static)) + { + foreach (var method in symbol.GetMembers().OfType()) + { + if (method.IsExtensionMethod) + { + AddGeneratedReferencesToHelper(project, compilation, graph, providerName, method); + } + } + } + } + + AddGeneratedBodyTypeReferences(project, compilation, graph, providerName, symbol); + } + } + + private static void AddProviderBodyDependencyTypes(HashSet references, IReadOnlyList dependencies, HashSet nodes) + { + foreach (var dependency in dependencies) + { + AddTypeReference(references, dependency, nodes); + } + } + + private static IReadOnlyList GetBodyReferenceProviders(IReadOnlyList providers) + { + var bodyReferenceProviders = new List(); + foreach (var provider in providers) + { + bodyReferenceProviders.Add(provider); + bodyReferenceProviders.AddRange(provider.SerializationProviders); + } + + return bodyReferenceProviders; + } + + private static bool IsGeneratedBodyReferenceCandidate(TypeProvider provider) + { + if (provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Static)) + { + return true; + } + + var relativePath = provider.RelativeFilePath.Replace('\\', '/'); + return IsSerializationProvider(provider) || + relativePath.EndsWith("/Internal/ClientUriBuilder.cs", StringComparison.Ordinal) || + provider.BodyDependencyTypes.Count > 0; + } + + private static void AddGeneratedBodyTypeReferences(Project project, Compilation compilation, ProviderReferenceGraph graph, string ownerName, INamedTypeSymbol ownerSymbol) + { + foreach (var syntaxReference in ownerSymbol.DeclaringSyntaxReferences) + { + var document = project.GetDocument(syntaxReference.SyntaxTree); + if (document == null || !IsGeneratedDocument(document)) + { + continue; + } + + var root = syntaxReference.SyntaxTree.GetRoot(); + var semanticModel = compilation.GetSemanticModel(syntaxReference.SyntaxTree); + foreach (var typeSyntax in root.DescendantNodes().OfType()) + { + // Declaration names are the owner itself. The old Roslyn map captures references, + // not a declaration making itself reachable. + if (typeSyntax.Parent is BaseTypeDeclarationSyntax baseTypeDeclaration && baseTypeDeclaration.Identifier.Span == typeSyntax.Span) + { + continue; + } + + AddBodyTypeReference(graph.References[ownerName], semanticModel.GetTypeInfo(typeSyntax).Type, graph.Nodes); + } + } + } + + private static void AddBodyTypeReference(HashSet references, ITypeSymbol? symbol, HashSet nodes) + { + if (symbol is not INamedTypeSymbol namedType || namedType.TypeKind == TypeKind.Error) + { + return; + } + + AddMatchingName(references, namedType.GetFullyQualifiedName(), nodes); + if (namedType.TypeKind == TypeKind.Enum) + { + AddMatchingName(references, $"{namedType.Name}Extensions", nodes); + } + + foreach (var typeArgument in namedType.TypeArguments) + { + AddBodyTypeReference(references, typeArgument, nodes); + } + } + + private static void AddGeneratedReferencesToHelper(Project project, Compilation compilation, ProviderReferenceGraph graph, string helperName, ISymbol symbol) + { + foreach (var reference in SymbolFinder.FindReferencesAsync(symbol, project.Solution).GetAwaiter().GetResult()) + { + foreach (var location in reference.Locations) + { + var document = location.Document; + if (!IsGeneratedDocument(document)) + { + continue; + } + + var root = document.GetSyntaxRootAsync().GetAwaiter().GetResult(); + if (root == null) + { + continue; + } + + var node = root.FindNode(location.Location.SourceSpan); + var owner = node.AncestorsAndSelf().OfType().FirstOrDefault(); + if (owner == null) + { + continue; + } + + var semanticModel = compilation.GetSemanticModel(owner.SyntaxTree); + if (semanticModel.GetDeclaredSymbol(owner) is not INamedTypeSymbol ownerSymbol) + { + continue; + } + + var ownerName = ownerSymbol.GetFullyQualifiedName(); + if (graph.Nodes.Contains(ownerName)) + { + graph.References[ownerName].Add(helperName); + } + } + } + } + + private static HashSet GetRootNames( + IReadOnlyList providers, + HashSet nodes, + HashSet helperRoots, + bool includeModelFactory, + bool includeAdditionalRoots, + bool includeUnionVariantRoots, + bool publicClientRootsOnly) + { + var generator = CodeModelGenerator.Instance; + var roots = new HashSet(StringComparer.Ordinal); + var modelFactoryName = GetProviderTypeName(generator.OutputLibrary.ModelFactory.Value.Type); + + foreach (var provider in providers) + { + var name = GetProviderTypeName(provider.Type); + if (IsClientProviderRoot(provider, publicClientRootsOnly) || + includeAdditionalRoots && provider.DeclaringTypeProvider == null && IsKept(provider.Type, generator.AdditionalRootTypes, nodes) || + includeModelFactory && string.Equals(name, modelFactoryName, StringComparison.Ordinal) || + includeModelFactory && helperRoots.Contains(name)) + { + roots.Add(name); + } + } + + AddLastContractModelFactorySignatureRoots(providers, roots, nodes); + + if (!includeUnionVariantRoots) + { + return roots; + } + + foreach (var root in generator.TypeFactory.UnionVariantTypesToKeep) + { + AddMatchingName(roots, root, nodes); + } + + foreach (var root in generator.AdditionalRootTypes) + { + AddMatchingName(roots, root, nodes); + } + + return roots; + } + + private static void AddLastContractModelFactorySignatureRoots(IReadOnlyList providers, HashSet roots, HashSet nodes) + { + foreach (var provider in providers.Where(IsModelFactoryProvider)) + { + foreach (var method in provider.LastContractView?.Methods ?? []) + { + if (!method.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Public) || + IsImplementationOnlyModelFactoryMethod(method)) + { + continue; + } + + AddTypeReference(roots, method.Signature.ReturnType, nodes); + foreach (var parameter in method.Signature.Parameters) + { + AddTypeReference(roots, parameter.Type, nodes); + } + } + } + } + + private static bool IsImplementationOnlyModelFactoryMethod(MethodProvider method) + { + var returnType = method.Signature.ReturnType; + if (returnType == null) + { + return true; + } + + var returnTypeName = GetSimpleName(GetProviderTypeName(returnType)); + return returnTypeName.StartsWith("Paged", StringComparison.Ordinal) || + returnTypeName.EndsWith("Request", StringComparison.Ordinal); + } + + private static HashSet GetPostProcessorDeclaredNodes(IReadOnlyList providers, HashSet nodes, bool publicOnly) + { + var generator = CodeModelGenerator.Instance; + var excludedNames = generator.NonRootTypes; + return GetGeneratedProviders(providers) + .Where(provider => !IsModelFactoryProvider(provider)) + .Where(provider => !publicOnly || provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)) + .Select(provider => GetProviderTypeName(provider.Type)) + .Where(name => nodes.Contains(name)) + .Where(name => !excludedNames.Contains(name) && !excludedNames.Contains(GetSimpleName(name))) + .ToHashSet(StringComparer.Ordinal); + } + + private static bool IsKept(CSharpType type, HashSet roots, HashSet nodes) => + roots.Contains(type.Name) || roots.Contains(GetProviderTypeName(type)) && nodes.Contains(GetProviderTypeName(type)); + + private static bool IsClientProviderRoot(TypeProvider provider, bool publicOnly) => + provider.RelativeFilePath.EndsWith("Client.cs", StringComparison.Ordinal) && + (!publicOnly || !HasApiBaselineDirectory() && provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)); + + private static bool HasApiBaselineDirectory() + { + var projectDirectory = CodeModelGenerator.Instance.Configuration.ProjectDirectory; + return !string.IsNullOrEmpty(projectDirectory) && + Directory.Exists(Path.GetFullPath(Path.Combine(projectDirectory, "..", "api"))); + } + + private static bool IsModelFactoryProvider(TypeProvider provider) + { + if (provider is ModelFactoryProvider) + { + return true; + } + + var relativePath = provider.RelativeFilePath.Replace('\\', '/'); + return relativePath.EndsWith("ModelFactory.cs", StringComparison.Ordinal); + } + + private static HashSet GetHelperRootNames(IReadOnlyList providers, HashSet nodes, HashSet reachableTypes) + { + var roots = new HashSet(StringComparer.Ordinal); + foreach (var provider in GetGeneratedProviders(providers)) + { + var providerName = GetProviderTypeName(provider.Type); + var isModelFactory = IsModelFactoryProvider(provider); + if (!reachableTypes.Contains(providerName) && !isModelFactory) + { + continue; + } + + AddHelperDependencies(roots, provider.HelperDependencyNames, nodes); + + foreach (var property in provider.Properties) + { + AddInitializationHelperRoot(roots, property.Type, nodes); + AddParameterValidationHelperRoot(roots, property.AsParameter, nodes); + } + + foreach (var field in provider.Fields) + { + AddParameterValidationHelperRoot(roots, field.AsParameter, nodes); + } + + foreach (var constructor in provider.Constructors) + { + foreach (var parameter in constructor.Signature.Parameters) + { + AddParameterValidationHelperRoot(roots, parameter, nodes); + } + } + + foreach (var method in provider.Methods) + { + // Only factory methods for reachable models can instantiate collection helpers. + if (isModelFactory && + (method.Signature.ReturnType == null || !reachableTypes.Contains(GetProviderTypeName(method.Signature.ReturnType)))) + { + continue; + } + + foreach (var parameter in method.Signature.Parameters) + { + AddParameterValidationHelperRoot(roots, parameter, nodes); + if (isModelFactory) + { + AddModelFactoryCollectionInitializationHelperRoot(roots, parameter.Type, nodes); + } + } + } + } + + return roots; + } + + private static void AddParameterValidationHelperRoot(HashSet roots, ParameterProvider parameter, HashSet nodes) + { + if (parameter.Validation != ParameterValidationType.None) + { + AddMatchingName(roots, "Argument", nodes); + } + } + + private static void AddHelperDependencies(HashSet roots, IReadOnlyList dependencies, HashSet nodes) + { + foreach (var dependency in dependencies) + { + AddMatchingName(roots, dependency, nodes); + } + } + + private static bool IsSerializationProvider(TypeProvider provider) + { + var relativePath = provider.RelativeFilePath.Replace('\\', '/'); + return relativePath.EndsWith(".Serialization.cs", StringComparison.Ordinal) || + relativePath.EndsWith(".Serialization.Multipart.cs", StringComparison.Ordinal); + } + + private static bool IsGeneratedDocument(Document document) + { + if (GeneratedCodeWorkspace.IsGeneratedDocument(document) || GeneratedCodeWorkspace.IsGeneratedTestDocument(document)) + { + return true; + } + + var filePath = document.FilePath?.Replace('\\', '/'); + return filePath != null && + (filePath.Contains("/Generated/", StringComparison.Ordinal) || + filePath.Contains("/GeneratedTests/", StringComparison.Ordinal)); + } + + private static void AddInitializationHelperRoot(HashSet roots, CSharpType? type, HashSet nodes) + { + if (type == null) + { + return; + } + + var initializationType = type.PropertyInitializationType; + if (!string.Equals(initializationType.FullyQualifiedName, type.FullyQualifiedName, StringComparison.Ordinal)) + { + AddMatchingName(roots, initializationType.Name, nodes); + } + + if (type is { IsList: true, IsReadOnlyMemory: false }) + { + AddMatchingName(roots, "ChangeTrackingList", nodes); + } + + foreach (var argument in type.Arguments) + { + AddInitializationHelperRoot(roots, argument, nodes); + } + } + + private static void AddModelFactoryCollectionInitializationHelperRoot(HashSet roots, CSharpType? type, HashSet nodes) + { + if (type == null) + { + return; + } + + if (type is { IsList: true, IsReadOnlyMemory: false }) + { + AddMatchingName(roots, "ChangeTrackingList", nodes); + } + + if (type.IsDictionary) + { + AddMatchingName(roots, "ChangeTrackingDictionary", nodes); + } + + foreach (var argument in type.Arguments) + { + AddModelFactoryCollectionInitializationHelperRoot(roots, argument, nodes); + } + } + + private static void AddMatchingName(HashSet target, string name, HashSet nodes) + { + if (nodes.Contains(name)) + { + target.Add(name); + return; + } + + var simpleNameLookup = _simpleNameLookupCache.GetValue(nodes, BuildSimpleNameLookup); + if (!simpleNameLookup.TryGetValue(name, out var matches)) + { + return; + } + + foreach (var match in matches) + { + target.Add(match); + } + } + + private static Dictionary BuildSimpleNameLookup(HashSet nodes) + { + return nodes + .GroupBy(static node => StripGenericArity(GetSimpleName(node)), StringComparer.Ordinal) + .ToDictionary(static group => group.Key, static group => group.ToArray(), StringComparer.Ordinal); + } + + private static HashSet GetReachableTypes(HashSet roots, IReadOnlyDictionary> references) + { + return GetReachableTypes(roots, references, expandableNodes: null); + } + + private static HashSet GetReachableTypes( + HashSet roots, + IReadOnlyDictionary> references, + HashSet? expandableNodes) + { + var reachable = new HashSet(StringComparer.Ordinal); + var queue = new Queue(roots); + while (queue.Count > 0) + { + var current = queue.Dequeue(); + if (!reachable.Add(current)) + { + continue; + } + + if (expandableNodes != null && !expandableNodes.Contains(current)) + { + continue; + } + + if (!references.TryGetValue(current, out var children)) + { + continue; + } + + foreach (var child in children) + { + queue.Enqueue(child); + } + } + + return reachable; + } + + private static bool HasPublicApiPredecessor( + string name, + IReadOnlyDictionary> references, + HashSet publicizeReachable, + HashSet generatedImplementationInternalDeclarations) + { + foreach (var (owner, children) in references) + { + if (!publicizeReachable.Contains(owner) || + string.Equals(owner, name, StringComparison.Ordinal) || + generatedImplementationInternalDeclarations.Contains(owner) || + !children.Contains(name)) + { + continue; + } + + return true; + } + + return false; + } + + private static void AddSignatureReferences( + HashSet references, + MethodSignatureBase signature, + HashSet nodes, + IReadOnlyDictionary? serializationProviderNamesByType, + bool includeAttributes = true) + { + AddTypeReference(references, signature.ReturnType, nodes, serializationProviderNamesByType); + if (includeAttributes) + { + AddAttributes(references, signature.Attributes, nodes, serializationProviderNamesByType); + } + + foreach (var parameter in signature.Parameters) + { + AddTypeReference(references, parameter.Type, nodes, serializationProviderNamesByType); + if (includeAttributes) + { + AddAttributes(references, parameter.Attributes, nodes, serializationProviderNamesByType); + } + } + + if (signature is MethodSignature methodSignature) + { + AddTypeReference(references, methodSignature.ExplicitInterface, nodes, serializationProviderNamesByType); + if (methodSignature.GenericArguments != null) + { + foreach (var genericArgument in methodSignature.GenericArguments) + { + AddTypeReference(references, genericArgument, nodes, serializationProviderNamesByType); + } + } + + if (methodSignature.GenericParameterConstraints != null) + { + foreach (var constraint in methodSignature.GenericParameterConstraints) + { + AddTypeReference(references, constraint.Type, nodes, serializationProviderNamesByType); + } + } + } + + if (signature is ConstructorSignature constructorSignature) + { + AddTypeReference(references, constructorSignature.Type, nodes, serializationProviderNamesByType); + } + } + + private static void AddAttributes( + HashSet references, + IReadOnlyList attributes, + HashSet nodes, + IReadOnlyDictionary? serializationProviderNamesByType) + { + foreach (var attribute in attributes) + { + AddTypeReference(references, attribute.Type, nodes, serializationProviderNamesByType); + } + } + + private static void AddTypeReference( + HashSet references, + CSharpType? type, + HashSet nodes, + IReadOnlyDictionary? serializationProviderNamesByType = null) + { + if (type == null) + { + return; + } + + var providerTypeName = GetProviderTypeName(type); + if (nodes.Contains(providerTypeName)) + { + references.Add(providerTypeName); + if (serializationProviderNamesByType != null && serializationProviderNamesByType.TryGetValue(providerTypeName, out var serializationProviderNames)) + { + foreach (var serializationProviderName in serializationProviderNames) + { + references.Add(serializationProviderName); + } + } + } + + AddTypeReference(references, type.BaseType, nodes, serializationProviderNamesByType); + AddTypeReference(references, type.DeclaringType, nodes, serializationProviderNamesByType); + foreach (var argument in type.Arguments) + { + AddTypeReference(references, argument, nodes, serializationProviderNamesByType); + } + } + + public static void WriteComparisonReport(string passName, IEnumerable roslynCandidates, IEnumerable providerCandidates) + { + if (!IsEnabled || !ShouldWriteReports) + { + return; + } + + var roslynSet = roslynCandidates.ToHashSet(StringComparer.Ordinal); + var providerSet = providerCandidates.ToHashSet(StringComparer.Ordinal); + var missingFromProvider = roslynSet.Except(providerSet, StringComparer.Ordinal).OrderBy(static name => name, StringComparer.Ordinal).ToArray(); + var extraInProvider = providerSet.Except(roslynSet, StringComparer.Ordinal).OrderBy(static name => name, StringComparer.Ordinal).ToArray(); + + var directory = GetOutputDirectory(); + Directory.CreateDirectory(directory); + var path = Path.Combine(directory, $"provider-reference-map-shadow-comparison-{passName}-{DateTime.UtcNow:yyyyMMddHHmmssfff}.txt"); + var builder = new StringBuilder(); + builder.AppendLine($"Provider reference map shadow comparison: {passName}"); + builder.AppendLine($"Roslyn candidates: {roslynSet.Count}"); + builder.AppendLine($"Provider candidates: {providerSet.Count}"); + builder.AppendLine($"Missing from provider: {missingFromProvider.Length}"); + builder.AppendLine($"Extra in provider: {extraInProvider.Length}"); + builder.AppendLine(); + builder.AppendLine("Missing from provider:"); + foreach (var item in missingFromProvider) + { + builder.AppendLine($" {item}"); + } + + builder.AppendLine(); + builder.AppendLine("Extra in provider:"); + foreach (var item in extraInProvider) + { + builder.AppendLine($" {item}"); + } + + File.WriteAllText(path, builder.ToString()); + CodeModelGenerator.Instance.Emitter.Debug($"Provider reference map shadow comparison written to {path}"); + } + + private static void WriteReport( + ProviderReferenceGraph graph, + HashSet customPublicRoots, + HashSet customRemovalRoots, + HashSet helperRoots, + HashSet internalizeRoots, + HashSet internalizeReachable, + IReadOnlyList internalizeCandidates, + HashSet publicizeRoots, + HashSet publicizeReachable, + IReadOnlyList publicizeCandidates, + HashSet removeRoots, + HashSet removeReachable, + IReadOnlyList removeCandidates) + { + var directory = GetOutputDirectory(); + + Directory.CreateDirectory(directory); + var path = Path.Combine(directory, $"provider-reference-map-shadow-{DateTime.UtcNow:yyyyMMddHHmmssfff}.txt"); + var builder = new StringBuilder(); + builder.AppendLine("Provider reference map shadow report"); + builder.AppendLine($"Declared providers: {graph.Nodes.Count}"); + builder.AppendLine($"Internalize roots: {internalizeRoots.Count}"); + builder.AppendLine($"Internalize reachable: {internalizeReachable.Count}"); + builder.AppendLine($"Internalize candidates: {internalizeCandidates.Count}"); + builder.AppendLine($"Publicize roots: {publicizeRoots.Count}"); + builder.AppendLine($"Publicize reachable: {publicizeReachable.Count}"); + builder.AppendLine($"Publicize candidates: {publicizeCandidates.Count}"); + builder.AppendLine($"Custom public roots: {customPublicRoots.Count}"); + builder.AppendLine($"Custom removal roots: {customRemovalRoots.Count}"); + builder.AppendLine($"Helper roots: {helperRoots.Count}"); + builder.AppendLine($"Remove roots: {removeRoots.Count}"); + builder.AppendLine($"Remove reachable: {removeReachable.Count}"); + builder.AppendLine($"Remove candidates: {removeCandidates.Count}"); + builder.AppendLine(); + AppendItems(builder, "Custom public roots", customPublicRoots); + AppendItems(builder, "Custom removal roots", customRemovalRoots); + AppendItems(builder, "Helper roots", helperRoots); + AppendItems(builder, "Internalize roots", internalizeRoots); + AppendItems(builder, "Internalize candidates", internalizeCandidates); + AppendItems(builder, "Publicize roots", publicizeRoots); + AppendItems(builder, "Publicize candidates", publicizeCandidates); + AppendItems(builder, "Remove roots", removeRoots); + AppendItems(builder, "Remove candidates", removeCandidates); + + builder.AppendLine(); + builder.AppendLine("References:"); + foreach (var (type, references) in graph.References.OrderBy(static item => item.Key, StringComparer.Ordinal)) + { + builder.AppendLine($" {type}"); + foreach (var reference in references.OrderBy(static name => name, StringComparer.Ordinal)) + { + builder.AppendLine($" -> {reference}"); + } + } + + File.WriteAllText(path, builder.ToString()); + CodeModelGenerator.Instance.Emitter.Debug($"Provider reference map shadow report written to {path}"); + } + + private static void AppendItems(StringBuilder builder, string title, IEnumerable items) + { + builder.AppendLine(); + builder.AppendLine($"{title}:"); + foreach (var item in items.OrderBy(static name => name, StringComparer.Ordinal)) + { + builder.AppendLine($" {item}"); + } + } + + private static string GetOutputDirectory() + { + var directory = Environment.GetEnvironmentVariable(OutputDirectoryEnvironmentVariable); + return string.IsNullOrWhiteSpace(directory) + ? Path.Combine(Path.GetTempPath(), "typespec-provider-reference-map-shadow") + : Path.GetFullPath(directory); + } + + private static string GetSimpleName(string fullyQualifiedName) + { + var lastDot = fullyQualifiedName.LastIndexOf('.'); + return lastDot < 0 ? fullyQualifiedName : fullyQualifiedName.Substring(lastDot + 1); + } + + private static string GetProviderTypeName(CSharpType type) + { + var name = type.Arguments.Count > 0 && !type.Name.Contains('`', StringComparison.Ordinal) + ? $"{type.Name}`{type.Arguments.Count}" + : type.Name; + return string.IsNullOrEmpty(type.Namespace) ? name : $"{type.Namespace}.{name}"; + } + + private static string StripGenericArity(string name) + { + var tick = name.IndexOf('`'); + return tick < 0 ? name : name.Substring(0, tick); + } + + private sealed record ProviderReferenceGraph( + HashSet Nodes, + Dictionary> References); + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapShadowResult.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapShadowResult.cs new file mode 100644 index 00000000000..6750579e24c --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapShadowResult.cs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Collections.Generic; +using Microsoft.CodeAnalysis; + +namespace Microsoft.TypeSpec.Generator +{ + internal sealed record ProviderReferenceMapShadowResult( + ProjectId ProjectId, + HashSet InternalizeCandidates, + HashSet PublicizeCandidates, + HashSet RemoveCandidates) + { + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/TypeProvider.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/TypeProvider.cs index 7a3047886b6..a4d76740c4d 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/TypeProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/TypeProvider.cs @@ -272,6 +272,14 @@ private IReadOnlyList ApplyCustomizationFilter(IEnumerable SerializationProviders => _serializationProviders ??= BuildSerializationProviders(); + private IReadOnlyList? _helperDependencyNames; + internal IReadOnlyList HelperDependencyNames => _helperDependencyNames ??= BuildHelperDependencyNames(); + protected internal virtual IReadOnlyList BuildHelperDependencyNames() => []; + + private IReadOnlyList? _bodyDependencyTypes; + internal IReadOnlyList BodyDependencyTypes => _bodyDependencyTypes ??= BuildBodyDependencyTypes(); + protected internal virtual IReadOnlyList BuildBodyDependencyTypes() => []; + private IReadOnlyList? _attributes; public IReadOnlyList Attributes