Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ public class ClientSideCredentialAccessBoundaryFactory {
private final Duration refreshMargin;
private RefreshTask refreshTask;
private final Object refreshLock = new byte[0];
private IntermediateCredentials intermediateCredentials = null;
private volatile IntermediateCredentials intermediateCredentials = null;
private final Clock clock;
private final CelCompiler celCompiler;

Expand Down Expand Up @@ -234,8 +234,26 @@ void refreshCredentialsIfRequired() throws IOException {
return;
}

// If a refresh is required, create or retrieve the refresh task.
RefreshTask currentRefreshTask = getOrCreateRefreshTask();
RefreshTask currentRefreshTask;
// Synchronize both the decision to refresh and the task creation to prevent a race
// condition. Without this, multiple threads might concurrently determine a refresh is
// needed (e.g., in ASYNC mode) and return ASYNC. The first thread would start the refresh
// task. If that task completes quickly and clears the `refreshTask` reference, a subsequent
// thread that also determined a refresh was needed would then see `refreshTask` as null
// and incorrectly start a second, redundant refresh task.
synchronized (refreshLock) {
// Re-evaluate the refresh type under the lock, as another thread might have completed
// the refresh while this thread was waiting for the lock.
refreshType = determineRefreshType();

if (refreshType == RefreshType.NONE) {
// No refresh needed, token is still valid.
return;
}

// If a refresh is required, create or retrieve the refresh task.
currentRefreshTask = getOrCreateRefreshTask();
}

// Handle the refresh based on the determined refresh type.
switch (refreshType) {
Expand Down Expand Up @@ -283,16 +301,14 @@ void refreshCredentialsIfRequired() throws IOException {
}
}

// Assumes the caller holds refreshLock.
private RefreshType determineRefreshType() {
AccessToken intermediateAccessToken;
synchronized (refreshLock) {
if (intermediateCredentials == null
|| intermediateCredentials.intermediateAccessToken == null) {
// A blocking refresh is needed if the intermediate access token doesn't exist.
return RefreshType.BLOCKING;
}
intermediateAccessToken = intermediateCredentials.intermediateAccessToken;
if (intermediateCredentials == null
|| intermediateCredentials.intermediateAccessToken == null) {
// A blocking refresh is needed if the intermediate access token doesn't exist.
return RefreshType.BLOCKING;
}
AccessToken intermediateAccessToken = intermediateCredentials.intermediateAccessToken;

Date expirationTime = intermediateAccessToken.getExpirationTime();
if (expirationTime == null) {
Expand Down Expand Up @@ -322,23 +338,22 @@ private RefreshType determineRefreshType() {
* responsibility of the caller to execute it. The task will clear the single flight slot upon
* completion.
*/
// Assumes the caller holds refreshLock.
private RefreshTask getOrCreateRefreshTask() {
synchronized (refreshLock) {
if (refreshTask != null) {
// An existing refresh task is already in progress. Return a NEW RefreshTask instance with
// the existing task, but set isNew to false. This indicates to the caller that a new
// refresh task was NOT created.
return new RefreshTask(refreshTask.task, false);
}
if (refreshTask != null) {
// An existing refresh task is already in progress. Return a NEW RefreshTask instance with
// the existing task, but set isNew to false. This indicates to the caller that a new
// refresh task was NOT created.
return new RefreshTask(refreshTask.task, false);
}

final ListenableFutureTask<IntermediateCredentials> task =
ListenableFutureTask.create(this::fetchIntermediateCredentials);
final ListenableFutureTask<IntermediateCredentials> task =
ListenableFutureTask.create(this::fetchIntermediateCredentials);

// Store the new refresh task in the refreshTask field before returning. This ensures that
// subsequent calls to this method will return the existing task while it's still in progress.
refreshTask = new RefreshTask(task, true);
return refreshTask;
}
// Store the new refresh task in the refreshTask field before returning. This ensures that
// subsequent calls to this method will return the existing task while it's still in progress.
refreshTask = new RefreshTask(task, true);
return refreshTask;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

/**
Expand Down Expand Up @@ -325,7 +324,6 @@ void refreshCredentialsIfRequired_blockingMultiThread() throws IOException, Inte
}

@Test
@Disabled("Flaky test: https://github.com/googleapis/google-cloud-java/issues/12871")
void refreshCredentialsIfRequired_asyncMultiThread() throws IOException, InterruptedException {
final ClientSideCredentialAccessBoundaryFactory factory =
getClientSideCredentialAccessBoundaryFactory(RefreshType.ASYNC);
Expand Down Expand Up @@ -623,10 +621,16 @@ private Clock createMockClock(RefreshType refreshType, GoogleCredentials sourceC
// (within the refresh margin).
mockedTimeInMillis = expirationTimeInMillis - refreshMarginInMillis + 60000;
when(mockClock.currentTimeMillis())
.thenReturn(
mockedTimeInMillis, // First call: Stale (triggers the async refresh)
currentTimeInMillis // Subsequent calls: Fresh (skips redundant refreshes)
);
.thenAnswer(
invocation -> {
// If the async refresh has already been triggered (request count >= 2),
// return the fresh time to skip redundant refreshes.
// Note: 1st request was the initial blocking refresh.
if (mockStsTransportFactory.transport.getRequestCount() >= 2) {
return currentTimeInMillis;
}
return mockedTimeInMillis;
});
break;
case BLOCKING:
// Set mocked time so that the token requires immediate refresh (just after the minimum
Expand Down
Loading