diff --git a/src/client/Microsoft.Identity.Client/PlatformsCommon/Shared/SimpleHttpClientFactory.cs b/src/client/Microsoft.Identity.Client/PlatformsCommon/Shared/SimpleHttpClientFactory.cs index 0e0b3c9b6f..b12325bff0 100644 --- a/src/client/Microsoft.Identity.Client/PlatformsCommon/Shared/SimpleHttpClientFactory.cs +++ b/src/client/Microsoft.Identity.Client/PlatformsCommon/Shared/SimpleHttpClientFactory.cs @@ -6,6 +6,7 @@ using System.Net.Http; using System.Net.Security; using System.Security.Cryptography.X509Certificates; +using System.Threading; using Microsoft.Identity.Client.Http; using Microsoft.Identity.Client.ManagedIdentity; @@ -21,11 +22,18 @@ namespace Microsoft.Identity.Client.PlatformsCommon.Shared internal class SimpleHttpClientFactory : IMsalMtlsHttpClientFactory, IMsalSFHttpClientFactory { //Please see (https://aka.ms/msal-httpclient-info) for important information regarding the HttpClient. - private static readonly ConcurrentDictionary s_httpClientPool = new ConcurrentDictionary(); + private static readonly ConcurrentDictionary> s_httpClientPool = + new ConcurrentDictionary>(); private static readonly object s_cacheLock = new object(); + private static int s_httpClientCreationCount; + + // referenced in unit tests + internal static int HttpClientCreationCount => s_httpClientCreationCount; + private static HttpClient CreateHttpClient() { + Interlocked.Increment(ref s_httpClientCreationCount); CheckAndManageCache(); var httpClient = new HttpClient(new HttpClientHandler() @@ -41,6 +49,7 @@ private static HttpClient CreateHttpClient() private static HttpClient CreateMtlsHttpClient(X509Certificate2 bindingCertificate) { #if SUPPORTS_MTLS + Interlocked.Increment(ref s_httpClientCreationCount); CheckAndManageCache(); if (bindingCertificate == null) @@ -63,7 +72,9 @@ private static HttpClient CreateMtlsHttpClient(X509Certificate2 bindingCertifica public HttpClient GetHttpClient() { - return s_httpClientPool.GetOrAdd("non_mtls", CreateHttpClient()); + return s_httpClientPool.GetOrAdd( + "non_mtls", + _ => new Lazy(CreateHttpClient, LazyThreadSafetyMode.ExecutionAndPublication)).Value; } public HttpClient GetHttpClient(X509Certificate2 x509Certificate2) @@ -74,7 +85,9 @@ public HttpClient GetHttpClient(X509Certificate2 x509Certificate2) } string key = x509Certificate2.Thumbprint; - return s_httpClientPool.GetOrAdd(key, CreateMtlsHttpClient(x509Certificate2)); + return s_httpClientPool.GetOrAdd( + key, + _ => new Lazy(() => CreateMtlsHttpClient(x509Certificate2), LazyThreadSafetyMode.ExecutionAndPublication)).Value; } private static void CheckAndManageCache() @@ -88,6 +101,22 @@ private static void CheckAndManageCache() } } + // referenced in unit tests + internal static void ResetStaticStateForTest() + { + lock (s_cacheLock) + { + foreach (Lazy lazy in s_httpClientPool.Values) + { + if (lazy.IsValueCreated) + lazy.Value?.Dispose(); + } + + s_httpClientPool.Clear(); + Interlocked.Exchange(ref s_httpClientCreationCount, 0); + } + } + // This method is used for Service Fabric scenarios where a custom server certificate validation callback is required. // It allows the caller to provide a custom HttpClientHandler with the callback. // The server cert rotates so we need a new HttpClient for each call. @@ -107,8 +136,7 @@ public HttpClient GetHttpClient(Func>(threadCount); + + // Act - call GetHttpClient() concurrently from many threads at once + for (int i = 0; i < threadCount; i++) + { + tasks.Add(Task.Run(() => factory.GetHttpClient())); + } + + HttpClient[] results = await Task.WhenAll(tasks).ConfigureAwait(false); + + int created = SimpleHttpClientFactory.HttpClientCreationCount; + + // Assert - all callers got a non-null client and the same cached instance. + // With Lazy(ExecutionAndPublication) only one HttpClient should + // ever be constructed, regardless of how many threads raced on the same key. + foreach (HttpClient client in results) + { + Assert.IsNotNull(client); + Assert.AreSame(results[0], client, "All concurrent callers should receive the same cached HttpClient instance."); + } + + Assert.AreEqual(1, created, + $"CreateHttpClient was called {created} times across {threadCount} concurrent calls. " + + "Lazy(ExecutionAndPublication) should guarantee exactly one construction per key."); + } + } }