Skip to content

Commit 6c71497

Browse files
committed
Make it work for build-time
1 parent bde8d62 commit 6c71497

4 files changed

Lines changed: 139 additions & 50 deletions

File tree

examples/AspNetCore/WebApi/MinimalOpenApiExample/MinimalOpenApiExample.csproj

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
<TargetFramework>net10.0</TargetFramework>
55
<AssemblyTitle>Example API</AssemblyTitle>
66
<GenerateDocumentationFile>true</GenerateDocumentationFile>
7-
<OpenApiGenerateDocuments>false</OpenApiGenerateDocuments>
8-
<OpenApiGenerateDocumentsOnBuild>false</OpenApiGenerateDocumentsOnBuild>
97
</PropertyGroup>
108

119
<ItemGroup>

src/AspNetCore/WebApi/src/Asp.Versioning.OpenApi/Builder/IEndpointConventionBuilderExtensions.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ namespace Microsoft.AspNetCore.Builder;
99
using Microsoft.AspNetCore.Builder;
1010
using Microsoft.AspNetCore.Http;
1111
using Microsoft.Extensions.DependencyInjection;
12+
using Microsoft.Extensions.Hosting;
1213

1314
/// <summary>
1415
/// Provides extension methods for <see cref="IEndpointConventionBuilder"/>.
@@ -43,12 +44,11 @@ private static void ApplyApiVersioning( EndpointBuilder builder )
4344

4445
private static Task InterceptRequestServices( HttpContext context, RequestDelegate action )
4546
{
46-
if ( context.RequestServices is not AggregateKeyedServiceProvider serviceProvider )
47+
if ( context.RequestServices is not AggregateKeyedServiceProvider )
4748
{
48-
serviceProvider = context.RequestServices.GetRequiredService<AggregateKeyedServiceProvider>();
49+
context.RequestServices = context.RequestServices.GetRequiredService<IHost>().Services;
4950
}
5051

51-
context.RequestServices = serviceProvider;
5252
return action( context );
5353
}
5454
}

src/AspNetCore/WebApi/src/Asp.Versioning.OpenApi/DependencyInjection/AggregateKeyedServiceProvider.cs

Lines changed: 73 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,57 +4,100 @@
44

55
namespace Microsoft.Extensions.DependencyInjection;
66

7-
internal sealed class AggregateKeyedServiceProvider( IServiceProvider parent ) : IKeyedServiceProvider
7+
using Asp.Versioning.ApiExplorer;
8+
using Microsoft.Extensions.DependencyInjection.Extensions;
9+
using Microsoft.Extensions.Hosting;
10+
11+
internal sealed class AggregateKeyedServiceProvider : IKeyedServiceProvider, IDisposable
812
{
9-
private readonly IServiceProvider parent = parent;
10-
private readonly List<IServiceProvider> providers = [];
13+
private readonly IServiceCollection services;
14+
private readonly SemaphoreSlim semaphore = new SemaphoreSlim( 1, 1 );
1115

12-
public object? GetKeyedService( Type serviceType, object? serviceKey )
16+
private IServiceProvider serviceProvider;
17+
private bool initialized;
18+
private int? initializingThreadId;
19+
20+
public AggregateKeyedServiceProvider( IServiceProvider serviceProvider, IServiceCollection services )
1321
{
14-
if ( providers.Count == 0 )
15-
{
16-
return parent.GetKeyedService( serviceType, serviceKey );
17-
}
22+
this.services = services;
23+
this.serviceProvider = serviceProvider;
24+
var lifetime = serviceProvider.GetRequiredService<IHostApplicationLifetime>();
25+
lifetime.ApplicationStarted.Register( () => EnsureInitialized(true) );
26+
}
1827

19-
foreach ( var provider in providers )
28+
private IServiceProvider ServiceProvider
29+
{
30+
get
2031
{
21-
if ( provider.GetKeyedService( serviceType, serviceKey ) is { } service )
22-
{
23-
return service;
24-
}
32+
EnsureInitialized(false);
33+
return serviceProvider;
2534
}
26-
27-
return null;
2835
}
2936

30-
public object GetRequiredKeyedService( Type serviceType, object? serviceKey )
37+
private void EnsureInitialized(bool isReady)
3138
{
32-
if ( providers.Count == 0 )
39+
// If already initialized, we can return immediately.
40+
if ( initialized)
41+
{
42+
return;
43+
}
44+
45+
if ( initializingThreadId.HasValue && Environment.CurrentManagedThreadId == initializingThreadId.Value )
3346
{
34-
return parent.GetRequiredKeyedService( serviceType, serviceKey );
47+
return;
3548
}
3649

37-
for ( int i = 0; i < providers.Count - 1; i++ )
50+
// If a "ready" call entered this call already, ensure that other calls will be blocked until we fully initialize.
51+
semaphore.Wait();
52+
try
3853
{
39-
if ( providers[i].GetKeyedService( serviceType, serviceKey ) is { } service )
54+
if ( initialized || !isReady )
55+
{
56+
return;
57+
}
58+
59+
initializingThreadId = Environment.CurrentManagedThreadId;
60+
var provider = serviceProvider.GetRequiredService<IApiVersionDescriptionProvider>();
61+
62+
var collection = new ServiceCollection();
63+
foreach ( var descriptor in services )
4064
{
41-
return service;
65+
collection.Add( descriptor );
4266
}
67+
68+
var descriptions = provider.ApiVersionDescriptions;
69+
70+
for ( var i = 0; i < descriptions.Count; i++ )
71+
{
72+
var description = descriptions[i];
73+
collection.AddOpenApi( description.GroupName );
74+
}
75+
76+
serviceProvider = collection.BuildServiceProvider();
77+
initialized = true;
78+
initializingThreadId = null;
4379
}
80+
finally
81+
{
82+
semaphore.Release();
83+
}
84+
}
4485

45-
return providers[providers.Count - 1].GetRequiredKeyedService( serviceType, serviceKey );
86+
public object? GetKeyedService( Type serviceType, object? serviceKey )
87+
{
88+
return ServiceProvider.GetKeyedService( serviceType, serviceKey );
89+
}
90+
91+
public object GetRequiredKeyedService( Type serviceType, object? serviceKey )
92+
{
93+
return ServiceProvider.GetRequiredKeyedService( serviceType, serviceKey );
4694
}
4795

4896
public object? GetService( Type serviceType )
49-
=> parent.GetService( serviceType );
97+
=> ServiceProvider.GetService( serviceType );
5098

51-
public void Add( IServiceCollection serviceCollection, IServiceCollection parentServiceCollection )
99+
public void Dispose()
52100
{
53-
foreach ( var descriptor in parentServiceCollection )
54-
{
55-
serviceCollection.Add( descriptor );
56-
}
57-
58-
providers.Add( serviceCollection.BuildServiceProvider() );
101+
semaphore.Dispose();
59102
}
60103
}

src/AspNetCore/WebApi/src/Asp.Versioning.OpenApi/DependencyInjection/IApiVersioningBuilderExtensions.cs

Lines changed: 63 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,18 @@ private static void AddOpenApiServices( IApiVersioningBuilder builder, Assembly[
6565

6666
var services = builder.Services;
6767

68-
services.AddTransient( serviceProvider => NewRequestServices( serviceProvider, services ) );
68+
services.Add( GetDocumentProviderDescriptor() );
69+
70+
var hostDescriptor = services.Single(
71+
s => !s.IsKeyedService &&
72+
s.ServiceType == typeof( IHost ) &&
73+
s.Lifetime == ServiceLifetime.Singleton &&
74+
s.ImplementationInstance is null &&
75+
s.ImplementationType is null &&
76+
s.ImplementationFactory is not null );
77+
var hostDescriptorIndex = services.IndexOf( hostDescriptor );
78+
79+
builder.Services[hostDescriptorIndex] = CreateHostWrapperDescriptor( services, hostDescriptor.ImplementationFactory! );
6980

7081
services.AddSingleton<VersionedOpenApiOptionsFactory>();
7182
services.TryAddEnumerable( Transient<IPostConfigureOptions<OpenApiOptions>, ConfigureOpenApiOptions>() );
@@ -74,6 +85,56 @@ private static void AddOpenApiServices( IApiVersioningBuilder builder, Assembly[
7485
services.TryAddTransient( sp => new XmlCommentsTransformer( sp.GetRequiredService<XmlCommentsFile>() ) );
7586
}
7687

88+
private static ServiceDescriptor GetDocumentProviderDescriptor()
89+
{
90+
var serviceCollection = new ServiceCollection();
91+
serviceCollection.AddOpenApi();
92+
foreach ( var descriptor in serviceCollection )
93+
{
94+
if ( descriptor.ServiceType.FullName == "Microsoft.Extensions.ApiDescriptions.IDocumentProvider" )
95+
{
96+
return descriptor;
97+
}
98+
}
99+
100+
throw new UnreachableException();
101+
}
102+
103+
private static ServiceDescriptor CreateHostWrapperDescriptor( IServiceCollection serviceCollection, Func<IServiceProvider, object> hostFactory )
104+
{
105+
Func<IServiceProvider, object> updatedHostFactory = serviceProvider =>
106+
{
107+
var originalHost = (IHost) hostFactory( serviceProvider );
108+
return new OpenApiHost(originalHost, NewRequestServices(serviceProvider, serviceCollection));
109+
};
110+
111+
return new ServiceDescriptor( typeof( IHost ), updatedHostFactory, ServiceLifetime.Singleton );
112+
}
113+
114+
private sealed class OpenApiHost : IHost
115+
{
116+
private readonly IHost originalHost;
117+
private readonly IServiceProvider customServiceProvider;
118+
119+
public OpenApiHost( IHost originalHost, IServiceProvider customServiceProvider )
120+
{
121+
this.originalHost = originalHost;
122+
this.customServiceProvider = customServiceProvider;
123+
}
124+
125+
public IServiceProvider Services
126+
=> customServiceProvider;
127+
128+
public void Dispose()
129+
=> originalHost.Dispose();
130+
131+
public Task StartAsync( CancellationToken cancellationToken = default )
132+
=> originalHost.StartAsync( cancellationToken );
133+
134+
public Task StopAsync( CancellationToken cancellationToken = default )
135+
=> originalHost.StopAsync( cancellationToken );
136+
}
137+
77138
// NOTE: The calling assembly must be captured at the call site that invokes AddOpenApi. In 99% of the cases that
78139
// should be the entry point to the application. It is technically possible to be invoked from some other assembly -
79140
// perhaps another extension library. If that were to happen, that library must resolve the path on its own and
@@ -90,21 +151,8 @@ private static Assembly[] GetAssemblies( Assembly callingAssembly )
90151
return [.. assemblies];
91152
}
92153

93-
[UnconditionalSuppressMessage( "ILLink", "IL3050" )]
94154
private static AggregateKeyedServiceProvider NewRequestServices( IServiceProvider services, IServiceCollection parentServiceCollection )
95155
{
96-
var provider = services.GetRequiredService<IApiVersionDescriptionProvider>();
97-
var container = new AggregateKeyedServiceProvider( services );
98-
var descriptions = provider.ApiVersionDescriptions;
99-
100-
for ( var i = 0; i < descriptions.Count; i++ )
101-
{
102-
var description = descriptions[i];
103-
var serviceCollection = new ServiceCollection();
104-
serviceCollection.AddOpenApi( description.GroupName );
105-
container.Add( serviceCollection, parentServiceCollection );
106-
}
107-
108-
return container;
156+
return new AggregateKeyedServiceProvider( services, parentServiceCollection );
109157
}
110158
}

0 commit comments

Comments
 (0)