Skip to content

Commit bb21de7

Browse files
committed
fix: #330 address check-then-act race for access token using single-flight pattern in OAuth2Client
1 parent 5a7754e commit bb21de7

4 files changed

Lines changed: 126 additions & 17 deletions

File tree

src/main/java/dev/openfga/sdk/api/auth/OAuth2Client.java

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
import java.time.Instant;
1212
import java.util.HashMap;
1313
import java.util.Map;
14+
import java.util.concurrent.atomic.AtomicReference;
1415
import java.util.concurrent.CompletableFuture;
1516

1617
public class OAuth2Client {
1718
private static final String DEFAULT_API_TOKEN_ISSUER_PATH = "/oauth/token";
1819

1920
private final ApiClient apiClient;
20-
private final AccessToken token = new AccessToken();
21+
private final AtomicReference<TokenSnapshot> snapshot = new AtomicReference<>(TokenSnapshot.EMPTY);
22+
private final AtomicReference<CompletableFuture<String>> inFlight = new AtomicReference<>();
2123
private final CredentialsFlowRequest authRequest;
2224
private final Configuration config;
2325
private final Telemetry telemetry;
@@ -45,26 +47,54 @@ public OAuth2Client(Configuration configuration, ApiClient apiClient) throws Fga
4547
}
4648

4749
/**
48-
* Gets an access token, handling exchange when necessary. The access token is naively cached in memory until it
49-
* expires.
50+
* Gets an access token, handling exchange when necessary. The token is cached as an immutable
51+
* snapshot until it expires. Concurrent calls are deduplicated: only one exchange is in flight
52+
* at a time; other callers join the same future rather than issuing redundant requests.
5053
*
5154
* @return An access token in a {@link CompletableFuture}
5255
*/
5356
public CompletableFuture<String> getAccessToken() throws FgaInvalidParameterException, ApiException {
54-
if (!token.isValid()) {
55-
return exchangeToken().thenCompose(response -> {
56-
token.setToken(response.getAccessToken());
57-
token.setExpiresAt(Instant.now().plusSeconds(response.getExpiresInSeconds()));
58-
59-
Map<Attribute, String> attributesMap = new HashMap<>();
57+
TokenSnapshot current = snapshot.get();
58+
if (current.isValid()) {
59+
return CompletableFuture.completedFuture(current.token());
60+
}
6061

61-
telemetry.metrics().credentialsRequest(1L, attributesMap);
62+
CompletableFuture<String> promise = new CompletableFuture<>();
63+
if (!inFlight.compareAndSet(null, promise)) {
64+
// Another thread won the race — join its exchange rather than starting a new one.
65+
CompletableFuture<String> existing = inFlight.get();
66+
return existing != null ? existing : getAccessToken();
67+
}
6268

63-
return CompletableFuture.completedFuture(token.getToken());
69+
// This thread owns the exchange. Start it, wiring completion back to `promise`.
70+
try {
71+
exchangeToken().whenComplete((response, ex) -> {
72+
if (ex != null) {
73+
inFlight.set(null);
74+
promise.completeExceptionally(ex);
75+
} else {
76+
String token = response.getAccessToken();
77+
// Write snapshot before clearing the gate so any new caller that arrives
78+
// after inFlight becomes null immediately sees a valid token.
79+
snapshot.set(
80+
new TokenSnapshot(
81+
token,
82+
Instant.now().plusSeconds(response.getExpiresInSeconds())));
83+
84+
telemetry.metrics().credentialsRequest(1L, new HashMap<>());
85+
86+
// Clear before completing
87+
inFlight.set(null);
88+
promise.complete(token);
89+
}
6490
});
91+
} catch (Exception e) {
92+
inFlight.set(null);
93+
promise.completeExceptionally(e);
94+
throw e;
6595
}
6696

67-
return CompletableFuture.completedFuture(token.getToken());
97+
return promise;
6898
}
6999

70100
/**
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package dev.openfga.sdk.api.auth;
2+
3+
import static dev.openfga.sdk.util.StringUtil.isNullOrWhitespace;
4+
5+
import dev.openfga.sdk.constants.FgaConstants;
6+
import java.time.Instant;
7+
import java.time.temporal.ChronoUnit;
8+
import java.util.concurrent.ThreadLocalRandom;
9+
10+
/**
11+
* Immutable snapshot of an access token and its expiry time. The snapshot is valid if the token is non-empty
12+
* and the current time is before the expiry time minus a buffer to ensure that callers receive a valid token
13+
* even if there is some clock skew or delay between retrieval and use.
14+
*/
15+
record TokenSnapshot(String token, Instant expiresAt) {
16+
private static final int EXPIRY_BUFFER_SECS = FgaConstants.TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC;
17+
private static final int EXPIRY_JITTER_SECS = FgaConstants.TOKEN_EXPIRY_JITTER_IN_SEC;
18+
19+
static final TokenSnapshot EMPTY = new TokenSnapshot(null, null);
20+
21+
TokenSnapshot {
22+
expiresAt = expiresAt != null ? expiresAt.truncatedTo(ChronoUnit.SECONDS) : null;
23+
}
24+
25+
boolean isValid() {
26+
if (isNullOrWhitespace(token)) {
27+
return false;
28+
}
29+
if (expiresAt == null) {
30+
return true;
31+
}
32+
Instant expiresWithLeeway = expiresAt
33+
.minusSeconds(EXPIRY_BUFFER_SECS)
34+
.minusSeconds(ThreadLocalRandom.current().nextInt(EXPIRY_JITTER_SECS))
35+
.truncatedTo(ChronoUnit.SECONDS);
36+
37+
return Instant.now().truncatedTo(ChronoUnit.SECONDS).isBefore(expiresWithLeeway);
38+
}
39+
}

src/test/java/dev/openfga/sdk/api/auth/AccessTokenTest.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,8 @@ private static Stream<Arguments> expTimeAndResults() {
3838

3939
@MethodSource("expTimeAndResults")
4040
@ParameterizedTest(name = "{0}")
41-
public void testTokenValid(String name, Instant exp, boolean valid) {
42-
AccessToken accessToken = new AccessToken();
43-
accessToken.setToken("token");
44-
accessToken.setExpiresAt(exp);
45-
assertEquals(valid, accessToken.isValid());
41+
void testTokenValid(String name, Instant exp, boolean valid) {
42+
TokenSnapshot snapshot = new TokenSnapshot("token", exp);
43+
assertEquals(valid, snapshot.isValid());
4644
}
4745
}

src/test/java/dev/openfga/sdk/api/auth/OAuth2ClientTest.java

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818
import java.net.URLEncoder;
1919
import java.nio.charset.StandardCharsets;
2020
import java.time.Duration;
21+
import java.util.ArrayList;
22+
import java.util.Collections;
23+
import java.util.List;
24+
import java.util.concurrent.CountDownLatch;
25+
import java.util.concurrent.TimeUnit;
2126
import java.util.stream.Stream;
2227
import org.junit.jupiter.api.Test;
2328
import org.junit.jupiter.params.ParameterizedTest;
@@ -166,6 +171,43 @@ public void exchangeOAuth2TokenWithRetriesFailure(WireMockRuntimeInfo wm) throws
166171
verify(3, postRequestedFor(urlEqualTo("/oauth/token")));
167172
}
168173

174+
@Test
175+
void exchangeOAuth2Token_concurrentRequests_singleExchange(WireMockRuntimeInfo wm) throws Exception {
176+
// Stub with a delay so concurrent threads pile up before the first exchange completes.
177+
stubFor(post(urlEqualTo("/oauth/token"))
178+
.willReturn(ok(String.format("{\"access_token\":\"%s\",\"expires_in\":3600}", ACCESS_TOKEN))
179+
.withFixedDelay(100)));
180+
181+
OAuth2Client client = newOAuth2Client(wm.getHttpBaseUrl(), false);
182+
183+
int threadCount = 5;
184+
CountDownLatch startGate = new CountDownLatch(1);
185+
CountDownLatch done = new CountDownLatch(threadCount);
186+
List<String> tokens = Collections.synchronizedList(new ArrayList<>());
187+
List<Throwable> failures = Collections.synchronizedList(new ArrayList<>());
188+
189+
for (int i = 0; i < threadCount; i++) {
190+
new Thread(() -> {
191+
try {
192+
startGate.await();
193+
tokens.add(client.getAccessToken().get());
194+
} catch (Exception e) {
195+
failures.add(e);
196+
} finally {
197+
done.countDown();
198+
}
199+
}).start();
200+
}
201+
202+
startGate.countDown();
203+
assertTrue(done.await(3, TimeUnit.SECONDS), "threads did not complete in time");
204+
205+
assertEquals(List.of(), failures, "no thread should have thrown");
206+
assertEquals(threadCount, tokens.size(), "all threads should have received a token");
207+
assertTrue(tokens.stream().allMatch(ACCESS_TOKEN::equals), "all threads should have received the same token");
208+
verify(1, postRequestedFor(urlEqualTo("/oauth/token")));
209+
}
210+
169211
@Test
170212
public void apiTokenIssuer_invalidScheme() {
171213
// When

0 commit comments

Comments
 (0)