1414import org .junit .jupiter .params .provider .ValueSource ;
1515import org .mockito .junit .jupiter .MockitoExtension ;
1616
17+ import org .mockito .ArgumentCaptor ;
18+
1719import java .net .SocketException ;
1820import java .nio .file .Path ;
1921import java .nio .file .Paths ;
3739@ TestInstance (TestInstance .Lifecycle .PER_METHOD )
3840class ManagedIdentityTests {
3941
42+ private static final String EXPECTED_SKU = "MSAL.Java" ;
43+ private static final String TEST_CORRELATION_ID = "00000000-0000-0000-0000-000000000001" ;
44+
4045 @ BeforeAll
4146 static void setupRetryPolicies () {
4247 // Set retry delays to 1ms for faster test execution
@@ -145,6 +150,9 @@ private String configureSourceSpecificParameters(ManagedIdentitySourceType sourc
145150 default :
146151 queryParameters .put ("api-version" , "2018-02-01" );
147152 headers .put ("Metadata" , "true" );
153+ headers .put ("x-client-SKU" , EXPECTED_SKU );
154+ headers .put ("x-client-VER" , HttpHeaders .PRODUCT_VERSION_HEADER_VALUE );
155+ headers .put ("x-ms-client-request-id" , TEST_CORRELATION_ID );
148156 return ManagedIdentityTestConstants .IMDS_ENDPOINT ;
149157 }
150158 }
@@ -206,6 +214,7 @@ void initManagedIdentityApplication(ManagedIdentityId idType) {
206214 miApp = ManagedIdentityApplication
207215 .builder (idType )
208216 .httpClient (httpClientMock )
217+ .correlationId (TEST_CORRELATION_ID )
209218 .build ();
210219
211220 // ManagedIdentityApplication uses a static token cache, avoid cross test pollution by clearing it
@@ -396,6 +405,7 @@ void managedIdentityTest_WithCapabilitiesOnly(ManagedIdentitySourceType source,
396405 .builder (ManagedIdentityId .systemAssigned ())
397406 .httpClient (httpClientMock )
398407 .clientCapabilities (singletonList ("cp1" ))
408+ .correlationId (TEST_CORRELATION_ID )
399409 .build ();
400410
401411 miApp .tokenCache .accessTokens .clear ();
@@ -425,6 +435,7 @@ void managedIdentity_ClaimsAndCapabilities(ManagedIdentitySourceType source, Str
425435 .builder (ManagedIdentityId .systemAssigned ())
426436 .clientCapabilities (singletonList ("cp1" ))
427437 .httpClient (httpClientMock )
438+ .correlationId (TEST_CORRELATION_ID )
428439 .build ();
429440
430441 // First call, get the token from the identity provider.
@@ -554,6 +565,7 @@ void managedIdentityTest_Retry(ManagedIdentitySourceType source, String endpoint
554565 miApp = ManagedIdentityApplication
555566 .builder (ManagedIdentityId .systemAssigned ())
556567 .httpClient (httpClientMock )
568+ .correlationId (TEST_CORRELATION_ID )
557569 .build ();
558570
559571 //Several specific 4xx and 5xx errors, such as 500, should trigger MSAL's retry logic
@@ -602,6 +614,7 @@ void managedIdentityTest_RetriesDisabled(ManagedIdentitySourceType source, Strin
602614 .builder (ManagedIdentityId .systemAssigned ())
603615 .httpClient (httpClientMock )
604616 .disableInternalRetries ()
617+ .correlationId (TEST_CORRELATION_ID )
605618 .build ();
606619
607620 //Several specific 4xx and 5xx errors, such as 500, should trigger MSAL's retry logic
@@ -631,6 +644,7 @@ void managedIdentityTest_IMDSRetry() throws Exception {
631644 miApp = ManagedIdentityApplication
632645 .builder (ManagedIdentityId .systemAssigned ())
633646 .httpClient (httpClientMock )
647+ .correlationId (TEST_CORRELATION_ID )
634648 .build ();
635649
636650 // IMDS has different retry logic for certain status codes, such as 410
@@ -672,6 +686,7 @@ void managedIdentityTest_RetrySucceedsAfterFailure() throws Exception {
672686 miApp = ManagedIdentityApplication
673687 .builder (ManagedIdentityId .systemAssigned ())
674688 .httpClient (httpClientMock )
689+ .correlationId (TEST_CORRELATION_ID )
675690 .build ();
676691
677692 // First call returns 500, subsequent calls return 200
@@ -685,7 +700,14 @@ void managedIdentityTest_RetrySucceedsAfterFailure() throws Exception {
685700 assertNotNull (result .accessToken ());
686701
687702 // Verify that the client was called exactly twice (first attempt + one retry)
688- verify (httpClientMock , times (2 )).send (any ());
703+ ArgumentCaptor <HttpRequest > captor = ArgumentCaptor .forClass (HttpRequest .class );
704+ verify (httpClientMock , times (2 )).send (captor .capture ());
705+
706+ // Verify IMDS client metadata headers on the captured request
707+ HttpRequest capturedRequest = captor .getValue ();
708+ assertEquals (EXPECTED_SKU , capturedRequest .headers ().get ("x-client-SKU" ));
709+ assertNotNull (capturedRequest .headers ().get ("x-client-VER" ));
710+ assertDoesNotThrow (() -> UUID .fromString (capturedRequest .headers ().get ("x-ms-client-request-id" )));
689711 }
690712
691713 @ Test
@@ -698,6 +720,7 @@ void managedIdentityTest_NonRetryableError() throws Exception {
698720 miApp = ManagedIdentityApplication
699721 .builder (ManagedIdentityId .systemAssigned ())
700722 .httpClient (httpClientMock )
723+ .correlationId (TEST_CORRELATION_ID )
701724 .build ();
702725
703726 // Use a status code that doesn't trigger retries (499)
@@ -727,6 +750,7 @@ void managedIdentityTest_RetriesARePerRequest() throws Exception {
727750 miApp = ManagedIdentityApplication
728751 .builder (ManagedIdentityId .systemAssigned ())
729752 .httpClient (httpClientMock )
753+ .correlationId (TEST_CORRELATION_ID )
730754 .build ();
731755
732756 // First call returns 500, subsequent calls return 200
0 commit comments