From d9a46a199e7c5a9df5855be93ffd273a12f6f5a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Fri, 10 Apr 2026 16:55:42 +0200 Subject: [PATCH] chore(spanner): use ReplicaSelector in KeyRangeCache Integrates the ReplicaSelector in KeyRangeCache, so it can be used to select the best replica to send the request to. This feature is effectively disabled for production, as the ReplicaSelector is set to null by default. --- .../cloud/spanner/spi/v1/ChannelEndpoint.java | 11 ++ .../cloud/spanner/spi/v1/KeyRangeCache.java | 40 ++++++- .../spi/v1/PowerOfTwoReplicaSelector.java | 14 +-- .../cloud/spanner/spi/v1/ReplicaSelector.java | 5 +- .../spanner/spi/v1/KeyRangeCacheTest.java | 94 ++++++++++++++++ .../spi/v1/PowerOfTwoReplicaSelectorTest.java | 105 +++++++++++++----- 6 files changed, 226 insertions(+), 43 deletions(-) diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelEndpoint.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelEndpoint.java index fc82c530fc6f..2a66d8b00ed6 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelEndpoint.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelEndpoint.java @@ -71,4 +71,15 @@ public interface ChannelEndpoint { * @return the managed channel for this server */ ManagedChannel getChannel(); + + /** + * Returns the latency tracker for this server. + * + *

Default implementation returns {@code null}. + * + * @return the latency tracker, or {@code null} if latency tracking is not supported + */ + default LatencyTracker getLatencyTracker() { + return null; + } } diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java index 59955ccb4bd2..0324587fc6b5 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java @@ -18,6 +18,7 @@ import com.google.api.core.InternalApi; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.Lists; import com.google.common.hash.Hashing; import com.google.protobuf.ByteString; import com.google.spanner.v1.CacheUpdate; @@ -66,6 +67,7 @@ public enum RangeMode { private final ChannelEndpointCache endpointCache; @javax.annotation.Nullable private final EndpointLifecycleManager lifecycleManager; + @javax.annotation.Nullable private final ReplicaSelector replicaSelector; private final NavigableMap ranges = new TreeMap<>(ByteString.unsignedLexicographicalComparator()); private final Map groups = new HashMap<>(); @@ -78,14 +80,22 @@ public enum RangeMode { private volatile int minCacheEntriesForRandomPick = DEFAULT_MIN_ENTRIES_FOR_RANDOM_PICK; public KeyRangeCache(ChannelEndpointCache endpointCache) { - this(endpointCache, null); + this(endpointCache, null, null); } public KeyRangeCache( ChannelEndpointCache endpointCache, @javax.annotation.Nullable EndpointLifecycleManager lifecycleManager) { + this(endpointCache, lifecycleManager, null); + } + + public KeyRangeCache( + ChannelEndpointCache endpointCache, + @javax.annotation.Nullable EndpointLifecycleManager lifecycleManager, + @javax.annotation.Nullable ReplicaSelector replicaSelector) { this.endpointCache = Objects.requireNonNull(endpointCache); this.lifecycleManager = lifecycleManager; + this.replicaSelector = replicaSelector; } @VisibleForTesting @@ -802,6 +812,8 @@ && leader().distance <= MAX_LOCAL_REPLICA_DISTANCE) { return leader(); } } + List eligibleTablets = + replicaSelector != null ? new ArrayList<>(tablets.size()) : null; for (int index = 0; index < tablets.size(); index++) { if (checkedLeader && index == leaderIndex) { continue; @@ -813,7 +825,31 @@ && leader().distance <= MAX_LOCAL_REPLICA_DISTANCE) { if (tablet.shouldSkip(hintBuilder, excludedEndpoints, skippedTabletUids)) { continue; } - return tablet; + if (replicaSelector == null) { + return tablet; + } + eligibleTablets.add(tablet); + } + + if (replicaSelector != null) { + return selectReplicaLocked(eligibleTablets); + } + return null; + } + + private CachedTablet selectReplicaLocked(final List eligibleTablets) { + if (eligibleTablets.isEmpty()) { + return null; + } + List candidates = + Lists.transform(eligibleTablets, tablet -> tablet.endpoint); + ChannelEndpoint selectedEndpoint = replicaSelector.select(candidates); + + // The number of eligible tablets is always small, so a linear search is more efficient. + for (CachedTablet tablet : eligibleTablets) { + if (tablet.endpoint.equals(selectedEndpoint)) { + return tablet; + } } return null; } diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelector.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelector.java index c7cc2012a615..08c8b9ac8a72 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelector.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelector.java @@ -18,11 +18,9 @@ import com.google.api.core.BetaApi; import com.google.api.core.InternalApi; -import com.google.common.base.MoreObjects; import java.util.List; import java.util.Random; import java.util.concurrent.ThreadLocalRandom; -import java.util.function.Function; /** Implementation of {@link ReplicaSelector} using the "Power of 2 Random Choices" strategy. */ @InternalApi @@ -30,8 +28,7 @@ public class PowerOfTwoReplicaSelector implements ReplicaSelector { @Override - public ChannelEndpoint select( - List candidates, Function scoreLookup) { + public ChannelEndpoint select(List candidates) { if (candidates == null || candidates.isEmpty()) { return null; } @@ -49,12 +46,11 @@ public ChannelEndpoint select( ChannelEndpoint c1 = candidates.get(index1); ChannelEndpoint c2 = candidates.get(index2); - Double score1 = scoreLookup.apply(c1); - Double score2 = scoreLookup.apply(c2); + LatencyTracker t1 = c1.getLatencyTracker(); + LatencyTracker t2 = c2.getLatencyTracker(); - // Handle null scores by treating them as Double.MAX_VALUE (lowest priority) - double s1 = MoreObjects.firstNonNull(score1, Double.MAX_VALUE); - double s2 = MoreObjects.firstNonNull(score2, Double.MAX_VALUE); + double s1 = t1 != null ? t1.getScore() : Double.MAX_VALUE; + double s2 = t2 != null ? t2.getScore() : Double.MAX_VALUE; return s1 <= s2 ? c1 : c2; } diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ReplicaSelector.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ReplicaSelector.java index de4f58e50f1e..9aa7f1e4a62a 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ReplicaSelector.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ReplicaSelector.java @@ -19,7 +19,6 @@ import com.google.api.core.BetaApi; import com.google.api.core.InternalApi; import java.util.List; -import java.util.function.Function; /** Interface for selecting a replica from a list of candidates. */ @InternalApi @@ -30,9 +29,7 @@ public interface ReplicaSelector { * Selects a replica from the given list of candidates. * * @param candidates the list of eligible candidates. - * @param scoreLookup a function to look up the latency score for a candidate. * @return the selected candidate, or null if the list is empty. */ - ChannelEndpoint select( - List candidates, Function scoreLookup); + ChannelEndpoint select(List candidates); } diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyRangeCacheTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyRangeCacheTest.java index b19123daa704..87be06d0329d 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyRangeCacheTest.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyRangeCacheTest.java @@ -827,11 +827,16 @@ static final class FakeEndpoint implements ChannelEndpoint { private final String address; private final FakeManagedChannel channel = new FakeManagedChannel(); private EndpointHealthState state = EndpointHealthState.READY; + private LatencyTracker latencyTracker = null; FakeEndpoint(String address) { this.address = address; } + void setLatencyTracker(LatencyTracker tracker) { + this.latencyTracker = tracker; + } + @Override public String getAddress() { return address; @@ -852,6 +857,11 @@ public ManagedChannel getChannel() { return channel; } + @Override + public LatencyTracker getLatencyTracker() { + return latencyTracker; + } + void setState(EndpointHealthState state) { this.state = state; channel.setConnectivityState(toConnectivityState(state)); @@ -926,4 +936,88 @@ public String authority() { return "fake"; } } + + @Test + public void usesReplicaSelectorWhenEnabled() { + FakeEndpointCache endpointCache = new FakeEndpointCache(); + + ReplicaSelector mockSelector = + (candidates) -> { + for (ChannelEndpoint candidate : candidates) { + if (candidate.getAddress().equals("server2")) { + return candidate; + } + } + return null; + }; + + KeyRangeCache cache = new KeyRangeCache(endpointCache, null, mockSelector); + cache.addRanges(twoReplicaUpdate()); + + endpointCache.get("server1"); + endpointCache.get("server2"); + + RoutingHint.Builder hint = RoutingHint.newBuilder().setKey(bytes("a")); + ChannelEndpoint server = + cache.fillRoutingHint( + false, + KeyRangeCache.RangeMode.COVERING_SPLIT, + DirectedReadOptions.getDefaultInstance(), + hint); + + assertNotNull(server); + assertEquals("server2", server.getAddress()); + } + + @Test + public void usesLatencyTrackerInSelector() { + FakeEndpointCache endpointCache = new FakeEndpointCache(); + + FakeEndpoint e1 = (FakeEndpoint) endpointCache.get("server1"); + FakeEndpoint e2 = (FakeEndpoint) endpointCache.get("server2"); + + e1.setLatencyTracker( + new LatencyTracker() { + @Override + public double getScore() { + return 100.0; + } + + @Override + public void update(java.time.Duration latency) {} + + @Override + public void recordError(java.time.Duration penalty) {} + }); + + e2.setLatencyTracker( + new LatencyTracker() { + @Override + public double getScore() { + return 10.0; + } + + @Override + public void update(java.time.Duration latency) {} + + @Override + public void recordError(java.time.Duration penalty) {} + }); + + ReplicaSelector selector = new PowerOfTwoReplicaSelector(); + + KeyRangeCache cache = new KeyRangeCache(endpointCache, null, selector); + cache.addRanges(twoReplicaUpdate()); + + RoutingHint.Builder hint = RoutingHint.newBuilder().setKey(bytes("a")); + ChannelEndpoint server = + cache.fillRoutingHint( + false, + KeyRangeCache.RangeMode.COVERING_SPLIT, + DirectedReadOptions.getDefaultInstance(), + hint); + + assertNotNull(server); + assertEquals("server2", server.getAddress()); + } } diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelectorTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelectorTest.java index 424efb363df6..090bab876c19 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelectorTest.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelectorTest.java @@ -21,9 +21,7 @@ import static org.junit.Assert.assertTrue; import java.util.Arrays; -import java.util.HashMap; import java.util.List; -import java.util.Map; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -33,11 +31,16 @@ public class PowerOfTwoReplicaSelectorTest { private static class TestEndpoint implements ChannelEndpoint { private final String address; + private double score = Double.MAX_VALUE; TestEndpoint(String address) { this.address = address; } + void setScore(double score) { + this.score = score; + } + @Override public String getAddress() { return address; @@ -57,72 +60,118 @@ public boolean isTransientFailure() { public io.grpc.ManagedChannel getChannel() { return null; } + + @Override + public LatencyTracker getLatencyTracker() { + return new LatencyTracker() { + @Override + public double getScore() { + return score; + } + + @Override + public void update(java.time.Duration latency) {} + + @Override + public void recordError(java.time.Duration penalty) {} + }; + } } @Test public void testEmptyList() { PowerOfTwoReplicaSelector selector = new PowerOfTwoReplicaSelector(); - assertNull(selector.select(null, endpoint -> 1.0)); - assertNull(selector.select(Arrays.asList(), endpoint -> 1.0)); + assertNull(selector.select(null)); + assertNull(selector.select(Arrays.asList())); } @Test public void testSingleElement() { PowerOfTwoReplicaSelector selector = new PowerOfTwoReplicaSelector(); ChannelEndpoint endpoint = new TestEndpoint("a"); - assertEquals(endpoint, selector.select(Arrays.asList(endpoint), e -> 1.0)); + assertEquals(endpoint, selector.select(Arrays.asList(endpoint))); } @Test public void testTwoElementsPicksBetter() { PowerOfTwoReplicaSelector selector = new PowerOfTwoReplicaSelector(); - ChannelEndpoint better = new TestEndpoint("better"); - ChannelEndpoint worse = new TestEndpoint("worse"); - - Map scores = new HashMap<>(); - scores.put(better, 10.0); - scores.put(worse, 20.0); + TestEndpoint better = new TestEndpoint("better"); + better.setScore(10.0); + TestEndpoint worse = new TestEndpoint("worse"); + worse.setScore(20.0); List candidates = Arrays.asList(better, worse); for (int i = 0; i < 100; i++) { - assertEquals(better, selector.select(candidates, scores::get)); + assertEquals(better, selector.select(candidates)); } } @Test public void testThreeElementsNeverPicksWorst() { PowerOfTwoReplicaSelector selector = new PowerOfTwoReplicaSelector(); - ChannelEndpoint best = new TestEndpoint("best"); - ChannelEndpoint middle = new TestEndpoint("middle"); - ChannelEndpoint worst = new TestEndpoint("worst"); - - Map scores = new HashMap<>(); - scores.put(best, 10.0); - scores.put(middle, 20.0); - scores.put(worst, 30.0); + TestEndpoint best = new TestEndpoint("best"); + best.setScore(10.0); + TestEndpoint middle = new TestEndpoint("middle"); + middle.setScore(20.0); + TestEndpoint worst = new TestEndpoint("worst"); + worst.setScore(30.0); List candidates = Arrays.asList(best, middle, worst); for (int i = 0; i < 100; i++) { - ChannelEndpoint selected = selector.select(candidates, scores::get); + ChannelEndpoint selected = selector.select(candidates); assertTrue("Should not pick worst", selected != worst); } } @Test - public void testNullScoresTreatedAsMax() { + public void testMissingScoresTreatedAsMax() { PowerOfTwoReplicaSelector selector = new PowerOfTwoReplicaSelector(); - ChannelEndpoint withScore = new TestEndpoint("withScore"); - ChannelEndpoint withoutScore = new TestEndpoint("withoutScore"); - - Map scores = new HashMap<>(); - scores.put(withScore, 100.0); + TestEndpoint withScore = new TestEndpoint("withScore"); + withScore.setScore(100.0); + TestEndpoint withoutScore = new TestEndpoint("withoutScore"); + // withoutScore has default Double.MAX_VALUE score List candidates = Arrays.asList(withScore, withoutScore); for (int i = 0; i < 100; i++) { - assertEquals(withScore, selector.select(candidates, scores::get)); + assertEquals(withScore, selector.select(candidates)); + } + } + + @Test + public void testNullTrackerTreatedAsMax() { + PowerOfTwoReplicaSelector selector = new PowerOfTwoReplicaSelector(); + TestEndpoint withScore = new TestEndpoint("withScore"); + withScore.setScore(100.0); + ChannelEndpoint withoutTracker = + new ChannelEndpoint() { + @Override + public String getAddress() { + return "withoutTracker"; + } + + @Override + public boolean isHealthy() { + return true; + } + + @Override + public boolean isTransientFailure() { + return false; + } + + @Override + public io.grpc.ManagedChannel getChannel() { + return null; + } + }; + + List candidates = Arrays.asList(withScore, withoutTracker); + + for (int i = 0; i < 100; i++) { + assertEquals(withScore, selector.select(candidates)); } } }