Skip to content

Commit 1396402

Browse files
committed
chore(spanner): use ReplicaSelector in KeyRangeCache
Integrates the ReplicaSelector in KeyRangeCache, so it can be used to select the best replica to send the request to. This feature is effectively disabled for production, as the ReplicaSelector is set to null by default.
1 parent 9b455b9 commit 1396402

File tree

6 files changed

+225
-43
lines changed

6 files changed

+225
-43
lines changed

java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelEndpoint.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,4 +71,15 @@ public interface ChannelEndpoint {
7171
* @return the managed channel for this server
7272
*/
7373
ManagedChannel getChannel();
74+
75+
/**
76+
* Returns the latency tracker for this server.
77+
*
78+
* <p>Default implementation returns {@code null}.
79+
*
80+
* @return the latency tracker, or {@code null} if latency tracking is not supported
81+
*/
82+
default LatencyTracker getLatencyTracker() {
83+
return null;
84+
}
7485
}

java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import com.google.api.core.InternalApi;
2020
import com.google.common.annotations.VisibleForTesting;
21+
import com.google.common.collect.Lists;
2122
import com.google.common.hash.Hashing;
2223
import com.google.protobuf.ByteString;
2324
import com.google.spanner.v1.CacheUpdate;
@@ -66,6 +67,7 @@ public enum RangeMode {
6667

6768
private final ChannelEndpointCache endpointCache;
6869
@javax.annotation.Nullable private final EndpointLifecycleManager lifecycleManager;
70+
@javax.annotation.Nullable private final ReplicaSelector replicaSelector;
6971
private final NavigableMap<ByteString, CachedRange> ranges =
7072
new TreeMap<>(ByteString.unsignedLexicographicalComparator());
7173
private final Map<Long, CachedGroup> groups = new HashMap<>();
@@ -78,14 +80,22 @@ public enum RangeMode {
7880
private volatile int minCacheEntriesForRandomPick = DEFAULT_MIN_ENTRIES_FOR_RANDOM_PICK;
7981

8082
public KeyRangeCache(ChannelEndpointCache endpointCache) {
81-
this(endpointCache, null);
83+
this(endpointCache, null, null);
8284
}
8385

8486
public KeyRangeCache(
8587
ChannelEndpointCache endpointCache,
8688
@javax.annotation.Nullable EndpointLifecycleManager lifecycleManager) {
89+
this(endpointCache, lifecycleManager, null);
90+
}
91+
92+
public KeyRangeCache(
93+
ChannelEndpointCache endpointCache,
94+
@javax.annotation.Nullable EndpointLifecycleManager lifecycleManager,
95+
@javax.annotation.Nullable ReplicaSelector replicaSelector) {
8796
this.endpointCache = Objects.requireNonNull(endpointCache);
8897
this.lifecycleManager = lifecycleManager;
98+
this.replicaSelector = replicaSelector;
8999
}
90100

91101
@VisibleForTesting
@@ -802,6 +812,7 @@ && leader().distance <= MAX_LOCAL_REPLICA_DISTANCE) {
802812
return leader();
803813
}
804814
}
815+
List<CachedTablet> eligibleTablets = replicaSelector != null ? new ArrayList<>() : null;
805816
for (int index = 0; index < tablets.size(); index++) {
806817
if (checkedLeader && index == leaderIndex) {
807818
continue;
@@ -813,7 +824,31 @@ && leader().distance <= MAX_LOCAL_REPLICA_DISTANCE) {
813824
if (tablet.shouldSkip(hintBuilder, excludedEndpoints, skippedTabletUids)) {
814825
continue;
815826
}
816-
return tablet;
827+
if (replicaSelector == null) {
828+
return tablet;
829+
}
830+
eligibleTablets.add(tablet);
831+
}
832+
833+
if (replicaSelector != null) {
834+
return selectReplicaLocked(eligibleTablets);
835+
}
836+
return null;
837+
}
838+
839+
private CachedTablet selectReplicaLocked(final List<CachedTablet> eligibleTablets) {
840+
if (eligibleTablets.isEmpty()) {
841+
return null;
842+
}
843+
List<ChannelEndpoint> candidates =
844+
Lists.transform(eligibleTablets, tablet -> tablet.endpoint);
845+
ChannelEndpoint selectedEndpoint = replicaSelector.select(candidates);
846+
847+
// The number of eligible tablets is always small, so a linear search is more efficient.
848+
for (CachedTablet tablet : eligibleTablets) {
849+
if (tablet.endpoint.equals(selectedEndpoint)) {
850+
return tablet;
851+
}
817852
}
818853
return null;
819854
}

java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelector.java

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,17 @@
1818

1919
import com.google.api.core.BetaApi;
2020
import com.google.api.core.InternalApi;
21-
import com.google.common.base.MoreObjects;
2221
import java.util.List;
2322
import java.util.Random;
2423
import java.util.concurrent.ThreadLocalRandom;
25-
import java.util.function.Function;
2624

2725
/** Implementation of {@link ReplicaSelector} using the "Power of 2 Random Choices" strategy. */
2826
@InternalApi
2927
@BetaApi
3028
public class PowerOfTwoReplicaSelector implements ReplicaSelector {
3129

3230
@Override
33-
public ChannelEndpoint select(
34-
List<ChannelEndpoint> candidates, Function<ChannelEndpoint, Double> scoreLookup) {
31+
public ChannelEndpoint select(List<ChannelEndpoint> candidates) {
3532
if (candidates == null || candidates.isEmpty()) {
3633
return null;
3734
}
@@ -49,12 +46,11 @@ public ChannelEndpoint select(
4946
ChannelEndpoint c1 = candidates.get(index1);
5047
ChannelEndpoint c2 = candidates.get(index2);
5148

52-
Double score1 = scoreLookup.apply(c1);
53-
Double score2 = scoreLookup.apply(c2);
49+
LatencyTracker t1 = c1.getLatencyTracker();
50+
LatencyTracker t2 = c2.getLatencyTracker();
5451

55-
// Handle null scores by treating them as Double.MAX_VALUE (lowest priority)
56-
double s1 = MoreObjects.firstNonNull(score1, Double.MAX_VALUE);
57-
double s2 = MoreObjects.firstNonNull(score2, Double.MAX_VALUE);
52+
double s1 = t1 != null ? t1.getScore() : Double.MAX_VALUE;
53+
double s2 = t2 != null ? t2.getScore() : Double.MAX_VALUE;
5854

5955
return s1 <= s2 ? c1 : c2;
6056
}

java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ReplicaSelector.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import com.google.api.core.BetaApi;
2020
import com.google.api.core.InternalApi;
2121
import java.util.List;
22-
import java.util.function.Function;
2322

2423
/** Interface for selecting a replica from a list of candidates. */
2524
@InternalApi
@@ -30,9 +29,7 @@ public interface ReplicaSelector {
3029
* Selects a replica from the given list of candidates.
3130
*
3231
* @param candidates the list of eligible candidates.
33-
* @param scoreLookup a function to look up the latency score for a candidate.
3432
* @return the selected candidate, or null if the list is empty.
3533
*/
36-
ChannelEndpoint select(
37-
List<ChannelEndpoint> candidates, Function<ChannelEndpoint, Double> scoreLookup);
34+
ChannelEndpoint select(List<ChannelEndpoint> candidates);
3835
}

java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyRangeCacheTest.java

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -827,11 +827,16 @@ static final class FakeEndpoint implements ChannelEndpoint {
827827
private final String address;
828828
private final FakeManagedChannel channel = new FakeManagedChannel();
829829
private EndpointHealthState state = EndpointHealthState.READY;
830+
private LatencyTracker latencyTracker = null;
830831

831832
FakeEndpoint(String address) {
832833
this.address = address;
833834
}
834835

836+
void setLatencyTracker(LatencyTracker tracker) {
837+
this.latencyTracker = tracker;
838+
}
839+
835840
@Override
836841
public String getAddress() {
837842
return address;
@@ -852,6 +857,11 @@ public ManagedChannel getChannel() {
852857
return channel;
853858
}
854859

860+
@Override
861+
public LatencyTracker getLatencyTracker() {
862+
return latencyTracker;
863+
}
864+
855865
void setState(EndpointHealthState state) {
856866
this.state = state;
857867
channel.setConnectivityState(toConnectivityState(state));
@@ -926,4 +936,88 @@ public String authority() {
926936
return "fake";
927937
}
928938
}
939+
940+
@Test
941+
public void usesReplicaSelectorWhenEnabled() {
942+
FakeEndpointCache endpointCache = new FakeEndpointCache();
943+
944+
ReplicaSelector mockSelector =
945+
(candidates) -> {
946+
for (ChannelEndpoint candidate : candidates) {
947+
if (candidate.getAddress().equals("server2")) {
948+
return candidate;
949+
}
950+
}
951+
return null;
952+
};
953+
954+
KeyRangeCache cache = new KeyRangeCache(endpointCache, null, mockSelector);
955+
cache.addRanges(twoReplicaUpdate());
956+
957+
endpointCache.get("server1");
958+
endpointCache.get("server2");
959+
960+
RoutingHint.Builder hint = RoutingHint.newBuilder().setKey(bytes("a"));
961+
ChannelEndpoint server =
962+
cache.fillRoutingHint(
963+
false,
964+
KeyRangeCache.RangeMode.COVERING_SPLIT,
965+
DirectedReadOptions.getDefaultInstance(),
966+
hint);
967+
968+
assertNotNull(server);
969+
assertEquals("server2", server.getAddress());
970+
}
971+
972+
@Test
973+
public void usesLatencyTrackerInSelector() {
974+
FakeEndpointCache endpointCache = new FakeEndpointCache();
975+
976+
FakeEndpoint e1 = (FakeEndpoint) endpointCache.get("server1");
977+
FakeEndpoint e2 = (FakeEndpoint) endpointCache.get("server2");
978+
979+
e1.setLatencyTracker(
980+
new LatencyTracker() {
981+
@Override
982+
public double getScore() {
983+
return 100.0;
984+
}
985+
986+
@Override
987+
public void update(java.time.Duration latency) {}
988+
989+
@Override
990+
public void recordError(java.time.Duration penalty) {}
991+
});
992+
993+
e2.setLatencyTracker(
994+
new LatencyTracker() {
995+
@Override
996+
public double getScore() {
997+
return 10.0;
998+
}
999+
1000+
@Override
1001+
public void update(java.time.Duration latency) {}
1002+
1003+
@Override
1004+
public void recordError(java.time.Duration penalty) {}
1005+
});
1006+
1007+
ReplicaSelector selector = new PowerOfTwoReplicaSelector();
1008+
1009+
KeyRangeCache cache = new KeyRangeCache(endpointCache, null, selector);
1010+
cache.addRanges(twoReplicaUpdate());
1011+
1012+
RoutingHint.Builder hint = RoutingHint.newBuilder().setKey(bytes("a"));
1013+
ChannelEndpoint server =
1014+
cache.fillRoutingHint(
1015+
false,
1016+
KeyRangeCache.RangeMode.COVERING_SPLIT,
1017+
DirectedReadOptions.getDefaultInstance(),
1018+
hint);
1019+
1020+
assertNotNull(server);
1021+
assertEquals("server2", server.getAddress());
1022+
}
9291023
}

0 commit comments

Comments
 (0)