Skip to content
This repository was archived by the owner on May 12, 2026. It is now read-only.

Commit 664754e

Browse files
igorbernstein2lesv
authored andcommitted
Add caching for JWT tokens (#151)
* Add caching for JWT tokens * code style * Add tests * refresh token 5 minutes early
1 parent 2f10be4 commit 664754e

3 files changed

Lines changed: 159 additions & 2 deletions

File tree

oauth2_http/java/com/google/auth/oauth2/ServiceAccountJwtAccessCredentials.java

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@
4747
import com.google.common.annotations.VisibleForTesting;
4848
import com.google.common.base.MoreObjects;
4949

50+
import com.google.common.base.Throwables;
51+
import com.google.common.base.Ticker;
52+
import com.google.common.cache.CacheBuilder;
53+
import com.google.common.cache.CacheLoader;
54+
import com.google.common.cache.LoadingCache;
55+
import com.google.common.util.concurrent.UncheckedExecutionException;
5056
import java.io.IOException;
5157
import java.io.InputStream;
5258
import java.io.ObjectInputStream;
@@ -61,7 +67,9 @@
6167
import java.util.List;
6268
import java.util.Map;
6369
import java.util.Objects;
70+
import java.util.concurrent.ExecutionException;
6471
import java.util.concurrent.Executor;
72+
import java.util.concurrent.TimeUnit;
6573

6674
/**
6775
* Service Account credentials for calling Google APIs using a JWT directly for access.
@@ -73,12 +81,16 @@ public class ServiceAccountJwtAccessCredentials extends Credentials
7381

7482
private static final long serialVersionUID = -7274955171379494197L;
7583
static final String JWT_ACCESS_PREFIX = OAuth2Utils.BEARER_PREFIX;
84+
@VisibleForTesting
85+
static final long LIFE_SPAN_SECS = TimeUnit.HOURS.toSeconds(1);
7686

7787
private final String clientId;
7888
private final String clientEmail;
7989
private final PrivateKey privateKey;
8090
private final String privateKeyId;
8191
private final URI defaultAudience;
92+
private transient LoadingCache<URI, String> tokenCache;
93+
8294

8395
// Until we expose this to the users it can remain transient and non-serializable
8496
@VisibleForTesting
@@ -119,6 +131,7 @@ public ServiceAccountJwtAccessCredentials(String clientId, String clientEmail,
119131
this.privateKey = Preconditions.checkNotNull(privateKey);
120132
this.privateKeyId = privateKeyId;
121133
this.defaultAudience = defaultAudience;
134+
this.tokenCache = createCache();
122135
}
123136

124137
/**
@@ -228,6 +241,28 @@ public static ServiceAccountJwtAccessCredentials fromStream(InputStream credenti
228241
+ " Expecting '%s'.", fileType, SERVICE_ACCOUNT_FILE_TYPE));
229242
}
230243

244+
private LoadingCache<URI, String> createCache() {
245+
return CacheBuilder.newBuilder()
246+
.maximumSize(100)
247+
.expireAfterWrite(LIFE_SPAN_SECS - 300, TimeUnit.SECONDS)
248+
.ticker(
249+
new Ticker() {
250+
@Override
251+
public long read() {
252+
return TimeUnit.MILLISECONDS.toNanos(clock.currentTimeMillis());
253+
}
254+
}
255+
)
256+
.build(
257+
new CacheLoader<URI, String>() {
258+
@Override
259+
public String load(URI key) throws Exception {
260+
return generateJwtAccess(key);
261+
}
262+
}
263+
);
264+
}
265+
231266
@Override
232267
public String getAuthenticationType() {
233268
return "JWTAccess";
@@ -275,10 +310,25 @@ public Map<String, List<String>> getRequestMetadata(URI uri) throws IOException
275310
*/
276311
@Override
277312
public void refresh() {
313+
tokenCache.invalidateAll();
278314
}
279315

280316
private String getJwtAccess(URI uri) throws IOException {
317+
try {
318+
return tokenCache.get(uri);
319+
} catch (ExecutionException e) {
320+
Throwables.propagateIfPossible(e.getCause(), IOException.class);
321+
// Should never happen
322+
throw new IllegalStateException("generateJwtAccess threw an unexpected checked exception", e.getCause());
323+
324+
} catch (UncheckedExecutionException e) {
325+
Throwables.propagateIfPossible(e);
326+
// Should never happen
327+
throw new IllegalStateException("generateJwtAccess threw an unchecked exception that couldn't be rethrown", e);
328+
}
329+
}
281330

331+
private String generateJwtAccess(URI uri) throws IOException {
282332
JsonWebSignature.Header header = new JsonWebSignature.Header();
283333
header.setAlgorithm("RS256");
284334
header.setType("JWT");
@@ -291,7 +341,7 @@ private String getJwtAccess(URI uri) throws IOException {
291341
payload.setSubject(clientEmail);
292342
payload.setAudience(uri.toString());
293343
payload.setIssuedAtTimeSeconds(currentTime / 1000);
294-
payload.setExpirationTimeSeconds(currentTime / 1000 + 3600);
344+
payload.setExpirationTimeSeconds(currentTime / 1000 + LIFE_SPAN_SECS);
295345

296346
JsonFactory jsonFactory = OAuth2Utils.JSON_FACTORY;
297347

@@ -369,6 +419,7 @@ public boolean equals(Object obj) {
369419
private void readObject(ObjectInputStream input) throws IOException, ClassNotFoundException {
370420
input.defaultReadObject();
371421
clock = Clock.SYSTEM;
422+
tokenCache = createCache();
372423
}
373424

374425
public static Builder newBuilder() {

oauth2_http/javatests/com/google/auth/oauth2/ServiceAccountJwtAccessCredentialsTest.java

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import static org.junit.Assert.assertArrayEquals;
3535
import static org.junit.Assert.assertEquals;
3636
import static org.junit.Assert.assertFalse;
37+
import static org.junit.Assert.assertNotEquals;
3738
import static org.junit.Assert.assertNotNull;
3839
import static org.junit.Assert.assertNull;
3940
import static org.junit.Assert.assertSame;
@@ -45,6 +46,7 @@
4546
import com.google.api.client.json.webtoken.JsonWebSignature;
4647
import com.google.api.client.util.Clock;
4748
import com.google.auth.Credentials;
49+
import com.google.auth.TestClock;
4850
import com.google.auth.http.AuthHttpConstants;
4951
import com.google.auth.oauth2.GoogleCredentialsTest.MockHttpTransportFactory;
5052

@@ -62,6 +64,7 @@
6264
import java.security.SignatureException;
6365
import java.util.List;
6466
import java.util.Map;
67+
import java.util.concurrent.TimeUnit;
6568

6669
/**
6770
* Test case for {@link ServiceAccountCredentials}.
@@ -224,6 +227,54 @@ public void getRequestMetadata_blocking_noURI_throws() throws IOException {
224227
}
225228
}
226229

230+
@Test
231+
public void getRequestMetadata_blocking_cached() throws IOException {
232+
TestClock testClock = new TestClock();
233+
234+
PrivateKey privateKey = ServiceAccountCredentials.privateKeyFromPkcs8(SA_PRIVATE_KEY_PKCS8);
235+
ServiceAccountJwtAccessCredentials credentials = ServiceAccountJwtAccessCredentials.newBuilder()
236+
.setClientId(SA_CLIENT_ID)
237+
.setClientEmail(SA_CLIENT_EMAIL)
238+
.setPrivateKey(privateKey)
239+
.setPrivateKeyId(SA_PRIVATE_KEY_ID)
240+
.build();
241+
credentials.clock = testClock;
242+
243+
Map<String, List<String>> metadata1 = credentials.getRequestMetadata(CALL_URI);
244+
245+
// Fast forward time a little
246+
long lifeSpanMs = TimeUnit.SECONDS.toMillis(10);
247+
testClock.setCurrentTime(lifeSpanMs);
248+
249+
Map<String, List<String>> metadata2 = credentials.getRequestMetadata(CALL_URI);
250+
251+
assertEquals(metadata1, metadata2);
252+
}
253+
254+
@Test
255+
public void getRequestMetadata_blocking_cache_expired() throws IOException {
256+
TestClock testClock = new TestClock();
257+
258+
PrivateKey privateKey = ServiceAccountCredentials.privateKeyFromPkcs8(SA_PRIVATE_KEY_PKCS8);
259+
ServiceAccountJwtAccessCredentials credentials = ServiceAccountJwtAccessCredentials.newBuilder()
260+
.setClientId(SA_CLIENT_ID)
261+
.setClientEmail(SA_CLIENT_EMAIL)
262+
.setPrivateKey(privateKey)
263+
.setPrivateKeyId(SA_PRIVATE_KEY_ID)
264+
.build();
265+
credentials.clock = testClock;
266+
267+
Map<String, List<String>> metadata1 = credentials.getRequestMetadata(CALL_URI);
268+
269+
// Fast forward time past the expiration
270+
long lifeSpanMs = TimeUnit.SECONDS.toMillis(ServiceAccountJwtAccessCredentials.LIFE_SPAN_SECS);
271+
testClock.setCurrentTime(lifeSpanMs);
272+
273+
Map<String, List<String>> metadata2 = credentials.getRequestMetadata(CALL_URI);
274+
275+
assertNotEquals(metadata1, metadata2);
276+
}
277+
227278
@Test
228279
public void getRequestMetadata_async_hasJwtAccess() throws IOException {
229280
PrivateKey privateKey = ServiceAccountCredentials.privateKeyFromPkcs8(SA_PRIVATE_KEY_PKCS8);
@@ -278,6 +329,60 @@ public void getRequestMetadata_async_noURI_exception() throws IOException {
278329
assertNotNull(callback.exception);
279330
}
280331

332+
@Test
333+
public void getRequestMetadata_async_cache_expired() throws IOException {
334+
TestClock testClock = new TestClock();
335+
336+
PrivateKey privateKey = ServiceAccountCredentials.privateKeyFromPkcs8(SA_PRIVATE_KEY_PKCS8);
337+
ServiceAccountJwtAccessCredentials credentials = ServiceAccountJwtAccessCredentials.newBuilder()
338+
.setClientId(SA_CLIENT_ID)
339+
.setClientEmail(SA_CLIENT_EMAIL)
340+
.setPrivateKey(privateKey)
341+
.setPrivateKeyId(SA_PRIVATE_KEY_ID)
342+
.build();
343+
credentials.clock = testClock;
344+
MockExecutor executor = new MockExecutor();
345+
346+
MockRequestMetadataCallback callback1 = new MockRequestMetadataCallback();
347+
credentials.getRequestMetadata(CALL_URI, executor, callback1);
348+
349+
// Fast forward time past the expiration
350+
long lifeSpanMs = TimeUnit.SECONDS.toMillis(ServiceAccountJwtAccessCredentials.LIFE_SPAN_SECS);
351+
testClock.setCurrentTime(lifeSpanMs);
352+
353+
MockRequestMetadataCallback callback2 = new MockRequestMetadataCallback();
354+
credentials.getRequestMetadata(CALL_URI, executor, callback2);
355+
356+
assertNotEquals(callback1.metadata, callback2.metadata);
357+
}
358+
359+
@Test
360+
public void getRequestMetadata_async_cached() throws IOException {
361+
TestClock testClock = new TestClock();
362+
363+
PrivateKey privateKey = ServiceAccountCredentials.privateKeyFromPkcs8(SA_PRIVATE_KEY_PKCS8);
364+
ServiceAccountJwtAccessCredentials credentials = ServiceAccountJwtAccessCredentials.newBuilder()
365+
.setClientId(SA_CLIENT_ID)
366+
.setClientEmail(SA_CLIENT_EMAIL)
367+
.setPrivateKey(privateKey)
368+
.setPrivateKeyId(SA_PRIVATE_KEY_ID)
369+
.build();
370+
credentials.clock = testClock;
371+
MockExecutor executor = new MockExecutor();
372+
373+
MockRequestMetadataCallback callback1 = new MockRequestMetadataCallback();
374+
credentials.getRequestMetadata(CALL_URI, executor, callback1);
375+
376+
// Fast forward time a little
377+
long lifeSpanMs = TimeUnit.SECONDS.toMillis(10);
378+
testClock.setCurrentTime(lifeSpanMs);
379+
380+
MockRequestMetadataCallback callback2 = new MockRequestMetadataCallback();
381+
credentials.getRequestMetadata(CALL_URI, executor, callback2);
382+
383+
assertEquals(callback1.metadata, callback2.metadata);
384+
}
385+
281386
@Test
282387
public void getAccount_sameAs() throws IOException {
283388
PrivateKey privateKey = ServiceAccountCredentials.privateKeyFromPkcs8(SA_PRIVATE_KEY_PKCS8);
@@ -466,6 +571,7 @@ public void serialize() throws IOException, ClassNotFoundException {
466571
.build();
467572
ServiceAccountJwtAccessCredentials deserializedCredentials =
468573
serializeAndDeserialize(credentials);
574+
verifyJwtAccess(deserializedCredentials.getRequestMetadata(), SA_CLIENT_EMAIL, CALL_URI, SA_PRIVATE_KEY_ID);
469575
assertEquals(credentials, deserializedCredentials);
470576
assertEquals(credentials.hashCode(), deserializedCredentials.hashCode());
471577
assertEquals(credentials.toString(), deserializedCredentials.toString());

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
4747
<java.version>1.6</java.version>
4848
<project.google.http.version>1.19.0</project.google.http.version>
49-
<project.junit.version>4.8.2</project.junit.version>
49+
<project.junit.version>4.12</project.junit.version>
5050
<project.guava.version>19.0</project.guava.version>
5151
<project.appengine.version>1.9.34</project.appengine.version>
5252
</properties>

0 commit comments

Comments
 (0)