@@ -18,55 +18,68 @@ class JWTService:
1818 def __init__ (self , jwt_config : JWTConfiguration ) -> None :
1919 self .jwt_config = jwt_config
2020
21- self .jwks_client = (
22- PyJWKClient (self .jwt_config .jwk_url ) if self .jwt_config .jwk_url else None
23- )
24- self .leeway = self .jwt_config .leeway
21+ def get_jwks_client (self , jwt_config : JWTConfiguration ) -> t .Optional [PyJWKClient ]:
22+ jwks_client = PyJWKClient (jwt_config .jwk_url ) if jwt_config .jwk_url else None
23+ return jwks_client
2524
26- def get_leeway (self ) -> timedelta :
27- if self .leeway is None :
25+ def get_leeway (self , jwt_config : JWTConfiguration ) -> timedelta :
26+ if jwt_config .leeway is None :
2827 return timedelta (seconds = 0 )
29- elif isinstance (self .leeway , (int , float )):
30- return timedelta (seconds = self .leeway )
31- elif isinstance (self .leeway , timedelta ):
32- return self .leeway
28+ elif isinstance (jwt_config .leeway , (int , float )):
29+ return timedelta (seconds = jwt_config .leeway )
30+ elif isinstance (jwt_config .leeway , timedelta ):
31+ return jwt_config .leeway
3332
34- def get_verifying_key (self , token : t .Any ) -> bytes :
33+ def get_verifying_key (self , token : t .Any , jwt_config : JWTConfiguration ) -> bytes :
3534 if self .jwt_config .algorithm .startswith ("HS" ):
36- return self . jwt_config .signing_secret_key .encode ()
35+ return jwt_config .signing_secret_key .encode ()
3736
38- if self .jwks_client :
37+ jwks_client = self .get_jwks_client (jwt_config )
38+ if jwks_client :
3939 try :
40- p_jwk = self . jwks_client .get_signing_key_from_jwt (token )
40+ p_jwk = jwks_client .get_signing_key_from_jwt (token )
4141 return p_jwk .key # type:ignore[no-any-return]
4242 except PyJWKClientError as ex :
4343 raise JWTTokenException ("Token is invalid or expired" ) from ex
4444
45- return self .jwt_config .verifying_secret_key .encode ()
45+ return jwt_config .verifying_secret_key .encode ()
46+
47+ def _merge_configurations (self , ** jwt_config : t .Any ) -> JWTConfiguration :
48+ jwt_config_default = self .jwt_config .dict ()
49+ jwt_config_default .update (jwt_config )
50+ return JWTConfiguration (** jwt_config_default )
4651
4752 def sign (
48- self , payload : dict , headers : t .Optional [t .Dict [str , t .Any ]] = None
53+ self ,
54+ payload : dict ,
55+ headers : t .Optional [t .Dict [str , t .Any ]] = None ,
56+ ** jwt_config : t .Any ,
4957 ) -> str :
5058 """
5159 Returns an encoded token for the given payload dictionary.
5260 """
53-
54- jwt_payload = Token (jwt_config = self . jwt_config ).build (payload .copy ())
61+ _jwt_config = self . _merge_configurations ( ** jwt_config )
62+ jwt_payload = Token (jwt_config = _jwt_config ).build (payload .copy ())
5563
5664 return jwt .encode (
5765 jwt_payload ,
58- self . jwt_config .signing_secret_key ,
59- algorithm = self . jwt_config .algorithm ,
60- json_encoder = self . jwt_config .json_encoder ,
66+ _jwt_config .signing_secret_key ,
67+ algorithm = _jwt_config .algorithm ,
68+ json_encoder = _jwt_config .json_encoder ,
6169 headers = headers ,
6270 )
6371
6472 async def sign_async (
65- self , payload : dict , headers : t .Optional [t .Dict [str , t .Any ]] = None
73+ self ,
74+ payload : dict ,
75+ headers : t .Optional [t .Dict [str , t .Any ]] = None ,
76+ ** jwt_config : t .Any ,
6677 ) -> str :
67- return await anyio .to_thread .run_sync (self .sign , payload , headers )
78+ return await anyio .to_thread .run_sync (self .sign , payload , headers , ** jwt_config )
6879
69- def decode (self , token : str , verify : bool = True ) -> t .Dict [str , t .Any ]:
80+ def decode (
81+ self , token : str , verify : bool = True , ** jwt_config : t .Any
82+ ) -> t .Dict [str , t .Any ]:
7083 """
7184 Performs a validation of the given token and returns its payload
7285 dictionary.
@@ -75,15 +88,16 @@ def decode(self, token: str, verify: bool = True) -> t.Dict[str, t.Any]:
7588 signature check fails, or if its 'exp' claim indicates it has expired.
7689 """
7790 try :
91+ _jwt_config = self ._merge_configurations (** jwt_config )
7892 return jwt .decode ( # type: ignore[no-any-return]
7993 token ,
80- self .get_verifying_key (token ),
81- algorithms = [self . jwt_config .algorithm ],
82- audience = self . jwt_config .audience ,
83- issuer = self . jwt_config .issuer ,
84- leeway = self .get_leeway (),
94+ self .get_verifying_key (token , _jwt_config ),
95+ algorithms = [_jwt_config .algorithm ],
96+ audience = _jwt_config .audience ,
97+ issuer = _jwt_config .issuer ,
98+ leeway = self .get_leeway (_jwt_config ),
8599 options = {
86- "verify_aud" : self . jwt_config .audience is not None ,
100+ "verify_aud" : _jwt_config .audience is not None ,
87101 "verify_signature" : verify ,
88102 },
89103 )
@@ -92,5 +106,7 @@ def decode(self, token: str, verify: bool = True) -> t.Dict[str, t.Any]:
92106 except InvalidTokenError as ex :
93107 raise JWTTokenException ("Token is invalid or expired" ) from ex
94108
95- async def decode_async (self , token : str , verify : bool = True ) -> t .Dict [str , t .Any ]:
96- return await anyio .to_thread .run_sync (self .decode , token , verify )
109+ async def decode_async (
110+ self , token : str , verify : bool = True , ** jwt_config : t .Any
111+ ) -> t .Dict [str , t .Any ]:
112+ return await anyio .to_thread .run_sync (self .decode , token , verify , ** jwt_config )
0 commit comments