diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelEndpoint.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelEndpoint.java
index fc82c530fc6f..2a66d8b00ed6 100644
--- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelEndpoint.java
+++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelEndpoint.java
@@ -71,4 +71,15 @@ public interface ChannelEndpoint {
* @return the managed channel for this server
*/
ManagedChannel getChannel();
+
+ /**
+ * Returns the latency tracker for this server.
+ *
+ *
Default implementation returns {@code null}.
+ *
+ * @return the latency tracker, or {@code null} if latency tracking is not supported
+ */
+ default LatencyTracker getLatencyTracker() {
+ return null;
+ }
}
diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java
index 59955ccb4bd2..0324587fc6b5 100644
--- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java
+++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java
@@ -18,6 +18,7 @@
import com.google.api.core.InternalApi;
import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.Lists;
import com.google.common.hash.Hashing;
import com.google.protobuf.ByteString;
import com.google.spanner.v1.CacheUpdate;
@@ -66,6 +67,7 @@ public enum RangeMode {
private final ChannelEndpointCache endpointCache;
@javax.annotation.Nullable private final EndpointLifecycleManager lifecycleManager;
+ @javax.annotation.Nullable private final ReplicaSelector replicaSelector;
private final NavigableMap ranges =
new TreeMap<>(ByteString.unsignedLexicographicalComparator());
private final Map groups = new HashMap<>();
@@ -78,14 +80,22 @@ public enum RangeMode {
private volatile int minCacheEntriesForRandomPick = DEFAULT_MIN_ENTRIES_FOR_RANDOM_PICK;
public KeyRangeCache(ChannelEndpointCache endpointCache) {
- this(endpointCache, null);
+ this(endpointCache, null, null);
}
public KeyRangeCache(
ChannelEndpointCache endpointCache,
@javax.annotation.Nullable EndpointLifecycleManager lifecycleManager) {
+ this(endpointCache, lifecycleManager, null);
+ }
+
+ public KeyRangeCache(
+ ChannelEndpointCache endpointCache,
+ @javax.annotation.Nullable EndpointLifecycleManager lifecycleManager,
+ @javax.annotation.Nullable ReplicaSelector replicaSelector) {
this.endpointCache = Objects.requireNonNull(endpointCache);
this.lifecycleManager = lifecycleManager;
+ this.replicaSelector = replicaSelector;
}
@VisibleForTesting
@@ -802,6 +812,8 @@ && leader().distance <= MAX_LOCAL_REPLICA_DISTANCE) {
return leader();
}
}
+ List eligibleTablets =
+ replicaSelector != null ? new ArrayList<>(tablets.size()) : null;
for (int index = 0; index < tablets.size(); index++) {
if (checkedLeader && index == leaderIndex) {
continue;
@@ -813,7 +825,31 @@ && leader().distance <= MAX_LOCAL_REPLICA_DISTANCE) {
if (tablet.shouldSkip(hintBuilder, excludedEndpoints, skippedTabletUids)) {
continue;
}
- return tablet;
+ if (replicaSelector == null) {
+ return tablet;
+ }
+ eligibleTablets.add(tablet);
+ }
+
+ if (replicaSelector != null) {
+ return selectReplicaLocked(eligibleTablets);
+ }
+ return null;
+ }
+
+ private CachedTablet selectReplicaLocked(final List eligibleTablets) {
+ if (eligibleTablets.isEmpty()) {
+ return null;
+ }
+ List candidates =
+ Lists.transform(eligibleTablets, tablet -> tablet.endpoint);
+ ChannelEndpoint selectedEndpoint = replicaSelector.select(candidates);
+
+ // The number of eligible tablets is always small, so a linear search is more efficient.
+ for (CachedTablet tablet : eligibleTablets) {
+ if (tablet.endpoint.equals(selectedEndpoint)) {
+ return tablet;
+ }
}
return null;
}
diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelector.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelector.java
index c7cc2012a615..08c8b9ac8a72 100644
--- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelector.java
+++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelector.java
@@ -18,11 +18,9 @@
import com.google.api.core.BetaApi;
import com.google.api.core.InternalApi;
-import com.google.common.base.MoreObjects;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
-import java.util.function.Function;
/** Implementation of {@link ReplicaSelector} using the "Power of 2 Random Choices" strategy. */
@InternalApi
@@ -30,8 +28,7 @@
public class PowerOfTwoReplicaSelector implements ReplicaSelector {
@Override
- public ChannelEndpoint select(
- List candidates, Function scoreLookup) {
+ public ChannelEndpoint select(List candidates) {
if (candidates == null || candidates.isEmpty()) {
return null;
}
@@ -49,12 +46,11 @@ public ChannelEndpoint select(
ChannelEndpoint c1 = candidates.get(index1);
ChannelEndpoint c2 = candidates.get(index2);
- Double score1 = scoreLookup.apply(c1);
- Double score2 = scoreLookup.apply(c2);
+ LatencyTracker t1 = c1.getLatencyTracker();
+ LatencyTracker t2 = c2.getLatencyTracker();
- // Handle null scores by treating them as Double.MAX_VALUE (lowest priority)
- double s1 = MoreObjects.firstNonNull(score1, Double.MAX_VALUE);
- double s2 = MoreObjects.firstNonNull(score2, Double.MAX_VALUE);
+ double s1 = t1 != null ? t1.getScore() : Double.MAX_VALUE;
+ double s2 = t2 != null ? t2.getScore() : Double.MAX_VALUE;
return s1 <= s2 ? c1 : c2;
}
diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ReplicaSelector.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ReplicaSelector.java
index de4f58e50f1e..9aa7f1e4a62a 100644
--- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ReplicaSelector.java
+++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ReplicaSelector.java
@@ -19,7 +19,6 @@
import com.google.api.core.BetaApi;
import com.google.api.core.InternalApi;
import java.util.List;
-import java.util.function.Function;
/** Interface for selecting a replica from a list of candidates. */
@InternalApi
@@ -30,9 +29,7 @@ public interface ReplicaSelector {
* Selects a replica from the given list of candidates.
*
* @param candidates the list of eligible candidates.
- * @param scoreLookup a function to look up the latency score for a candidate.
* @return the selected candidate, or null if the list is empty.
*/
- ChannelEndpoint select(
- List candidates, Function scoreLookup);
+ ChannelEndpoint select(List candidates);
}
diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyRangeCacheTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyRangeCacheTest.java
index b19123daa704..87be06d0329d 100644
--- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyRangeCacheTest.java
+++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyRangeCacheTest.java
@@ -827,11 +827,16 @@ static final class FakeEndpoint implements ChannelEndpoint {
private final String address;
private final FakeManagedChannel channel = new FakeManagedChannel();
private EndpointHealthState state = EndpointHealthState.READY;
+ private LatencyTracker latencyTracker = null;
FakeEndpoint(String address) {
this.address = address;
}
+ void setLatencyTracker(LatencyTracker tracker) {
+ this.latencyTracker = tracker;
+ }
+
@Override
public String getAddress() {
return address;
@@ -852,6 +857,11 @@ public ManagedChannel getChannel() {
return channel;
}
+ @Override
+ public LatencyTracker getLatencyTracker() {
+ return latencyTracker;
+ }
+
void setState(EndpointHealthState state) {
this.state = state;
channel.setConnectivityState(toConnectivityState(state));
@@ -926,4 +936,88 @@ public String authority() {
return "fake";
}
}
+
+ @Test
+ public void usesReplicaSelectorWhenEnabled() {
+ FakeEndpointCache endpointCache = new FakeEndpointCache();
+
+ ReplicaSelector mockSelector =
+ (candidates) -> {
+ for (ChannelEndpoint candidate : candidates) {
+ if (candidate.getAddress().equals("server2")) {
+ return candidate;
+ }
+ }
+ return null;
+ };
+
+ KeyRangeCache cache = new KeyRangeCache(endpointCache, null, mockSelector);
+ cache.addRanges(twoReplicaUpdate());
+
+ endpointCache.get("server1");
+ endpointCache.get("server2");
+
+ RoutingHint.Builder hint = RoutingHint.newBuilder().setKey(bytes("a"));
+ ChannelEndpoint server =
+ cache.fillRoutingHint(
+ false,
+ KeyRangeCache.RangeMode.COVERING_SPLIT,
+ DirectedReadOptions.getDefaultInstance(),
+ hint);
+
+ assertNotNull(server);
+ assertEquals("server2", server.getAddress());
+ }
+
+ @Test
+ public void usesLatencyTrackerInSelector() {
+ FakeEndpointCache endpointCache = new FakeEndpointCache();
+
+ FakeEndpoint e1 = (FakeEndpoint) endpointCache.get("server1");
+ FakeEndpoint e2 = (FakeEndpoint) endpointCache.get("server2");
+
+ e1.setLatencyTracker(
+ new LatencyTracker() {
+ @Override
+ public double getScore() {
+ return 100.0;
+ }
+
+ @Override
+ public void update(java.time.Duration latency) {}
+
+ @Override
+ public void recordError(java.time.Duration penalty) {}
+ });
+
+ e2.setLatencyTracker(
+ new LatencyTracker() {
+ @Override
+ public double getScore() {
+ return 10.0;
+ }
+
+ @Override
+ public void update(java.time.Duration latency) {}
+
+ @Override
+ public void recordError(java.time.Duration penalty) {}
+ });
+
+ ReplicaSelector selector = new PowerOfTwoReplicaSelector();
+
+ KeyRangeCache cache = new KeyRangeCache(endpointCache, null, selector);
+ cache.addRanges(twoReplicaUpdate());
+
+ RoutingHint.Builder hint = RoutingHint.newBuilder().setKey(bytes("a"));
+ ChannelEndpoint server =
+ cache.fillRoutingHint(
+ false,
+ KeyRangeCache.RangeMode.COVERING_SPLIT,
+ DirectedReadOptions.getDefaultInstance(),
+ hint);
+
+ assertNotNull(server);
+ assertEquals("server2", server.getAddress());
+ }
}
diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelectorTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelectorTest.java
index 424efb363df6..090bab876c19 100644
--- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelectorTest.java
+++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelectorTest.java
@@ -21,9 +21,7 @@
import static org.junit.Assert.assertTrue;
import java.util.Arrays;
-import java.util.HashMap;
import java.util.List;
-import java.util.Map;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -33,11 +31,16 @@ public class PowerOfTwoReplicaSelectorTest {
private static class TestEndpoint implements ChannelEndpoint {
private final String address;
+ private double score = Double.MAX_VALUE;
TestEndpoint(String address) {
this.address = address;
}
+ void setScore(double score) {
+ this.score = score;
+ }
+
@Override
public String getAddress() {
return address;
@@ -57,72 +60,118 @@ public boolean isTransientFailure() {
public io.grpc.ManagedChannel getChannel() {
return null;
}
+
+ @Override
+ public LatencyTracker getLatencyTracker() {
+ return new LatencyTracker() {
+ @Override
+ public double getScore() {
+ return score;
+ }
+
+ @Override
+ public void update(java.time.Duration latency) {}
+
+ @Override
+ public void recordError(java.time.Duration penalty) {}
+ };
+ }
}
@Test
public void testEmptyList() {
PowerOfTwoReplicaSelector selector = new PowerOfTwoReplicaSelector();
- assertNull(selector.select(null, endpoint -> 1.0));
- assertNull(selector.select(Arrays.asList(), endpoint -> 1.0));
+ assertNull(selector.select(null));
+ assertNull(selector.select(Arrays.asList()));
}
@Test
public void testSingleElement() {
PowerOfTwoReplicaSelector selector = new PowerOfTwoReplicaSelector();
ChannelEndpoint endpoint = new TestEndpoint("a");
- assertEquals(endpoint, selector.select(Arrays.asList(endpoint), e -> 1.0));
+ assertEquals(endpoint, selector.select(Arrays.asList(endpoint)));
}
@Test
public void testTwoElementsPicksBetter() {
PowerOfTwoReplicaSelector selector = new PowerOfTwoReplicaSelector();
- ChannelEndpoint better = new TestEndpoint("better");
- ChannelEndpoint worse = new TestEndpoint("worse");
-
- Map scores = new HashMap<>();
- scores.put(better, 10.0);
- scores.put(worse, 20.0);
+ TestEndpoint better = new TestEndpoint("better");
+ better.setScore(10.0);
+ TestEndpoint worse = new TestEndpoint("worse");
+ worse.setScore(20.0);
List candidates = Arrays.asList(better, worse);
for (int i = 0; i < 100; i++) {
- assertEquals(better, selector.select(candidates, scores::get));
+ assertEquals(better, selector.select(candidates));
}
}
@Test
public void testThreeElementsNeverPicksWorst() {
PowerOfTwoReplicaSelector selector = new PowerOfTwoReplicaSelector();
- ChannelEndpoint best = new TestEndpoint("best");
- ChannelEndpoint middle = new TestEndpoint("middle");
- ChannelEndpoint worst = new TestEndpoint("worst");
-
- Map scores = new HashMap<>();
- scores.put(best, 10.0);
- scores.put(middle, 20.0);
- scores.put(worst, 30.0);
+ TestEndpoint best = new TestEndpoint("best");
+ best.setScore(10.0);
+ TestEndpoint middle = new TestEndpoint("middle");
+ middle.setScore(20.0);
+ TestEndpoint worst = new TestEndpoint("worst");
+ worst.setScore(30.0);
List candidates = Arrays.asList(best, middle, worst);
for (int i = 0; i < 100; i++) {
- ChannelEndpoint selected = selector.select(candidates, scores::get);
+ ChannelEndpoint selected = selector.select(candidates);
assertTrue("Should not pick worst", selected != worst);
}
}
@Test
- public void testNullScoresTreatedAsMax() {
+ public void testMissingScoresTreatedAsMax() {
PowerOfTwoReplicaSelector selector = new PowerOfTwoReplicaSelector();
- ChannelEndpoint withScore = new TestEndpoint("withScore");
- ChannelEndpoint withoutScore = new TestEndpoint("withoutScore");
-
- Map scores = new HashMap<>();
- scores.put(withScore, 100.0);
+ TestEndpoint withScore = new TestEndpoint("withScore");
+ withScore.setScore(100.0);
+ TestEndpoint withoutScore = new TestEndpoint("withoutScore");
+ // withoutScore has default Double.MAX_VALUE score
List candidates = Arrays.asList(withScore, withoutScore);
for (int i = 0; i < 100; i++) {
- assertEquals(withScore, selector.select(candidates, scores::get));
+ assertEquals(withScore, selector.select(candidates));
+ }
+ }
+
+ @Test
+ public void testNullTrackerTreatedAsMax() {
+ PowerOfTwoReplicaSelector selector = new PowerOfTwoReplicaSelector();
+ TestEndpoint withScore = new TestEndpoint("withScore");
+ withScore.setScore(100.0);
+ ChannelEndpoint withoutTracker =
+ new ChannelEndpoint() {
+ @Override
+ public String getAddress() {
+ return "withoutTracker";
+ }
+
+ @Override
+ public boolean isHealthy() {
+ return true;
+ }
+
+ @Override
+ public boolean isTransientFailure() {
+ return false;
+ }
+
+ @Override
+ public io.grpc.ManagedChannel getChannel() {
+ return null;
+ }
+ };
+
+ List candidates = Arrays.asList(withScore, withoutTracker);
+
+ for (int i = 0; i < 100; i++) {
+ assertEquals(withScore, selector.select(candidates));
}
}
}