Skip to content

Commit 36e0c6b

Browse files
Add in-process MAA token caching to PopKeyAttestor (#5887)
1 parent d9c6979 commit 36e0c6b

9 files changed

Lines changed: 649 additions & 47 deletions

File tree

src/client/Microsoft.Identity.Client.KeyAttestation/ManagedIdentityAttestationExtensions.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
using System;
55
using System.Threading.Tasks;
6+
using Microsoft.Identity.Client.Core;
67

78
namespace Microsoft.Identity.Client.KeyAttestation
89
{
@@ -25,13 +26,14 @@ public static AcquireTokenForManagedIdentityParameterBuilder WithAttestationSupp
2526
throw new ArgumentNullException(nameof(builder));
2627
}
2728

28-
// Set the attestation provider delegate
29-
builder.CommonParameters.AttestationTokenProvider = async (endpoint, keyHandle, clientId, ct) =>
29+
builder.CommonParameters.AttestationTokenProvider = async (endpoint, keyHandle, clientId, keyId, logger, ct) =>
3030
{
3131
var result = await PopKeyAttestor.AttestCredentialGuardAsync(
3232
endpoint,
3333
keyHandle,
3434
clientId,
35+
keyId,
36+
logger,
3537
ct).ConfigureAwait(false);
3638

3739
// Return JWT on success, null for non-attested flow on failure

src/client/Microsoft.Identity.Client.KeyAttestation/PopKeyAttestor.cs

Lines changed: 215 additions & 20 deletions
Large diffs are not rendered by default.

src/client/Microsoft.Identity.Client/ApiConfig/Parameters/AcquireTokenCommonParameters.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
using Microsoft.Identity.Client.Extensibility;
1515
using Microsoft.Identity.Client.Internal;
1616
using Microsoft.Identity.Client.Internal.ClientCredential;
17+
using Microsoft.Identity.Client.Core;
1718
using Microsoft.Identity.Client.ManagedIdentity;
1819
using Microsoft.Identity.Client.TelemetryCore.Internal.Events;
1920
using Microsoft.Identity.Client.Utils;
@@ -47,8 +48,9 @@ internal class AcquireTokenCommonParameters
4748
/// Optional delegate for obtaining attestation JWT for Credential Guard keys.
4849
/// Set by the KeyAttestation package via .WithAttestationSupport().
4950
/// Returns null for non-attested flows.
51+
/// Signature: (endpoint, keyHandle, clientId, keyId, logger, cancellationToken) → JWT or null.
5052
/// </summary>
51-
public Func<string, SafeHandle, string, CancellationToken, Task<string>> AttestationTokenProvider { get; set; }
53+
public Func<string, SafeHandle, string, string, ILoggerAdapter, CancellationToken, Task<string>> AttestationTokenProvider { get; set; }
5254

5355
/// <summary>
5456
/// This tries to see if the token request should be done over mTLS or over normal HTTP

src/client/Microsoft.Identity.Client/ApiConfig/Parameters/AcquireTokenForManagedIdentityParameters.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@ internal class AcquireTokenForManagedIdentityParameters : IAcquireTokenParameter
3131
/// <summary>
3232
/// Optional delegate for obtaining attestation JWT for Credential Guard keys.
3333
/// Set by the KeyAttestation package via .WithAttestationSupport().
34+
/// Signature: (endpoint, keyHandle, clientId, keyId, logger, cancellationToken) → JWT or null.
3435
/// </summary>
35-
public Func<string, SafeHandle, string, CancellationToken, Task<string>> AttestationTokenProvider { get; set; }
36+
public Func<string, SafeHandle, string, string, ILoggerAdapter, CancellationToken, Task<string>> AttestationTokenProvider { get; set; }
3637

3738
public void LogParameters(ILoggerAdapter logger)
3839
{

src/client/Microsoft.Identity.Client/ApplicationBase.cs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33

44
using System;
5+
using System.Collections.Concurrent;
56
using System.Collections.Generic;
67
using System.Threading;
78
using System.Threading.Tasks;
@@ -28,6 +29,22 @@ public abstract class ApplicationBase : IApplicationBase
2829
/// </summary>
2930
internal const string DefaultAuthority = "https://login.microsoftonline.com/common/";
3031

32+
// Allows extension packages (e.g. KeyAttestation) to register their own static-cache reset logic
33+
// without introducing a circular dependency. Invoked by ResetStateForTest().
34+
private static readonly ConcurrentBag<Action> s_resetCallbacks = new ConcurrentBag<Action>();
35+
36+
/// <summary>
37+
/// Registers a callback to be invoked by <see cref="ResetStateForTest"/>.
38+
/// Intended for extension packages (e.g. KeyAttestation) that own static caches MSAL cannot reference directly.
39+
/// </summary>
40+
internal static void RegisterResetCallback(Action callback)
41+
{
42+
if (callback is null)
43+
throw new ArgumentNullException(nameof(callback));
44+
45+
s_resetCallbacks.Add(callback);
46+
}
47+
3148
internal IServiceBundle ServiceBundle { get; }
3249

3350
internal ApplicationBase(ApplicationConfiguration config)
@@ -101,6 +118,28 @@ public static void ResetStateForTest()
101118

102119
InMemoryPartitionedAppTokenCacheAccessor.ClearStaticCacheForTest();
103120
InMemoryPartitionedUserTokenCacheAccessor.ClearStaticCacheForTest();
121+
122+
List<Exception> callbackExceptions = null;
123+
foreach (var cb in s_resetCallbacks)
124+
{
125+
try
126+
{
127+
cb();
128+
}
129+
catch (Exception ex)
130+
{
131+
callbackExceptions ??= new List<Exception>();
132+
callbackExceptions.Add(new InvalidOperationException(
133+
"A registered reset callback threw during ResetStateForTest().", ex));
134+
}
135+
}
136+
137+
if (callbackExceptions != null)
138+
{
139+
throw new AggregateException(
140+
"One or more registered reset callbacks failed during ResetStateForTest().",
141+
callbackExceptions);
142+
}
104143
}
105144
}
106145
}

src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using System.Net;
88
using System.Net.Http;
99
using System.Runtime.InteropServices;
10+
using System.Security.Cryptography;
1011
using System.Security.Cryptography.X509Certificates;
1112
using System.Threading;
1213
using System.Threading.Tasks;
@@ -27,7 +28,7 @@ internal class ImdsV2ManagedIdentitySource : AbstractManagedIdentity
2728
internal static readonly ICertificateCache s_mtlsCertificateCache = new InMemoryCertificateCache();
2829

2930
private readonly IMtlsCertificateCache _mtlsCache;
30-
private Func<string, SafeHandle, string, CancellationToken, Task<string>> _attestationTokenProvider;
31+
private Func<string, SafeHandle, string, string, ILoggerAdapter, CancellationToken, Task<string>> _attestationTokenProvider;
3132

3233
// used in unit tests
3334
public const string ApiVersionQueryParam = "cred-api-version";
@@ -470,11 +471,18 @@ private async Task<string> GetAttestationJwtAsync(
470471

471472
try
472473
{
473-
// Call the attestation token provider delegate
474+
// Call the attestation token provider delegate.
475+
// Prefer the CNG key name (stable for persisted/KSP keys).
476+
// For ephemeral keys (no name), derive a stable identifier from the public key
477+
// fingerprint so that the same key handle maps to the same cache entry while
478+
// distinct ephemeral keys get distinct entries.
479+
string keyId = rsaCng.Key.KeyName ?? GetPublicKeyFingerprint(rsaCng);
474480
string attestationJwt = await _attestationTokenProvider(
475481
attestationEndpoint.AbsoluteUri,
476482
rsaCng.Key.Handle,
477483
clientId,
484+
keyId,
485+
_requestContext.Logger,
478486
cancellationToken).ConfigureAwait(false);
479487

480488
if (string.IsNullOrWhiteSpace(attestationJwt))
@@ -521,5 +529,22 @@ internal static void ResetCertCacheForTest()
521529
s_mtlsCertificateCache.Clear();
522530
}
523531
}
532+
533+
/// <summary>
534+
/// Computes a stable hex fingerprint of the RSA public key.
535+
/// Used as a cache key for ephemeral CNG keys that have no key name.
536+
/// Compatible with .NET Framework 4.6.2 and netstandard2.0.
537+
/// </summary>
538+
private static string GetPublicKeyFingerprint(RSA rsa)
539+
{
540+
RSAParameters p = rsa.ExportParameters(includePrivateParameters: false);
541+
// Concatenate Modulus + Exponent as a stable, unique representation of the public key.
542+
byte[] combined = new byte[p.Modulus.Length + p.Exponent.Length];
543+
Buffer.BlockCopy(p.Modulus, 0, combined, 0, p.Modulus.Length);
544+
Buffer.BlockCopy(p.Exponent, 0, combined, p.Modulus.Length, p.Exponent.Length);
545+
using var sha256 = SHA256.Create();
546+
byte[] hash = sha256.ComputeHash(combined);
547+
return BitConverter.ToString(hash).Replace("-", string.Empty);
548+
}
524549
}
525550
}

0 commit comments

Comments
 (0)