diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelector.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelector.java index c7cc2012a615..df975e4834e8 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelector.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelector.java @@ -19,6 +19,7 @@ import com.google.api.core.BetaApi; import com.google.api.core.InternalApi; import com.google.common.base.MoreObjects; +import com.google.common.base.Preconditions; import java.util.List; import java.util.Random; import java.util.concurrent.ThreadLocalRandom; @@ -29,6 +30,19 @@ @BetaApi public class PowerOfTwoReplicaSelector implements ReplicaSelector { + public static final double DEFAULT_EPSILON = 0.1; + + private final double epsilon; + + public PowerOfTwoReplicaSelector() { + this(DEFAULT_EPSILON); + } + + public PowerOfTwoReplicaSelector(double epsilon) { + Preconditions.checkArgument(epsilon >= 0.0 && epsilon <= 1.0, "epsilon must be in [0, 1]"); + this.epsilon = epsilon; + } + @Override public ChannelEndpoint select( List candidates, Function scoreLookup) { @@ -40,6 +54,11 @@ public ChannelEndpoint select( } Random random = ThreadLocalRandom.current(); + + // Epsilon-greedy exploration: with probability epsilon, pick a random candidate. + if (random.nextDouble() < epsilon) { + return candidates.get(random.nextInt(candidates.size())); + } int index1 = random.nextInt(candidates.size()); int index2 = random.nextInt(candidates.size() - 1); if (index2 >= index1) { diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelectorTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelectorTest.java index 424efb363df6..3cdef1d3042c 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelectorTest.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/PowerOfTwoReplicaSelectorTest.java @@ -75,7 +75,8 @@ public void testSingleElement() { @Test public void testTwoElementsPicksBetter() { - PowerOfTwoReplicaSelector selector = new PowerOfTwoReplicaSelector(); + // Use epsilon=0.0 to test pure Po2RC behavior + PowerOfTwoReplicaSelector selector = new PowerOfTwoReplicaSelector(0.0); ChannelEndpoint better = new TestEndpoint("better"); ChannelEndpoint worse = new TestEndpoint("worse"); @@ -92,7 +93,8 @@ public void testTwoElementsPicksBetter() { @Test public void testThreeElementsNeverPicksWorst() { - PowerOfTwoReplicaSelector selector = new PowerOfTwoReplicaSelector(); + // Use epsilon=0.0 to test pure Po2RC behavior + PowerOfTwoReplicaSelector selector = new PowerOfTwoReplicaSelector(0.0); ChannelEndpoint best = new TestEndpoint("best"); ChannelEndpoint middle = new TestEndpoint("middle"); ChannelEndpoint worst = new TestEndpoint("worst"); @@ -112,7 +114,8 @@ public void testThreeElementsNeverPicksWorst() { @Test public void testNullScoresTreatedAsMax() { - PowerOfTwoReplicaSelector selector = new PowerOfTwoReplicaSelector(); + // Use epsilon=0.0 to test pure Po2RC behavior + PowerOfTwoReplicaSelector selector = new PowerOfTwoReplicaSelector(0.0); ChannelEndpoint withScore = new TestEndpoint("withScore"); ChannelEndpoint withoutScore = new TestEndpoint("withoutScore"); @@ -125,4 +128,27 @@ public void testNullScoresTreatedAsMax() { assertEquals(withScore, selector.select(candidates, scores::get)); } } + + @Test + public void testEpsilonExploration() { + // Set epsilon to 1.0 to force 100% exploration + PowerOfTwoReplicaSelector selector = new PowerOfTwoReplicaSelector(1.0); + ChannelEndpoint best = new TestEndpoint("best"); + ChannelEndpoint worst = new TestEndpoint("worst"); + + Map scores = new HashMap<>(); + scores.put(best, 10.0); + scores.put(worst, 20.0); + + List candidates = Arrays.asList(best, worst); + + boolean pickedWorst = false; + for (int i = 0; i < 100; i++) { + if (selector.select(candidates, scores::get) == worst) { + pickedWorst = true; + break; + } + } + assertTrue("Should occasionally pick worst with epsilon=1.0", pickedWorst); + } }