Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Predicate;
import javax.annotation.Nullable;

Expand All @@ -62,8 +63,11 @@ public final class ChannelFinder {
private final AtomicLong databaseId = new AtomicLong();
private final KeyRecipeCache recipeCache = new KeyRecipeCache();
private final KeyRangeCache rangeCache;
private final AtomicReference<CacheUpdate> pendingUpdate = new AtomicReference<>();
private volatile java.util.concurrent.CountDownLatch drainingLatch;
private final ConcurrentLinkedQueue<PendingCacheUpdate> pendingUpdates =
new ConcurrentLinkedQueue<>();
private final AtomicBoolean drainScheduled = new AtomicBoolean();
private volatile java.util.concurrent.CountDownLatch drainingLatch =
new java.util.concurrent.CountDownLatch(0);
@Nullable private final EndpointLifecycleManager lifecycleManager;
@Nullable private final String finderKey;

Expand Down Expand Up @@ -105,15 +109,38 @@ private static ExecutorService createCacheUpdatePool() {
return executor;
}

private static final class PendingCacheUpdate {
private final CacheUpdate update;

private PendingCacheUpdate(CacheUpdate update) {
this.update = update;
}
}

private boolean isMaterialUpdate(CacheUpdate update) {
return update.getGroupCount() > 0
|| update.getRangeCount() > 0
|| (update.hasKeyRecipes() && update.getKeyRecipes().getRecipeCount() > 0);
}

private boolean shouldProcessUpdate(CacheUpdate update) {
if (isMaterialUpdate(update)) {
return true;
}
long updateDatabaseId = update.getDatabaseId();
return updateDatabaseId != 0 && databaseId.get() != updateDatabaseId;
}

public void update(CacheUpdate update) {
synchronized (updateLock) {
long currentId = databaseId.get();
if (currentId != update.getDatabaseId()) {
long updateDatabaseId = update.getDatabaseId();
if (updateDatabaseId != 0 && currentId != updateDatabaseId) {
if (currentId != 0) {
recipeCache.clear();
rangeCache.clear();
}
databaseId.set(update.getDatabaseId());
databaseId.set(updateDatabaseId);
}
if (update.hasKeyRecipes()) {
recipeCache.addRecipes(update.getKeyRecipes());
Expand Down Expand Up @@ -141,10 +168,11 @@ public void update(CacheUpdate update) {
}

public void updateAsync(CacheUpdate update) {
// Replace any pending update atomically. Each CacheUpdate contains the full current state,
// so intermediate updates can be safely dropped to prevent unbounded queue growth.
if (pendingUpdate.getAndSet(update) == null) {
// No previous pending update means no drain task is scheduled yet — submit one.
if (!shouldProcessUpdate(update)) {
return;
}
pendingUpdates.add(new PendingCacheUpdate(update));
if (drainScheduled.compareAndSet(false, true)) {
java.util.concurrent.CountDownLatch latch = new java.util.concurrent.CountDownLatch(1);
drainingLatch = latch;
CACHE_UPDATE_POOL.execute(
Expand All @@ -159,9 +187,15 @@ public void updateAsync(CacheUpdate update) {
}

private void drainPendingUpdate() {
CacheUpdate toApply;
while ((toApply = pendingUpdate.getAndSet(null)) != null) {
update(toApply);
while (true) {
PendingCacheUpdate toApply;
while ((toApply = pendingUpdates.poll()) != null) {
update(toApply.update);
}
drainScheduled.set(false);
if (pendingUpdates.isEmpty() || !drainScheduled.compareAndSet(false, true)) {
return;
}
}
}

Expand All @@ -171,15 +205,19 @@ private void drainPendingUpdate() {
*/
@VisibleForTesting
void awaitPendingUpdates() throws InterruptedException {
// Spin until no pending update remains.
long deadline = System.nanoTime() + java.util.concurrent.TimeUnit.SECONDS.toNanos(5);
while (pendingUpdate.get() != null && System.nanoTime() < deadline) {
Thread.sleep(1);
}
// Wait for the drain task to fully complete (including the update() call).
java.util.concurrent.CountDownLatch latch = drainingLatch;
if (latch != null) {
latch.await(5, java.util.concurrent.TimeUnit.SECONDS);
while (System.nanoTime() < deadline) {
java.util.concurrent.CountDownLatch latch = drainingLatch;
if (latch != null) {
long remainingNanos = deadline - System.nanoTime();
if (remainingNanos <= 0) {
break;
}
latch.await(remainingNanos, java.util.concurrent.TimeUnit.NANOSECONDS);
}
if (pendingUpdates.isEmpty() && !drainScheduled.get()) {
return;
}
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
/*
* Copyright 2026 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.spanner.spi.v1;

import static com.google.common.truth.Truth.assertThat;

import com.google.protobuf.ByteString;
import com.google.spanner.v1.CacheUpdate;
import com.google.spanner.v1.Group;
import com.google.spanner.v1.Range;
import com.google.spanner.v1.Tablet;
import io.grpc.CallOptions;
import io.grpc.ClientCall;
import io.grpc.ManagedChannel;
import io.grpc.MethodDescriptor;
import java.lang.reflect.Field;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public class ChannelFinderTest {

@Test
public void updateAsyncDrainsQueuedUpdatesInOrderWithoutDroppingAny() throws Exception {
ExecutorService executor = cacheUpdatePool();
int threadCount = maxCacheUpdateThreads();
CountDownLatch workersStarted = new CountDownLatch(threadCount);
CountDownLatch releaseWorkers = new CountDownLatch(1);

try {
for (int i = 0; i < threadCount; i++) {
executor.execute(
() -> {
workersStarted.countDown();
try {
releaseWorkers.await(5, TimeUnit.SECONDS);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
});
}
assertThat(workersStarted.await(5, TimeUnit.SECONDS)).isTrue();

ChannelFinder finder = new ChannelFinder(new FakeEndpointCache());
int updateCount = 64;
for (int i = 0; i < updateCount; i++) {
finder.updateAsync(singleRangeUpdate(i));
}

releaseWorkers.countDown();
finder.awaitPendingUpdates();

assertThat(rangeCache(finder).size()).isEqualTo(updateCount);
} finally {
releaseWorkers.countDown();
}
}

@Test
public void updateIgnoresZeroDatabaseIdAndKeepsExistingCache() throws Exception {
ChannelFinder finder = new ChannelFinder(new FakeEndpointCache());
finder.update(singleRangeUpdate(0));

finder.update(CacheUpdate.newBuilder().setDatabaseId(0L).build());

assertThat(databaseId(finder)).isEqualTo(7L);
assertThat(rangeCache(finder).size()).isEqualTo(1);
}

@Test
public void updateAsyncSkipsTrulyEmptyUpdateForCurrentDatabase() throws Exception {
ChannelFinder finder = new ChannelFinder(new FakeEndpointCache());
finder.update(singleRangeUpdate(0));

finder.updateAsync(CacheUpdate.newBuilder().setDatabaseId(7L).build());
finder.awaitPendingUpdates();

assertThat(databaseId(finder)).isEqualTo(7L);
assertThat(rangeCache(finder).size()).isEqualTo(1);
}

@Test
public void updateAsyncProcessesDatabaseTransitionWithoutRangesOrGroups() throws Exception {
ChannelFinder finder = new ChannelFinder(new FakeEndpointCache());
finder.update(singleRangeUpdate(0));

finder.updateAsync(CacheUpdate.newBuilder().setDatabaseId(9L).build());
finder.awaitPendingUpdates();

assertThat(databaseId(finder)).isEqualTo(9L);
assertThat(rangeCache(finder).size()).isEqualTo(0);
}

private static CacheUpdate singleRangeUpdate(int index) {
String startKey = String.format("k%05d", index);
String limitKey = String.format("k%05d", index + 1);
long groupUid = index + 1L;
return CacheUpdate.newBuilder()
.setDatabaseId(7L)
.addRange(
Range.newBuilder()
.setStartKey(bytes(startKey))
.setLimitKey(bytes(limitKey))
.setGroupUid(groupUid)
.setSplitId(groupUid)
.setGeneration(bytes("g")))
.addGroup(
Group.newBuilder()
.setGroupUid(groupUid)
.setGeneration(bytes("g"))
.addTablets(
Tablet.newBuilder()
.setTabletUid(groupUid)
.setServerAddress("server-" + index + ":1234")
.setIncarnation(bytes("i"))
.setDistance(0)))
.build();
}

private static ByteString bytes(String value) {
return ByteString.copyFromUtf8(value);
}

private static ExecutorService cacheUpdatePool() throws Exception {
Field field = ChannelFinder.class.getDeclaredField("CACHE_UPDATE_POOL");
field.setAccessible(true);
return (ExecutorService) field.get(null);
}

private static int maxCacheUpdateThreads() throws Exception {
Field field = ChannelFinder.class.getDeclaredField("MAX_CACHE_UPDATE_THREADS");
field.setAccessible(true);
return field.getInt(null);
}

private static long databaseId(ChannelFinder finder) throws Exception {
Field field = ChannelFinder.class.getDeclaredField("databaseId");
field.setAccessible(true);
return ((AtomicLong) field.get(finder)).get();
}

private static KeyRangeCache rangeCache(ChannelFinder finder) throws Exception {
Field field = ChannelFinder.class.getDeclaredField("rangeCache");
field.setAccessible(true);
return (KeyRangeCache) field.get(finder);
}

private static final class FakeEndpointCache implements ChannelEndpointCache {
private final Map<String, FakeEndpoint> endpoints = new ConcurrentHashMap<>();
private final FakeEndpoint defaultEndpoint = new FakeEndpoint("default");

@Override
public ChannelEndpoint defaultChannel() {
return defaultEndpoint;
}

@Override
public ChannelEndpoint get(String address) {
return endpoints.computeIfAbsent(address, FakeEndpoint::new);
}

@Override
public ChannelEndpoint getIfPresent(String address) {
return endpoints.computeIfAbsent(address, FakeEndpoint::new);
}

@Override
public void evict(String address) {
endpoints.remove(address);
}

@Override
public void shutdown() {
endpoints.clear();
}
}

private static final class FakeEndpoint implements ChannelEndpoint {
private final String address;
private final ManagedChannel channel = new FakeManagedChannel();

private FakeEndpoint(String address) {
this.address = address;
}

@Override
public String getAddress() {
return address;
}

@Override
public boolean isHealthy() {
return true;
}

@Override
public boolean isTransientFailure() {
return false;
}

@Override
public ManagedChannel getChannel() {
return channel;
}
}

private static final class FakeManagedChannel extends ManagedChannel {
@Override
public ManagedChannel shutdown() {
return this;
}

@Override
public ManagedChannel shutdownNow() {
return this;
}

@Override
public boolean isShutdown() {
return false;
}

@Override
public boolean isTerminated() {
return false;
}

@Override
public boolean awaitTermination(long timeout, TimeUnit unit) {
return true;
}

@Override
public <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall(
MethodDescriptor<RequestT, ResponseT> methodDescriptor, CallOptions callOptions) {
throw new UnsupportedOperationException();
}

@Override
public String authority() {
return "fake";
}
}
}
Loading