Skip to content

Commit ea55529

Browse files
committed
feat(spanner): add shared endpoint cooldowns for location-aware rerouting
1 parent 8095342 commit ea55529

8 files changed

Lines changed: 1372 additions & 276 deletions

File tree

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
/*
2+
* Copyright 2026 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.cloud.spanner.spi.v1;
18+
19+
import com.google.common.annotations.VisibleForTesting;
20+
import java.time.Clock;
21+
import java.time.Duration;
22+
import java.time.Instant;
23+
import java.util.concurrent.ConcurrentHashMap;
24+
import java.util.concurrent.ThreadLocalRandom;
25+
import java.util.function.LongUnaryOperator;
26+
27+
/**
28+
* Tracks short-lived endpoint cooldowns after routed {@code RESOURCE_EXHAUSTED} failures.
29+
*
30+
* <p>This allows later requests to try a different replica instead of immediately routing back to
31+
* the same overloaded endpoint.
32+
*/
33+
final class EndpointOverloadCooldownTracker {
34+
35+
@VisibleForTesting static final Duration DEFAULT_INITIAL_COOLDOWN = Duration.ofSeconds(10);
36+
@VisibleForTesting static final Duration DEFAULT_MAX_COOLDOWN = Duration.ofMinutes(1);
37+
@VisibleForTesting static final Duration DEFAULT_RESET_AFTER = Duration.ofMinutes(10);
38+
39+
@VisibleForTesting
40+
static final class CooldownState {
41+
private final int consecutiveFailures;
42+
private final Instant cooldownUntil;
43+
private final Instant lastFailureAt;
44+
45+
private CooldownState(int consecutiveFailures, Instant cooldownUntil, Instant lastFailureAt) {
46+
this.consecutiveFailures = consecutiveFailures;
47+
this.cooldownUntil = cooldownUntil;
48+
this.lastFailureAt = lastFailureAt;
49+
}
50+
}
51+
52+
private final ConcurrentHashMap<String, CooldownState> entries = new ConcurrentHashMap<>();
53+
private final Duration initialCooldown;
54+
private final Duration maxCooldown;
55+
private final Duration resetAfter;
56+
private final Clock clock;
57+
private final LongUnaryOperator randomLong;
58+
59+
EndpointOverloadCooldownTracker() {
60+
this(
61+
DEFAULT_INITIAL_COOLDOWN,
62+
DEFAULT_MAX_COOLDOWN,
63+
DEFAULT_RESET_AFTER,
64+
Clock.systemUTC(),
65+
bound -> ThreadLocalRandom.current().nextLong(bound));
66+
}
67+
68+
@VisibleForTesting
69+
EndpointOverloadCooldownTracker(
70+
Duration initialCooldown,
71+
Duration maxCooldown,
72+
Duration resetAfter,
73+
Clock clock,
74+
LongUnaryOperator randomLong) {
75+
Duration resolvedInitial =
76+
(initialCooldown == null || initialCooldown.isZero() || initialCooldown.isNegative())
77+
? DEFAULT_INITIAL_COOLDOWN
78+
: initialCooldown;
79+
Duration resolvedMax =
80+
(maxCooldown == null || maxCooldown.isZero() || maxCooldown.isNegative())
81+
? DEFAULT_MAX_COOLDOWN
82+
: maxCooldown;
83+
if (resolvedMax.compareTo(resolvedInitial) < 0) {
84+
resolvedMax = resolvedInitial;
85+
}
86+
this.initialCooldown = resolvedInitial;
87+
this.maxCooldown = resolvedMax;
88+
this.resetAfter =
89+
(resetAfter == null || resetAfter.isZero() || resetAfter.isNegative())
90+
? DEFAULT_RESET_AFTER
91+
: resetAfter;
92+
this.clock = clock == null ? Clock.systemUTC() : clock;
93+
this.randomLong =
94+
randomLong == null ? bound -> ThreadLocalRandom.current().nextLong(bound) : randomLong;
95+
}
96+
97+
boolean isCoolingDown(String address) {
98+
if (address == null || address.isEmpty()) {
99+
return false;
100+
}
101+
Instant now = clock.instant();
102+
CooldownState state = entries.get(address);
103+
if (state == null) {
104+
return false;
105+
}
106+
if (state.cooldownUntil.isAfter(now)) {
107+
return true;
108+
}
109+
if (Duration.between(state.lastFailureAt, now).compareTo(resetAfter) < 0) {
110+
return false;
111+
}
112+
entries.remove(address, state);
113+
CooldownState current = entries.get(address);
114+
return current != null && current.cooldownUntil.isAfter(now);
115+
}
116+
117+
void recordFailure(String address) {
118+
if (address == null || address.isEmpty()) {
119+
return;
120+
}
121+
Instant now = clock.instant();
122+
entries.compute(
123+
address,
124+
(ignored, state) -> {
125+
int consecutiveFailures = 1;
126+
if (state != null
127+
&& Duration.between(state.lastFailureAt, now).compareTo(resetAfter) < 0) {
128+
consecutiveFailures = state.consecutiveFailures + 1;
129+
}
130+
Duration cooldown = cooldownForFailures(consecutiveFailures);
131+
return new CooldownState(consecutiveFailures, now.plus(cooldown), now);
132+
});
133+
}
134+
135+
private Duration cooldownForFailures(int failures) {
136+
Duration cooldown = initialCooldown;
137+
for (int i = 1; i < failures; i++) {
138+
if (cooldown.compareTo(maxCooldown.dividedBy(2)) > 0) {
139+
cooldown = maxCooldown;
140+
break;
141+
}
142+
cooldown = cooldown.multipliedBy(2);
143+
}
144+
long bound = Math.max(1L, cooldown.toMillis() + 1L);
145+
return Duration.ofMillis(randomLong.applyAsLong(bound));
146+
}
147+
148+
@VisibleForTesting
149+
CooldownState getState(String address) {
150+
return entries.get(address);
151+
}
152+
}

java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,11 @@ public GapicSpannerRpc(final SpannerOptions options) {
432432
this.readRetrySettings =
433433
options.getSpannerStubSettings().streamingReadSettings().getRetrySettings();
434434
this.readRetryableCodes =
435-
options.getSpannerStubSettings().streamingReadSettings().getRetryableCodes();
435+
ImmutableSet.<Code>builder()
436+
.addAll(
437+
options.getSpannerStubSettings().streamingReadSettings().getRetryableCodes())
438+
.add(Code.RESOURCE_EXHAUSTED)
439+
.build();
436440
this.executeQueryRetrySettings =
437441
options.getSpannerStubSettings().executeStreamingSqlSettings().getRetrySettings();
438442
this.executeQueryRetryableCodes =

java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java

Lines changed: 60 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import com.google.api.core.InternalApi;
2222
import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider;
2323
import com.google.cloud.spanner.XGoogSpannerRequestId;
24+
import com.google.common.annotations.VisibleForTesting;
2425
import com.google.common.cache.Cache;
2526
import com.google.common.cache.CacheBuilder;
2627
import com.google.protobuf.ByteString;
@@ -102,11 +103,20 @@ final class KeyAwareChannel extends ManagedChannel {
102103
.maximumSize(MAX_TRACKED_EXCLUDED_LOGICAL_REQUESTS)
103104
.expireAfterWrite(EXCLUDED_LOGICAL_REQUEST_TTL_MINUTES, TimeUnit.MINUTES)
104105
.build();
106+
private final EndpointOverloadCooldownTracker endpointOverloadCooldowns;
105107

106108
private KeyAwareChannel(
107109
InstantiatingGrpcChannelProvider channelProvider,
108110
@Nullable ChannelEndpointCacheFactory endpointCacheFactory)
109111
throws IOException {
112+
this(channelProvider, endpointCacheFactory, new EndpointOverloadCooldownTracker());
113+
}
114+
115+
private KeyAwareChannel(
116+
InstantiatingGrpcChannelProvider channelProvider,
117+
@Nullable ChannelEndpointCacheFactory endpointCacheFactory,
118+
EndpointOverloadCooldownTracker endpointOverloadCooldowns)
119+
throws IOException {
110120
if (endpointCacheFactory == null) {
111121
this.endpointCache = new GrpcChannelEndpointCache(channelProvider);
112122
} else {
@@ -120,6 +130,7 @@ private KeyAwareChannel(
120130
// would interfere with test assertions.
121131
this.lifecycleManager =
122132
(endpointCacheFactory == null) ? new EndpointLifecycleManager(endpointCache) : null;
133+
this.endpointOverloadCooldowns = endpointOverloadCooldowns;
123134
}
124135

125136
static KeyAwareChannel create(
@@ -129,6 +140,15 @@ static KeyAwareChannel create(
129140
return new KeyAwareChannel(channelProvider, endpointCacheFactory);
130141
}
131142

143+
@VisibleForTesting
144+
static KeyAwareChannel create(
145+
InstantiatingGrpcChannelProvider channelProvider,
146+
@Nullable ChannelEndpointCacheFactory endpointCacheFactory,
147+
EndpointOverloadCooldownTracker endpointOverloadCooldowns)
148+
throws IOException {
149+
return new KeyAwareChannel(channelProvider, endpointCacheFactory, endpointOverloadCooldowns);
150+
}
151+
132152
private static final class ChannelFinderReference extends SoftReference<ChannelFinder> {
133153
final String databaseId;
134154

@@ -321,36 +341,56 @@ void clearTransactionAndChannelAffinity(ByteString transactionId, @Nullable Long
321341

322342
private void maybeExcludeEndpointOnNextCall(
323343
@Nullable ChannelEndpoint endpoint, @Nullable String logicalRequestKey) {
324-
if (endpoint == null || logicalRequestKey == null) {
344+
if (endpoint == null) {
325345
return;
326346
}
327347
String address = endpoint.getAddress();
328-
if (!defaultEndpointAddress.equals(address)) {
329-
excludedEndpointsForLogicalRequest
330-
.asMap()
331-
.compute(
332-
logicalRequestKey,
333-
(ignored, excludedEndpoints) -> {
334-
Set<String> updated =
335-
excludedEndpoints == null ? ConcurrentHashMap.newKeySet() : excludedEndpoints;
336-
updated.add(address);
337-
return updated;
338-
});
348+
if (defaultEndpointAddress.equals(address)) {
349+
return;
350+
}
351+
endpointOverloadCooldowns.recordFailure(address);
352+
if (logicalRequestKey == null) {
353+
return;
339354
}
355+
excludedEndpointsForLogicalRequest
356+
.asMap()
357+
.compute(
358+
logicalRequestKey,
359+
(ignored, excludedEndpoints) -> {
360+
Set<String> updated =
361+
excludedEndpoints == null ? ConcurrentHashMap.newKeySet() : excludedEndpoints;
362+
updated.add(address);
363+
return updated;
364+
});
340365
}
341366

342367
private Predicate<String> consumeExcludedEndpointsForCurrentCall(
343368
@Nullable String logicalRequestKey) {
344-
if (logicalRequestKey == null) {
345-
return address -> false;
369+
Predicate<String> requestScopedExcluded = address -> false;
370+
if (logicalRequestKey != null) {
371+
Set<String> excludedEndpoints =
372+
excludedEndpointsForLogicalRequest.asMap().remove(logicalRequestKey);
373+
if (excludedEndpoints != null && !excludedEndpoints.isEmpty()) {
374+
excludedEndpoints = new HashSet<>(excludedEndpoints);
375+
requestScopedExcluded = excludedEndpoints::contains;
376+
}
346377
}
378+
Predicate<String> finalRequestScopedExcluded = requestScopedExcluded;
379+
return address ->
380+
finalRequestScopedExcluded.test(address)
381+
|| endpointOverloadCooldowns.isCoolingDown(address);
382+
}
383+
384+
@VisibleForTesting
385+
boolean isCoolingDown(String address) {
386+
return endpointOverloadCooldowns.isCoolingDown(address);
387+
}
388+
389+
@VisibleForTesting
390+
boolean hasExcludedEndpointForLogicalRequest(String logicalRequestKey, String address) {
347391
Set<String> excludedEndpoints =
348-
excludedEndpointsForLogicalRequest.asMap().remove(logicalRequestKey);
349-
if (excludedEndpoints == null || excludedEndpoints.isEmpty()) {
350-
return address -> false;
351-
}
352-
excludedEndpoints = new HashSet<>(excludedEndpoints);
353-
return excludedEndpoints::contains;
392+
excludedEndpointsForLogicalRequest.getIfPresent(logicalRequestKey);
393+
return excludedEndpoints != null && excludedEndpoints.contains(address);
354394
}
355395

356396
private boolean isReadOnlyTransaction(ByteString transactionId) {

0 commit comments

Comments
 (0)