1212import java .lang .ref .WeakReference ;
1313import java .net .*;
1414import java .util .Collections ;
15+ import java .util .HashMap ;
1516import java .util .List ;
17+ import java .util .Map ;
1618
1719import static com .auth0 .jwk .UrlJwkProvider .WELL_KNOWN_JWKS_PATH ;
1820import static org .hamcrest .Matchers .*;
@@ -272,6 +274,7 @@ public Object answer(InvocationOnMock invocation) throws Throwable {
272274 UrlJwkProvider urlJwkProvider = new UrlJwkProvider (url , connectTimeout , readTimeout );
273275 assertThat (urlJwkProvider .proxy , is (nullValue ()));
274276
277+ urlJwkProvider .setCachedJwks (null );
275278 Jwk jwk = urlJwkProvider .get ("NkJCQzIyQzRBMEU4NjhGNUU4MzU4RkY0M0ZDQzkwOUQ0Q0VGNUMwQg" );
276279 assertNotNull (jwk );
277280 assertThat (mockFactory .urlUsed .get (), is (url ));
@@ -280,6 +283,8 @@ public Object answer(InvocationOnMock invocation) throws Throwable {
280283 // Test creation: custom headers
281284 UrlJwkProvider urlJwkProviderWithHeaders = new UrlJwkProvider (url , connectTimeout , readTimeout , null ,
282285 Collections .singletonMap ("Accept" , "application/jwks-set+json" ));
286+
287+ urlJwkProvider .setCachedJwks (null ); // <-- force fetch
283288 Jwk hJwk = urlJwkProviderWithHeaders .get ("NkJCQzIyQzRBMEU4NjhGNUU4MzU4RkY0M0ZDQzkwOUQ0Q0VGNUMwQg" );
284289 assertNotNull (hJwk );
285290 assertThat (mockFactory .urlUsed .get (), is (url ));
@@ -291,6 +296,7 @@ public Object answer(InvocationOnMock invocation) throws Throwable {
291296 UrlJwkProvider pUrlJwkProvider = new UrlJwkProvider (pUrl , connectTimeout , readTimeout , proxy );
292297 assertThat (pUrlJwkProvider .proxy , is (proxy ));
293298
299+ urlJwkProvider .setCachedJwks (null ); // <-- force fetch
294300 Jwk pJwk = pUrlJwkProvider .get ("NkJCQzIyQzRBMEU4NjhGNUU4MzU4RkY0M0ZDQzkwOUQ0Q0VGNUMwQg" );
295301 assertNotNull (pJwk );
296302 assertThat (mockFactory .urlUsed .get (), is (pUrl ));
@@ -317,6 +323,7 @@ public Object answer(InvocationOnMock invocation) throws Throwable {
317323 try {
318324 IOException exception = mock (IOException .class );
319325 when (urlConnection .getInputStream ()).thenThrow (exception );
326+ urlJwkProvider .setCachedJwks (null ); // <-- force fetch
320327 urlJwkProvider .get ("NkJCQzIyQzRBMEU4NjhGNUU4MzU4RkY0M0ZDQzkwOUQ0Q0VGNUMwQg" );
321328 } catch (Exception e ) {
322329 capturedException = e ;
@@ -328,4 +335,73 @@ public Object answer(InvocationOnMock invocation) throws Throwable {
328335 //release
329336 mockFactory .clear ();
330337 }
338+
339+ @ Test
340+ public void shouldCacheJwksAfterFirstFetch () throws Exception {
341+ URL url = getClass ().getResource ("/jwks.json" );
342+ UrlJwkProvider provider = spy (new UrlJwkProvider (url ));
343+
344+ Jwk firstJwk = provider .get (KID );
345+ assertNotNull (firstJwk );
346+
347+ Jwk secondJwk = provider .get (KID );
348+ assertNotNull (secondJwk );
349+
350+ verify (provider , times (1 )).getAll ();
351+ }
352+
353+ @ Test
354+ public void shouldRefreshCacheIfKeyNotFound () throws Exception {
355+ URL url = getClass ().getResource ("/jwks.json" );
356+ UrlJwkProvider provider = spy (new UrlJwkProvider (url ));
357+
358+ // Pre-load a cache with an invalid key (simulate wrong cache)
359+ Map <String , Object > jwkValues = new HashMap <>();
360+ jwkValues .put ("kid" , "wrong-kid" );
361+ jwkValues .put ("kty" , "RSA" );
362+ jwkValues .put ("alg" , "RS256" );
363+ jwkValues .put ("use" , "sig" );
364+ jwkValues .put ("n" , "test-modulus" );
365+ jwkValues .put ("e" , "AQAB" );
366+
367+ List <Jwk > wrongJwks = Collections .singletonList (Jwk .fromValues (jwkValues ));
368+
369+ provider .setCachedJwks (wrongJwks );
370+
371+ // Call with correct key - should miss cache, then refresh
372+ Jwk actualJwk = provider .get (KID );
373+ assertNotNull (actualJwk );
374+
375+ verify (provider , times (1 )).getAll ();
376+ }
377+
378+ @ Test
379+ public void shouldFailIfKeyNotFoundEvenAfterRefresh () throws Exception {
380+ expectedException .expect (SigningKeyNotFoundException .class );
381+
382+ URL url = getClass ().getResource ("/jwks.json" );
383+ UrlJwkProvider provider = spy (new UrlJwkProvider (url ));
384+
385+ // Set empty cache
386+ provider .setCachedJwks (Collections .emptyList ());
387+
388+ // Call with missing key — should refresh, but still fail
389+ provider .get ("wrong-kid" );
390+
391+ verify (provider , times (1 )).getAll (); // Only one refresh
392+ }
393+
394+ @ Test
395+ public void shouldFetchIfCacheIsNull () throws Exception {
396+ UrlJwkProvider provider = spy (new UrlJwkProvider (getClass ().getResource ("/jwks.json" )));
397+
398+ // Ensure cache is unset (null)
399+ provider .setCachedJwks (null );
400+
401+ Jwk jwk = provider .get (KID );
402+ assertNotNull (jwk );
403+
404+ verify (provider , atLeastOnce ()).getAll (); // Should definitely be called
405+ }
406+
331407}
0 commit comments