Skip to content

Commit 42a863b

Browse files
committed
chore: experimental TLA+ proof
1 parent f879fde commit 42a863b

7 files changed

Lines changed: 399 additions & 8 deletions

File tree

src/main/java/dev/openfga/sdk/api/auth/OAuth2Client.java

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,16 @@ public class OAuth2Client {
2222
private final Configuration config;
2323
private final Telemetry telemetry;
2424

25+
/**
26+
* Test-only seam invoked on the cold path after the lock-free snapshot check fails and
27+
* before entering the synchronized acquisition gate. Defaults to a no-op. Used by tests to
28+
* deterministically interleave threads around the post-exchange race window. Not part of the
29+
* public API.
30+
*/
31+
static final Runnable NO_OP_HOOK = () -> {};
32+
33+
volatile Runnable beforeAcquireHook = NO_OP_HOOK;
34+
2535
/**
2636
* Initializes a new instance of the {@link OAuth2Client} class
2737
*
@@ -49,24 +59,55 @@ public OAuth2Client(Configuration configuration, ApiClient apiClient) throws Fga
4959
* snapshot until it expires. Concurrent calls are deduplicated: only one exchange is in flight
5060
* at a time; other callers join the same future rather than issuing redundant requests.
5161
*
62+
* <p>The hot path (valid cached token) is lock-free. The cold path serializes the
63+
* "is snapshot valid? / is there an in-flight exchange? / publish my promise" decision
64+
* under a monitor so that:
65+
* <ul>
66+
* <li>at most one exchange is started per expiry,</li>
67+
* <li>joiners always observe the same in-flight promise that the owner will complete.</li>
68+
* </ul>
69+
* The exchange itself (the HTTP round-trip) runs asynchronously outside the monitor.
70+
*
5271
* @return An access token in a {@link CompletableFuture}
5372
*/
5473
public CompletableFuture<String> getAccessToken() throws FgaInvalidParameterException, ApiException {
74+
// Lock-free hot path: a valid cached token short-circuits everything.
5575
AccessToken current = snapshot.get();
5676
if (current.isValid()) {
5777
return CompletableFuture.completedFuture(current.token());
5878
}
79+
// Cold-path test seam (no-op in production).
80+
beforeAcquireHook.run();
81+
return acquireToken();
82+
}
5983

60-
CompletableFuture<String> promise = new CompletableFuture<>();
61-
if (!inFlight.compareAndSet(null, promise)) {
62-
// Another thread won the race — join its exchange rather than starting a new one.
63-
CompletableFuture<String> existing = inFlight.get();
64-
return existing != null ? existing : getAccessToken();
84+
/**
85+
* Cold path: snapshot is missing or expired. Serialized to guarantee a single in-flight
86+
* exchange and to avoid the join-vs-clear race that an atomic-CAS-only approach is prone to.
87+
*/
88+
private synchronized CompletableFuture<String> acquireToken()
89+
throws FgaInvalidParameterException, ApiException {
90+
// Re-check under the monitor: another thread may have just refreshed the snapshot.
91+
AccessToken current = snapshot.get();
92+
if (current.isValid()) {
93+
return CompletableFuture.completedFuture(current.token());
94+
}
95+
96+
// Join an existing exchange if one is already in flight.
97+
CompletableFuture<String> existing = inFlight.get();
98+
if (existing != null) {
99+
return existing;
65100
}
66101

67-
// This thread owns the exchange. Start it, wiring completion back to `promise`.
102+
// This thread owns the exchange. Publish the promise so concurrent callers can join.
103+
CompletableFuture<String> promise = new CompletableFuture<>();
104+
inFlight.set(promise);
105+
68106
try {
69107
exchangeToken().whenComplete((response, ex) -> {
108+
// Completion runs asynchronously, outside the monitor. That's fine: state
109+
// transitions here (snapshot, inFlight, promise) are each individually
110+
// thread-safe, and the monitor only guards the decision to *start* an exchange.
70111
if (ex != null) {
71112
inFlight.set(null);
72113
promise.completeExceptionally(ex);
@@ -75,14 +116,13 @@ public CompletableFuture<String> getAccessToken() throws FgaInvalidParameterExce
75116
// Write snapshot before clearing the gate so any new caller that arrives
76117
// after inFlight becomes null immediately sees a valid token.
77118
snapshot.set(new AccessToken(token, Instant.now().plusSeconds(response.getExpiresInSeconds())));
78-
79-
// Clear before completing
80119
inFlight.set(null);
81120
promise.complete(token);
82121
telemetry.metrics().credentialsRequest(1L, new HashMap<>());
83122
}
84123
});
85124
} catch (Exception e) {
125+
// Synchronous failure to even dispatch the request: clear the gate and propagate.
86126
inFlight.set(null);
87127
promise.completeExceptionally(e);
88128
throw e;

src/test/java/dev/openfga/sdk/api/auth/OAuth2ClientTest.java

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.util.List;
2424
import java.util.concurrent.CountDownLatch;
2525
import java.util.concurrent.TimeUnit;
26+
import java.util.concurrent.atomic.AtomicReference;
2627
import java.util.stream.Stream;
2728
import org.junit.jupiter.api.Test;
2829
import org.junit.jupiter.params.ParameterizedTest;
@@ -209,6 +210,102 @@ void exchangeOAuth2Token_concurrentRequests_singleExchange(WireMockRuntimeInfo w
209210
verify(1, postRequestedFor(urlEqualTo("/oauth/token")));
210211
}
211212

213+
/**
214+
* After a successful exchange, subsequent calls must hit the cached snapshot and avoid the
215+
* network entirely. This guards the lock-free hot path in {@link OAuth2Client#getAccessToken()}.
216+
*/
217+
@Test
218+
void exchangeOAuth2Token_cachedAcrossCalls_noSecondRequest(WireMockRuntimeInfo wm) throws Exception {
219+
stubFor(post(urlEqualTo("/oauth/token"))
220+
.willReturn(ok(String.format("{\"access_token\":\"%s\",\"expires_in\":3600}", ACCESS_TOKEN))));
221+
222+
OAuth2Client client = newOAuth2Client(wm.getHttpBaseUrl(), false);
223+
224+
// Prime the cache.
225+
assertEquals(ACCESS_TOKEN, client.getAccessToken().get());
226+
227+
// Many subsequent calls should all be served from the snapshot.
228+
for (int i = 0; i < 20; i++) {
229+
assertEquals(ACCESS_TOKEN, client.getAccessToken().get());
230+
}
231+
232+
verify(1, postRequestedFor(urlEqualTo("/oauth/token")));
233+
}
234+
235+
/**
236+
* Deterministic regression test for the post-exchange race that the synchronized cold path
237+
* is meant to close.
238+
*
239+
* <p>Race being pinned: thread A reads an invalid snapshot, parks; thread B completes a full
240+
* exchange (writes snapshot, clears the in-flight gate); thread A then reaches the
241+
* acquisition gate. With the original CAS-only logic, A would have started a redundant
242+
* second exchange. With the synchronized re-check, A must observe the freshly-written
243+
* snapshot and return the cached token without contacting the IdP.
244+
*
245+
* <p>Determinism is achieved via a package-private {@code beforeAcquireHook} test seam in
246+
* {@link OAuth2Client}: thread A is parked at the hook between its lock-free snapshot read
247+
* and the synchronized gate; thread B runs end-to-end with the hook disarmed; A is then
248+
* released. No sleeps, no thread-scheduling assumptions.
249+
*
250+
* <p>Asserts: exactly one HTTP exchange total, both threads receive the same token.
251+
*/
252+
@Test
253+
void exchangeOAuth2Token_postCompletionRace_noSecondExchange(WireMockRuntimeInfo wm) throws Exception {
254+
stubFor(post(urlEqualTo("/oauth/token"))
255+
.willReturn(ok(String.format("{\"access_token\":\"%s\",\"expires_in\":3600}", ACCESS_TOKEN))));
256+
257+
OAuth2Client client = newOAuth2Client(wm.getHttpBaseUrl(), false);
258+
259+
CountDownLatch threadAParked = new CountDownLatch(1);
260+
CountDownLatch releaseThreadA = new CountDownLatch(1);
261+
262+
// Arm the hook: the next caller (thread A) will park here right between its lock-free
263+
// snapshot check and entering the synchronized gate.
264+
client.beforeAcquireHook = () -> {
265+
threadAParked.countDown();
266+
try {
267+
if (!releaseThreadA.await(5, TimeUnit.SECONDS)) {
268+
throw new IllegalStateException("thread A was never released");
269+
}
270+
} catch (InterruptedException e) {
271+
Thread.currentThread().interrupt();
272+
throw new RuntimeException(e);
273+
}
274+
};
275+
276+
// Thread A: enters, sees invalid snapshot, parks at the hook.
277+
AtomicReference<String> tokenA = new AtomicReference<>();
278+
AtomicReference<Throwable> failureA = new AtomicReference<>();
279+
Thread a = new Thread(() -> {
280+
try {
281+
tokenA.set(client.getAccessToken().get(5, TimeUnit.SECONDS));
282+
} catch (Throwable t) {
283+
failureA.set(t);
284+
}
285+
}, "race-thread-A");
286+
a.start();
287+
288+
assertTrue(threadAParked.await(2, TimeUnit.SECONDS), "thread A never reached the hook");
289+
290+
// Disarm the hook so thread B (this thread) is *not* trapped, then perform a full
291+
// exchange end-to-end. After this returns: snapshot is valid, inFlight is null.
292+
client.beforeAcquireHook = OAuth2Client.NO_OP_HOOK;
293+
String tokenB = client.getAccessToken().get(5, TimeUnit.SECONDS);
294+
assertEquals(ACCESS_TOKEN, tokenB);
295+
verify(1, postRequestedFor(urlEqualTo("/oauth/token")));
296+
297+
// Now release A. It will enter acquireToken() in *exactly* the post-completion state
298+
// (snapshot valid, inFlight null) that the original CAS-only code mishandled.
299+
releaseThreadA.countDown();
300+
a.join(5_000);
301+
assertFalse(a.isAlive(), "thread A did not finish");
302+
assertNull(failureA.get(), "thread A threw");
303+
assertEquals(ACCESS_TOKEN, tokenA.get());
304+
305+
// The decisive assertion: A must NOT have triggered a second exchange.
306+
verify(1, postRequestedFor(urlEqualTo("/oauth/token")));
307+
}
308+
212309
@Test
213310
public void apiTokenIssuer_invalidScheme() {
214311
// When

tla/.gitignore

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# TLA+ tools jar — download on demand, see README
2+
tla2tools.jar
3+
4+
# TLC working directories
5+
states/
6+
*.old
7+
8+
# Default cfg auto-written by pcal.trans (we use named cfgs instead)
9+
OAuth2Client.cfg
10+

tla/OAuth2Client.tla

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
--------------------------- MODULE OAuth2Client ---------------------------
2+
EXTENDS Naturals, FiniteSets, TLC
3+
4+
CONSTANTS Threads, BUGGY
5+
6+
ASSUME Cardinality(Threads) >= 2
7+
8+
NONE == 0
9+
10+
(* --algorithm OAuth2Client {
11+
variables
12+
valid = FALSE,
13+
inFlight = NONE,
14+
monitor = NONE,
15+
exchanges = 0,
16+
tokens = [t \in Threads |-> 0];
17+
18+
define {
19+
AtMostOneExchange == exchanges <= 1
20+
AllDone == \A t \in Threads : tokens[t] = 1
21+
}
22+
23+
process (Thr \in Threads)
24+
variables joined = NONE;
25+
{
26+
HotPath:
27+
if (valid) {
28+
tokens[self] := 1;
29+
goto Finish;
30+
};
31+
32+
Acquire:
33+
await monitor = NONE;
34+
monitor := self;
35+
36+
UnderMonitor:
37+
if (~BUGGY /\ valid) {
38+
monitor := NONE;
39+
tokens[self] := 1;
40+
goto Finish;
41+
} else if (inFlight # NONE) {
42+
joined := inFlight;
43+
monitor := NONE;
44+
goto AwaitJoined;
45+
} else {
46+
inFlight := self;
47+
exchanges := exchanges + 1;
48+
monitor := NONE;
49+
goto DoExchange;
50+
};
51+
52+
DoExchange:
53+
valid := TRUE;
54+
inFlight := NONE;
55+
tokens[self] := 1;
56+
goto Finish;
57+
58+
AwaitJoined:
59+
await valid;
60+
tokens[self] := 1;
61+
goto Finish;
62+
63+
Finish: skip;
64+
}
65+
}
66+
*)
67+
\* BEGIN TRANSLATION
68+
VARIABLES valid, inFlight, monitor, exchanges, tokens, pc
69+
70+
(* define statement *)
71+
AtMostOneExchange == exchanges <= 1
72+
AllDone == \A t \in Threads : tokens[t] = 1
73+
74+
VARIABLE joined
75+
76+
vars == << valid, inFlight, monitor, exchanges, tokens, pc, joined >>
77+
78+
ProcSet == (Threads)
79+
80+
Init == (* Global variables *)
81+
/\ valid = FALSE
82+
/\ inFlight = NONE
83+
/\ monitor = NONE
84+
/\ exchanges = 0
85+
/\ tokens = [t \in Threads |-> 0]
86+
(* Process Thr *)
87+
/\ joined = [self \in Threads |-> NONE]
88+
/\ pc = [self \in ProcSet |-> "HotPath"]
89+
90+
HotPath(self) == /\ pc[self] = "HotPath"
91+
/\ IF valid
92+
THEN /\ tokens' = [tokens EXCEPT ![self] = 1]
93+
/\ pc' = [pc EXCEPT ![self] = "Finish"]
94+
ELSE /\ pc' = [pc EXCEPT ![self] = "Acquire"]
95+
/\ UNCHANGED tokens
96+
/\ UNCHANGED << valid, inFlight, monitor, exchanges, joined >>
97+
98+
Acquire(self) == /\ pc[self] = "Acquire"
99+
/\ monitor = NONE
100+
/\ monitor' = self
101+
/\ pc' = [pc EXCEPT ![self] = "UnderMonitor"]
102+
/\ UNCHANGED << valid, inFlight, exchanges, tokens, joined >>
103+
104+
UnderMonitor(self) == /\ pc[self] = "UnderMonitor"
105+
/\ IF ~BUGGY /\ valid
106+
THEN /\ monitor' = NONE
107+
/\ tokens' = [tokens EXCEPT ![self] = 1]
108+
/\ pc' = [pc EXCEPT ![self] = "Finish"]
109+
/\ UNCHANGED << inFlight, exchanges, joined >>
110+
ELSE /\ IF inFlight # NONE
111+
THEN /\ joined' = [joined EXCEPT ![self] = inFlight]
112+
/\ monitor' = NONE
113+
/\ pc' = [pc EXCEPT ![self] = "AwaitJoined"]
114+
/\ UNCHANGED << inFlight,
115+
exchanges >>
116+
ELSE /\ inFlight' = self
117+
/\ exchanges' = exchanges + 1
118+
/\ monitor' = NONE
119+
/\ pc' = [pc EXCEPT ![self] = "DoExchange"]
120+
/\ UNCHANGED joined
121+
/\ UNCHANGED tokens
122+
/\ valid' = valid
123+
124+
DoExchange(self) == /\ pc[self] = "DoExchange"
125+
/\ valid' = TRUE
126+
/\ inFlight' = NONE
127+
/\ tokens' = [tokens EXCEPT ![self] = 1]
128+
/\ pc' = [pc EXCEPT ![self] = "Finish"]
129+
/\ UNCHANGED << monitor, exchanges, joined >>
130+
131+
AwaitJoined(self) == /\ pc[self] = "AwaitJoined"
132+
/\ valid
133+
/\ tokens' = [tokens EXCEPT ![self] = 1]
134+
/\ pc' = [pc EXCEPT ![self] = "Finish"]
135+
/\ UNCHANGED << valid, inFlight, monitor, exchanges,
136+
joined >>
137+
138+
Finish(self) == /\ pc[self] = "Finish"
139+
/\ TRUE
140+
/\ pc' = [pc EXCEPT ![self] = "Done"]
141+
/\ UNCHANGED << valid, inFlight, monitor, exchanges, tokens,
142+
joined >>
143+
144+
Thr(self) == HotPath(self) \/ Acquire(self) \/ UnderMonitor(self)
145+
\/ DoExchange(self) \/ AwaitJoined(self) \/ Finish(self)
146+
147+
(* Allow infinite stuttering to prevent deadlock on termination. *)
148+
Terminating == /\ \A self \in ProcSet: pc[self] = "Done"
149+
/\ UNCHANGED vars
150+
151+
Next == (\E self \in Threads: Thr(self))
152+
\/ Terminating
153+
154+
Spec == Init /\ [][Next]_vars
155+
156+
Termination == <>(\A self \in ProcSet: pc[self] = "Done")
157+
158+
\* END TRANSLATION
159+
=============================================================================
160+

0 commit comments

Comments
 (0)