Skip to content

Commit 1e21c7e

Browse files
authored
Feat: Added Cache Implementation (#221)
2 parents 1f31987 + 09f09b9 commit 1e21c7e

File tree

2 files changed

+128
-10
lines changed

2 files changed

+128
-10
lines changed

src/main/java/com/auth0/jwk/UrlJwkProvider.java

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,8 @@
66
import java.io.IOException;
77
import java.io.InputStream;
88
import java.net.*;
9-
import java.util.ArrayList;
10-
import java.util.Collections;
11-
import java.util.List;
12-
import java.util.Map;
9+
import java.util.*;
10+
import java.util.concurrent.atomic.AtomicReference;
1311

1412
/**
1513
* Jwk provider that loads them from a {@link URL}
@@ -20,6 +18,8 @@ public class UrlJwkProvider implements JwkProvider {
2018
@VisibleForTesting
2119
static final String WELL_KNOWN_JWKS_PATH = "/.well-known/jwks.json";
2220

21+
private final AtomicReference<List<Jwk>> cachedJwks = new AtomicReference<>();
22+
2323
final URL url;
2424
final Proxy proxy;
2525
final Map<String, String> headers;
@@ -103,6 +103,11 @@ public UrlJwkProvider(String domain) {
103103
this(urlForDomain(domain));
104104
}
105105

106+
@VisibleForTesting
107+
void setCachedJwks(List<Jwk> jwks) {
108+
this.cachedJwks.set(jwks);
109+
}
110+
106111
static URL urlForDomain(String domain) {
107112
Util.checkArgument(!Util.isNullOrEmpty(domain), "A domain is required");
108113

@@ -158,19 +163,56 @@ public List<Jwk> getAll() throws SigningKeyNotFoundException {
158163
return jwks;
159164
}
160165

161-
@Override
162-
public Jwk get(String keyId) throws JwkException {
163-
final List<Jwk> jwks = getAll();
166+
private List<Jwk> getCachedJwks() throws JwkException {
167+
List<Jwk> jwks = cachedJwks.get();
168+
if (jwks == null) {
169+
synchronized (this) {
170+
jwks = cachedJwks.get();
171+
if (jwks == null) {
172+
jwks = getAll();
173+
cachedJwks.set(jwks);
174+
}
175+
}
176+
}
177+
return jwks;
178+
}
179+
180+
private Optional<Jwk> findKey(String keyId) throws JwkException {
181+
List<Jwk> jwks = getCachedJwks();
182+
Optional<Jwk> foundKey = searchKey(jwks, keyId);
183+
if (foundKey.isPresent()) {
184+
return foundKey;
185+
}
186+
187+
// Key not found — refreshing JWKS from remote
188+
synchronized (this) {
189+
List<Jwk> freshJwks = getAll();
190+
cachedJwks.set(freshJwks);
191+
192+
return searchKey(freshJwks, keyId);
193+
}
194+
}
195+
196+
private Optional<Jwk> searchKey(List<Jwk> jwks, String keyId) {
164197
if (keyId == null && jwks.size() == 1) {
165-
return jwks.get(0);
198+
return Optional.of(jwks.get(0));
166199
}
167200
if (keyId != null) {
168201
for (Jwk jwk : jwks) {
169202
if (keyId.equals(jwk.getId())) {
170-
return jwk;
203+
return Optional.of(jwk);
171204
}
172205
}
173206
}
174-
throw new SigningKeyNotFoundException("No key found in " + url.toString() + " with kid " + keyId, null);
207+
return Optional.empty();
208+
}
209+
210+
@Override
211+
public Jwk get(String keyId) throws JwkException {
212+
213+
return findKey(keyId).orElseThrow(() ->
214+
new SigningKeyNotFoundException("No key found in " + url.toString() + " with kid " + keyId, null)
215+
);
216+
175217
}
176218
}

src/test/java/com/auth0/jwk/UrlJwkProviderTest.java

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
import java.lang.ref.WeakReference;
1313
import java.net.*;
1414
import java.util.Collections;
15+
import java.util.HashMap;
1516
import java.util.List;
17+
import java.util.Map;
1618

1719
import static com.auth0.jwk.UrlJwkProvider.WELL_KNOWN_JWKS_PATH;
1820
import static org.hamcrest.Matchers.*;
@@ -272,6 +274,7 @@ public Object answer(InvocationOnMock invocation) throws Throwable {
272274
UrlJwkProvider urlJwkProvider = new UrlJwkProvider(url, connectTimeout, readTimeout);
273275
assertThat(urlJwkProvider.proxy, is(nullValue()));
274276

277+
urlJwkProvider.setCachedJwks(null);
275278
Jwk jwk = urlJwkProvider.get("NkJCQzIyQzRBMEU4NjhGNUU4MzU4RkY0M0ZDQzkwOUQ0Q0VGNUMwQg");
276279
assertNotNull(jwk);
277280
assertThat(mockFactory.urlUsed.get(), is(url));
@@ -280,6 +283,8 @@ public Object answer(InvocationOnMock invocation) throws Throwable {
280283
// Test creation: custom headers
281284
UrlJwkProvider urlJwkProviderWithHeaders = new UrlJwkProvider(url, connectTimeout, readTimeout, null,
282285
Collections.singletonMap("Accept", "application/jwks-set+json"));
286+
287+
urlJwkProvider.setCachedJwks(null); // <-- force fetch
283288
Jwk hJwk = urlJwkProviderWithHeaders.get("NkJCQzIyQzRBMEU4NjhGNUU4MzU4RkY0M0ZDQzkwOUQ0Q0VGNUMwQg");
284289
assertNotNull(hJwk);
285290
assertThat(mockFactory.urlUsed.get(), is(url));
@@ -291,6 +296,7 @@ public Object answer(InvocationOnMock invocation) throws Throwable {
291296
UrlJwkProvider pUrlJwkProvider = new UrlJwkProvider(pUrl, connectTimeout, readTimeout, proxy);
292297
assertThat(pUrlJwkProvider.proxy, is(proxy));
293298

299+
urlJwkProvider.setCachedJwks(null); // <-- force fetch
294300
Jwk pJwk = pUrlJwkProvider.get("NkJCQzIyQzRBMEU4NjhGNUU4MzU4RkY0M0ZDQzkwOUQ0Q0VGNUMwQg");
295301
assertNotNull(pJwk);
296302
assertThat(mockFactory.urlUsed.get(), is(pUrl));
@@ -317,6 +323,7 @@ public Object answer(InvocationOnMock invocation) throws Throwable {
317323
try {
318324
IOException exception = mock(IOException.class);
319325
when(urlConnection.getInputStream()).thenThrow(exception);
326+
urlJwkProvider.setCachedJwks(null); // <-- force fetch
320327
urlJwkProvider.get("NkJCQzIyQzRBMEU4NjhGNUU4MzU4RkY0M0ZDQzkwOUQ0Q0VGNUMwQg");
321328
} catch (Exception e) {
322329
capturedException = e;
@@ -328,4 +335,73 @@ public Object answer(InvocationOnMock invocation) throws Throwable {
328335
//release
329336
mockFactory.clear();
330337
}
338+
339+
@Test
340+
public void shouldCacheJwksAfterFirstFetch() throws Exception {
341+
URL url = getClass().getResource("/jwks.json");
342+
UrlJwkProvider provider = spy(new UrlJwkProvider(url));
343+
344+
Jwk firstJwk = provider.get(KID);
345+
assertNotNull(firstJwk);
346+
347+
Jwk secondJwk = provider.get(KID);
348+
assertNotNull(secondJwk);
349+
350+
verify(provider, times(1)).getAll();
351+
}
352+
353+
@Test
354+
public void shouldRefreshCacheIfKeyNotFound() throws Exception {
355+
URL url = getClass().getResource("/jwks.json");
356+
UrlJwkProvider provider = spy(new UrlJwkProvider(url));
357+
358+
// Pre-load a cache with an invalid key (simulate wrong cache)
359+
Map<String, Object> jwkValues = new HashMap<>();
360+
jwkValues.put("kid", "wrong-kid");
361+
jwkValues.put("kty", "RSA");
362+
jwkValues.put("alg", "RS256");
363+
jwkValues.put("use", "sig");
364+
jwkValues.put("n", "test-modulus");
365+
jwkValues.put("e", "AQAB");
366+
367+
List<Jwk> wrongJwks = Collections.singletonList(Jwk.fromValues(jwkValues));
368+
369+
provider.setCachedJwks(wrongJwks);
370+
371+
// Call with correct key - should miss cache, then refresh
372+
Jwk actualJwk = provider.get(KID);
373+
assertNotNull(actualJwk);
374+
375+
verify(provider, times(1)).getAll();
376+
}
377+
378+
@Test
379+
public void shouldFailIfKeyNotFoundEvenAfterRefresh() throws Exception {
380+
expectedException.expect(SigningKeyNotFoundException.class);
381+
382+
URL url = getClass().getResource("/jwks.json");
383+
UrlJwkProvider provider = spy(new UrlJwkProvider(url));
384+
385+
// Set empty cache
386+
provider.setCachedJwks(Collections.emptyList());
387+
388+
// Call with missing key — should refresh, but still fail
389+
provider.get("wrong-kid");
390+
391+
verify(provider, times(1)).getAll(); // Only one refresh
392+
}
393+
394+
@Test
395+
public void shouldFetchIfCacheIsNull() throws Exception {
396+
UrlJwkProvider provider = spy(new UrlJwkProvider(getClass().getResource("/jwks.json")));
397+
398+
// Ensure cache is unset (null)
399+
provider.setCachedJwks(null);
400+
401+
Jwk jwk = provider.get(KID);
402+
assertNotNull(jwk);
403+
404+
verify(provider, atLeastOnce()).getAll(); // Should definitely be called
405+
}
406+
331407
}

0 commit comments

Comments
 (0)