Skip to content

Commit 7170925

Browse files
Copilotgladjohn
andauthored
Add MSAL client metadata headers (x-client-SKU, x-client-VER, x-ms-client-request-id) to IMDS managed identity requests
Agent-Logs-Url: https://github.com/AzureAD/microsoft-authentication-library-for-java/sessions/a002e01d-801c-4ba9-8cf4-0d0ad3b9a0c9 Co-authored-by: gladjohn <90415114+gladjohn@users.noreply.github.com>
1 parent b1d2b77 commit 7170925

2 files changed

Lines changed: 28 additions & 1 deletion

File tree

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/IMDSManagedIdentitySource.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ public void createManagedIdentityRequest(String resource) {
7777

7878
managedIdentityRequest.headers = new HashMap<>();
7979
managedIdentityRequest.headers.put("Metadata", "true");
80+
managedIdentityRequest.headers.put(HttpHeaders.PRODUCT_HEADER_NAME, HttpHeaders.PRODUCT_HEADER_VALUE);
81+
managedIdentityRequest.headers.put(HttpHeaders.PRODUCT_VERSION_HEADER_NAME, HttpHeaders.PRODUCT_VERSION_HEADER_VALUE);
82+
managedIdentityRequest.headers.put("x-ms-client-request-id", managedIdentityRequest.requestContext().correlationId());
8083

8184
managedIdentityRequest.queryParameters = new HashMap<>();
8285
managedIdentityRequest.queryParameters.put("api-version", IMDS_API_VERSION);

msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import org.junit.jupiter.params.provider.ValueSource;
1515
import org.mockito.junit.jupiter.MockitoExtension;
1616

17+
import org.mockito.ArgumentCaptor;
18+
1719
import java.net.SocketException;
1820
import java.nio.file.Path;
1921
import java.nio.file.Paths;
@@ -37,6 +39,9 @@
3739
@TestInstance(TestInstance.Lifecycle.PER_METHOD)
3840
class ManagedIdentityTests {
3941

42+
private static final String EXPECTED_SKU = "MSAL.Java";
43+
private static final String TEST_CORRELATION_ID = "00000000-0000-0000-0000-000000000001";
44+
4045
@BeforeAll
4146
static void setupRetryPolicies() {
4247
// Set retry delays to 1ms for faster test execution
@@ -145,6 +150,9 @@ private String configureSourceSpecificParameters(ManagedIdentitySourceType sourc
145150
default:
146151
queryParameters.put("api-version", "2018-02-01");
147152
headers.put("Metadata", "true");
153+
headers.put("x-client-SKU", EXPECTED_SKU);
154+
headers.put("x-client-VER", HttpHeaders.PRODUCT_VERSION_HEADER_VALUE);
155+
headers.put("x-ms-client-request-id", TEST_CORRELATION_ID);
148156
return ManagedIdentityTestConstants.IMDS_ENDPOINT;
149157
}
150158
}
@@ -206,6 +214,7 @@ void initManagedIdentityApplication(ManagedIdentityId idType) {
206214
miApp = ManagedIdentityApplication
207215
.builder(idType)
208216
.httpClient(httpClientMock)
217+
.correlationId(TEST_CORRELATION_ID)
209218
.build();
210219

211220
// ManagedIdentityApplication uses a static token cache, avoid cross test pollution by clearing it
@@ -396,6 +405,7 @@ void managedIdentityTest_WithCapabilitiesOnly(ManagedIdentitySourceType source,
396405
.builder(ManagedIdentityId.systemAssigned())
397406
.httpClient(httpClientMock)
398407
.clientCapabilities(singletonList("cp1"))
408+
.correlationId(TEST_CORRELATION_ID)
399409
.build();
400410

401411
miApp.tokenCache.accessTokens.clear();
@@ -425,6 +435,7 @@ void managedIdentity_ClaimsAndCapabilities(ManagedIdentitySourceType source, Str
425435
.builder(ManagedIdentityId.systemAssigned())
426436
.clientCapabilities(singletonList("cp1"))
427437
.httpClient(httpClientMock)
438+
.correlationId(TEST_CORRELATION_ID)
428439
.build();
429440

430441
// First call, get the token from the identity provider.
@@ -554,6 +565,7 @@ void managedIdentityTest_Retry(ManagedIdentitySourceType source, String endpoint
554565
miApp = ManagedIdentityApplication
555566
.builder(ManagedIdentityId.systemAssigned())
556567
.httpClient(httpClientMock)
568+
.correlationId(TEST_CORRELATION_ID)
557569
.build();
558570

559571
//Several specific 4xx and 5xx errors, such as 500, should trigger MSAL's retry logic
@@ -602,6 +614,7 @@ void managedIdentityTest_RetriesDisabled(ManagedIdentitySourceType source, Strin
602614
.builder(ManagedIdentityId.systemAssigned())
603615
.httpClient(httpClientMock)
604616
.disableInternalRetries()
617+
.correlationId(TEST_CORRELATION_ID)
605618
.build();
606619

607620
//Several specific 4xx and 5xx errors, such as 500, should trigger MSAL's retry logic
@@ -631,6 +644,7 @@ void managedIdentityTest_IMDSRetry() throws Exception {
631644
miApp = ManagedIdentityApplication
632645
.builder(ManagedIdentityId.systemAssigned())
633646
.httpClient(httpClientMock)
647+
.correlationId(TEST_CORRELATION_ID)
634648
.build();
635649

636650
// IMDS has different retry logic for certain status codes, such as 410
@@ -672,6 +686,7 @@ void managedIdentityTest_RetrySucceedsAfterFailure() throws Exception {
672686
miApp = ManagedIdentityApplication
673687
.builder(ManagedIdentityId.systemAssigned())
674688
.httpClient(httpClientMock)
689+
.correlationId(TEST_CORRELATION_ID)
675690
.build();
676691

677692
// First call returns 500, subsequent calls return 200
@@ -685,7 +700,14 @@ void managedIdentityTest_RetrySucceedsAfterFailure() throws Exception {
685700
assertNotNull(result.accessToken());
686701

687702
// Verify that the client was called exactly twice (first attempt + one retry)
688-
verify(httpClientMock, times(2)).send(any());
703+
ArgumentCaptor<HttpRequest> captor = ArgumentCaptor.forClass(HttpRequest.class);
704+
verify(httpClientMock, times(2)).send(captor.capture());
705+
706+
// Verify IMDS client metadata headers on the captured request
707+
HttpRequest capturedRequest = captor.getValue();
708+
assertEquals(EXPECTED_SKU, capturedRequest.headers().get("x-client-SKU"));
709+
assertNotNull(capturedRequest.headers().get("x-client-VER"));
710+
assertDoesNotThrow(() -> UUID.fromString(capturedRequest.headers().get("x-ms-client-request-id")));
689711
}
690712

691713
@Test
@@ -698,6 +720,7 @@ void managedIdentityTest_NonRetryableError() throws Exception {
698720
miApp = ManagedIdentityApplication
699721
.builder(ManagedIdentityId.systemAssigned())
700722
.httpClient(httpClientMock)
723+
.correlationId(TEST_CORRELATION_ID)
701724
.build();
702725

703726
// Use a status code that doesn't trigger retries (499)
@@ -727,6 +750,7 @@ void managedIdentityTest_RetriesARePerRequest() throws Exception {
727750
miApp = ManagedIdentityApplication
728751
.builder(ManagedIdentityId.systemAssigned())
729752
.httpClient(httpClientMock)
753+
.correlationId(TEST_CORRELATION_ID)
730754
.build();
731755

732756
// First call returns 500, subsequent calls return 200

0 commit comments

Comments
 (0)