Skip to content

Commit b09ca16

Browse files
authored
fix: attach Authorization header to streaming and ApiExecutor requests (#330) (#331)
* fix: Issue #330 added ApiClient#applyAuthHeader as a single entry point for attaching OAuth2 Authorization header. Includes lazy OAuth2Client initialization on detected CLIENT_CREDENTIALS use. * fix: #330 use setHeader on request builder to address duplicate Authorization header * test: #330 improve coverage on ApiClient changeset * fix: #330 defensively pivot to concurrent map of OAuth2Client to protect multi-tenant credentials * fix: #330 do not keep raw client secret value in cached keys on heap * fix: #330 address check-then-act race for access token using single-flight pattern in OAuth2Client * chore: lint and formatting compliance * fix: #330 address reviewer feedback for managing ApiClient#applyAuthHeader exception propogation * fix: #330 consolidate existing AccessToken rather than creating a forked TokenSnapshot Also improve stream API executor tests per reviewer comments. * fix: #330 address reviewer input to reorder telemetry call for stability during OAuth2 exchange handling * fix: #330 address possible redundant token exchange with refactor to strict mutex; includes regression test
1 parent 2565cc8 commit b09ca16

12 files changed

Lines changed: 723 additions & 90 deletions

src/main/java/dev/openfga/sdk/api/BaseStreamingApi.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,8 @@ protected HttpRequest buildHttpRequest(String method, String path, Object body,
168168
byte[] bodyBytes = objectMapper.writeValueAsBytes(body);
169169
HttpRequest.Builder requestBuilder = ApiClient.requestBuilder(method, path, bodyBytes, configuration);
170170

171+
apiClient.applyAuthHeader(requestBuilder, configuration);
172+
171173
// Apply request interceptors if any
172174
var interceptor = apiClient.getRequestInterceptor();
173175
if (interceptor != null) {

src/main/java/dev/openfga/sdk/api/OpenFgaApi.java

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import static dev.openfga.sdk.util.StringUtil.isNullOrWhitespace;
1616
import static dev.openfga.sdk.util.Validation.assertParamExists;
1717

18-
import dev.openfga.sdk.api.auth.*;
1918
import dev.openfga.sdk.api.client.*;
2019
import dev.openfga.sdk.api.configuration.*;
2120
import dev.openfga.sdk.api.model.BatchCheckRequest;
@@ -69,7 +68,6 @@ public class OpenFgaApi {
6968
private final Configuration configuration;
7069

7170
private final ApiClient apiClient;
72-
private final OAuth2Client oAuth2Client;
7371
private final Telemetry telemetry;
7472

7573
public OpenFgaApi(Configuration configuration) throws FgaInvalidParameterException {
@@ -89,12 +87,6 @@ public OpenFgaApi(Configuration configuration, ApiClient apiClient, Telemetry te
8987
this.configuration = configuration;
9088
this.telemetry = telemetry;
9189

92-
if (configuration.getCredentials().getCredentialsMethod() == CredentialsMethod.CLIENT_CREDENTIALS) {
93-
this.oAuth2Client = new OAuth2Client(configuration, apiClient);
94-
} else {
95-
this.oAuth2Client = null;
96-
}
97-
9890
var defaultHeaders = configuration.getDefaultHeaders();
9991
if (defaultHeaders != null) {
10092
apiClient.addRequestInterceptor(httpRequest -> defaultHeaders.forEach(httpRequest::setHeader));
@@ -1294,10 +1286,7 @@ private HttpRequest buildHttpRequestWithPublisher(
12941286
httpRequest.header("Content-Type", "application/json");
12951287
httpRequest.header("Accept", "application/json");
12961288

1297-
if (configuration.getCredentials().getCredentialsMethod() != CredentialsMethod.NONE) {
1298-
String accessToken = getAccessToken(configuration);
1299-
httpRequest.header("Authorization", "Bearer " + accessToken);
1300-
}
1289+
apiClient.applyAuthHeader(httpRequest, configuration);
13011290

13021291
if (configuration.getUserAgent() != null) {
13031292
httpRequest.header("User-Agent", configuration.getUserAgent());
@@ -1337,29 +1326,4 @@ private String pathWithParams(String basePath, Object... params) {
13371326
}
13381327
return path.toString();
13391328
}
1340-
1341-
/**
1342-
* Get an access token. Expects that configuration is valid (meaning it can
1343-
* pass {@link Configuration#assertValid()}) and expects that if the
1344-
* CredentialsMethod is CLIENT_CREDENTIALS that a valid {@link OAuth2Client}
1345-
* has been initialized. Otherwise, it will throw an IllegalStateException.
1346-
* @throws IllegalStateException when the configuration is invalid
1347-
*/
1348-
private String getAccessToken(Configuration configuration) throws ApiException {
1349-
CredentialsMethod credentialsMethod = configuration.getCredentials().getCredentialsMethod();
1350-
1351-
if (credentialsMethod == CredentialsMethod.API_TOKEN) {
1352-
return configuration.getCredentials().getApiToken().getToken();
1353-
}
1354-
1355-
if (credentialsMethod == CredentialsMethod.CLIENT_CREDENTIALS) {
1356-
try {
1357-
return oAuth2Client.getAccessToken().get();
1358-
} catch (Exception e) {
1359-
throw new ApiException(e);
1360-
}
1361-
}
1362-
1363-
throw new IllegalStateException("Configuration is invalid.");
1364-
}
13651329
}

src/main/java/dev/openfga/sdk/api/auth/AccessToken.java

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,25 @@
55
import dev.openfga.sdk.constants.FgaConstants;
66
import java.time.Instant;
77
import java.time.temporal.ChronoUnit;
8-
import java.util.Random;
9-
10-
class AccessToken {
8+
import java.util.concurrent.ThreadLocalRandom;
9+
10+
/**
11+
* Immutable snapshot of an access token and its expiry time. The snapshot is valid if the token is non-empty
12+
* and the current time is before the expiry time minus a buffer to ensure that callers receive a valid token
13+
* even if there is some clock skew or delay between retrieval and use.
14+
*/
15+
record AccessToken(String token, Instant expiresAt) {
1116
private static final int TOKEN_EXPIRY_BUFFER_THRESHOLD_IN_SEC = FgaConstants.TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC;
1217
// We add some jitter so that token refreshes are less likely to collide
1318
private static final int TOKEN_EXPIRY_JITTER_IN_SEC = FgaConstants.TOKEN_EXPIRY_JITTER_IN_SEC;
1419

15-
private Instant expiresAt;
20+
static final AccessToken EMPTY = new AccessToken(null, null);
1621

17-
private final Random random = new Random();
18-
private String token;
22+
AccessToken {
23+
expiresAt = expiresAt != null ? expiresAt.truncatedTo(ChronoUnit.SECONDS) : null;
24+
}
1925

20-
public boolean isValid() {
26+
boolean isValid() {
2127
if (isNullOrWhitespace(token)) {
2228
return false;
2329
}
@@ -31,24 +37,9 @@ public boolean isValid() {
3137
// to account for multiple calls to `isValid` at the same time and prevent multiple refresh calls
3238
Instant expiresWithLeeway = expiresAt
3339
.minusSeconds(TOKEN_EXPIRY_BUFFER_THRESHOLD_IN_SEC)
34-
.minusSeconds(random.nextInt(TOKEN_EXPIRY_JITTER_IN_SEC))
40+
.minusSeconds(ThreadLocalRandom.current().nextInt(TOKEN_EXPIRY_JITTER_IN_SEC))
3541
.truncatedTo(ChronoUnit.SECONDS);
3642

3743
return Instant.now().truncatedTo(ChronoUnit.SECONDS).isBefore(expiresWithLeeway);
3844
}
39-
40-
public String getToken() {
41-
return token;
42-
}
43-
44-
public void setExpiresAt(Instant expiresAt) {
45-
if (expiresAt != null) {
46-
// Truncate to seconds to zero out the milliseconds to keep comparison simpler
47-
this.expiresAt = expiresAt.truncatedTo(ChronoUnit.SECONDS);
48-
}
49-
}
50-
51-
public void setToken(String token) {
52-
this.token = token;
53-
}
5445
}

src/main/java/dev/openfga/sdk/api/auth/OAuth2Client.java

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,20 @@
44
import dev.openfga.sdk.api.configuration.*;
55
import dev.openfga.sdk.errors.ApiException;
66
import dev.openfga.sdk.errors.FgaInvalidParameterException;
7-
import dev.openfga.sdk.telemetry.Attribute;
87
import dev.openfga.sdk.telemetry.Telemetry;
98
import java.net.URI;
109
import java.net.http.HttpRequest;
1110
import java.time.Instant;
1211
import java.util.HashMap;
13-
import java.util.Map;
1412
import java.util.concurrent.CompletableFuture;
13+
import java.util.concurrent.atomic.AtomicReference;
1514

1615
public class OAuth2Client {
1716
private static final String DEFAULT_API_TOKEN_ISSUER_PATH = "/oauth/token";
1817

1918
private final ApiClient apiClient;
20-
private final AccessToken token = new AccessToken();
19+
private final AtomicReference<AccessToken> snapshot = new AtomicReference<>(AccessToken.EMPTY);
20+
private final AtomicReference<CompletableFuture<String>> inFlight = new AtomicReference<>();
2121
private final CredentialsFlowRequest authRequest;
2222
private final Configuration config;
2323
private final Telemetry telemetry;
@@ -45,26 +45,60 @@ public OAuth2Client(Configuration configuration, ApiClient apiClient) throws Fga
4545
}
4646

4747
/**
48-
* Gets an access token, handling exchange when necessary. The access token is naively cached in memory until it
49-
* expires.
48+
* Gets an access token, handling exchange when necessary. The token is cached as an immutable
49+
* snapshot until it expires. Concurrent calls are deduplicated: only one exchange is in flight
50+
* at a time; other callers join the same future rather than issuing redundant requests.
5051
*
5152
* @return An access token in a {@link CompletableFuture}
5253
*/
5354
public CompletableFuture<String> getAccessToken() throws FgaInvalidParameterException, ApiException {
54-
if (!token.isValid()) {
55-
return exchangeToken().thenCompose(response -> {
56-
token.setToken(response.getAccessToken());
57-
token.setExpiresAt(Instant.now().plusSeconds(response.getExpiresInSeconds()));
58-
59-
Map<Attribute, String> attributesMap = new HashMap<>();
60-
61-
telemetry.metrics().credentialsRequest(1L, attributesMap);
62-
63-
return CompletableFuture.completedFuture(token.getToken());
64-
});
55+
// Fast path (lock-free): return cached token if still valid.
56+
AccessToken current = snapshot.get();
57+
if (current.isValid()) {
58+
return CompletableFuture.completedFuture(current.token());
6559
}
6660

67-
return CompletableFuture.completedFuture(token.getToken());
61+
// Slow path: decide under the lock who starts the exchange.
62+
synchronized (this) {
63+
// Double-check: another thread may have refreshed while we waited.
64+
AccessToken rechecked = snapshot.get();
65+
if (rechecked.isValid()) {
66+
return CompletableFuture.completedFuture(rechecked.token());
67+
}
68+
69+
// Join an existing in-flight exchange.
70+
CompletableFuture<String> existing = inFlight.get();
71+
if (existing != null) {
72+
return existing;
73+
}
74+
75+
// Start a new exchange and publish the future so other callers join it.
76+
CompletableFuture<String> promise = new CompletableFuture<>();
77+
inFlight.set(promise);
78+
79+
try {
80+
exchangeToken().whenComplete((response, ex) -> {
81+
if (ex != null) {
82+
inFlight.set(null);
83+
promise.completeExceptionally(ex);
84+
} else {
85+
String token = response.getAccessToken();
86+
// Write snapshot before clearing the gate so any new caller that arrives
87+
// after inFlight becomes null immediately sees a valid token.
88+
snapshot.set(new AccessToken(token, Instant.now().plusSeconds(response.getExpiresInSeconds())));
89+
inFlight.set(null);
90+
promise.complete(token);
91+
telemetry.metrics().credentialsRequest(1L, new HashMap<>());
92+
}
93+
});
94+
} catch (Exception e) {
95+
inFlight.set(null);
96+
promise.completeExceptionally(e);
97+
throw e;
98+
}
99+
100+
return promise;
101+
}
68102
}
69103

70104
/**

src/main/java/dev/openfga/sdk/api/client/ApiClient.java

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@
77
import com.fasterxml.jackson.databind.ObjectMapper;
88
import com.fasterxml.jackson.databind.SerializationFeature;
99
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
10+
import dev.openfga.sdk.api.auth.OAuth2Client;
11+
import dev.openfga.sdk.api.configuration.ClientCredentials;
1012
import dev.openfga.sdk.api.configuration.Configuration;
13+
import dev.openfga.sdk.api.configuration.Credentials;
14+
import dev.openfga.sdk.api.configuration.CredentialsMethod;
15+
import dev.openfga.sdk.errors.ApiException;
1116
import dev.openfga.sdk.errors.FgaInvalidParameterException;
1217
import dev.openfga.sdk.util.StringUtil;
1318
import java.io.InputStream;
@@ -16,7 +21,14 @@
1621
import java.net.http.HttpClient;
1722
import java.net.http.HttpRequest;
1823
import java.net.http.HttpResponse;
24+
import java.security.MessageDigest;
25+
import java.security.NoSuchAlgorithmException;
1926
import java.time.Duration;
27+
import java.util.Arrays;
28+
import java.util.Objects;
29+
import java.util.concurrent.ConcurrentHashMap;
30+
import java.util.concurrent.ConcurrentMap;
31+
import java.util.concurrent.ExecutionException;
2032
import java.util.function.Consumer;
2133
import org.openapitools.jackson.nullable.JsonNullableModule;
2234

@@ -41,6 +53,7 @@ public class ApiClient {
4153
private Consumer<HttpRequest.Builder> interceptor;
4254
private Consumer<HttpResponse<InputStream>> responseInterceptor;
4355
private Consumer<HttpResponse<String>> asyncResponseInterceptor;
56+
private final ConcurrentMap<CredentialsCacheKey, OAuth2Client> oAuth2Clients = new ConcurrentHashMap<>();
4457

4558
/**
4659
* Create an instance of ApiClient.
@@ -324,4 +337,118 @@ public ApiClient setAsyncResponseInterceptor(Consumer<HttpResponse<String>> inte
324337
public Consumer<HttpResponse<String>> getAsyncResponseInterceptor() {
325338
return asyncResponseInterceptor;
326339
}
340+
341+
/**
342+
* Applies the {@code Authorization: Bearer <token>} header to the request builder based on the
343+
* supplied configuration's {@link Credentials}. This is the single entry point for attaching
344+
* auth to outbound requests across the SDK — every request builder should delegate here.
345+
*
346+
* <ul>
347+
* <li>{@link CredentialsMethod#NONE}: no header is applied.</li>
348+
* <li>{@link CredentialsMethod#API_TOKEN}: the static API token from the configuration is used.</li>
349+
* <li>{@link CredentialsMethod#CLIENT_CREDENTIALS}: an {@link OAuth2Client} performs the
350+
* client-credentials exchange and caches the token on this {@code ApiClient} until expiry.
351+
* The client is lazily created from {@code configuration} on first use.</li>
352+
* </ul>
353+
*
354+
* @param requestBuilder the request builder to mutate.
355+
* @param configuration the configuration that supplies credentials.
356+
* @throws ApiException if CLIENT_CREDENTIALS token exchange fails.
357+
* @throws FgaInvalidParameterException if the configuration is invalid when lazily creating
358+
* an {@link OAuth2Client}.
359+
*/
360+
public void applyAuthHeader(HttpRequest.Builder requestBuilder, Configuration configuration)
361+
throws ApiException, FgaInvalidParameterException {
362+
363+
Credentials credentials = configuration.getCredentials();
364+
if (credentials == null) {
365+
return;
366+
}
367+
368+
CredentialsMethod method = credentials.getCredentialsMethod();
369+
if (method == null || method == CredentialsMethod.NONE) {
370+
return;
371+
}
372+
373+
String accessToken;
374+
switch (method) {
375+
case API_TOKEN:
376+
accessToken = credentials.getApiToken().getToken();
377+
break;
378+
case CLIENT_CREDENTIALS:
379+
try {
380+
accessToken =
381+
ensureOAuth2Client(configuration).getAccessToken().get();
382+
} catch (InterruptedException e) {
383+
Thread.currentThread().interrupt();
384+
throw new ApiException(e);
385+
} catch (ExecutionException e) {
386+
Throwable cause = e.getCause();
387+
if (cause instanceof ApiException) {
388+
throw (ApiException) cause;
389+
}
390+
throw new ApiException(cause != null ? cause : e);
391+
}
392+
break;
393+
default:
394+
throw new IllegalStateException("Unknown credentials method: " + method);
395+
}
396+
397+
requestBuilder.setHeader("Authorization", "Bearer " + accessToken);
398+
}
399+
400+
private OAuth2Client ensureOAuth2Client(Configuration configuration) throws FgaInvalidParameterException {
401+
ClientCredentials cc = configuration.getCredentials().getClientCredentials();
402+
CredentialsCacheKey key = new CredentialsCacheKey(cc);
403+
OAuth2Client existing = oAuth2Clients.get(key);
404+
if (existing != null) {
405+
return existing;
406+
}
407+
OAuth2Client created = new OAuth2Client(configuration, this);
408+
OAuth2Client prior = oAuth2Clients.putIfAbsent(key, created);
409+
return prior != null ? prior : created;
410+
}
411+
412+
private static final class CredentialsCacheKey {
413+
private final String clientId;
414+
private final byte[] clientSecretHash;
415+
private final String apiTokenIssuer;
416+
private final String apiAudience;
417+
private final String scopes;
418+
419+
CredentialsCacheKey(ClientCredentials cc) {
420+
this.clientId = cc.getClientId();
421+
this.clientSecretHash = sha256(cc.getClientSecret());
422+
this.apiTokenIssuer = cc.getApiTokenIssuer();
423+
this.apiAudience = cc.getApiAudience();
424+
this.scopes = cc.getScopes();
425+
}
426+
427+
private static byte[] sha256(String value) {
428+
try {
429+
return MessageDigest.getInstance("SHA-256").digest(value == null ? new byte[0] : value.getBytes(UTF_8));
430+
} catch (NoSuchAlgorithmException e) {
431+
throw new IllegalStateException("SHA-256 not available", e);
432+
}
433+
}
434+
435+
@Override
436+
public boolean equals(Object o) {
437+
if (this == o) return true;
438+
if (!(o instanceof CredentialsCacheKey)) return false;
439+
CredentialsCacheKey that = (CredentialsCacheKey) o;
440+
return Objects.equals(clientId, that.clientId)
441+
&& Arrays.equals(clientSecretHash, that.clientSecretHash)
442+
&& Objects.equals(apiTokenIssuer, that.apiTokenIssuer)
443+
&& Objects.equals(apiAudience, that.apiAudience)
444+
&& Objects.equals(scopes, that.scopes);
445+
}
446+
447+
@Override
448+
public int hashCode() {
449+
int result = Objects.hash(clientId, apiTokenIssuer, apiAudience, scopes);
450+
result = 31 * result + Arrays.hashCode(clientSecretHash);
451+
return result;
452+
}
453+
}
327454
}

0 commit comments

Comments
 (0)