|
13 | 13 | import org.springframework.http.ResponseEntity; |
14 | 14 | import org.springframework.web.client.HttpClientErrorException; |
15 | 15 | import org.springframework.web.client.RestTemplate; |
| 16 | +import java.lang.reflect.Field; |
16 | 17 | import java.time.Instant; |
17 | 18 | import java.util.HashMap; |
18 | 19 | import java.util.Map; |
| 20 | +import java.util.concurrent.ExecutorService; |
| 21 | +import java.util.concurrent.Executors; |
| 22 | +import java.util.concurrent.Future; |
| 23 | +import static java.util.concurrent.TimeUnit.MILLISECONDS; |
| 24 | +import static java.util.concurrent.TimeUnit.SECONDS; |
| 25 | +import static org.awaitility.Awaitility.await; |
19 | 26 | import static org.junit.jupiter.api.Assertions.assertEquals; |
20 | 27 | import static org.junit.jupiter.api.Assertions.assertNotNull; |
21 | 28 | import static org.junit.jupiter.api.Assertions.assertThrows; |
22 | 29 | import static org.junit.jupiter.api.Assertions.assertTrue; |
23 | | -import static org.junit.jupiter.api.Assertions.fail; |
24 | 30 | import static org.mockito.ArgumentMatchers.any; |
25 | 31 | import static org.mockito.ArgumentMatchers.anyString; |
26 | 32 | import static org.mockito.ArgumentMatchers.eq; |
27 | 33 | import static org.mockito.Mockito.mock; |
28 | | -import static org.mockito.Mockito.never; |
29 | 34 | import static org.mockito.Mockito.spy; |
30 | 35 | import static org.mockito.Mockito.times; |
31 | 36 | import static org.mockito.Mockito.verify; |
@@ -97,22 +102,37 @@ void testTooManyRequestsBacksOffAndRetries() { |
97 | 102 | assertTrue(authClient.getExpireTime().isAfter(Instant.now())); |
98 | 103 | } |
99 | 104 |
|
| 105 | + |
100 | 106 | @Test |
101 | | - void testRefreshToken_skipsIfTokenIsValidInitially() { |
102 | | - CrowdStrikeAuthClient client = spy(new CrowdStrikeAuthClient(mockSourceConfig) { |
103 | | - @Override |
104 | | - protected boolean isTokenValid() { |
105 | | - return true; |
106 | | - } |
107 | | - |
108 | | - @Override |
109 | | - protected void getAuthToken() { |
110 | | - fail("getAuthToken should not be called when token is already valid."); |
111 | | - } |
112 | | - }); |
113 | | - |
114 | | - client.refreshToken(); |
115 | | - verify(client, times(1)).isTokenValid(); |
116 | | - verify(client, never()).getAuthToken(); |
| 107 | + void testConcurrentRefreshToken_onlyOneApiCall() throws Exception { |
| 108 | + CrowdStrikeAuthClient client = spy(new CrowdStrikeAuthClient(mockSourceConfig)); |
| 109 | + Field restTemplateField = CrowdStrikeAuthClient.class.getDeclaredField("restTemplate"); |
| 110 | + restTemplateField.setAccessible(true); |
| 111 | + restTemplateField.set(client, restTemplateMock); |
| 112 | + when(restTemplateMock.postForEntity(anyString(), any(HttpEntity.class), eq(Map.class))) |
| 113 | + .thenReturn(ResponseEntity.ok(Map.of( |
| 114 | + "access_token", "mock_access_token", |
| 115 | + "expires_in", 3600 |
| 116 | + ))); |
| 117 | + |
| 118 | + // Launch two parallel refreshToken() calls |
| 119 | + ExecutorService executor = Executors.newFixedThreadPool(2); |
| 120 | + Future<?> firstCall = executor.submit(client::refreshToken); |
| 121 | + Future<?> secondCall = executor.submit(client::refreshToken); |
| 122 | + |
| 123 | + await() |
| 124 | + .atMost(10, SECONDS) |
| 125 | + .pollInterval(10, MILLISECONDS) |
| 126 | + .until(() -> firstCall.isDone() && secondCall.isDone()); |
| 127 | + |
| 128 | + executor.shutdown(); |
| 129 | + |
| 130 | + // Validate only 1 token request is made |
| 131 | + assertNotNull(client.getBearerToken()); |
| 132 | + assertEquals("mock_access_token", client.getBearerToken()); |
| 133 | + assertNotNull(client.getExpireTime()); |
| 134 | + assertTrue(client.getExpireTime().isAfter(Instant.now().minusSeconds(3500))); |
| 135 | + |
| 136 | + verify(restTemplateMock, times(1)).postForEntity(anyString(), any(HttpEntity.class), eq(Map.class)); |
117 | 137 | } |
118 | 138 | } |
0 commit comments