Skip to content

Commit 5d32e74

Browse files
TimChildPJ-Snapclaude
authored
fix: improve auth session sync reliability (#24)
* fix: improve auth session sync reliability - Add configurable JWT validation leeway (default 60s) to handle clock drift between Clerk servers and the backend - Catch ExpiredTokenError alongside InvalidClaimError/MissingClaimError to prevent crashes on expired tokens - Rewrite ClerkSessionSynchronizer JS for reliability: - Use useRef to deduplicate rapid calls while remaining reconnect-safe - Request fresh tokens with skipCache to avoid near-expiry cached tokens - Handle token retrieval failures gracefully (clear session instead of hang) - Include all dependencies in useEffect array ([isLoaded, isSignedIn, addEvents, getToken]) - Add unit tests for expired token handling and JS code correctness - Add pythonpath config for pytest to find custom_components * fix: address review on auth session sync PR JS dedupe + retry fixes (Copilot review): - Only update lastSentRef after a confirmed dispatch so a failed token fetch doesn't poison the dedupe and block later retries. - Add inFlightRef to prevent overlapping getToken calls when the effect re-fires before the in-flight promise resolves. - Retry getToken once after a 500ms delay before clearing the backend session, so transient token-fetch failures don't force a backend logout while Clerk is still signed in. - On final failure, leave lastSentRef unchanged so the next trigger (reconnect, sign-in toggle) re-attempts the sync. Test typecheck fixes: - pyright ignore for the `_reflex_internal_init` Reflex internal init flag. - pyright ignore for accessing `.fn` on the wrapped EventCallback. Co-Authored-By: Paul Johnson <paul.johnson@snaplabs.ai> Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * fix: address Copilot race-condition review on session sync JS - Drop inFlightRef hard gate. It blocked the !isSignedIn branch, so a sign-out occurring during an in-flight token fetch would not dispatch clear_clerk_session and the backend session would stay stale until another trigger arrived. - Replace it with a requestIdRef counter that's incremented on every effect run with a new desired state. The in-flight fetch's then/catch handlers check the captured myRequestId against requestIdRef.current before dispatching - so a getToken() that resolves after sign-out can no longer dispatch a stale set_clerk_session. - Sign-out path now runs unconditionally and immediately. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Paul Johnson <paul.johnson@snaplabs.ai> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent d456dd0 commit 5d32e74

3 files changed

Lines changed: 150 additions & 23 deletions

File tree

custom_components/reflex_clerk_api/clerk_provider.py

Lines changed: 97 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ class ClerkState(rx.State):
5959
"nbf": {"essential": True},
6060
# "azp": {"essential": False, "values": ["http://localhost:3000", "https://example.com"]},
6161
}
62+
_jwt_validate_leeway_seconds: ClassVar[int] = 60
63+
"""Clock-skew leeway (seconds) for validating JWT claims like exp/nbf."""
6264

6365
@classmethod
6466
def register_dependent_handler(cls, handler: EventCallback) -> None:
@@ -85,6 +87,29 @@ def set_claims_options(cls, claims_options: dict[str, Any]) -> None:
8587
"""Set the claims options for the JWT claims validation."""
8688
cls._claims_options = claims_options
8789

90+
@classmethod
91+
def set_jwt_validate_leeway_seconds(cls, seconds: int) -> None:
92+
"""Set clock-skew leeway (seconds) for JWT exp/nbf validation.
93+
94+
Default is 60 seconds. Increase if you see intermittent ExpiredTokenError
95+
due to clock drift between Clerk servers and your backend.
96+
97+
Args:
98+
seconds: Non-negative integer, max 3600 (1 hour).
99+
100+
Raises:
101+
ValueError: If seconds is negative or exceeds 3600.
102+
"""
103+
if not isinstance(seconds, int) or isinstance(seconds, bool) or seconds < 0:
104+
raise ValueError(
105+
f"jwt_validate_leeway_seconds must be a non-negative integer, got {seconds!r}"
106+
)
107+
if seconds > 3600:
108+
raise ValueError(
109+
f"jwt_validate_leeway_seconds exceeds maximum of 3600 (1 hour), got {seconds}"
110+
)
111+
cls._jwt_validate_leeway_seconds = seconds
112+
88113
@property
89114
def client(self) -> clerk_backend_api.Clerk:
90115
if self._client is None:
@@ -116,9 +141,13 @@ async def set_clerk_session(self, token: str) -> EventType:
116141
return ClerkState.clear_clerk_session
117142
try:
118143
# Validate the token according to the claim options (e.g. iss, exp, nbf, azp.)
119-
decoded.validate()
120-
except (jose_errors.InvalidClaimError, jose_errors.MissingClaimError) as e:
121-
logging.warning(f"JWT token is invalid: {e}")
144+
decoded.validate(leeway=self._jwt_validate_leeway_seconds)
145+
except (
146+
jose_errors.ExpiredTokenError,
147+
jose_errors.InvalidClaimError,
148+
jose_errors.MissingClaimError,
149+
) as e:
150+
logging.warning(f"JWT token validation failed: {type(e).__name__}: {e}")
122151
return ClerkState.clear_clerk_session
123152

124153
async with self:
@@ -364,7 +393,7 @@ def add_imports(
364393
) -> rx.ImportDict:
365394
addl_imports: rx.ImportDict = {
366395
"@clerk/clerk-react": ["useAuth"],
367-
"react": ["useContext", "useEffect"],
396+
"react": ["useContext", "useEffect", "useRef"],
368397
"$/utils/context": ["EventLoopContext"],
369398
"$/utils/state": ["ReflexEvent"],
370399
}
@@ -375,28 +404,73 @@ def add_custom_code(self) -> list[str]:
375404

376405
return [
377406
"""
378-
function ClerkSessionSynchronizer({ children }) {
379-
const { getToken, isLoaded, isSignedIn } = useAuth()
380-
const [ addEvents, connectErrors ] = useContext(EventLoopContext)
381-
382-
useEffect(() => {
383-
if (isLoaded && !!addEvents) {
384-
if (isSignedIn) {
385-
getToken().then(token => {
386-
addEvents([ReflexEvent("%s.set_clerk_session", {token})])
387-
})
388-
} else {
389-
addEvents([ReflexEvent("%s.clear_clerk_session")])
390-
}
391-
}
392-
}, [isSignedIn])
407+
function ClerkSessionSynchronizer({{ children }}) {{
408+
const {{ getToken, isLoaded, isSignedIn }} = useAuth()
409+
const [ addEvents ] = useContext(EventLoopContext)
410+
// Tracks the last *successfully dispatched* (state, addEvents) pair. Only
411+
// updated after a confirmed dispatch so transient token-fetch failures don't
412+
// poison the dedupe and prevent later retries.
413+
const lastSentRef = useRef({{ stateKey: null, addEvents: null }})
414+
// Incremented on every effect run with new desired state. In-flight token
415+
// fetches check this on resolve and bail if a newer effect has superseded
416+
// them - prevents stale set_clerk_session dispatches after sign-out.
417+
const requestIdRef = useRef(0)
418+
419+
useEffect(() => {{
420+
// Wait for all dependencies to be ready.
421+
if (!isLoaded || !addEvents) return
422+
423+
// Deduplicate rapid calls, but remain reconnect-safe:
424+
// addEvents identity changes across websocket reconnects, so include it in the key.
425+
const stateKey = isSignedIn ? "signed_in" : "signed_out"
426+
if (
427+
lastSentRef.current?.stateKey === stateKey &&
428+
lastSentRef.current?.addEvents === addEvents
429+
) return
430+
431+
const myRequestId = ++requestIdRef.current
432+
433+
if (!isSignedIn) {{
434+
// Always run the sign-out path immediately. Any in-flight token fetch
435+
// will see myRequestId !== requestIdRef.current on resolve and drop.
436+
addEvents([ReflexEvent("{state}.clear_clerk_session")])
437+
lastSentRef.current = {{ stateKey, addEvents }}
438+
return
439+
}}
440+
441+
// isSignedIn: try to get a fresh token. Retry once after a short delay
442+
// on transient failures before clearing the backend session - clearing
443+
// prematurely forces a logout while Clerk is still signed in.
444+
// Prefer skipCache (avoids near-expiry cached tokens); fall back if the
445+
// installed Clerk version doesn't support that option.
446+
const fetchToken = () =>
447+
getToken({{ skipCache: true }}).catch(() => getToken())
448+
fetchToken()
449+
.catch(() => new Promise(resolve => setTimeout(resolve, 500)).then(fetchToken))
450+
.then(token => {{
451+
// Drop if a newer effect run (e.g., sign-out) has superseded us.
452+
if (myRequestId !== requestIdRef.current) return
453+
if (token) {{
454+
addEvents([ReflexEvent("{state}.set_clerk_session", {{token}})])
455+
lastSentRef.current = {{ stateKey, addEvents }}
456+
}} else {{
457+
// Final failure: clear backend session but leave lastSentRef
458+
// unchanged so the next trigger (reconnect, sign-in toggle, etc.)
459+
// re-attempts the sync instead of being deduped away.
460+
addEvents([ReflexEvent("{state}.clear_clerk_session")])
461+
}}
462+
}})
463+
.catch(() => {{
464+
if (myRequestId !== requestIdRef.current) return
465+
addEvents([ReflexEvent("{state}.clear_clerk_session")])
466+
}})
467+
}}, [isLoaded, isSignedIn, addEvents, getToken])
393468
394469
return (
395-
<>{children}</>
470+
<>{{children}}</>
396471
)
397-
}
398-
"""
399-
% (clerk_state_name, clerk_state_name)
472+
}}
473+
""".format(state=clerk_state_name)
400474
]
401475

402476

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ homepage = "https://reflex-clerk-api-demo.adventuresoftim.com"
4747

4848
[tool.pytest.ini_options]
4949
# addopts = "--headed"
50+
pythonpath = ["custom_components"]
5051

5152
[tool.pyright]
5253
venvPath = "."

tests/test_clerk_provider_unit.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import asyncio
2+
3+
import authlib.jose.errors as jose_errors
4+
5+
6+
def test_set_clerk_session_expired_token_clears(monkeypatch):
7+
"""Expired tokens should not crash the handler; they should clear session."""
8+
# Import inside the test so the module is importable in different test layouts.
9+
# We need the actual module object (not just ClerkState) to monkeypatch jwt.decode
10+
# where it's used. importlib is required because reflex_clerk_api.clerk_provider
11+
# resolves to the function via __init__.py re-exports.
12+
import importlib
13+
14+
clerk_provider_module = importlib.import_module("reflex_clerk_api.clerk_provider")
15+
from reflex_clerk_api.clerk_provider import ClerkState
16+
17+
# Instantiate state in a framework-safe way for tests.
18+
state = ClerkState(_reflex_internal_init=True) # pyright: ignore[reportCallIssue]
19+
20+
async def fake_get_jwk_keys(self):
21+
return {}
22+
23+
monkeypatch.setattr(ClerkState, "_get_jwk_keys", fake_get_jwk_keys, raising=True)
24+
25+
validate_calls: dict[str, object] = {}
26+
27+
class FakeClaims:
28+
def validate(self, leeway=None):
29+
validate_calls["leeway"] = leeway
30+
raise jose_errors.ExpiredTokenError()
31+
32+
monkeypatch.setattr(
33+
clerk_provider_module.jwt,
34+
"decode",
35+
lambda *args, **kwargs: FakeClaims(),
36+
raising=True,
37+
)
38+
39+
result = asyncio.run(
40+
ClerkState.set_clerk_session.fn(state, token="fake") # pyright: ignore[reportAttributeAccessIssue]
41+
)
42+
assert validate_calls["leeway"] == 60
43+
assert result == ClerkState.clear_clerk_session
44+
45+
46+
def test_clerk_session_synchronizer_js_contains_reconnect_safe_deps_and_skipcache():
47+
"""String-based regression test for the generated JS."""
48+
from reflex_clerk_api.clerk_provider import ClerkSessionSynchronizer
49+
50+
js = ClerkSessionSynchronizer.create().add_custom_code()[0]
51+
assert "[isLoaded, isSignedIn, addEvents, getToken]" in js
52+
assert "skipCache: true" in js

0 commit comments

Comments
 (0)