@@ -44,7 +44,8 @@ def __init__(
4444 require_kid : bool = True ,
4545 keys_provider : Optional [KeysProvider ] = None ,
4646 keys_url : Optional [str ] = None ,
47- cache_time : float = 10800
47+ cache_time : float = 10800 ,
48+ refresh_time : float = 120 ,
4849 ) -> None :
4950 """
5051 Creates a new instance of JWTValidator. This class only supports validating
@@ -72,9 +73,15 @@ def __init__(
7273 keys_url : Optional[str], optional
7374 If provided, keys are obtained from the given URL through HTTP GET.
7475 This parameter is ignored if `keys_provider` is given.
75- cache_time : float, optional
76- If >= 0, JWKS are cached in memory and stored for the given amount in
77- seconds. By default 10800 (3 hours).
76+ cache_time : float
77+ JWKS are cached in memory and stored for the given amount in seconds.
78+ By default 10800 (3 hours). Regardless of this parameter, JWKS are refreshed
79+ automatically if an unknown kid is met and JWKS were last fetched more than
80+ `refresh_time` earlier (in seconds).
81+ refresh_time : float
82+ JWKS are refreshed automatically if an unknown `kid` is encountered, and
83+ JWKS were last fetched more than `refresh_time` seconds ago (by default
84+ 120 seconds)
7885 """
7986 if keys_provider :
8087 pass
@@ -89,26 +96,24 @@ def __init__(
8996 "`authority`, or `keys_provider`."
9097 )
9198
92- if cache_time :
93- keys_provider = CachingKeysProvider (keys_provider , cache_time )
99+ keys_provider = CachingKeysProvider (keys_provider , cache_time , refresh_time )
94100
95101 self ._valid_issuers = list (valid_issuers )
96102 self ._valid_audiences = list (valid_audiences )
97103 self ._algorithms = list (algorithms )
98- self ._keys_provider : KeysProvider = keys_provider
104+ self ._keys_provider = keys_provider
99105 self .require_kid = require_kid
100106 self .logger = get_logger ()
101107
102108 async def get_jwks (self ) -> JWKS :
103109 return await self ._keys_provider .get_keys ()
104110
105111 async def get_jwk (self , kid : str ) -> JWK :
106- jwks = await self .get_jwks ( )
112+ key = await self ._keys_provider . get_key ( kid )
107113
108- for jwk in jwks .keys :
109- if jwk .kid is not None and jwk .kid == kid :
110- return jwk
111- raise InvalidAccessToken ("kid not recognized" )
114+ if key is None :
115+ raise InvalidAccessToken ("kid not recognized" )
116+ return key
112117
113118 def _validate_jwt_by_key (
114119 self , access_token : str , jwk : JWK
0 commit comments