diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelFinder.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelFinder.java index edb3da6a7cc4..6e77ebd2692d 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelFinder.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelFinder.java @@ -24,25 +24,23 @@ import com.google.spanner.v1.CommitRequest; import com.google.spanner.v1.DirectedReadOptions; import com.google.spanner.v1.ExecuteSqlRequest; -import com.google.spanner.v1.Group; import com.google.spanner.v1.Mutation; import com.google.spanner.v1.ReadRequest; import com.google.spanner.v1.RoutingHint; -import com.google.spanner.v1.Tablet; import com.google.spanner.v1.TransactionOptions; import com.google.spanner.v1.TransactionSelector; import java.util.ArrayList; -import java.util.HashSet; import java.util.List; import java.util.Objects; import java.util.Set; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ExecutorService; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; import java.util.function.Predicate; import javax.annotation.Nullable; @@ -54,6 +52,7 @@ @InternalApi public final class ChannelFinder { private static final Predicate NO_EXCLUDED_ENDPOINTS = address -> false; + private static final int CACHE_UPDATE_DRAIN_BATCH_SIZE = 64; private static final int MAX_CACHE_UPDATE_THREADS = Math.max(2, Runtime.getRuntime().availableProcessors()); private static final ExecutorService CACHE_UPDATE_POOL = createCacheUpdatePool(); @@ -62,8 +61,11 @@ public final class ChannelFinder { private final AtomicLong databaseId = new AtomicLong(); private final KeyRecipeCache recipeCache = new KeyRecipeCache(); private final KeyRangeCache rangeCache; - private final AtomicReference pendingUpdate = new AtomicReference<>(); - private volatile java.util.concurrent.CountDownLatch drainingLatch; + private final ConcurrentLinkedQueue pendingUpdates = + new ConcurrentLinkedQueue<>(); + private final AtomicBoolean drainScheduled = new AtomicBoolean(); + private volatile java.util.concurrent.CountDownLatch drainingLatch = + new java.util.concurrent.CountDownLatch(0); @Nullable private final EndpointLifecycleManager lifecycleManager; @Nullable private final String finderKey; @@ -105,46 +107,43 @@ private static ExecutorService createCacheUpdatePool() { return executor; } + private static final class PendingCacheUpdate { + private final CacheUpdate update; + + private PendingCacheUpdate(CacheUpdate update) { + this.update = update; + } + } + + private boolean isMaterialUpdate(CacheUpdate update) { + return update.getGroupCount() > 0 + || update.getRangeCount() > 0 + || (update.hasKeyRecipes() && update.getKeyRecipes().getRecipeCount() > 0); + } + + private boolean shouldProcessUpdate(CacheUpdate update) { + if (isMaterialUpdate(update)) { + return true; + } + long updateDatabaseId = update.getDatabaseId(); + return updateDatabaseId != 0 && databaseId.get() != updateDatabaseId; + } + public void update(CacheUpdate update) { + Set currentAddresses; synchronized (updateLock) { - long currentId = databaseId.get(); - if (currentId != update.getDatabaseId()) { - if (currentId != 0) { - recipeCache.clear(); - rangeCache.clear(); - } - databaseId.set(update.getDatabaseId()); - } - if (update.hasKeyRecipes()) { - recipeCache.addRecipes(update.getKeyRecipes()); - } - rangeCache.addRanges(update); - - // Notify the lifecycle manager about server addresses so it can create endpoints - // in the background and start probing, and evict stale endpoints atomically. - if (lifecycleManager != null && finderKey != null) { - Set currentAddresses = new HashSet<>(); - for (Group group : update.getGroupList()) { - for (Tablet tablet : group.getTabletsList()) { - String addr = tablet.getServerAddress(); - if (!addr.isEmpty()) { - currentAddresses.add(addr); - } - } - } - // Also include addresses from existing cached tablets not in this update. - currentAddresses.addAll(rangeCache.getActiveAddresses()); - // Atomically ensure endpoints exist and evict stale ones. - lifecycleManager.updateActiveAddresses(finderKey, currentAddresses); - } + applyUpdateLocked(update); + currentAddresses = snapshotActiveAddressesLocked(); } + publishLifecycleUpdate(currentAddresses); } public void updateAsync(CacheUpdate update) { - // Replace any pending update atomically. Each CacheUpdate contains the full current state, - // so intermediate updates can be safely dropped to prevent unbounded queue growth. - if (pendingUpdate.getAndSet(update) == null) { - // No previous pending update means no drain task is scheduled yet — submit one. + if (!shouldProcessUpdate(update)) { + return; + } + pendingUpdates.add(new PendingCacheUpdate(update)); + if (drainScheduled.compareAndSet(false, true)) { java.util.concurrent.CountDownLatch latch = new java.util.concurrent.CountDownLatch(1); drainingLatch = latch; CACHE_UPDATE_POOL.execute( @@ -159,27 +158,89 @@ public void updateAsync(CacheUpdate update) { } private void drainPendingUpdate() { - CacheUpdate toApply; - while ((toApply = pendingUpdate.getAndSet(null)) != null) { - update(toApply); + List batch = new ArrayList<>(CACHE_UPDATE_DRAIN_BATCH_SIZE); + while (true) { + drainBatch(batch); + if (!batch.isEmpty()) { + applyBatch(batch); + batch.clear(); + } + drainScheduled.set(false); + if (pendingUpdates.isEmpty() || !drainScheduled.compareAndSet(false, true)) { + return; + } + } + } + + private void drainBatch(List batch) { + PendingCacheUpdate toApply; + while (batch.size() < CACHE_UPDATE_DRAIN_BATCH_SIZE + && (toApply = pendingUpdates.poll()) != null) { + batch.add(toApply); } } + private void applyBatch(List batch) { + Set currentAddresses; + synchronized (updateLock) { + for (PendingCacheUpdate pendingUpdate : batch) { + applyUpdateLocked(pendingUpdate.update); + } + currentAddresses = snapshotActiveAddressesLocked(); + } + publishLifecycleUpdate(currentAddresses); + } + + private void applyUpdateLocked(CacheUpdate update) { + long currentId = databaseId.get(); + long updateDatabaseId = update.getDatabaseId(); + if (updateDatabaseId != 0 && currentId != updateDatabaseId) { + if (currentId != 0) { + recipeCache.clear(); + rangeCache.clear(); + } + databaseId.set(updateDatabaseId); + } + if (update.hasKeyRecipes()) { + recipeCache.addRecipes(update.getKeyRecipes()); + } + rangeCache.addRanges(update); + } + + @Nullable + private Set snapshotActiveAddressesLocked() { + if (lifecycleManager == null || finderKey == null) { + return null; + } + return rangeCache.getActiveAddresses(); + } + + private void publishLifecycleUpdate(@Nullable Set currentAddresses) { + if (currentAddresses == null) { + return; + } + lifecycleManager.updateActiveAddressesAsync(finderKey, currentAddresses); + } + /** * Test-only hook used by {@link KeyAwareChannel#awaitPendingCacheUpdates()} to wait until the * async cache update worker has finished applying the latest pending update. */ @VisibleForTesting void awaitPendingUpdates() throws InterruptedException { - // Spin until no pending update remains. long deadline = System.nanoTime() + java.util.concurrent.TimeUnit.SECONDS.toNanos(5); - while (pendingUpdate.get() != null && System.nanoTime() < deadline) { - Thread.sleep(1); - } - // Wait for the drain task to fully complete (including the update() call). - java.util.concurrent.CountDownLatch latch = drainingLatch; - if (latch != null) { - latch.await(5, java.util.concurrent.TimeUnit.SECONDS); + while (System.nanoTime() < deadline) { + java.util.concurrent.CountDownLatch latch = drainingLatch; + if (latch != null) { + long remainingNanos = deadline - System.nanoTime(); + if (remainingNanos <= 0) { + break; + } + latch.await(remainingNanos, java.util.concurrent.TimeUnit.NANOSECONDS); + } + if (pendingUpdates.isEmpty() && !drainScheduled.get()) { + return; + } } } diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointLifecycleManager.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointLifecycleManager.java index 854f4330c183..ae78f07b14a3 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointLifecycleManager.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointLifecycleManager.java @@ -29,6 +29,8 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; @@ -103,6 +105,11 @@ static final class EndpointState { private final ChannelEndpointCache endpointCache; private final Map endpoints = new ConcurrentHashMap<>(); private final Set transientFailureEvictedAddresses = ConcurrentHashMap.newKeySet(); + private final Map finderGenerations = new ConcurrentHashMap<>(); + private final Map pendingActiveAddressUpdates = + new ConcurrentHashMap<>(); + private final Set queuedFinderKeys = ConcurrentHashMap.newKeySet(); + private final ConcurrentLinkedQueue queuedFinders = new ConcurrentLinkedQueue<>(); /** * Active addresses reported by each ChannelFinder, keyed by database id. @@ -118,8 +125,10 @@ static final class EndpointState { private final Object activeAddressLock = new Object(); + private final ExecutorService activeAddressUpdateExecutor; private final ScheduledExecutorService scheduler; private final AtomicBoolean isShutdown = new AtomicBoolean(false); + private final AtomicBoolean activeAddressDrainScheduled = new AtomicBoolean(false); private final long probeIntervalSeconds; private final Duration idleEvictionDuration; private final Clock clock; @@ -127,6 +136,16 @@ static final class EndpointState { private ScheduledFuture evictionFuture; + private static final class PendingActiveAddressUpdate { + private final Set activeAddresses; + private final long generation; + + private PendingActiveAddressUpdate(Set activeAddresses, long generation) { + this.activeAddresses = activeAddresses; + this.generation = generation; + } + } + EndpointLifecycleManager(ChannelEndpointCache endpointCache) { this( endpointCache, @@ -146,6 +165,13 @@ static final class EndpointState { this.idleEvictionDuration = idleEvictionDuration; this.clock = clock; this.defaultEndpointAddress = endpointCache.defaultChannel().getAddress(); + this.activeAddressUpdateExecutor = + Executors.newSingleThreadExecutor( + r -> { + Thread t = new Thread(r, "spanner-active-address-update"); + t.setDaemon(true); + return t; + }); this.scheduler = Executors.newScheduledThreadPool( 2, @@ -213,6 +239,59 @@ private void clearTransientFailureEvictionMarker(String address) { } } + /** + * Enqueues active-address reconciliation on a dedicated worker so cache-map updates do not block + * on endpoint lifecycle bookkeeping. + */ + void updateActiveAddressesAsync(String finderKey, Set activeAddresses) { + if (isShutdown.get() || finderKey == null || finderKey.isEmpty()) { + return; + } + synchronized (activeAddressLock) { + long generation = finderGenerations.getOrDefault(finderKey, 0L); + pendingActiveAddressUpdates.put( + finderKey, new PendingActiveAddressUpdate(new HashSet<>(activeAddresses), generation)); + if (queuedFinderKeys.add(finderKey)) { + queuedFinders.add(finderKey); + } + } + scheduleActiveAddressDrain(); + } + + private void scheduleActiveAddressDrain() { + if (!activeAddressDrainScheduled.compareAndSet(false, true)) { + return; + } + activeAddressUpdateExecutor.execute(this::drainPendingActiveAddressUpdates); + } + + private void drainPendingActiveAddressUpdates() { + while (true) { + String finderKey = queuedFinders.poll(); + if (finderKey == null) { + activeAddressDrainScheduled.set(false); + if (queuedFinders.isEmpty() || !activeAddressDrainScheduled.compareAndSet(false, true)) { + return; + } + continue; + } + + queuedFinderKeys.remove(finderKey); + PendingActiveAddressUpdate pendingUpdate = pendingActiveAddressUpdates.remove(finderKey); + if (pendingUpdate == null) { + continue; + } + + synchronized (activeAddressLock) { + long currentGeneration = finderGenerations.getOrDefault(finderKey, 0L); + if (currentGeneration != pendingUpdate.generation) { + continue; + } + } + updateActiveAddresses(finderKey, pendingUpdate.activeAddresses); + } + } + /** * Records that real (non-probe) traffic was routed to an endpoint. This refreshes the idle * eviction timer for this endpoint. @@ -295,6 +374,9 @@ void unregisterFinder(String finderKey) { return; } synchronized (activeAddressLock) { + finderGenerations.merge(finderKey, 1L, Long::sum); + pendingActiveAddressUpdates.remove(finderKey); + queuedFinderKeys.remove(finderKey); if (activeAddressesPerFinder.remove(finderKey) == null) { return; } @@ -588,6 +670,7 @@ void shutdown() { } logger.log(Level.FINE, "Shutting down endpoint lifecycle manager"); + activeAddressUpdateExecutor.shutdownNow(); if (evictionFuture != null) { evictionFuture.cancel(false); @@ -602,6 +685,9 @@ void shutdown() { } endpoints.clear(); transientFailureEvictedAddresses.clear(); + pendingActiveAddressUpdates.clear(); + queuedFinderKeys.clear(); + queuedFinders.clear(); scheduler.shutdown(); try { diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/ChannelFinderTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/ChannelFinderTest.java new file mode 100644 index 000000000000..d73c946cc2ac --- /dev/null +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/ChannelFinderTest.java @@ -0,0 +1,311 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner.spi.v1; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.protobuf.ByteString; +import com.google.spanner.v1.CacheUpdate; +import com.google.spanner.v1.Group; +import com.google.spanner.v1.Range; +import com.google.spanner.v1.Tablet; +import io.grpc.CallOptions; +import io.grpc.ClientCall; +import io.grpc.ManagedChannel; +import io.grpc.MethodDescriptor; +import java.lang.reflect.Field; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ChannelFinderTest { + + @Test + public void updateAsyncDrainsQueuedUpdatesInOrderWithoutDroppingAny() throws Exception { + ExecutorService executor = cacheUpdatePool(); + int threadCount = maxCacheUpdateThreads(); + CountDownLatch workersStarted = new CountDownLatch(threadCount); + CountDownLatch releaseWorkers = new CountDownLatch(1); + + try { + for (int i = 0; i < threadCount; i++) { + executor.execute( + () -> { + workersStarted.countDown(); + try { + releaseWorkers.await(5, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + } + assertThat(workersStarted.await(5, TimeUnit.SECONDS)).isTrue(); + + ChannelFinder finder = new ChannelFinder(new FakeEndpointCache()); + int updateCount = 64; + for (int i = 0; i < updateCount; i++) { + finder.updateAsync(singleRangeUpdate(i)); + } + + releaseWorkers.countDown(); + finder.awaitPendingUpdates(); + + assertThat(rangeCache(finder).size()).isEqualTo(updateCount); + } finally { + releaseWorkers.countDown(); + } + } + + @Test + public void updateIgnoresZeroDatabaseIdAndKeepsExistingCache() throws Exception { + ChannelFinder finder = new ChannelFinder(new FakeEndpointCache()); + finder.update(singleRangeUpdate(0)); + + finder.update(CacheUpdate.newBuilder().setDatabaseId(0L).build()); + + assertThat(databaseId(finder)).isEqualTo(7L); + assertThat(rangeCache(finder).size()).isEqualTo(1); + } + + @Test + public void updateAsyncSkipsTrulyEmptyUpdateForCurrentDatabase() throws Exception { + ChannelFinder finder = new ChannelFinder(new FakeEndpointCache()); + finder.update(singleRangeUpdate(0)); + + finder.updateAsync(CacheUpdate.newBuilder().setDatabaseId(7L).build()); + finder.awaitPendingUpdates(); + + assertThat(databaseId(finder)).isEqualTo(7L); + assertThat(rangeCache(finder).size()).isEqualTo(1); + } + + @Test + public void updateAsyncProcessesDatabaseTransitionWithoutRangesOrGroups() throws Exception { + ChannelFinder finder = new ChannelFinder(new FakeEndpointCache()); + finder.update(singleRangeUpdate(0)); + + finder.updateAsync(CacheUpdate.newBuilder().setDatabaseId(9L).build()); + finder.awaitPendingUpdates(); + + assertThat(databaseId(finder)).isEqualTo(9L); + assertThat(rangeCache(finder).size()).isEqualTo(0); + } + + @Test + public void updateDoesNotBlockOnLifecycleManagerAddressReconciliation() throws Exception { + BlockingLifecycleManager lifecycleManager = + new BlockingLifecycleManager(new FakeEndpointCache()); + ChannelFinder finder = new ChannelFinder(new FakeEndpointCache(), lifecycleManager, "db-1"); + ExecutorService executor = Executors.newSingleThreadExecutor(); + try { + Future future = executor.submit(() -> finder.update(singleRangeUpdate(0))); + + future.get(1, TimeUnit.SECONDS); + assertThat(lifecycleManager.updateStarted.await(5, TimeUnit.SECONDS)).isTrue(); + assertThat(rangeCache(finder).size()).isEqualTo(1); + + lifecycleManager.releaseUpdate.countDown(); + assertThat(lifecycleManager.updateFinished.await(5, TimeUnit.SECONDS)).isTrue(); + } finally { + lifecycleManager.releaseUpdate.countDown(); + lifecycleManager.shutdown(); + executor.shutdownNow(); + } + } + + private static CacheUpdate singleRangeUpdate(int index) { + String startKey = String.format("k%05d", index); + String limitKey = String.format("k%05d", index + 1); + long groupUid = index + 1L; + return CacheUpdate.newBuilder() + .setDatabaseId(7L) + .addRange( + Range.newBuilder() + .setStartKey(bytes(startKey)) + .setLimitKey(bytes(limitKey)) + .setGroupUid(groupUid) + .setSplitId(groupUid) + .setGeneration(bytes("g"))) + .addGroup( + Group.newBuilder() + .setGroupUid(groupUid) + .setGeneration(bytes("g")) + .addTablets( + Tablet.newBuilder() + .setTabletUid(groupUid) + .setServerAddress("server-" + index + ":1234") + .setIncarnation(bytes("i")) + .setDistance(0))) + .build(); + } + + private static ByteString bytes(String value) { + return ByteString.copyFromUtf8(value); + } + + private static ExecutorService cacheUpdatePool() throws Exception { + Field field = ChannelFinder.class.getDeclaredField("CACHE_UPDATE_POOL"); + field.setAccessible(true); + return (ExecutorService) field.get(null); + } + + private static int maxCacheUpdateThreads() throws Exception { + Field field = ChannelFinder.class.getDeclaredField("MAX_CACHE_UPDATE_THREADS"); + field.setAccessible(true); + return field.getInt(null); + } + + private static long databaseId(ChannelFinder finder) throws Exception { + Field field = ChannelFinder.class.getDeclaredField("databaseId"); + field.setAccessible(true); + return ((AtomicLong) field.get(finder)).get(); + } + + private static KeyRangeCache rangeCache(ChannelFinder finder) throws Exception { + Field field = ChannelFinder.class.getDeclaredField("rangeCache"); + field.setAccessible(true); + return (KeyRangeCache) field.get(finder); + } + + private static final class FakeEndpointCache implements ChannelEndpointCache { + private final Map endpoints = new ConcurrentHashMap<>(); + private final FakeEndpoint defaultEndpoint = new FakeEndpoint("default"); + + @Override + public ChannelEndpoint defaultChannel() { + return defaultEndpoint; + } + + @Override + public ChannelEndpoint get(String address) { + return endpoints.computeIfAbsent(address, FakeEndpoint::new); + } + + @Override + public ChannelEndpoint getIfPresent(String address) { + return endpoints.computeIfAbsent(address, FakeEndpoint::new); + } + + @Override + public void evict(String address) { + endpoints.remove(address); + } + + @Override + public void shutdown() { + endpoints.clear(); + } + } + + private static final class FakeEndpoint implements ChannelEndpoint { + private final String address; + private final ManagedChannel channel = new FakeManagedChannel(); + + private FakeEndpoint(String address) { + this.address = address; + } + + @Override + public String getAddress() { + return address; + } + + @Override + public boolean isHealthy() { + return true; + } + + @Override + public boolean isTransientFailure() { + return false; + } + + @Override + public ManagedChannel getChannel() { + return channel; + } + } + + private static final class FakeManagedChannel extends ManagedChannel { + @Override + public ManagedChannel shutdown() { + return this; + } + + @Override + public ManagedChannel shutdownNow() { + return this; + } + + @Override + public boolean isShutdown() { + return false; + } + + @Override + public boolean isTerminated() { + return false; + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) { + return true; + } + + @Override + public ClientCall newCall( + MethodDescriptor methodDescriptor, CallOptions callOptions) { + throw new UnsupportedOperationException(); + } + + @Override + public String authority() { + return "fake"; + } + } + + private static final class BlockingLifecycleManager extends EndpointLifecycleManager { + private final CountDownLatch updateStarted = new CountDownLatch(1); + private final CountDownLatch releaseUpdate = new CountDownLatch(1); + private final CountDownLatch updateFinished = new CountDownLatch(1); + + private BlockingLifecycleManager(ChannelEndpointCache endpointCache) { + super(endpointCache); + } + + @Override + void updateActiveAddresses(String finderKey, java.util.Set activeAddresses) { + updateStarted.countDown(); + try { + releaseUpdate.await(5, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } finally { + updateFinished.countDown(); + } + } + } +}