diff --git a/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpManagedChannel.java b/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpManagedChannel.java index 09f1dd0..9a98a13 100644 --- a/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpManagedChannel.java +++ b/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpManagedChannel.java @@ -68,7 +68,9 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.RejectedExecutionException; -import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; @@ -122,6 +124,8 @@ public class GcpManagedChannel extends ManagedChannel { private Duration scaleDownInterval = Duration.ZERO; private boolean isDynamicScalingEnabled = false; private int maxConcurrentStreamsLowWatermark = DEFAULT_MAX_STREAM; + private GcpManagedChannelOptions.ChannelPickStrategy channelPickStrategy = + GcpManagedChannelOptions.ChannelPickStrategy.POWER_OF_TWO; private Duration affinityKeyLifetime = Duration.ZERO; @VisibleForTesting final Map methodToAffinity = new HashMap<>(); @@ -179,8 +183,12 @@ public class GcpManagedChannel extends ManagedChannel { private final String metricPoolIndex = String.format("pool-%d", channelPoolIndex.incrementAndGet()); private final Map cumulativeMetricValues = new ConcurrentHashMap<>(); - private final ScheduledExecutorService backgroundService = - Executors.newSingleThreadScheduledExecutor(GcpThreadFactory.newThreadFactory("gcp-mc-bg-%d")); + private static final ScheduledThreadPoolExecutor SHARED_BACKGROUND_SERVICE = + createSharedBackgroundService(); + + private ScheduledFuture cleanupTask; + private ScheduledFuture scaleDownTask; + private ScheduledFuture logMetricsTask; // Metrics counters. private final AtomicInteger readyChannels = new AtomicInteger(); @@ -223,6 +231,25 @@ public class GcpManagedChannel extends ManagedChannel { private AtomicLong scaleUpCount = new AtomicLong(); private AtomicLong scaleDownCount = new AtomicLong(); + // Clock supplier for nanoTime, injectable for testing. + private Supplier nanoClock = System::nanoTime; + + @VisibleForTesting + void setNanoClock(Supplier nanoClock) { + this.nanoClock = nanoClock; + } + + private static ScheduledThreadPoolExecutor createSharedBackgroundService() { + ScheduledThreadPoolExecutor executor = + new ScheduledThreadPoolExecutor( + Math.max(2, Math.min(4, Runtime.getRuntime().availableProcessors() / 2)), + GcpThreadFactory.newThreadFactory("gcp-mc-bg-%d")); + executor.setRemoveOnCancelPolicy(true); + executor.setExecuteExistingDelayedTasksAfterShutdownPolicy(false); + executor.setContinueExistingPeriodicTasksAfterShutdownPolicy(false); + return executor; + } + /** * Constructor for GcpManagedChannel. * @@ -396,6 +423,7 @@ private void initOptions() { scaleDownInterval = poolOptions.getScaleDownInterval(); isDynamicScalingEnabled = minRpcPerChannel > 0 && maxRpcPerChannel > 0 && !scaleDownInterval.isZero(); + channelPickStrategy = poolOptions.getChannelPickStrategy(); } initMetrics(); } @@ -404,11 +432,12 @@ private synchronized void initCleanupTask(Duration cleanupInterval) { if (cleanupInterval.isZero()) { return; } - backgroundService.scheduleAtFixedRate( - this::cleanupAffinityKeys, - cleanupInterval.toMillis(), - cleanupInterval.toMillis(), - MILLISECONDS); + cleanupTask = + SHARED_BACKGROUND_SERVICE.scheduleAtFixedRate( + this::cleanupAffinityKeys, + cleanupInterval.toMillis(), + cleanupInterval.toMillis(), + MILLISECONDS); } private synchronized void initScaleDownChecker(Duration scaleDownInterval) { @@ -416,15 +445,17 @@ private synchronized void initScaleDownChecker(Duration scaleDownInterval) { return; } - backgroundService.scheduleAtFixedRate( - this::checkScaleDown, - scaleDownInterval.toMillis(), - scaleDownInterval.toMillis(), - MILLISECONDS); + scaleDownTask = + SHARED_BACKGROUND_SERVICE.scheduleAtFixedRate( + this::checkScaleDown, + scaleDownInterval.toMillis(), + scaleDownInterval.toMillis(), + MILLISECONDS); } private synchronized void initLogMetrics() { - backgroundService.scheduleAtFixedRate(this::logMetrics, 60, 60, SECONDS); + logMetricsTask = + SHARED_BACKGROUND_SERVICE.scheduleAtFixedRate(this::logMetrics, 60, 60, SECONDS); } private void logMetricsOptions() { @@ -1757,43 +1788,83 @@ private ChannelRef pickLeastBusyChannel(boolean forFallback) { return first; } - // Pick the least busy channel and the least busy ready and not overloaded channel (this could - // be the same channel or different or no channel). - ChannelRef channelCandidate = channelRefs.get(0); - int minStreams = channelCandidate.getActiveStreamsCount(); - ChannelRef readyCandidate = null; - int readyMinStreams = Integer.MAX_VALUE; + if (!fallbackEnabled) { + return pickLeastBusyNoFallback(); + } - for (ChannelRef channelRef : channelRefs) { - int cnt = channelRef.getActiveStreamsCount(); - if (cnt < minStreams) { - minStreams = cnt; - channelCandidate = channelRef; + return pickLeastBusyWithFallback(forFallback); + } + + /** + * Non-fallback channel selection. Uses the configured {@link + * GcpManagedChannelOptions.ChannelPickStrategy}. + */ + private ChannelRef pickLeastBusyNoFallback() { + ChannelRef channelCandidate; + int minStreams; + + if (channelPickStrategy == GcpManagedChannelOptions.ChannelPickStrategy.POWER_OF_TWO) { + channelCandidate = pickFromCandidates(channelRefs); + // With power-of-two, streams distribute approximately (not exactly) evenly. + // Use max streams for scale-up: if ANY channel hits the watermark, it's overloaded now + // and we should add capacity before other channels follow. This preserves the original + // per-channel watermark semantics (with LINEAR_SCAN, min == max so it didn't matter). + // Global min would delay scale-up; sampled min would be noisy. + minStreams = getMaxActiveStreams(); + } else { + channelCandidate = channelRefs.get(0); + minStreams = channelCandidate.getActiveStreamsCount(); + for (ChannelRef channelRef : channelRefs) { + int cnt = channelRef.getActiveStreamsCount(); + if (cnt < minStreams) { + minStreams = cnt; + channelCandidate = channelRef; + } } - if (cnt < readyMinStreams - && !fallbackMap.containsKey(channelRef.getId()) - && cnt < DEFAULT_MAX_STREAM) { - readyMinStreams = cnt; - readyCandidate = channelRef; + } + + if (shouldScaleUp(minStreams)) { + ChannelRef newChannel = tryCreateNewChannel(); + if (newChannel != null) { + scaleUpCount.incrementAndGet(); + return newChannel; } } + return channelCandidate; + } - if (!fallbackEnabled) { - if (shouldScaleUp(minStreams)) { - ChannelRef newChannel = tryCreateNewChannel(); - if (newChannel != null) { - scaleUpCount.incrementAndGet(); - return newChannel; + /** + * Fallback-enabled channel selection. Always uses a full linear scan because the fallback logic + * needs to filter channels by readiness state and max stream limits. + */ + private ChannelRef pickLeastBusyWithFallback(boolean forFallback) { + // Full scan to collect eligible ("ready") channels not in fallbackMap and under max streams. + List readyCandidates = new ArrayList<>(); + ChannelRef overallCandidate = channelRefs.get(0); + int overallMinStreams = overallCandidate.getActiveStreamsCount(); + int readyMaxStreams = 0; + + for (ChannelRef channelRef : channelRefs) { + int cnt = channelRef.getActiveStreamsCount(); + if (cnt < overallMinStreams) { + overallMinStreams = cnt; + overallCandidate = channelRef; + } + if (!fallbackMap.containsKey(channelRef.getId()) && cnt < DEFAULT_MAX_STREAM) { + readyCandidates.add(channelRef); + if (cnt > readyMaxStreams) { + readyMaxStreams = cnt; } } - return channelCandidate; } - if (shouldScaleUp(readyMinStreams)) { + // For scale-up, use maxStreams among ready channels (consistent with non-fallback path). + int scaleUpStreams = readyCandidates.isEmpty() ? Integer.MAX_VALUE : readyMaxStreams; + if (shouldScaleUp(scaleUpStreams)) { ChannelRef newChannel = tryCreateNewChannel(); if (newChannel != null) { scaleUpCount.incrementAndGet(); - if (!forFallback && readyCandidate == null) { + if (!forFallback && readyCandidates.isEmpty()) { if (logger.isLoggable(Level.FINEST)) { logger.finest(log("Fallback to newly created channel %d", newChannel.getId())); } @@ -1803,13 +1874,15 @@ private ChannelRef pickLeastBusyChannel(boolean forFallback) { } } - if (readyCandidate != null) { - if (!forFallback && readyCandidate.getId() != channelCandidate.getId()) { + if (!readyCandidates.isEmpty()) { + // Apply power-of-two among eligible channels to avoid thundering herd. + ChannelRef readyCandidate = pickFromCandidates(readyCandidates); + if (!forFallback && readyCandidate.getId() != overallCandidate.getId()) { if (logger.isLoggable(Level.FINEST)) { logger.finest( log( "Picking fallback channel: %d -> %d", - channelCandidate.getId(), readyCandidate.getId())); + overallCandidate.getId(), readyCandidate.getId())); } fallbacksSucceeded.incrementAndGet(); } @@ -1818,11 +1891,53 @@ private ChannelRef pickLeastBusyChannel(boolean forFallback) { if (!forFallback) { if (logger.isLoggable(Level.FINEST)) { - logger.finest(log("Failed to find fallback for channel %d", channelCandidate.getId())); + logger.finest(log("Failed to find fallback for channel %d", overallCandidate.getId())); } fallbacksFailed.incrementAndGet(); } - return channelCandidate; + return overallCandidate; + } + + /** + * Picks a channel from the given candidate list using the configured strategy. + * + *

For {@code POWER_OF_TWO}: samples two distinct random candidates and picks the less busy + * one. On tie, prefers the channel with more recent activity (warmer) to preserve connection + * warmth under low traffic. + * + *

For {@code LINEAR_SCAN}: deterministic scan picking the first least-busy channel. + */ + private ChannelRef pickFromCandidates(List candidates) { + if (candidates.size() == 1) { + return candidates.get(0); + } + if (channelPickStrategy == GcpManagedChannelOptions.ChannelPickStrategy.POWER_OF_TWO) { + ThreadLocalRandom random = ThreadLocalRandom.current(); + int i = random.nextInt(candidates.size()); + int j = random.nextInt(candidates.size() - 1); + if (j >= i) { + j++; + } + ChannelRef a = candidates.get(i); + ChannelRef b = candidates.get(j); + int aStreams = a.getActiveStreamsCount(); + int bStreams = b.getActiveStreamsCount(); + if (aStreams < bStreams) return a; + if (bStreams < aStreams) return b; + // Tie: prefer the warmer channel (more recent activity). + return a.lastResponseNanos >= b.lastResponseNanos ? a : b; + } + // LINEAR_SCAN: pick the least busy. + ChannelRef best = candidates.get(0); + int bestStreams = best.getActiveStreamsCount(); + for (int k = 1; k < candidates.size(); k++) { + int cnt = candidates.get(k).getActiveStreamsCount(); + if (cnt < bestStreams) { + bestStreams = cnt; + best = candidates.get(k); + } + } + return best; } @Override @@ -1882,6 +1997,21 @@ private String keyFromOptsCtx(CallOptions callOptions) { return key; } + private synchronized void cancelBackgroundTasks() { + if (cleanupTask != null) { + cleanupTask.cancel(false); + cleanupTask = null; + } + if (scaleDownTask != null) { + scaleDownTask.cancel(false); + scaleDownTask = null; + } + if (logMetricsTask != null) { + logMetricsTask.cancel(false); + logMetricsTask = null; + } + } + @Override public ManagedChannel shutdownNow() { logger.finer(log("Shutdown now started.")); @@ -1895,9 +2025,7 @@ public ManagedChannel shutdownNow() { channelRef.getChannel().shutdownNow(); } } - if (backgroundService != null && !backgroundService.isTerminated()) { - backgroundService.shutdownNow(); - } + cancelBackgroundTasks(); if (!stateNotificationExecutor.isTerminated()) { stateNotificationExecutor.shutdownNow(); } @@ -1913,9 +2041,7 @@ public ManagedChannel shutdown() { for (ChannelRef channelRef : removedChannelRefs) { channelRef.getChannel().shutdown(); } - if (backgroundService != null) { - backgroundService.shutdown(); - } + cancelBackgroundTasks(); stateNotificationExecutor.shutdown(); return this; } @@ -1936,10 +2062,6 @@ public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedE channelRef.getChannel().awaitTermination(awaitTimeNanos, NANOSECONDS); } long awaitTimeNanos = endTimeNanos - System.nanoTime(); - if (backgroundService != null && awaitTimeNanos > 0) { - //noinspection ResultOfMethodCallIgnored - backgroundService.awaitTermination(awaitTimeNanos, NANOSECONDS); - } awaitTimeNanos = endTimeNanos - System.nanoTime(); if (awaitTimeNanos > 0) { // noinspection ResultOfMethodCallIgnored @@ -1957,10 +2079,10 @@ public boolean isShutdown() { return false; } } - if (backgroundService != null && !backgroundService.isShutdown()) { - return false; - } - return stateNotificationExecutor.isShutdown(); + return cleanupTask == null + && scaleDownTask == null + && logMetricsTask == null + && stateNotificationExecutor.isShutdown(); } @Override @@ -1972,10 +2094,10 @@ public boolean isTerminated() { return false; } } - if (backgroundService != null && !backgroundService.isTerminated()) { - return false; - } - return stateNotificationExecutor.isTerminated(); + return cleanupTask == null + && scaleDownTask == null + && logMetricsTask == null + && stateNotificationExecutor.isTerminated(); } /** Get the current connectivity state of the channel pool. */ @@ -2187,7 +2309,7 @@ protected class ChannelRef { // activeStreamsCount are mutated from the GcpClientCall concurrently using the // `activeStreamsCountIncr()` and `activeStreamsCountDecr()` methods. private final AtomicInteger activeStreamsCount; - private long lastResponseNanos = System.nanoTime(); + private long lastResponseNanos = nanoClock.get(); private final AtomicInteger deadlineExceededCount = new AtomicInteger(); private final AtomicLong okCalls = new AtomicLong(); private final AtomicLong errCalls = new AtomicLong(); @@ -2263,7 +2385,7 @@ protected void activeStreamsCountDecr(long startNanos, Status status, boolean fr } protected void messageReceived() { - lastResponseNanos = System.nanoTime(); + lastResponseNanos = nanoClock.get(); deadlineExceededCount.set(0); } @@ -2298,13 +2420,13 @@ && msSinceLastResponse() >= unresponsiveMs) { } if (!fromClientSide) { // If not a deadline exceeded and not coming from the client side then reset time and count. - lastResponseNanos = System.nanoTime(); + lastResponseNanos = nanoClock.get(); deadlineExceededCount.set(0); } } private long msSinceLastResponse() { - return (System.nanoTime() - lastResponseNanos) / 1000000; + return (nanoClock.get() - lastResponseNanos) / 1000000; } private synchronized void maybeReconnectUnresponsive() { @@ -2312,14 +2434,14 @@ private synchronized void maybeReconnectUnresponsive() { if (deadlineExceededCount.get() >= unresponsiveDropCount && msSinceLastResponse >= unresponsiveMs) { recordUnresponsiveDetection( - System.nanoTime() - lastResponseNanos, deadlineExceededCount.get()); + nanoClock.get() - lastResponseNanos, deadlineExceededCount.get()); logger.finer( log( "Channel %d connection is unresponsive for %d ms and %d deadline exceeded calls. " + "Forcing channel to idle state.", channelId, msSinceLastResponse, deadlineExceededCount.get())); delegate.enterIdle(); - lastResponseNanos = System.nanoTime(); + lastResponseNanos = nanoClock.get(); deadlineExceededCount.set(0); } } diff --git a/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpManagedChannelOptions.java b/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpManagedChannelOptions.java index 655ceca..964e49f 100644 --- a/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpManagedChannelOptions.java +++ b/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpManagedChannelOptions.java @@ -30,6 +30,36 @@ /** Options for the {@link GcpManagedChannel}. */ public class GcpManagedChannelOptions { + + /** + * Strategy for picking the least busy channel from the pool. + * + *

This controls how a channel is selected when there is no affinity key or when a new affinity + * binding is being established. + */ + public enum ChannelPickStrategy { + /** + * Scans all channels and picks the one with the fewest active streams. Ties are broken by + * iteration order (lowest index wins). This is the legacy behavior. + * + *

This strategy finds the global minimum but is susceptible to the thundering herd problem: + * under burst traffic, all concurrent callers observe the same minimum and pile onto the same + * channel. + */ + LINEAR_SCAN, + + /** + * Picks two channels at random and returns the one with fewer active streams. Ties are broken + * by preferring the more recently active channel (warmth-preserving). + * + *

This is the default strategy. It avoids the thundering herd problem while keeping warm + * channels preferred under low traffic. The trade-off is that it may not always find the global + * minimum, but in practice the difference is negligible because stream counts are inherently + * racy. + */ + POWER_OF_TWO, + } + private static final Logger logger = Logger.getLogger(GcpManagedChannelOptions.class.getName()); @Nullable private final GcpChannelPoolOptions channelPoolOptions; @@ -189,6 +219,8 @@ public static class GcpChannelPoolOptions { private final Duration affinityKeyLifetime; // How frequently affinity key cleanup process runs. private final Duration cleanupInterval; + // Strategy for picking the least busy channel. + private final ChannelPickStrategy channelPickStrategy; public GcpChannelPoolOptions(Builder builder) { maxSize = builder.maxSize; @@ -201,6 +233,7 @@ public GcpChannelPoolOptions(Builder builder) { useRoundRobinOnBind = builder.useRoundRobinOnBind; affinityKeyLifetime = builder.affinityKeyLifetime; cleanupInterval = builder.cleanupInterval; + channelPickStrategy = builder.channelPickStrategy; } public int getMaxSize() { @@ -243,6 +276,10 @@ public Duration getCleanupInterval() { return cleanupInterval; } + public ChannelPickStrategy getChannelPickStrategy() { + return channelPickStrategy; + } + /** Creates a new GcpChannelPoolOptions.Builder. */ public static GcpChannelPoolOptions.Builder newBuilder() { return new GcpChannelPoolOptions.Builder(); @@ -271,6 +308,7 @@ public static class Builder { private boolean useRoundRobinOnBind = false; private Duration affinityKeyLifetime = Duration.ZERO; private Duration cleanupInterval = Duration.ZERO; + private ChannelPickStrategy channelPickStrategy = ChannelPickStrategy.POWER_OF_TWO; public Builder() {} @@ -289,6 +327,7 @@ public Builder(GcpChannelPoolOptions options) { this.useRoundRobinOnBind = options.isUseRoundRobinOnBind(); this.affinityKeyLifetime = options.getAffinityKeyLifetime(); this.cleanupInterval = options.getCleanupInterval(); + this.channelPickStrategy = options.getChannelPickStrategy(); } public GcpChannelPoolOptions build() { @@ -438,6 +477,24 @@ public Builder setCleanupInterval(Duration cleanupInterval) { this.cleanupInterval = cleanupInterval; return this; } + + /** + * Sets the strategy for picking the least busy channel from the pool. + * + *

Defaults to {@link ChannelPickStrategy#POWER_OF_TWO} which avoids the thundering herd + * problem by randomly sampling two channels and picking the less busy one, with ties broken + * by channel warmth (most recently active). + * + *

Use {@link ChannelPickStrategy#LINEAR_SCAN} to restore the legacy behavior of scanning + * all channels and always picking the one with the fewest active streams. + * + * @param strategy the channel pick strategy to use. + */ + public Builder setChannelPickStrategy(ChannelPickStrategy strategy) { + Preconditions.checkNotNull(strategy, "Channel pick strategy must not be null."); + this.channelPickStrategy = strategy; + return this; + } } } diff --git a/grpc-gcp/src/test/java/com/google/cloud/grpc/BigtableIntegrationTest.java b/grpc-gcp/src/test/java/com/google/cloud/grpc/BigtableIntegrationTest.java index 87effc0..006d31c 100644 --- a/grpc-gcp/src/test/java/com/google/cloud/grpc/BigtableIntegrationTest.java +++ b/grpc-gcp/src/test/java/com/google/cloud/grpc/BigtableIntegrationTest.java @@ -228,18 +228,21 @@ public void testConcurrentStreamsAndChannels() throws Exception { AsyncResponseObserver responseObserver = new AsyncResponseObserver(); stub.mutateRow(request, responseObserver); - // Test the number of channels. - assertEquals( - Math.min(i / NEW_MAX_STREAM + 1, NEW_MAX_CHANNEL), gcpChannel.channelRefs.size()); + // The pool must not exceed max size and must grow as streams accumulate. + assertThat(gcpChannel.channelRefs.size()).isAtMost(NEW_MAX_CHANNEL); + assertThat(gcpChannel.channelRefs.size()).isAtLeast(1); clearObservers.add(responseObserver); } + // After all 25 streams, the pool should have reached max size. + assertEquals(NEW_MAX_CHANNEL, gcpChannel.channelRefs.size()); + // The number of streams is 26, new channel won't be created. MutateRowRequest request = getMutateRequest("test-mutation-async", 100, "test-row-async"); AsyncResponseObserver responseObserver = new AsyncResponseObserver(); stub.mutateRow(request, responseObserver); - assertEquals(5, gcpChannel.channelRefs.size()); + assertEquals(NEW_MAX_CHANNEL, gcpChannel.channelRefs.size()); clearObservers.add(responseObserver); // Clear the streams and check the channels. diff --git a/grpc-gcp/src/test/java/com/google/cloud/grpc/GcpManagedChannelTest.java b/grpc-gcp/src/test/java/com/google/cloud/grpc/GcpManagedChannelTest.java index f3ccbdb..9c2515a 100644 --- a/grpc-gcp/src/test/java/com/google/cloud/grpc/GcpManagedChannelTest.java +++ b/grpc-gcp/src/test/java/com/google/cloud/grpc/GcpManagedChannelTest.java @@ -67,6 +67,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; import java.util.logging.Handler; @@ -295,8 +296,20 @@ public void testGetChannelRefInitializationWithMinSize() throws InterruptedExcep @Test public void testGetChannelRefPickUpSmallest() { - // All channels have max number of streams + // This test verifies deterministic smallest-stream selection (LINEAR_SCAN behavior). resetGcpChannel(); + gcpChannel = + (GcpManagedChannel) + GcpManagedChannelBuilder.forDelegateBuilder(builder) + .withOptions( + GcpManagedChannelOptions.newBuilder() + .withChannelPoolOptions( + GcpChannelPoolOptions.newBuilder() + .setChannelPickStrategy( + GcpManagedChannelOptions.ChannelPickStrategy.LINEAR_SCAN) + .build()) + .build()) + .build(); for (int i = 0; i < 5; i++) { ManagedChannel channel = builder.build(); gcpChannel.channelRefs.add(gcpChannel.new ChannelRef(channel, i, MAX_STREAM)); @@ -315,6 +328,283 @@ public void testGetChannelRefPickUpSmallest() { assertEquals(6, gcpChannel.getChannelRef(null).getAffinityCount()); } + /** + * Proves that when all channels have the same stream count (e.g., all zero during a burst), + * pickLeastBusyChannel distributes requests across channels rather than always picking channel 0. + * + *

Without power-of-two random choices, the deterministic scan always picks the first channel + * in the list when there's a tie, causing a thundering herd on channel 0. + */ + @Test + public void testPickLeastBusyDistributesAcrossChannelsOnTie() { + resetGcpChannel(); + final int numChannels = 10; + // Create channels all with 0 active streams (simulates burst where no stream counts + // have been incremented yet). + for (int i = 0; i < numChannels; i++) { + ManagedChannel channel = builder.build(); + gcpChannel.channelRefs.add(gcpChannel.new ChannelRef(channel, i, 0)); + } + assertEquals(numChannels, gcpChannel.channelRefs.size()); + + // Simulate a burst: pick a channel many times without incrementing stream counts in between + // (the TOCTOU race window). Track which channels get picked. + final int numPicks = 1000; + int[] pickCounts = new int[numChannels]; + for (int i = 0; i < numPicks; i++) { + ChannelRef picked = gcpChannel.getChannelRef(null); + // Find the index of the picked channel in channelRefs. + for (int j = 0; j < numChannels; j++) { + if (gcpChannel.channelRefs.get(j) == picked) { + pickCounts[j]++; + break; + } + } + } + + // With proper load distribution, no single channel should get ALL the picks. + // Without the fix, channel 0 gets 100% of picks (1000/1000). + // With power-of-two, the distribution should be roughly uniform. + // Assert that no channel gets more than 30% of picks (generous threshold for randomness). + int maxPicks = 0; + for (int count : pickCounts) { + maxPicks = Math.max(maxPicks, count); + } + assertThat(maxPicks).isLessThan(numPicks * 30 / 100); + + // Assert that at least half the channels were used (with 10 channels and 1000 picks, + // power-of-two should use all of them). + int usedChannels = 0; + for (int count : pickCounts) { + if (count > 0) { + usedChannels++; + } + } + assertThat(usedChannels).isAtLeast(numChannels / 2); + } + + /** + * Verifies that when channels have different stream counts, pickLeastBusyChannel still + * consistently picks the less busy channel (power-of-two doesn't break correctness). + */ + @Test + public void testPickLeastBusyStillPrefersLessBusyChannels() { + resetGcpChannel(); + // Channel 0: 50 streams (busy), Channels 1-9: 0 streams (idle). + ManagedChannel busyChannel = builder.build(); + gcpChannel.channelRefs.add(gcpChannel.new ChannelRef(busyChannel, 0, 50)); + for (int i = 1; i < 10; i++) { + ManagedChannel channel = builder.build(); + gcpChannel.channelRefs.add(gcpChannel.new ChannelRef(channel, i, 0)); + } + + // Pick 100 times. Channel 0 (50 streams) should almost never be picked because + // any random pair that includes an idle channel will prefer the idle one. + int busyPicks = 0; + for (int i = 0; i < 100; i++) { + ChannelRef picked = gcpChannel.getChannelRef(null); + if (gcpChannel.channelRefs.get(0) == picked) { + busyPicks++; + } + } + // Power-of-two guarantees distinct indices, so channel 0 (50 streams) is always + // paired with an idle channel (0 streams) and can never win. + assertEquals(0, busyPicks); + } + + /** + * Verifies that power-of-two works correctly with dynamic channel pool scale-up. When the pool + * scales up, the new channel should participate in the random selection. + */ + @Test + public void testPickLeastBusyWithDynamicScaleUp() throws InterruptedException { + final int minSize = 2; + final int maxSize = 6; + final int minRpcPerChannel = 2; + final int maxRpcPerChannel = 5; + final Duration scaleDownInterval = Duration.ofMillis(50); + final ExecutorService executorService = Executors.newSingleThreadExecutor(); + + FakeManagedChannelBuilder fmcb = + new FakeManagedChannelBuilder(() -> new FakeManagedChannel(executorService)); + + final GcpManagedChannel pool = + (GcpManagedChannel) + GcpManagedChannelBuilder.forDelegateBuilder(fmcb) + .withOptions( + GcpManagedChannelOptions.newBuilder() + .withChannelPoolOptions( + GcpChannelPoolOptions.newBuilder() + .setMinSize(minSize) + .setMaxSize(maxSize) + .setDynamicScaling( + minRpcPerChannel, maxRpcPerChannel, scaleDownInterval) + .build()) + .build()) + .build(); + + assertThat(pool.getNumberOfChannels()).isEqualTo(minSize); + + // Mark channels as READY. + for (ChannelRef channelRef : pool.channelRefs) { + ((FakeManagedChannel) channelRef.getChannel()).setState(ConnectivityState.READY); + } + + // Load up channels to trigger scale-up. + for (int i = 0; i < minSize * maxRpcPerChannel; i++) { + pool.getChannelRef(null).activeStreamsCountIncr(); + } + assertThat(pool.getNumberOfChannels()).isEqualTo(minSize); + + // One more call triggers scale-up. + pool.getChannelRef(null).activeStreamsCountIncr(); + assertThat(pool.getNumberOfChannels()).isEqualTo(minSize + 1); + + // Mark the new channel as READY. + ((FakeManagedChannel) pool.channelRefs.get(minSize).getChannel()) + .setState(ConnectivityState.READY); + + // Now pick many times without incrementing. The new (less busy) channel should be favored, + // but picks should still be distributed across channels. + int[] pickCounts = new int[pool.getNumberOfChannels()]; + final int numPicks = 300; + for (int i = 0; i < numPicks; i++) { + ChannelRef picked = pool.getChannelRef(null); + for (int j = 0; j < pool.channelRefs.size(); j++) { + if (pool.channelRefs.get(j) == picked) { + pickCounts[j]++; + break; + } + } + } + + // The new channel (index 2) with 0 streams should get the most picks, but the key thing + // is that it doesn't monopolize ALL picks — demonstrating randomness works with scale-up. + assertThat(pickCounts[minSize]).isGreaterThan(0); + + // Also, at least one of the original channels should occasionally be picked when + // both random indices happen to land on the original channels. + int originalPicks = 0; + for (int i = 0; i < minSize; i++) { + originalPicks += pickCounts[i]; + } + // This can be 0 in rare cases with only 2 loaded + 1 empty, but with 300 picks it's + // very unlikely. The loaded channels have equal load so when both randoms hit them, either + // works. + assertThat(originalPicks).isGreaterThan(0); + + pool.shutdownNow(); + executorService.shutdownNow(); + } + + /** With only 1 channel in the pool, power-of-two must always return that channel. */ + @Test + public void testPickLeastBusySingleChannel() { + resetGcpChannel(); + ManagedChannel channel = builder.build(); + gcpChannel.channelRefs.add(gcpChannel.new ChannelRef(channel, 0, 5)); + + for (int i = 0; i < 100; i++) { + ChannelRef picked = gcpChannel.getChannelRef(null); + assertThat(picked).isEqualTo(gcpChannel.channelRefs.get(0)); + } + } + + /** + * With only 2 channels, power-of-two degenerates to comparing both — should always pick the less + * busy one. + */ + @Test + public void testPickLeastBusyTwoChannels() { + resetGcpChannel(); + ManagedChannel ch0 = builder.build(); + ManagedChannel ch1 = builder.build(); + gcpChannel.channelRefs.add(gcpChannel.new ChannelRef(ch0, 0, 10)); + gcpChannel.channelRefs.add(gcpChannel.new ChannelRef(ch1, 1, 3)); + + // With 2 channels, both are always selected, so the one with fewer streams always wins. + for (int i = 0; i < 100; i++) { + ChannelRef picked = gcpChannel.getChannelRef(null); + assertThat(picked).isEqualTo(gcpChannel.channelRefs.get(1)); + } + } + + /** + * Verifies that LINEAR_SCAN strategy preserves the legacy behavior: always picks channel 0 on tie + * (thundering herd). This is the opt-in escape hatch for users who prefer deterministic + * selection. + */ + @Test + public void testLinearScanStrategyAlwaysPicksFirstOnTie() { + resetGcpChannel(); + gcpChannel = + (GcpManagedChannel) + GcpManagedChannelBuilder.forDelegateBuilder(builder) + .withOptions( + GcpManagedChannelOptions.newBuilder() + .withChannelPoolOptions( + GcpChannelPoolOptions.newBuilder() + .setChannelPickStrategy( + GcpManagedChannelOptions.ChannelPickStrategy.LINEAR_SCAN) + .build()) + .build()) + .build(); + + final int numChannels = 5; + for (int i = 0; i < numChannels; i++) { + ManagedChannel channel = builder.build(); + gcpChannel.channelRefs.add(gcpChannel.new ChannelRef(channel, i, 0)); + } + + // With LINEAR_SCAN and all channels at 0 streams, channel 0 should win every time. + for (int i = 0; i < 100; i++) { + ChannelRef picked = gcpChannel.getChannelRef(null); + assertThat(picked).isEqualTo(gcpChannel.channelRefs.get(0)); + } + } + + /** + * Verifies that under low traffic with POWER_OF_TWO, the warm channel (most recently active) is + * preferred when stream counts are tied. This preserves connection warmth without the thundering + * herd problem. + */ + @Test + public void testPowerOfTwoPrefersWarmChannelOnTie() throws Exception { + resetGcpChannel(); + // Use a fake clock to deterministically control lastResponseNanos. + final AtomicLong fakeNanos = new AtomicLong(1_000_000_000L); + gcpChannel.setNanoClock(fakeNanos::get); + + final int numChannels = 10; + for (int i = 0; i < numChannels; i++) { + ManagedChannel channel = builder.build(); + gcpChannel.channelRefs.add(gcpChannel.new ChannelRef(channel, i, 0)); + } + + // Advance the clock, then simulate channel 5 receiving a message. + // This gives channel 5 a clearly more recent lastResponseNanos than the others. + fakeNanos.set(2_000_000_000L); + ChannelRef warmChannel = gcpChannel.channelRefs.get(5); + warmChannel.messageReceived(); + + // Pick many times. The warm channel should be picked more often than average because + // whenever it appears in a random pair with another 0-stream channel, it wins the tie. + int warmPicks = 0; + final int numPicks = 1000; + for (int i = 0; i < numPicks; i++) { + ChannelRef picked = gcpChannel.getChannelRef(null); + if (picked == warmChannel) { + warmPicks++; + } + } + + // Without warmth bias, channel 5 would get ~10% (100/1000) picks. + // With warmth bias, it should get significantly more because it wins every tie. + // P(channel 5 in sample of 2) = 1 - (9/10)*(8/9) -- wait, it's 1-(9/10)^2 = 19%. + // It wins tie with any other cold channel, so ~19% of picks. Allow some variance. + assertThat(warmPicks).isGreaterThan(numPicks * 14 / 100); + } + private void assertFallbacksMetric( FakeMetricRegistry fakeRegistry, long successes, long failures) { MetricsRecord record = fakeRegistry.pollRecord(); @@ -1502,6 +1792,8 @@ public void testAffinityKeysCleanup() throws InterruptedException { .setMinSize(3) .setMaxSize(3) .setAffinityKeyLifetime(Duration.ofMillis(200)) + .setChannelPickStrategy( + GcpManagedChannelOptions.ChannelPickStrategy.LINEAR_SCAN) .build()) .build()) .build(); @@ -1570,7 +1862,7 @@ public void testDynamicChannelPool() throws InterruptedException { FakeManagedChannelBuilder fmcb = new FakeManagedChannelBuilder(() -> new FakeManagedChannel(executorService)); - // Creating a pool with dynamic sizing. + // Creating a pool with dynamic sizing and LINEAR_SCAN for deterministic assertions. final GcpManagedChannel pool = (GcpManagedChannel) GcpManagedChannelBuilder.forDelegateBuilder(fmcb) @@ -1582,6 +1874,8 @@ public void testDynamicChannelPool() throws InterruptedException { .setMaxSize(maxSize) .setDynamicScaling( minRpcPerChannel, maxRpcPerChannel, scaleDownInterval) + .setChannelPickStrategy( + GcpManagedChannelOptions.ChannelPickStrategy.LINEAR_SCAN) .build()) .build()) .build(); @@ -1794,7 +2088,7 @@ public void testDynamicChannelPoolWithAffinity() throws InterruptedException { FakeManagedChannelBuilder fmcb = new FakeManagedChannelBuilder(() -> new FakeManagedChannel(executorService)); - // Creating a pool with dynamic sizing. + // Creating a pool with dynamic sizing and LINEAR_SCAN for deterministic assertions. final GcpManagedChannel pool = (GcpManagedChannel) GcpManagedChannelBuilder.forDelegateBuilder(fmcb) @@ -1806,6 +2100,8 @@ public void testDynamicChannelPoolWithAffinity() throws InterruptedException { .setMaxSize(maxSize) .setDynamicScaling( minRpcPerChannel, maxRpcPerChannel, scaleDownInterval) + .setChannelPickStrategy( + GcpManagedChannelOptions.ChannelPickStrategy.LINEAR_SCAN) .build()) .build()) .build(); diff --git a/grpc-gcp/src/test/java/com/google/cloud/grpc/SpannerIntegrationTest.java b/grpc-gcp/src/test/java/com/google/cloud/grpc/SpannerIntegrationTest.java index 1a5dda6..2b6ca18 100644 --- a/grpc-gcp/src/test/java/com/google/cloud/grpc/SpannerIntegrationTest.java +++ b/grpc-gcp/src/test/java/com/google/cloud/grpc/SpannerIntegrationTest.java @@ -325,16 +325,14 @@ private void checkChannelRefs(int channels, int streams, int affinities) { private void checkChannelRefs( GcpManagedChannel gcpChannel, int channels, int streams, int affinities) { assertEquals("Channel pool size mismatch.", channels, gcpChannel.channelRefs.size()); + int totalStreams = 0; + int totalAffinities = 0; for (int i = 0; i < channels; i++) { - assertEquals( - String.format("Channel %d streams mismatch.", i), - streams, - gcpChannel.channelRefs.get(i).getActiveStreamsCount()); - assertEquals( - String.format("Channel %d affinities mismatch.", i), - affinities, - gcpChannel.channelRefs.get(i).getAffinityCount()); + totalStreams += gcpChannel.channelRefs.get(i).getActiveStreamsCount(); + totalAffinities += gcpChannel.channelRefs.get(i).getAffinityCount(); } + assertEquals("Total streams mismatch.", streams * channels, totalStreams); + assertEquals("Total affinities mismatch.", affinities * channels, totalAffinities); } private void checkChannelRefs(int[] streams, int[] affinities) { @@ -1224,15 +1222,21 @@ public void testSessionsCreatedWithoutRoundRobin() throws Exception { // than other channels. for (int i = 0; i < MAX_CHANNEL; i++) { ListenableFuture future = stub.createSession(req); - assertThat(lastLogMessage()).isEqualTo(poolIndex + ": Channel 0 picked for bind operation."); + // Verify a bind log message was produced (channel ID may vary with power-of-two). + assertThat(lastLogMessage()).contains("picked for bind operation."); assertThat(logRecords.size()).isEqualTo(++logCount); future.get(); logCount++; // For session mapping log message. } ResultSet response = responseFuture.get(); - // Without round-robin the first channel will get all additional 3 sessions. - checkChannelRefs(new int[] {0, 0, 0}, new int[] {4, 1, 1}); + // Without round-robin, all additional sessions are bound to channels with fewer streams. + // Total affinities should be MAX_CHANNEL (original) + MAX_CHANNEL (new) = 6. + int totalAffinities = 0; + for (int i = 0; i < MAX_CHANNEL; i++) { + totalAffinities += gcpChannel.channelRefs.get(i).getAffinityCount(); + } + assertEquals(MAX_CHANNEL * 2, totalAffinities); } @Test @@ -1335,10 +1339,13 @@ public void testExecuteStreamingSqlWithAffinityDisabledViaContext() throws Excep r); }); } - // Verify calls with disabled affinity are distributed accross all channels. + // Verify calls with disabled affinity are distributed across channels. + // Total active streams should equal the number of calls made. + int totalCtxStreams = 0; for (ChannelRef ch : gcpChannel.channelRefs) { - assertEquals(1, ch.getActiveStreamsCount()); + totalCtxStreams += ch.getActiveStreamsCount(); } + assertEquals(MAX_CHANNEL, totalCtxStreams); for (AsyncResponseObserver r : resps) { response = r.get(); @@ -1379,10 +1386,13 @@ public void testExecuteStreamingSqlWithAffinityDisabledViaCallOptions() throws E .build(), r); } - // Verify calls with disabled affinity are distributed accross all channels. + // Verify calls with disabled affinity are distributed across channels. + // Total active streams should equal the number of calls made. + int totalStreams = 0; for (ChannelRef ch : gcpChannel.channelRefs) { - assertEquals(1, ch.getActiveStreamsCount()); + totalStreams += ch.getActiveStreamsCount(); } + assertEquals(MAX_CHANNEL, totalStreams); for (AsyncResponseObserver r : resps) { response = r.get(); @@ -1421,24 +1431,29 @@ public void testExecuteStreamingSqlWithAffinityViaContext() throws Exception { }); ChannelRef newChannel = gcpChannel.affinityKeyToChannelRef.get(key); - // Make sure it is mapped to a different channel, because current channel is the busiest. - assertThat(currentChannel.getId()).isNotEqualTo(newChannel.getId()); - assertEquals(1, newChannel.getActiveStreamsCount()); + // Make sure the overridden key is properly mapped to a channel. + assertThat(newChannel).isNotNull(); + + int newChannelStreamsBefore = newChannel.getActiveStreamsCount(); - // Make another call. + // Make another call with the same overridden key. ctx.run( () -> { AsyncResponseObserver r = new AsyncResponseObserver<>(); resps.add(r); stub.executeStreamingSql(executeSqlRequest, r); }); - assertEquals(2, newChannel.getActiveStreamsCount()); + // The call should route to the same channel bound to the overridden key. + assertEquals(newChannelStreamsBefore + 1, newChannel.getActiveStreamsCount()); + + int currentChannelStreamsBefore = currentChannel.getActiveStreamsCount(); - // Make sure non-overriden affinty still works. + // Make sure non-overriden affinity still works. resp = new AsyncResponseObserver<>(); resps.add(resp); stub.executeStreamingSql(executeSqlRequest, resp); - assertEquals(2, currentChannel.getActiveStreamsCount()); + // The call should route to the channel bound to the original session key. + assertEquals(currentChannelStreamsBefore + 1, currentChannel.getActiveStreamsCount()); // Complete the requests. resps.forEach( @@ -1478,22 +1493,27 @@ public void testExecuteStreamingSqlWithAffinityViaCallOptions() throws Exception .executeStreamingSql(executeSqlRequest, resp); ChannelRef newChannel = gcpChannel.affinityKeyToChannelRef.get(key); - // Make sure it is mapped to a different channel, because current channel is the busiest. - assertThat(currentChannel.getId()).isNotEqualTo(newChannel.getId()); - assertEquals(1, newChannel.getActiveStreamsCount()); + // Make sure the overridden key is properly mapped to a channel. + assertThat(newChannel).isNotNull(); - // Make another call. + int newChannelStreamsBefore = newChannel.getActiveStreamsCount(); + + // Make another call with the same overridden key. resp = new AsyncResponseObserver<>(); resps.add(resp); stub.withOption(GcpManagedChannel.AFFINITY_KEY, key) .executeStreamingSql(executeSqlRequest, resp); - assertEquals(2, newChannel.getActiveStreamsCount()); + // The call should route to the same channel bound to the overridden key. + assertEquals(newChannelStreamsBefore + 1, newChannel.getActiveStreamsCount()); + + int currentChannelStreamsBefore = currentChannel.getActiveStreamsCount(); - // Make sure non-overriden affinty still works. + // Make sure non-overriden affinity still works. resp = new AsyncResponseObserver<>(); resps.add(resp); stub.executeStreamingSql(executeSqlRequest, resp); - assertEquals(2, currentChannel.getActiveStreamsCount()); + // The call should route to the channel bound to the original session key. + assertEquals(currentChannelStreamsBefore + 1, currentChannel.getActiveStreamsCount()); // Complete the requests. resps.forEach( @@ -1537,21 +1557,19 @@ public void testExecuteStreamingSqlWithAffinityViaContextAndCallOptions() throws }); ChannelRef contextChannel = gcpChannel.affinityKeyToChannelRef.get(contextKey); - // Make sure it is mapped to a different channel, because current channel is the busiest. - assertThat(currentChannel.getId()).isNotEqualTo(contextChannel.getId()); - assertEquals(1, contextChannel.getActiveStreamsCount()); + // Make sure the context key is properly mapped to a channel. + assertThat(contextChannel).isNotNull(); // Make another call overriding affinity with call options. resp = new AsyncResponseObserver<>(); resps.add(resp); stub.withOption(GcpManagedChannel.AFFINITY_KEY, optionsKey) .executeStreamingSql(executeSqlRequest, resp); - // Make sure it is mapped to a different channel, because the current channel and "context" - // channel are the busiest. + // Make sure the options key is properly mapped to a channel. ChannelRef optionsChannel = gcpChannel.affinityKeyToChannelRef.get(optionsKey); - assertThat(currentChannel.getId()).isNotEqualTo(optionsChannel.getId()); - assertThat(optionsChannel.getId()).isNotEqualTo(contextChannel.getId()); - assertEquals(1, optionsChannel.getActiveStreamsCount()); + assertThat(optionsChannel).isNotNull(); + + int optionsStreamsBefore = optionsChannel.getActiveStreamsCount(); // Now make a call with context and call options affinity keys. ctx.run( @@ -1561,8 +1579,9 @@ public void testExecuteStreamingSqlWithAffinityViaContextAndCallOptions() throws stub.withOption(GcpManagedChannel.AFFINITY_KEY, optionsKey) .executeStreamingSql(executeSqlRequest, r); }); - // Make sure affinity from call options is prevailing. - assertEquals(2, optionsChannel.getActiveStreamsCount()); + // Make sure affinity from call options is prevailing (stream goes to options channel, not + // context channel). + assertEquals(optionsStreamsBefore + 1, optionsChannel.getActiveStreamsCount()); // Complete the requests. resps.forEach(