Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 97 additions & 23 deletions custom_components/reflex_clerk_api/clerk_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class ClerkState(rx.State):
"nbf": {"essential": True},
# "azp": {"essential": False, "values": ["http://localhost:3000", "https://example.com"]},
}
_jwt_validate_leeway_seconds: ClassVar[int] = 60
"""Clock-skew leeway (seconds) for validating JWT claims like exp/nbf."""

@classmethod
def register_dependent_handler(cls, handler: EventCallback) -> None:
Expand All @@ -85,6 +87,29 @@ def set_claims_options(cls, claims_options: dict[str, Any]) -> None:
"""Set the claims options for the JWT claims validation."""
cls._claims_options = claims_options

@classmethod
def set_jwt_validate_leeway_seconds(cls, seconds: int) -> None:
"""Set clock-skew leeway (seconds) for JWT exp/nbf validation.

Default is 60 seconds. Increase if you see intermittent ExpiredTokenError
due to clock drift between Clerk servers and your backend.

Args:
seconds: Non-negative integer, max 3600 (1 hour).

Raises:
ValueError: If seconds is negative or exceeds 3600.
"""
if not isinstance(seconds, int) or isinstance(seconds, bool) or seconds < 0:
raise ValueError(
f"jwt_validate_leeway_seconds must be a non-negative integer, got {seconds!r}"
)
if seconds > 3600:
raise ValueError(
f"jwt_validate_leeway_seconds exceeds maximum of 3600 (1 hour), got {seconds}"
)
cls._jwt_validate_leeway_seconds = seconds

@property
def client(self) -> clerk_backend_api.Clerk:
if self._client is None:
Expand Down Expand Up @@ -116,9 +141,13 @@ async def set_clerk_session(self, token: str) -> EventType:
return ClerkState.clear_clerk_session
try:
# Validate the token according to the claim options (e.g. iss, exp, nbf, azp.)
decoded.validate()
except (jose_errors.InvalidClaimError, jose_errors.MissingClaimError) as e:
logging.warning(f"JWT token is invalid: {e}")
decoded.validate(leeway=self._jwt_validate_leeway_seconds)
except (
jose_errors.ExpiredTokenError,
jose_errors.InvalidClaimError,
jose_errors.MissingClaimError,
) as e:
logging.warning(f"JWT token validation failed: {type(e).__name__}: {e}")
return ClerkState.clear_clerk_session

async with self:
Expand Down Expand Up @@ -364,7 +393,7 @@ def add_imports(
) -> rx.ImportDict:
addl_imports: rx.ImportDict = {
"@clerk/clerk-react": ["useAuth"],
"react": ["useContext", "useEffect"],
"react": ["useContext", "useEffect", "useRef"],
"$/utils/context": ["EventLoopContext"],
"$/utils/state": ["ReflexEvent"],
}
Expand All @@ -375,28 +404,73 @@ def add_custom_code(self) -> list[str]:

return [
"""
function ClerkSessionSynchronizer({ children }) {
const { getToken, isLoaded, isSignedIn } = useAuth()
const [ addEvents, connectErrors ] = useContext(EventLoopContext)

useEffect(() => {
if (isLoaded && !!addEvents) {
if (isSignedIn) {
getToken().then(token => {
addEvents([ReflexEvent("%s.set_clerk_session", {token})])
})
} else {
addEvents([ReflexEvent("%s.clear_clerk_session")])
}
}
}, [isSignedIn])
function ClerkSessionSynchronizer({{ children }}) {{
const {{ getToken, isLoaded, isSignedIn }} = useAuth()
const [ addEvents ] = useContext(EventLoopContext)
// Tracks the last *successfully dispatched* (state, addEvents) pair. Only
// updated after a confirmed dispatch so transient token-fetch failures don't
// poison the dedupe and prevent later retries.
const lastSentRef = useRef({{ stateKey: null, addEvents: null }})
// Incremented on every effect run with new desired state. In-flight token
// fetches check this on resolve and bail if a newer effect has superseded
// them - prevents stale set_clerk_session dispatches after sign-out.
const requestIdRef = useRef(0)

useEffect(() => {{
// Wait for all dependencies to be ready.
if (!isLoaded || !addEvents) return

// Deduplicate rapid calls, but remain reconnect-safe:
// addEvents identity changes across websocket reconnects, so include it in the key.
const stateKey = isSignedIn ? "signed_in" : "signed_out"
if (
lastSentRef.current?.stateKey === stateKey &&
lastSentRef.current?.addEvents === addEvents
) return

const myRequestId = ++requestIdRef.current

Comment on lines +425 to +432
if (!isSignedIn) {{
// Always run the sign-out path immediately. Any in-flight token fetch
// will see myRequestId !== requestIdRef.current on resolve and drop.
addEvents([ReflexEvent("{state}.clear_clerk_session")])
lastSentRef.current = {{ stateKey, addEvents }}
return
}}

// isSignedIn: try to get a fresh token. Retry once after a short delay
// on transient failures before clearing the backend session - clearing
// prematurely forces a logout while Clerk is still signed in.
// Prefer skipCache (avoids near-expiry cached tokens); fall back if the
// installed Clerk version doesn't support that option.
const fetchToken = () =>
getToken({{ skipCache: true }}).catch(() => getToken())
fetchToken()
.catch(() => new Promise(resolve => setTimeout(resolve, 500)).then(fetchToken))
.then(token => {{
// Drop if a newer effect run (e.g., sign-out) has superseded us.
if (myRequestId !== requestIdRef.current) return
if (token) {{
addEvents([ReflexEvent("{state}.set_clerk_session", {{token}})])
lastSentRef.current = {{ stateKey, addEvents }}
Comment on lines +450 to +455
}} else {{
// Final failure: clear backend session but leave lastSentRef
// unchanged so the next trigger (reconnect, sign-in toggle, etc.)
// re-attempts the sync instead of being deduped away.
addEvents([ReflexEvent("{state}.clear_clerk_session")])
}}
}})
.catch(() => {{
if (myRequestId !== requestIdRef.current) return
addEvents([ReflexEvent("{state}.clear_clerk_session")])
}})
}}, [isLoaded, isSignedIn, addEvents, getToken])

return (
<>{children}</>
<>{{children}}</>
)
}
"""
% (clerk_state_name, clerk_state_name)
}}
""".format(state=clerk_state_name)
]


Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ homepage = "https://reflex-clerk-api-demo.adventuresoftim.com"

[tool.pytest.ini_options]
# addopts = "--headed"
pythonpath = ["custom_components"]

[tool.pyright]
venvPath = "."
Expand Down
52 changes: 52 additions & 0 deletions tests/test_clerk_provider_unit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import asyncio

import authlib.jose.errors as jose_errors


def test_set_clerk_session_expired_token_clears(monkeypatch):
"""Expired tokens should not crash the handler; they should clear session."""
# Import inside the test so the module is importable in different test layouts.
# We need the actual module object (not just ClerkState) to monkeypatch jwt.decode
# where it's used. importlib is required because reflex_clerk_api.clerk_provider
# resolves to the function via __init__.py re-exports.
import importlib

clerk_provider_module = importlib.import_module("reflex_clerk_api.clerk_provider")
from reflex_clerk_api.clerk_provider import ClerkState

# Instantiate state in a framework-safe way for tests.
state = ClerkState(_reflex_internal_init=True) # pyright: ignore[reportCallIssue]

async def fake_get_jwk_keys(self):
return {}

monkeypatch.setattr(ClerkState, "_get_jwk_keys", fake_get_jwk_keys, raising=True)

validate_calls: dict[str, object] = {}

class FakeClaims:
def validate(self, leeway=None):
validate_calls["leeway"] = leeway
raise jose_errors.ExpiredTokenError()

monkeypatch.setattr(
clerk_provider_module.jwt,
"decode",
lambda *args, **kwargs: FakeClaims(),
raising=True,
)

result = asyncio.run(
ClerkState.set_clerk_session.fn(state, token="fake") # pyright: ignore[reportAttributeAccessIssue]
)
assert validate_calls["leeway"] == 60
assert result == ClerkState.clear_clerk_session


def test_clerk_session_synchronizer_js_contains_reconnect_safe_deps_and_skipcache():
"""String-based regression test for the generated JS."""
from reflex_clerk_api.clerk_provider import ClerkSessionSynchronizer

js = ClerkSessionSynchronizer.create().add_custom_code()[0]
assert "[isLoaded, isSignedIn, addEvents, getToken]" in js
assert "skipCache: true" in js
Loading