diff --git a/custom_components/reflex_clerk_api/clerk_provider.py b/custom_components/reflex_clerk_api/clerk_provider.py index eccc235..0952524 100644 --- a/custom_components/reflex_clerk_api/clerk_provider.py +++ b/custom_components/reflex_clerk_api/clerk_provider.py @@ -59,6 +59,8 @@ class ClerkState(rx.State): "nbf": {"essential": True}, # "azp": {"essential": False, "values": ["http://localhost:3000", "https://example.com"]}, } + _jwt_validate_leeway_seconds: ClassVar[int] = 60 + """Clock-skew leeway (seconds) for validating JWT claims like exp/nbf.""" @classmethod def register_dependent_handler(cls, handler: EventCallback) -> None: @@ -85,6 +87,29 @@ def set_claims_options(cls, claims_options: dict[str, Any]) -> None: """Set the claims options for the JWT claims validation.""" cls._claims_options = claims_options + @classmethod + def set_jwt_validate_leeway_seconds(cls, seconds: int) -> None: + """Set clock-skew leeway (seconds) for JWT exp/nbf validation. + + Default is 60 seconds. Increase if you see intermittent ExpiredTokenError + due to clock drift between Clerk servers and your backend. + + Args: + seconds: Non-negative integer, max 3600 (1 hour). + + Raises: + ValueError: If seconds is negative or exceeds 3600. + """ + if not isinstance(seconds, int) or isinstance(seconds, bool) or seconds < 0: + raise ValueError( + f"jwt_validate_leeway_seconds must be a non-negative integer, got {seconds!r}" + ) + if seconds > 3600: + raise ValueError( + f"jwt_validate_leeway_seconds exceeds maximum of 3600 (1 hour), got {seconds}" + ) + cls._jwt_validate_leeway_seconds = seconds + @property def client(self) -> clerk_backend_api.Clerk: if self._client is None: @@ -116,9 +141,13 @@ async def set_clerk_session(self, token: str) -> EventType: return ClerkState.clear_clerk_session try: # Validate the token according to the claim options (e.g. iss, exp, nbf, azp.) - decoded.validate() - except (jose_errors.InvalidClaimError, jose_errors.MissingClaimError) as e: - logging.warning(f"JWT token is invalid: {e}") + decoded.validate(leeway=self._jwt_validate_leeway_seconds) + except ( + jose_errors.ExpiredTokenError, + jose_errors.InvalidClaimError, + jose_errors.MissingClaimError, + ) as e: + logging.warning(f"JWT token validation failed: {type(e).__name__}: {e}") return ClerkState.clear_clerk_session async with self: @@ -364,7 +393,7 @@ def add_imports( ) -> rx.ImportDict: addl_imports: rx.ImportDict = { "@clerk/clerk-react": ["useAuth"], - "react": ["useContext", "useEffect"], + "react": ["useContext", "useEffect", "useRef"], "$/utils/context": ["EventLoopContext"], "$/utils/state": ["ReflexEvent"], } @@ -375,28 +404,73 @@ def add_custom_code(self) -> list[str]: return [ """ -function ClerkSessionSynchronizer({ children }) { - const { getToken, isLoaded, isSignedIn } = useAuth() - const [ addEvents, connectErrors ] = useContext(EventLoopContext) - - useEffect(() => { - if (isLoaded && !!addEvents) { - if (isSignedIn) { - getToken().then(token => { - addEvents([ReflexEvent("%s.set_clerk_session", {token})]) - }) - } else { - addEvents([ReflexEvent("%s.clear_clerk_session")]) - } - } - }, [isSignedIn]) +function ClerkSessionSynchronizer({{ children }}) {{ + const {{ getToken, isLoaded, isSignedIn }} = useAuth() + const [ addEvents ] = useContext(EventLoopContext) + // Tracks the last *successfully dispatched* (state, addEvents) pair. Only + // updated after a confirmed dispatch so transient token-fetch failures don't + // poison the dedupe and prevent later retries. + const lastSentRef = useRef({{ stateKey: null, addEvents: null }}) + // Incremented on every effect run with new desired state. In-flight token + // fetches check this on resolve and bail if a newer effect has superseded + // them - prevents stale set_clerk_session dispatches after sign-out. + const requestIdRef = useRef(0) + + useEffect(() => {{ + // Wait for all dependencies to be ready. + if (!isLoaded || !addEvents) return + + // Deduplicate rapid calls, but remain reconnect-safe: + // addEvents identity changes across websocket reconnects, so include it in the key. + const stateKey = isSignedIn ? "signed_in" : "signed_out" + if ( + lastSentRef.current?.stateKey === stateKey && + lastSentRef.current?.addEvents === addEvents + ) return + + const myRequestId = ++requestIdRef.current + + if (!isSignedIn) {{ + // Always run the sign-out path immediately. Any in-flight token fetch + // will see myRequestId !== requestIdRef.current on resolve and drop. + addEvents([ReflexEvent("{state}.clear_clerk_session")]) + lastSentRef.current = {{ stateKey, addEvents }} + return + }} + + // isSignedIn: try to get a fresh token. Retry once after a short delay + // on transient failures before clearing the backend session - clearing + // prematurely forces a logout while Clerk is still signed in. + // Prefer skipCache (avoids near-expiry cached tokens); fall back if the + // installed Clerk version doesn't support that option. + const fetchToken = () => + getToken({{ skipCache: true }}).catch(() => getToken()) + fetchToken() + .catch(() => new Promise(resolve => setTimeout(resolve, 500)).then(fetchToken)) + .then(token => {{ + // Drop if a newer effect run (e.g., sign-out) has superseded us. + if (myRequestId !== requestIdRef.current) return + if (token) {{ + addEvents([ReflexEvent("{state}.set_clerk_session", {{token}})]) + lastSentRef.current = {{ stateKey, addEvents }} + }} else {{ + // Final failure: clear backend session but leave lastSentRef + // unchanged so the next trigger (reconnect, sign-in toggle, etc.) + // re-attempts the sync instead of being deduped away. + addEvents([ReflexEvent("{state}.clear_clerk_session")]) + }} + }}) + .catch(() => {{ + if (myRequestId !== requestIdRef.current) return + addEvents([ReflexEvent("{state}.clear_clerk_session")]) + }}) + }}, [isLoaded, isSignedIn, addEvents, getToken]) return ( - <>{children} + <>{{children}} ) -} -""" - % (clerk_state_name, clerk_state_name) +}} +""".format(state=clerk_state_name) ] diff --git a/pyproject.toml b/pyproject.toml index 85bc7d8..0144cb0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ homepage = "https://reflex-clerk-api-demo.adventuresoftim.com" [tool.pytest.ini_options] # addopts = "--headed" +pythonpath = ["custom_components"] [tool.pyright] venvPath = "." diff --git a/tests/test_clerk_provider_unit.py b/tests/test_clerk_provider_unit.py new file mode 100644 index 0000000..9e77572 --- /dev/null +++ b/tests/test_clerk_provider_unit.py @@ -0,0 +1,52 @@ +import asyncio + +import authlib.jose.errors as jose_errors + + +def test_set_clerk_session_expired_token_clears(monkeypatch): + """Expired tokens should not crash the handler; they should clear session.""" + # Import inside the test so the module is importable in different test layouts. + # We need the actual module object (not just ClerkState) to monkeypatch jwt.decode + # where it's used. importlib is required because reflex_clerk_api.clerk_provider + # resolves to the function via __init__.py re-exports. + import importlib + + clerk_provider_module = importlib.import_module("reflex_clerk_api.clerk_provider") + from reflex_clerk_api.clerk_provider import ClerkState + + # Instantiate state in a framework-safe way for tests. + state = ClerkState(_reflex_internal_init=True) # pyright: ignore[reportCallIssue] + + async def fake_get_jwk_keys(self): + return {} + + monkeypatch.setattr(ClerkState, "_get_jwk_keys", fake_get_jwk_keys, raising=True) + + validate_calls: dict[str, object] = {} + + class FakeClaims: + def validate(self, leeway=None): + validate_calls["leeway"] = leeway + raise jose_errors.ExpiredTokenError() + + monkeypatch.setattr( + clerk_provider_module.jwt, + "decode", + lambda *args, **kwargs: FakeClaims(), + raising=True, + ) + + result = asyncio.run( + ClerkState.set_clerk_session.fn(state, token="fake") # pyright: ignore[reportAttributeAccessIssue] + ) + assert validate_calls["leeway"] == 60 + assert result == ClerkState.clear_clerk_session + + +def test_clerk_session_synchronizer_js_contains_reconnect_safe_deps_and_skipcache(): + """String-based regression test for the generated JS.""" + from reflex_clerk_api.clerk_provider import ClerkSessionSynchronizer + + js = ClerkSessionSynchronizer.create().add_custom_code()[0] + assert "[isLoaded, isSignedIn, addEvents, getToken]" in js + assert "skipCache: true" in js