Skip to content

Commit 069e2ee

Browse files
Improve JWKS automatic rotation
1 parent 4537386 commit 069e2ee

12 files changed

Lines changed: 216 additions & 31 deletions

.flake8

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ max-complexity = 18
66
select = B,C,E,F,W,T4,B9
77
per-file-ignores =
88
guardpost/__init__.py:F401
9+
tests/test_jwks.py:E501

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## [1.0.1] - 2023-03-20 :sun_with_face:
9+
- Improves the automatic rotation of `JWKS`: when validating `JWTs`, `JWKS` are
10+
refreshed automatically if an unknown `kid` is encountered, and `JWKS` were
11+
last fetched more than `refresh_time` seconds ago (by default 120 seconds).
12+
- Corrects an inconsistency in how `claims` are read in the `User` class.
13+
814
## [1.0.0] - 2023-01-07 :star:
915
- Adds built-in support for dependency injection, using the new `ContainerProtocol`
1016
in `rodi` v2.

guardpost/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.0.0"
1+
__version__ = "1.0.1"

guardpost/authentication.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,14 @@ def __init__(
2626

2727
@property
2828
def sub(self) -> Optional[str]:
29-
return self["sub"]
29+
return self.get("sub")
3030

3131
def is_authenticated(self) -> bool:
3232
return bool(self.authentication_mode)
3333

34+
def get(self, key: str):
35+
return self.claims.get(key)
36+
3437
def __getitem__(self, item):
3538
return self.claims[item]
3639

@@ -44,15 +47,15 @@ def has_claim_value(self, name: str, value: str) -> bool:
4447
class User(Identity):
4548
@property
4649
def id(self) -> Optional[str]:
47-
return self["id"] or self.sub
50+
return self.get("id") or self.sub
4851

4952
@property
5053
def name(self) -> Optional[str]:
51-
return self["name"]
54+
return self.get("name")
5255

5356
@property
5457
def email(self) -> Optional[str]:
55-
return self["email"]
58+
return self.get("email")
5659

5760

5861
class AuthenticationHandler(ABC):

guardpost/authorization.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ def _get_message(forced_failure, failed_requirements):
8686

8787

8888
class AuthorizationContext:
89-
9089
__slots__ = ("identity", "requirements", "_succeeded", "_failed_forced")
9190

9291
def __init__(self, identity: Identity, requirements: Sequence[Requirement]):
@@ -222,7 +221,6 @@ async def _handle_with_policy(self, policy: Policy, identity: Identity, scope: A
222221
with AuthorizationContext(
223222
identity, list(self._get_requirements(policy, scope))
224223
) as context:
225-
226224
for requirement in context.requirements:
227225
if _is_async_handler(type(requirement)): # type: ignore
228226
await requirement.handle(context)

guardpost/jwks/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ def from_dict(cls, value) -> "JWK":
7070
class JWKS:
7171
keys: List[JWK]
7272

73+
def update(self, new_set: "JWKS"):
74+
self.keys = list({key.kid: key for key in self.keys + new_set.keys}.values())
75+
7376
@classmethod
7477
def from_dict(cls, value) -> "JWKS":
7578
if "keys" not in value:

guardpost/jwks/caching.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
11
import time
22
from typing import Optional
33

4-
from . import JWKS, KeysProvider
4+
from . import JWK, JWKS, KeysProvider
55

66

77
class CachingKeysProvider(KeysProvider):
88
"""
99
Kind of KeysProvider that can cache the result of another KeysProvider.
1010
"""
1111

12-
def __init__(self, keys_provider: KeysProvider, cache_time: float) -> None:
12+
def __init__(
13+
self, keys_provider: KeysProvider, cache_time: float, refresh_time: float = 120
14+
) -> None:
1315
"""
1416
Creates a new instance of CachingKeysProvider bound to a given KeysProvider,
1517
and caching its result up to an optional amount of seconds described by
1618
cache_time. Expiration is disabled if `cache_time` <= 0.
19+
JWKS are refreshed anyway if an unknown `kid` is encountered and the set was
20+
fetched more than `refresh_time` seconds ago.
1721
"""
1822
super().__init__()
1923

@@ -22,6 +26,7 @@ def __init__(self, keys_provider: KeysProvider, cache_time: float) -> None:
2226

2327
self._keys: Optional[JWKS] = None
2428
self._cache_time = cache_time
29+
self._refresh_time = refresh_time
2530
self._last_fetch_time: float = 0
2631
self._keys_provider = keys_provider
2732

@@ -34,6 +39,14 @@ async def _fetch_keys(self) -> JWKS:
3439
self._last_fetch_time = time.time()
3540
return self._keys
3641

42+
async def _refresh_keys(self) -> JWKS:
43+
new_set = await self._fetch_keys()
44+
if self._keys is None: # pragma: no cover
45+
self._keys = new_set
46+
else:
47+
self._keys.update(new_set)
48+
return self._keys
49+
3750
async def get_keys(self) -> JWKS:
3851
if self._keys is not None:
3952
if self._cache_time > 0 and (
@@ -43,3 +56,27 @@ async def get_keys(self) -> JWKS:
4356
else:
4457
return self._keys
4558
return await self._fetch_keys()
59+
60+
async def get_key(self, kid: str) -> Optional[JWK]:
61+
"""
62+
Tries to get a JWK by kid. If the JWK is not found and the last time the keys
63+
were fetched is older than `refresh_time` (default 120 seconds), it fetches
64+
again the JWKS from the source.
65+
"""
66+
jwks = await self.get_keys()
67+
68+
for jwk in jwks.keys.copy():
69+
if jwk.kid is not None and jwk.kid == kid:
70+
return jwk
71+
72+
if (
73+
self._refresh_time > 0
74+
and time.time() - self._last_fetch_time >= self._refresh_time
75+
):
76+
jwks = await self._refresh_keys()
77+
78+
for jwk in jwks.keys.copy():
79+
if jwk.kid is not None and jwk.kid == kid:
80+
return jwk
81+
82+
return None

guardpost/jwts/__init__.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ def __init__(
4444
require_kid: bool = True,
4545
keys_provider: Optional[KeysProvider] = None,
4646
keys_url: Optional[str] = None,
47-
cache_time: float = 10800
47+
cache_time: float = 10800,
48+
refresh_time: float = 120,
4849
) -> None:
4950
"""
5051
Creates a new instance of JWTValidator. This class only supports validating
@@ -72,9 +73,15 @@ def __init__(
7273
keys_url : Optional[str], optional
7374
If provided, keys are obtained from the given URL through HTTP GET.
7475
This parameter is ignored if `keys_provider` is given.
75-
cache_time : float, optional
76-
If >= 0, JWKS are cached in memory and stored for the given amount in
77-
seconds. By default 10800 (3 hours).
76+
cache_time : float
77+
JWKS are cached in memory and stored for the given amount in seconds.
78+
By default 10800 (3 hours). Regardless of this parameter, JWKS are refreshed
79+
automatically if an unknown kid is met and JWKS were last fetched more than
80+
`refresh_time` earlier (in seconds).
81+
refresh_time : float
82+
JWKS are refreshed automatically if an unknown `kid` is encountered, and
83+
JWKS were last fetched more than `refresh_time` seconds ago (by default
84+
120 seconds)
7885
"""
7986
if keys_provider:
8087
pass
@@ -89,26 +96,24 @@ def __init__(
8996
"`authority`, or `keys_provider`."
9097
)
9198

92-
if cache_time:
93-
keys_provider = CachingKeysProvider(keys_provider, cache_time)
99+
keys_provider = CachingKeysProvider(keys_provider, cache_time, refresh_time)
94100

95101
self._valid_issuers = list(valid_issuers)
96102
self._valid_audiences = list(valid_audiences)
97103
self._algorithms = list(algorithms)
98-
self._keys_provider: KeysProvider = keys_provider
104+
self._keys_provider = keys_provider
99105
self.require_kid = require_kid
100106
self.logger = get_logger()
101107

102108
async def get_jwks(self) -> JWKS:
103109
return await self._keys_provider.get_keys()
104110

105111
async def get_jwk(self, kid: str) -> JWK:
106-
jwks = await self.get_jwks()
112+
key = await self._keys_provider.get_key(kid)
107113

108-
for jwk in jwks.keys:
109-
if jwk.kid is not None and jwk.kid == kid:
110-
return jwk
111-
raise InvalidAccessToken("kid not recognized")
114+
if key is None:
115+
raise InvalidAccessToken("kid not recognized")
116+
return key
112117

113118
def _validate_jwt_by_key(
114119
self, access_token: str, jwk: JWK

tests/test_authentication.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ async def authenticate(self, context: Request):
112112

113113
@pytest.mark.asyncio
114114
async def test_strategy_throws_for_missing_context():
115-
116115
strategy = AuthenticationStrategy()
117116

118117
with raises(ValueError):

tests/test_common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,6 @@ def test_authorization_strategy_set_default_fluent():
147147

148148

149149
def test_unauthorized_error_supports_error_and_description():
150-
151150
error = UnauthorizedError(
152151
None,
153152
[],

0 commit comments

Comments
 (0)