diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelEndpoint.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelEndpoint.java index fc82c530fc6f..2a66d8b00ed6 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelEndpoint.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelEndpoint.java @@ -71,4 +71,15 @@ public interface ChannelEndpoint { * @return the managed channel for this server */ ManagedChannel getChannel(); + + /** + * Returns the latency tracker for this server. + * + *

Default implementation returns {@code null}. + * + * @return the latency tracker, or {@code null} if latency tracking is not supported + */ + default LatencyTracker getLatencyTracker() { + return null; + } } diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java index 59955ccb4bd2..0324587fc6b5 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java @@ -18,6 +18,7 @@ import com.google.api.core.InternalApi; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.Lists; import com.google.common.hash.Hashing; import com.google.protobuf.ByteString; import com.google.spanner.v1.CacheUpdate; @@ -66,6 +67,7 @@ public enum RangeMode { private final ChannelEndpointCache endpointCache; @javax.annotation.Nullable private final EndpointLifecycleManager lifecycleManager; + @javax.annotation.Nullable private final ReplicaSelector replicaSelector; private final NavigableMap ranges = new TreeMap<>(ByteString.unsignedLexicographicalComparator()); private final Map groups = new HashMap<>(); @@ -78,14 +80,22 @@ public enum RangeMode { private volatile int minCacheEntriesForRandomPick = DEFAULT_MIN_ENTRIES_FOR_RANDOM_PICK; public KeyRangeCache(ChannelEndpointCache endpointCache) { - this(endpointCache, null); + this(endpointCache, null, null); } public KeyRangeCache( ChannelEndpointCache endpointCache, @javax.annotation.Nullable EndpointLifecycleManager lifecycleManager) { + this(endpointCache, lifecycleManager, null); + } + + public KeyRangeCache( + ChannelEndpointCache endpointCache, + @javax.annotation.Nullable EndpointLifecycleManager lifecycleManager, + @javax.annotation.Nullable ReplicaSelector replicaSelector) { this.endpointCache = Objects.requireNonNull(endpointCache); this.lifecycleManager = lifecycleManager; + this.replicaSelector = replicaSelector; } @VisibleForTesting @@ -802,6 +812,8 @@ && leader().distance <= MAX_LOCAL_REPLICA_DISTANCE) { return leader(); } } + List eligibleTablets = + replicaSelector != null ? new ArrayList<>(tablets.size()) : null; for (int index = 0; index < tablets.size(); index++) { if (checkedLeader && index == leaderIndex) { continue; @@ -813,7 +825,31 @@ && leader().distance <= MAX_LOCAL_REPLICA_DISTANCE) { if (tablet.shouldSkip(hintBuilder, excludedEndpoints, skippedTabletUids)) { continue; } - return tablet; + if (replicaSelector == null) { + return tablet; + } + eligibleTablets.add(tablet); + } + + if (replicaSelector != null) { + return selectReplicaLocked(eligibleTablets); + } + return null; + } + + private CachedTablet selectReplicaLocked(final List eligibleTablets) { + if (eligibleTablets.isEmpty()) { + return null; + } + List candidates = + Lists.transform(eligibleTablets, tablet -> tablet.endpoint); + ChannelEndpoint selectedEndpoint = replicaSelector.select(candidates); + + // The number of eligible tablets is always small, so a linear search is more efficient. + for (CachedTablet tablet : eligibleTablets) { + if (tablet.endpoint.equals(selectedEndpoint)) { + return tablet; + } } return null; } diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelector.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelector.java index c7cc2012a615..08c8b9ac8a72 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelector.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelector.java @@ -18,11 +18,9 @@ import com.google.api.core.BetaApi; import com.google.api.core.InternalApi; -import com.google.common.base.MoreObjects; import java.util.List; import java.util.Random; import java.util.concurrent.ThreadLocalRandom; -import java.util.function.Function; /** Implementation of {@link ReplicaSelector} using the "Power of 2 Random Choices" strategy. */ @InternalApi @@ -30,8 +28,7 @@ public class PowerOfTwoReplicaSelector implements ReplicaSelector { @Override - public ChannelEndpoint select( - List candidates, Function scoreLookup) { + public ChannelEndpoint select(List candidates) { if (candidates == null || candidates.isEmpty()) { return null; } @@ -49,12 +46,11 @@ public ChannelEndpoint select( ChannelEndpoint c1 = candidates.get(index1); ChannelEndpoint c2 = candidates.get(index2); - Double score1 = scoreLookup.apply(c1); - Double score2 = scoreLookup.apply(c2); + LatencyTracker t1 = c1.getLatencyTracker(); + LatencyTracker t2 = c2.getLatencyTracker(); - // Handle null scores by treating them as Double.MAX_VALUE (lowest priority) - double s1 = MoreObjects.firstNonNull(score1, Double.MAX_VALUE); - double s2 = MoreObjects.firstNonNull(score2, Double.MAX_VALUE); + double s1 = t1 != null ? t1.getScore() : Double.MAX_VALUE; + double s2 = t2 != null ? t2.getScore() : Double.MAX_VALUE; return s1 <= s2 ? c1 : c2; } diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ReplicaSelector.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ReplicaSelector.java index de4f58e50f1e..9aa7f1e4a62a 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ReplicaSelector.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ReplicaSelector.java @@ -19,7 +19,6 @@ import com.google.api.core.BetaApi; import com.google.api.core.InternalApi; import java.util.List; -import java.util.function.Function; /** Interface for selecting a replica from a list of candidates. */ @InternalApi @@ -30,9 +29,7 @@ public interface ReplicaSelector { * Selects a replica from the given list of candidates. * * @param candidates the list of eligible candidates. - * @param scoreLookup a function to look up the latency score for a candidate. * @return the selected candidate, or null if the list is empty. */ - ChannelEndpoint select( - List candidates, Function scoreLookup); + ChannelEndpoint select(List candidates); } diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyRangeCacheTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyRangeCacheTest.java index b19123daa704..87be06d0329d 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyRangeCacheTest.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyRangeCacheTest.java @@ -827,11 +827,16 @@ static final class FakeEndpoint implements ChannelEndpoint { private final String address; private final FakeManagedChannel channel = new FakeManagedChannel(); private EndpointHealthState state = EndpointHealthState.READY; + private LatencyTracker latencyTracker = null; FakeEndpoint(String address) { this.address = address; } + void setLatencyTracker(LatencyTracker tracker) { + this.latencyTracker = tracker; + } + @Override public String getAddress() { return address; @@ -852,6 +857,11 @@ public ManagedChannel getChannel() { return channel; } + @Override + public LatencyTracker getLatencyTracker() { + return latencyTracker; + } + void setState(EndpointHealthState state) { this.state = state; channel.setConnectivityState(toConnectivityState(state)); @@ -926,4 +936,88 @@ public String authority() { return "fake"; } } + + @Test + public void usesReplicaSelectorWhenEnabled() { + FakeEndpointCache endpointCache = new FakeEndpointCache(); + + ReplicaSelector mockSelector = + (candidates) -> { + for (ChannelEndpoint candidate : candidates) { + if (candidate.getAddress().equals("server2")) { + return candidate; + } + } + return null; + }; + + KeyRangeCache cache = new KeyRangeCache(endpointCache, null, mockSelector); + cache.addRanges(twoReplicaUpdate()); + + endpointCache.get("server1"); + endpointCache.get("server2"); + + RoutingHint.Builder hint = RoutingHint.newBuilder().setKey(bytes("a")); + ChannelEndpoint server = + cache.fillRoutingHint( + false, + KeyRangeCache.RangeMode.COVERING_SPLIT, + DirectedReadOptions.getDefaultInstance(), + hint); + + assertNotNull(server); + assertEquals("server2", server.getAddress()); + } + + @Test + public void usesLatencyTrackerInSelector() { + FakeEndpointCache endpointCache = new FakeEndpointCache(); + + FakeEndpoint e1 = (FakeEndpoint) endpointCache.get("server1"); + FakeEndpoint e2 = (FakeEndpoint) endpointCache.get("server2"); + + e1.setLatencyTracker( + new LatencyTracker() { + @Override + public double getScore() { + return 100.0; + } + + @Override + public void update(java.time.Duration latency) {} + + @Override + public void recordError(java.time.Duration penalty) {} + }); + + e2.setLatencyTracker( + new LatencyTracker() { + @Override + public double getScore() { + return 10.0; + } + + @Override + public void update(java.time.Duration latency) {} + + @Override + public void recordError(java.time.Duration penalty) {} + }); + + ReplicaSelector selector = new PowerOfTwoReplicaSelector(); + + KeyRangeCache cache = new KeyRangeCache(endpointCache, null, selector); + cache.addRanges(twoReplicaUpdate()); + + RoutingHint.Builder hint = RoutingHint.newBuilder().setKey(bytes("a")); + ChannelEndpoint server = + cache.fillRoutingHint( + false, + KeyRangeCache.RangeMode.COVERING_SPLIT, + DirectedReadOptions.getDefaultInstance(), + hint); + + assertNotNull(server); + assertEquals("server2", server.getAddress()); + } } diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelectorTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelectorTest.java index 424efb363df6..090bab876c19 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelectorTest.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelectorTest.java @@ -21,9 +21,7 @@ import static org.junit.Assert.assertTrue; import java.util.Arrays; -import java.util.HashMap; import java.util.List; -import java.util.Map; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -33,11 +31,16 @@ public class PowerOfTwoReplicaSelectorTest { private static class TestEndpoint implements ChannelEndpoint { private final String address; + private double score = Double.MAX_VALUE; TestEndpoint(String address) { this.address = address; } + void setScore(double score) { + this.score = score; + } + @Override public String getAddress() { return address; @@ -57,72 +60,118 @@ public boolean isTransientFailure() { public io.grpc.ManagedChannel getChannel() { return null; } + + @Override + public LatencyTracker getLatencyTracker() { + return new LatencyTracker() { + @Override + public double getScore() { + return score; + } + + @Override + public void update(java.time.Duration latency) {} + + @Override + public void recordError(java.time.Duration penalty) {} + }; + } } @Test public void testEmptyList() { PowerOfTwoReplicaSelector selector = new PowerOfTwoReplicaSelector(); - assertNull(selector.select(null, endpoint -> 1.0)); - assertNull(selector.select(Arrays.asList(), endpoint -> 1.0)); + assertNull(selector.select(null)); + assertNull(selector.select(Arrays.asList())); } @Test public void testSingleElement() { PowerOfTwoReplicaSelector selector = new PowerOfTwoReplicaSelector(); ChannelEndpoint endpoint = new TestEndpoint("a"); - assertEquals(endpoint, selector.select(Arrays.asList(endpoint), e -> 1.0)); + assertEquals(endpoint, selector.select(Arrays.asList(endpoint))); } @Test public void testTwoElementsPicksBetter() { PowerOfTwoReplicaSelector selector = new PowerOfTwoReplicaSelector(); - ChannelEndpoint better = new TestEndpoint("better"); - ChannelEndpoint worse = new TestEndpoint("worse"); - - Map scores = new HashMap<>(); - scores.put(better, 10.0); - scores.put(worse, 20.0); + TestEndpoint better = new TestEndpoint("better"); + better.setScore(10.0); + TestEndpoint worse = new TestEndpoint("worse"); + worse.setScore(20.0); List candidates = Arrays.asList(better, worse); for (int i = 0; i < 100; i++) { - assertEquals(better, selector.select(candidates, scores::get)); + assertEquals(better, selector.select(candidates)); } } @Test public void testThreeElementsNeverPicksWorst() { PowerOfTwoReplicaSelector selector = new PowerOfTwoReplicaSelector(); - ChannelEndpoint best = new TestEndpoint("best"); - ChannelEndpoint middle = new TestEndpoint("middle"); - ChannelEndpoint worst = new TestEndpoint("worst"); - - Map scores = new HashMap<>(); - scores.put(best, 10.0); - scores.put(middle, 20.0); - scores.put(worst, 30.0); + TestEndpoint best = new TestEndpoint("best"); + best.setScore(10.0); + TestEndpoint middle = new TestEndpoint("middle"); + middle.setScore(20.0); + TestEndpoint worst = new TestEndpoint("worst"); + worst.setScore(30.0); List candidates = Arrays.asList(best, middle, worst); for (int i = 0; i < 100; i++) { - ChannelEndpoint selected = selector.select(candidates, scores::get); + ChannelEndpoint selected = selector.select(candidates); assertTrue("Should not pick worst", selected != worst); } } @Test - public void testNullScoresTreatedAsMax() { + public void testMissingScoresTreatedAsMax() { PowerOfTwoReplicaSelector selector = new PowerOfTwoReplicaSelector(); - ChannelEndpoint withScore = new TestEndpoint("withScore"); - ChannelEndpoint withoutScore = new TestEndpoint("withoutScore"); - - Map scores = new HashMap<>(); - scores.put(withScore, 100.0); + TestEndpoint withScore = new TestEndpoint("withScore"); + withScore.setScore(100.0); + TestEndpoint withoutScore = new TestEndpoint("withoutScore"); + // withoutScore has default Double.MAX_VALUE score List candidates = Arrays.asList(withScore, withoutScore); for (int i = 0; i < 100; i++) { - assertEquals(withScore, selector.select(candidates, scores::get)); + assertEquals(withScore, selector.select(candidates)); + } + } + + @Test + public void testNullTrackerTreatedAsMax() { + PowerOfTwoReplicaSelector selector = new PowerOfTwoReplicaSelector(); + TestEndpoint withScore = new TestEndpoint("withScore"); + withScore.setScore(100.0); + ChannelEndpoint withoutTracker = + new ChannelEndpoint() { + @Override + public String getAddress() { + return "withoutTracker"; + } + + @Override + public boolean isHealthy() { + return true; + } + + @Override + public boolean isTransientFailure() { + return false; + } + + @Override + public io.grpc.ManagedChannel getChannel() { + return null; + } + }; + + List candidates = Arrays.asList(withScore, withoutTracker); + + for (int i = 0; i < 100; i++) { + assertEquals(withScore, selector.select(candidates)); } } }