Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 52 additions & 10 deletions src/main/java/com/auth0/jwk/UrlJwkProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
import java.io.IOException;
import java.io.InputStream;
import java.net.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.*;
import java.util.concurrent.atomic.AtomicReference;

/**
* Jwk provider that loads them from a {@link URL}
Expand All @@ -20,6 +18,8 @@ public class UrlJwkProvider implements JwkProvider {
@VisibleForTesting
static final String WELL_KNOWN_JWKS_PATH = "/.well-known/jwks.json";

private final AtomicReference<List<Jwk>> cachedJwks = new AtomicReference<>();

final URL url;
final Proxy proxy;
final Map<String, String> headers;
Expand Down Expand Up @@ -103,6 +103,11 @@ public UrlJwkProvider(String domain) {
this(urlForDomain(domain));
}

@VisibleForTesting
void setCachedJwks(List<Jwk> jwks) {
this.cachedJwks.set(jwks);
}

static URL urlForDomain(String domain) {
Util.checkArgument(!Util.isNullOrEmpty(domain), "A domain is required");

Expand Down Expand Up @@ -158,19 +163,56 @@ public List<Jwk> getAll() throws SigningKeyNotFoundException {
return jwks;
}

@Override
public Jwk get(String keyId) throws JwkException {
final List<Jwk> jwks = getAll();
private List<Jwk> getCachedJwks() throws JwkException {
List<Jwk> jwks = cachedJwks.get();
if (jwks == null) {
synchronized (this) {
jwks = cachedJwks.get();
if (jwks == null) {
jwks = getAll();
cachedJwks.set(jwks);
}
}
}
return jwks;
}

private Optional<Jwk> findKey(String keyId) throws JwkException {
List<Jwk> jwks = getCachedJwks();
Optional<Jwk> foundKey = searchKey(jwks, keyId);
if (foundKey.isPresent()) {
return foundKey;
}

// Key not found — refreshing JWKS from remote
synchronized (this) {
List<Jwk> freshJwks = getAll();
cachedJwks.set(freshJwks);

return searchKey(freshJwks, keyId);
}
}

private Optional<Jwk> searchKey(List<Jwk> jwks, String keyId) {
if (keyId == null && jwks.size() == 1) {
return jwks.get(0);
return Optional.of(jwks.get(0));
}
if (keyId != null) {
for (Jwk jwk : jwks) {
if (keyId.equals(jwk.getId())) {
return jwk;
return Optional.of(jwk);
}
}
}
throw new SigningKeyNotFoundException("No key found in " + url.toString() + " with kid " + keyId, null);
return Optional.empty();
}

@Override
public Jwk get(String keyId) throws JwkException {

return findKey(keyId).orElseThrow(() ->
new SigningKeyNotFoundException("No key found in " + url.toString() + " with kid " + keyId, null)
);

}
}
76 changes: 76 additions & 0 deletions src/test/java/com/auth0/jwk/UrlJwkProviderTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
import java.lang.ref.WeakReference;
import java.net.*;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static com.auth0.jwk.UrlJwkProvider.WELL_KNOWN_JWKS_PATH;
import static org.hamcrest.Matchers.*;
Expand Down Expand Up @@ -272,6 +274,7 @@ public Object answer(InvocationOnMock invocation) throws Throwable {
UrlJwkProvider urlJwkProvider = new UrlJwkProvider(url, connectTimeout, readTimeout);
assertThat(urlJwkProvider.proxy, is(nullValue()));

urlJwkProvider.setCachedJwks(null);
Jwk jwk = urlJwkProvider.get("NkJCQzIyQzRBMEU4NjhGNUU4MzU4RkY0M0ZDQzkwOUQ0Q0VGNUMwQg");
assertNotNull(jwk);
assertThat(mockFactory.urlUsed.get(), is(url));
Expand All @@ -280,6 +283,8 @@ public Object answer(InvocationOnMock invocation) throws Throwable {
// Test creation: custom headers
UrlJwkProvider urlJwkProviderWithHeaders = new UrlJwkProvider(url, connectTimeout, readTimeout, null,
Collections.singletonMap("Accept", "application/jwks-set+json"));

urlJwkProvider.setCachedJwks(null); // <-- force fetch
Jwk hJwk = urlJwkProviderWithHeaders.get("NkJCQzIyQzRBMEU4NjhGNUU4MzU4RkY0M0ZDQzkwOUQ0Q0VGNUMwQg");
assertNotNull(hJwk);
assertThat(mockFactory.urlUsed.get(), is(url));
Expand All @@ -291,6 +296,7 @@ public Object answer(InvocationOnMock invocation) throws Throwable {
UrlJwkProvider pUrlJwkProvider = new UrlJwkProvider(pUrl, connectTimeout, readTimeout, proxy);
assertThat(pUrlJwkProvider.proxy, is(proxy));

urlJwkProvider.setCachedJwks(null); // <-- force fetch
Jwk pJwk = pUrlJwkProvider.get("NkJCQzIyQzRBMEU4NjhGNUU4MzU4RkY0M0ZDQzkwOUQ0Q0VGNUMwQg");
assertNotNull(pJwk);
assertThat(mockFactory.urlUsed.get(), is(pUrl));
Expand All @@ -317,6 +323,7 @@ public Object answer(InvocationOnMock invocation) throws Throwable {
try {
IOException exception = mock(IOException.class);
when(urlConnection.getInputStream()).thenThrow(exception);
urlJwkProvider.setCachedJwks(null); // <-- force fetch
urlJwkProvider.get("NkJCQzIyQzRBMEU4NjhGNUU4MzU4RkY0M0ZDQzkwOUQ0Q0VGNUMwQg");
} catch (Exception e) {
capturedException = e;
Expand All @@ -328,4 +335,73 @@ public Object answer(InvocationOnMock invocation) throws Throwable {
//release
mockFactory.clear();
}

@Test
public void shouldCacheJwksAfterFirstFetch() throws Exception {
URL url = getClass().getResource("/jwks.json");
UrlJwkProvider provider = spy(new UrlJwkProvider(url));

Jwk firstJwk = provider.get(KID);
assertNotNull(firstJwk);

Jwk secondJwk = provider.get(KID);
assertNotNull(secondJwk);

verify(provider, times(1)).getAll();
}

@Test
public void shouldRefreshCacheIfKeyNotFound() throws Exception {
URL url = getClass().getResource("/jwks.json");
UrlJwkProvider provider = spy(new UrlJwkProvider(url));

// Pre-load a cache with an invalid key (simulate wrong cache)
Map<String, Object> jwkValues = new HashMap<>();
jwkValues.put("kid", "wrong-kid");
jwkValues.put("kty", "RSA");
jwkValues.put("alg", "RS256");
jwkValues.put("use", "sig");
jwkValues.put("n", "test-modulus");
jwkValues.put("e", "AQAB");

List<Jwk> wrongJwks = Collections.singletonList(Jwk.fromValues(jwkValues));

provider.setCachedJwks(wrongJwks);

// Call with correct key - should miss cache, then refresh
Jwk actualJwk = provider.get(KID);
assertNotNull(actualJwk);

verify(provider, times(1)).getAll();
}

@Test
public void shouldFailIfKeyNotFoundEvenAfterRefresh() throws Exception {
expectedException.expect(SigningKeyNotFoundException.class);

URL url = getClass().getResource("/jwks.json");
UrlJwkProvider provider = spy(new UrlJwkProvider(url));

// Set empty cache
provider.setCachedJwks(Collections.emptyList());

// Call with missing key — should refresh, but still fail
provider.get("wrong-kid");

verify(provider, times(1)).getAll(); // Only one refresh
}

@Test
public void shouldFetchIfCacheIsNull() throws Exception {
UrlJwkProvider provider = spy(new UrlJwkProvider(getClass().getResource("/jwks.json")));

// Ensure cache is unset (null)
provider.setCachedJwks(null);

Jwk jwk = provider.get(KID);
assertNotNull(jwk);

verify(provider, atLeastOnce()).getAll(); // Should definitely be called
}

}
Loading