Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,15 @@ public interface ChannelEndpoint {
* @return the managed channel for this server
*/
ManagedChannel getChannel();

/**
* Returns the latency tracker for this server.
*
* <p>Default implementation returns {@code null}.
*
* @return the latency tracker, or {@code null} if latency tracking is not supported
*/
default LatencyTracker getLatencyTracker() {
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.google.api.core.InternalApi;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
import com.google.common.hash.Hashing;
import com.google.protobuf.ByteString;
import com.google.spanner.v1.CacheUpdate;
Expand Down Expand Up @@ -66,6 +67,7 @@ public enum RangeMode {

private final ChannelEndpointCache endpointCache;
@javax.annotation.Nullable private final EndpointLifecycleManager lifecycleManager;
@javax.annotation.Nullable private final ReplicaSelector replicaSelector;
private final NavigableMap<ByteString, CachedRange> ranges =
new TreeMap<>(ByteString.unsignedLexicographicalComparator());
private final Map<Long, CachedGroup> groups = new HashMap<>();
Expand All @@ -78,14 +80,22 @@ public enum RangeMode {
private volatile int minCacheEntriesForRandomPick = DEFAULT_MIN_ENTRIES_FOR_RANDOM_PICK;

public KeyRangeCache(ChannelEndpointCache endpointCache) {
this(endpointCache, null);
this(endpointCache, null, null);
}

public KeyRangeCache(
ChannelEndpointCache endpointCache,
@javax.annotation.Nullable EndpointLifecycleManager lifecycleManager) {
this(endpointCache, lifecycleManager, null);
}

public KeyRangeCache(
ChannelEndpointCache endpointCache,
@javax.annotation.Nullable EndpointLifecycleManager lifecycleManager,
@javax.annotation.Nullable ReplicaSelector replicaSelector) {
this.endpointCache = Objects.requireNonNull(endpointCache);
this.lifecycleManager = lifecycleManager;
this.replicaSelector = replicaSelector;
}

@VisibleForTesting
Expand Down Expand Up @@ -802,6 +812,8 @@ && leader().distance <= MAX_LOCAL_REPLICA_DISTANCE) {
return leader();
}
}
List<CachedTablet> eligibleTablets =
replicaSelector != null ? new ArrayList<>(tablets.size()) : null;
for (int index = 0; index < tablets.size(); index++) {
if (checkedLeader && index == leaderIndex) {
continue;
Expand All @@ -813,7 +825,31 @@ && leader().distance <= MAX_LOCAL_REPLICA_DISTANCE) {
if (tablet.shouldSkip(hintBuilder, excludedEndpoints, skippedTabletUids)) {
continue;
}
return tablet;
if (replicaSelector == null) {
return tablet;
}
eligibleTablets.add(tablet);
}

if (replicaSelector != null) {
return selectReplicaLocked(eligibleTablets);
}
return null;
}

private CachedTablet selectReplicaLocked(final List<CachedTablet> eligibleTablets) {
if (eligibleTablets.isEmpty()) {
return null;
}
List<ChannelEndpoint> candidates =
Lists.transform(eligibleTablets, tablet -> tablet.endpoint);
ChannelEndpoint selectedEndpoint = replicaSelector.select(candidates);

// The number of eligible tablets is always small, so a linear search is more efficient.
for (CachedTablet tablet : eligibleTablets) {
if (tablet.endpoint.equals(selectedEndpoint)) {
return tablet;
}
}
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,17 @@

import com.google.api.core.BetaApi;
import com.google.api.core.InternalApi;
import com.google.common.base.MoreObjects;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Function;

/** Implementation of {@link ReplicaSelector} using the "Power of 2 Random Choices" strategy. */
@InternalApi
@BetaApi
public class PowerOfTwoReplicaSelector implements ReplicaSelector {

@Override
public ChannelEndpoint select(
List<ChannelEndpoint> candidates, Function<ChannelEndpoint, Double> scoreLookup) {
public ChannelEndpoint select(List<ChannelEndpoint> candidates) {
if (candidates == null || candidates.isEmpty()) {
return null;
}
Expand All @@ -49,12 +46,11 @@ public ChannelEndpoint select(
ChannelEndpoint c1 = candidates.get(index1);
ChannelEndpoint c2 = candidates.get(index2);

Double score1 = scoreLookup.apply(c1);
Double score2 = scoreLookup.apply(c2);
LatencyTracker t1 = c1.getLatencyTracker();
LatencyTracker t2 = c2.getLatencyTracker();

// Handle null scores by treating them as Double.MAX_VALUE (lowest priority)
double s1 = MoreObjects.firstNonNull(score1, Double.MAX_VALUE);
double s2 = MoreObjects.firstNonNull(score2, Double.MAX_VALUE);
double s1 = t1 != null ? t1.getScore() : Double.MAX_VALUE;
double s2 = t2 != null ? t2.getScore() : Double.MAX_VALUE;

return s1 <= s2 ? c1 : c2;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import com.google.api.core.BetaApi;
import com.google.api.core.InternalApi;
import java.util.List;
import java.util.function.Function;

/** Interface for selecting a replica from a list of candidates. */
@InternalApi
Expand All @@ -30,9 +29,7 @@ public interface ReplicaSelector {
* Selects a replica from the given list of candidates.
*
* @param candidates the list of eligible candidates.
* @param scoreLookup a function to look up the latency score for a candidate.
* @return the selected candidate, or null if the list is empty.
*/
ChannelEndpoint select(
List<ChannelEndpoint> candidates, Function<ChannelEndpoint, Double> scoreLookup);
ChannelEndpoint select(List<ChannelEndpoint> candidates);
}
Original file line number Diff line number Diff line change
Expand Up @@ -827,11 +827,16 @@ static final class FakeEndpoint implements ChannelEndpoint {
private final String address;
private final FakeManagedChannel channel = new FakeManagedChannel();
private EndpointHealthState state = EndpointHealthState.READY;
private LatencyTracker latencyTracker = null;

FakeEndpoint(String address) {
this.address = address;
}

void setLatencyTracker(LatencyTracker tracker) {
this.latencyTracker = tracker;
}

@Override
public String getAddress() {
return address;
Expand All @@ -852,6 +857,11 @@ public ManagedChannel getChannel() {
return channel;
}

@Override
public LatencyTracker getLatencyTracker() {
return latencyTracker;
}

void setState(EndpointHealthState state) {
this.state = state;
channel.setConnectivityState(toConnectivityState(state));
Expand Down Expand Up @@ -926,4 +936,88 @@ public String authority() {
return "fake";
}
}

@Test
public void usesReplicaSelectorWhenEnabled() {
FakeEndpointCache endpointCache = new FakeEndpointCache();

ReplicaSelector mockSelector =
(candidates) -> {
for (ChannelEndpoint candidate : candidates) {
if (candidate.getAddress().equals("server2")) {
return candidate;
}
}
return null;
};

KeyRangeCache cache = new KeyRangeCache(endpointCache, null, mockSelector);
cache.addRanges(twoReplicaUpdate());

endpointCache.get("server1");
endpointCache.get("server2");

RoutingHint.Builder hint = RoutingHint.newBuilder().setKey(bytes("a"));
ChannelEndpoint server =
cache.fillRoutingHint(
false,
KeyRangeCache.RangeMode.COVERING_SPLIT,
DirectedReadOptions.getDefaultInstance(),
hint);

assertNotNull(server);
assertEquals("server2", server.getAddress());
}

@Test
public void usesLatencyTrackerInSelector() {
FakeEndpointCache endpointCache = new FakeEndpointCache();

FakeEndpoint e1 = (FakeEndpoint) endpointCache.get("server1");
FakeEndpoint e2 = (FakeEndpoint) endpointCache.get("server2");

e1.setLatencyTracker(
new LatencyTracker() {
@Override
public double getScore() {
return 100.0;
}

@Override
public void update(java.time.Duration latency) {}

@Override
public void recordError(java.time.Duration penalty) {}
});

e2.setLatencyTracker(
new LatencyTracker() {
@Override
public double getScore() {
return 10.0;
}

@Override
public void update(java.time.Duration latency) {}

@Override
public void recordError(java.time.Duration penalty) {}
});

ReplicaSelector selector = new PowerOfTwoReplicaSelector();

KeyRangeCache cache = new KeyRangeCache(endpointCache, null, selector);
cache.addRanges(twoReplicaUpdate());

RoutingHint.Builder hint = RoutingHint.newBuilder().setKey(bytes("a"));
ChannelEndpoint server =
cache.fillRoutingHint(
false,
KeyRangeCache.RangeMode.COVERING_SPLIT,
DirectedReadOptions.getDefaultInstance(),
hint);

assertNotNull(server);
assertEquals("server2", server.getAddress());
}
}
Loading
Loading